[
  {
    "path": ".gitattributes",
    "content": "*.ipynb filter=lfs diff=lfs merge=lfs -text\n*.pt filter=lfs diff=lfs merge=lfs -text\n*.obj filter=lfs diff=lfs merge=lfs -text\n*.dae filter=lfs diff=lfs merge=lfs -text\n*.onnx filter=lfs diff=lfs merge=lfs -text\n*.pkl filter=lfs diff=lfs merge=lfs -text\n*.npz filter=lfs diff=lfs merge=lfs -text\n*.npy filter=lfs diff=lfs merge=lfs -text\nassets/smplx filter=lfs diff=lfs merge=lfs -text\nassets/smpl filter=lfs diff=lfs merge=lfs -text\nassets/test_data filter=lfs diff=lfs merge=lfs -text\nassets/media/*.png filter=lfs diff=lfs merge=lfs -text\nassets/media/*.jpg filter=lfs diff=lfs merge=lfs -text\nassets/media/*.jpeg filter=lfs diff=lfs merge=lfs -text\nassets/videos/*.mp4 filter=lfs diff=lfs merge=lfs -text\n*.gif filter=lfs diff=lfs merge=lfs -text\n*.svg filter=lfs diff=lfs merge=lfs -text\n*.png filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": ".githooks/README.md",
    "content": "# Git hooks\n\nPre-commit runs [ruff](https://docs.astral.sh/ruff/) format on staged Python files (using `train.env` for the correct environment).\n\n**Install ruff** in the holomotion train environment if it is absent:\n\n```bash\nconda activate holomotion_train\npip install ruff\n```\n\nRuff is also listed in `environments/requirements_base.txt` if you install deps from there.\n\n**Enable hooks** (run once from repo root):\n\n```bash\ngit config core.hooksPath .githooks\n```\n\nEnsure the hook is executable: `chmod +x .githooks/pre-commit`\n\n**Requirement:** Run `git commit` from a shell where conda is available so `train.env` can set `Train_CONDA_PREFIX`.\n\n**Skip hook for one commit:** `git commit --no-verify`\n"
  },
  {
    "path": ".githooks/pre-commit",
    "content": "#!/bin/bash\n# Run ruff format on staged Python files before commit.\n# Use from holomotion repo root (standalone clone or submodule).\n# Requires: run git commit from a shell where conda is available (so train.env can set Train_CONDA_PREFIX).\n\nset -e\ncd \"$(git rev-parse --show-toplevel)\"\n\nmapfile -t staged_py < <(git diff --cached --name-only --diff-filter=ACM | grep '\\.py$' || true)\nif [ ${#staged_py[@]} -eq 0 ]; then\n  exit 0\nfi\n\n# train.env sets Train_CONDA_PREFIX; conda must be on PATH when you run git commit\nsource train.env\n\"$Train_CONDA_PREFIX/bin/ruff\" format --config pyproject.toml \"${staged_py[@]}\"\ngit add \"${staged_py[@]}\"\nexit 0\n"
  },
  {
    "path": ".gitignore",
    "content": "# ignore logs and cache\nlogs/\nlogs_eval/\ndata/\noutputs/\n.archive/\ntmp/\n\n# ignore deployment bag_record, install, log\ndeployment/unitree_g1_ros2/bag_record/\ndeployment/unitree_g1_ros2/install/\ndeployment/unitree_g1_ros2/log/\n\n# ignore data and outputs\ndata\ndata/\noutputs/\nbuild/\ninstall/\nlog/\n.DS_Store/\n**.egg-info/\n**.log\n**.LOG\n\n# ignore large files\n*.log\n*.pkl\n*.pt\n*.onnx\n*.npy\n*.npz\n*.zip\n*.tar.gz\n\n*.obj\n*.dae\n*.STL\n\n# ignore video, image, etc.\n*.mp4\n*.avi\n*.mov\n*.png\n*.pdf\n\n__pycache__/\n*.pyc\n*.egg-info\n\n.agents/\n.cursor/\n.cursorignore\n.vscode/\n\n.*_cache/\n\n**/usd/\nassets/isaac/\n\nnot_for_commit/\n\nthirdparties/smpl_models/\n\nMUJOCO_LOG.TXT\n\n# allow certain files\n!deployment/unitree_g1_ros2/src/models/*.onnx\n!deployment/unitree_g1_ros2/src/motion_data/*.pkl\n!assets/smpl/*.pkl\n!assets/smpl/*.npz\n!assets/smplx/*.pkl\n!assets/smplx/*.npz\n!assets/test_data/**\n!assets/media/**\n\n# macOS system files\n.DS_Store\n**/.DS_Store\n"
  },
  {
    "path": ".gitlab-ci.yml",
    "content": "\nworkflow:\n  rules:\n    - if: $CI_PIPELINE_SOURCE == 'merge_request_event'\n\njob1:\n  script:\n    - echo \"This job runs in merge request pipelines\"\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"thirdparties/SMPLSim\"]\npath = thirdparties/SMPLSim\nurl = https://github.com/ZhengyiLuo/SMPLSim\nbranch = master\n\n[submodule \"thirdparties/joints2smpl\"]\npath = thirdparties/joints2smpl\nurl = https://github.com/wangsen1312/joints2smpl.git\nbranch = main\n\n[submodule \"thirdparties/omomo_release\"]\npath = thirdparties/omomo_release\nurl = https://github.com/lijiaman/omomo_release.git\nbranch = main\n\n[submodule \"thirdparties/unitree_ros\"]\npath = thirdparties/unitree_ros\nurl = https://github.com/unitreerobotics/unitree_ros\nbranch = master\n\n[submodule \"thirdparties/unitree_ros2\"]\npath = thirdparties/unitree_ros2\nurl = https://github.com/unitreerobotics/unitree_ros2\nbranch = master\n\n[submodule \"thirdparties/unitree_sdk2_python\"]\npath = thirdparties/unitree_sdk2_python\nurl = https://github.com/unitreerobotics/unitree_sdk2_python.git\n\n[submodule \"thirdparties/cyclonedds\"]\npath = thirdparties/cyclonedds\nurl = https://github.com/eclipse-cyclonedds/cyclonedds\nbranch = 0.10.2\n\n[submodule \"thirdparties/unitree_sdk2\"]\npath = thirdparties/unitree_sdk2\nurl = https://github.com/unitreerobotics/unitree_sdk2.git\n\n[submodule \"thirdparties/GMR\"]\npath = thirdparties/GMR\nurl = https://github.com/YanjieZe/GMR.git\nbranch = master\n[submodule \"thirdparties/HoloMotion_assets\"]\npath = thirdparties/HoloMotion_assets\nurl = https://huggingface.co/datasets/HorizonRobotics/HoloMotion_assets\n[submodule \"thirdparties/smplx\"]\n\tpath = thirdparties/smplx\n\turl = https://github.com/vchoutas/smplx\n[submodule \"thirdparties/GVHMR\"]\n\tpath = thirdparties/GVHMR\n\turl = https://github.com/zju3dv/GVHMR.git\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2025 maiyue01.chen\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": "# Variables\nPY_SRC := holomotion/  # Your Python code directory\nRUFF := ruff  # Assumes Ruff is installed locally\nPYTEST := pytest -v\nTESTS := holomotion/tests/\nCOV := --cov=holomotion/src --cov-report=term-missing\n\n# Directory to lint/format - can be overridden with DIR=path\nDIR ?= holomotion/src\n\n.PHONY: lint format check lint-dir format-dir\n\n# Run Ruff linter on default directory\nlint:\n\t@echo \"Linting with Ruff...\"\n\t@$(RUFF) check $(PY_SRC)\n\n# Format code in default directory\nformat:\n\t@echo \"Formatting with Ruff...\"\n\t@$(RUFF) format $(PY_SRC)\n\t@$(RUFF) check --fix $(PY_SRC)  # Auto-fix lint errors\n\n# Run Ruff linter on specific directory (with fallback)\nlint-dir:\n\t@echo \"Linting directory: $(DIR)\"\n\t@$(RUFF) check $(DIR)\n\n# Format code in specific directory (with fallback)\nformat-dir:\n\t@echo \"Formatting directory: $(DIR)\"\n\t@$(RUFF) format $(DIR)\n\t@$(RUFF) check --fix $(DIR)  # Auto-fix lint errors\n\n# Strict check (for CI)\ncheck:\n\t@$(RUFF) check $(PY_SRC) --exit-non-zero-on-fix\n\n# Run all tests\ntest:\n\t$(PYTEST) $(TESTS)"
  },
  {
    "path": "NOTICE",
    "content": "\n\n=======================================================================\nASAP's MIT License\n=======================================================================\nCode derived from implementations in ASAP should mention its derivation\nand reference the following license:\n\n    MIT License\n\n    Copyright (c) 2025 ASAP Team\n\n    Permission is hereby granted, free of charge, to any person obtaining a copy\n    of this software and associated documentation files (the \"Software\"), to deal\n    in the Software without restriction, including without limitation the rights\n    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n    copies of the Software, and to permit persons to whom the Software is\n    furnished to do so, subject to the following conditions:\n\n    The above copyright notice and this permission notice shall be included in all\n    copies or substantial portions of the Software.\n\n    THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n    SOFTWARE.\n\n\n\n=======================================================================\nomomo_release's MIT License\n=======================================================================\nCode derived from implementations in omomo_release should mention its derivation\nand reference the following license:\n\n   MIT License\n\n   Copyright (c) 2023 Jiaman Li\n\n   Permission is hereby granted, free of charge, to any person obtaining a copy\n   of this software and associated documentation files (the \"Software\"), to deal\n   in the Software without restriction, including without limitation the rights\n   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n   copies of the Software, and to permit persons to whom the Software is\n   furnished to do so, subject to the following conditions:\n\n   The above copyright notice and this permission notice shall be included in all\n   copies or substantial portions of the Software.\n\n   THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n   SOFTWARE.\n\n=======================================================================\nNVIDIA License\n=======================================================================\n\n   Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited."
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n<img src=\"assets/media/holomotion_logo_text.png\" alt=\"HoloMotion Logo\" width=\"500\"/>\n\n---\n\n[![Python](https://img.shields.io/badge/Python3.11-3776AB?logo=python&logoColor=fff)](#)\n[![Ubuntu](https://img.shields.io/badge/Ubuntu22.04-E95420?logo=ubuntu&logoColor=white)](#)\n[![License](https://img.shields.io/badge/License-Apache_2.0-green?logo=apache&logoColor=white)](./LICENSE)\n\n[![Safari](https://img.shields.io/badge/Website-006CFF?logo=safari&logoColor=fff)](https://horizonrobotics.github.io/robot_lab/holomotion/)\n[![HuggingFace](https://img.shields.io/badge/-HuggingFace-3B4252?style=flat&logo=huggingface&logoColor=)](https://huggingface.co/collections/HorizonRobotics/holomotion)\n[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/HorizonRobotics/HoloMotion)\n[![WeChat](https://img.shields.io/badge/Wechat-7BB32E?logo=wechat&logoColor=white)](https://horizonrobotics.feishu.cn/docx/Xs3cdEI8bo1EZuxUfzjckTgKn2c)\n\n<!-- [![arXiv](https://img.shields.io/badge/arXiv-2025.00000-red?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2025.00000) -->\n<!-- [![arXiv](https://img.shields.io/badge/arXiv-2025.00000-red?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2025.00000) -->\n\n</div>\n\n# HoloMotion: A Foundation Model for Whole-Body Humanoid Control\n\n## NEWS\n- [2026.04.04] The v1.2 version of HoloMotion has been released, we provide pre-trained motion tracking and velocity tracking models for the community to deploy directly.\n\n- [2026.01.06] The v1.1 version of HoloMotion has been released, representing a major step forward toward a fully engineered, stable, and reproducible humanoid motion intelligence system.\n\n- [2025.11.05] The v1.0 version of HoloMotion has been released, and the WeChat user group is now open! Please scan the [QR Code](https://horizonrobotics.feishu.cn/docx/Xs3cdEI8bo1EZuxUfzjckTgKn2c) to join.\n<!-- <p align=\"center\">\n  <img width=\"150\" height=\"230\" src=\"assets/media/wechat_group_20251125.jpg\" hspace=\"10\">\n</p> -->\n\n## Pre-trained Models\n- Motion Tracking Model: [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_motion_tracking_model)\n- Velocity Tracking Model: [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_velocity_tracking_model)\n\nPlease read the doc for  [real-world deployment](docs/realworld_deployment.md) for more details on how to use the models. \n\n## Introduction\n\nHoloMotion is a foundation model for humanoid robotics, designed to fullfill robust, real-time, and generalizable whole-body control.\n\nOur framework provides an end-to-end solution, encompassing the entire workflow from data curation and motion retargeting to distributed model training, evaluation, and seamless deployment on physical hardware via ROS2. HoloMotion's modular architecture allows for flexible adaptation and extension, enabling researchers and developers to build and benchmark agents that can imitate, generalize, and master complex whole-body motions.\n\nFor those at the forefront of creating the next generation of humanoid robots, HoloMotion serves as a powerful, extensible, and open-source foundation for achieving whole-body control.\n\n---\n\n### 🛠️ Roadmap: Progress Toward Any Humanoid Control\n\nWe envision HoloMotion as a general-purpose foundation for humanoid motion and control. Its development is structured around four core generalization goals: Any Pose, Any Command, Any Terrain, and Any Embodiment. Each goal corresponds to a major version milestone.\n\n| Version  | Target Capability | Description                                                                                                                         |\n| -------- | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------- |\n| **v1.0** | 🔄 Any Pose       | Achieve robust tracking and imitation of diverse, whole-body human motions, forming the core of the imitation learning capability.  |\n| **v2.0** | ⏳ Any Command    | Enable language- and task-conditioned motion generation, allowing for goal-directed and interactive behaviors.                      |\n| **v3.0** | ⏳ Any Terrain    | Master adaptation to uneven, dynamic, and complex terrains, enhancing real-world operational robustness.                            |\n| **v4.0** | ⏳ Any Embodiment | Generalize control policies across humanoids with varying morphologies and kinematics, achieving true embodiment-level abstraction. |\n\n> Each stage builds on the previous one, moving from motion imitation to instruction following, terrain adaptation, and embodiment-level generalization.\n\n## Pipeline Overview\n\n```mermaid\nflowchart LR\n    A[\"🔧 1. Environment Setup<br/>Dependencies & conda\"]\n\n    subgraph dataFrame [\"DATA\"]\n        B[\"📊 2. Dataset Preparation<br/>Download & curate\"]\n        C[\"🔄 3. Motion Retargeting<br/>Human to robot motion\"]\n        B --> C\n    end\n\n    subgraph modelFrame [\"TRAIN & EVAL\"]\n        D[\"🧠 4. Model Training<br/>Train with HoloMotion\"]\n        E[\"📈 5. Evaluation<br/>Test & export\"]\n        D --> E\n    end\n\n    F[\"🚀 6. Deployment<br/>Deploy to robots\"]\n\n    A --> dataFrame\n    dataFrame --> modelFrame\n    modelFrame --> F\n\n    classDef subgraphStyle fill:#f9f9f9,stroke:#333,stroke-width:2px,stroke-dasharray:5 5,rx:10,ry:10,font-size:16px,font-weight:bold\n    classDef nodeStyle fill:#e1f5fe,stroke:#0277bd,stroke-width:2px,rx:10,ry:10\n\n    class dataFrame,modelFrame subgraphStyle\n    class A,B,C,D,E,F nodeStyle\n```\n\n## Quick Start\n\n### 🔧 1. Environment Setup [[Doc](docs/environment_setup.md)]\n\nSet up your development and deployment environments using Conda. This initial step ensures all dependencies are correctly configured for both training and real-world execution.\n\nIf you only intend to use our pretrained models, you can skip the training environment setup and proceed directly to configure the deployment environment. See the [real-world deployment documentation](docs/realworld_deployment.md) for details.\n\n### 📊 2. Dataset Preparation [[Doc](docs/smpl_data_curation.md)]\n\nAcquire and process large-scale motion datasets. Our tools help you curate high-quality data by converting it to the AMASS-compatible smpl format and filtering out anomalies using kinematic metrics.\n\n### 🔄 3. Motion Retargeting [[Doc](docs/motion_retargeting.md)]\n\nTranslate human motion data into robot-specific kinematic data. Our pipeline leverages [GMR](https://github.com/YanjieZe/GMR) to map human movements onto your robot's morphology, producing optimized HDF5 datasets ready for high-speed, distributed training.\n\n### 🧠 4. Model Training [[Doc](docs/train_motion_tracking.md)]\n\nTrain your foundation model using our reinforcement learning framework. HoloMotion supports versatile training tasks, including motion tracking and velocity tracking.\n\n### 📈 5. Evaluation [[Doc](docs/evaluate_motion_tracking.md)]\n\nEvaluate your trained policies in IsaacLab. Visualize performance, and export trained models in ONNX format for seamless deployment.\n\n### 🚀 6. Real-world Deployment [[Doc](docs/realworld_deployment.md)]\n\nOur ROS2 package facilitates the deployment of the exported ONNX models, enabling real-time control on hardware like the Unitree G1.\n\n## Join Us\n\nWe are hiring full-time engineers, new graduates, and interns who are excited about humanoid robots, motion control, and embodied intelligence.\nSend your resume by scanning the **WeChat** QR code below to get in touch with us.\n\n<p align=\"center\">\n  <img width=\"420\" height=\"150\" src=\"assets/media/qr_codes.png\" hspace=\"10\">\n</p>\n\n## Citation\n\n```\n@software{HoloMotion,\n  author = {Maiyue Chen, Kaihui Wang, Bo Zhang, Yi Ren, Zihao Zhu, Xihan Ma, Qijun Huang, Zhiyuan Yang, Yucheng Wang, Zhizhong Su},\n  title = {HoloMotion: A Foundation Model for Whole-Body Humanoid Control},\n  year = {2026},\n  month = April,\n  version = {1.2.0},\n  url = {https://github.com/HorizonRobotics/HoloMotion},\n  license = {Apache-2.0}\n}\n```\n\n## License\n\nThis project is released under the **[Apache 2.0](https://img.shields.io/badge/license-Apache--2.0-blue.svg)** license.\n\n## Acknowledgements\n\nThis project is built upon and inspired by several outstanding open source projects:\n\n- [GMR](https://github.com/YanjieZe/GMR)\n- [BeyondMimic](https://github.com/HybridRobotics/whole_body_tracking/tree/dcecabd8c24c68f59d143fdf8e3a670f420c972d)\n- [ASAP](https://github.com/LeCAR-Lab/ASAP)\n- [Humanoidverse](https://github.com/LeCAR-Lab/HumanoidVerse)\n- [PHC](https://github.com/ZhengyiLuo/PHC?tab=readme-ov-file)\n- [ProtoMotion](https://github.com/NVlabs/ProtoMotions/tree/main/protomotions)\n- [Mink](https://github.com/kevinzakka/mink)\n- [PBHC](https://github.com/TeleHuman/PBHC)\n"
  },
  {
    "path": "assets/robots/unitree/G1/29dof/g1_29dof.xml",
    "content": "<mujoco model=\"g1_29dof\">\n  <compiler angle=\"radian\" meshdir=\"../meshes\" />\n\n  <default>\n    <default class=\"torso_motor\">\n      <joint damping=\"0.05\" armature=\"0.01\" frictionloss=\"0.2\"/>\n    </default>\n    <default class=\"leg_motor\">\n      <joint damping=\"0.05\" armature=\"0.01\" frictionloss=\"0.2\"/>\n    </default>\n    <default class=\"ankle_motor\">\n      <joint damping=\"0.05\" armature=\"0.01\" frictionloss=\"0.2\"/>\n    </default>\n    <default class=\"arm_motor\">\n      <joint damping=\"0.05\" armature=\"0.01\" frictionloss=\"0.2\"/>\n    </default>\n    <default class=\"wrist_motor\">\n      <joint damping=\"0.05\" armature=\"0.01\" frictionloss=\"0.1\"/>\n    </default>\n    \n  </default>\n\n  <asset>\n    <mesh name=\"pelvis\" file=\"pelvis.STL\" />\n    <mesh name=\"pelvis_contour_link\" file=\"pelvis_contour_link.STL\" />\n    <mesh name=\"left_hip_pitch_link\" file=\"left_hip_pitch_link.STL\" />\n    <mesh name=\"left_hip_roll_link\" file=\"left_hip_roll_link.STL\" />\n    <mesh name=\"left_hip_yaw_link\" file=\"left_hip_yaw_link.STL\" />\n    <mesh name=\"left_knee_link\" file=\"left_knee_link.STL\" />\n    <mesh name=\"left_ankle_pitch_link\" file=\"left_ankle_pitch_link.STL\" />\n    <mesh name=\"left_ankle_roll_link\" file=\"left_ankle_roll_link.STL\" />\n    <mesh name=\"right_hip_pitch_link\" file=\"right_hip_pitch_link.STL\" />\n    <mesh name=\"right_hip_roll_link\" file=\"right_hip_roll_link.STL\" />\n    <mesh name=\"right_hip_yaw_link\" file=\"right_hip_yaw_link.STL\" />\n    <mesh name=\"right_knee_link\" file=\"right_knee_link.STL\" />\n    <mesh name=\"right_ankle_pitch_link\" file=\"right_ankle_pitch_link.STL\" />\n    <mesh name=\"right_ankle_roll_link\" file=\"right_ankle_roll_link.STL\" />\n    <mesh name=\"waist_yaw_link\" file=\"waist_yaw_link.STL\" />\n    <mesh name=\"waist_roll_link\" file=\"waist_roll_link.STL\" />\n    <mesh name=\"torso_link\" file=\"torso_link.STL\" />\n    <mesh name=\"logo_link\" file=\"logo_link.STL\" />\n    <mesh name=\"head_link\" file=\"head_link.STL\" />\n    <mesh name=\"waist_support_link\" file=\"waist_support_link.STL\" />\n    <mesh name=\"left_shoulder_pitch_link\" file=\"left_shoulder_pitch_link.STL\" />\n    <mesh name=\"left_shoulder_roll_link\" file=\"left_shoulder_roll_link.STL\" />\n    <mesh name=\"left_shoulder_yaw_link\" file=\"left_shoulder_yaw_link.STL\" />\n    <mesh name=\"left_elbow_link\" file=\"left_elbow_link.STL\" />\n    <mesh name=\"left_wrist_roll_link\" file=\"left_wrist_roll_link.STL\" />\n    <mesh name=\"left_wrist_pitch_link\" file=\"left_wrist_pitch_link.STL\" />\n    <mesh name=\"left_wrist_yaw_link\" file=\"left_wrist_yaw_link.STL\" />\n    <mesh name=\"left_rubber_hand\" file=\"left_rubber_hand.STL\" />\n    <mesh name=\"right_shoulder_pitch_link\" file=\"right_shoulder_pitch_link.STL\" />\n    <mesh name=\"right_shoulder_roll_link\" file=\"right_shoulder_roll_link.STL\" />\n    <mesh name=\"right_shoulder_yaw_link\" file=\"right_shoulder_yaw_link.STL\" />\n    <mesh name=\"right_elbow_link\" file=\"right_elbow_link.STL\" />\n    <mesh name=\"right_wrist_roll_link\" file=\"right_wrist_roll_link.STL\" />\n    <mesh name=\"right_wrist_pitch_link\" file=\"right_wrist_pitch_link.STL\" />\n    <mesh name=\"right_wrist_yaw_link\" file=\"right_wrist_yaw_link.STL\" />\n    <mesh name=\"right_rubber_hand\" file=\"right_rubber_hand.STL\" />\n  </asset>\n\n  <worldbody>\n    <body name=\"pelvis\" pos=\"0 0 0.793\">\n      <site name=\"imu\" size=\"0.01\" pos=\"0.0 0.0 0.0\" />\n      <inertial pos=\"0 0 -0.07605\" quat=\"1 0 -0.000399148 0\" mass=\"3.813\"\n        diaginertia=\"0.010549 0.0093089 0.0079184\" />\n      <joint name=\"floating_base_joint\" type=\"free\" limited=\"false\" actuatorfrclimited=\"false\" />\n      <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\"\n        mesh=\"pelvis\" />\n      <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n        mesh=\"pelvis_contour_link\" />\n      <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"pelvis_contour_link\" />\n      <body name=\"left_hip_pitch_link\" pos=\"0 0.064452 -0.1027\">\n        <inertial pos=\"0.002741 0.047791 -0.02606\" quat=\"0.954862 0.293964 0.0302556 0.030122\"\n          mass=\"1.35\" diaginertia=\"0.00181517 0.00153422 0.00116212\" />\n        <joint name=\"left_hip_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-2.5307 2.8798\"\n          actuatorfrcrange=\"-88 88\" class=\"leg_motor\" />\n        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\"\n          mesh=\"left_hip_pitch_link\" />\n        <geom type=\"mesh\" rgba=\"0.2 0.2 0.2 1\" mesh=\"left_hip_pitch_link\" />\n        <body name=\"left_hip_roll_link\" pos=\"0 0.052 -0.030465\" quat=\"0.996179 0 -0.0873386 0\">\n          <inertial pos=\"0.029812 -0.001045 -0.087934\"\n            quat=\"0.977808 -1.97119e-05 0.205576 -0.0403793\" mass=\"1.52\"\n            diaginertia=\"0.00254986 0.00241169 0.00148755\" />\n          <joint name=\"left_hip_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.5236 2.9671\"\n            actuatorfrcrange=\"-88 88\" class=\"leg_motor\" />\n          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n            mesh=\"left_hip_roll_link\" />\n          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_hip_roll_link\" />\n          <body name=\"left_hip_yaw_link\" pos=\"0.025001 0 -0.12412\">\n            <inertial pos=\"-0.057709 -0.010981 -0.15078\" quat=\"0.600598 0.15832 0.223482 0.751181\"\n              mass=\"1.702\" diaginertia=\"0.00776166 0.00717575 0.00160139\" />\n            <joint name=\"left_hip_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-2.7576 2.7576\"\n              actuatorfrcrange=\"-88 88\" class=\"leg_motor\" />\n            <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n              mesh=\"left_hip_yaw_link\" />\n            <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_hip_yaw_link\" />\n            <body name=\"left_knee_link\" pos=\"-0.078273 0.0021489 -0.17734\"\n              quat=\"0.996179 0 0.0873386 0\">\n              <inertial pos=\"0.005457 0.003964 -0.12074\"\n                quat=\"0.923418 -0.0327699 0.0158246 0.382067\" mass=\"1.932\"\n                diaginertia=\"0.0113804 0.0112778 0.00146458\" />\n              <joint name=\"left_knee_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.087267 2.8798\"\n                actuatorfrcrange=\"-139 139\" class=\"leg_motor\" />\n              <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                rgba=\"0.7 0.7 0.7 1\" mesh=\"left_knee_link\" />\n              <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_knee_link\" />\n              <body name=\"left_ankle_pitch_link\" pos=\"0 -9.4445e-05 -0.30001\">\n                <inertial pos=\"-0.007269 0 0.011137\" quat=\"0.603053 0.369225 0.369225 0.603053\"\n                  mass=\"0.074\" diaginertia=\"1.89e-05 1.40805e-05 6.9195e-06\" />\n                <joint name=\"left_ankle_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\"\n                  range=\"-0.87267 0.5236\" actuatorfrcrange=\"-50 50\"  class=\"ankle_motor\"/>\n                <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                  rgba=\"0.7 0.7 0.7 1\" mesh=\"left_ankle_pitch_link\" />\n                <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_ankle_pitch_link\" />\n                <body name=\"left_ankle_roll_link\" pos=\"0 0 -0.017558\">\n                  <inertial pos=\"0.026505 0 -0.016425\"\n                    quat=\"-0.000481092 0.728482 -0.000618967 0.685065\" mass=\"0.608\"\n                    diaginertia=\"0.00167218 0.0016161 0.000217621\" />\n                  <joint name=\"left_ankle_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\"\n                    range=\"-0.2618 0.2618\" actuatorfrcrange=\"-50 50\" class=\"ankle_motor\"/>\n                  <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                    rgba=\"0.2 0.2 0.2 1\" mesh=\"left_ankle_roll_link\" />\n                  <geom size=\"0.005\" pos=\"-0.05 0.025 -0.03\" rgba=\"0.2 0.2 0.2 1\" />\n                  <geom size=\"0.005\" pos=\"-0.05 -0.025 -0.03\" rgba=\"0.2 0.2 0.2 1\" />\n                  <geom size=\"0.005\" pos=\"0.12 0.03 -0.03\" rgba=\"0.2 0.2 0.2 1\" />\n                  <geom size=\"0.005\" pos=\"0.12 -0.03 -0.03\" rgba=\"0.2 0.2 0.2 1\" />\n                </body>\n              </body>\n            </body>\n          </body>\n        </body>\n      </body>\n      <body name=\"right_hip_pitch_link\" pos=\"0 -0.064452 -0.1027\">\n        <inertial pos=\"0.002741 -0.047791 -0.02606\" quat=\"0.954862 -0.293964 0.0302556 -0.030122\"\n          mass=\"1.35\" diaginertia=\"0.00181517 0.00153422 0.00116212\" />\n        <joint name=\"right_hip_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-2.5307 2.8798\"\n          actuatorfrcrange=\"-88 88\" class=\"leg_motor\" />\n        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\"\n          mesh=\"right_hip_pitch_link\" />\n        <geom type=\"mesh\" rgba=\"0.2 0.2 0.2 1\" mesh=\"right_hip_pitch_link\" />\n        <body name=\"right_hip_roll_link\" pos=\"0 -0.052 -0.030465\" quat=\"0.996179 0 -0.0873386 0\">\n          <inertial pos=\"0.029812 0.001045 -0.087934\" quat=\"0.977808 1.97119e-05 0.205576 0.0403793\"\n            mass=\"1.52\" diaginertia=\"0.00254986 0.00241169 0.00148755\" />\n          <joint name=\"right_hip_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-2.9671 0.5236\"\n            actuatorfrcrange=\"-88 88\" class=\"leg_motor\" />\n          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n            mesh=\"right_hip_roll_link\" />\n          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_hip_roll_link\" />\n          <body name=\"right_hip_yaw_link\" pos=\"0.025001 0 -0.12412\">\n            <inertial pos=\"-0.057709 0.010981 -0.15078\" quat=\"0.751181 0.223482 0.15832 0.600598\"\n              mass=\"1.702\" diaginertia=\"0.00776166 0.00717575 0.00160139\" />\n            <joint name=\"right_hip_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-2.7576 2.7576\"\n              actuatorfrcrange=\"-88 88\" class=\"leg_motor\" />\n            <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n              mesh=\"right_hip_yaw_link\" />\n            <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_hip_yaw_link\" />\n            <body name=\"right_knee_link\" pos=\"-0.078273 -0.0021489 -0.17734\"\n              quat=\"0.996179 0 0.0873386 0\">\n              <inertial pos=\"0.005457 -0.003964 -0.12074\"\n                quat=\"0.923439 0.0345276 0.0116333 -0.382012\" mass=\"1.932\"\n                diaginertia=\"0.011374 0.0112843 0.00146452\" />\n              <joint name=\"right_knee_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.087267 2.8798\"\n                actuatorfrcrange=\"-139 139\" class=\"leg_motor\" />\n              <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                rgba=\"0.7 0.7 0.7 1\" mesh=\"right_knee_link\" />\n              <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_knee_link\" />\n              <body name=\"right_ankle_pitch_link\" pos=\"0 9.4445e-05 -0.30001\">\n                <inertial pos=\"-0.007269 0 0.011137\" quat=\"0.603053 0.369225 0.369225 0.603053\"\n                  mass=\"0.074\" diaginertia=\"1.89e-05 1.40805e-05 6.9195e-06\" />\n                <joint name=\"right_ankle_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\"\n                  range=\"-0.87267 0.5236\" actuatorfrcrange=\"-50 50\" class=\"ankle_motor\" />\n                <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                  rgba=\"0.7 0.7 0.7 1\" mesh=\"right_ankle_pitch_link\" />\n                <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_ankle_pitch_link\" />\n                <body name=\"right_ankle_roll_link\" pos=\"0 0 -0.017558\">\n                  <inertial pos=\"0.026505 0 -0.016425\"\n                    quat=\"0.000481092 0.728482 0.000618967 0.685065\" mass=\"0.608\"\n                    diaginertia=\"0.00167218 0.0016161 0.000217621\" />\n                  <joint name=\"right_ankle_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\"\n                    range=\"-0.2618 0.2618\" actuatorfrcrange=\"-50 50\" class=\"ankle_motor\" />\n                  <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                    rgba=\"0.2 0.2 0.2 1\" mesh=\"right_ankle_roll_link\" />\n                  <geom size=\"0.005\" pos=\"-0.05 0.025 -0.03\" rgba=\"0.2 0.2 0.2 1\" />\n                  <geom size=\"0.005\" pos=\"-0.05 -0.025 -0.03\" rgba=\"0.2 0.2 0.2 1\" />\n                  <geom size=\"0.005\" pos=\"0.12 0.03 -0.03\" rgba=\"0.2 0.2 0.2 1\" />\n                  <geom size=\"0.005\" pos=\"0.12 -0.03 -0.03\" rgba=\"0.2 0.2 0.2 1\" />\n                </body>\n              </body>\n            </body>\n          </body>\n        </body>\n      </body>\n      <body name=\"waist_yaw_link\">\n        <inertial pos=\"0.003964 0 0.018769\" quat=\"-0.0178291 0.628464 0.0282471 0.777121\"\n          mass=\"0.244\" diaginertia=\"0.000158561 0.000124229 9.67669e-05\" />\n        <joint name=\"waist_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-2.618 2.618\"\n          actuatorfrcrange=\"-88 88\" class=\"torso_motor\"/>\n        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n          mesh=\"waist_yaw_link\" />\n        <body name=\"waist_roll_link\" pos=\"-0.0039635 0 0.035\">\n          <inertial pos=\"0 -0.000236 0.010111\" quat=\"0.99979 0.020492 0 0\" mass=\"0.047\"\n            diaginertia=\"7.515e-06 6.40206e-06 3.98394e-06\" />\n          <joint name=\"waist_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.52 0.52\"\n            actuatorfrcrange=\"-50 50\" class=\"torso_motor\"/>\n          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n            mesh=\"waist_roll_link\" />\n          <body name=\"torso_link\" pos=\"0 0 0.019\">\n            <inertial pos=\"0.00331658 0.000261533 0.179856\"\n              quat=\"0.999831 0.000376204 0.0179895 -0.00377704\" mass=\"9.598\"\n              diaginertia=\"0.12407 0.111951 0.0325382\" />\n            <joint name=\"waist_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.52 0.52\"\n              actuatorfrcrange=\"-50 50\" class=\"torso_motor\"/>\n            <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n              mesh=\"torso_link\" />\n            <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"torso_link\" />\n            <geom pos=\"0.0039635 0 -0.054\" quat=\"1 0 0 0\" type=\"mesh\" contype=\"0\" conaffinity=\"0\"\n              group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"logo_link\" />\n            <geom pos=\"0.0039635 0 -0.054\" quat=\"1 0 0 0\" type=\"mesh\" rgba=\"0.2 0.2 0.2 1\"\n              mesh=\"logo_link\" />\n            <geom pos=\"0.0039635 0 -0.054\" type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\"\n              density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"head_link\" />\n            <geom pos=\"0.0039635 0 -0.054\" type=\"mesh\" rgba=\"0.2 0.2 0.2 1\" mesh=\"head_link\" />\n            <geom pos=\"0.0039635 0 -0.054\" quat=\"1 0 0 0\" type=\"mesh\" contype=\"0\" conaffinity=\"0\"\n              group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"waist_support_link\" />\n            <geom pos=\"0.0039635 0 -0.054\" quat=\"1 0 0 0\" type=\"mesh\" rgba=\"0.7 0.7 0.7 1\"\n              mesh=\"waist_support_link\" />\n            <body name=\"left_shoulder_pitch_link\" pos=\"0.0039563 0.10022 0.23778\"\n              quat=\"0.990264 0.139201 1.38722e-05 -9.86868e-05\">\n              <inertial pos=\"0 0.035892 -0.011628\" quat=\"0.654152 0.0130458 -0.326267 0.68225\"\n                mass=\"0.718\" diaginertia=\"0.000465864 0.000432842 0.000406394\" />\n              <joint name=\"left_shoulder_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\"\n                range=\"-3.0892 2.6704\" actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n              <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                rgba=\"0.7 0.7 0.7 1\" mesh=\"left_shoulder_pitch_link\" />\n              <geom size=\"0.03 0.025\" pos=\"0 0.04 -0.01\" quat=\"0.707107 0 0.707107 0\"\n                type=\"cylinder\" rgba=\"0.7 0.7 0.7 1\" />\n              <body name=\"left_shoulder_roll_link\" pos=\"0 0.038 -0.013831\"\n                quat=\"0.990268 -0.139172 0 0\">\n                <inertial pos=\"-0.000227 0.00727 -0.063243\"\n                  quat=\"0.701256 -0.0196223 -0.00710317 0.712604\" mass=\"0.643\"\n                  diaginertia=\"0.000691311 0.000618011 0.000388977\" />\n                <joint name=\"left_shoulder_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\"\n                  range=\"-1.5882 2.2515\" actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n                <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                  rgba=\"0.7 0.7 0.7 1\" mesh=\"left_shoulder_roll_link\" />\n                <geom size=\"0.03 0.015\" pos=\"-0.004 0.006 -0.053\" type=\"cylinder\"\n                  rgba=\"0.7 0.7 0.7 1\" />\n                <body name=\"left_shoulder_yaw_link\" pos=\"0 0.00624 -0.1032\">\n                  <inertial pos=\"0.010773 -0.002949 -0.072009\"\n                    quat=\"0.716879 -0.0964829 -0.0679942 0.687134\" mass=\"0.734\"\n                    diaginertia=\"0.00106187 0.00103217 0.000400661\" />\n                  <joint name=\"left_shoulder_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\"\n                    range=\"-2.618 2.618\" actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n                  <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                    rgba=\"0.7 0.7 0.7 1\" mesh=\"left_shoulder_yaw_link\" />\n                  <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_shoulder_yaw_link\" />\n                  <body name=\"left_elbow_link\" pos=\"0.015783 0 -0.080518\">\n                    <inertial pos=\"0.064956 0.004454 -0.010062\"\n                      quat=\"0.541765 0.636132 0.388821 0.388129\" mass=\"0.6\"\n                      diaginertia=\"0.000443035 0.000421612 0.000259353\" />\n                    <joint name=\"left_elbow_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.0472 2.0944\"\n                      actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n                    <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                      rgba=\"0.7 0.7 0.7 1\" mesh=\"left_elbow_link\" />\n                    <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_elbow_link\" />\n                    <body name=\"left_wrist_roll_link\" pos=\"0.1 0.00188791 -0.01\">\n                      <inertial pos=\"0.0171394 0.000537591 4.8864e-07\"\n                        quat=\"0.575338 0.411667 -0.574906 0.411094\" mass=\"0.085445\"\n                        diaginertia=\"5.48211e-05 4.96646e-05 3.57798e-05\" />\n                      <joint name=\"left_wrist_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\"\n                        range=\"-1.97222 1.97222\" actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n                      <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                        rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_roll_link\" />\n                      <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_roll_link\" />\n                      <body name=\"left_wrist_pitch_link\" pos=\"0.038 0 0\">\n                        <inertial pos=\"0.0229999 -0.00111685 -0.00111658\"\n                          quat=\"0.249998 0.661363 0.293036 0.643608\" mass=\"0.48405\"\n                          diaginertia=\"0.000430353 0.000429873 0.000164648\" />\n                        <joint name=\"left_wrist_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\"\n                          range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\" class=\"wrist_motor\" />\n                        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                          rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_pitch_link\" />\n                        <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_pitch_link\" />\n                        <body name=\"left_wrist_yaw_link\" pos=\"0.046 0 0\">\n                          <inertial pos=\"0.0708244 0.000191745 0.00161742\"\n                            quat=\"0.510571 0.526295 0.468078 0.493188\" mass=\"0.254576\"\n                            diaginertia=\"0.000646113 0.000559993 0.000147566\" />\n                          <joint name=\"left_wrist_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\"\n                            range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\" class=\"wrist_motor\" />\n                          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                            rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_yaw_link\" />\n                          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_yaw_link\" />\n                          <geom pos=\"0.0415 0.003 0\" quat=\"1 0 0 0\" type=\"mesh\" contype=\"0\"\n                            conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n                            mesh=\"left_rubber_hand\" />\n                        </body>\n                      </body>\n                    </body>\n                  </body>\n                </body>\n              </body>\n            </body>\n            <body name=\"right_shoulder_pitch_link\" pos=\"0.0039563 -0.10021 0.23778\"\n              quat=\"0.990264 -0.139201 1.38722e-05 9.86868e-05\">\n              <inertial pos=\"0 -0.035892 -0.011628\" quat=\"0.68225 -0.326267 0.0130458 0.654152\"\n                mass=\"0.718\" diaginertia=\"0.000465864 0.000432842 0.000406394\" />\n              <joint name=\"right_shoulder_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\"\n                range=\"-3.0892 2.6704\" actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n              <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                rgba=\"0.7 0.7 0.7 1\" mesh=\"right_shoulder_pitch_link\" />\n              <geom size=\"0.03 0.025\" pos=\"0 -0.04 -0.01\" quat=\"0.707107 0 0.707107 0\"\n                type=\"cylinder\" rgba=\"0.7 0.7 0.7 1\" />\n              <body name=\"right_shoulder_roll_link\" pos=\"0 -0.038 -0.013831\"\n                quat=\"0.990268 0.139172 0 0\">\n                <inertial pos=\"-0.000227 -0.00727 -0.063243\"\n                  quat=\"0.712604 -0.00710317 -0.0196223 0.701256\" mass=\"0.643\"\n                  diaginertia=\"0.000691311 0.000618011 0.000388977\" />\n                <joint name=\"right_shoulder_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\"\n                  range=\"-2.2515 1.5882\" actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n                <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                  rgba=\"0.7 0.7 0.7 1\" mesh=\"right_shoulder_roll_link\" />\n                <geom size=\"0.03 0.015\" pos=\"-0.004 -0.006 -0.053\" type=\"cylinder\"\n                  rgba=\"0.7 0.7 0.7 1\" />\n                <body name=\"right_shoulder_yaw_link\" pos=\"0 -0.00624 -0.1032\">\n                  <inertial pos=\"0.010773 0.002949 -0.072009\"\n                    quat=\"0.687134 -0.0679942 -0.0964829 0.716879\" mass=\"0.734\"\n                    diaginertia=\"0.00106187 0.00103217 0.000400661\" />\n                  <joint name=\"right_shoulder_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\"\n                    range=\"-2.618 2.618\" actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n                  <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                    rgba=\"0.7 0.7 0.7 1\" mesh=\"right_shoulder_yaw_link\" />\n                  <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_shoulder_yaw_link\" />\n                  <body name=\"right_elbow_link\" pos=\"0.015783 0 -0.080518\">\n                    <inertial pos=\"0.064956 -0.004454 -0.010062\"\n                      quat=\"0.388129 0.388821 0.636132 0.541765\" mass=\"0.6\"\n                      diaginertia=\"0.000443035 0.000421612 0.000259353\" />\n                    <joint name=\"right_elbow_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.0472 2.0944\"\n                      actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n                    <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                      rgba=\"0.7 0.7 0.7 1\" mesh=\"right_elbow_link\" />\n                    <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_elbow_link\" />\n                    <body name=\"right_wrist_roll_link\" pos=\"0.1 -0.00188791 -0.01\">\n                      <inertial pos=\"0.0171394 -0.000537591 4.8864e-07\"\n                        quat=\"0.411667 0.575338 -0.411094 0.574906\" mass=\"0.085445\"\n                        diaginertia=\"5.48211e-05 4.96646e-05 3.57798e-05\" />\n                      <joint name=\"right_wrist_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\"\n                        range=\"-1.97222 1.97222\" actuatorfrcrange=\"-25 25\" class=\"arm_motor\" />\n                      <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                        rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_roll_link\" />\n                      <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_roll_link\" />\n                      <body name=\"right_wrist_pitch_link\" pos=\"0.038 0 0\">\n                        <inertial pos=\"0.0229999 0.00111685 -0.00111658\"\n                          quat=\"0.643608 0.293036 0.661363 0.249998\" mass=\"0.48405\"\n                          diaginertia=\"0.000430353 0.000429873 0.000164648\" />\n                        <joint name=\"right_wrist_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\"\n                          range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\" class=\"wrist_motor\" />\n                        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                          rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_pitch_link\" />\n                        <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_pitch_link\" />\n                        <body name=\"right_wrist_yaw_link\" pos=\"0.046 0 0\">\n                          <inertial pos=\"0.0708244 -0.000191745 0.00161742\"\n                            quat=\"0.493188 0.468078 0.526295 0.510571\" mass=\"0.254576\"\n                            diaginertia=\"0.000646113 0.000559993 0.000147566\" />\n                          <joint name=\"right_wrist_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\"\n                            range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\" class=\"wrist_motor\" />\n                          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\"\n                            rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_yaw_link\" />\n                          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_yaw_link\" />\n                          <geom pos=\"0.0415 -0.003 0\" quat=\"1 0 0 0\" type=\"mesh\" contype=\"0\"\n                            conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\"\n                            mesh=\"right_rubber_hand\" />\n                        </body>\n                      </body>\n                    </body>\n                  </body>\n                </body>\n              </body>\n            </body>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n\n  <actuator>\n    <motor name=\"left_hip_pitch\" joint=\"left_hip_pitch_joint\" ctrlrange=\"-88 88\" />\n    <motor name=\"left_hip_roll\" joint=\"left_hip_roll_joint\" ctrlrange=\"-88 88\" />\n    <motor name=\"left_hip_yaw\" joint=\"left_hip_yaw_joint\" ctrlrange=\"-88 88\" />\n    <motor name=\"left_knee\" joint=\"left_knee_joint\" ctrlrange=\"-139 139\" />\n    <motor name=\"left_ankle_pitch\" joint=\"left_ankle_pitch_joint\" ctrlrange=\"-50 50\" />\n    <motor name=\"left_ankle_roll\" joint=\"left_ankle_roll_joint\" ctrlrange=\"-50 50\" />\n\n    <motor name=\"right_hip_pitch\" joint=\"right_hip_pitch_joint\" ctrlrange=\"-88 88\" />\n    <motor name=\"right_hip_roll\" joint=\"right_hip_roll_joint\" ctrlrange=\"-88 88\" />\n    <motor name=\"right_hip_yaw\" joint=\"right_hip_yaw_joint\" ctrlrange=\"-88 88\" />\n    <motor name=\"right_knee\" joint=\"right_knee_joint\" ctrlrange=\"-139 139\" />\n    <motor name=\"right_ankle_pitch\" joint=\"right_ankle_pitch_joint\" ctrlrange=\"-50 50\" />\n    <motor name=\"right_ankle_roll\" joint=\"right_ankle_roll_joint\" ctrlrange=\"-50 50\" />\n\n    <motor name=\"waist_yaw\" joint=\"waist_yaw_joint\" ctrlrange=\"-88 88\" />\n    <motor name=\"waist_roll\" joint=\"waist_roll_joint\" ctrlrange=\"-50 50\" />\n    <motor name=\"waist_pitch\" joint=\"waist_pitch_joint\" ctrlrange=\"-50 50\" />\n\n    <motor name=\"left_shoulder_pitch\" joint=\"left_shoulder_pitch_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"left_shoulder_roll\" joint=\"left_shoulder_roll_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"left_shoulder_yaw\" joint=\"left_shoulder_yaw_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"left_elbow\" joint=\"left_elbow_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"left_wrist_roll\" joint=\"left_wrist_roll_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"left_wrist_pitch\" joint=\"left_wrist_pitch_joint\" ctrlrange=\"-5 5\" />\n    <motor name=\"left_wrist_yaw\" joint=\"left_wrist_yaw_joint\" ctrlrange=\"-5 5\" />\n\n    <motor name=\"right_shoulder_pitch\" joint=\"right_shoulder_pitch_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"right_shoulder_roll\" joint=\"right_shoulder_roll_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"right_shoulder_yaw\" joint=\"right_shoulder_yaw_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"right_elbow\" joint=\"right_elbow_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"right_wrist_roll\" joint=\"right_wrist_roll_joint\" ctrlrange=\"-25 25\" />\n    <motor name=\"right_wrist_pitch\" joint=\"right_wrist_pitch_joint\" ctrlrange=\"-5 5\" />\n    <motor name=\"right_wrist_yaw\" joint=\"right_wrist_yaw_joint\" ctrlrange=\"-5 5\" />\n  </actuator>\n\n  <sensor>\n    <jointpos name=\"left_hip_pitch_pos\" joint=\"left_hip_pitch_joint\" />\n    <jointpos name=\"left_hip_roll_pos\" joint=\"left_hip_roll_joint\" />\n    <jointpos name=\"left_hip_yaw_pos\" joint=\"left_hip_yaw_joint\" />\n    <jointpos name=\"left_knee_pos\" joint=\"left_knee_joint\" />\n    <jointpos name=\"left_ankle_pitch_pos\" joint=\"left_ankle_pitch_joint\" />\n    <jointpos name=\"left_ankle_roll_pos\" joint=\"left_ankle_roll_joint\" />\n    <jointpos name=\"right_hip_pitch_pos\" joint=\"right_hip_pitch_joint\" />\n    <jointpos name=\"right_hip_roll_pos\" joint=\"right_hip_roll_joint\" />\n    <jointpos name=\"right_hip_yaw_pos\" joint=\"right_hip_yaw_joint\" />\n    <jointpos name=\"right_knee_pos\" joint=\"right_knee_joint\" />\n    <jointpos name=\"right_ankle_pitch_pos\" joint=\"right_ankle_pitch_joint\" />\n    <jointpos name=\"right_ankle_roll_pos\" joint=\"right_ankle_roll_joint\" />\n    <jointpos name=\"waist_yaw_pos\" joint=\"waist_yaw_joint\" />\n    <jointpos name=\"waist_roll_pos\" joint=\"waist_roll_joint\" />\n    <jointpos name=\"waist_pitch_pos\" joint=\"waist_pitch_joint\" />\n    <jointpos name=\"left_shoulder_pitch_pos\" joint=\"left_shoulder_pitch_joint\" />\n    <jointpos name=\"left_shoulder_roll_pos\" joint=\"left_shoulder_roll_joint\" />\n    <jointpos name=\"left_shoulder_yaw_pos\" joint=\"left_shoulder_yaw_joint\" />\n    <jointpos name=\"left_elbow_pos\" joint=\"left_elbow_joint\" />\n    <jointpos name=\"left_wrist_roll_pos\" joint=\"left_wrist_roll_joint\" />\n    <jointpos name=\"left_wrist_pitch_pos\" joint=\"left_wrist_pitch_joint\" />\n    <jointpos name=\"left_wrist_yaw_pos\" joint=\"left_wrist_yaw_joint\" />\n    <jointpos name=\"right_shoulder_pitch_pos\" joint=\"right_shoulder_pitch_joint\" />\n    <jointpos name=\"right_shoulder_roll_pos\" joint=\"right_shoulder_roll_joint\" />\n    <jointpos name=\"right_shoulder_yaw_pos\" joint=\"right_shoulder_yaw_joint\" />\n    <jointpos name=\"right_elbow_pos\" joint=\"right_elbow_joint\" />\n    <jointpos name=\"right_wrist_roll_pos\" joint=\"right_wrist_roll_joint\" />\n    <jointpos name=\"right_wrist_pitch_pos\" joint=\"right_wrist_pitch_joint\" />\n    <jointpos name=\"right_wrist_yaw_pos\" joint=\"right_wrist_yaw_joint\" />\n\n    <jointvel name=\"left_hip_pitch_vel\" joint=\"left_hip_pitch_joint\" />\n    <jointvel name=\"left_hip_roll_vel\" joint=\"left_hip_roll_joint\" />\n    <jointvel name=\"left_hip_yaw_vel\" joint=\"left_hip_yaw_joint\" />\n    <jointvel name=\"left_knee_vel\" joint=\"left_knee_joint\" />\n    <jointvel name=\"left_ankle_pitch_vel\" joint=\"left_ankle_pitch_joint\" />\n    <jointvel name=\"left_ankle_roll_vel\" joint=\"left_ankle_roll_joint\" />\n    <jointvel name=\"right_hip_pitch_vel\" joint=\"right_hip_pitch_joint\" />\n    <jointvel name=\"right_hip_roll_vel\" joint=\"right_hip_roll_joint\" />\n    <jointvel name=\"right_hip_yaw_vel\" joint=\"right_hip_yaw_joint\" />\n    <jointvel name=\"right_knee_vel\" joint=\"right_knee_joint\" />\n    <jointvel name=\"right_ankle_pitch_vel\" joint=\"right_ankle_pitch_joint\" />\n    <jointvel name=\"right_ankle_roll_vel\" joint=\"right_ankle_roll_joint\" />\n    <jointvel name=\"waist_yaw_vel\" joint=\"waist_yaw_joint\" />\n    <jointvel name=\"waist_roll_vel\" joint=\"waist_roll_joint\" />\n    <jointvel name=\"waist_pitch_vel\" joint=\"waist_pitch_joint\" />\n    <jointvel name=\"left_shoulder_pitch_vel\" joint=\"left_shoulder_pitch_joint\" />\n    <jointvel name=\"left_shoulder_roll_vel\" joint=\"left_shoulder_roll_joint\" />\n    <jointvel name=\"left_shoulder_yaw_vel\" joint=\"left_shoulder_yaw_joint\" />\n    <jointvel name=\"left_elbow_vel\" joint=\"left_elbow_joint\" />\n    <jointvel name=\"left_wrist_roll_vel\" joint=\"left_wrist_roll_joint\" />\n    <jointvel name=\"left_wrist_pitch_vel\" joint=\"left_wrist_pitch_joint\" />\n    <jointvel name=\"left_wrist_yaw_vel\" joint=\"left_wrist_yaw_joint\" />\n    <jointvel name=\"right_shoulder_pitch_vel\" joint=\"right_shoulder_pitch_joint\" />\n    <jointvel name=\"right_shoulder_roll_vel\" joint=\"right_shoulder_roll_joint\" />\n    <jointvel name=\"right_shoulder_yaw_vel\" joint=\"right_shoulder_yaw_joint\" />\n    <jointvel name=\"right_elbow_vel\" joint=\"right_elbow_joint\" />\n    <jointvel name=\"right_wrist_roll_vel\" joint=\"right_wrist_roll_joint\" />\n    <jointvel name=\"right_wrist_pitch_vel\" joint=\"right_wrist_pitch_joint\" />\n    <jointvel name=\"right_wrist_yaw_vel\" joint=\"right_wrist_yaw_joint\" />\n\n    <jointactuatorfrc name=\"left_hip_pitch_torque\" joint=\"left_hip_pitch_joint\" />\n    <jointactuatorfrc name=\"left_hip_roll_torque\" joint=\"left_hip_roll_joint\" />\n    <jointactuatorfrc name=\"left_hip_yaw_torque\" joint=\"left_hip_yaw_joint\" />\n    <jointactuatorfrc name=\"left_knee_torque\" joint=\"left_knee_joint\" />\n    <jointactuatorfrc name=\"left_ankle_pitch_torque\" joint=\"left_ankle_pitch_joint\" />\n    <jointactuatorfrc name=\"left_ankle_roll_torque\" joint=\"left_ankle_roll_joint\" />\n    <jointactuatorfrc name=\"right_hip_pitch_torque\" joint=\"right_hip_pitch_joint\" />\n    <jointactuatorfrc name=\"right_hip_roll_torque\" joint=\"right_hip_roll_joint\" />\n    <jointactuatorfrc name=\"right_hip_yaw_torque\" joint=\"right_hip_yaw_joint\" />\n    <jointactuatorfrc name=\"right_knee_torque\" joint=\"right_knee_joint\" />\n    <jointactuatorfrc name=\"right_ankle_pitch_torque\" joint=\"right_ankle_pitch_joint\" />\n    <jointactuatorfrc name=\"right_ankle_roll_torque\" joint=\"right_ankle_roll_joint\" />\n    <jointactuatorfrc name=\"waist_yaw_torque\" joint=\"waist_yaw_joint\" />\n    <jointactuatorfrc name=\"waist_roll_torque\" joint=\"waist_roll_joint\" />\n    <jointactuatorfrc name=\"waist_pitch_torque\" joint=\"waist_pitch_joint\" />\n    <jointactuatorfrc name=\"left_shoulder_pitch_torque\" joint=\"left_shoulder_pitch_joint\" />\n    <jointactuatorfrc name=\"left_shoulder_roll_torque\" joint=\"left_shoulder_roll_joint\" />\n    <jointactuatorfrc name=\"left_shoulder_yaw_torque\" joint=\"left_shoulder_yaw_joint\" />\n    <jointactuatorfrc name=\"left_elbow_torque\" joint=\"left_elbow_joint\" />\n    <jointactuatorfrc name=\"left_wrist_roll_torque\" joint=\"left_wrist_roll_joint\" />\n    <jointactuatorfrc name=\"left_wrist_pitch_torque\" joint=\"left_wrist_pitch_joint\" />\n    <jointactuatorfrc name=\"left_wrist_yaw_torque\" joint=\"left_wrist_yaw_joint\" />\n    <jointactuatorfrc name=\"right_shoulder_pitch_torque\" joint=\"right_shoulder_pitch_joint\" />\n    <jointactuatorfrc name=\"right_shoulder_roll_torque\" joint=\"right_shoulder_roll_joint\" />\n    <jointactuatorfrc name=\"right_shoulder_yaw_torque\" joint=\"right_shoulder_yaw_joint\" />\n    <jointactuatorfrc name=\"right_elbow_torque\" joint=\"right_elbow_joint\" />\n    <jointactuatorfrc name=\"right_wrist_roll_torque\" joint=\"right_wrist_roll_joint\" />\n    <jointactuatorfrc name=\"right_wrist_pitch_torque\" joint=\"right_wrist_pitch_joint\" />\n    <jointactuatorfrc name=\"right_wrist_yaw_torque\" joint=\"right_wrist_yaw_joint\" />\n\n    <framequat name=\"imu_quat\" objtype=\"site\" objname=\"imu\" />\n    <gyro name=\"imu_gyro\" site=\"imu\" />\n    <accelerometer name=\"imu_acc\" site=\"imu\" />\n\n    <framepos name=\"frame_pos\" objtype=\"site\" objname=\"imu\" />\n    <framelinvel name=\"frame_vel\" objtype=\"site\" objname=\"imu\" />\n  </sensor>\n\n</mujoco>"
  },
  {
    "path": "assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0.urdf",
    "content": "<robot name=\"g1_29dof_rev_1_0\">\r\n  <material name=\"dark\">\r\n    <color rgba=\"0.2 0.2 0.2 1\"/>\r\n  </material>\r\n  <material name=\"white\">\r\n    <color rgba=\"0.7 0.7 0.7 1\"/>\r\n  </material>\r\n\r\n  <mujoco>\r\n    <compiler meshdir=\"../meshes\" discardvisual=\"false\"/>\r\n  </mujoco>\r\n\r\n  <!-- [CAUTION] uncomment when convert to mujoco -->\r\n  <!-- <link name=\"world\"></link>\r\n  <joint name=\"floating_base_joint\" type=\"floating\">\r\n    <parent link=\"world\"/>\r\n    <child link=\"pelvis\"/>\r\n  </joint> -->\r\n\r\n  <link name=\"pelvis\">\r\n    <inertial>\r\n      <origin xyz=\"0 0 -0.07605\" rpy=\"0 0 0\"/>\r\n      <mass value=\"3.813\"/>\r\n      <inertia ixx=\"0.010549\" ixy=\"0\" ixz=\"2.1E-06\" iyy=\"0.0093089\" iyz=\"0\" izz=\"0.0079184\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/pelvis.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n  </link>\r\n  <link name=\"pelvis_contour_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.001\"/>\r\n      <inertia ixx=\"1e-7\" ixy=\"0\" ixz=\"0\" iyy=\"1e-7\" iyz=\"0\" izz=\"1e-7\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/pelvis_contour_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/pelvis_contour_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"pelvis_contour_joint\" type=\"fixed\">\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"pelvis_contour_link\"/>\r\n  </joint>\r\n\r\n  <!-- Legs -->\r\n  <link name=\"left_hip_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.002741 0.047791 -0.02606\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.35\"/>\r\n      <inertia ixx=\"0.001811\" ixy=\"3.68E-05\" ixz=\"-3.44E-05\" iyy=\"0.0014193\" iyz=\"0.000171\" izz=\"0.0012812\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_hip_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0.064452 -0.1027\" rpy=\"0 0 0\"/>\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"left_hip_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-2.5307\" upper=\"2.8798\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"left_hip_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.029812 -0.001045 -0.087934\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.52\"/>\r\n      <inertia ixx=\"0.0023773\" ixy=\"-3.8E-06\" ixz=\"-0.0003908\" iyy=\"0.0024123\" iyz=\"1.84E-05\" izz=\"0.0016595\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_roll_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_hip_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0.052 -0.030465\" rpy=\"0 -0.1749 0\"/>\r\n    <parent link=\"left_hip_pitch_link\"/>\r\n    <child link=\"left_hip_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-0.5236\" upper=\"2.9671\" effort=\"139\" velocity=\"20\"/>\r\n  </joint>\r\n  <link name=\"left_hip_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.057709 -0.010981 -0.15078\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.702\"/>\r\n      <inertia ixx=\"0.0057774\" ixy=\"-0.0005411\" ixz=\"-0.0023948\" iyy=\"0.0076124\" iyz=\"-0.0007072\" izz=\"0.003149\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_hip_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.025001 0 -0.12412\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_hip_roll_link\"/>\r\n    <child link=\"left_hip_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.7576\" upper=\"2.7576\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"left_knee_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.005457 0.003964 -0.12074\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.932\"/>\r\n      <inertia ixx=\"0.011329\" ixy=\"4.82E-05\" ixz=\"-4.49E-05\" iyy=\"0.011277\" iyz=\"-0.0007146\" izz=\"0.0015168\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_knee_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_knee_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_knee_joint\" type=\"revolute\">\r\n    <origin xyz=\"-0.078273 0.0021489 -0.17734\" rpy=\"0 0.1749 0\"/>\r\n    <parent link=\"left_hip_yaw_link\"/>\r\n    <child link=\"left_knee_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.087267\" upper=\"2.8798\" effort=\"139\" velocity=\"20\"/>\r\n  </joint>\r\n  <link name=\"left_ankle_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.007269 0 0.011137\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.074\"/>\r\n      <inertia ixx=\"8.4E-06\" ixy=\"0\" ixz=\"-2.9E-06\" iyy=\"1.89E-05\" iyz=\"0\" izz=\"1.26E-05\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_ankle_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_ankle_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_ankle_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -9.4445E-05 -0.30001\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_knee_link\"/>\r\n    <child link=\"left_ankle_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.87267\" upper=\"0.5236\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n  <link name=\"left_ankle_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.026505 0 -0.016425\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.608\"/>\r\n      <inertia ixx=\"0.0002231\" ixy=\"2E-07\" ixz=\"8.91E-05\" iyy=\"0.0016161\" iyz=\"-1E-07\" izz=\"0.0016667\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_ankle_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"-0.05 0.025 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"-0.05 -0.025 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"0.12 0.03 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"0.12 -0.03 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_ankle_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0 -0.017558\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_ankle_pitch_link\"/>\r\n    <child link=\"left_ankle_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-0.2618\" upper=\"0.2618\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n  <link name=\"right_hip_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.002741 -0.047791 -0.02606\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.35\"/>\r\n      <inertia ixx=\"0.001811\" ixy=\"-3.68E-05\" ixz=\"-3.44E-05\" iyy=\"0.0014193\" iyz=\"-0.000171\" izz=\"0.0012812\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_hip_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -0.064452 -0.1027\" rpy=\"0 0 0\"/>\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"right_hip_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-2.5307\" upper=\"2.8798\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"right_hip_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.029812 0.001045 -0.087934\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.52\"/>\r\n      <inertia ixx=\"0.0023773\" ixy=\"3.8E-06\" ixz=\"-0.0003908\" iyy=\"0.0024123\" iyz=\"-1.84E-05\" izz=\"0.0016595\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_roll_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_hip_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -0.052 -0.030465\" rpy=\"0 -0.1749 0\"/>\r\n    <parent link=\"right_hip_pitch_link\"/>\r\n    <child link=\"right_hip_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-2.9671\" upper=\"0.5236\" effort=\"139\" velocity=\"20\"/>\r\n  </joint>\r\n  <link name=\"right_hip_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.057709 0.010981 -0.15078\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.702\"/>\r\n      <inertia ixx=\"0.0057774\" ixy=\"0.0005411\" ixz=\"-0.0023948\" iyy=\"0.0076124\" iyz=\"0.0007072\" izz=\"0.003149\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_hip_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.025001 0 -0.12412\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_hip_roll_link\"/>\r\n    <child link=\"right_hip_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.7576\" upper=\"2.7576\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"right_knee_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.005457 -0.003964 -0.12074\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.932\"/>\r\n      <inertia ixx=\"0.011329\" ixy=\"-4.82E-05\" ixz=\"4.49E-05\" iyy=\"0.011277\" iyz=\"0.0007146\" izz=\"0.0015168\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_knee_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_knee_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_knee_joint\" type=\"revolute\">\r\n    <origin xyz=\"-0.078273 -0.0021489 -0.17734\" rpy=\"0 0.1749 0\"/>\r\n    <parent link=\"right_hip_yaw_link\"/>\r\n    <child link=\"right_knee_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.087267\" upper=\"2.8798\" effort=\"139\" velocity=\"20\"/>\r\n  </joint>\r\n  <link name=\"right_ankle_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.007269 0 0.011137\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.074\"/>\r\n      <inertia ixx=\"8.4E-06\" ixy=\"0\" ixz=\"-2.9E-06\" iyy=\"1.89E-05\" iyz=\"0\" izz=\"1.26E-05\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_ankle_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_ankle_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_ankle_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 9.4445E-05 -0.30001\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_knee_link\"/>\r\n    <child link=\"right_ankle_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.87267\" upper=\"0.5236\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n  <link name=\"right_ankle_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.026505 0 -0.016425\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.608\"/>\r\n      <inertia ixx=\"0.0002231\" ixy=\"-2E-07\" ixz=\"8.91E-05\" iyy=\"0.0016161\" iyz=\"1E-07\" izz=\"0.0016667\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_ankle_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"-0.05 0.025 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"-0.05 -0.025 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"0.12 0.03 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"0.12 -0.03 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_ankle_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0 -0.017558\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_ankle_pitch_link\"/>\r\n    <child link=\"right_ankle_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-0.2618\" upper=\"0.2618\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n\r\n  <!-- Torso -->\r\n  <link name=\"waist_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.003494 0.000233 0.018034\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.214\"/>\r\n      <inertia ixx=\"0.00010673\" ixy=\"2.703E-06\" ixz=\"-7.631E-06\" iyy=\"0.00010422\" iyz=\"-2.01E-07\" izz=\"0.0001625\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/waist_yaw_link_rev_1_0.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n  </link>\r\n  <joint name=\"waist_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"waist_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.618\" upper=\"2.618\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"waist_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 2.3E-05 0\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.086\"/>\r\n      <inertia ixx=\"7.079E-06\" ixy=\"0\" ixz=\"0\" iyy=\"6.339E-06\" iyz=\"0\" izz=\"8.245E-06\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/waist_roll_link_rev_1_0.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n  </link>\r\n  <joint name=\"waist_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"-0.0039635 0 0.044\" rpy=\"0 0 0\"/>\r\n    <parent link=\"waist_yaw_link\"/>\r\n    <child link=\"waist_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-0.52\" upper=\"0.52\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n  <link name=\"torso_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.000931 0.000346 0.15082\" rpy=\"0 0 0\"/>\r\n      <mass value=\"6.78\"/>\r\n      <inertia ixx=\"0.05905\" ixy=\"3.3302E-05\" ixz=\"-0.0017715\" iyy=\"0.047014\" iyz=\"-2.2399E-05\" izz=\"0.025652\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/torso_link_rev_1_0.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/torso_link_rev_1_0.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"waist_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n    <parent link=\"waist_roll_link\"/>\r\n    <child link=\"torso_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.52\" upper=\"0.52\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n\r\n  <!-- LOGO -->\r\n  <joint name=\"logo_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0039635 0 -0.044\" rpy=\"0 0 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"logo_link\"/>\r\n  </joint>\r\n  <link name=\"logo_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.001\"/>\r\n      <inertia ixx=\"1e-7\" ixy=\"0\" ixz=\"0\" iyy=\"1e-7\" iyz=\"0\" izz=\"1e-7\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/logo_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/logo_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n\r\n  <!-- Head -->\r\n  <link name=\"head_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.005267 0.000299 0.449869\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.036\"/>\r\n      <inertia ixx=\"0.004085051\" ixy=\"-2.543E-06\" ixz=\"-6.9455E-05\" iyy=\"0.004185212\" iyz=\"-3.726E-06\" izz=\"0.001807911\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/head_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/head_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"head_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0039635 0 -0.044\" rpy=\"0 0 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"head_link\"/>\r\n  </joint>\r\n\r\n\r\n  <!-- IMU -->\r\n  <link name=\"imu_in_torso\"></link>\r\n  <joint name=\"imu_in_torso_joint\" type=\"fixed\">\r\n    <origin xyz=\"-0.03959 -0.00224 0.14792\" rpy=\"0 0 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"imu_in_torso\"/>\r\n  </joint>\r\n\r\n  <link name=\"imu_in_pelvis\"></link>\r\n  <joint name=\"imu_in_pelvis_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.04525 0 -0.08339\" rpy=\"0 0 0\"/>\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"imu_in_pelvis\"/>\r\n  </joint>\r\n\r\n  <!-- d435 -->\r\n  <link name=\"d435_link\"></link>\r\n  <joint name=\"d435_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0576235 0.01753 0.42987\" rpy=\"0 0.8307767239493009 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"d435_link\"/>\r\n  </joint>\r\n\r\n  <!-- mid360 -->\r\n  <link name=\"mid360_link\"></link>\r\n  <joint name=\"mid360_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0002835 0.00003 0.41618\" rpy=\"0 0.04014257279586953 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"mid360_link\"/>\r\n  </joint>\r\n\r\n  <!-- Arm -->\r\n  <link name=\"left_shoulder_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 0.035892 -0.011628\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.718\"/>\r\n      <inertia ixx=\"0.0004291\" ixy=\"-9.2E-06\" ixz=\"6.4E-06\" iyy=\"0.000453\" iyz=\"2.26E-05\" izz=\"0.000423\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_shoulder_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0.04 -0.01\" rpy=\"0 1.5707963267948966 0\"/>\r\n      <geometry>\r\n        <cylinder radius=\"0.03\" length=\"0.05\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_shoulder_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.0039563 0.10022 0.24778\" rpy=\"0.27931 5.4949E-05 -0.00019159\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"left_shoulder_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-3.0892\" upper=\"2.6704\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"left_shoulder_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.000227 0.00727 -0.063243\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.643\"/>\r\n      <inertia ixx=\"0.0006177\" ixy=\"-1E-06\" ixz=\"8.7E-06\" iyy=\"0.0006912\" iyz=\"-5.3E-06\" izz=\"0.0003894\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_shoulder_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"-0.004 0.006 -0.053\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <cylinder radius=\"0.03\" length=\"0.03\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_shoulder_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0.038 -0.013831\" rpy=\"-0.27925 0 0\"/>\r\n    <parent link=\"left_shoulder_pitch_link\"/>\r\n    <child link=\"left_shoulder_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-1.5882\" upper=\"2.2515\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"left_shoulder_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.010773 -0.002949 -0.072009\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.734\"/>\r\n      <inertia ixx=\"0.0009988\" ixy=\"7.9E-06\" ixz=\"0.0001412\" iyy=\"0.0010605\" iyz=\"-2.86E-05\" izz=\"0.0004354\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_shoulder_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_shoulder_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_shoulder_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0.00624 -0.1032\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_shoulder_roll_link\"/>\r\n    <child link=\"left_shoulder_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.618\" upper=\"2.618\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"left_elbow_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.064956 0.004454 -0.010062\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.6\"/>\r\n      <inertia ixx=\"0.0002891\" ixy=\"6.53E-05\" ixz=\"1.72E-05\" iyy=\"0.0004152\" iyz=\"-5.6E-06\" izz=\"0.0004197\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_elbow_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_elbow_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_elbow_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.015783 0 -0.080518\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_shoulder_yaw_link\"/>\r\n    <child link=\"left_elbow_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-1.0472\" upper=\"2.0944\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <joint name=\"left_wrist_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.100 0.00188791 -0.010\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <parent link=\"left_elbow_link\"/>\r\n    <child link=\"left_wrist_roll_link\"/>\r\n    <limit effort=\"25\" velocity=\"37\" lower=\"-1.972222054\" upper=\"1.972222054\"/>\r\n  </joint>\r\n  <link name=\"left_wrist_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.01713944778 0.00053759094 0.00000048864\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.08544498\"/>\r\n      <inertia ixx=\"0.00004821544023\" ixy=\"-0.00000424511021\" ixz=\"0.00000000510599\" iyy=\"0.00003722899093\" iyz=\"-0.00000000123525\" izz=\"0.00005482106541\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_roll_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_wrist_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.038 0 0\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <parent link=\"left_wrist_roll_link\"/>\r\n    <child link=\"left_wrist_pitch_link\"/>\r\n    <limit effort=\"5\" velocity=\"22\" lower=\"-1.614429558\" upper=\"1.614429558\"/>\r\n  </joint>\r\n  <link name=\"left_wrist_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.02299989837 -0.00111685314 -0.00111658096\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.48404956\"/>\r\n      <inertia ixx=\"0.00016579646273\" ixy=\"-0.00001231206746\" ixz=\"0.00001231699194\" iyy=\"0.00042954057410\" iyz=\"0.00000081417712\" izz=\"0.00042953697654\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_wrist_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.046 0 0\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <parent link=\"left_wrist_pitch_link\"/>\r\n    <child link=\"left_wrist_yaw_link\"/>\r\n    <limit effort=\"5\" velocity=\"22\" lower=\"-1.614429558\" upper=\"1.614429558\"/>\r\n  </joint>\r\n  <link name=\"left_wrist_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.02200381568 0.00049485096 0.00053861123\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.08457647\"/>\r\n      <inertia ixx=\"0.00004929128828\" ixy=\"-0.00000045735494\" ixz=\"0.00000445867591\" iyy=\"0.00005973338134\" iyz=\"0.00000043217198\" izz=\"0.00003928083826\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_hand_palm_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0415 0.003 0\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_wrist_yaw_link\"/>\r\n    <child link=\"left_rubber_hand\"/>\r\n  </joint>\r\n  <link name=\"left_rubber_hand\">\r\n    <inertial>\r\n      <origin xyz=\"0.05361310808 -0.00295905240 0.00215413091\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.170\"/>\r\n      <inertia ixx=\"0.00010099485234748\" ixy=\"0.00003618590790516\" ixz=\"-0.00000074301518642\" iyy=\"0.00028135871571621\" iyz=\"0.00000330189743286\" izz=\"0.00021894770413514\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_rubber_hand.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n  </link>\r\n  <link name=\"right_shoulder_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 -0.035892 -0.011628\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.718\"/>\r\n      <inertia ixx=\"0.0004291\" ixy=\"9.2E-06\" ixz=\"6.4E-06\" iyy=\"0.000453\" iyz=\"-2.26E-05\" izz=\"0.000423\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_shoulder_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 -0.04 -0.01\" rpy=\"0 1.5707963267948966 0\"/>\r\n      <geometry>\r\n        <cylinder radius=\"0.03\" length=\"0.05\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_shoulder_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.0039563 -0.10021 0.24778\" rpy=\"-0.27931 5.4949E-05 0.00019159\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"right_shoulder_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-3.0892\" upper=\"2.6704\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"right_shoulder_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.000227 -0.00727 -0.063243\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.643\"/>\r\n      <inertia ixx=\"0.0006177\" ixy=\"1E-06\" ixz=\"8.7E-06\" iyy=\"0.0006912\" iyz=\"5.3E-06\" izz=\"0.0003894\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_shoulder_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"-0.004 -0.006 -0.053\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <cylinder radius=\"0.03\" length=\"0.03\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_shoulder_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -0.038 -0.013831\" rpy=\"0.27925 0 0\"/>\r\n    <parent link=\"right_shoulder_pitch_link\"/>\r\n    <child link=\"right_shoulder_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-2.2515\" upper=\"1.5882\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"right_shoulder_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.010773 0.002949 -0.072009\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.734\"/>\r\n      <inertia ixx=\"0.0009988\" ixy=\"-7.9E-06\" ixz=\"0.0001412\" iyy=\"0.0010605\" iyz=\"2.86E-05\" izz=\"0.0004354\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_shoulder_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_shoulder_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_shoulder_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -0.00624 -0.1032\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_shoulder_roll_link\"/>\r\n    <child link=\"right_shoulder_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.618\" upper=\"2.618\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"right_elbow_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.064956 -0.004454 -0.010062\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.6\"/>\r\n      <inertia ixx=\"0.0002891\" ixy=\"-6.53E-05\" ixz=\"1.72E-05\" iyy=\"0.0004152\" iyz=\"5.6E-06\" izz=\"0.0004197\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_elbow_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_elbow_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_elbow_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.015783 0 -0.080518\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_shoulder_yaw_link\"/>\r\n    <child link=\"right_elbow_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-1.0472\" upper=\"2.0944\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <joint name=\"right_wrist_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.100 -0.00188791 -0.010\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <parent link=\"right_elbow_link\"/>\r\n    <child link=\"right_wrist_roll_link\"/>\r\n    <limit effort=\"25\" velocity=\"37\" lower=\"-1.972222054\" upper=\"1.972222054\"/>\r\n  </joint>\r\n  <link name=\"right_wrist_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.01713944778 -0.00053759094 0.00000048864\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.08544498\"/>\r\n      <inertia ixx=\"0.00004821544023\" ixy=\"0.00000424511021\" ixz=\"0.00000000510599\" iyy=\"0.00003722899093\" iyz=\"0.00000000123525\" izz=\"0.00005482106541\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_roll_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_wrist_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.038 0 0\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <parent link=\"right_wrist_roll_link\"/>\r\n    <child link=\"right_wrist_pitch_link\"/>\r\n    <limit effort=\"5\" velocity=\"22\" lower=\"-1.614429558\" upper=\"1.614429558\"/>\r\n  </joint>\r\n  <link name=\"right_wrist_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.02299989837 0.00111685314 -0.00111658096\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.48404956\"/>\r\n      <inertia ixx=\"0.00016579646273\" ixy=\"0.00001231206746\" ixz=\"0.00001231699194\" iyy=\"0.00042954057410\" iyz=\"-0.00000081417712\" izz=\"0.00042953697654\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_wrist_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.046 0 0\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <parent link=\"right_wrist_pitch_link\"/>\r\n    <child link=\"right_wrist_yaw_link\"/>\r\n    <limit effort=\"5\" velocity=\"22\" lower=\"-1.614429558\" upper=\"1.614429558\"/>\r\n  </joint>\r\n  <link name=\"right_wrist_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.02200381568 -0.00049485096 0.00053861123\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.08457647\"/>\r\n      <inertia ixx=\"0.00004929128828\" ixy=\"0.00000045735494\" ixz=\"0.00000445867591\" iyy=\"0.00005973338134\" iyz=\"-0.00000043217198\" izz=\"0.00003928083826\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_hand_palm_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0415 -0.003 0\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_wrist_yaw_link\"/>\r\n    <child link=\"right_rubber_hand\"/>\r\n  </joint>\r\n  <link name=\"right_rubber_hand\">\r\n    <inertial>\r\n      <origin xyz=\"0.05361310808 0.00295905240 0.00215413091\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.170\"/>\r\n      <inertia ixx=\"0.00010099485234748\" ixy=\"-0.00003618590790516\" ixz=\"-0.00000074301518642\" iyy=\"0.00028135871571621\" iyz=\"-0.00000330189743286\" izz=\"0.00021894770413514\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_rubber_hand.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n  </link>\r\n</robot>"
  },
  {
    "path": "assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0.xml",
    "content": "<mujoco model=\"g1_29dof_rev_1_0\">\n  <compiler angle=\"radian\" meshdir=\"../meshes\"/>\n\n  <statistic meansize=\"0.144785\" extent=\"1.23314\" center=\"0.025392 2.0634e-05 -0.245975\"/>\n  <default>\n    <joint damping=\"0.00\" armature=\"0.001\" frictionloss=\"0.03\"/>\n  </default>\n\n  <asset>\n    <mesh name=\"pelvis\" file=\"pelvis.STL\"/>\n    <mesh name=\"pelvis_contour_link\" file=\"pelvis_contour_link.STL\"/>\n    <mesh name=\"left_hip_pitch_link\" file=\"left_hip_pitch_link.STL\"/>\n    <mesh name=\"left_hip_roll_link\" file=\"left_hip_roll_link.STL\"/>\n    <mesh name=\"left_hip_yaw_link\" file=\"left_hip_yaw_link.STL\"/>\n    <mesh name=\"left_knee_link\" file=\"left_knee_link.STL\"/>\n    <mesh name=\"left_ankle_pitch_link\" file=\"left_ankle_pitch_link.STL\"/>\n    <mesh name=\"left_ankle_roll_link\" file=\"left_ankle_roll_link.STL\"/>\n    <mesh name=\"right_hip_pitch_link\" file=\"right_hip_pitch_link.STL\"/>\n    <mesh name=\"right_hip_roll_link\" file=\"right_hip_roll_link.STL\"/>\n    <mesh name=\"right_hip_yaw_link\" file=\"right_hip_yaw_link.STL\"/>\n    <mesh name=\"right_knee_link\" file=\"right_knee_link.STL\"/>\n    <mesh name=\"right_ankle_pitch_link\" file=\"right_ankle_pitch_link.STL\"/>\n    <mesh name=\"right_ankle_roll_link\" file=\"right_ankle_roll_link.STL\"/>\n    <mesh name=\"waist_yaw_link\" file=\"waist_yaw_link_rev_1_0.STL\"/>\n    <mesh name=\"waist_roll_link\" file=\"waist_roll_link_rev_1_0.STL\"/>\n    <mesh name=\"torso_link\" file=\"torso_link_rev_1_0.STL\"/>\n    <mesh name=\"logo_link\" file=\"logo_link.STL\"/>\n    <mesh name=\"head_link\" file=\"head_link.STL\"/>\n    <mesh name=\"left_shoulder_pitch_link\" file=\"left_shoulder_pitch_link.STL\"/>\n    <mesh name=\"left_shoulder_roll_link\" file=\"left_shoulder_roll_link.STL\"/>\n    <mesh name=\"left_shoulder_yaw_link\" file=\"left_shoulder_yaw_link.STL\"/>\n    <mesh name=\"left_elbow_link\" file=\"left_elbow_link.STL\"/>\n    <mesh name=\"left_wrist_roll_link\" file=\"left_wrist_roll_link.STL\"/>\n    <mesh name=\"left_wrist_pitch_link\" file=\"left_wrist_pitch_link.STL\"/>\n    <mesh name=\"left_wrist_yaw_link\" file=\"left_wrist_yaw_link.STL\"/>\n    <mesh name=\"left_rubber_hand\" file=\"left_rubber_hand.STL\"/>\n    <mesh name=\"right_shoulder_pitch_link\" file=\"right_shoulder_pitch_link.STL\"/>\n    <mesh name=\"right_shoulder_roll_link\" file=\"right_shoulder_roll_link.STL\"/>\n    <mesh name=\"right_shoulder_yaw_link\" file=\"right_shoulder_yaw_link.STL\"/>\n    <mesh name=\"right_elbow_link\" file=\"right_elbow_link.STL\"/>\n    <mesh name=\"right_wrist_roll_link\" file=\"right_wrist_roll_link.STL\"/>\n    <mesh name=\"right_wrist_pitch_link\" file=\"right_wrist_pitch_link.STL\"/>\n    <mesh name=\"right_wrist_yaw_link\" file=\"right_wrist_yaw_link.STL\"/>\n    <mesh name=\"right_rubber_hand\" file=\"right_rubber_hand.STL\"/>\n  </asset>\n\n  <worldbody>\n    <body name=\"pelvis\" pos=\"0 0 0.793\">\n      <inertial pos=\"0 0 -0.07605\" quat=\"1 0 -0.000399148 0\" mass=\"3.813\" diaginertia=\"0.010549 0.0093089 0.0079184\"/>\n      <joint name=\"floating_base_joint\" type=\"free\" limited=\"false\" actuatorfrclimited=\"false\"/>\n      <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"pelvis\"/>\n      <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"pelvis_contour_link\"/>\n      <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"pelvis_contour_link\"/>\n      <site name=\"imu_in_pelvis\" size=\"0.01\" pos=\"0.04525 0 -0.08339\"/>\n      <body name=\"left_hip_pitch_link\" pos=\"0 0.064452 -0.1027\">\n        <inertial pos=\"0.002741 0.047791 -0.02606\" quat=\"0.954862 0.293964 0.0302556 0.030122\" mass=\"1.35\" diaginertia=\"0.00181517 0.00153422 0.00116212\"/>\n        <joint name=\"left_hip_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-2.5307 2.8798\" actuatorfrcrange=\"-88 88\"/>\n        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"left_hip_pitch_link\"/>\n        <geom type=\"mesh\" rgba=\"0.2 0.2 0.2 1\" mesh=\"left_hip_pitch_link\"/>\n        <body name=\"left_hip_roll_link\" pos=\"0 0.052 -0.030465\" quat=\"0.996179 0 -0.0873386 0\">\n          <inertial pos=\"0.029812 -0.001045 -0.087934\" quat=\"0.977808 -1.97119e-05 0.205576 -0.0403793\" mass=\"1.52\" diaginertia=\"0.00254986 0.00241169 0.00148755\"/>\n          <joint name=\"left_hip_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.5236 2.9671\" actuatorfrcrange=\"-139 139\"/>\n          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_hip_roll_link\"/>\n          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_hip_roll_link\"/>\n          <body name=\"left_hip_yaw_link\" pos=\"0.025001 0 -0.12412\">\n            <inertial pos=\"-0.057709 -0.010981 -0.15078\" quat=\"0.600598 0.15832 0.223482 0.751181\" mass=\"1.702\" diaginertia=\"0.00776166 0.00717575 0.00160139\"/>\n            <joint name=\"left_hip_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-2.7576 2.7576\" actuatorfrcrange=\"-88 88\"/>\n            <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_hip_yaw_link\"/>\n            <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_hip_yaw_link\"/>\n            <body name=\"left_knee_link\" pos=\"-0.078273 0.0021489 -0.17734\" quat=\"0.996179 0 0.0873386 0\">\n              <inertial pos=\"0.005457 0.003964 -0.12074\" quat=\"0.923418 -0.0327699 0.0158246 0.382067\" mass=\"1.932\" diaginertia=\"0.0113804 0.0112778 0.00146458\"/>\n              <joint name=\"left_knee_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.087267 2.8798\" actuatorfrcrange=\"-139 139\"/>\n              <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_knee_link\"/>\n              <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_knee_link\"/>\n              <body name=\"left_ankle_pitch_link\" pos=\"0 -9.4445e-05 -0.30001\">\n                <inertial pos=\"-0.007269 0 0.011137\" quat=\"0.603053 0.369225 0.369225 0.603053\" mass=\"0.074\" diaginertia=\"1.89e-05 1.40805e-05 6.9195e-06\"/>\n                <joint name=\"left_ankle_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.87267 0.5236\" actuatorfrcrange=\"-50 50\"/>\n                <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_ankle_pitch_link\"/>\n                <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_ankle_pitch_link\"/>\n                <body name=\"left_ankle_roll_link\" pos=\"0 0 -0.017558\">\n                  <inertial pos=\"0.026505 0 -0.016425\" quat=\"-0.000481092 0.728482 -0.000618967 0.685065\" mass=\"0.608\" diaginertia=\"0.00167218 0.0016161 0.000217621\"/>\n                  <joint name=\"left_ankle_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.2618 0.2618\" actuatorfrcrange=\"-50 50\"/>\n                  <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"left_ankle_roll_link\"/>\n                  <geom size=\"0.005\" pos=\"-0.05 0.025 -0.03\" rgba=\"0.2 0.2 0.2 1\"/>\n                  <geom size=\"0.005\" pos=\"-0.05 -0.025 -0.03\" rgba=\"0.2 0.2 0.2 1\"/>\n                  <geom size=\"0.005\" pos=\"0.12 0.03 -0.03\" rgba=\"0.2 0.2 0.2 1\"/>\n                  <geom size=\"0.005\" pos=\"0.12 -0.03 -0.03\" rgba=\"0.2 0.2 0.2 1\"/>\n                </body>\n              </body>\n            </body>\n          </body>\n        </body>\n      </body>\n      <body name=\"right_hip_pitch_link\" pos=\"0 -0.064452 -0.1027\">\n        <inertial pos=\"0.002741 -0.047791 -0.02606\" quat=\"0.954862 -0.293964 0.0302556 -0.030122\" mass=\"1.35\" diaginertia=\"0.00181517 0.00153422 0.00116212\"/>\n        <joint name=\"right_hip_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-2.5307 2.8798\" actuatorfrcrange=\"-88 88\"/>\n        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"right_hip_pitch_link\"/>\n        <geom type=\"mesh\" rgba=\"0.2 0.2 0.2 1\" mesh=\"right_hip_pitch_link\"/>\n        <body name=\"right_hip_roll_link\" pos=\"0 -0.052 -0.030465\" quat=\"0.996179 0 -0.0873386 0\">\n          <inertial pos=\"0.029812 0.001045 -0.087934\" quat=\"0.977808 1.97119e-05 0.205576 0.0403793\" mass=\"1.52\" diaginertia=\"0.00254986 0.00241169 0.00148755\"/>\n          <joint name=\"right_hip_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-2.9671 0.5236\" actuatorfrcrange=\"-139 139\"/>\n          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_hip_roll_link\"/>\n          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_hip_roll_link\"/>\n          <body name=\"right_hip_yaw_link\" pos=\"0.025001 0 -0.12412\">\n            <inertial pos=\"-0.057709 0.010981 -0.15078\" quat=\"0.751181 0.223482 0.15832 0.600598\" mass=\"1.702\" diaginertia=\"0.00776166 0.00717575 0.00160139\"/>\n            <joint name=\"right_hip_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-2.7576 2.7576\" actuatorfrcrange=\"-88 88\"/>\n            <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_hip_yaw_link\"/>\n            <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_hip_yaw_link\"/>\n            <body name=\"right_knee_link\" pos=\"-0.078273 -0.0021489 -0.17734\" quat=\"0.996179 0 0.0873386 0\">\n              <inertial pos=\"0.005457 -0.003964 -0.12074\" quat=\"0.923439 0.0345276 0.0116333 -0.382012\" mass=\"1.932\" diaginertia=\"0.011374 0.0112843 0.00146452\"/>\n              <joint name=\"right_knee_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.087267 2.8798\" actuatorfrcrange=\"-139 139\"/>\n              <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_knee_link\"/>\n              <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_knee_link\"/>\n              <body name=\"right_ankle_pitch_link\" pos=\"0 9.4445e-05 -0.30001\">\n                <inertial pos=\"-0.007269 0 0.011137\" quat=\"0.603053 0.369225 0.369225 0.603053\" mass=\"0.074\" diaginertia=\"1.89e-05 1.40805e-05 6.9195e-06\"/>\n                <joint name=\"right_ankle_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.87267 0.5236\" actuatorfrcrange=\"-50 50\"/>\n                <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_ankle_pitch_link\"/>\n                <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_ankle_pitch_link\"/>\n                <body name=\"right_ankle_roll_link\" pos=\"0 0 -0.017558\">\n                  <inertial pos=\"0.026505 0 -0.016425\" quat=\"0.000481092 0.728482 0.000618967 0.685065\" mass=\"0.608\" diaginertia=\"0.00167218 0.0016161 0.000217621\"/>\n                  <joint name=\"right_ankle_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.2618 0.2618\" actuatorfrcrange=\"-50 50\"/>\n                  <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"right_ankle_roll_link\"/>\n                  <geom size=\"0.005\" pos=\"-0.05 0.025 -0.03\" rgba=\"0.2 0.2 0.2 1\"/>\n                  <geom size=\"0.005\" pos=\"-0.05 -0.025 -0.03\" rgba=\"0.2 0.2 0.2 1\"/>\n                  <geom size=\"0.005\" pos=\"0.12 0.03 -0.03\" rgba=\"0.2 0.2 0.2 1\"/>\n                  <geom size=\"0.005\" pos=\"0.12 -0.03 -0.03\" rgba=\"0.2 0.2 0.2 1\"/>\n                </body>\n              </body>\n            </body>\n          </body>\n        </body>\n      </body>\n      <body name=\"waist_yaw_link\">\n        <inertial pos=\"0.003494 0.000233 0.018034\" quat=\"0.289697 0.591001 -0.337795 0.672821\" mass=\"0.214\" diaginertia=\"0.000163531 0.000107714 0.000102205\"/>\n        <joint name=\"waist_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-2.618 2.618\" actuatorfrcrange=\"-88 88\"/>\n        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"waist_yaw_link\"/>\n        <body name=\"waist_roll_link\" pos=\"-0.0039635 0 0.044\">\n          <inertial pos=\"0 2.3e-05 0\" quat=\"0.5 0.5 -0.5 0.5\" mass=\"0.086\" diaginertia=\"8.245e-06 7.079e-06 6.339e-06\"/>\n          <joint name=\"waist_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-0.52 0.52\" actuatorfrcrange=\"-50 50\"/>\n          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"waist_roll_link\"/>\n          <body name=\"torso_link\">\n            <inertial pos=\"0.00203158 0.000339683 0.184568\" quat=\"0.999803 -6.03319e-05 0.0198256 0.00131986\" mass=\"7.818\" diaginertia=\"0.121847 0.109825 0.0273735\"/>\n            <joint name=\"waist_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-0.52 0.52\" actuatorfrcrange=\"-50 50\"/>\n            <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"torso_link\"/>\n            <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"torso_link\"/>\n            <geom pos=\"0.0039635 0 -0.044\" quat=\"1 0 0 0\" type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"logo_link\"/>\n            <geom pos=\"0.0039635 0 -0.044\" quat=\"1 0 0 0\" type=\"mesh\" rgba=\"0.2 0.2 0.2 1\" mesh=\"logo_link\"/>\n            <geom pos=\"0.0039635 0 -0.044\" type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.2 0.2 0.2 1\" mesh=\"head_link\"/>\n            <geom pos=\"0.0039635 0 -0.044\" type=\"mesh\" rgba=\"0.2 0.2 0.2 1\" mesh=\"head_link\"/>\n            <site name=\"imu_in_torso\" size=\"0.01\" pos=\"-0.03959 -0.00224 0.14792\"/>\n            <body name=\"left_shoulder_pitch_link\" pos=\"0.0039563 0.10022 0.24778\" quat=\"0.990264 0.139201 1.38722e-05 -9.86868e-05\">\n              <inertial pos=\"0 0.035892 -0.011628\" quat=\"0.654152 0.0130458 -0.326267 0.68225\" mass=\"0.718\" diaginertia=\"0.000465864 0.000432842 0.000406394\"/>\n              <joint name=\"left_shoulder_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-3.0892 2.6704\" actuatorfrcrange=\"-25 25\"/>\n              <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_shoulder_pitch_link\"/>\n              <geom size=\"0.03 0.025\" pos=\"0 0.04 -0.01\" quat=\"0.707107 0 0.707107 0\" type=\"cylinder\" rgba=\"0.7 0.7 0.7 1\"/>\n              <body name=\"left_shoulder_roll_link\" pos=\"0 0.038 -0.013831\" quat=\"0.990268 -0.139172 0 0\">\n                <inertial pos=\"-0.000227 0.00727 -0.063243\" quat=\"0.701256 -0.0196223 -0.00710317 0.712604\" mass=\"0.643\" diaginertia=\"0.000691311 0.000618011 0.000388977\"/>\n                <joint name=\"left_shoulder_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-1.5882 2.2515\" actuatorfrcrange=\"-25 25\"/>\n                <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_shoulder_roll_link\"/>\n                <geom size=\"0.03 0.015\" pos=\"-0.004 0.006 -0.053\" type=\"cylinder\" rgba=\"0.7 0.7 0.7 1\"/>\n                <body name=\"left_shoulder_yaw_link\" pos=\"0 0.00624 -0.1032\">\n                  <inertial pos=\"0.010773 -0.002949 -0.072009\" quat=\"0.716879 -0.0964829 -0.0679942 0.687134\" mass=\"0.734\" diaginertia=\"0.00106187 0.00103217 0.000400661\"/>\n                  <joint name=\"left_shoulder_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-2.618 2.618\" actuatorfrcrange=\"-25 25\"/>\n                  <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_shoulder_yaw_link\"/>\n                  <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_shoulder_yaw_link\"/>\n                  <body name=\"left_elbow_link\" pos=\"0.015783 0 -0.080518\">\n                    <inertial pos=\"0.064956 0.004454 -0.010062\" quat=\"0.541765 0.636132 0.388821 0.388129\" mass=\"0.6\" diaginertia=\"0.000443035 0.000421612 0.000259353\"/>\n                    <joint name=\"left_elbow_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.0472 2.0944\" actuatorfrcrange=\"-25 25\"/>\n                    <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_elbow_link\"/>\n                    <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_elbow_link\"/>\n                    <body name=\"left_wrist_roll_link\" pos=\"0.1 0.00188791 -0.01\">\n                      <inertial pos=\"0.0171394 0.000537591 4.8864e-07\" quat=\"0.575338 0.411667 -0.574906 0.411094\" mass=\"0.085445\" diaginertia=\"5.48211e-05 4.96646e-05 3.57798e-05\"/>\n                      <joint name=\"left_wrist_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-1.97222 1.97222\" actuatorfrcrange=\"-25 25\"/>\n                      <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_roll_link\"/>\n                      <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_roll_link\"/>\n                      <body name=\"left_wrist_pitch_link\" pos=\"0.038 0 0\">\n                        <inertial pos=\"0.0229999 -0.00111685 -0.00111658\" quat=\"0.249998 0.661363 0.293036 0.643608\" mass=\"0.48405\" diaginertia=\"0.000430353 0.000429873 0.000164648\"/>\n                        <joint name=\"left_wrist_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\"/>\n                        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_pitch_link\"/>\n                        <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_pitch_link\"/>\n                        <body name=\"left_wrist_yaw_link\" pos=\"0.046 0 0\">\n                          <inertial pos=\"0.0708244 0.000191745 0.00161742\" quat=\"0.510571 0.526295 0.468078 0.493188\" mass=\"0.254576\" diaginertia=\"0.000646113 0.000559993 0.000147566\"/>\n                          <joint name=\"left_wrist_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\"/>\n                          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_yaw_link\"/>\n                          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_wrist_yaw_link\"/>\n                          <!-- <body name=\"left_rubber_hand\" pos=\"0.0415 0.003 0\"> -->\n                            <!-- <inertial pos=\"0.05361310808 -0.00295905240 0.00215413091\" quat=\"1 0 0 0\" mass=\"0.170\" diaginertia=\"0.00010099485234748 0.00028135871571621 0.00021894770413514\"/> -->\n                            <!-- <joint name=\"left_hand_palm_joint\" pos=\"0 0 0\" type=\"fixed\"/> -->\n                          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_rubber_hand\"/>\n                          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"left_rubber_hand\"/>\n                          <!-- </body> -->\n                        </body>\n                      </body>\n                    </body>\n                  </body>\n                </body>\n              </body>\n            </body>\n            <body name=\"right_shoulder_pitch_link\" pos=\"0.0039563 -0.10021 0.24778\" quat=\"0.990264 -0.139201 1.38722e-05 9.86868e-05\">\n              <inertial pos=\"0 -0.035892 -0.011628\" quat=\"0.68225 -0.326267 0.0130458 0.654152\" mass=\"0.718\" diaginertia=\"0.000465864 0.000432842 0.000406394\"/>\n              <joint name=\"right_shoulder_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-3.0892 2.6704\" actuatorfrcrange=\"-25 25\"/>\n              <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_shoulder_pitch_link\"/>\n              <geom size=\"0.03 0.025\" pos=\"0 -0.04 -0.01\" quat=\"0.707107 0 0.707107 0\" type=\"cylinder\" rgba=\"0.7 0.7 0.7 1\"/>\n              <body name=\"right_shoulder_roll_link\" pos=\"0 -0.038 -0.013831\" quat=\"0.990268 0.139172 0 0\">\n                <inertial pos=\"-0.000227 -0.00727 -0.063243\" quat=\"0.712604 -0.00710317 -0.0196223 0.701256\" mass=\"0.643\" diaginertia=\"0.000691311 0.000618011 0.000388977\"/>\n                <joint name=\"right_shoulder_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-2.2515 1.5882\" actuatorfrcrange=\"-25 25\"/>\n                <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_shoulder_roll_link\"/>\n                <geom size=\"0.03 0.015\" pos=\"-0.004 -0.006 -0.053\" type=\"cylinder\" rgba=\"0.7 0.7 0.7 1\"/>\n                <body name=\"right_shoulder_yaw_link\" pos=\"0 -0.00624 -0.1032\">\n                  <inertial pos=\"0.010773 0.002949 -0.072009\" quat=\"0.687134 -0.0679942 -0.0964829 0.716879\" mass=\"0.734\" diaginertia=\"0.00106187 0.00103217 0.000400661\"/>\n                  <joint name=\"right_shoulder_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-2.618 2.618\" actuatorfrcrange=\"-25 25\"/>\n                  <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_shoulder_yaw_link\"/>\n                  <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_shoulder_yaw_link\"/>\n                  <body name=\"right_elbow_link\" pos=\"0.015783 0 -0.080518\">\n                    <inertial pos=\"0.064956 -0.004454 -0.010062\" quat=\"0.388129 0.388821 0.636132 0.541765\" mass=\"0.6\" diaginertia=\"0.000443035 0.000421612 0.000259353\"/>\n                    <joint name=\"right_elbow_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.0472 2.0944\" actuatorfrcrange=\"-25 25\"/>\n                    <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_elbow_link\"/>\n                    <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_elbow_link\"/>\n                    <body name=\"right_wrist_roll_link\" pos=\"0.1 -0.00188791 -0.01\">\n                      <inertial pos=\"0.0171394 -0.000537591 4.8864e-07\" quat=\"0.411667 0.575338 -0.411094 0.574906\" mass=\"0.085445\" diaginertia=\"5.48211e-05 4.96646e-05 3.57798e-05\"/>\n                      <joint name=\"right_wrist_roll_joint\" pos=\"0 0 0\" axis=\"1 0 0\" range=\"-1.97222 1.97222\" actuatorfrcrange=\"-25 25\"/>\n                      <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_roll_link\"/>\n                      <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_roll_link\"/>\n                      <body name=\"right_wrist_pitch_link\" pos=\"0.038 0 0\">\n                        <inertial pos=\"0.0229999 0.00111685 -0.00111658\" quat=\"0.643608 0.293036 0.661363 0.249998\" mass=\"0.48405\" diaginertia=\"0.000430353 0.000429873 0.000164648\"/>\n                        <joint name=\"right_wrist_pitch_joint\" pos=\"0 0 0\" axis=\"0 1 0\" range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\"/>\n                        <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_pitch_link\"/>\n                        <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_pitch_link\"/>\n                        <body name=\"right_wrist_yaw_link\" pos=\"0.046 0 0\">\n                          <inertial pos=\"0.0708244 -0.000191745 0.00161742\" quat=\"0.493188 0.468078 0.526295 0.510571\" mass=\"0.254576\" diaginertia=\"0.000646113 0.000559993 0.000147566\"/>\n                          <joint name=\"right_wrist_yaw_joint\" pos=\"0 0 0\" axis=\"0 0 1\" range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\"/>\n                          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_yaw_link\"/>\n                          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_wrist_yaw_link\"/>\n                          <!-- <body name=\"right_rubber_hand\" pos=\"0.0415 -0.003 0\"> -->\n                            <!-- <inertial pos=\"0.05361310808 0.00295905240 0.00215413091\" quat=\"1 0 0 0\" mass=\"0.170\" diaginertia=\"0.00010099485234748 0.00028135871571621 0.00021894770413514\"/> -->\n                            <!-- <joint name=\"right_hand_palm_joint\" pos=\"0 0 0\" type=\"fixed\"/> -->\n                          <geom type=\"mesh\" contype=\"0\" conaffinity=\"0\" group=\"1\" density=\"0\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_rubber_hand\"/>\n                          <geom type=\"mesh\" rgba=\"0.7 0.7 0.7 1\" mesh=\"right_rubber_hand\"/>\n                          <!-- </body> -->\n                        </body>\n                      </body>\n                    </body>\n                  </body>\n                </body>\n              </body>\n            </body>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n\n  <actuator>\n    <motor name=\"left_hip_pitch_joint\" joint=\"left_hip_pitch_joint\"/>\n    <motor name=\"left_hip_roll_joint\" joint=\"left_hip_roll_joint\"/>\n    <motor name=\"left_hip_yaw_joint\" joint=\"left_hip_yaw_joint\"/>\n    <motor name=\"left_knee_joint\" joint=\"left_knee_joint\"/>\n    <motor name=\"left_ankle_pitch_joint\" joint=\"left_ankle_pitch_joint\"/>\n    <motor name=\"left_ankle_roll_joint\" joint=\"left_ankle_roll_joint\"/>\n    <motor name=\"right_hip_pitch_joint\" joint=\"right_hip_pitch_joint\"/>\n    <motor name=\"right_hip_roll_joint\" joint=\"right_hip_roll_joint\"/>\n    <motor name=\"right_hip_yaw_joint\" joint=\"right_hip_yaw_joint\"/>\n    <motor name=\"right_knee_joint\" joint=\"right_knee_joint\"/>\n    <motor name=\"right_ankle_pitch_joint\" joint=\"right_ankle_pitch_joint\"/>\n    <motor name=\"right_ankle_roll_joint\" joint=\"right_ankle_roll_joint\"/>\n    <motor name=\"waist_yaw_joint\" joint=\"waist_yaw_joint\"/>\n    <motor name=\"waist_roll_joint\" joint=\"waist_roll_joint\"/>\n    <motor name=\"waist_pitch_joint\" joint=\"waist_pitch_joint\"/>\n    <motor name=\"left_shoulder_pitch_joint\" joint=\"left_shoulder_pitch_joint\"/>\n    <motor name=\"left_shoulder_roll_joint\" joint=\"left_shoulder_roll_joint\"/>\n    <motor name=\"left_shoulder_yaw_joint\" joint=\"left_shoulder_yaw_joint\"/>\n    <motor name=\"left_elbow_joint\" joint=\"left_elbow_joint\"/>\n    <motor name=\"left_wrist_roll_joint\" joint=\"left_wrist_roll_joint\"/>\n    <motor name=\"left_wrist_pitch_joint\" joint=\"left_wrist_pitch_joint\"/>\n    <motor name=\"left_wrist_yaw_joint\" joint=\"left_wrist_yaw_joint\"/>\n    <motor name=\"right_shoulder_pitch_joint\" joint=\"right_shoulder_pitch_joint\"/>\n    <motor name=\"right_shoulder_roll_joint\" joint=\"right_shoulder_roll_joint\"/>\n    <motor name=\"right_shoulder_yaw_joint\" joint=\"right_shoulder_yaw_joint\"/>\n    <motor name=\"right_elbow_joint\" joint=\"right_elbow_joint\"/>\n    <motor name=\"right_wrist_roll_joint\" joint=\"right_wrist_roll_joint\"/>\n    <motor name=\"right_wrist_pitch_joint\" joint=\"right_wrist_pitch_joint\"/>\n    <motor name=\"right_wrist_yaw_joint\" joint=\"right_wrist_yaw_joint\"/>\n  </actuator>\n\n  <sensor>\n    <gyro name=\"imu-torso-angular-velocity\" site=\"imu_in_torso\" noise=\"5e-4\" cutoff=\"34.9\"/>\n    <accelerometer name=\"imu-torso-linear-acceleration\" site=\"imu_in_torso\" noise=\"1e-2\" cutoff=\"157\"/>\n    <gyro name=\"imu-pelvis-angular-velocity\" site=\"imu_in_pelvis\" noise=\"5e-4\" cutoff=\"34.9\"/>\n    <accelerometer name=\"imu-pelvis-linear-acceleration\" site=\"imu_in_pelvis\" noise=\"1e-2\" cutoff=\"157\"/>\n  </sensor>\n\n\n  <!-- setup scene -->\n  <statistic center=\"1.0 0.7 1.0\" extent=\"0.8\"/>\n  <visual>\n    <headlight diffuse=\"0.6 0.6 0.6\" ambient=\"0.1 0.1 0.1\" specular=\"0.9 0.9 0.9\"/>\n    <rgba haze=\"0.15 0.25 0.35 1\"/>\n    <global azimuth=\"-140\" elevation=\"-20\"/>\n  </visual>\n  <asset>\n    <texture type=\"skybox\" builtin=\"flat\" rgb1=\"0 0 0\" rgb2=\"0 0 0\" width=\"512\" height=\"3072\"/>\n    <texture type=\"2d\" name=\"groundplane\" builtin=\"checker\" mark=\"edge\" rgb1=\"0.2 0.3 0.4\" rgb2=\"0.1 0.2 0.3\" markrgb=\"0.8 0.8 0.8\" width=\"300\" height=\"300\"/>\n    <material name=\"groundplane\" texture=\"groundplane\" texuniform=\"true\" texrepeat=\"5 5\" reflectance=\"0.2\"/>\n  </asset>\n  <worldbody>\n    <light pos=\"1 0 3.5\" dir=\"0 0 -1\" directional=\"true\"/>\n    <geom name=\"floor\" size=\"0 0 0.05\" type=\"plane\" material=\"groundplane\"/>\n  </worldbody>\n</mujoco>"
  },
  {
    "path": "assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0_s100.urdf",
    "content": "<robot name=\"g1_29dof_rev_1_0\">\r\n  <material name=\"dark\">\r\n    <color rgba=\"0.2 0.2 0.2 1\"/>\r\n  </material>\r\n  <material name=\"white\">\r\n    <color rgba=\"0.7 0.7 0.7 1\"/>\r\n  </material>\r\n  <material name=\"red\">\r\n    <color rgba=\"1.0 0.0 0.0 1\"/>\r\n  </material>\r\n  <material name=\"blue\">\r\n    <color rgba=\"0.0 0.0 1.0 1\"/>\r\n  </material>\r\n  <material name=\"green\">\r\n    <color rgba=\"0.0 1.0 0.0 1\"/>\r\n  </material>\r\n  <material name=\"yellow\">\r\n    <color rgba=\"1.0 1.0 0.0 1\"/>\r\n  </material>\r\n  <material name=\"purple\">\r\n    <color rgba=\"1.0 0.0 1.0 1\"/>\r\n  </material>\r\n\r\n  <mujoco>\r\n    <compiler meshdir=\"../meshes\" discardvisual=\"false\"/>\r\n  </mujoco>\r\n\r\n  <!-- [CAUTION] uncomment when convert to mujoco -->\r\n  <!-- <link name=\"world\"></link>\r\n  <joint name=\"floating_base_joint\" type=\"floating\">\r\n    <parent link=\"world\"/>\r\n    <child link=\"pelvis\"/>\r\n  </joint> -->\r\n\r\n  <link name=\"pelvis\">\r\n    <inertial>\r\n      <origin xyz=\"0 0 -0.07605\" rpy=\"0 0 0\"/>\r\n      <mass value=\"3.813\"/>\r\n      <inertia ixx=\"0.010549\" ixy=\"0\" ixz=\"2.1E-06\" iyy=\"0.0093089\" iyz=\"0\" izz=\"0.0079184\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/pelvis.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n  </link>\r\n  <link name=\"pelvis_contour_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.001\"/>\r\n      <inertia ixx=\"1e-7\" ixy=\"0\" ixz=\"0\" iyy=\"1e-7\" iyz=\"0\" izz=\"1e-7\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/pelvis_contour_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/pelvis_contour_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"pelvis_contour_joint\" type=\"fixed\">\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"pelvis_contour_link\"/>\r\n  </joint>\r\n\r\n  <!-- Legs -->\r\n  <link name=\"left_hip_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.002741 0.047791 -0.02606\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.35\"/>\r\n      <inertia ixx=\"0.001811\" ixy=\"3.68E-05\" ixz=\"-3.44E-05\" iyy=\"0.0014193\" iyz=\"0.000171\" izz=\"0.0012812\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_hip_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0.064452 -0.1027\" rpy=\"0 0 0\"/>\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"left_hip_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-2.5307\" upper=\"2.8798\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"left_hip_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.029812 -0.001045 -0.087934\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.52\"/>\r\n      <inertia ixx=\"0.0023773\" ixy=\"-3.8E-06\" ixz=\"-0.0003908\" iyy=\"0.0024123\" iyz=\"1.84E-05\" izz=\"0.0016595\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_roll_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_hip_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0.052 -0.030465\" rpy=\"0 -0.1749 0\"/>\r\n    <parent link=\"left_hip_pitch_link\"/>\r\n    <child link=\"left_hip_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-0.5236\" upper=\"2.9671\" effort=\"139\" velocity=\"20\"/>\r\n  </joint>\r\n  <link name=\"left_hip_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.057709 -0.010981 -0.15078\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.702\"/>\r\n      <inertia ixx=\"0.0057774\" ixy=\"-0.0005411\" ixz=\"-0.0023948\" iyy=\"0.0076124\" iyz=\"-0.0007072\" izz=\"0.003149\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_hip_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_hip_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.025001 0 -0.12412\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_hip_roll_link\"/>\r\n    <child link=\"left_hip_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.7576\" upper=\"2.7576\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"left_knee_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.005457 0.003964 -0.12074\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.932\"/>\r\n      <inertia ixx=\"0.011329\" ixy=\"4.82E-05\" ixz=\"-4.49E-05\" iyy=\"0.011277\" iyz=\"-0.0007146\" izz=\"0.0015168\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_knee_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_knee_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_knee_joint\" type=\"revolute\">\r\n    <origin xyz=\"-0.078273 0.0021489 -0.17734\" rpy=\"0 0.1749 0\"/>\r\n    <parent link=\"left_hip_yaw_link\"/>\r\n    <child link=\"left_knee_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.087267\" upper=\"2.8798\" effort=\"139\" velocity=\"20\"/>\r\n  </joint>\r\n  <link name=\"left_ankle_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.007269 0 0.011137\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.074\"/>\r\n      <inertia ixx=\"8.4E-06\" ixy=\"0\" ixz=\"-2.9E-06\" iyy=\"1.89E-05\" iyz=\"0\" izz=\"1.26E-05\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_ankle_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_ankle_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_ankle_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -9.4445E-05 -0.30001\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_knee_link\"/>\r\n    <child link=\"left_ankle_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.87267\" upper=\"0.5236\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n  <link name=\"left_ankle_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.026505 0 -0.016425\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.608\"/>\r\n      <inertia ixx=\"0.0002231\" ixy=\"2E-07\" ixz=\"8.91E-05\" iyy=\"0.0016161\" iyz=\"-1E-07\" izz=\"0.0016667\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_ankle_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"-0.05 0.025 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"-0.05 -0.025 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"0.12 0.03 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"0.12 -0.03 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_ankle_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0 -0.017558\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_ankle_pitch_link\"/>\r\n    <child link=\"left_ankle_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-0.2618\" upper=\"0.2618\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n  <link name=\"right_hip_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.002741 -0.047791 -0.02606\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.35\"/>\r\n      <inertia ixx=\"0.001811\" ixy=\"-3.68E-05\" ixz=\"-3.44E-05\" iyy=\"0.0014193\" iyz=\"-0.000171\" izz=\"0.0012812\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_hip_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -0.064452 -0.1027\" rpy=\"0 0 0\"/>\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"right_hip_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-2.5307\" upper=\"2.8798\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"right_hip_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.029812 0.001045 -0.087934\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.52\"/>\r\n      <inertia ixx=\"0.0023773\" ixy=\"3.8E-06\" ixz=\"-0.0003908\" iyy=\"0.0024123\" iyz=\"-1.84E-05\" izz=\"0.0016595\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_roll_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_hip_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -0.052 -0.030465\" rpy=\"0 -0.1749 0\"/>\r\n    <parent link=\"right_hip_pitch_link\"/>\r\n    <child link=\"right_hip_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-2.9671\" upper=\"0.5236\" effort=\"139\" velocity=\"20\"/>\r\n  </joint>\r\n  <link name=\"right_hip_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.057709 0.010981 -0.15078\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.702\"/>\r\n      <inertia ixx=\"0.0057774\" ixy=\"0.0005411\" ixz=\"-0.0023948\" iyy=\"0.0076124\" iyz=\"0.0007072\" izz=\"0.003149\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_hip_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_hip_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.025001 0 -0.12412\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_hip_roll_link\"/>\r\n    <child link=\"right_hip_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.7576\" upper=\"2.7576\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"right_knee_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.005457 -0.003964 -0.12074\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.932\"/>\r\n      <inertia ixx=\"0.011329\" ixy=\"-4.82E-05\" ixz=\"4.49E-05\" iyy=\"0.011277\" iyz=\"0.0007146\" izz=\"0.0015168\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_knee_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_knee_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_knee_joint\" type=\"revolute\">\r\n    <origin xyz=\"-0.078273 -0.0021489 -0.17734\" rpy=\"0 0.1749 0\"/>\r\n    <parent link=\"right_hip_yaw_link\"/>\r\n    <child link=\"right_knee_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.087267\" upper=\"2.8798\" effort=\"139\" velocity=\"20\"/>\r\n  </joint>\r\n  <link name=\"right_ankle_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.007269 0 0.011137\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.074\"/>\r\n      <inertia ixx=\"8.4E-06\" ixy=\"0\" ixz=\"-2.9E-06\" iyy=\"1.89E-05\" iyz=\"0\" izz=\"1.26E-05\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_ankle_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_ankle_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_ankle_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 9.4445E-05 -0.30001\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_knee_link\"/>\r\n    <child link=\"right_ankle_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.87267\" upper=\"0.5236\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n  <link name=\"right_ankle_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.026505 0 -0.016425\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.608\"/>\r\n      <inertia ixx=\"0.0002231\" ixy=\"-2E-07\" ixz=\"8.91E-05\" iyy=\"0.0016161\" iyz=\"1E-07\" izz=\"0.0016667\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_ankle_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"-0.05 0.025 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"-0.05 -0.025 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"0.12 0.03 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n    <collision>\r\n      <origin xyz=\"0.12 -0.03 -0.03\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <sphere radius=\"0.005\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_ankle_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0 -0.017558\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_ankle_pitch_link\"/>\r\n    <child link=\"right_ankle_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-0.2618\" upper=\"0.2618\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n\r\n  <!-- Torso -->\r\n  <link name=\"waist_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.003494 0.000233 0.018034\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.214\"/>\r\n      <inertia ixx=\"0.00010673\" ixy=\"2.703E-06\" ixz=\"-7.631E-06\" iyy=\"0.00010422\" iyz=\"-2.01E-07\" izz=\"0.0001625\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/waist_yaw_link_rev_1_0.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n  </link>\r\n  <joint name=\"waist_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"waist_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.618\" upper=\"2.618\" effort=\"88\" velocity=\"32\"/>\r\n  </joint>\r\n  <link name=\"waist_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 2.3E-05 0\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.086\"/>\r\n      <inertia ixx=\"7.079E-06\" ixy=\"0\" ixz=\"0\" iyy=\"6.339E-06\" iyz=\"0\" izz=\"8.245E-06\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/waist_roll_link_rev_1_0.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n  </link>\r\n  <joint name=\"waist_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"-0.0039635 0 0.044\" rpy=\"0 0 0\"/>\r\n    <parent link=\"waist_yaw_link\"/>\r\n    <child link=\"waist_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-0.52\" upper=\"0.52\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n  <link name=\"torso_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.000931 0.000346 0.15082\" rpy=\"0 0 0\"/>\r\n      <mass value=\"6.78\"/>\r\n      <inertia ixx=\"0.05905\" ixy=\"3.3302E-05\" ixz=\"-0.0017715\" iyy=\"0.047014\" iyz=\"-2.2399E-05\" izz=\"0.025652\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/torso_link_rev_1_0.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/torso_link_rev_1_0.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"waist_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n    <parent link=\"waist_roll_link\"/>\r\n    <child link=\"torso_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-0.52\" upper=\"0.52\" effort=\"35\" velocity=\"30\"/>\r\n  </joint>\r\n\r\n  <!-- LOGO -->\r\n  <joint name=\"logo_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0039635 0 -0.044\" rpy=\"0 0 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"logo_link\"/>\r\n  </joint>\r\n  <link name=\"logo_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.001\"/>\r\n      <inertia ixx=\"1e-7\" ixy=\"0\" ixz=\"0\" iyy=\"1e-7\" iyz=\"0\" izz=\"1e-7\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/logo_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/logo_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n\r\n  <!-- Head -->\r\n  <link name=\"head_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.005267 0.000299 0.449869\" rpy=\"0 0 0\"/>\r\n      <mass value=\"1.036\"/>\r\n      <inertia ixx=\"0.004085051\" ixy=\"-2.543E-06\" ixz=\"-6.9455E-05\" iyy=\"0.004185212\" iyz=\"-3.726E-06\" izz=\"0.001807911\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/head_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"dark\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/head_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"head_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0039635 0 -0.044\" rpy=\"0 0 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"head_link\"/>\r\n  </joint>\r\n\r\n\r\n  <!-- IMU -->\r\n  <link name=\"imu_in_torso\"></link>\r\n  <joint name=\"imu_in_torso_joint\" type=\"fixed\">\r\n    <origin xyz=\"-0.03959 -0.00224 0.14792\" rpy=\"0 0 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"imu_in_torso\"/>\r\n  </joint>\r\n\r\n  <link name=\"imu_in_pelvis\"></link>\r\n  <joint name=\"imu_in_pelvis_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.04525 0 -0.08339\" rpy=\"0 0 0\"/>\r\n    <parent link=\"pelvis\"/>\r\n    <child link=\"imu_in_pelvis\"/>\r\n  </joint>\r\n\r\n  <!-- d435 -->\r\n  <link name=\"d435_link\"></link>\r\n  <joint name=\"d435_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0576235 0.01753 0.42987\" rpy=\"0 0.8307767239493009 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"d435_link\"/>\r\n  </joint>\r\n\r\n  <!-- mid360 -->\r\n  <link name=\"mid360_link\"></link>\r\n  <joint name=\"mid360_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0002835 0.00003 0.41618\" rpy=\"0 0.04014257279586953 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"mid360_link\"/>\r\n  </joint>\r\n    <!-- S100 Processor - Using STL -->\r\n  <link name=\"shell_support\">\r\n    <inertial>\r\n      <origin xyz=\"0.0 0.0 -0.0\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.6\"/>\r\n      <inertia ixx=\"0.000943\" ixy=\"-2.1e-08\" ixz=\"-2.70e-05\"\r\n              iyy=\"0.000991\" iyz=\"1.8e-08\"  izz=\"0.001600\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"meshes/shell_support_s100_transformed.stl\" scale=\"0.001 0.001 0.001\"/>\r\n      </geometry>\r\n      <material name=\"purple\"/>\r\n    </visual>\r\n    <collision>\r\n     <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"meshes/shell_support_s100_transformed.stl\" scale=\"0.001 0.001 0.001\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"s100_joint\" type=\"fixed\">\r\n    <origin xyz=\"-0.09 -0.005 0.06\" rpy=\"0 0 0\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"shell_support\"/>\r\n  </joint>\r\n  <!-- Arm -->\r\n  <link name=\"left_shoulder_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 0.035892 -0.011628\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.718\"/>\r\n      <inertia ixx=\"0.0004291\" ixy=\"-9.2E-06\" ixz=\"6.4E-06\" iyy=\"0.000453\" iyz=\"2.26E-05\" izz=\"0.000423\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_shoulder_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0.04 -0.01\" rpy=\"0 1.5707963267948966 0\"/>\r\n      <geometry>\r\n        <cylinder radius=\"0.03\" length=\"0.05\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_shoulder_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.0039563 0.10022 0.24778\" rpy=\"0.27931 5.4949E-05 -0.00019159\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"left_shoulder_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-3.0892\" upper=\"2.6704\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"left_shoulder_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.000227 0.00727 -0.063243\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.643\"/>\r\n      <inertia ixx=\"0.0006177\" ixy=\"-1E-06\" ixz=\"8.7E-06\" iyy=\"0.0006912\" iyz=\"-5.3E-06\" izz=\"0.0003894\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_shoulder_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"-0.004 0.006 -0.053\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <cylinder radius=\"0.03\" length=\"0.03\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_shoulder_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0.038 -0.013831\" rpy=\"-0.27925 0 0\"/>\r\n    <parent link=\"left_shoulder_pitch_link\"/>\r\n    <child link=\"left_shoulder_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-1.5882\" upper=\"2.2515\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"left_shoulder_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.010773 -0.002949 -0.072009\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.734\"/>\r\n      <inertia ixx=\"0.0009988\" ixy=\"7.9E-06\" ixz=\"0.0001412\" iyy=\"0.0010605\" iyz=\"-2.86E-05\" izz=\"0.0004354\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_shoulder_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_shoulder_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_shoulder_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 0.00624 -0.1032\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_shoulder_roll_link\"/>\r\n    <child link=\"left_shoulder_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.618\" upper=\"2.618\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"left_elbow_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.064956 0.004454 -0.010062\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.6\"/>\r\n      <inertia ixx=\"0.0002891\" ixy=\"6.53E-05\" ixz=\"1.72E-05\" iyy=\"0.0004152\" iyz=\"-5.6E-06\" izz=\"0.0004197\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_elbow_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_elbow_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_elbow_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.015783 0 -0.080518\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_shoulder_yaw_link\"/>\r\n    <child link=\"left_elbow_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-1.0472\" upper=\"2.0944\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <joint name=\"left_wrist_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.100 0.00188791 -0.010\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <parent link=\"left_elbow_link\"/>\r\n    <child link=\"left_wrist_roll_link\"/>\r\n    <limit effort=\"25\" velocity=\"37\" lower=\"-1.972222054\" upper=\"1.972222054\"/>\r\n  </joint>\r\n  <link name=\"left_wrist_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.01713944778 0.00053759094 0.00000048864\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.08544498\"/>\r\n      <inertia ixx=\"0.00004821544023\" ixy=\"-0.00000424511021\" ixz=\"0.00000000510599\" iyy=\"0.00003722899093\" iyz=\"-0.00000000123525\" izz=\"0.00005482106541\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_roll_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_wrist_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.038 0 0\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <parent link=\"left_wrist_roll_link\"/>\r\n    <child link=\"left_wrist_pitch_link\"/>\r\n    <limit effort=\"5\" velocity=\"22\" lower=\"-1.614429558\" upper=\"1.614429558\"/>\r\n  </joint>\r\n  <link name=\"left_wrist_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.02299989837 -0.00111685314 -0.00111658096\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.48404956\"/>\r\n      <inertia ixx=\"0.00016579646273\" ixy=\"-0.00001231206746\" ixz=\"0.00001231699194\" iyy=\"0.00042954057410\" iyz=\"0.00000081417712\" izz=\"0.00042953697654\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_wrist_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.046 0 0\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <parent link=\"left_wrist_pitch_link\"/>\r\n    <child link=\"left_wrist_yaw_link\"/>\r\n    <limit effort=\"5\" velocity=\"22\" lower=\"-1.614429558\" upper=\"1.614429558\"/>\r\n  </joint>\r\n  <link name=\"left_wrist_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.02200381568 0.00049485096 0.00053861123\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.08457647\"/>\r\n      <inertia ixx=\"0.00004929128828\" ixy=\"-0.00000045735494\" ixz=\"0.00000445867591\" iyy=\"0.00005973338134\" iyz=\"0.00000043217198\" izz=\"0.00003928083826\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_wrist_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"left_hand_palm_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0415 0.003 0\" rpy=\"0 0 0\"/>\r\n    <parent link=\"left_wrist_yaw_link\"/>\r\n    <child link=\"left_rubber_hand\"/>\r\n  </joint>\r\n  <link name=\"left_rubber_hand\">\r\n    <inertial>\r\n      <origin xyz=\"0.05361310808 -0.00295905240 0.00215413091\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.170\"/>\r\n      <inertia ixx=\"0.00010099485234748\" ixy=\"0.00003618590790516\" ixz=\"-0.00000074301518642\" iyy=\"0.00028135871571621\" iyz=\"0.00000330189743286\" izz=\"0.00021894770413514\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/left_rubber_hand.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n  </link>\r\n  <link name=\"right_shoulder_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0 -0.035892 -0.011628\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.718\"/>\r\n      <inertia ixx=\"0.0004291\" ixy=\"9.2E-06\" ixz=\"6.4E-06\" iyy=\"0.000453\" iyz=\"-2.26E-05\" izz=\"0.000423\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_shoulder_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 -0.04 -0.01\" rpy=\"0 1.5707963267948966 0\"/>\r\n      <geometry>\r\n        <cylinder radius=\"0.03\" length=\"0.05\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_shoulder_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.0039563 -0.10021 0.24778\" rpy=\"-0.27931 5.4949E-05 0.00019159\"/>\r\n    <parent link=\"torso_link\"/>\r\n    <child link=\"right_shoulder_pitch_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-3.0892\" upper=\"2.6704\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"right_shoulder_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"-0.000227 -0.00727 -0.063243\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.643\"/>\r\n      <inertia ixx=\"0.0006177\" ixy=\"1E-06\" ixz=\"8.7E-06\" iyy=\"0.0006912\" iyz=\"5.3E-06\" izz=\"0.0003894\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_shoulder_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"-0.004 -0.006 -0.053\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <cylinder radius=\"0.03\" length=\"0.03\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_shoulder_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -0.038 -0.013831\" rpy=\"0.27925 0 0\"/>\r\n    <parent link=\"right_shoulder_pitch_link\"/>\r\n    <child link=\"right_shoulder_roll_link\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <limit lower=\"-2.2515\" upper=\"1.5882\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"right_shoulder_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.010773 0.002949 -0.072009\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.734\"/>\r\n      <inertia ixx=\"0.0009988\" ixy=\"-7.9E-06\" ixz=\"0.0001412\" iyy=\"0.0010605\" iyz=\"2.86E-05\" izz=\"0.0004354\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_shoulder_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_shoulder_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_shoulder_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0 -0.00624 -0.1032\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_shoulder_roll_link\"/>\r\n    <child link=\"right_shoulder_yaw_link\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <limit lower=\"-2.618\" upper=\"2.618\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <link name=\"right_elbow_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.064956 -0.004454 -0.010062\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.6\"/>\r\n      <inertia ixx=\"0.0002891\" ixy=\"-6.53E-05\" ixz=\"1.72E-05\" iyy=\"0.0004152\" iyz=\"5.6E-06\" izz=\"0.0004197\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_elbow_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_elbow_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_elbow_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.015783 0 -0.080518\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_shoulder_yaw_link\"/>\r\n    <child link=\"right_elbow_link\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <limit lower=\"-1.0472\" upper=\"2.0944\" effort=\"25\" velocity=\"37\"/>\r\n  </joint>\r\n  <joint name=\"right_wrist_roll_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.100 -0.00188791 -0.010\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"1 0 0\"/>\r\n    <parent link=\"right_elbow_link\"/>\r\n    <child link=\"right_wrist_roll_link\"/>\r\n    <limit effort=\"25\" velocity=\"37\" lower=\"-1.972222054\" upper=\"1.972222054\"/>\r\n  </joint>\r\n  <link name=\"right_wrist_roll_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.01713944778 -0.00053759094 0.00000048864\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.08544498\"/>\r\n      <inertia ixx=\"0.00004821544023\" ixy=\"0.00000424511021\" ixz=\"0.00000000510599\" iyy=\"0.00003722899093\" iyz=\"0.00000000123525\" izz=\"0.00005482106541\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_roll_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_roll_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_wrist_pitch_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.038 0 0\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"0 1 0\"/>\r\n    <parent link=\"right_wrist_roll_link\"/>\r\n    <child link=\"right_wrist_pitch_link\"/>\r\n    <limit effort=\"5\" velocity=\"22\" lower=\"-1.614429558\" upper=\"1.614429558\"/>\r\n  </joint>\r\n  <link name=\"right_wrist_pitch_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.02299989837 0.00111685314 -0.00111658096\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.48404956\"/>\r\n      <inertia ixx=\"0.00016579646273\" ixy=\"0.00001231206746\" ixz=\"0.00001231699194\" iyy=\"0.00042954057410\" iyz=\"-0.00000081417712\" izz=\"0.00042953697654\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_pitch_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_pitch_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_wrist_yaw_joint\" type=\"revolute\">\r\n    <origin xyz=\"0.046 0 0\" rpy=\"0 0 0\"/>\r\n    <axis xyz=\"0 0 1\"/>\r\n    <parent link=\"right_wrist_pitch_link\"/>\r\n    <child link=\"right_wrist_yaw_link\"/>\r\n    <limit effort=\"5\" velocity=\"22\" lower=\"-1.614429558\" upper=\"1.614429558\"/>\r\n  </joint>\r\n  <link name=\"right_wrist_yaw_link\">\r\n    <inertial>\r\n      <origin xyz=\"0.02200381568 -0.00049485096 0.00053861123\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.08457647\"/>\r\n      <inertia ixx=\"0.00004929128828\" ixy=\"0.00000045735494\" ixz=\"0.00000445867591\" iyy=\"0.00005973338134\" iyz=\"-0.00000043217198\" izz=\"0.00003928083826\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_yaw_link.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n    <collision>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_wrist_yaw_link.STL\"/>\r\n      </geometry>\r\n    </collision>\r\n  </link>\r\n  <joint name=\"right_hand_palm_joint\" type=\"fixed\">\r\n    <origin xyz=\"0.0415 -0.003 0\" rpy=\"0 0 0\"/>\r\n    <parent link=\"right_wrist_yaw_link\"/>\r\n    <child link=\"right_rubber_hand\"/>\r\n  </joint>\r\n  <link name=\"right_rubber_hand\">\r\n    <inertial>\r\n      <origin xyz=\"0.05361310808 0.00295905240 0.00215413091\" rpy=\"0 0 0\"/>\r\n      <mass value=\"0.170\"/>\r\n      <inertia ixx=\"0.00010099485234748\" ixy=\"-0.00003618590790516\" ixz=\"-0.00000074301518642\" iyy=\"0.00028135871571621\" iyz=\"-0.00000330189743286\" izz=\"0.00021894770413514\"/>\r\n    </inertial>\r\n    <visual>\r\n      <origin xyz=\"0 0 0\" rpy=\"0 0 0\"/>\r\n      <geometry>\r\n        <mesh filename=\"../meshes/right_rubber_hand.STL\"/>\r\n      </geometry>\r\n      <material name=\"white\"/>\r\n    </visual>\r\n  </link>\r\n</robot>"
  },
  {
    "path": "assets/robots/unitree/G1/29dof/scene_29dof.xml",
    "content": "<mujoco model=\"g1_29dof scene\">\n  <include file=\"g1_29dof.xml\"/>\n\n  <!-- setup scene -->\n  <statistic center=\"0.0 0.0 1.0\" extent=\"0.8\"/>\n  <visual>\n    <headlight diffuse=\"0.75 0.75 0.75\" ambient=\"0.18 0.18 0.18\" specular=\"0.95 0.95 0.95\"/>\n    <rgba haze=\"0.15 0.25 0.35 1\"/>\n    <global azimuth=\"-140\" elevation=\"-20\" offwidth=\"2080\" offheight=\"1170\"/>\n  </visual>\n  <asset>\n     <texture type=\"skybox\" builtin=\"gradient\" rgb1=\"1 1 1\" rgb2=\"1 1 1\" width=\"800\" height=\"800\"/>\n    <texture type=\"2d\" name=\"groundplane\" builtin=\"checker\" mark=\"edge\" rgb1=\"1.0 1.0 1.0\" rgb2=\"0.6 0.8 1.0\" markrgb=\"0 0 0\"\n      width=\"300\" height=\"300\"/>\n    <material name=\"groundplane\" texture=\"groundplane\" texuniform=\"true\" texrepeat=\"5 5\" reflectance=\"0\"/>\n    <texture type=\"skybox\" builtin=\"gradient\" rgb1=\".4 .5 .6\" rgb2=\"0 0 0\" width=\"100\" height=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture name=\"texplane\" builtin=\"checker\" height=\"512\" width=\"512\" rgb1=\".2 .3 .4\" rgb2=\".1 .15 .2\" type=\"2d\" />\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"0.01\" specular=\"0.1\" texrepeat=\"1 1\" texture=\"texplane\" texuniform=\"true\" />\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n   <worldbody>\n    <geom name=\"floor\" size=\"0 0 0.01\" type=\"plane\" material=\"groundplane\" contype=\"1\" conaffinity=\"0\" priority=\"1\" condim=\"3\"/>\n\n        <light diffuse=\"0.65 0.65 0.65\" pos=\"-3 -3 5\" dir=\"3 3 -5\" castshadow=\"true\"/>\n\n  </worldbody>\n</mujoco>\n"
  },
  {
    "path": "assets/test_data/motion_retargeting/ACCAD/Male1Walking_c3d/Walk_B10_-_Walk_turn_left_45_stageii.npz",
    "content": "version https://git-lfs.github.com/spec/v1\noid sha256:738f96eb1767e281d78631ca697079adefb5f171d581acd622a27740f2503b4e\nsize 5876184\n"
  },
  {
    "path": "deploy.env",
    "content": "export CONDA_BASE=$(conda info --base)\nexport Deploy_CONDA_PREFIX=\"$CONDA_BASE/envs/holomotion_deploy\"\n\nexport CUDA_HOME=$Deploy_CONDA_PREFIX\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$Deploy_CONDA_PREFIX/lib/\nexport LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$Deploy_CONDA_PREFIX/lib/stubs\nexport LIBRARY_PATH=$Deploy_CONDA_PREFIX/lib:$LIBRARY_PATH\nexport LIBRARY_PATH=$Deploy_CONDA_PREFIX/lib/stubs:$LIBRARY_PATH\nexport HYDRA_FULL_ERROR=1\n\necho \"--------------------------------\"\necho \"Deploy_CONDA_PREFIX: $Deploy_CONDA_PREFIX\"\necho \"CUDA_HOME: $CUDA_HOME\"\necho \"LD_LIBRARY_PATH: $LD_LIBRARY_PATH\"\necho \"LIBRARY_PATH: $LIBRARY_PATH\"\necho \"HYDRA_FULL_ERROR: $HYDRA_FULL_ERROR\"\necho \"--------------------------------\""
  },
  {
    "path": "deployment/deploy_environment.sh",
    "content": "#!/bin/bash\n##############################################################################\n# HoloMotion Environment Deployment Script\n#\n# This script sets up the complete environment for HoloMotion humanoid robot\n# system deployment. It handles:\n# 1. Conda environment creation with all dependencies (CUDA, PyTorch, etc.)\n# 2. Special dependencies (unitree_sdk2_python)  \n# 3. ROS2 workspace compilation\n#\n# Prerequisites:\n# - Anaconda/Miniconda installed\n# - ROS2 Humble installed at /opt/ros/humble/\n# - Unitree ROS2 SDK at ~/unitree_ros2/\n#\n# Usage:\n#   chmod +x deploy_environment.sh\n#   ./deploy_environment.sh [environment_name]\n#\n# Arguments:\n#   environment_name: Optional. Name for the conda environment (default: holomotion_deploy)\n#\n# Examples:\n#   ./deploy_environment.sh                    # Uses default name 'holomotion_deploy'\n#   ./deploy_environment.sh my_robot_env      # Uses custom name 'my_robot_env'\n#\n# Author: HoloMotion Team\n##############################################################################\n\nset -e  # Exit on any error\n\n# Parse command line arguments\nENV_NAME=\"${1:-holomotion_deploy}\"\n\necho \"🚀 Starting HoloMotion Environment Deployment...\"\necho \"📝 Environment name: $ENV_NAME\"\n\n# Get script directory\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\nPROJECT_ROOT=\"$(dirname \"$(dirname \"$SCRIPT_DIR\")\")\"\n\necho \"📁 Project root: $PROJECT_ROOT\"\necho \"📁 Script directory: $SCRIPT_DIR\"\n\n# Step 1: Create conda environment with all dependencies\necho \"\"\necho \"📦 Step 1: Creating conda environment with all dependencies...\"\nif conda env list | grep -q \"^$ENV_NAME \"; then\n    echo \"⚠️  Environment '$ENV_NAME' already exists. Removing it...\"\n    conda env remove -n \"$ENV_NAME\" -y\nfi\n\necho \"🔧 Creating new environment from environment_deploy.yaml...\"\necho \"   This will install: PyTorch (CUDA), NumPy, SciPy, ONNX Runtime, and all other dependencies...\"\ncd \"$PROJECT_ROOT\"\nconda env create -f holomotion/environment_deploy.yaml -n \"$ENV_NAME\"\n\necho \"✅ Conda environment with all dependencies created successfully!\"\n\n# Step 2: Install unitree_sdk2_python\necho \"\"\necho \"📦 Step 2: Installing unitree_sdk2_python...\"\n\n# Function to run commands in conda environment\nrun_in_env() {\n    conda run -n \"$ENV_NAME\" \"$@\"\n}\n\necho \"🔧 Installing unitree_sdk2_python...\"\nif [ ! -d \"$HOME/unitree_sdk2_python\" ]; then\n    echo \"📥 Cloning unitree_sdk2_python repository...\"\n    git clone https://github.com/unitreerobotics/unitree_sdk2_python.git \"$HOME/unitree_sdk2_python\"\nfi\n\necho \"🔧 Installing unitree_sdk2_python in development mode...\"\ncd \"$HOME/unitree_sdk2_python\"\nrun_in_env pip install -e .\n\necho \"✅ unitree_sdk2_python installed successfully!\"\n\n# Step 3: Setup ROS2 workspace\necho \"\"\necho \"📦 Step 3: Setting up ROS2 workspace...\"\n\n# Ensure conda environment is completely deactivated for ROS2 compilation\necho \"🔧 Ensuring conda environment is completely deactivated...\"\n\n# Initialize conda for this script\neval \"$(conda shell.bash hook)\"\n\n# Deactivate any active conda environments\nwhile [[ \"$CONDA_DEFAULT_ENV\" != \"\" && \"$CONDA_DEFAULT_ENV\" != \"base\" ]]; do\n    echo \"  Deactivating conda environment: $CONDA_DEFAULT_ENV\"\n    conda deactivate\ndone\n\n# If we're still in base environment, deactivate it too\nif [[ \"$CONDA_DEFAULT_ENV\" == \"base\" ]]; then\n    echo \"  Deactivating base conda environment\"\n    conda deactivate\nfi\n\necho \"  ✅ Conda environment fully deactivated\"\n\n# Check ROS2 installation\nif [ ! -f \"/opt/ros/humble/setup.bash\" ]; then\n    echo \"❌ ROS2 Humble not found at /opt/ros/humble/\"\n    echo \"   Please install ROS2 Humble first: https://docs.ros.org/en/humble/Installation.html\"\n    exit 1\nfi\n\n# Check Unitree ROS2 SDK\nif [ ! -f \"$HOME/unitree_ros2/setup.sh\" ]; then\n    echo \"❌ Unitree ROS2 SDK not found at ~/unitree_ros2/\"\n    echo \"   Please install Unitree ROS2 SDK first\"\n    exit 1\nfi\n\necho \"🔧 Compiling ROS2 workspace...\"\ncd \"$PROJECT_ROOT/holomotion/deployment/unitree_g1_ros2_29dof\"\n\n# Create necessary directories\necho \"📁 Creating required directories...\"\nmkdir -p src/models\nmkdir -p src/motion_data\n\n# Clean previous build\nrm -rf build install log\n\n# Source ROS2 and Unitree setup\nsource /opt/ros/humble/setup.bash\nsource ~/unitree_ros2/setup.sh\n\n# Build workspace\necho \"🏗️  Building workspace with colcon...\"\ncolcon build\n\necho \"✅ ROS2 workspace compiled successfully!\"\n\necho \"\"\necho \"🎉 Deployment completed successfully!\"\necho \"\"\necho \"📋 Summary of installed packages:\"\necho \"   ✅ PyTorch 2.3.1 with CUDA 12.1 support\"  \necho \"   ✅ ONNX Runtime for neural network inference\"\necho \"   ✅ SMPLX for humanoid motion processing\"\necho \"   ✅ Scientific computing packages (NumPy, SciPy, etc.)\"\necho \"   ✅ Unitree SDK2 Python bindings\"\necho \"   ✅ ROS2 workspace compiled\"\necho \"\"\necho \"📋 To run the system:\"\necho \"1. Activate the conda environment:\"\necho \"   conda activate $ENV_NAME\"\necho \"\"\necho \"2. Launch the system:\"\necho \"   cd $PROJECT_ROOT/holomotion/deployment/unitree_g1_ros2_29dof\"\necho \"   bash launch_holomotion.sh\"\necho \"\"\necho \"✅ Environment '$ENV_NAME' setup complete!\"\necho \"🚀 Ready for robot deployment!\"\n"
  },
  {
    "path": "deployment/holomotion_teleop/holomotion_teleop_node.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nSingle-process teleoperation pipeline.\n\nThis node reads raw Pico body tracking data from XRoboToolkit, converts it to\nSMPL, applies GMR retargeting, and publishes a 65D observation vector to the\nrobot over ZMQ.\n\nData flow:\n    xrobotoolkit_sdk (body_poses 24x7)\n        -> body_poses_to_smpl_pose_trans\n        -> SMPL_Parser / humanoid_fk\n        -> GMR\n        -> latest_obs(65)\n        -> ZMQ PUB\n\nMessage format:\n    [topic_bytes][1280-byte JSON header][binary payload]\n\nDefault payload fields:\n    - latest_obs: (65,) float32\n    - frame_index: (1,) int64\n    - timestamp_realtime: (1,) float64\n    - timestamp_monotonic: (1,) float64\n    - timestamp_ns: (1,) int64\n    - pico_dt: (1,) float32\n    - pico_fps: (1,) float32\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nimport json\nimport logging\nimport os\nimport subprocess\nimport sys\nimport threading\nimport time\nimport traceback\nfrom types import SimpleNamespace\nfrom typing import Any, Dict, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport zmq\nfrom scipy.spatial.transform import Rotation as R\n\n\nFILE_DIR = os.path.dirname(os.path.abspath(__file__))\nHOLOMOTION_ROOT_DIR = os.path.abspath(os.path.join(FILE_DIR, \"..\", \"..\"))\nSMPL_ASSET_DIR = os.path.join(HOLOMOTION_ROOT_DIR, \"assets\", \"smpl\")\nfor extra_path in (\n    FILE_DIR,\n    os.path.join(FILE_DIR, \"GMR\"),\n    os.path.join(FILE_DIR, \"SMPLSim\"),\n):\n    if extra_path not in sys.path:\n        sys.path.insert(0, extra_path)\n\n\ntry:\n    import xrobotoolkit_sdk as xrt\nexcept ImportError:\n    xrt = None\n\nfrom third_party.GMR.general_motion_retargeting.motion_retarget import GeneralMotionRetargeting as GMR\nfrom smpl_sim.smpllib.smpl_parser import SMPL_Parser\n\n\nMIRROR_POSE = False\nMIRROR_AXIS = \"x\"\nHEADER_SIZE = 1280\nOUT_TOPIC = b\"obs65\"\n\nGMR_LR_SWAP_PAIRS = [\n    (\"left_hip\", \"right_hip\"),\n    (\"left_knee\", \"right_knee\"),\n    (\"left_foot\", \"right_foot\"),\n    (\"left_shoulder\", \"right_shoulder\"),\n    (\"left_elbow\", \"right_elbow\"),\n    (\"left_wrist\", \"right_wrist\"),\n]\n\nSMPL_PARENTS_24 = np.array(\n    [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21],\n    dtype=np.int32,\n)\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotation matrices to quaternions in wxyz format.\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        matrix.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    quat_by_rijk = torch.stack(\n        [\n            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),\n            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),\n            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),\n            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),\n        ],\n        dim=-2,\n    )\n\n    floor = torch.tensor(0.1, dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(floor))\n    return quat_candidates[\n        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,\n        :,\n    ].reshape(batch_dim + (4,))\n\n\ndef axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert axis-angle vectors to rotation matrices.\n    Input shape: (..., 3)\n    Output shape: (..., 3, 3)\n    \"\"\"\n    orig_shape = axis_angle.shape[:-1]\n    aa = axis_angle.reshape(-1, 3)\n\n    theta = torch.linalg.norm(aa, dim=-1, keepdim=True)\n    axis = aa / torch.clamp(theta, min=1e-8)\n\n    x = axis[:, 0]\n    y = axis[:, 1]\n    z = axis[:, 2]\n    zeros = torch.zeros_like(x)\n\n    K = torch.stack(\n        [\n            zeros, -z, y,\n            z, zeros, -x,\n            -y, x, zeros,\n        ],\n        dim=-1,\n    ).reshape(-1, 3, 3)\n\n    eye = torch.eye(3, dtype=aa.dtype, device=aa.device).unsqueeze(0).expand(aa.shape[0], -1, -1)\n    sin_theta = torch.sin(theta).unsqueeze(-1)\n    cos_theta = torch.cos(theta).unsqueeze(-1)\n    axis_outer = axis.unsqueeze(-1) @ axis.unsqueeze(-2)\n\n    small = (theta.squeeze(-1) < 1e-8).unsqueeze(-1).unsqueeze(-1)\n    rot = cos_theta * eye + (1.0 - cos_theta) * axis_outer + sin_theta * K\n    rot = torch.where(small, eye, rot)\n    return rot.reshape(orig_shape + (3, 3))\n\n\nclass Humanoid_Batch_V2:\n    \"\"\"\n    Minimal per-frame SMPL kinematics helper used by this script only.\n    Keeping it local avoids importing the much larger training/visualization module.\n    \"\"\"\n\n    def __init__(self, device: torch.device = torch.device(\"cpu\")):\n        self.device = device\n        self.smpl_24_parents = [\n            -1, 0, 0, 0, 1, 2, 3,\n            4, 5, 6, 7, 8, 9, 9,\n            9, 12, 13, 14, 16, 17,\n            18, 19, 20, 21,\n        ]\n\n    @staticmethod\n    def _relative_link_position(joints_world: torch.Tensor, root_pos: torch.Tensor) -> torch.Tensor:\n        return joints_world - root_pos.unsqueeze(0)\n\n    def _relative_link_pose(self, full_pose_aa: torch.Tensor) -> torch.Tensor:\n        joint_count = full_pose_aa.shape[0]\n        assert joint_count == len(self.smpl_24_parents), (\n            f\"Joint count mismatch: {joint_count} vs {len(self.smpl_24_parents)}\"\n        )\n\n        rotation_local = axis_angle_to_matrix(full_pose_aa)\n        rotation_global = torch.empty_like(rotation_local)\n        for joint_idx in range(joint_count):\n            parent = self.smpl_24_parents[joint_idx]\n            if parent == -1:\n                rotation_global[joint_idx] = rotation_local[joint_idx]\n            else:\n                rotation_global[joint_idx] = rotation_global[parent] @ rotation_local[joint_idx]\n        return rotation_global\n\n    def step_per_frame(\n        self,\n        full_pose_aa: torch.Tensor,\n        root_pos: torch.Tensor,\n        joints: torch.Tensor,\n    ) -> SimpleNamespace:\n        global_joints_position = joints\n        global_joints2root_pos = self._relative_link_position(joints[1:, :], root_pos)\n        global_joints_rotation_mat = self._relative_link_pose(full_pose_aa)\n\n        return SimpleNamespace(\n            global_joints2root_pos=global_joints2root_pos,\n            global_joints_rotation_mat=global_joints_rotation_mat,\n            global_joints_position=global_joints_position,\n        )\n\n\nhumanoid_fk = Humanoid_Batch_V2()\n\n\n@dataclass\nclass PicoToSmplConfig:\n    quat_scalar_first: bool = False\n    apply_global_y_180: bool = True\n    apply_root_rx90: bool = True\n    root_align_degrees: float = 90.0\n    root_align_axis: str = \"x\"\n\n\ndef body_poses_to_smpl_pose_trans(\n    body_poses: np.ndarray,\n    parents: np.ndarray = SMPL_PARENTS_24,\n    cfg: Optional[PicoToSmplConfig] = None,\n) -> Tuple[np.ndarray, np.ndarray]:\n    if cfg is None:\n        cfg = PicoToSmplConfig()\n\n    body_poses = np.asarray(body_poses, dtype=np.float32)\n    if body_poses.shape != (24, 7):\n        raise ValueError(f\"body_poses shape must be (24,7), got {body_poses.shape}\")\n\n    positions = body_poses[:, 0:3].astype(np.float32)\n    qx, qy, qz, qw = body_poses[:, 3], body_poses[:, 4], body_poses[:, 5], body_poses[:, 6]\n    global_quats_sfirst = np.stack([qw, qx, qy, qz], axis=1).astype(np.float32)\n    global_rots = R.from_quat(global_quats_sfirst, scalar_first=True)\n\n    if cfg.apply_global_y_180:\n        global_rots = global_rots * R.from_euler(\"y\", 180.0, degrees=True)\n\n    local_rots = []\n    for i in range(24):\n        parent = int(parents[i])\n        if parent == -1:\n            local_rots.append(global_rots[i])\n        else:\n            local_rots.append(global_rots[parent].inv() * global_rots[i])\n\n    pose_aa_24x3 = np.stack([rot.as_rotvec() for rot in local_rots], axis=0).astype(np.float32)\n    trans = positions[0].astype(np.float32)\n\n    if cfg.apply_root_rx90:\n        rot_align = R.from_euler(cfg.root_align_axis, cfg.root_align_degrees, degrees=True).as_matrix().astype(\n            np.float32\n        )\n        root_matrix = R.from_rotvec(pose_aa_24x3[0]).as_matrix().astype(np.float32)\n        pose_aa_24x3[0] = R.from_matrix(rot_align @ root_matrix).as_rotvec().astype(np.float32)\n        trans = (rot_align @ trans.reshape(3, 1)).reshape(3).astype(np.float32)\n\n    return pose_aa_24x3, trans\n\n\ndef _mirror_matrix(mirror_axis: str) -> np.ndarray:\n    if mirror_axis == \"x\":\n        return np.diag([-1.0, 1.0, 1.0]).astype(np.float32)\n    if mirror_axis == \"y\":\n        return np.diag([1.0, -1.0, 1.0]).astype(np.float32)\n    if mirror_axis == \"z\":\n        return np.diag([1.0, 1.0, -1.0]).astype(np.float32)\n    raise ValueError(f\"mirror_axis must be one of x/y/z, got {mirror_axis}\")\n\n\ndef safe_normalize_quat_wxyz(q: np.ndarray, eps: float = 1e-8) -> np.ndarray:\n    q = np.asarray(q, dtype=np.float32).reshape(4,)\n    n = float(np.linalg.norm(q))\n    if n < eps:\n        return np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)\n    return (q / n).astype(np.float32)\n\n\ndef mirror_pos_and_quat_wxyz(pos: np.ndarray, quat_wxyz: np.ndarray, mirror_axis: str) -> Tuple[np.ndarray, np.ndarray]:\n    pos = np.asarray(pos, dtype=np.float32).reshape(3,)\n    q = safe_normalize_quat_wxyz(quat_wxyz)\n    M = _mirror_matrix(mirror_axis)\n\n    pos_m = (M @ pos).astype(np.float32)\n    q_xyzw = np.array([q[1], q[2], q[3], q[0]], dtype=np.float32)\n    rot_m = R.from_quat(q_xyzw).as_matrix().astype(np.float32)\n    rot_m = (M @ rot_m @ M).astype(np.float32)\n    q_m_xyzw = R.from_matrix(rot_m).as_quat().astype(np.float32)\n    quat_m_wxyz = np.array([q_m_xyzw[3], q_m_xyzw[0], q_m_xyzw[1], q_m_xyzw[2]], dtype=np.float32)\n    return pos_m, safe_normalize_quat_wxyz(quat_m_wxyz)\n\n\ndef mirror_and_swap_gmr_input(gmr_input: Dict[str, Any], mirror_axis: str = \"x\") -> Dict[str, Any]:\n    mirrored: Dict[str, Any] = {}\n    for key, (pos, quat) in gmr_input.items():\n        mirrored[key] = mirror_pos_and_quat_wxyz(pos, quat, mirror_axis)\n\n    out = dict(mirrored)\n    for a, b in GMR_LR_SWAP_PAIRS:\n        if a in out and b in out:\n            out[a], out[b] = out[b], out[a]\n    return out\n\n\ndef pack_numpy_message(payload: dict, topic: bytes = OUT_TOPIC, version: int = 1) -> bytes:\n    fields = []\n    binary_data = []\n\n    for key, value in payload.items():\n        if not isinstance(value, np.ndarray):\n            continue\n        if value.dtype == np.float32:\n            dtype_str = \"f32\"\n        elif value.dtype == np.float64:\n            dtype_str = \"f64\"\n        elif value.dtype == np.int32:\n            dtype_str = \"i32\"\n        elif value.dtype == np.int64:\n            dtype_str = \"i64\"\n        elif value.dtype == np.uint8:\n            dtype_str = \"u8\"\n        elif value.dtype == bool:\n            dtype_str = \"bool\"\n        else:\n            dtype_str = \"f32\"\n            value = value.astype(np.float32)\n\n        if not value.flags[\"C_CONTIGUOUS\"]:\n            value = np.ascontiguousarray(value)\n        if value.dtype.byteorder == \">\":\n            value = value.astype(value.dtype.newbyteorder(\"<\"))\n\n        fields.append({\"name\": key, \"dtype\": dtype_str, \"shape\": list(value.shape)})\n        binary_data.append(value.tobytes())\n\n    header = {\"v\": version, \"endian\": \"le\", \"count\": 1, \"fields\": fields}\n    header_bytes = json.dumps(header, separators=(\",\", \":\")).encode(\"utf-8\")\n    if len(header_bytes) > HEADER_SIZE:\n        raise ValueError(f\"Header too large: {len(header_bytes)} > {HEADER_SIZE}\")\n    header_bytes = header_bytes.ljust(HEADER_SIZE, b\"\\x00\")\n    return topic + header_bytes + b\"\".join(binary_data)\n\n\nclass PicoReader:\n    def __init__(self):\n        self._stop = threading.Event()\n        self._thread = threading.Thread(target=self._run, daemon=True)\n        self._last_stamp_ns = None\n        self._fps_ema = 0.0\n        self._latest = None\n        self._lock = threading.Lock()\n\n    def start(self):\n        self._thread.start()\n\n    def stop(self):\n        self._stop.set()\n        self._thread.join(timeout=1.0)\n\n    def get_latest(self):\n        with self._lock:\n            return self._latest\n\n    def _run(self):\n        last_report = time.time()\n        while not self._stop.is_set():\n            if not xrt.is_body_data_available():\n                time.sleep(0.001)\n                continue\n\n            stamp_ns = xrt.get_time_stamp_ns()\n            prev_stamp_ns = self._last_stamp_ns\n            if prev_stamp_ns is not None and stamp_ns == prev_stamp_ns:\n                time.sleep(0.000001)\n                continue\n\n            device_dt = ((stamp_ns - prev_stamp_ns) * 1e-9) if prev_stamp_ns is not None else 0.0\n            if device_dt > 0.0:\n                inst_fps = 1.0 / device_dt\n                self._fps_ema = inst_fps if self._fps_ema == 0.0 else (0.9 * self._fps_ema + 0.1 * inst_fps)\n            self._last_stamp_ns = stamp_ns\n\n            t_realtime = time.time()\n            t_monotonic = time.monotonic()\n            try:\n                body_poses = xrt.get_body_joints_pose()\n                body_poses_np = np.asarray(body_poses, dtype=np.float32)\n                if body_poses_np.shape != (24, 7):\n                    print(f\"[PicoReader] WARNING: unexpected body_poses shape: {body_poses_np.shape}\")\n\n                sample = {\n                    \"body_poses_np\": body_poses_np,\n                    \"timestamp_realtime\": t_realtime,\n                    \"timestamp_monotonic\": t_monotonic,\n                    \"timestamp_ns\": int(stamp_ns),\n                    \"dt\": float(device_dt),\n                    \"fps\": float(self._fps_ema),\n                }\n                with self._lock:\n                    self._latest = sample\n\n                now = time.time()\n                if now - last_report >= 5.0:\n                    print(\n                        f\"[PicoReader] shape={body_poses_np.shape}, \"\n                        f\"dt_ts={device_dt * 1000.0:.2f} ms, fps={self._fps_ema:.2f}\"\n                    )\n                    last_report = now\n            except Exception as exc:\n                print(f\"[PicoReader] read error: {exc}\")\n\n\nclass ZmqObsSender:\n    def __init__(self, uri: str, logger, topic: bytes = OUT_TOPIC, mode: str = \"bind\", conflate: bool = True):\n        self.logger = logger\n        self.topic = topic\n        self._context = zmq.Context()\n        self._socket = self._context.socket(zmq.PUB)\n        self._socket.setsockopt(zmq.SNDHWM, 1)\n        if conflate and hasattr(zmq, \"CONFLATE\"):\n            self._socket.setsockopt(zmq.CONFLATE, 1)\n\n        if mode == \"bind\":\n            self._socket.bind(uri)\n        elif mode == \"connect\":\n            self._socket.connect(uri)\n        else:\n            raise ValueError(\"mode must be 'bind' or 'connect'\")\n\n        self._last_send_time = None\n        self._send_freq_log = []\n        self._frame_count = 0\n        self.logger.info(f\"[ZMQOut] sender ready: mode={mode}, uri={uri}, topic={topic.decode('utf-8')}\")\n\n    def send(self, latest_obs: np.ndarray, frame_index: int, sample_meta: dict):\n        payload = {\n            \"latest_obs\": np.asarray(latest_obs, dtype=np.float32),\n            \"frame_index\": np.array([frame_index], dtype=np.int64),\n            \"timestamp_realtime\": np.array([sample_meta[\"timestamp_realtime\"]], dtype=np.float64),\n            \"timestamp_monotonic\": np.array([sample_meta[\"timestamp_monotonic\"]], dtype=np.float64),\n            \"timestamp_ns\": np.array([sample_meta[\"timestamp_ns\"]], dtype=np.int64),\n            \"pico_dt\": np.array([sample_meta[\"dt\"]], dtype=np.float32),\n            \"pico_fps\": np.array([sample_meta[\"fps\"]], dtype=np.float32),\n        }\n        packet = pack_numpy_message(payload, topic=self.topic)\n        self._socket.send(packet)\n\n        now = time.time()\n        if self._last_send_time is not None:\n            dt = now - self._last_send_time\n            if dt > 0:\n                self._send_freq_log.append(1.0 / dt)\n                self._frame_count += 1\n                if self._frame_count >= 50:\n                    avg_freq = sum(self._send_freq_log) / len(self._send_freq_log)\n                    self.logger.info(f\"Average ZMQ send rate: {avg_freq:.2f} Hz\")\n                    self._send_freq_log.clear()\n                    self._frame_count = 0\n        self._last_send_time = now\n\n    def stop(self):\n        self._socket.close(0)\n        self._context.term()\n        self.logger.info(\"🛑 ZMQ obs sender stopped\")\n\n\nclass VRNodeXRTPicoGMRZmqOut:\n    def __init__(\n        self,\n        robot_zmq_uri: str,\n        robot_zmq_mode: str = \"bind\",\n        loop_hz: float = 55.0,\n        timing_log_every: int = 100,\n        save_obs_path: str = \"\",\n    ):\n        self.device = \"cpu\"\n        logging.getLogger(\"websockets\").setLevel(logging.WARNING)\n        self.info(f\"✅ VRNodeXRTPicoGMRZmqOut running on device={self.device}\")\n        self.info(\"starting xrt pico -> gmr -> robot zmq node\")\n\n        self.gmr = GMR(src_human=\"smplx\", tgt_robot=\"unitree_g1\")\n        self.smpl_parser = SMPL_Parser(model_path=SMPL_ASSET_DIR, gender=\"neutral\")\n        if hasattr(self.smpl_parser, \"to\"):\n            self.smpl_parser = self.smpl_parser.to(self.device)\n\n        self.betas = torch.zeros(1, 10, device=self.device)\n        self.gmr_input_data: Dict[str, Any] = {}\n        self.prev_dof_pos = None\n        self.lasttime = None\n        self.timing_log_every = max(1, timing_log_every)\n        self.save_obs_path = save_obs_path\n        self.mirror_pose = MIRROR_POSE\n        self.mirror_axis = MIRROR_AXIS\n        self.tick_count = 0\n        self.frame_index = 0\n        self.timing_sums_ms = defaultdict(float)\n        self.saved_obs = []\n        self.latest_sample = None\n\n        self.reader = PicoReader()\n        self.reader.start()\n        self.sender = ZmqObsSender(uri=robot_zmq_uri, logger=self, mode=robot_zmq_mode)\n        self.start_loop(loop_hz)\n\n    def info(self, msg): print(f\"[INFO] {msg}\")\n    def error(self, msg): print(f\"[ERROR] {msg}\")\n    def warning(self, msg): print(f\"[WARN] {msg}\")\n    def debug(self, msg): print(f\"[DEBUG] {msg}\")\n\n    def _accumulate_timing(self, name: str, start_time: float) -> float:\n        elapsed_ms = (time.perf_counter() - start_time) * 1000.0\n        self.timing_sums_ms[name] += elapsed_ms\n        return elapsed_ms\n\n    def _maybe_log_timing(self):\n        if self.tick_count <= 0 or self.tick_count % self.timing_log_every != 0:\n            return\n        avg_parts = []\n        for key in (\"body_poses_to_smpl\", \"smpl_to_joints\", \"gmr_retarget\", \"postprocess_send\", \"tick_total\"):\n            if key in self.timing_sums_ms:\n                avg_ms = self.timing_sums_ms[key] / self.timing_log_every\n                avg_parts.append(f\"{key}={avg_ms:.2f}ms\")\n        if avg_parts:\n            self.info(\"[Timing] \" + \", \".join(avg_parts))\n        self.timing_sums_ms.clear()\n\n    def process_smpl_pose_trans_to_gmr_input(self, smpl_pose_aa, smpl_trans) -> Dict[str, Any]:\n        stage_start = time.perf_counter()\n        if not isinstance(smpl_pose_aa, torch.Tensor):\n            smpl_pose_aa = torch.tensor(smpl_pose_aa, dtype=torch.float32)\n        if not isinstance(smpl_trans, torch.Tensor):\n            smpl_trans = torch.tensor(smpl_trans, dtype=torch.float32)\n\n        pose = smpl_pose_aa.to(self.device, dtype=torch.float32)\n        trans = smpl_trans.to(self.device, dtype=torch.float32)\n        if pose.ndim == 2:\n            pose = pose.unsqueeze(0)\n        if trans.ndim == 1:\n            trans = trans.unsqueeze(0)\n\n        verts, joints = self.smpl_parser.get_joints_verts(pose, self.betas, trans)\n        # joints[..., 2] -= verts[0, :, 2].min().item()\n\n        pose = pose.squeeze(0)\n        trans = trans.squeeze(0)\n        joints = joints.squeeze(0)\n        motion_state = humanoid_fk.step_per_frame(pose, trans, joints)\n\n        global_joints_position = motion_state.global_joints_position\n        global_joints_rotation_mat = motion_state.global_joints_rotation_mat\n        global_joints_qua_wxyz = matrix_to_quaternion(global_joints_rotation_mat)\n\n        smpl_to_gmr = {\n            \"pelvis\": 0,\n            \"spine3\": 9,\n            \"left_hip\": 1,\n            \"right_hip\": 2,\n            \"left_knee\": 4,\n            \"right_knee\": 5,\n            \"left_foot\": 10,\n            \"right_foot\": 11,\n            \"left_shoulder\": 16,\n            \"right_shoulder\": 17,\n            \"left_elbow\": 18,\n            \"right_elbow\": 19,\n            \"left_wrist\": 20,\n            \"right_wrist\": 21,\n        }\n\n        gmr_input_data: Dict[str, Any] = {}\n        for name, idx in smpl_to_gmr.items():\n            pos = global_joints_position[idx].detach().cpu().numpy()\n            quat = global_joints_qua_wxyz[idx].detach().cpu().numpy()\n            gmr_input_data[name] = (pos, quat)\n\n        if self.mirror_pose:\n            gmr_input_data = mirror_and_swap_gmr_input(gmr_input_data, mirror_axis=self.mirror_axis)\n\n        self._accumulate_timing(\"smpl_to_joints\", stage_start)\n        return gmr_input_data\n\n    def process_xrt_frame_to_gmr_input(self, sample: dict):\n        body_poses = np.asarray(sample[\"body_poses_np\"], dtype=np.float32)\n        if body_poses.shape != (24, 7):\n            raise ValueError(f\"[XRT] body_poses_np must have shape (24,7), got {body_poses.shape}\")\n\n        stage_start = time.perf_counter()\n        pose_aa, trans = body_poses_to_smpl_pose_trans(\n            body_poses,\n            cfg=PicoToSmplConfig(\n                apply_global_y_180=True,\n                apply_root_rx90=True,\n                root_align_axis=\"x\",\n                root_align_degrees=90.0,\n            ),\n        )\n        self._accumulate_timing(\"body_poses_to_smpl\", stage_start)\n        self.gmr_input_data = self.process_smpl_pose_trans_to_gmr_input(pose_aa, trans)\n\n    def process_gmr_output(self):\n        stage_start = time.perf_counter()\n        qpos = self.gmr.retarget(self.gmr_input_data)\n        self._accumulate_timing(\"gmr_retarget\", stage_start)\n\n        stage_start = time.perf_counter()\n        root_pos = qpos[:3]\n        root_rot = qpos[3:7]\n        dof_pos = qpos[7:]\n\n        now = time.time()\n        delta_time = 1 / 50 if self.lasttime is None else (now - self.lasttime)\n        self.lasttime = now\n\n        if self.prev_dof_pos is None:\n            dof_vel = np.zeros_like(dof_pos, dtype=np.float32)\n        else:\n            dof_vel = (dof_pos - self.prev_dof_pos) / max(delta_time, 1e-6)\n        self.prev_dof_pos = dof_pos.copy()\n\n        latest_obs = np.concatenate([dof_pos, dof_vel, root_pos, root_rot], axis=0).astype(np.float32)\n        self.publish_data(latest_obs)\n        self.sender.send(latest_obs, self.frame_index, self.latest_sample)\n        self.saved_obs.append(latest_obs.copy())\n        self.frame_index += 1\n        self._accumulate_timing(\"postprocess_send\", stage_start)\n        return latest_obs\n\n    def publish_data(self, motion_state: np.ndarray):\n        if motion_state.size != 65:\n            self.error(f\"Output dim {motion_state.size} != expected 65\")\n            return\n        if np.isnan(motion_state).any():\n            self.error(\"NaN detected\")\n            return\n\n    def save_observations(self):\n        if not self.save_obs_path:\n            return\n        if len(self.saved_obs) == 0:\n            self.warning(f\"[SaveObs] no observations to save: {self.save_obs_path}\")\n            return\n\n        obs_array = np.stack(self.saved_obs, axis=0).astype(np.float32)\n        save_dir = os.path.dirname(self.save_obs_path)\n        if save_dir:\n            os.makedirs(save_dir, exist_ok=True)\n\n        if self.save_obs_path.endswith(\".npy\"):\n            np.save(self.save_obs_path, obs_array)\n        else:\n            np.savez_compressed(\n                self.save_obs_path,\n                latest_obs=obs_array,\n                columns=np.array([\"dof_pos(29)\", \"dof_vel(29)\", \"root_pos(3)\", \"root_rot_wxyz(4)\"], dtype=object),\n            )\n        self.info(f\"[SaveObs] saved {obs_array.shape[0]} frames to {self.save_obs_path}\")\n\n    def start_loop(self, hz=50):\n        self.info(f\"Starting main loop at {hz} Hz\")\n        interval = 1.0 / hz\n\n        def loop():\n            next_time = time.time()\n            while True:\n                self._tick()\n                next_time += interval\n                sleep_time = next_time - time.time()\n                if sleep_time > 0:\n                    time.sleep(sleep_time)\n                else:\n                    next_time = time.time()\n\n        threading.Thread(target=loop, daemon=True).start()\n\n    def _tick(self):\n        tick_start = time.perf_counter()\n        sample = self.reader.get_latest()\n        if sample is not None:\n            try:\n                self.latest_sample = sample\n                self.process_xrt_frame_to_gmr_input(sample)\n                self.process_gmr_output()\n            except Exception as exc:\n                self.error(f\"[tick_error] {exc}\")\n                self.error(traceback.format_exc())\n                return\n        elif self.prev_dof_pos is not None:\n            try:\n                self.process_gmr_output()\n            except Exception as exc:\n                self.error(f\"[tick_error] {exc}\")\n                self.error(traceback.format_exc())\n                return\n\n        self.tick_count += 1\n        self._accumulate_timing(\"tick_total\", tick_start)\n        self._maybe_log_timing()\n\n    def stop(self):\n        self.reader.stop()\n        self.sender.stop()\n        self.save_observations()\n        try:\n            if xrt is not None and hasattr(xrt, \"close\"):\n                xrt.close()\n        except Exception:\n            pass\n\n\ndef init_xrt(start_service: bool = True):\n    if xrt is None:\n        raise ImportError(\"XRoboToolkit SDK not available. Install xrobotoolkit_sdk first.\")\n    if start_service:\n        subprocess.Popen([\"bash\", \"/opt/apps/roboticsservice/runService.sh\"])\n    xrt.init()\n    print(\"Waiting for body tracking data...\")\n    while not xrt.is_body_data_available():\n        print(\"waiting for body data...\")\n        time.sleep(1)\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"XRT Pico -> GMR -> robot ZMQ(65D)\")\n    parser.add_argument(\"--robot-zmq-uri\", default=\"tcp://*:6001\", help=\"Robot-side ZMQ uri for 65D obs output\")\n    parser.add_argument(\"--robot-zmq-mode\", default=\"bind\", choices=[\"bind\", \"connect\"])\n    parser.add_argument(\"--hz\", type=float, default=55.0, help=\"Main loop frequency / publish cap\")\n    parser.add_argument(\"--timing-log-every\", type=int, default=200, help=\"Print average stage timing every N ticks\")\n    parser.add_argument(\"--save-obs-path\", type=str, default=\"\", help=\"Optional path to save emitted 65D observations\")\n    parser.add_argument(\"--skip-start-service\", action=\"store_true\", help=\"Do not auto-run /opt/apps/roboticsservice/runService.sh\")\n    args = parser.parse_args()\n\n    init_xrt(start_service=not args.skip_start_service)\n    node = VRNodeXRTPicoGMRZmqOut(\n        robot_zmq_uri=args.robot_zmq_uri,\n        robot_zmq_mode=args.robot_zmq_mode,\n        loop_hz=args.hz,\n        timing_log_every=args.timing_log_every,\n        save_obs_path=args.save_obs_path,\n    )\n    try:\n        while True:\n            time.sleep(1)\n    except KeyboardInterrupt:\n        node.stop()\n        print(\"🛑 Program terminated by user.\")\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "deployment/holomotion_teleop/holomotion_teleop_setup.md",
    "content": "# Holomotion Teleop\n\nSingle-process pipeline for:\n\n`PICO / XRoboToolkit -> SMPL conversion -> GMR retargeting -> robot ZMQ`\n\n\n## Prerequisites\n\nBefore setting up the Python environment, install XRoboToolkit PC Service manually.\n\n1. On Ubuntu 22.04, download the XRoboToolkit PC Service `.deb` package, or build it from source.\n\n```bash\nsudo dpkg -i XRoboToolkit_PC_Service_1.0.0_ubuntu_22.04_amd64.deb\n```\n\n\n## Environment Setup\n\n```bash\ncd /path/to/holomotion_teleop\nbash setup_holomotion_teleop_x86_ubuntu2204.sh\n```\n\nThis script will:\n\n- create the Conda environment `holomotion_teleop`\n- automatically clone and install `GMR` and `SMPLSim`\n- install runtime dependencies such as `numpy==1.23.5`, `torch`, and `pyzmq`\n- build and install `xrobotoolkit_sdk` from source\n\n\nOptional environment variables:\n\n```bash\nENV_NAME=holomotion_teleop\nPYTHON_VERSION=3.10\nINSTALL_APT_DEPS=auto\nTHIRD_PARTY_DIR=/path/to/third_party\nGMR_SOURCE_DIR=/path/to/GMR\nSMPLSIM_SOURCE_DIR=/path/to/SMPLSim\nXRT_PYBIND_REPO_DIR=/path/to/XRoboToolkit-PC-Service-Pybind\n```\n\n- `INSTALL_APT_DEPS=auto`: only runs apt installation if required build tools are missing\n- `INSTALL_APT_DEPS=0`: skip apt installation entirely if your machine already has the tools or apt is unusable\n- `INSTALL_APT_DEPS=1`: force the apt installation step\n- `THIRD_PARTY_DIR`: default directory used for auto-cloned third-party repositories\n- `GMR_SOURCE_DIR` / `SMPLSIM_SOURCE_DIR`: point to external source checkouts; if omitted, the script auto-clones them\n\n\n## Input and Output\n\n### Input\n\nThe script reads raw body tracking data directly from `xrobotoolkit_sdk.get_body_joints_pose()`:\n\n- shape: `(24, 7)`\n- row format: `[x, y, z, qx, qy, qz, qw]`\n\n### Output\n\nThe robot-side ZMQ payload contains `latest_obs` as `float32[65]`:\n\n1. `dof_pos[29]`\n2. `dof_vel[29]`\n3. `root_pos[3]`\n4. `root_rot_wxyz[4]`\n\nAdditional metadata is included in the same payload:\n\n- `frame_index`\n- `timestamp_realtime`\n- `timestamp_monotonic`\n- `timestamp_ns`\n- `pico_dt`\n- `pico_fps`\n\n## Next Steps\n\nBefore running teleoperation on the real robot, make sure the operators are already familiar with the offline `.npz` motion-performance workflow and the robot's basic mode-switching behavior. Teleoperation should not be the first time the team tests motion-mode entry on hardware.\n\n\n### Real Robot Workflow\n\nUse the following checklist when running the teleoperation stack on the real robot.\n\n#### 1. Hardware and Network\n\nRequired hardware:\n\n- PICO 4 / PICO 4 Pro headset\n- 2 PICO controllers\n- 2 PICO motion trackers attached to the ankles\n- One workstation running `holomotion_teleop_node.py`\n- One robot computer running the policy / control stack\n- A low-latency Wi-Fi network shared by the PICO headset and the workstation\n\nMake sure the robot, the workstation and the PICO headset are on the same Wi-Fi network. Low network latency is important for stable teleoperation. The PICO-side setup steps below follow the XRoboToolkit / PICO workflow described in the [GR00T VR Teleop Setup (PICO)](https://nvlabs.github.io/GR00T-WholeBodyControl/getting_started/vr_teleop_setup.html).\n\n#### 2. Install and Configure PICO\n\n1. Install the XRoboToolkit PICO app on the headset.\n   - Enable Developer Mode on the headset.\n   - Open the browser on PICO and download the XRoboToolkit PICO APK.\n   - Install the APK from the downloads page and confirm it appears in the app library.\n2. Pair the two PICO motion trackers.\n   - Attach one tracker to each ankle.\n   - Open the motion tracker settings on the headset.\n   - Unpair any old trackers first, then pair both trackers again.\n3. Calibrate the motion trackers on the headset.\n   - Follow the standing calibration step.\n   - Then look down at the foot trackers so the headset cameras can detect them.\n4. Connect the headset to the workstation.\n   - Confirm the headset and workstation are on the same Wi-Fi network.\n   - Open the XRoboToolkit app on the headset.\n   - Enter the workstation IP address into the PC Service field.\n   - Verify the status shows a successful connection.\n5. In XRoboToolkit, enable the required streaming options.\n   - Enable `Head` and `Controller` tracking.\n   - Set `Pico Motion Tracker` to `Full body`.\n   - Enable the `Send` option for data/control streaming.\n\n#### 3. Configure the Robot-Side Policy\n\nBefore starting the robot-side policy, update the robot config file:\n\n`HoloMotion/deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml`\n\nRecommended settings:\n\n- `enable_teleop_reference: true`\n- `require_vr_data_for_motion: true`\n- `latest_obs_zmq_uri: \"tcp://<workstation-ip>:6001\"`\n\nReplace `<workstation-ip>` with the actual IP address of the workstation that runs `holomotion_teleop_node.py`.\n\nThis ensures the robot waits for live VR data before switching into motion mode and connects to the correct ZMQ publisher endpoint.\n\n#### 4. Launch Order\n\nStart the system in the following order:\n\n1. Start the robot control / policy stack on the robot computer.\n2. Wait until the control policy is fully initialized, then press `Start` to move the robot into the default pose.\n3. Start XRoboToolkit on the PICO headset and confirm that body-tracking data is being streamed.\n4. Start the teleoperation node on the workstation:\n\n```bash\nconda activate holomotion_teleop\ncd /path/to/holomotion_teleop\npython holomotion_teleop_node.py\n```\n\nIf needed, pass explicit ZMQ arguments such as:\n\n```bash\npython holomotion_teleop_node.py \\\n  --robot-zmq-uri tcp://*:6001 \\\n  --robot-zmq-mode bind \\\n  --hz 50\n```\n\n5. After the robot-side policy is receiving live teleoperation data, perform the runtime mode sequence:\n   - press `A` to enter walking / velocity mode\n   - press `B` to enter teleoperation motion mode\n   - press `Y` whenever you want to leave teleoperation and return to walking mode\n## Optional Arguments\n\n- `--robot-zmq-uri`: robot-side ZMQ endpoint for the 65D output\n- `--robot-zmq-mode`: `bind` or `connect`\n- `--hz`: main loop frequency / processing cap\n- `--timing-log-every`: print average stage timing every N ticks\n- `--save-obs-path`: save emitted 65D observations on exit as `.npy` or `.npz`\n\n#### 5. Runtime Check\n\nBefore enabling motion on the robot:\n\n- confirm XRoboToolkit PC Service is running\n- confirm the PICO headset is connected to the workstation\n- confirm `holomotion_teleop_node.py` is publishing ZMQ data\n- confirm the robot-side policy is using the correct workstation IP in `latest_obs_zmq_uri`\n- confirm the robot-side config keeps `enable_teleop_reference: true`\n- confirm the robot-side config keeps `require_vr_data_for_motion: true`\n- confirm the team has already validated the offline `.npz` motion-performance pipeline before attempting live teleoperation\n\nOnce the ZMQ stream is stable, enable the robot policy and switch into motion mode."
  },
  {
    "path": "deployment/holomotion_teleop/setup_holomotion_teleop_x86_ubuntu2204.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\n# One-click setup script for the holomotion teleoperation environment.\n#\n# This script automates the manually verified workflow:\n# 1. create/activate conda env\n# 2. clone/install GMR\n# 3. clone/build/install XRoboToolkit pybind SDK\n# 4. clone/install SMPLSim\n# 5. install runtime Python dependencies\n#\n# Usage:\n#   bash setup_gmr_holomotion_teleop_ubuntu2204.sh\n#\n# Optional env vars:\n#   ENV_NAME=holomotion_teleop\n#   PYTHON_VERSION=3.10\n#   INSTALL_APT_DEPS=0             # default disabled; set to 1 only if you need apt\n#   THIRD_PARTY_DIR=/path/to/third_party\n#   GMR_SOURCE_DIR=/path/to/GMR\n#   SMPLSIM_SOURCE_DIR=/path/to/SMPLSim\n#   XRT_PYBIND_REPO_DIR=/path/to/XRoboToolkit-PC-Service-Pybind\n\nENV_NAME=\"${ENV_NAME:-holomotion_teleop}\"\nPYTHON_VERSION=\"${PYTHON_VERSION:-3.10}\"\nINSTALL_APT_DEPS=\"${INSTALL_APT_DEPS:-0}\"\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\nPROJECT_ROOT=\"${SCRIPT_DIR}\"\nTHIRD_PARTY_DIR=\"${THIRD_PARTY_DIR:-$PROJECT_ROOT/third_party}\"\n\nGMR_REPO_URL=\"${GMR_REPO_URL:-https://github.com/YanjieZe/GMR.git}\"\nSMPLSIM_REPO_URL=\"${SMPLSIM_REPO_URL:-https://github.com/ZhengyiLuo/SMPLSim.git}\"\nXRT_PYBIND_REPO_URL=\"${XRT_PYBIND_REPO_URL:-https://github.com/YanjieZe/XRoboToolkit-PC-Service-Pybind.git}\"\nXRT_PC_SERVICE_REPO_URL=\"${XRT_PC_SERVICE_REPO_URL:-https://github.com/XR-Robotics/XRoboToolkit-PC-Service.git}\"\n\nGMR_SOURCE_DIR=\"${GMR_SOURCE_DIR:-$THIRD_PARTY_DIR/GMR}\"\nSMPLSIM_SOURCE_DIR=\"${SMPLSIM_SOURCE_DIR:-$THIRD_PARTY_DIR/SMPLSim}\"\nXRT_PYBIND_REPO_DIR=\"${XRT_PYBIND_REPO_DIR:-$THIRD_PARTY_DIR/XRoboToolkit-PC-Service-Pybind}\"\n\ninfo() {\n  echo \"[INFO] $*\"\n}\n\nwarn() {\n  echo \"[WARN] $*\" >&2\n}\n\nerror() {\n  echo \"[ERROR] $*\" >&2\n  exit 1\n}\n\nrequire_command() {\n  local cmd=\"$1\"\n  local hint=\"${2:-}\"\n  if ! command -v \"$cmd\" >/dev/null 2>&1; then\n    if [[ -n \"$hint\" ]]; then\n      error \"$cmd not found. $hint\"\n    else\n      error \"$cmd not found.\"\n    fi\n  fi\n}\n\nrun_conda_relaxed() {\n  # Some conda activation/deactivation hooks are not compatible with `set -u`\n  # and may reference unset variables such as SETVARS_CALL.\n  set +u\n  \"$@\"\n  local status=$?\n  set -u\n  return $status\n}\n\nshow_env_summary() {\n  info \"project root: $PROJECT_ROOT\"\n  info \"env name: $ENV_NAME\"\n  info \"python version: $PYTHON_VERSION\"\n  info \"install apt deps: $INSTALL_APT_DEPS\"\n  info \"third party dir: $THIRD_PARTY_DIR\"\n  info \"gmr source dir: $GMR_SOURCE_DIR\"\n  info \"smplsim source dir: $SMPLSIM_SOURCE_DIR\"\n  info \"xrt pybind dir: $XRT_PYBIND_REPO_DIR\"\n}\n\ncheck_platform() {\n  if [[ \"$(uname -s)\" != \"Linux\" ]]; then\n    error \"This setup script currently supports Linux only.\"\n  fi\n\n  if [[ -f /etc/os-release ]]; then\n    # shellcheck disable=SC1091\n    source /etc/os-release\n    info \"detected OS: ${PRETTY_NAME:-unknown}\"\n    if [[ \"${ID:-}\" != \"ubuntu\" || \"${VERSION_ID:-}\" != \"22.04\" ]]; then\n      warn \"This script is primarily tested on Ubuntu 22.04. Continuing anyway.\"\n    fi\n  fi\n}\n\napt_deps_missing() {\n  local missing=0\n  command -v gcc >/dev/null 2>&1 || missing=1\n  command -v g++ >/dev/null 2>&1 || missing=1\n  command -v make >/dev/null 2>&1 || missing=1\n  command -v git >/dev/null 2>&1 || missing=1\n  command -v cmake >/dev/null 2>&1 || missing=1\n  return \"$missing\"\n}\n\ninstall_apt_deps_if_needed() {\n  case \"$INSTALL_APT_DEPS\" in\n    0|false|False|FALSE|no|NO)\n      info \"Skipping apt dependency installation because INSTALL_APT_DEPS=$INSTALL_APT_DEPS\"\n      info \"This matches the manually verified workflow and avoids unrelated apt source failures\"\n      return\n      ;;\n    1|true|True|TRUE|yes|YES)\n      ;;\n    *)\n      error \"Unsupported INSTALL_APT_DEPS value: $INSTALL_APT_DEPS (expected 1 or 0)\"\n      ;;\n  esac\n\n  require_command sudo \"Install sudo or run the equivalent apt commands manually.\"\n  require_command apt-get \"This script needs apt-get to install build tools.\"\n\n  info \"Installing apt packages needed for build\"\n  if ! sudo apt-get update; then\n    error \"apt-get update failed. Common causes: broken apt sources, third-party repository timeouts, or proxy/network issues.\"\n  fi\n\n  if ! sudo apt-get install -y build-essential git cmake; then\n    cat >&2 <<'EOF'\n[ERROR] apt package installation failed.\n\nTry one of the following:\n  1. sudo apt --fix-broken install\n  2. disable broken third-party apt repositories temporarily\n  3. rerun with INSTALL_APT_DEPS=0 if gcc/g++/make/git/cmake already exist\nEOF\n    exit 1\n  fi\n}\n\nsetup_conda_env() {\n  require_command conda \"Please install Miniconda or Anaconda first.\"\n  # shellcheck disable=SC1091\n  source \"$(conda info --base)/etc/profile.d/conda.sh\"\n\n  if ! conda env list | awk '{print $1}' | grep -Fx \"$ENV_NAME\" >/dev/null 2>&1; then\n    info \"Creating conda env: $ENV_NAME\"\n    run_conda_relaxed conda create -n \"$ENV_NAME\" \"python=$PYTHON_VERSION\" -y\n  else\n    info \"Conda env already exists: $ENV_NAME\"\n  fi\n\n  run_conda_relaxed conda activate \"$ENV_NAME\"\n}\n\nclone_repo_if_missing() {\n  local repo_dir=\"$1\"\n  local repo_url=\"$2\"\n  local repo_name=\"$3\"\n\n  if [[ ! -d \"$repo_dir/.git\" ]]; then\n    info \"Cloning $repo_name\"\n    mkdir -p \"$(dirname \"$repo_dir\")\"\n    git clone \"$repo_url\" \"$repo_dir\"\n  else\n    info \"Using existing $repo_name checkout at $repo_dir\"\n  fi\n}\n\ninstall_gmr() {\n  clone_repo_if_missing \"$GMR_SOURCE_DIR\" \"$GMR_REPO_URL\" \"GMR\"\n  info \"Installing GMR in editable mode\"\n  python -m pip install -e \"$GMR_SOURCE_DIR\"\n}\n\nbuild_xrt_python_sdk() {\n  clone_repo_if_missing \"$XRT_PYBIND_REPO_DIR\" \"$XRT_PYBIND_REPO_URL\" \"XRoboToolkit pybind repository\"\n\n  pushd \"$XRT_PYBIND_REPO_DIR\" >/dev/null\n\n  mkdir -p tmp\n  clone_repo_if_missing \"tmp/XRoboToolkit-PC-Service\" \"$XRT_PC_SERVICE_REPO_URL\" \"XRoboToolkit PC Service source\"\n\n  info \"Building PXREARobotSDK\"\n  pushd tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK >/dev/null\n  bash build.sh\n  popd >/dev/null\n\n  mkdir -p lib include\n  cp tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/PXREARobotSDK.h include/\n  cp -r tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/nlohmann include/nlohmann/\n  cp tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/build/libPXREARobotSDK.so lib/\n\n  info \"Installing pybind11 into conda env\"\n  run_conda_relaxed conda install -y -c conda-forge pybind11\n\n  info \"Reinstalling xrobotoolkit_sdk\"\n  python -m pip uninstall -y xrobotoolkit_sdk || true\n  python setup.py install\n\n  popd >/dev/null\n}\n\ninstall_smplsim() {\n  clone_repo_if_missing \"$SMPLSIM_SOURCE_DIR\" \"$SMPLSIM_REPO_URL\" \"SMPLSim\"\n  info \"Installing SMPLSim in editable mode\"\n  python -m pip install -e \"$SMPLSIM_SOURCE_DIR\"\n}\n\ninstall_runtime_python_deps() {\n  info \"Upgrading pip toolchain\"\n  python -m pip install --upgrade pip setuptools wheel\n\n  info \"Installing runtime Python packages\"\n  python -m pip install pyzmq\n  python -m pip install open3d\n}\n\ninstall_compat_python_deps() {\n  info \"Installing compatibility packages\"\n  python -m pip install chumpy\n  info \"Pinning numpy for chumpy compatibility\"\n  python -m pip install --upgrade \"numpy==1.23.5\"\n}\n\nprint_next_steps() {\n  echo\n  info \"Environment setup complete\"\n  echo\n  info \"Manual prerequisite:\"\n  echo \"  Install XRoboToolkit PC Service manually from the Ubuntu 22.04 .deb package.\"\n  echo \"  Launch xrobotoolkit-pc-service before teleoperation.\"\n  echo\n  info \"Activate with:\"\n  echo \"  conda activate $ENV_NAME\"\n  echo\n  info \"Example command:\"\n  echo \"  python \\\"$PROJECT_ROOT/holomotion_teleop_node.py\\\" \\\\\"\n  echo \"    --robot-zmq-uri tcp://*:6001 \\\\\"\n  echo \"    --robot-zmq-mode bind \\\\\"\n  echo \"    --hz 50 \\\\\"\n  echo \"    --timing-log-every 250\"\n}\n\nmain() {\n  check_platform\n  show_env_summary\n  install_apt_deps_if_needed\n  setup_conda_env\n  install_gmr\n  build_xrt_python_sdk\n  install_runtime_python_deps\n  install_smplsim\n  install_compat_python_deps\n  print_next_steps\n}\n\nmain \"$@\"\n\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/launch_holomotion_29dof.sh",
    "content": "#!/bin/bash\n\n##############################################################################\n# HoloMotion Deployment Launch Script\n#\n# This script sets up the complete environment and launches the HoloMotion\n# humanoid robot control system for the Unitree G1 robot. It handles:\n# 1. ROS2 environment setup and workspace building\n# 2. Conda environment configuration for GPU/CUDA support\n# 3. Library path configuration for proper linking\n# 4. Launch of the complete HoloMotion control pipeline\n#\n# Prerequisites:\n# - Unitree ROS2 SDK properly installed at ~/unitree_ros2/\n# - Conda environment 'holomotion_deploy' with required packages\n# - Network interface configured for robot communication\n# - Proper permissions for robot hardware access\n#\n# Usage:\n#   ./launch_holomotion_29dof.sh [--record]\n#   --record: Enable topic recording (optional, disabled by default)\n#\n# Author: HoloMotion Team\n# License: See project LICENSE file\n##############################################################################\n\n# Default values\nENABLE_RECORDING=false\n\n# Parse command line arguments\nwhile [[ $# -gt 0 ]]; do\n    case $1 in\n        --record)\n            ENABLE_RECORDING=true\n            shift\n            ;;\n        -h|--help)\n            echo \"Usage: $0 [--record]\"\n            echo \"  --record: Enable topic recording (optional, disabled by default)\"\n            exit 0\n            ;;\n        *)\n            echo \"Unknown option $1\"\n            echo \"Usage: $0 [--record]\"\n            echo \"  --record: Enable topic recording (optional, disabled by default)\"\n            exit 1\n            ;;\n    esac\ndone\n\necho \"Starting HoloMotion 29DOF...\"\necho \"Recording enabled: $ENABLE_RECORDING\"\nrm -rf build/ install/ log/ 2>/dev/null || sudo rm -rf build/ install/ log/\nsource ~/miniconda3/bin/activate\nwhile [[ ${CONDA_SHLVL:-0} -gt 0 ]]; do\n    conda deactivate\ndone\nsource /opt/ros/humble/setup.sh\nsource ~/unitree_ros2/setup.sh\ncolcon build\nsource install/setup.bash\n\nsource ../../deploy.env\n\n# Launch with recording parameter\nros2 launch humanoid_control holomotion_29dof_launch.py enable_recording:=$ENABLE_RECORDING"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/launch_holomotion_29dof_docker.sh",
    "content": "#!/bin/bash\n\n##############################################################################\n# HoloMotion Deployment Launch Script\n#\n# This script sets up the complete environment and launches the HoloMotion\n# humanoid robot control system for the Unitree G1 robot. It handles:\n# 1. ROS2 environment setup and workspace building\n# 2. Conda environment configuration for GPU/CUDA support\n# 3. Library path configuration for proper linking\n# 4. Launch of the complete HoloMotion control pipeline\n#\n# Prerequisites:\n# - Unitree ROS2 SDK properly installed at ~/unitree_ros2/\n# - Conda environment 'holomotion_deploy' with required packages\n# - Network interface configured for robot communication\n# - Proper permissions for robot hardware access\n#\n# Usage:\n#   ./launch_holomotion_29dof_docker.sh [--record]\n#   --record: Enable topic recording (optional, disabled by default)\n#\n# Author: HoloMotion Team\n# License: See project LICENSE file\n##############################################################################\n\n# Default values\nENABLE_RECORDING=false\n\n# Parse command line arguments\nwhile [[ $# -gt 0 ]]; do\n    case $1 in\n        --record)\n            ENABLE_RECORDING=true\n            shift\n            ;;\n        -h|--help)\n            echo \"Usage: $0 [--record]\"\n            echo \"  --record: Enable topic recording (optional, disabled by default)\"\n            exit 0\n            ;;\n        *)\n            echo \"Unknown option $1\"\n            echo \"Usage: $0 [--record]\"\n            echo \"  --record: Enable topic recording (optional, disabled by default)\"\n            exit 1\n            ;;\n    esac\ndone\n\necho \"Starting HoloMotion 29DOF Docker...\"\necho \"Recording enabled: $ENABLE_RECORDING\"\nsource /root/miniconda3/etc/profile.d/conda.sh\nwhile [[ ${CONDA_SHLVL:-0} -gt 0 ]]; do\n    conda deactivate\ndone\nrm -rf build/ install/ log/\nsource /opt/ros/humble/setup.sh \nsource /root/unitree_ros2/setup.sh\n\ncolcon build\nsource install/setup.bash\n\n# Configure conda environment paths for CUDA and library linking\n# NOTE: Update this path to match your actual conda environment location\nexport CYCLONEDDS_HOME=/root/cyclonedds/install\nexport CMAKE_PREFIX_PATH=$CYCLONEDDS_HOME:$CMAKE_PREFIX_PATH\nsource ../../deploy.env\nexport LD_LIBRARY_PATH=/host_gpu:/cuda_base:/usr/lib/aarch64-linux-gnu/tegra:/usr/lib/aarch64-linux-gnu:/usr/local/cuda/lib64:/lib/aarch64-linux-gnu/:$LD_LIBRARY_PATH\n# Launch with recording parameter\nros2 launch humanoid_control holomotion_29dof_launch.py enable_recording:=$ENABLE_RECORDING"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.8)\nproject(humanoid_control)\n\n# Default to C99\nif(NOT CMAKE_C_STANDARD)\n  set(CMAKE_C_STANDARD 99)\nendif()\n\n# Default to C++14\nif(NOT CMAKE_CXX_STANDARD)\n  set(CMAKE_CXX_STANDARD 17)\nendif()\n\nif(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES \"Clang\")\n  add_compile_options(-Wall -Wextra -Wpedantic)\nendif()\n\n\ninclude_directories(include include/common include/nlohmann)\nlink_directories(src)\n\nset(\n  DEPENDENCY_LIST\n  unitree_go\n  unitree_hg\n  unitree_api\n  rclcpp\n  std_msgs\n  rosbag2_cpp\n  yaml-cpp\n)\n\n# find dependencies\nfind_package(ament_cmake REQUIRED)\nfind_package(ament_cmake_python REQUIRED)\nfind_package(unitree_go REQUIRED)\nfind_package(unitree_hg REQUIRED)\nfind_package(unitree_api REQUIRED)\nfind_package(rclcpp REQUIRED)\nfind_package(std_msgs REQUIRED)\nfind_package(rosbag2_cpp REQUIRED)\nfind_package(yaml-cpp REQUIRED)\n\n\n# Main control executable\nadd_executable(\n  humanoid_control\n  src/main_node.cpp\n  src/common/motor_crc_hg.cpp\n  src/common/wireless_controller.cpp\n)\n\n\nament_target_dependencies(humanoid_control ${DEPENDENCY_LIST})\n\n\n# Install Python modules\nament_python_install_package(humanoid_policy)\n\n# Install Python scripts as executables\n\ninstall(PROGRAMS\n  humanoid_policy/policy_node_29dof.py\n  DESTINATION lib/${PROJECT_NAME}\n  RENAME policy_node_29dof\n)\n\n# Install your models directory\ninstall(DIRECTORY\n  models/\n  DESTINATION share/${PROJECT_NAME}/models\n)\n\ninstall(TARGETS\n  humanoid_control\n  DESTINATION lib/${PROJECT_NAME})\n\n# motion folder\ninstall(DIRECTORY\n  config/\n  DESTINATION share/${PROJECT_NAME}/config\n)\n\ninstall(DIRECTORY\n  motion_data/\n  DESTINATION share/${PROJECT_NAME}/motion_data\n)\n# Install launch files\ninstall(\n  DIRECTORY launch/\n  DESTINATION share/${PROJECT_NAME}\n)\n\nif(BUILD_TESTING)\n  find_package(ament_lint_auto REQUIRED)\n  ament_lint_auto_find_test_dependencies()\nendif()\n\nament_package()\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml",
    "content": "device: \"cuda\"\npolicy_freq: 50 # Hz\ncontrol_freq: 500 # Hz\nlowstate_topic: \"/lowstate\"\naction_topic: \"/humanoid/action\"\n\n# walking policy\nvelocity_tracking_model_folder: \"velocity_tracking_model\"\n# motion policy\nmotion_tracking_model_folder: \"motion_tracking_model\"  \n\n# motion data\nmotion_clip_dir: \"motion_data\" \n\n\ncpu_affinity_main: \"\"\ncpu_affinity_zmq_sub: \"\"\n\n# VR / ZMQ: the robot acts as a SUB socket and receives latest_obs from the sender.\nvr:\n  enable_teleop_reference: false\n  latest_obs_zmq_uri: \"tcp://192.168.124.29:6001\"\n  latest_obs_zmq_topic: \"obs65\"\n  latest_obs_zmq_mode: \"connect\"\n  latest_obs_zmq_conflate: true\n  zmq_jitter_delay_frames: 5\n  max_data_age: 0.6\n  require_vr_data_for_motion: false\n  timing_debug_enabled: false\n  timing_debug_log_interval_sec: 5.0\n  timing_debug_log_per_loop: false\n\ncomplete_dof_order:\n  - left_hip_pitch_joint\n  - left_hip_roll_joint\n  - left_hip_yaw_joint\n  - left_knee_joint\n  - left_ankle_pitch_joint\n  - left_ankle_roll_joint\n  - right_hip_pitch_joint\n  - right_hip_roll_joint\n  - right_hip_yaw_joint\n  - right_knee_joint\n  - right_ankle_pitch_joint\n  - right_ankle_roll_joint\n  - waist_yaw_joint\n  - waist_roll_joint\n  - waist_pitch_joint\n  - left_shoulder_pitch_joint\n  - left_shoulder_roll_joint\n  - left_shoulder_yaw_joint\n  - left_elbow_joint\n  - left_wrist_roll_joint\n  - left_wrist_pitch_joint\n  - left_wrist_yaw_joint\n  - right_shoulder_pitch_joint\n  - right_shoulder_roll_joint\n  - right_shoulder_yaw_joint\n  - right_elbow_joint\n  - right_wrist_roll_joint\n  - right_wrist_pitch_joint\n  - right_wrist_yaw_joint\n\npolicy_dof_order:\n  - left_hip_pitch_joint\n  - left_hip_roll_joint\n  - left_hip_yaw_joint\n  - left_knee_joint\n  - left_ankle_pitch_joint\n  - left_ankle_roll_joint\n  - right_hip_pitch_joint\n  - right_hip_roll_joint\n  - right_hip_yaw_joint\n  - right_knee_joint\n  - right_ankle_pitch_joint\n  - right_ankle_roll_joint\n  - waist_yaw_joint\n  - waist_roll_joint\n  - waist_pitch_joint\n  - left_shoulder_pitch_joint\n  - left_shoulder_roll_joint\n  - left_shoulder_yaw_joint\n  - left_elbow_joint\n  - left_wrist_roll_joint\n  - left_wrist_pitch_joint\n  - left_wrist_yaw_joint\n  - right_shoulder_pitch_joint\n  - right_shoulder_roll_joint\n  - right_shoulder_yaw_joint\n  - right_elbow_joint\n  - right_wrist_roll_joint\n  - right_wrist_pitch_joint\n  - right_wrist_yaw_joint\n\ndof2motor_idx_mapping:\n  # https://support.unitree.com/home/zh/G1_developer/about_G1\n  left_hip_pitch_joint: 0\n  left_hip_roll_joint: 1\n  left_hip_yaw_joint: 2\n  left_knee_joint: 3\n  left_ankle_pitch_joint: 4\n  left_ankle_roll_joint: 5\n  right_hip_pitch_joint: 6\n  right_hip_roll_joint: 7\n  right_hip_yaw_joint: 8\n  right_knee_joint: 9\n  right_ankle_pitch_joint: 10\n  right_ankle_roll_joint: 11\n  waist_yaw_joint: 12\n  waist_roll_joint: 13\n  waist_pitch_joint: 14\n  left_shoulder_pitch_joint: 15\n  left_shoulder_roll_joint: 16\n  left_shoulder_yaw_joint: 17\n  left_elbow_joint: 18\n  left_wrist_roll_joint: 19\n  left_wrist_pitch_joint: 20\n  left_wrist_yaw_joint: 21\n  right_shoulder_pitch_joint: 22\n  right_shoulder_roll_joint: 23\n  right_shoulder_yaw_joint: 24\n  right_elbow_joint: 25\n  right_wrist_roll_joint: 26\n  right_wrist_pitch_joint: 27\n  right_wrist_yaw_joint: 28\n\n\ndefault_joint_angles:\n  # Left leg joints (indices 0-5)\n  left_hip_pitch_joint: -0.312\n  left_hip_roll_joint: 0.0\n  left_hip_yaw_joint: 0.0\n  left_knee_joint: 0.669\n  left_ankle_pitch_joint: -0.33\n  left_ankle_roll_joint: 0.0\n\n  # Right leg joints (indices 6-11)\n  right_hip_pitch_joint: -0.312\n  right_hip_roll_joint: 0.0\n  right_hip_yaw_joint: 0.0\n  right_knee_joint: 0.669\n  right_ankle_pitch_joint: -0.33\n  right_ankle_roll_joint: 0.0\n\n  # Waist joints (indices 12-14)\n  waist_yaw_joint: 0.0\n  waist_roll_joint: 0.0\n  waist_pitch_joint: 0.2\n\n  # Left arm joints (indices 15-21)\n  left_shoulder_pitch_joint: 0.2\n  left_shoulder_roll_joint: 0.2\n  left_shoulder_yaw_joint: 0.0\n  left_elbow_joint: 0.6\n  left_wrist_roll_joint: 0.0\n  left_wrist_pitch_joint: 0.0\n  left_wrist_yaw_joint: 0.0\n\n  # Right arm joints (indices 22-28)\n  right_shoulder_pitch_joint: 0.2\n  right_shoulder_roll_joint: -0.2\n  right_shoulder_yaw_joint: 0.0\n  right_elbow_joint: 0.6\n  right_wrist_roll_joint: 0.0\n  right_wrist_pitch_joint: 0.0\n  right_wrist_yaw_joint: 0.0\n\n# Joint limits\njoint_limits:\n  position:\n    # Left leg joints\n    left_hip_pitch_joint: [-2.5307, 2.8798]\n    left_hip_roll_joint: [-0.5236, 2.9671]\n    left_hip_yaw_joint: [-2.7576, 2.7576]\n    left_knee_joint: [-0.087267, 2.8798]\n    left_ankle_pitch_joint: [-0.87267, 0.5236]\n    left_ankle_roll_joint: [-0.2618, 0.2618]\n\n    # Right leg joints\n    right_hip_pitch_joint: [-2.5307, 2.8798]\n    right_hip_roll_joint: [-2.9671, 0.5236]\n    right_hip_yaw_joint: [-2.7576, 2.7576]\n    right_knee_joint: [-0.087267, 2.8798]\n    right_ankle_pitch_joint: [-0.87267, 0.5236]\n    right_ankle_roll_joint: [-0.2618, 0.2618]\n\n    # Waist joints\n    waist_yaw_joint: [-2.618, 2.618]\n    waist_roll_joint: [-0.52, 0.52]\n    waist_pitch_joint: [-0.52, 0.52]\n\n    # Left arm joints\n    left_shoulder_pitch_joint: [-3.0892, 2.6704]\n    left_shoulder_roll_joint: [-1.5882, 2.2515]\n    left_shoulder_yaw_joint: [-2.618, 2.618]\n    left_elbow_joint: [-1.0472, 2.0944]\n    left_wrist_roll_joint: [-1.972222054, 1.972222054]\n    left_wrist_pitch_joint: [-1.614429558, 1.614429558]\n    left_wrist_yaw_joint: [-1.614429558, 1.614429558]\n\n    # Right arm joints\n    right_shoulder_pitch_joint: [-3.0892, 2.6704]\n    right_shoulder_roll_joint: [-2.2515, 1.5882]\n    right_shoulder_yaw_joint: [-2.618, 2.618]\n    right_elbow_joint: [-1.0472, 2.0944]\n    right_wrist_roll_joint: [-1.972222054, 1.972222054]\n    right_wrist_pitch_joint: [-1.614429558, 1.614429558]\n    right_wrist_yaw_joint: [-1.614429558, 1.614429558]\n\n  velocity:\n    # Left leg joints\n    left_hip_pitch_joint: 32\n    left_hip_roll_joint: 20\n    left_hip_yaw_joint: 32\n    left_knee_joint: 20\n    left_ankle_pitch_joint: 30\n    left_ankle_roll_joint: 30\n\n    # Right leg joints\n    right_hip_pitch_joint: 32\n    right_hip_roll_joint: 20\n    right_hip_yaw_joint: 32\n    right_knee_joint: 20\n    right_ankle_pitch_joint: 30\n    right_ankle_roll_joint: 30\n\n    # Waist joints\n    waist_yaw_joint: 32\n    waist_roll_joint: 30\n    waist_pitch_joint: 30\n\n    # Left arm joints\n    left_shoulder_pitch_joint: 37\n    left_shoulder_roll_joint: 37\n    left_shoulder_yaw_joint: 37\n    left_elbow_joint: 37\n    left_wrist_roll_joint: 37\n    left_wrist_pitch_joint: 22\n    left_wrist_yaw_joint: 22\n\n    # Right arm joints\n    right_shoulder_pitch_joint: 37\n    right_shoulder_roll_joint: 37\n    right_shoulder_yaw_joint: 37\n    right_elbow_joint: 37\n    right_wrist_roll_joint: 37\n    right_wrist_pitch_joint: 22\n    right_wrist_yaw_joint: 22\n\n  effort:\n    # Left leg joints\n    left_hip_pitch_joint: 88\n    left_hip_roll_joint: 139\n    left_hip_yaw_joint: 88\n    left_knee_joint: 139\n    left_ankle_pitch_joint: 35\n    left_ankle_roll_joint: 35\n\n    # Right leg joints\n    right_hip_pitch_joint: 88\n    right_hip_roll_joint: 139\n    right_hip_yaw_joint: 88\n    right_knee_joint: 139\n    right_ankle_pitch_joint: 35\n    right_ankle_roll_joint: 35\n\n    # Waist joints\n    waist_yaw_joint: 88\n    waist_roll_joint: 35\n    waist_pitch_joint: 35\n\n    # Left arm joints\n    left_shoulder_pitch_joint: 25\n    left_shoulder_roll_joint: 25\n    left_shoulder_yaw_joint: 25\n    left_elbow_joint: 25\n    left_wrist_roll_joint: 25\n    left_wrist_pitch_joint: 5\n    left_wrist_yaw_joint: 5\n\n    # Right arm joints\n    right_shoulder_pitch_joint: 25\n    right_shoulder_roll_joint: 25\n    right_shoulder_yaw_joint: 25\n    right_elbow_joint: 25\n    right_wrist_roll_joint: 25\n    right_wrist_pitch_joint: 5\n    right_wrist_yaw_joint: 5\n\nlimit_scales:\n  position: 2.0 # Allows 50% more range of motion\n  velocity: 2.0\n  effort: 2.0\n\n# move to default position\n# joint_names and default_position are auto-generated from complete_dof_order and default_joint_angles\n# Only kp and kd arrays need to be specified here\n\nkp:\n  [ 350.0, 200.0, 200.0, 300.0, 300.0, 150.0,\n    350.0, 200.0, 200.0, 300.0, 300.0, 150.0,\n    200.0, 200.0, 200.0,\n    40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0,\n    40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0 ]\n\nkd:\n  [ 5.0, 5.0, 5.0, 10.0, 5.0, 5.0,\n    5.0, 5.0, 5.0, 10.0, 5.0, 5.0,\n    5.0, 5.0, 5.0,\n    3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,\n    3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0 ]\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/__init__.py",
    "content": ""
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/holomotion_fk_root_only.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom __future__ import annotations\n\nimport logging\nimport time\nfrom typing import Callable, Dict, Sequence\n\nimport numpy as np\nimport torch\n\n\ndef _xyzw_to_wxyz(q: np.ndarray) -> np.ndarray:\n    return np.concatenate([q[..., 3:4], q[..., 0:3]], axis=-1)\n\n\ndef _wxyz_to_xyzw(q: np.ndarray) -> np.ndarray:\n    return np.concatenate([q[..., 1:4], q[..., 0:1]], axis=-1)\n\n\ndef _quat_conjugate_wxyz(q: np.ndarray) -> np.ndarray:\n    out = np.array(q, copy=True)\n    out[..., 1:4] *= -1.0\n    return out\n\n\ndef _quat_mul_wxyz(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:\n    w1 = q1[..., 0]\n    x1 = q1[..., 1]\n    y1 = q1[..., 2]\n    z1 = q1[..., 3]\n    w2 = q2[..., 0]\n    x2 = q2[..., 1]\n    y2 = q2[..., 2]\n    z2 = q2[..., 3]\n    return np.stack(\n        [\n            w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,\n            w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,\n            w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,\n            w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,\n        ],\n        axis=-1,\n    )\n\n\ndef _standardize_quaternion_wxyz(q: np.ndarray) -> np.ndarray:\n    return np.where(q[..., 0:1] < 0.0, -q, q)\n\n\ndef _axis_angle_from_wxyz(q: np.ndarray) -> np.ndarray:\n    q = _standardize_quaternion_wxyz(q)\n    q = q / np.linalg.norm(q, axis=-1, keepdims=True).clip(min=1.0e-9)\n    quat_w = q[..., 0]\n    quat_xyz = q[..., 1:4]\n    mag = np.linalg.norm(quat_xyz, axis=-1)\n    half_angle = np.arctan2(mag, quat_w)\n    angle = 2.0 * half_angle\n    use_taylor = np.abs(angle) <= 1.0e-6\n    angle_safe = np.where(use_taylor, 1.0, angle)\n    sin_half_over_angle = np.where(\n        use_taylor,\n        0.5 - angle * angle / 48.0,\n        np.sin(half_angle) / angle_safe,\n    )\n    return quat_xyz / sin_half_over_angle[..., None]\n\n\ndef _grad_t(x: np.ndarray, dt: float) -> np.ndarray:\n    if dt <= 0.0:\n        raise ValueError(f\"Invalid dt: {dt}\")\n    if x.shape[1] < 2:\n        return np.zeros_like(x)\n    grad = np.empty_like(x)\n    inv_dt = 1.0 / dt\n    grad[:, 0] = (x[:, 1] - x[:, 0]) * inv_dt\n    grad[:, -1] = (x[:, -1] - x[:, -2]) * inv_dt\n    if x.shape[1] > 2:\n        grad[:, 1:-1] = (x[:, 2:] - x[:, :-2]) * (0.5 * inv_dt)\n    return grad\n\n\nclass HoloMotionFKRootOnly(torch.nn.Module):\n    \"\"\"Root-only online FK.\n\n    This lightweight variant is intended for policy-time VR reference building when\n    only the root body pose/velocity are consumed by observation terms.\n    \"\"\"\n\n    def __init__(\n        self,\n        dof_names: Sequence[str],\n        device: torch.device | str = \"cpu\",\n        dtype: torch.dtype = torch.float32,\n        timing_logger_enabled: bool = False,\n        timing_log_interval_sec: float = 5.0,\n        timing_log_per_call: bool = False,\n        timing_name: str = \"HoloMotionFKRootOnly\",\n        timing_log_fn: Callable[[str], None] | None = None,\n    ) -> None:\n        super().__init__()\n        self.body_names = [\"root\"]\n        self.dof_names = list(dof_names)\n        self.num_bodies = 1\n        self.num_dof = len(self.dof_names)\n        if self.num_dof <= 0:\n            raise ValueError(\"dof_names must not be empty\")\n        self._device = torch.device(device)\n        self._dtype = dtype\n        if self._dtype == torch.float64:\n            self._np_dtype = np.float64\n        else:\n            self._np_dtype = np.float32\n        self._timing_logger_enabled = bool(timing_logger_enabled)\n        self._timing_log_interval_sec = float(timing_log_interval_sec)\n        self._timing_log_per_call = bool(timing_log_per_call)\n        self._timing_name = str(timing_name)\n        self._timing_logger = logging.getLogger(__name__)\n        self._timing_log_fn = timing_log_fn\n        self._timing_last_log_time = None\n        self._timing_count = 0\n        self._timing_sum_ms = {}\n        self._timing_max_ms = {}\n        self.last_timing_ms = {}\n        self._gaussian_kernel_cache: Dict[tuple[float, str], np.ndarray] = {}\n\n    def set_timing_logger(\n        self,\n        enabled: bool,\n        interval_sec: float | None = None,\n        per_call: bool | None = None,\n        log_fn: Callable[[str], None] | None = None,\n    ) -> None:\n        self._timing_logger_enabled = bool(enabled)\n        if interval_sec is not None:\n            self._timing_log_interval_sec = float(interval_sec)\n        if per_call is not None:\n            self._timing_log_per_call = bool(per_call)\n        if log_fn is not None:\n            self._timing_log_fn = log_fn\n\n    def _timing_ms(self, t0: float) -> float:\n        return (time.perf_counter() - t0) * 1000.0\n\n    def _to_numpy(self, x: torch.Tensor | np.ndarray) -> np.ndarray:\n        if isinstance(x, np.ndarray):\n            return np.asarray(x, dtype=self._np_dtype)\n        if not isinstance(x, torch.Tensor):\n            return np.asarray(x, dtype=self._np_dtype)\n        if x.device.type != \"cpu\" or x.dtype != self._dtype:\n            x = x.detach().to(device=\"cpu\", dtype=self._dtype)\n        else:\n            x = x.detach()\n        return x.numpy()\n\n    def _to_output_tensor(self, x: np.ndarray) -> torch.Tensor:\n        tensor = torch.from_numpy(np.ascontiguousarray(x))\n        if self._device.type != \"cpu\" or tensor.dtype != self._dtype:\n            tensor = tensor.to(device=self._device, dtype=self._dtype)\n        return tensor\n\n    def _get_gaussian_kernel(self, sigma: float) -> np.ndarray | None:\n        if sigma <= 0.0:\n            return None\n        key = (float(sigma), np.dtype(self._np_dtype).str)\n        kernel = self._gaussian_kernel_cache.get(key, None)\n        if kernel is not None:\n            return kernel\n        radius = int(4.0 * sigma + 0.5)\n        kernel_x = np.arange(-radius, radius + 1, dtype=self._np_dtype)\n        kernel = np.exp(-0.5 * np.square(kernel_x / sigma)).astype(\n            self._np_dtype, copy=False\n        )\n        kernel /= kernel.sum(dtype=self._np_dtype)\n        self._gaussian_kernel_cache[key] = kernel\n        return kernel\n\n    def _gaussian_filter_time(self, x: np.ndarray, kernel: np.ndarray | None) -> np.ndarray:\n        if kernel is None or x.shape[1] < 2:\n            return x\n        radius = kernel.shape[0] // 2\n        padded = np.pad(x, ((0, 0), (radius, radius), (0, 0)), mode=\"edge\")\n        windows = np.lib.stride_tricks.sliding_window_view(\n            padded, window_shape=kernel.shape[0], axis=1\n        )\n        return np.tensordot(windows, kernel, axes=([-1], [0])).astype(\n            x.dtype, copy=False\n        )\n\n    def _log_timing_message(self, message: str) -> None:\n        if self._timing_log_fn is not None:\n            self._timing_log_fn(message)\n        else:\n            self._timing_logger.info(message)\n\n    def _record_timing(self, sample: Dict[str, float]) -> None:\n        self.last_timing_ms = dict(sample)\n        if not self._timing_logger_enabled:\n            return\n\n        self._timing_count += 1\n        for key, value in sample.items():\n            v = float(value)\n            self._timing_sum_ms[key] = self._timing_sum_ms.get(key, 0.0) + v\n            self._timing_max_ms[key] = max(self._timing_max_ms.get(key, v), v)\n        if self._timing_log_per_call:\n            self._log_timing_message(\n                (\n                    f\"[{self._timing_name}][Timing] \"\n                    f\"total={sample['total_ms']:.3f}ms \"\n                    f\"input={sample['input_ms']:.3f}ms \"\n                    f\"quat={sample['quat_ms']:.3f}ms \"\n                    f\"linvel={sample['linvel_ms']:.3f}ms \"\n                    f\"angvel={sample['angvel_ms']:.3f}ms \"\n                    f\"smooth={sample['smooth_ms']:.3f}ms \"\n                    f\"output={sample['output_ms']:.3f}ms\"\n                )\n            )\n\n        now = time.time()\n        if self._timing_last_log_time is None:\n            self._timing_last_log_time = now\n            return\n        if now - self._timing_last_log_time < self._timing_log_interval_sec:\n            return\n        if self._timing_count == 0:\n            self._timing_last_log_time = now\n            return\n\n        keys = [\n            \"total_ms\",\n            \"input_ms\",\n            \"quat_ms\",\n            \"linvel_ms\",\n            \"angvel_ms\",\n            \"smooth_ms\",\n            \"output_ms\",\n        ]\n        self._log_timing_message(\n            f\"[{self._timing_name}][Timing-Agg] \"\n            + \" \".join(\n                f\"{key}=mean:{self._timing_sum_ms.get(key, 0.0) / self._timing_count:.3f}ms/\"\n                f\"max:{self._timing_max_ms.get(key, 0.0):.3f}ms\"\n                for key in keys\n            )\n            + f\" n={self._timing_count}\"\n        )\n        self._timing_count = 0\n        self._timing_sum_ms.clear()\n        self._timing_max_ms.clear()\n        self._timing_last_log_time = now\n\n    @torch.inference_mode()\n    def forward(\n        self,\n        root_pos: torch.Tensor,\n        root_quat: torch.Tensor,\n        dof_pos: torch.Tensor,\n        fps: float,\n        quat_format: str = \"xyzw\",\n        sub_batch_size: int = 64,\n        vel_smoothing_sigma: float = 2.0,\n        compute_velocity: bool = True,\n    ) -> Dict[str, torch.Tensor]:\n        t_total = time.perf_counter()\n        del sub_batch_size\n        del compute_velocity  # kept for call-site compatibility\n\n        if fps <= 0.0:\n            raise ValueError(f\"Invalid fps: {fps}\")\n        if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3:\n            raise ValueError(\"Inputs must be (B, T, ...)\")\n        if (\n            root_pos.shape[:2] != root_quat.shape[:2]\n            or root_pos.shape[:2] != dof_pos.shape[:2]\n        ):\n            raise ValueError(\"Mismatched batch/time shapes among inputs\")\n        if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4:\n            raise ValueError(\n                \"root_pos must be (B,T,3) and root_quat must be (B,T,4)\"\n            )\n        if dof_pos.shape[-1] != self.num_dof:\n            raise ValueError(\n                f\"dof_pos last dim {dof_pos.shape[-1]} does not match {self.num_dof}\"\n            )\n\n        t_input = time.perf_counter()\n        root_pos_np = self._to_numpy(root_pos)\n        root_quat_np = self._to_numpy(root_quat)\n        dof_pos_np = self._to_numpy(dof_pos)\n        input_ms = self._timing_ms(t_input)\n\n        t_quat = time.perf_counter()\n        if quat_format == \"xyzw\":\n            root_quat_xyzw_np = root_quat_np\n            root_quat_wxyz_np = _xyzw_to_wxyz(root_quat_np)\n        elif quat_format == \"wxyz\":\n            root_quat_wxyz_np = root_quat_np\n            root_quat_xyzw_np = _wxyz_to_xyzw(root_quat_np)\n        else:\n            raise ValueError(f\"Unsupported quat_format: {quat_format}\")\n        quat_ms = self._timing_ms(t_quat)\n\n        dt = 1.0 / fps\n        kernel = self._get_gaussian_kernel(float(vel_smoothing_sigma))\n        t_linvel = time.perf_counter()\n        root_vel_np = _grad_t(root_pos_np, dt)\n        linvel_ms = self._timing_ms(t_linvel)\n\n        t_angvel = time.perf_counter()\n        root_angvel_np = np.zeros_like(root_pos_np)\n        if root_quat_wxyz_np.shape[1] >= 2:\n            q1 = root_quat_wxyz_np[:, 1:]\n            q0_inv = _quat_conjugate_wxyz(root_quat_wxyz_np[:, :-1])\n            q_rel = _quat_mul_wxyz(q1, q0_inv)\n            root_angvel_np[:, :-1] = _axis_angle_from_wxyz(q_rel) / dt\n        angvel_ms = self._timing_ms(t_angvel)\n\n        t_smooth = time.perf_counter()\n        if kernel is not None and root_pos_np.shape[1] >= 2:\n            vel_and_ang_np = np.concatenate([root_vel_np, root_angvel_np], axis=-1)\n            vel_and_ang_np = self._gaussian_filter_time(vel_and_ang_np, kernel)\n            root_vel_np = vel_and_ang_np[..., :3]\n            root_angvel_np = vel_and_ang_np[..., 3:6]\n        smooth_ms = self._timing_ms(t_smooth)\n\n        t_output = time.perf_counter()\n        out = {\n            \"global_translation\": self._to_output_tensor(root_pos_np[:, :, None, :]),\n            \"global_rotation_quat\": self._to_output_tensor(\n                root_quat_xyzw_np[:, :, None, :]\n            ),\n            \"global_velocity\": self._to_output_tensor(root_vel_np[:, :, None, :]),\n            \"global_angular_velocity\": self._to_output_tensor(\n                root_angvel_np[:, :, None, :]\n            ),\n            \"dof_pos\": self._to_output_tensor(dof_pos_np),\n            \"dof_vel\": self._to_output_tensor(np.zeros_like(dof_pos_np)),\n        }\n        output_ms = self._timing_ms(t_output)\n        self._record_timing(\n            {\n                \"total_ms\": self._timing_ms(t_total),\n                \"input_ms\": input_ms,\n                \"quat_ms\": quat_ms,\n                \"linvel_ms\": linvel_ms,\n                \"angvel_ms\": angvel_ms,\n                \"smooth_ms\": smooth_ms,\n                \"output_ms\": output_ms,\n            }\n        )\n        return out\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/obs_builder/__init__.py",
    "content": "from .obs_builder import PolicyObsBuilder, get_gravity_orientation\n\n__all__ = [\n    \"PolicyObsBuilder\",\n    \"get_gravity_orientation\",\n]\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/obs_builder/obs_builder.py",
    "content": "import numpy as np\nimport torch\n\nfrom typing import Dict, List, Sequence, Any, Optional\n\n\ndef get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray:\n    \"\"\"Calculate gravity orientation from quaternion.\n\n    Args:\n        quaternion: Array-like [w, x, y, z]\n\n    Returns:\n        np.ndarray of shape (3,) representing gravity projection.\n    \"\"\"\n    qw = float(quaternion[0])\n    qx = float(quaternion[1])\n    qy = float(quaternion[2])\n    qz = float(quaternion[3])\n\n    gravity_orientation = np.zeros(3, dtype=np.float32)\n    gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy)\n    gravity_orientation[1] = -2.0 * (qz * qy + qw * qx)\n    gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz)\n    return gravity_orientation\n\n\nclass _CircularBuffer:\n    \"\"\"History buffer for batched tensor data (batch==1 in our eval/deploy).\n\n    Stores history in oldest->newest order when accessed via .buffer.\n    \"\"\"\n\n    def __init__(self, max_len: int, feat_dim: int):\n        if max_len < 1:\n            raise ValueError(f\"max_len must be >= 1, got {max_len}\")\n        self._max_len = int(max_len)\n        self._feat_dim = int(feat_dim)\n        self._pointer = -1\n        self._num_pushes = 0\n        self._buffer: torch.Tensor = torch.zeros(\n            (self._max_len, 1, self._feat_dim),\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n\n    @property\n    def buffer(self) -> torch.Tensor:\n        \"\"\"Tensor of shape [1, max_len, feat_dim], oldest->newest along dim=1.\"\"\"\n        if self._num_pushes == 0:\n            raise RuntimeError(\n                \"Attempting to read from an empty history buffer.\"\n            )\n        # roll such that oldest is at index=0 along the history axis\n        rolled = torch.roll(\n            self._buffer, shifts=self._max_len - self._pointer - 1, dims=0\n        )\n        return torch.transpose(rolled, 0, 1)  # [1, max_len, feat]\n\n    def append(self, data: torch.Tensor) -> None:\n        \"\"\"Append one step: data shape [1, feat_dim] on the configured device.\"\"\"\n        if (\n            data.ndim != 2\n            or data.shape[0] != 1\n            or data.shape[1] != self._feat_dim\n        ):\n            raise ValueError(\n                f\"Expected data with shape [1, {self._feat_dim}], got {tuple(data.shape)}\"\n            )\n        self._pointer = (self._pointer + 1) % self._max_len\n        self._buffer[self._pointer] = data\n        if self._num_pushes == 0:\n            # duplicate first push across entire history for warm start\n            self._buffer[:] = data\n        self._num_pushes += 1\n\n\nclass PolicyObsBuilder:\n    \"\"\"Builds policy observations from Unitree lowstate with temporal history.\n\n    Designed to be shared between MuJoCo sim2sim evaluation and ROS2 deployment.\n    History management is internal and produces a flattened vector of size\n    sum_i(context_length * feat_i) across the configured observation items.\n\n    Supports two command modes:\n    - \"motion_tracking\": uses reference motion states\n    - \"velocity_tracking\": uses velocity commands [vx, vy, vyaw]\n    \"\"\"\n\n    def __init__(\n        self,\n        dof_names_onnx: Sequence[str],\n        default_angles_onnx: np.ndarray,\n        evaluator: Optional[Any] = None,\n        obs_policy_cfg: Optional[Dict[str, Any]] = None,\n    ) -> None:\n        self.dof_names_onnx: List[str] = list(dof_names_onnx)\n        self.num_actions: int = len(self.dof_names_onnx)\n        self.evaluator = evaluator\n        self.obs_policy_cfg = obs_policy_cfg\n\n        if default_angles_onnx.shape[0] != self.num_actions:\n            raise ValueError(\n                \"default_angles_onnx length must match num actions\"\n            )\n        self.default_angles_onnx = default_angles_onnx.astype(np.float32)\n        self.default_angles_dict: Dict[str, float] = {\n            name: float(self.default_angles_onnx[idx])\n            for idx, name in enumerate(self.dof_names_onnx)\n        }\n\n        # Build observation schema from config if provided\n        self.term_specs: List[Dict[str, Any]] = []\n\n        for term_dict in self.obs_policy_cfg[\"atomic_obs_list\"]:\n            for name, cfg in term_dict.items():\n                term_dict = {**cfg}\n                term_dict[\"name\"] = name\n                self.term_specs.append(term_dict)\n\n        # Buffers are created lazily after first dimension inference\n        self._buffers: Dict[str, _CircularBuffer] = {}\n\n    def reset(self) -> None:\n        for buf in self._buffers.values():\n            buf._pointer = -1\n            buf._num_pushes = 0\n            buf._buffer.zero_()\n\n    def _compute_term(\n        self,\n        name: str,\n    ) -> np.ndarray:\n        # Prefer evaluator-provided methods; no legacy fallbacks\n        if self.evaluator is not None:\n            meth = getattr(self.evaluator, f\"_get_obs_{name}\", None)\n            if callable(meth):\n                out = meth()\n                return np.asarray(out, dtype=np.float32).reshape(-1)\n        raise ValueError(\n            f\"Unknown observation term '{name}' or evaluator method missing.\"\n        )\n\n    def build_policy_obs(self) -> np.ndarray:\n        \"\"\"Append one step using evaluator-provided observation terms and return flattened obs.\"\"\"\n        # Compute per-term outputs\n        values: Dict[str, np.ndarray] = {}\n        for spec in self.term_specs:\n            name = spec[\"name\"]\n            scale = float(spec.get(\"scale\", 1.0))\n            values[name] = self._compute_term(name) * scale\n\n        # Lazily initialize buffers with inferred feature dims\n        if len(self._buffers) == 0:\n            for spec in self.term_specs:\n                name = spec[\"name\"]\n                hist_len = int(spec.get(\"history_length\", 0))\n                if hist_len <= 0:\n                    continue\n                feat_dim = int(values[name].reshape(-1).shape[0])\n                self._buffers[name] = _CircularBuffer(\n                    hist_len,\n                    feat_dim,\n                )\n\n        # Append current step to buffers (skip terms without history)\n        for spec in self.term_specs:\n            name = spec[\"name\"]\n            if name in self._buffers:\n                item = torch.as_tensor(\n                    values[name],\n                    dtype=torch.float32,\n                    device=\"cpu\",\n                )[None, :]\n                self._buffers[name].append(item)\n\n        # Assemble flat list according to term ordering and history flatten rules\n        flat_list: List[np.ndarray] = []\n        for spec in self.term_specs:\n            name = spec[\"name\"]\n            if name in self._buffers:\n                buf = self._buffers[name].buffer[0]  # [hist, feat]\n                arr = buf.reshape(-1).detach().cpu().numpy()\n                flat_list.append(arr.astype(np.float32))\n            else:\n                # no history -> use computed value directly\n                flat_list.append(values[name].reshape(-1).astype(np.float32))\n\n        if len(flat_list) == 0:\n            return np.zeros(0, dtype=np.float32)\n\n        return np.concatenate(flat_list, axis=0).astype(np.float32)\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/policy_node_29dof.py",
    "content": "#! /your_dir/miniconda3/envs/holomotion_deploy/bin/python\n\"\"\"\nHoloMotion Policy Node\n\nThis module implements the main policy execution node for the HoloMotion humanoid robot system using ZMQ latest_obs transport.\nIt handles neural network policy inference, motion sequence management, remote controller input,\nand robot state coordination for humanoid behaviors including velocity tracking and motion tracking.\n\nThe policy node serves as the high-level decision maker that:\n- Processes sensor observations and builds state representations\n- Executes trained neural network policies for motion generation (velocity tracking and motion tracking)\n- Manages multiple motion sequences (motion clips) loaded from offline files\n- Handles remote controller input for motion selection\n- Coordinates with the main control node for safe operation\n\nKey Features:\n- Dual policy support: velocity tracking and motion tracking\n- Offline motion file loading (.npz format)\n- Runtime policy switching with button controls\n- Separate hyperparameters (kps, kds, action_scale, default_angles) for each model\n\nAuthor: HoloMotion Team\nLicense: See project LICENSE file\n\"\"\"\nimport os\nimport torch\nimport time\nimport json\nimport threading\nfrom collections import deque\n\nimport easydict\nimport numpy as np\nimport onnx\nimport onnxruntime\nimport rclpy\nimport zmq\nimport yaml\nfrom ament_index_python.packages import get_package_share_directory\nfrom omegaconf import OmegaConf\nfrom rclpy.node import Node\nfrom rclpy.qos import QoSProfile\nfrom std_msgs.msg import Float32MultiArray, String\nfrom unitree_hg.msg import LowState\n\nfrom humanoid_policy.obs_builder import PolicyObsBuilder\nfrom humanoid_policy.utils.remote_controller_filter import KeyMap, RemoteController\nfrom humanoid_policy.holomotion_fk_root_only import HoloMotionFKRootOnly\n\n\ndef _parse_cpu_affinity_str(s):\n    \"\"\"Parse '0,1' or '2' -> [0,1] or [2]. Empty/invalid -> [].\"\"\"\n    s = str(s).strip()\n    if not s:\n        return []\n    out = []\n    for x in s.split(\",\"):\n        x = x.strip()\n        if x.isdigit():\n            out.append(int(x))\n    return out\n\n\ndef set_thread_cpu_affinity(cpu_ids):\n    \"\"\"Pin current thread to given CPU core IDs (Linux only).\n    cpu_ids: list of int, e.g. [0], [0,1]. Returns True if set successfully.\"\"\"\n    if not cpu_ids:\n        return False\n    try:\n        import ctypes\n        libc = ctypes.CDLL(\"libc.so.6\")\n        CPU_SETSIZE = 1024\n        ncpubits = 8 * ctypes.sizeof(ctypes.c_ulong)\n        nlongs = (CPU_SETSIZE + ncpubits - 1) // ncpubits\n        class CpuSetT(ctypes.Structure):\n            _fields_ = [(\"__bits\", ctypes.c_ulong * nlongs)]\n        libc.pthread_self.restype = ctypes.c_ulong\n        libc.pthread_setaffinity_np.argtypes = [\n            ctypes.c_ulong, ctypes.c_size_t, ctypes.POINTER(CpuSetT)\n        ]\n        cs = CpuSetT()\n        for i in range(nlongs):\n            cs.__bits[i] = 0\n        for c in cpu_ids:\n            if 0 <= c < CPU_SETSIZE:\n                idx = c // ncpubits\n                bit = c % ncpubits\n                cs.__bits[idx] |= 1 << bit\n        tid = libc.pthread_self()\n        sz = ctypes.sizeof(CpuSetT)\n        ret = libc.pthread_setaffinity_np(tid, sz, ctypes.byref(cs))\n        return ret == 0\n    except Exception:\n        return False\n\n\nHEADER_SIZE = 1280\nDEFAULT_ZMQ_TOPIC = b\"obs65\"\n_DTYPE_BY_NAME = {\n    \"f32\": np.float32,\n    \"f64\": np.float64,\n    \"i32\": np.int32,\n    \"i64\": np.int64,\n    \"u8\": np.uint8,\n    \"bool\": np.bool_,\n}\n\n\ndef _decode_zmq_topic(topic_value) -> bytes:\n    if isinstance(topic_value, bytes):\n        return topic_value\n    return str(topic_value).encode(\"utf-8\")\n\n\ndef _coerce_config_bool(value, default: bool = False) -> bool:\n    if value is None:\n        return default\n    if isinstance(value, (bool, np.bool_)):\n        return bool(value)\n    if isinstance(value, str):\n        value = value.strip().lower()\n        if value in {\"1\", \"true\", \"yes\", \"y\", \"on\"}:\n            return True\n        if value in {\"0\", \"false\", \"no\", \"n\", \"off\", \"\"}:\n            return False\n    return bool(value)\n\n\ndef _infer_onnx_dim(dim, default: int = 1) -> int:\n    if isinstance(dim, int) and dim > 0:\n        return dim\n    return int(default)\n\n\ndef _infer_numpy_dtype_from_onnx_type(type_str: str):\n    type_str = str(type_str).lower()\n    if \"float16\" in type_str:\n        return np.float16\n    if \"float64\" in type_str or \"double\" in type_str:\n        return np.float64\n    if \"int64\" in type_str:\n        return np.int64\n    if \"int32\" in type_str:\n        return np.int32\n    if \"bool\" in type_str:\n        return np.bool_\n    return np.float32\n\n\ndef unpack_numpy_message(packet: bytes, expected_topic: bytes | None = None) -> dict:\n    if expected_topic is not None:\n        if not packet.startswith(expected_topic):\n            raise ValueError(\"ZMQ packet topic prefix mismatch\")\n        packet = packet[len(expected_topic) :]\n\n    if len(packet) < HEADER_SIZE:\n        raise ValueError(f\"ZMQ packet too short: {len(packet)} < {HEADER_SIZE}\")\n\n    header_bytes = packet[:HEADER_SIZE].rstrip(b\"\\x00\")\n    if not header_bytes:\n        raise ValueError(\"ZMQ packet has empty header\")\n    header = json.loads(header_bytes.decode(\"utf-8\"))\n\n    payload = memoryview(packet)[HEADER_SIZE:]\n    result = {}\n    offset = 0\n    for field in header.get(\"fields\", []):\n        name = str(field[\"name\"])\n        dtype_name = str(field[\"dtype\"])\n        shape = tuple(int(x) for x in field.get(\"shape\", []))\n        if dtype_name not in _DTYPE_BY_NAME:\n            raise ValueError(f\"Unsupported dtype in ZMQ packet: {dtype_name}\")\n\n        dtype = np.dtype(_DTYPE_BY_NAME[dtype_name]).newbyteorder(\"<\")\n        count = int(np.prod(shape, dtype=np.int64)) if len(shape) > 0 else 1\n        nbytes = count * dtype.itemsize\n        end = offset + nbytes\n        if end > len(payload):\n            raise ValueError(\n                f\"ZMQ packet field '{name}' exceeds payload size: end={end}, payload={len(payload)}\"\n            )\n        arr = np.frombuffer(payload[offset:end], dtype=dtype, count=count)\n        if len(shape) > 0:\n            arr = arr.reshape(shape)\n        else:\n            arr = arr.reshape(())\n        result[name] = np.array(arr, copy=True)\n        offset = end\n    return result\n\n\nclass LatestObsBuffer:\n    \"\"\"Thread-safe buffer for delayed latest_obs access.\"\"\"\n\n    def __init__(self, max_queue_size: int = 20):\n        self._lock = threading.Lock()\n        self._data = None\n        self._timestamp = None\n        self._sender_timestamp = None\n        self._frame_index = None\n        self._data_queue = deque(maxlen=max_queue_size)\n        self._timestamp_queue = deque(maxlen=max_queue_size)\n        self._sender_timestamp_queue = deque(maxlen=max_queue_size)\n        self._frame_index_queue = deque(maxlen=max_queue_size)\n\n    def set(\n        self,\n        arr: np.ndarray,\n        sender_timestamp: float | None = None,\n        frame_index: int | None = None,\n    ):\n        with self._lock:\n            current_time = time.time()\n            arr_copy = np.asarray(arr, dtype=np.float32).copy()\n            self._data = arr_copy\n            self._timestamp = current_time\n            self._sender_timestamp = sender_timestamp\n            self._frame_index = frame_index\n            self._data_queue.append(arr_copy)\n            self._timestamp_queue.append(current_time)\n            self._sender_timestamp_queue.append(sender_timestamp)\n            self._frame_index_queue.append(frame_index)\n\n    def get_with_age_and_delay(self, max_age: float = 0.1, delay_steps: int = 0):\n        \"\"\"Return a delayed frame and report whether it is stale.\"\"\"\n        with self._lock:\n            if len(self._data_queue) == 0:\n                if self._data is None or self._timestamp is None:\n                    return None, None, True, None, None\n                current_time = time.time()\n                age = current_time - self._timestamp\n                return (\n                    self._data,\n                    self._timestamp,\n                    age > max_age,\n                    self._frame_index,\n                    self._sender_timestamp,\n                )\n\n            if delay_steps < 0:\n                delay_steps = 0\n            idx = len(self._data_queue) - 1 - delay_steps\n            if idx < 0:\n                idx = 0\n\n            data = self._data_queue[idx]\n            ts = self._timestamp_queue[idx]\n            frame_index = self._frame_index_queue[idx]\n            sender_timestamp = self._sender_timestamp_queue[idx]\n\n        current_time = time.time()\n        age = current_time - ts\n        is_stale = age > max_age\n        return data, ts, is_stale, frame_index, sender_timestamp\n\n    def get_queue_stats(self):\n        with self._lock:\n            if len(self._data_queue) < 2:\n                return {\"queue_size\": len(self._data_queue), \"avg_interval\": None}\n            intervals = []\n            for i in range(1, len(self._timestamp_queue)):\n                interval = self._timestamp_queue[i] - self._timestamp_queue[i - 1]\n                intervals.append(interval)\n            avg_interval = float(np.mean(intervals)) if intervals else None\n            return {\n                \"queue_size\": len(self._data_queue),\n                \"avg_interval\": avg_interval,\n                \"expected_freq\": 1.0 / avg_interval if avg_interval and avg_interval > 0 else None,\n            }\n\n\nclass ZmqLatestObsSubscriber:\n    \"\"\"Background ZMQ SUB receiver for latest_obs packets.\"\"\"\n\n    def __init__(\n        self,\n        uri: str,\n        topic: bytes,\n        buffer: LatestObsBuffer,\n        logger,\n        mode: str = \"connect\",\n        cpu_affinity=None,\n        conflate: bool = True,\n    ):\n        self.uri = uri\n        self.topic = topic\n        self.buffer = buffer\n        self.logger = logger\n        self.mode = str(mode).strip().lower()\n        self.cpu_affinity = cpu_affinity or []\n        self.conflate = bool(conflate)\n\n        self._thread = None\n        self._stop_event = threading.Event()\n        self._context = None\n        self._socket = None\n        self._poller = None\n        self._recv_count = 0\n\n    def _process_packet(self, packet: bytes):\n        payload = unpack_numpy_message(packet, expected_topic=self.topic)\n        latest_obs = payload.get(\"latest_obs\", None)\n        if latest_obs is None:\n            raise ValueError(\"ZMQ packet missing latest_obs field\")\n\n        frame_index = payload.get(\"frame_index\", None)\n        if frame_index is not None:\n            frame_index = int(np.asarray(frame_index).reshape(-1)[0])\n\n        sender_timestamp = payload.get(\"timestamp_realtime\", None)\n        if sender_timestamp is not None:\n            sender_timestamp = float(np.asarray(sender_timestamp).reshape(-1)[0])\n\n        self.buffer.set(\n            np.asarray(latest_obs, dtype=np.float32),\n            sender_timestamp=sender_timestamp,\n            frame_index=frame_index,\n        )\n        self._recv_count += 1\n        if self._recv_count == 1:\n            self.logger.info(\n                f\"[ZMQ] first latest_obs packet received from {self.uri}, \"\n                f\"topic={self.topic.decode('utf-8', errors='ignore')}\"\n            )\n\n    def _run(self):\n        if self.cpu_affinity and set_thread_cpu_affinity(self.cpu_affinity):\n            self.logger.info(f\"[ZMQ] subscriber thread pinned to CPUs {self.cpu_affinity}\")\n\n        self._context = zmq.Context()\n        self._socket = self._context.socket(zmq.SUB)\n        self._socket.setsockopt(zmq.RCVHWM, 1)\n        self._socket.setsockopt(zmq.SUBSCRIBE, self.topic)\n        if self.conflate and hasattr(zmq, \"CONFLATE\"):\n            self._socket.setsockopt(zmq.CONFLATE, 1)\n\n        if self.mode == \"bind\":\n            self._socket.bind(self.uri)\n        elif self.mode == \"connect\":\n            self._socket.connect(self.uri)\n        else:\n            raise ValueError(\"latest_obs_zmq_mode must be 'bind' or 'connect'\")\n\n        self._poller = zmq.Poller()\n        self._poller.register(self._socket, zmq.POLLIN)\n        self.logger.info(\n            f\"[ZMQ] latest_obs subscriber ready: mode={self.mode}, uri={self.uri}, \"\n            f\"topic={self.topic.decode('utf-8', errors='ignore')}, conflate={self.conflate}\"\n        )\n\n        try:\n            while not self._stop_event.is_set():\n                events = dict(self._poller.poll(50))\n                if self._socket not in events:\n                    continue\n                try:\n                    packet = self._socket.recv(flags=zmq.NOBLOCK)\n                except zmq.Again:\n                    continue\n                self._process_packet(packet)\n        except Exception as exc:\n            if not self._stop_event.is_set():\n                self.logger.error(f\"[ZMQ] subscriber loop failed: {exc}\")\n        finally:\n            try:\n                if self._poller is not None and self._socket is not None:\n                    self._poller.unregister(self._socket)\n            except Exception:\n                pass\n            try:\n                if self._socket is not None:\n                    self._socket.close(0)\n            except Exception:\n                pass\n            try:\n                if self._context is not None:\n                    self._context.term()\n            except Exception:\n                pass\n            self._socket = None\n            self._context = None\n            self._poller = None\n\n    def start(self):\n        if self._thread is not None:\n            return\n        self._stop_event.clear()\n        self._thread = threading.Thread(target=self._run, daemon=True)\n        self._thread.start()\n        self.logger.info(\"[ZMQ] subscriber thread started\")\n\n    def stop(self):\n        self._stop_event.set()\n        if self._thread:\n            self._thread.join(timeout=2.0)\n            self._thread = None\n        self.logger.info(\"[ZMQ] subscriber thread stopped\")\n\n\nclass HoloMotionPolicyNode(Node):\n    \"\"\"Main policy execution node for HoloMotion humanoid robot control with dual policy support.\n\n    This node implements the high-level control logic for a humanoid robot capable of\n    performing both velocity tracking and motion sequence execution. It supports two\n    neural network policies and allows runtime switching between them.\n\n    Key Features:\n    - Dual neural network policy inference (velocity + motion) using ONNX Runtime\n    - Runtime policy switching with A/B/Y button controls\n    - Velocity tracking mode with joystick control\n    - Motion tracking mode with motion clip sequence selection\n    - Safety-aware state machine with motion prerequisites\n    - Real-time observation processing and action generation\n\n    Policy Control:\n    - A button: Enable policy (defaults to velocity mode)\n    - B button: Switch from velocity to motion mode\n    - Y button: Switch from motion back to velocity mode\n    \n    Input Controls:\n    - Motion mode:  B button (for mode switch)\n    - Velocity mode: Y button (for mode switch) + Joystick +UP/DOWN/LEFT/RIGHT (for motion selection)\n\n    State Machine:\n    - ZERO_TORQUE: Initial safe state, waiting for activation\n    - MOVE_TO_DEFAULT: Ready state, allows policy operations\n    - Policy execution with mode switching\n    - Emergency stop handling\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"Initialize the policy node with configuration, models, and ROS2 interfaces.\n\n        Sets up the complete policy execution pipeline including:\n        - Configuration loading from YAML file\n        - Neural network model initialization\n        - Motion data loading for all sequences\n        - ROS2 publishers, subscribers, and timers\n        - State machine initialization\n\n        The node starts in a safe state and waits for proper robot state\n        before allowing motion execution.\n        \"\"\"\n        super().__init__(\"policy_node\")\n\n        # Get config path from ROS parameter\n        config_path = self.declare_parameter(\"config_path\", \"\").value\n        with open(config_path, \"r\", encoding=\"utf-8\") as config_file:\n            self.config_yaml = easydict.EasyDict(yaml.safe_load(config_file))\n        # Read policy frequency from config, default to 50 Hz if not specified\n        policy_freq = self.config_yaml.get(\"policy_freq\", 50)\n        self.dt = 1.0 / policy_freq\n        self.get_logger().info(f\"Policy frequency set to: {policy_freq} Hz (dt = {self.dt:.4f} s)\")\n        # Initialize basic parameters - will be updated after config loading\n        self.actions_dim = 29  # Default value, will be updated from config\n        self.real_dof_names = []  # Will be loaded from config\n        self.current_motion_clip_index = 0  # Current motion clip index\n        # Button state tracking for preventing multiple triggers\n        self.last_button_states = {\n            KeyMap.up: 0,\n            KeyMap.down: 0,\n            KeyMap.left: 0,\n            KeyMap.right: 0,\n            KeyMap.A: 0,\n            KeyMap.B: 0,\n            KeyMap.Y: 0,\n        }\n        # Safety check related flags\n        self.policy_enabled = False  # Controls whether policy is enabled\n        # Robot state related flags\n        self.robot_state_ready = False  # Marks whether MOVE_TO_DEFAULT state is received, allowing key operations\n        self._setup_subscribers()\n        self._setup_publishers()\n        self._setup_timers()\n        # Initialize variables for dual policy\n        self.velocity_policy_session = None\n        self.motion_policy_session = None\n        self.use_kv_cache = False\n        self.motion_kv_cache = None\n        self.motion_kv_input_name = None\n        self.motion_kv_output_name = None\n        self.motion_step_idx_input_name = None\n        self.current_policy_mode = \"velocity\"\n        self.velocity_config = None\n        self.motion_config = None\n        self.motion_frame_idx = 0\n        self.ref_dof_pos = None\n        self.ref_dof_vel = None\n        self.ref_raw_bodylink_pos = None\n        self.ref_raw_bodylink_rot = None\n        self.n_motion_frames = 0\n\n        self.external_latest_obs = None\n        self.external_obs_received = False\n        self.last_external_obs_time = None\n        self._latest_sender_timestamp = None\n        self.latest_obs_flag = False\n        self.latest_obs_expected_dim = 65\n        self.external_fut_dof_pos_queue = None\n        self.external_fut_dof_vel_queue = None\n        self.external_fut_root_pos_queue = None\n        self.external_fut_root_rot_queue = None\n        self.external_fut_frame_idx_queue = None\n        self._prev_external_dof_pos = None\n        self._prev_external_dof_vel = None\n        self._prev_external_root_pos = None\n        self._prev_external_root_rot = None\n        self._prev_external_frame_idx = None\n        self.max_data_age = 0.6\n        self.stale_data_warning_count = 0\n        self.last_poll_time = None\n        self._last_vr_status_log_time = None\n        self.latest_obs_zmq_uri = self.declare_parameter(\n            \"latest_obs_zmq_uri\", \"tcp://192.168.124.29:6001\"\n        ).value\n        self.latest_obs_zmq_topic = self.declare_parameter(\n            \"latest_obs_zmq_topic\", DEFAULT_ZMQ_TOPIC.decode(\"utf-8\")\n        ).value\n        self.latest_obs_zmq_mode = self.declare_parameter(\n            \"latest_obs_zmq_mode\", \"connect\"\n        ).value\n        self.latest_obs_zmq_conflate = self.declare_parameter(\n            \"latest_obs_zmq_conflate\", True\n        ).value\n        self.zmq_jitter_delay_frames = self.declare_parameter(\n            \"zmq_jitter_delay_frames\", 5\n        ).value\n        self.require_vr_data_for_motion = self.declare_parameter(\n            \"require_vr_data_for_motion\", True\n        ).value\n        self.enable_teleop_reference = self.declare_parameter(\n            \"enable_teleop_reference\", True\n        ).value\n        self._cpu_affinity_main_str = self.declare_parameter(\n            \"cpu_affinity_main\", \"\"\n        ).value\n        self._cpu_affinity_zmq_sub_str = self.declare_parameter(\n            \"cpu_affinity_zmq_sub\", \"\"\n        ).value\n        self.timing_debug_enabled = self.declare_parameter(\n            \"timing_debug_enabled\", True\n        ).value\n        self.timing_debug_log_interval_sec = self.declare_parameter(\n            \"timing_debug_log_interval_sec\", 5.0\n        ).value\n        self.timing_debug_log_per_loop = self.declare_parameter(\n            \"timing_debug_log_per_loop\", False\n        ).value\n        self._timing_debug_last_log_time = None\n        self._timing_debug_samples = deque(maxlen=500)\n        self._root_only_fk_keybody_warned = False\n\n        _vr = getattr(self.config_yaml, \"vr\", None) or {}\n        if _vr:\n            self.latest_obs_zmq_uri = str(_vr.get(\"latest_obs_zmq_uri\", self.latest_obs_zmq_uri))\n            self.latest_obs_zmq_topic = str(\n                _vr.get(\"latest_obs_zmq_topic\", self.latest_obs_zmq_topic)\n            )\n            self.latest_obs_zmq_mode = str(\n                _vr.get(\"latest_obs_zmq_mode\", self.latest_obs_zmq_mode)\n            )\n            self.latest_obs_zmq_conflate = bool(\n                _vr.get(\"latest_obs_zmq_conflate\", self.latest_obs_zmq_conflate)\n            )\n            self.zmq_jitter_delay_frames = int(\n                _vr.get(\"zmq_jitter_delay_frames\", self.zmq_jitter_delay_frames)\n            )\n            self.max_data_age = float(_vr.get(\"max_data_age\", self.max_data_age))\n            self.require_vr_data_for_motion = bool(\n                _vr.get(\"require_vr_data_for_motion\", self.require_vr_data_for_motion)\n            )\n            self.enable_teleop_reference = bool(\n                _vr.get(\"enable_teleop_reference\", self.enable_teleop_reference)\n            )\n            self.timing_debug_enabled = bool(\n                _vr.get(\"timing_debug_enabled\", self.timing_debug_enabled)\n            )\n            self.timing_debug_log_interval_sec = float(\n                _vr.get(\n                    \"timing_debug_log_interval_sec\",\n                    self.timing_debug_log_interval_sec,\n                )\n            )\n            self.timing_debug_log_per_loop = bool(\n                _vr.get(\"timing_debug_log_per_loop\", self.timing_debug_log_per_loop)\n            )\n\n        self._cpu_affinity_main_str = str(\n            getattr(self.config_yaml, \"cpu_affinity_main\", self._cpu_affinity_main_str)\n        )\n        self._cpu_affinity_zmq_sub_str = str(\n            getattr(\n                self.config_yaml,\n                \"cpu_affinity_zmq_sub\",\n                self._cpu_affinity_zmq_sub_str,\n            )\n        )\n        self._ros_latest_obs_buffer = None\n        self._npz_replay_frame_index = None\n        self._external_seen_frames = 0\n        self._vr_ready_logged = False\n\n        self._latest_obs_buffer = LatestObsBuffer()\n        self._latest_obs_zmq_topic_bytes = _decode_zmq_topic(self.latest_obs_zmq_topic)\n        if str(self.latest_obs_zmq_mode).strip().lower() == \"connect\":\n            uri_str = str(self.latest_obs_zmq_uri)\n            if \"*\" in uri_str or \"0.0.0.0\" in uri_str:\n                self.get_logger().warn(\n                    \"[ZMQ] connect mode requires a concrete peer address. \"\n                    \"Do not use '*' or '0.0.0.0'; use the sender IP instead, \"\n                    \"for example tcp://192.168.124.29:6001.\"\n                )\n        zmq_cpu_affinity = _parse_cpu_affinity_str(self._cpu_affinity_zmq_sub_str)\n        self._zmq_subscriber = ZmqLatestObsSubscriber(\n            uri=self.latest_obs_zmq_uri,\n            topic=self._latest_obs_zmq_topic_bytes,\n            mode=self.latest_obs_zmq_mode,\n            conflate=bool(self.latest_obs_zmq_conflate),\n            buffer=self._latest_obs_buffer,\n            logger=self.get_logger(),\n            cpu_affinity=zmq_cpu_affinity if zmq_cpu_affinity else None,\n        )\n        self._zmq_subscriber.start()\n        self.get_logger().info(\n            f\"ZMQ latest_obs subscriber started: mode={self.latest_obs_zmq_mode}, \"\n            f\"uri={self.latest_obs_zmq_uri}, topic={self.latest_obs_zmq_topic}, \"\n            f\"jitter_delay={self.zmq_jitter_delay_frames}\"\n        )\n\n        self.dof_names_ref_motion = []\n        self.num_actions = 29\n        self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32)\n\n        self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.target_dof_pos_onnx = self.default_angles_onnx.copy()\n        self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)\n\n        self._lowstate_msg = None\n        self.target_dof_pos_real = None\n        self.motion_in_progress = False\n        self._keybody_indices_by_term_name = {}\n        self.fk = None\n        self.fk_initialized = False\n        self.motion_action_ema_filter_enabled = False\n        self.motion_action_ema_filter_alpha = 1.0\n        self._motion_filtered_actions_onnx = None\n\n    def _is_vr_ready_for_motion(self) -> bool:\n        \"\"\"Return whether the ZMQ reference stream is ready for motion mode.\"\"\"\n        if not getattr(self, \"enable_teleop_reference\", True):\n            return False\n        if not (\n            getattr(self, \"external_obs_received\", False)\n            and getattr(self, \"external_latest_obs\", None) is not None\n        ):\n            return False\n        n_fut = int(getattr(self, \"n_fut_frames\", 0) or 0)\n        if n_fut <= 0:\n            return True\n        delay = int(getattr(self, \"zmq_jitter_delay_frames\", 0) or 0)\n        needed = n_fut + max(delay, 0) + 1\n        return int(getattr(self, \"_external_seen_frames\", 0)) >= needed\n\n    \n    def _init_keybody_indices_cache(self):\n        if self.motion_config is None:\n            raise ValueError(\"motion_config is not loaded; cannot init keybody index cache\")\n\n        atomic_list = self._get_policy_atomic_obs_list(self.motion_config)[\"atomic_obs_list\"]\n        body_names = [str(name) for name in self.motion_config.robot.body_names]\n        body_name_to_idx = {body_name: idx for idx, body_name in enumerate(body_names)}\n\n        cache = {}\n        for term_dict in atomic_list:\n            term_name = str(list(term_dict.keys())[0])\n            term_cfg = term_dict[term_name]\n            params = {}\n            if isinstance(term_cfg, dict):\n                params = term_cfg.get(\"params\", {}) or {}\n                if not isinstance(params, dict):\n                    raise ValueError(\n                        f\"Observation term '{term_name}' params must be a dict, got {type(params)}\"\n                    )\n            needs_keybody = (\"keybody\" in term_name) or (\"keybody_names\" in params)\n            if not needs_keybody:\n                continue\n\n            keybody_names = params.get(\"keybody_names\", None)\n            if keybody_names is None:\n                keybody_idxs = np.arange(len(body_names), dtype=np.int64)\n            else:\n                keybody_names = [str(name) for name in keybody_names]\n                missing_names = [\n                    name for name in keybody_names if name not in body_name_to_idx\n                ]\n                if len(missing_names) > 0:\n                    raise ValueError(\n                        f\"Unknown keybody_names in '{term_name}': {missing_names}. \"\n                        f\"Available body names: {body_names}\"\n                    )\n                keybody_idxs = np.asarray(\n                    [body_name_to_idx[name] for name in keybody_names],\n                    dtype=np.int64,\n                )\n\n            cache[term_name] = keybody_idxs\n\n        self._keybody_indices_by_term_name = cache\n\n    def _get_policy_atomic_obs_list(self, config):\n        \"\"\"Resolve the atomic obs list used to build the ONNX policy input.\n\n        Aligns with MuJoCo sim2sim eval ordering by honoring modules.actor.obs_schema\n        when available, to guarantee the policy input term order matches training/export.\n        \"\"\"\n\n        def _to_plain_obs_cfg(cfg):\n            if OmegaConf.is_config(cfg):\n                plain_cfg = OmegaConf.to_container(cfg, resolve=True)\n            elif cfg is None:\n                plain_cfg = {}\n            else:\n                plain_cfg = dict(cfg)\n            if plain_cfg is None:\n                plain_cfg = {}\n            if not isinstance(plain_cfg, dict):\n                raise ValueError(\n                    f\"Observation term config must be a mapping, got {type(plain_cfg)}\"\n                )\n            return plain_cfg\n\n        def _get_actor_atomic_obs_entries():\n            obs_cfg = config.get(\"obs\", None)\n            if obs_cfg is None:\n                raise ValueError(\"Missing config.obs for policy obs\")\n            obs_groups = obs_cfg.get(\"obs_groups\", None)\n            if obs_groups is None:\n                raise ValueError(\"Missing config.obs.obs_groups for policy obs\")\n\n            if obs_groups.get(\"policy\", None) is not None:\n                entries = []\n                for term_dict in obs_groups.policy.atomic_obs_list:\n                    term_name = str(list(term_dict.keys())[0])\n                    entries.append(\n                        (\n                            \"policy\",\n                            term_name,\n                            _to_plain_obs_cfg(term_dict[term_name]),\n                        )\n                    )\n                return entries\n\n            if obs_groups.get(\"unified\", None) is not None:\n                entries = []\n                for term_dict in obs_groups.unified.atomic_obs_list:\n                    term_name = str(list(term_dict.keys())[0])\n                    if term_name.startswith(\"actor_\"):\n                        entries.append(\n                            (\n                                \"unified\",\n                                term_name,\n                                _to_plain_obs_cfg(term_dict[term_name]),\n                            )\n                        )\n                if not entries:\n                    raise ValueError(\n                        \"obs_groups.unified found but contains no actor_* terms.\"\n                    )\n                return entries\n\n            raise ValueError(\n                \"Unsupported obs config : expected obs_groups.policy or obs_groups.unified.\"\n            )\n\n        def _get_actor_obs_schema_terms():\n            modules_cfg = config.get(\"modules\", None)\n            if modules_cfg is None:\n                return []\n            actor_cfg = modules_cfg.get(\"actor\", None)\n            if actor_cfg is None:\n                return []\n            obs_schema = actor_cfg.get(\"obs_schema\", None)\n            if obs_schema is None:\n                return []\n\n            if OmegaConf.is_config(obs_schema):\n                obs_schema_plain = OmegaConf.to_container(obs_schema, resolve=True)\n            else:\n                obs_schema_plain = obs_schema\n            if not isinstance(obs_schema_plain, dict):\n                return []\n\n            ordered_terms = []\n\n            def _collect_terms(node):\n                if node is None:\n                    return\n                if isinstance(node, dict):\n                    if \"terms\" in node and isinstance(node[\"terms\"], list):\n                        ordered_terms.extend(str(term) for term in node[\"terms\"])\n                        return\n                    for v in node.values():\n                        _collect_terms(v)\n                    return\n                if isinstance(node, list):\n                    for v in node:\n                        _collect_terms(v)\n                    return\n\n            _collect_terms(obs_schema_plain)\n            return ordered_terms\n\n        actor_atomic_entries = _get_actor_atomic_obs_entries()\n        schema_terms = _get_actor_obs_schema_terms()\n\n        if len(schema_terms) == 0:\n            return {\n                \"atomic_obs_list\": [\n                    {term_name: term_cfg}\n                    for _, term_name, term_cfg in actor_atomic_entries\n                ]\n            }\n\n        by_full_key = {}\n        by_leaf_key = {}\n        ambiguous_leaf_keys = set()\n        for group_name, term_name, term_cfg in actor_atomic_entries:\n            full_key = f\"{group_name}/{term_name}\"\n            by_full_key[full_key] = (term_name, term_cfg)\n            if term_name in by_leaf_key:\n                ambiguous_leaf_keys.add(term_name)\n            else:\n                by_leaf_key[term_name] = (term_name, term_cfg)\n\n        ordered_atomic_list = []\n        for schema_term in schema_terms:\n            schema_term_key = str(schema_term)\n            if schema_term_key in by_full_key:\n                term_name, term_cfg = by_full_key[schema_term_key]\n                ordered_atomic_list.append({term_name: term_cfg})\n                continue\n\n            leaf_key = schema_term_key.split(\"/\")[-1]\n            if leaf_key in ambiguous_leaf_keys:\n                raise ValueError(\n                    f\"Ambiguous obs_schema term '{schema_term_key}': \"\n                    f\"multiple atomic obs share leaf key '{leaf_key}'.\"\n                )\n            if leaf_key not in by_leaf_key:\n                available = sorted(list(by_leaf_key.keys()))\n                raise ValueError(\n                    f\"obs_schema term '{schema_term_key}' not found in atomic_obs_list. \"\n                    f\"Available terms: {available}\"\n                )\n            term_name, term_cfg = by_leaf_key[leaf_key]\n            ordered_atomic_list.append({term_name: term_cfg})\n\n        return {\"atomic_obs_list\": ordered_atomic_list}\n\n    def _find_actor_place_holder_ndim(self):\n        n_dim = 0\n        atomic_list = self._get_policy_atomic_obs_list(self.motion_config)[\n            \"atomic_obs_list\"\n        ]\n        for obs_dict in atomic_list:\n            name = str(list(obs_dict.keys())[0])\n            if name == \"place_holder\" or name == \"actor_place_holder\":\n                cfg = obs_dict[name]\n                params = cfg.get(\"params\", {}) if isinstance(cfg, dict) else {}\n                n_dim = int(params.get(\"n_dim\", 0))\n        return n_dim\n\n    def _init_obs_buffers(self):\n        \"\"\"Initialize observation builders for both velocity and motion policies.\n        \n        Each obs_builder uses its own model's dof_names_onnx and default_angles_onnx\n        to ensure correct observation computation for each policy.\n        \"\"\"\n        # Use velocity model's parameters for velocity obs_builder\n        self.velocity_obs_builder = PolicyObsBuilder(\n            dof_names_onnx=self.velocity_dof_names_onnx,\n            default_angles_onnx=self.velocity_default_angles_onnx,\n            evaluator=self,\n            obs_policy_cfg=self._get_policy_atomic_obs_list(self.velocity_config),\n        )\n\n        # Use motion model's parameters for motion obs_builder\n        self.motion_obs_builder = PolicyObsBuilder(\n            dof_names_onnx=self.motion_dof_names_onnx,\n            default_angles_onnx=self.motion_default_angles_onnx,\n            evaluator=self,\n            obs_policy_cfg=self._get_policy_atomic_obs_list(self.motion_config),\n        )\n\n        if hasattr(self, \"n_fut_frames\") and int(self.n_fut_frames) > 0:\n            n_fut = int(self.n_fut_frames)\n            self.external_fut_dof_pos_queue = np.zeros((n_fut, self.num_actions), dtype=np.float32)\n            self.external_fut_dof_vel_queue = np.zeros((n_fut, self.num_actions), dtype=np.float32)\n            self.external_fut_root_pos_queue = np.zeros((n_fut, 3), dtype=np.float32)\n            self.external_fut_root_rot_queue = np.zeros((n_fut, 4), dtype=np.float32)\n            self._fk_root_pos_seq_np = np.zeros((1, n_fut + 1, 3), dtype=np.float32)\n            self._fk_root_rot_seq_np = np.zeros((1, n_fut + 1, 4), dtype=np.float32)\n            self._fk_dof_pos_seq_np = np.zeros(\n                (1, n_fut + 1, self.num_actions), dtype=np.float32\n            )\n            self._fk_root_pos_seq_tensor = torch.from_numpy(self._fk_root_pos_seq_np)\n            self._fk_root_rot_seq_tensor = torch.from_numpy(self._fk_root_rot_seq_np)\n            self._fk_dof_pos_seq_tensor = torch.from_numpy(self._fk_dof_pos_seq_np)\n            self.external_fut_frame_idx_queue = np.full((n_fut,), -1, dtype=np.int32)\n            self.get_logger().info(\n                f\"Initialized VR future frame queues: n_fut_frames={n_fut}, num_actions={self.num_actions}\"\n            )\n        else:\n            self.external_fut_dof_pos_queue = None\n            self.external_fut_dof_vel_queue = None\n            self.external_fut_root_pos_queue = None\n            self.external_fut_root_rot_queue = None\n            self.external_fut_frame_idx_queue = None\n            self._fk_root_pos_seq_np = None\n            self._fk_root_rot_seq_np = None\n            self._fk_dof_pos_seq_np = None\n            self._fk_root_pos_seq_tensor = None\n            self._fk_root_rot_seq_tensor = None\n            self._fk_dof_pos_seq_tensor = None\n\n        # Set default obs_builder to velocity mode\n        self.obs_builder = self.velocity_obs_builder\n\n    def _reset_counter(self):\n        \"\"\"Reset motion timing counters to start of sequence.\"\"\"\n        self.motion_frame_idx = 0\n        self.motion_step_idx = 0\n        if self.use_kv_cache and self.motion_kv_cache is not None:\n            self.motion_kv_cache.fill(0)\n\n    def _switch_to_velocity_mode(self, reason: str = \"\"):\n        \"\"\"Switch to velocity tracking mode and clear action cache.\n        \n        Uses velocity model's default_angles_onnx to ensure correct initialization.\n        Also publishes velocity model's control parameters (kps/kds).\n        \"\"\"\n        self.current_policy_mode = \"velocity\"\n        self.latest_obs_flag = False\n        self.motion_in_progress = False\n        self._fk_vr_out = None\n        self._use_fk_vr = False\n        self._reset_motion_action_ema_filter()\n        self._reset_counter()\n        self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        # Use velocity model's default angles\n        self.target_dof_pos_onnx = self.velocity_default_angles_onnx.copy()\n        # Publish velocity model's control parameters\n        self._publish_control_params()\n        if reason:\n            self.get_logger().info(f\"Switched to velocity tracking mode ({reason})\")\n        else:\n            self.get_logger().info(\"Switched to velocity tracking mode\")\n\n    def _is_button_pressed(self, button_key):\n        \"\"\"Check if button was just pressed (rising edge detection).\"\"\"\n        current_state = self.remote_controller.button[button_key]\n        last_state = self.last_button_states[button_key]  \n        # Update the last state\n        self.last_button_states[button_key] = current_state\n        # Return True only on rising edge (0 -> 1)\n        return current_state == 1 and last_state == 0\n\n    def load_policy(self):\n        \"\"\"Load both velocity and motion policy models using ONNX Runtime.\"\"\"\n        self.get_logger().info(\"Loading dual policies...\")\n        \n        providers = [\n            (\n                \"CUDAExecutionProvider\",\n                {\n                    \"device_id\": 0,\n                },\n            ),\n            \"CPUExecutionProvider\",\n        ]\n        onnx_threads = int(self.config_yaml.get(\"onnx_intra_op_threads\", 2))\n        sess_options = onnxruntime.SessionOptions()\n        sess_options.intra_op_num_threads = onnx_threads\n        sess_options.inter_op_num_threads = 1\n        # Load velocity policy from model folder\n        velocity_model_folder = self.config_yaml.velocity_tracking_model_folder\n        velocity_model_path = os.path.join(\n            get_package_share_directory(\"humanoid_control\"),\n            \"models\",\n            velocity_model_folder,\n            \"exported\",\n        )\n        # Find ONNX file in exported folder\n        velocity_onnx_files = [f for f in os.listdir(velocity_model_path) if f.endswith('.onnx')]\n        if not velocity_onnx_files:\n            raise FileNotFoundError(f\"No ONNX files found in {velocity_model_path}\")\n        \n        velocity_onnx_path = os.path.join(velocity_model_path, velocity_onnx_files[0])\n        self.get_logger().info(f\"Loading velocity policy from {velocity_onnx_path}\")\n        \n        self.velocity_policy_session = onnxruntime.InferenceSession(\n            str(velocity_onnx_path), sess_options=sess_options, providers=providers\n        )\n        self.get_logger().info(\n            f\"Velocity policy loaded successfully using: \"\n            f\"{self.velocity_policy_session.get_providers()}\"\n        )\n        # Load motion policy from model folder\n        motion_model_folder = self.config_yaml.motion_tracking_model_folder\n        motion_model_path = os.path.join(\n            get_package_share_directory(\"humanoid_control\"),\n            \"models\",\n            motion_model_folder,\n            \"exported\",\n        )\n        # Find ONNX file in exported folder\n        motion_onnx_files = [f for f in os.listdir(motion_model_path) if f.endswith('.onnx')]\n        if not motion_onnx_files:\n            raise FileNotFoundError(f\"No ONNX files found in {motion_model_path}\")\n        \n        motion_onnx_path = os.path.join(motion_model_path, motion_onnx_files[0])\n        self.get_logger().info(f\"Loading motion policy from {motion_onnx_path}\")\n        \n        self.motion_policy_session = onnxruntime.InferenceSession(\n            str(motion_onnx_path), sess_options=sess_options, providers=providers\n        )\n        self.get_logger().info(\n            f\"Motion policy loaded successfully using: \"\n            f\"{self.motion_policy_session.get_providers()}\"\n        )\n        # Set input/output names for both policies\n        self.velocity_input_name = self.velocity_policy_session.get_inputs()[0].name\n        self.velocity_output_name = self.velocity_policy_session.get_outputs()[0].name\n        self.motion_input_name = self.motion_policy_session.get_inputs()[0].name\n        self.motion_output_name = self.motion_policy_session.get_outputs()[0].name\n        \n        self.get_logger().info(\n            f\"Velocity policy - Input: {self.velocity_input_name}, \"\n            f\"Output: {self.velocity_output_name}\"\n        )\n        self.get_logger().info(\n            f\"Motion policy - Input: {self.motion_input_name}, \"\n            f\"Output: {self.motion_output_name}\"\n        )\n        # Store ONNX paths for metadata reading\n        self.velocity_onnx_path = velocity_onnx_path\n        self.motion_onnx_path = motion_onnx_path\n        self.get_logger().info(\"Initializing KV-Cache for Motion Policy...\")\n        \n        self.motion_kv_input_name = None\n        self.motion_kv_output_name = None\n        self.motion_kv_shape = None\n        self.motion_step_idx_input_name = None\n        self.motion_kv_dtype = np.float32\n        \n        for node in self.motion_policy_session.get_inputs():\n            name = node.name\n            shape = node.shape\n            node_type = node.type\n            self.get_logger().info(f\"Motion policy input: name={name}, shape={shape}, type={node_type}\")\n            if \"obs\" in name:\n                self.motion_input_name = name\n            elif \"past_key_values\" in name:\n                self.motion_kv_input_name = name\n                self.motion_kv_shape = shape\n                if isinstance(node_type, str) and \"float16\" in node_type:\n                    self.motion_kv_dtype = np.float16\n            elif \"step_idx\" in name or name == \"step_idx\":\n                self.motion_step_idx_input_name = name\n            elif \"current_pos\" in name or name == \"current_pos\":\n                self.motion_step_idx_input_name = name\n            elif (\n                self.motion_step_idx_input_name is None\n                and isinstance(node_type, str)\n                and \"int64\" in node_type\n                and name not in (self.motion_input_name, self.motion_kv_input_name)\n            ):\n                self.motion_step_idx_input_name = name\n\n        motion_outputs = self.motion_policy_session.get_outputs()\n        action_output_name = None\n        kv_output_name = None\n        for node in motion_outputs:\n            self.get_logger().info(f\"Motion policy output: name={node.name}, shape={node.shape}, type={node.type}\")\n            if \"present_key_values\" in node.name:\n                kv_output_name = node.name\n            elif \"actions\" in node.name:\n                action_output_name = node.name\n        if action_output_name is None:\n            for node in motion_outputs:\n                if kv_output_name is not None and node.name == kv_output_name:\n                    continue\n                action_output_name = node.name\n                break\n        if action_output_name is None:\n            action_output_name = motion_outputs[0].name\n        self.motion_output_name = action_output_name\n        self.motion_kv_output_name = kv_output_name\n        if self.motion_kv_input_name is not None and self.motion_kv_output_name is None:\n            self.get_logger().warn(\n                \"Motion policy has past_key_values input but no present_key_values output was found. \"\n                \"KV cache will not update and transformer performance will degrade.\"\n            )\n\n        if self.motion_kv_input_name and self.motion_kv_shape:\n            shape = [d if isinstance(d, int) else 1 for d in self.motion_kv_shape]\n            \n            self.motion_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype)\n            self.motion_model_context_len = int(shape[3]) if len(shape) > 3 else 0\n            self.motion_max_context_len = int(\n                self.motion_config.get(\"algo\", {})\n                .get(\"config\", {})\n                .get(\"num_steps_per_env\", 0)\n            )\n            if self.motion_max_context_len > 0 and self.motion_model_context_len > 0:\n                self.motion_effective_context_len = min(\n                    self.motion_max_context_len, self.motion_model_context_len\n                )\n            else:\n                self.motion_effective_context_len = self.motion_model_context_len\n            self.use_kv_cache = True\n            self.get_logger().info(\n                f\"KV-Cache initialized with shape {shape} \"\n                f\"(model_ctx={self.motion_model_context_len}, \"\n                f\"effective_ctx={self.motion_effective_context_len})\"\n            )\n        else:\n            self.use_kv_cache = False\n            self.motion_kv_cache = None\n            self.motion_model_context_len = 0\n            self.motion_effective_context_len = 0\n            self.get_logger().warn(\"No KV-Cache inputs found in Motion Policy model!\")\n        self.get_logger().info(\"Dual policies loaded successfully\")\n\n    def load_model_config(self):\n        \"\"\"Load config.yaml from both velocity and motion model folders.\"\"\"\n        # Load velocity model config\n        velocity_model_folder = self.config_yaml.velocity_tracking_model_folder\n        velocity_config_dir = os.path.join(\n            get_package_share_directory(\"humanoid_control\"),\n            \"models\",\n            velocity_model_folder,\n        )\n        # Try different config file names for velocity model\n        config_names = [\"config.yaml\"]\n        velocity_config_path = None\n        \n        for config_name in config_names:\n            potential_path = os.path.join(velocity_config_dir, config_name)\n            if os.path.exists(potential_path):\n                velocity_config_path = potential_path\n                break\n        \n        if velocity_config_path is None:\n            raise FileNotFoundError(\n                f\"No config file found in {velocity_config_dir}. Tried: {config_names}\"\n            )\n\n        self.get_logger().info(\n            f\"Loading velocity model config from {velocity_config_path}\"\n        )\n        self.velocity_config = OmegaConf.load(velocity_config_path)\n\n        # Load motion model config\n        motion_model_folder = self.config_yaml.motion_tracking_model_folder\n        motion_config_dir = os.path.join(\n            get_package_share_directory(\"humanoid_control\"),\n            \"models\",\n            motion_model_folder,\n        )\n        # Try different config file names for motion model\n        motion_config_path = None\n        \n        for config_name in config_names:\n            potential_path = os.path.join(motion_config_dir, config_name)\n            if os.path.exists(potential_path):\n                motion_config_path = potential_path\n                break\n        \n        if motion_config_path is None:\n            raise FileNotFoundError(\n                f\"No config file found in {motion_config_dir}. Tried: {config_names}\"\n            )\n\n        self.get_logger().info(f\"Loading motion model config from {motion_config_path}\")\n        self.motion_config = OmegaConf.load(motion_config_path)\n        self._load_motion_action_ema_filter_cfg()\n        self.actor_place_holder_ndim = self._find_actor_place_holder_ndim()\n        self.n_fut_frames = int(self.motion_config.obs.n_fut_frames)\n        self.torso_body_idx = self.motion_config.robot.body_names.index(\"torso_link\")\n        self.get_logger().info(\"Both model configs loaded successfully\")\n\n    def _load_motion_action_ema_filter_cfg(self) -> None:\n        actuator_cfg = self.motion_config.get(\"robot\", {}).get(\"actuators\", {})\n        enabled_raw = actuator_cfg.get(\"ema_filter_enabled\", None)\n        alpha_raw = actuator_cfg.get(\"ema_filter_alpha\", None)\n\n        if enabled_raw is None or alpha_raw is None:\n            self.motion_action_ema_filter_enabled = False\n            self.motion_action_ema_filter_alpha = 1.0\n            self.get_logger().info(\n                \"[Motion EMA] ema_filter_enabled/ema_filter_alpha not found in motion config; EMA disabled.\"\n            )\n            return\n\n        self.motion_action_ema_filter_enabled = _coerce_config_bool(\n            enabled_raw, default=False\n        )\n        self.motion_action_ema_filter_alpha = float(alpha_raw)\n        if not 0.0 <= self.motion_action_ema_filter_alpha <= 1.0:\n            raise ValueError(\n                \"motion_config robot.actuators.ema_filter_alpha must be within [0, 1], \"\n                f\"got {self.motion_action_ema_filter_alpha}.\"\n            )\n        self.get_logger().info(\n            \"[Motion EMA] Loaded from motion config: \"\n            f\"enabled={self.motion_action_ema_filter_enabled}, \"\n            f\"alpha={self.motion_action_ema_filter_alpha:.4f}\"\n        )\n\n    def _reset_motion_action_ema_filter(self) -> None:\n        self._motion_filtered_actions_onnx = None\n\n    def _apply_motion_action_ema_filter(\n        self, raw_actions: np.ndarray\n    ) -> np.ndarray:\n        raw_actions = np.asarray(raw_actions, dtype=np.float32).reshape(-1)\n        if not self.motion_action_ema_filter_enabled:\n            return raw_actions.copy()\n\n        if self._motion_filtered_actions_onnx is None:\n            self._motion_filtered_actions_onnx = raw_actions.copy()\n            return self._motion_filtered_actions_onnx.copy()\n\n        alpha = float(self.motion_action_ema_filter_alpha)\n        filtered_actions = (\n            alpha * raw_actions\n            + (1.0 - alpha) * self._motion_filtered_actions_onnx\n        ).astype(np.float32, copy=False)\n        self._motion_filtered_actions_onnx = filtered_actions.copy()\n        return self._motion_filtered_actions_onnx.copy()\n\n    def _build_dummy_input_from_onnx_node(self, node, fallback_last_dim: int | None = None):\n        shape = list(getattr(node, \"shape\", []) or [])\n        if not shape:\n            shape = [1]\n        inferred_shape = [_infer_onnx_dim(dim, default=1) for dim in shape]\n        if fallback_last_dim is not None and len(inferred_shape) >= 2:\n            last_dim = shape[-1]\n            if not isinstance(last_dim, int) or last_dim <= 0:\n                inferred_shape[-1] = int(fallback_last_dim)\n        dtype = _infer_numpy_dtype_from_onnx_type(getattr(node, \"type\", \"tensor(float)\"))\n        return np.zeros(inferred_shape, dtype=dtype)\n\n    def _warmup_motion_policy(self, num_iters: int = 2) -> None:\n        if self.motion_policy_session is None:\n            return\n\n        try:\n            input_nodes = {node.name: node for node in self.motion_policy_session.get_inputs()}\n            obs_node = input_nodes.get(self.motion_input_name, None)\n            if obs_node is None:\n                raise ValueError(\n                    f\"Motion warmup failed to find obs input '{self.motion_input_name}'.\"\n                )\n\n            motion_obs_dim = None\n            try:\n                motion_obs_dim = int(\n                    self.motion_obs_builder.build_policy_obs().shape[0]\n                )\n            except Exception:\n                motion_obs_dim = None\n\n            obs_dummy = self._build_dummy_input_from_onnx_node(\n                obs_node, fallback_last_dim=motion_obs_dim\n            )\n            output_names = [self.motion_output_name]\n            if self.motion_kv_output_name:\n                output_names.append(self.motion_kv_output_name)\n\n            local_kv_cache = None\n            if self.use_kv_cache and self.motion_kv_input_name is not None:\n                if self.motion_kv_cache is not None:\n                    local_kv_cache = np.zeros_like(self.motion_kv_cache)\n                else:\n                    shape = [\n                        _infer_onnx_dim(dim, default=1)\n                        for dim in (self.motion_kv_shape or [])\n                    ]\n                    local_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype)\n\n            for warmup_step in range(max(1, int(num_iters))):\n                input_feed = {self.motion_input_name: obs_dummy}\n                if self.use_kv_cache and self.motion_kv_input_name is not None:\n                    input_feed[self.motion_kv_input_name] = local_kv_cache\n                if self.motion_step_idx_input_name is not None:\n                    step_node = input_nodes.get(self.motion_step_idx_input_name, None)\n                    step_dtype = np.int64\n                    if step_node is not None:\n                        step_dtype = _infer_numpy_dtype_from_onnx_type(\n                            getattr(step_node, \"type\", \"tensor(int64)\")\n                        )\n                    input_feed[self.motion_step_idx_input_name] = np.array(\n                        [warmup_step], dtype=step_dtype\n                    )\n\n                warmup_output = self.motion_policy_session.run(output_names, input_feed)\n                if (\n                    local_kv_cache is not None\n                    and self.motion_kv_output_name\n                    and len(warmup_output) > 1\n                ):\n                    local_kv_cache = warmup_output[1]\n\n            if self.motion_kv_cache is not None:\n                self.motion_kv_cache.fill(0)\n            self.motion_step_idx = 0\n            self.get_logger().info(\n                f\"[Warmup] Motion policy warmup completed ({max(1, int(num_iters))} iterations, KV cache kept clean).\"\n            )\n        except Exception as exc:\n            if self.motion_kv_cache is not None:\n                self.motion_kv_cache.fill(0)\n            self.motion_step_idx = 0\n            self.get_logger().warn(f\"[Warmup] Motion policy warmup skipped: {exc}\")\n\n    def update_config_parameters(self):\n        \"\"\"Update configuration parameters from loaded configs.\"\"\"\n        # Check if both models have the same basic parameters\n        velocity_actions_dim = self.velocity_config.get(\"robot\", {}).get(\"actions_dim\", 29)\n        motion_actions_dim = self.motion_config.get(\"robot\", {}).get(\"actions_dim\", 29)\n        \n        velocity_dof_names = self.velocity_config.get(\"robot\", {}).get(\"dof_names\", [])\n        motion_dof_names = self.motion_config.get(\"robot\", {}).get(\"dof_names\", [])\n        \n        # Verify that both models have compatible configurations\n        if velocity_actions_dim != motion_actions_dim:\n            self.get_logger().warn(\n                f\"Different actions_dim: velocity={velocity_actions_dim}, \"\n                f\"motion={motion_actions_dim}\"\n            )\n\n        if velocity_dof_names != motion_dof_names:\n            self.get_logger().warn(f\"Different dof_names between models\")\n            self.get_logger().warn(f\"Velocity dof_names: {len(velocity_dof_names)} items\")\n            self.get_logger().warn(f\"Motion dof_names: {len(motion_dof_names)} items\")\n        \n        # Use velocity config as the primary source for basic parameters\n        config = self.velocity_config\n        # Update basic parameters\n        self.actions_dim = config.get(\"robot\", {}).get(\"actions_dim\", 29)\n        self.real_dof_names = config.get(\"robot\", {}).get(\"dof_names\", [])\n        self.dof_names_ref_motion = list(config.robot.dof_names)\n        self.num_actions = len(self.dof_names_ref_motion)\n\n        # Update arrays with correct sizes\n        self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32)\n        self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.target_dof_pos_onnx = self.default_angles_onnx.copy()\n        self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        \n        self.get_logger().info(\n            f\"Updated config parameters: actions_dim={self.actions_dim}, \"\n            f\"dof_names={len(self.real_dof_names)}\"\n        )\n\n    def load_motion_data(self):\n        \"\"\"Load motion clip data from .npz files.\"\"\"\n        motion_clips_dir = os.path.join(\n            get_package_share_directory(\"humanoid_control\"),\n            self.config_yaml.motion_clip_dir,\n        )\n        \n        self.get_logger().info(f\"Looking for motion clip data in: {motion_clips_dir}\")\n        self.get_logger().info(f\"Directory exists: {os.path.exists(motion_clips_dir)}\")\n\n        if not os.path.exists(motion_clips_dir):\n            self.get_logger().warn(f\"Motion clips directory not found: {motion_clips_dir}\")\n            return\n\n        # Only collect .npz files\n        motion_clip_files = [f for f in os.listdir(motion_clips_dir) if f.endswith(\".npz\")]\n        motion_clip_files.sort()\n        self.get_logger().info(\n            f\"Found {len(motion_clip_files)} motion clip files (.npz): {motion_clip_files}\"\n        )\n        if not motion_clip_files:\n            self.get_logger().warn(\n                f\"No motion clip files (.npz) found in directory: {motion_clips_dir}\"\n            )\n            return\n\n        # Load each .npz file\n        self.all_motion_data = []\n        self.motion_file_names = []\n        for motion_clip_file in motion_clip_files:\n            motion_path = os.path.join(motion_clips_dir, motion_clip_file)\n            motion_data_dict = dict(np.load(motion_path, allow_pickle=True))\n\n            self.all_motion_data.append(\n                {\n                    \"dof_pos\": motion_data_dict[\"ref_dof_pos\"],\n                    \"dof_vel\": motion_data_dict[\"ref_dof_vel\"],\n                    \"global_translation\": motion_data_dict[\n                        \"ref_global_translation\"\n                    ],\n                    \"global_rotation_quat\": motion_data_dict[\n                        \"ref_global_rotation_quat\"\n                    ],\n                    \"global_velocity\": motion_data_dict[\"ref_global_velocity\"],\n                    \"global_angular_velocity\": motion_data_dict[\"ref_global_angular_velocity\"],\n                    \"n_frames\": motion_data_dict[\"ref_dof_pos\"].shape[0],\n                }\n            )\n            self.motion_file_names.append(motion_clip_file)\n        \n        if not self.all_motion_data:\n            self.get_logger().error(\"Failed to load any motion clip files\")\n            return\n\n        # Initialize with the first motion clip\n        self.current_motion_clip_index = 0\n        self._load_current_motion()\n        \n        self.get_logger().info(f\"Loaded {len(self.all_motion_data)} motion clips successfully\")\n        self.get_logger().info(\n            f\"Current motion clip: {self.motion_file_names[self.current_motion_clip_index]}\"\n        )\n\n    def _load_current_motion(self):\n        \"\"\"Load the current selected motion clip data.\"\"\"\n        if not self.all_motion_data:\n            return\n            \n        self.motion_frame_idx = 0\n        current_motion = self.all_motion_data[self.current_motion_clip_index]\n        self.ref_dof_pos = current_motion[\"dof_pos\"]\n        self.ref_dof_vel = current_motion[\"dof_vel\"]\n        self.ref_raw_bodylink_pos = current_motion[\"global_translation\"]\n        self.ref_raw_bodylink_rot = current_motion[\"global_rotation_quat\"]\n        self.ref_global_velocity = current_motion[\"global_velocity\"]\n        self.ref_global_angular_velocity = current_motion[\"global_angular_velocity\"]\n\n        self.n_motion_frames = current_motion[\"n_frames\"]\n        if self.ref_dof_pos is None or self.ref_dof_vel is None:\n            raise ValueError(\"Motion clip is missing ref_dof_pos/ref_dof_vel arrays\")\n        if self.ref_raw_bodylink_pos is None or self.ref_raw_bodylink_rot is None:\n            raise ValueError(\n                \"Motion clip is missing ref_global_translation/ref_global_rotation_quat arrays\"\n            )\n        if int(self.ref_dof_pos.shape[1]) != int(len(self.dof_names_ref_motion)):\n            raise ValueError(\n                \"ref_dof_pos DOF dimension mismatch: \"\n                f\"ref_dof_pos.shape[1]={int(self.ref_dof_pos.shape[1])} \"\n                f\"but len(dof_names_ref_motion)={int(len(self.dof_names_ref_motion))}\"\n            )\n        if int(self.ref_raw_bodylink_pos.shape[1]) != int(\n            len(self.motion_config.robot.body_names)\n        ):\n            raise ValueError(\n                \"ref_global_translation body dimension mismatch: \"\n                f\"ref_raw_bodylink_pos.shape[1]={int(self.ref_raw_bodylink_pos.shape[1])} \"\n                f\"but len(motion_config.robot.body_names)={int(len(self.motion_config.robot.body_names))}\"\n            )\n\n        self.motion_in_progress = True\n        self.get_logger().info(\n            f\"Loaded motion clip {self.current_motion_clip_index}: \"\n            f\"{self.motion_file_names[self.current_motion_clip_index]} ({self.n_motion_frames} frames)\"\n        )\n\n    def _setup_subscribers(self):\n        \"\"\"Set up ROS2 subscribers for robot state and remote controller input.\"\"\"\n        self.remote_controller = RemoteController()\n        self.low_state_sub = self.create_subscription(\n            LowState,\n            self.config_yaml.lowstate_topic,\n            self._low_state_callback,\n            QoSProfile(depth=10),\n        )\n\n        # Add robot_state topic subscription\n        self.robot_state_sub = self.create_subscription(\n            String,\n            \"/robot_state\",\n            self._robot_state_callback,\n            QoSProfile(depth=10),\n        )\n\n        self.latest_obs_ros_sub = self.create_subscription(\n            Float32MultiArray,\n            \"latest_obs_ros\",\n            self._latest_obs_ros_callback,\n            QoSProfile(depth=10),\n        )\n\n    def _latest_obs_ros_callback(self, msg: Float32MultiArray):\n        \"\"\"Receive replayed latest_obs_ros messages for offline validation.\"\"\"\n        data = np.asarray(msg.data, dtype=np.float32)\n        if data.size == 66:\n            frame_idx = int(data[0])\n            obs = data[1:66]\n            self._ros_latest_obs_buffer = (frame_idx, obs)\n        elif data.size >= 65:\n            self._ros_latest_obs_buffer = (None, data[:65])\n\n    def _setup_publishers(self):\n        \"\"\"Set up ROS2 publishers for action commands and status information.\"\"\"\n        self.action_pub = self.create_publisher(\n            Float32MultiArray,\n            self.config_yaml.action_topic,\n            QoSProfile(depth=10),\n        )\n        # Add publishers for kps and kds parameters\n        self.kps_pub = self.create_publisher(\n            Float32MultiArray,\n            \"/humanoid/kps\",\n            QoSProfile(depth=10),\n        )\n        self.kds_pub = self.create_publisher(\n            Float32MultiArray,\n            \"/humanoid/kds\",\n            QoSProfile(depth=10),\n        )\n        # Add publisher for policy mode status\n        self.policy_mode_pub = self.create_publisher(\n            String,\n            \"policy_mode\",\n            QoSProfile(depth=10),\n        )\n        self.latest_obs_pub = self.create_publisher(\n            Float32MultiArray,\n            \"latest_obs\",\n            QoSProfile(depth=10),\n        )\n\n    def _setup_timers(self):\n        \"\"\"Set up ROS2 timer for main execution loop.\"\"\"\n        # Create a one-time timer to call setup after ROS2 initialization\n        self.create_timer(0.1, self._delayed_setup)\n        self.create_timer(self.dt, self.run)\n\n\n    def _delayed_setup(self):\n        \"\"\"Call setup after ROS2 initialization is complete.\"\"\"\n        if not hasattr(self, '_setup_completed'):\n            self.get_logger().info(\"Starting policy node setup...\")\n            try:\n                self.setup()\n                self._setup_completed = True\n                self.get_logger().info(\"Policy node setup completed successfully\")\n            except Exception as e:\n                self.get_logger().error(f\"Policy node setup failed: {e}\")\n                # Cancel the timer to avoid repeated attempts\n                return\n\n\n    def _robot_state_callback(self, msg: String):\n        \"\"\"Handle robot state messages for safety coordination.\n\n        Processes robot state updates from the main control node to ensure\n        safe operation. Button operations are only allowed when the robot\n        is in MOVE_TO_DEFAULT state.\n\n        Args:\n            msg: String message containing robot state information\n                Valid states: ZERO_TORQUE, MOVE_TO_DEFAULT, EMERGENCY_STOP, POLICY\n        \"\"\"\n        robot_state = msg.data\n        # Only allow button operations when robot state is MOVE_TO_DEFAULT\n        if robot_state == \"MOVE_TO_DEFAULT\":\n            self.robot_state_ready = True\n        elif robot_state == \"ZERO_TORQUE\":\n            self.robot_state_ready = False\n        elif robot_state == \"EMERGENCY_STOP\":\n            self.robot_state_ready = False\n\n    # =========== Properties ===========\n\n    @property\n    def robot_root_rot_quat_wxyz(self):\n        return np.array(self._lowstate_msg.imu_state.quaternion, dtype=np.float32)\n\n    @property\n    def robot_root_ang_vel(self):\n        return np.array(self._lowstate_msg.imu_state.gyroscope, dtype=np.float32)\n\n    @property\n    def robot_dof_pos_by_name(self):\n        \"\"\"Get DOF positions by name.\"\"\"\n        if self._lowstate_msg is None:\n            return {}\n        return {\n            self.real_dof_names[i]: float(self._lowstate_msg.motor_state[i].q)\n            for i in range(self.actions_dim)\n        }\n\n    @property\n    def robot_dof_vel_by_name(self):\n        \"\"\"Get DOF velocities by name.\"\"\"\n        if self._lowstate_msg is None:\n            return {}\n        return {\n            self.real_dof_names[i]: float(self._lowstate_msg.motor_state[i].dq)\n            for i in range(self.actions_dim)\n        }\n\n    @property\n    def ref_motion_frame_idx(self):\n        return min(self.motion_frame_idx, self.n_motion_frames - 1)\n\n    @property\n    def ref_dof_pos_raw(self):\n        if not self.latest_obs_flag:\n            return self.ref_dof_pos[self.ref_motion_frame_idx]\n        if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None:\n            if self._prev_external_dof_pos is not None:\n                return self._prev_external_dof_pos\n            return self.external_fut_dof_pos_queue[0]\n        if self.external_latest_obs is None:\n            return self.ref_dof_pos[self.ref_motion_frame_idx]\n        return self.external_latest_obs[0, :29]\n\n    @property\n    def ref_dof_vel_raw(self):\n        if not self.latest_obs_flag:\n            return self.ref_dof_vel[self.ref_motion_frame_idx]\n        if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None:\n            if self._prev_external_dof_vel is not None:\n                return self._prev_external_dof_vel\n            return self.external_fut_dof_vel_queue[0]\n        if self.external_latest_obs is None:\n            return self.ref_dof_vel[self.ref_motion_frame_idx]\n        return self.external_latest_obs[0, 29:58]\n\n    @property\n    def ref_dof_pos_onnx_order(self):\n        return self.ref_dof_pos_raw[self.ref_to_onnx]\n\n    @property\n    def ref_dof_vel_onnx_order(self):\n        return self.ref_dof_vel_raw[self.ref_to_onnx]\n\n    @property\n    def ref_root_pos_raw(self):\n        if not self.latest_obs_flag:\n            return np.asarray(\n                self.ref_raw_bodylink_pos[self.ref_motion_frame_idx, self.root_body_idx],\n                dtype=np.float32,\n            )\n        if self.n_fut_frames > 0 and self.external_fut_root_pos_queue is not None:\n            if self._prev_external_root_pos is not None:\n                return self._prev_external_root_pos.astype(np.float32)\n            return self.external_fut_root_pos_queue[0].astype(np.float32)\n        if self.external_latest_obs is None:\n            return np.zeros(3, dtype=np.float32)\n        return self.external_latest_obs[0, 58:61].astype(np.float32)\n\n    @property\n    def root_body_idx(self):\n        return 0\n\n    @property\n    def last_valid_ref_motion_frame_idx(self):\n        return self.n_motion_frames - 1\n\n    # =========== Policy Obeservation Methods ===========\n    def _xyzw_to_wxyz(self, q_xyzw: np.ndarray) -> np.ndarray:\n        \"\"\"Convert quaternions from xyzw to wxyz order.\"\"\"\n        q_xyzw = np.asarray(q_xyzw, dtype=np.float32)\n        if q_xyzw.shape[-1] != 4:\n            raise ValueError(f\"_xyzw_to_wxyz expects (...,4) but got shape {q_xyzw.shape}\")\n        # q_xyzw[..., 0:3] -> xyz, q_xyzw[..., 3:4] -> w\n        w = q_xyzw[..., 3:4]\n        xyz = q_xyzw[..., 0:3]\n        return np.concatenate([w, xyz], axis=-1)\n\n    def _standardize_quaternion_wxyz(self, q_wxyz: np.ndarray) -> np.ndarray:\n        \"\"\"Standardize quaternion sign so that w >= 0.\"\"\"\n        q_wxyz = np.asarray(q_wxyz, dtype=np.float32)\n        if q_wxyz.shape[-1] != 4:\n            raise ValueError(f\"_standardize_quaternion_wxyz expects (...,4) but got shape {q_wxyz.shape}\")\n        mask = q_wxyz[..., 0:1] < 0.0\n        q_wxyz = np.where(mask, -q_wxyz, q_wxyz)\n        return q_wxyz\n\n    def _quat_rotate_wxyz(self, q_wxyz: np.ndarray, v: np.ndarray) -> np.ndarray:\n        q_wxyz = np.asarray(q_wxyz, dtype=np.float32)\n        v = np.asarray(v, dtype=np.float32)\n        qvec = q_wxyz[..., 1:4]\n        w = q_wxyz[..., 0:1]\n        t = 2.0 * np.cross(qvec, v)\n        return v + w * t + np.cross(qvec, t)\n\n    def _quat_rotate_inv_wxyz(self, q_wxyz: np.ndarray, v: np.ndarray) -> np.ndarray:\n        q_wxyz = np.asarray(q_wxyz, dtype=np.float32)\n        n = int(np.prod(q_wxyz.shape[:-1])) if q_wxyz.ndim > 1 else 1\n        q_conj = self._q_conj_buffer[:n].reshape(q_wxyz.shape)\n        q_conj[..., 0] = q_wxyz[..., 0]\n        q_conj[..., 1:4] = -q_wxyz[..., 1:4]\n        return self._quat_rotate_wxyz(q_conj, v)\n\n    def _quat_rotate_inv_wxyz_single(\n        self, q_wxyz: np.ndarray, v: np.ndarray, out: np.ndarray\n    ) -> np.ndarray:\n        \"\"\"Rotate one 3D vector by the inverse quaternion into a preallocated output.\"\"\"\n        q_conj = self._q_conj_buffer[0]\n        q_conj[0] = q_wxyz[0]\n        q_conj[1] = -q_wxyz[1]\n        q_conj[2] = -q_wxyz[2]\n        q_conj[3] = -q_wxyz[3]\n        qvec = q_conj[1:4]\n        w = q_conj[0]\n        self._cross_t_buffer[:] = np.cross(qvec, v)\n        self._cross_t_buffer *= 2.0\n        out[:] = v + w * self._cross_t_buffer\n        self._cross_t_buffer[:] = np.cross(qvec, self._cross_t_buffer)\n        out += self._cross_t_buffer\n        return out\n\n    def _get_future_frame_indices(self) -> np.ndarray:\n        frame_idx = self.ref_motion_frame_idx\n        last_valid = self.last_valid_ref_motion_frame_idx\n        np.minimum(\n            frame_idx + self._future_frame_offsets,\n            last_valid,\n            out=self._future_frame_indices_buffer,\n        )\n        return self._future_frame_indices_buffer\n\n    def _cache_fk_vr_for_obs(self):\n        \"\"\"Cache FK outputs used repeatedly during observation construction.\"\"\"\n        fk = getattr(self, \"_fk_vr_out\", None)\n        if not getattr(self, \"latest_obs_flag\", False) or fk is None:\n            self._use_fk_vr = False\n            return\n        self._use_fk_vr = True\n        T = self.n_fut_frames_int\n        rb = self.root_body_idx\n        np.copyto(self._fk_vel_0_root, fk[\"global_velocity\"][0, 0, rb])\n        np.copyto(self._fk_angvel_0_root, fk[\"global_angular_velocity\"][0, 0, rb])\n        np.copyto(self._fk_quat_0_root, fk[\"global_rotation_quat\"][0, 0, rb])\n        self._fk_quat_0_root_wxyz[0] = self._fk_quat_0_root[3]\n        self._fk_quat_0_root_wxyz[1:4] = self._fk_quat_0_root[:3]\n        if self._fk_quat_0_root_wxyz[0] < 0.0:\n            self._fk_quat_0_root_wxyz *= -1.0\n        trans_0 = fk[\"global_translation\"][0, 0]\n        if self._fk_trans_0 is None or self._fk_trans_0.shape != trans_0.shape:\n            self._fk_trans_0 = np.empty_like(trans_0)\n        np.copyto(self._fk_trans_0, trans_0)\n        if T > 0:\n            np.copyto(self._fk_vel_fut[:T], fk[\"global_velocity\"][0, 1 : 1 + T, rb])\n            np.copyto(self._fk_angvel_fut[:T], fk[\"global_angular_velocity\"][0, 1 : 1 + T, rb])\n            np.copyto(self._fk_quat_fut[:T], fk[\"global_rotation_quat\"][0, 1 : 1 + T, rb])\n            self._fk_quat_fut_wxyz[:T, 0] = self._fk_quat_fut[:T, 3]\n            self._fk_quat_fut_wxyz[:T, 1:4] = self._fk_quat_fut[:T, :3]\n            neg = self._fk_quat_fut_wxyz[:T, 0] < 0.0\n            self._fk_quat_fut_wxyz[:T][neg] *= -1.0\n            trans_fut = fk[\"global_translation\"][0, 1 : 1 + T]\n            if self._fk_trans_fut is None or self._fk_trans_fut.shape != trans_fut.shape:\n                self._fk_trans_fut = np.empty_like(trans_fut)\n            np.copyto(self._fk_trans_fut, trans_fut)\n            self._fill_vr_base_linvel_angvel_fut()\n\n    def _fill_vr_base_linvel_angvel_fut(self):\n        \"\"\"Rotate future linear and angular velocity buffers in one pass.\"\"\"\n        T = self.n_fut_frames_int\n        if T <= 0:\n            return\n        vel_T6 = self._vel_fut_T6[:T]\n        vel_T6[:, :3] = self._fk_vel_fut[:T]\n        vel_T6[:, 3:6] = self._fk_angvel_fut[:T]\n        q = self._fk_quat_fut_wxyz[:T]\n        q_conj = self._q_conj_buffer[:T].reshape(T, 4)\n        q_conj[:, 0] = q[:, 0]\n        q_conj[:, 1:4] = -q[:, 1:4]\n        qvec = q_conj[:, 1:4]\n        w = q_conj[:, 0:1]\n        rt = self._rot_t_buffer[:T]\n        rc = self._rot_cross_buffer[:T]\n        rt[:] = np.cross(qvec, vel_T6[:, :3])\n        rt *= 2.0\n        rc[:] = np.cross(qvec, rt)\n        self._base_linvel_fut_buffer[:T] = vel_T6[:, :3] + w * rt + rc\n        rt[:] = np.cross(qvec, vel_T6[:, 3:6])\n        rt *= 2.0\n        rc[:] = np.cross(qvec, rt)\n        self._base_angvel_fut_buffer[:T] = vel_T6[:, 3:6] + w * rt + rc\n\n    def _prepare_vr_fk_tensors(\n        self,\n        cur_root_pos: np.ndarray,\n        cur_root_rot: np.ndarray,\n        cur_dof_pos: np.ndarray,\n        n_fut: int,\n    ):\n        \"\"\"Fill preallocated FK input buffers and return torch views without reallocation.\"\"\"\n        if (\n            n_fut <= 0\n            or self._fk_root_pos_seq_np is None\n            or self._fk_root_rot_seq_np is None\n            or self._fk_dof_pos_seq_np is None\n        ):\n            raise ValueError(\"VR FK sequence buffers are not initialized\")\n\n        np.copyto(self._fk_root_pos_seq_np[0, 0], cur_root_pos)\n        np.copyto(self._fk_root_rot_seq_np[0, 0], cur_root_rot)\n        np.copyto(self._fk_dof_pos_seq_np[0, 0], cur_dof_pos)\n        np.copyto(\n            self._fk_root_pos_seq_np[0, 1 : 1 + n_fut],\n            self.external_fut_root_pos_queue[:n_fut],\n        )\n        np.copyto(\n            self._fk_root_rot_seq_np[0, 1 : 1 + n_fut],\n            self.external_fut_root_rot_queue[:n_fut],\n        )\n        np.copyto(\n            self._fk_dof_pos_seq_np[0, 1 : 1 + n_fut],\n            self.external_fut_dof_pos_queue[:n_fut],\n        )\n        return (\n            self._fk_root_pos_seq_tensor,\n            self._fk_root_rot_seq_tensor,\n            self._fk_dof_pos_seq_tensor,\n        )\n\n    def _get_future_root_quat_wxyz(self) -> np.ndarray:\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable; future_root_quat_wxyz will return zeros.\"\n            )\n            return self._future_root_quat_wxyz_buffer\n\n        fut_idx = self._get_future_frame_indices()\n        q_root_xyzw = np.asarray(\n            self.ref_raw_bodylink_rot[fut_idx, self.root_body_idx],\n            dtype=np.float32,\n        )\n        q_root_wxyz = self._future_root_quat_wxyz_buffer\n        q_root_wxyz[:, 0] = q_root_xyzw[:, 3]\n        q_root_wxyz[:, 1] = q_root_xyzw[:, 0]\n        q_root_wxyz[:, 2] = q_root_xyzw[:, 1]\n        q_root_wxyz[:, 3] = q_root_xyzw[:, 2]\n        neg_mask = q_root_wxyz[:, 0] < 0.0\n        q_root_wxyz[neg_mask] *= -1.0\n        return self._future_root_quat_wxyz_buffer\n\n    def _get_ref_keybody_indices(self, term_name: str) -> np.ndarray:\n        keybody_idxs = self._keybody_indices_by_term_name.get(term_name, None)\n        if keybody_idxs is None:\n            raise ValueError(\n                f\"Keybody indices for term '{term_name}' were not cached. \"\n                \"Ensure the term exists in motion policy obs and cache is initialized.\"\n            )\n        return keybody_idxs\n\n    def _get_obs_actor_velocity_command(self):\n        return self._get_obs_velocity_command()\n\n    def _get_obs_actor_projected_gravity(self):\n        return self._get_obs_projected_gravity()\n\n    def _get_obs_actor_rel_robot_root_ang_vel(self):\n        return self._get_obs_rel_robot_root_ang_vel()\n\n    def _get_obs_actor_dof_pos(self):\n        return self._get_obs_dof_pos()\n\n    def _get_obs_actor_dof_vel(self):\n        return self._get_obs_dof_vel()\n        \n    def _get_obs_actor_last_action(self):\n        return self._get_obs_last_action()\n\n    def _get_obs_actor_ref_gravity_projection_cur(self):\n        return self._get_obs_ref_gravity_projection_cur()\n\n    def _get_obs_actor_ref_gravity_projection_fut(self):\n        return self._get_obs_ref_gravity_projection_fut()\n\n    def _get_obs_actor_ref_base_linvel_cur(self):\n        return self._get_obs_ref_base_linvel_cur()\n\n    def _get_obs_actor_ref_base_linvel_fut(self):\n        return self._get_obs_ref_base_linvel_fut()\n\n    def _get_obs_actor_ref_base_angvel_cur(self):\n        return self._get_obs_ref_base_angvel_cur()\n\n    def _get_obs_actor_ref_base_angvel_fut(self):\n        return self._get_obs_ref_base_angvel_fut()\n\n    def _get_obs_actor_ref_dof_pos_cur(self):\n        return self._get_obs_ref_dof_pos_cur()\n\n    def _get_obs_actor_ref_dof_pos_fut(self):\n        return self._get_obs_ref_dof_pos_fut()\n\n    def _get_obs_actor_ref_root_height_cur(self):\n        return self._get_obs_ref_root_height_cur()\n\n    def _get_obs_actor_ref_root_height_fut(self):\n        return self._get_obs_ref_root_height_fut()\n\n    def _get_obs_actor_ref_keybody_rel_pos_cur(self):\n        return self._get_obs_ref_keybody_rel_pos_cur()\n\n    def _get_obs_actor_ref_keybody_rel_pos_fut(self):\n        return self._get_obs_ref_keybody_rel_pos_fut()\n\n\n\n    def _get_obs_velocity_command(self):\n        \"\"\"Get velocity command observation (reuses pre-allocated array).\"\"\"\n        self._velocity_cmd_obs[1] = self.vx\n        self._velocity_cmd_obs[2] = self.vy\n        self._velocity_cmd_obs[3] = self.vyaw\n        self._velocity_cmd_obs[0] = float(\n            np.linalg.norm(self._velocity_cmd_obs[1:4]) > 0.1\n        )\n        return self._velocity_cmd_obs\n\n    def _get_obs_projected_gravity(self):\n        return get_gravity_orientation(self.robot_root_rot_quat_wxyz)\n\n    def _get_obs_rel_robot_root_ang_vel(self):\n        return self.robot_root_ang_vel\n\n    def _get_obs_dof_pos(self):\n        \"\"\"Get DOF position observation (pre-allocated buffer + index lookup, no dict/list).\"\"\"\n        if self._lowstate_msg is None:\n            return self._dof_pos_obs_buffer[: len(self.motion_dof_names_onnx)]\n        if self.current_policy_mode == \"motion\":\n            buf = self._dof_pos_obs_buffer\n            ms = self._lowstate_msg.motor_state\n            def_angles = self.motion_default_angles_onnx\n            for i, ri in enumerate(self.motion_dof_real_indices):\n                buf[i] = ms[ri].q - def_angles[i]\n            return buf[: len(self.motion_dof_names_onnx)]\n        def_angles = self.velocity_default_angles_onnx\n        for i, ri in enumerate(self.velocity_dof_real_indices):\n            self._dof_pos_obs_buffer[i] = (\n                self._lowstate_msg.motor_state[ri].q - def_angles[i]\n            )\n        return self._dof_pos_obs_buffer[: len(self.velocity_dof_names_onnx)]\n\n    def _get_obs_dof_vel(self):\n        \"\"\"Get DOF velocity observation (pre-allocated buffer + index lookup, no dict/list).\"\"\"\n        if self._lowstate_msg is None:\n            return self._dof_vel_obs_buffer[: len(self.motion_dof_names_onnx)]\n        if self.current_policy_mode == \"motion\":\n            buf = self._dof_vel_obs_buffer\n            ms = self._lowstate_msg.motor_state\n            for i, ri in enumerate(self.motion_dof_real_indices):\n                buf[i] = ms[ri].dq\n            return buf[: len(self.motion_dof_names_onnx)]\n        for i, ri in enumerate(self.velocity_dof_real_indices):\n            self._dof_vel_obs_buffer[i] = self._lowstate_msg.motor_state[ri].dq\n        return self._dof_vel_obs_buffer[: len(self.velocity_dof_names_onnx)]\n\n    def _get_obs_last_action(self):\n        return self.actions_onnx.copy()\n\n    def _get_obs_ref_motion_states(self):\n        return np.concatenate(\n            [self.ref_dof_pos_onnx_order, self.ref_dof_vel_onnx_order]\n        )\n\n    def _get_obs_ref_dof_pos_fut(self):\n        \"\"\"Get future DOF position observation (reuses pre-allocated buffer).\"\"\"\n        T = self.n_fut_frames_int\n        if T <= 0:\n            return np.zeros(0, dtype=np.float32)\n        if getattr(self, \"latest_obs_flag\", False):\n            if (\n                getattr(self, \"external_fut_dof_pos_queue\", None) is not None\n                and self.external_fut_dof_pos_queue.shape[0] >= T\n            ):\n                pos_fut = self._pos_fut_buffer\n                pos_fut[:, :] = self.external_fut_dof_pos_queue[:T].T\n                pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0)  # [N, T]\n                return pos_fut_onnx.reshape(-1).astype(np.float32)\n            return np.zeros(self.num_actions * T, dtype=np.float32)\n        if not hasattr(self, \"ref_dof_pos\") or self.ref_dof_pos is None:\n            self.get_logger().warn(\n                \"[VR] ref_dof_pos is unavailable and latest_obs is not active; returning zeros for ref_dof_pos_fut.\"\n            )\n            return np.zeros(self.num_actions * T, dtype=np.float32)\n        fut_idx = self._get_future_frame_indices()\n        pos_fut = self._pos_fut_buffer\n        pos_fut[:, :] = self.ref_dof_pos[fut_idx].T\n        # Reorder to ONNX and flatten per training layout\n        pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0)  # [N, T]\n        return pos_fut_onnx.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_root_height_fut(self):\n        \"\"\"Get future root height observation (reuses pre-allocated buffer).\"\"\"\n        T = self.n_fut_frames_int\n        if T <= 0:\n            return np.zeros(0, dtype=np.float32)\n        if self.latest_obs_flag and self.external_fut_root_pos_queue is not None:\n            root_pos_fut = self.external_fut_root_pos_queue[:, 2].astype(np.float32)\n            return root_pos_fut.reshape(-1)\n        if not hasattr(self, \"ref_raw_bodylink_pos\") or self.ref_raw_bodylink_pos is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_root_height_fut.\"\n            )\n            return np.zeros(T, dtype=np.float32)\n        fut_idx = self._get_future_frame_indices()\n        h_fut = self._h_fut_buffer\n        h_fut[0, :] = self.ref_raw_bodylink_pos[fut_idx, self.root_body_idx, 2]\n        return h_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_root_pos_fut(self):\n        \"\"\"Get future root position observation (reuses pre-allocated buffer).\"\"\"\n        T = self.n_fut_frames_int\n        if T <= 0:\n            return np.zeros(0, dtype=np.float32)\n        if self.latest_obs_flag and self.external_fut_root_pos_queue is not None:\n            pos_fut = self.external_fut_root_pos_queue.astype(np.float32)\n            return pos_fut.reshape(-1).astype(np.float32)\n        if not hasattr(self, \"ref_raw_bodylink_pos\") or self.ref_raw_bodylink_pos is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_root_pos_fut.\"\n            )\n            return np.zeros(3 * T, dtype=np.float32)\n        fut_idx = self._get_future_frame_indices()\n        pos_fut = self._root_pos_fut_buffer\n        pos_fut[:, :] = self.ref_raw_bodylink_pos[fut_idx, self.root_body_idx, :]\n        return pos_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_dof_pos_cur(self):\n        return self.ref_dof_pos_onnx_order\n\n    def _get_obs_ref_dof_vel_cur(self):\n        return self.ref_dof_vel_onnx_order\n\n    def _get_obs_ref_root_height_cur(self):\n        if not self.latest_obs_flag:\n            return self.ref_raw_bodylink_pos[\n                self.ref_motion_frame_idx, self.root_body_idx, 2\n            ]\n        return float(self.ref_root_pos_raw[2])\n\n    def _get_obs_ref_root_pos_cur(self):\n        return self.ref_root_pos_raw.astype(np.float32)\n\n    def _get_obs_ref_gravity_projection_cur(self):\n        if getattr(self, \"_use_fk_vr\", False):\n            return get_gravity_orientation(self._fk_quat_0_root_wxyz)\n        if getattr(self, \"latest_obs_flag\", False) and getattr(\n            self, \"external_latest_obs\", None\n        ) is not None:\n            q_root_wxyz = self.external_latest_obs[0, 61:65].astype(np.float32)\n            q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)\n            return get_gravity_orientation(q_root_wxyz)\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for gravity_projection_cur.\"\n            )\n            return np.zeros(3, dtype=np.float32)\n        q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx]\n        q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw)\n        q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)\n        return get_gravity_orientation(q_root_wxyz)\n\n    def _get_obs_ref_gravity_projection_fut(self):\n        T = self.n_fut_frames_int\n        if T <= 0:\n            return np.zeros(0, dtype=np.float32)\n        if getattr(self, \"_use_fk_vr\", False):\n            q_root_wxyz = self._fk_quat_fut_wxyz[:T]\n            gravity_fut = self._gravity_fut_buffer\n            qw = q_root_wxyz[:, 0]\n            qx = q_root_wxyz[:, 1]\n            qy = q_root_wxyz[:, 2]\n            qz = q_root_wxyz[:, 3]\n            gravity_fut[:, 0] = 2.0 * (-qz * qx + qw * qy)\n            gravity_fut[:, 1] = -2.0 * (qz * qy + qw * qx)\n            gravity_fut[:, 2] = 1.0 - 2.0 * (qw * qw + qz * qz)\n            return gravity_fut.reshape(-1).astype(np.float32)\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_gravity_projection_fut.\"\n            )\n            return np.zeros(3 * T, dtype=np.float32)\n        q_root_wxyz = self._get_future_root_quat_wxyz()\n        gravity_fut = self._gravity_fut_buffer\n        qw = q_root_wxyz[:, 0]\n        qx = q_root_wxyz[:, 1]\n        qy = q_root_wxyz[:, 2]\n        qz = q_root_wxyz[:, 3]\n        gravity_fut[:, 0] = 2.0 * (-qz * qx + qw * qy)\n        gravity_fut[:, 1] = -2.0 * (qz * qy + qw * qx)\n        gravity_fut[:, 2] = 1.0 - 2.0 * (qw * qw + qz * qz)\n        return gravity_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_base_linvel_cur(self):\n        if getattr(self, \"_use_fk_vr\", False):\n            self._quat_rotate_inv_wxyz_single(\n                self._fk_quat_0_root_wxyz, self._fk_vel_0_root, self._rotated_3vec_buffer\n            )\n            return self._rotated_3vec_buffer\n        if getattr(self, \"latest_obs_flag\", False) and getattr(\n            self, \"external_latest_obs\", None\n        ) is not None:\n            return np.zeros(3, dtype=np.float32)\n        if not hasattr(self, \"ref_global_velocity\") or self.ref_global_velocity is None:\n            self.get_logger().warn(\n                \"[VR] ref_global_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_cur.\"\n            )\n            return np.zeros(3, dtype=np.float32)\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_cur.\"\n            )\n            return np.zeros(3, dtype=np.float32)\n        q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx]\n        q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw)\n        q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)\n        v_root_w = np.asarray(\n            self.ref_global_velocity[self.ref_motion_frame_idx, self.root_body_idx],\n            dtype=np.float32,\n        )\n        v_root = self._quat_rotate_inv_wxyz(q_root_wxyz, v_root_w)\n        return np.asarray(v_root, dtype=np.float32).reshape(3)\n\n    def _get_obs_ref_base_linvel_fut(self):\n        T = self.n_fut_frames_int\n        if T <= 0:\n            return np.zeros(0, dtype=np.float32)\n        if getattr(self, \"_use_fk_vr\", False):\n            return self._base_linvel_fut_buffer[:T].reshape(-1).astype(np.float32)\n\n        if not hasattr(self, \"ref_global_velocity\") or self.ref_global_velocity is None:\n            self.get_logger().warn(\n                \"[VR] ref_global_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_fut.\"\n            )\n            return np.zeros(3 * T, dtype=np.float32)\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_fut.\"\n            )\n            return np.zeros(3 * T, dtype=np.float32)\n        fut_idx = self._get_future_frame_indices()\n        q_root_wxyz = self._get_future_root_quat_wxyz()\n        v_root_w = np.asarray(\n            self.ref_global_velocity[fut_idx, self.root_body_idx],\n            dtype=np.float32,\n        )\n        base_linvel_fut = self._base_linvel_fut_buffer\n        base_linvel_fut[:, :] = self._quat_rotate_inv_wxyz(q_root_wxyz, v_root_w)\n        return base_linvel_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_base_angvel_cur(self):\n        if getattr(self, \"_use_fk_vr\", False):\n            self._quat_rotate_inv_wxyz_single(\n                self._fk_quat_0_root_wxyz,\n                self._fk_angvel_0_root,\n                self._rotated_angvel_cur_buffer,\n            )\n            return self._rotated_angvel_cur_buffer\n        if getattr(self, \"latest_obs_flag\", False) and getattr(\n            self, \"external_latest_obs\", None\n        ) is not None:\n            return np.zeros(3, dtype=np.float32)\n        if not hasattr(self, \"ref_global_angular_velocity\") or self.ref_global_angular_velocity is None:\n            self.get_logger().warn(\n                \"[VR] ref_global_angular_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_cur.\"\n            )\n            return np.zeros(3, dtype=np.float32)\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_cur.\"\n            )\n            return np.zeros(3, dtype=np.float32)\n        q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx]\n        q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw)\n        q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)\n        w_root_w = np.asarray(\n            self.ref_global_angular_velocity[self.ref_motion_frame_idx, self.root_body_idx],\n            dtype=np.float32,\n        )\n        w_root = self._quat_rotate_inv_wxyz(q_root_wxyz, w_root_w)\n        return np.asarray(w_root, dtype=np.float32).reshape(3)\n\n    def _get_obs_ref_base_angvel_fut(self):\n        T = self.n_fut_frames_int\n        if T <= 0:\n            return np.zeros(0, dtype=np.float32)\n        if getattr(self, \"_use_fk_vr\", False):\n            return self._base_angvel_fut_buffer[:T].reshape(-1).astype(np.float32)\n\n        if not hasattr(self, \"ref_global_angular_velocity\") or self.ref_global_angular_velocity is None:\n            self.get_logger().warn(\n                \"[VR] ref_global_angular_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_fut.\"\n            )\n            return np.zeros(3 * T, dtype=np.float32)\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_fut.\"\n            )\n            return np.zeros(3 * T, dtype=np.float32)\n        fut_idx = self._get_future_frame_indices()\n        q_root_wxyz = self._get_future_root_quat_wxyz()\n        w_root_w = np.asarray(\n            self.ref_global_angular_velocity[fut_idx, self.root_body_idx],\n            dtype=np.float32,\n        )\n        base_angvel_fut = self._base_angvel_fut_buffer\n        base_angvel_fut[:, :] = self._quat_rotate_inv_wxyz(q_root_wxyz, w_root_w)\n        return base_angvel_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_keybody_rel_pos_cur(self):\n        if getattr(self, \"_use_fk_vr\", False) and self._fk_trans_0 is not None:\n            keybody_idxs = self._get_ref_keybody_indices(\"actor_ref_keybody_rel_pos_cur\")\n            n_keybodies = int(keybody_idxs.shape[0])\n            if n_keybodies == 0:\n                return np.zeros(0, dtype=np.float32)\n            if not self._root_only_fk_has_required_keybodies(keybody_idxs):\n                return np.zeros(3 * n_keybodies, dtype=np.float32)\n            root_pos = self._fk_trans_0[self.root_body_idx]\n            keybody_pos = self._fk_trans_0[keybody_idxs]\n            rel_pos_w = keybody_pos - root_pos[None, :]\n            rel_pos_root = self._quat_rotate_inv_wxyz(self._fk_quat_0_root_wxyz, rel_pos_w)\n            return np.asarray(rel_pos_root, dtype=np.float32).reshape(-1)\n\n        if getattr(self, \"latest_obs_flag\", False) and getattr(\n            self, \"external_latest_obs\", None\n        ) is not None:\n            keybody_idxs = self._get_ref_keybody_indices(\"actor_ref_keybody_rel_pos_cur\")\n            n_keybodies = int(keybody_idxs.shape[0])\n            if n_keybodies == 0:\n                return np.zeros(0, dtype=np.float32)\n            return np.zeros(3 * n_keybodies, dtype=np.float32)\n\n        if not hasattr(self, \"ref_raw_bodylink_pos\") or self.ref_raw_bodylink_pos is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_cur.\"\n            )\n            keybody_idxs = self._get_ref_keybody_indices(\"actor_ref_keybody_rel_pos_cur\")\n            n_keybodies = int(keybody_idxs.shape[0])\n            if n_keybodies == 0:\n                return np.zeros(0, dtype=np.float32)\n            return np.zeros(3 * n_keybodies, dtype=np.float32)\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_cur.\"\n            )\n            keybody_idxs = self._get_ref_keybody_indices(\"actor_ref_keybody_rel_pos_cur\")\n            n_keybodies = int(keybody_idxs.shape[0])\n            if n_keybodies == 0:\n                return np.zeros(0, dtype=np.float32)\n            return np.zeros(3 * n_keybodies, dtype=np.float32)\n\n        keybody_idxs = self._get_ref_keybody_indices(\"actor_ref_keybody_rel_pos_cur\")\n        n_keybodies = int(keybody_idxs.shape[0])\n        if n_keybodies == 0:\n            return np.zeros(0, dtype=np.float32)\n\n        frame_idx = self.ref_motion_frame_idx\n        ref_body_global_pos = np.asarray(self.ref_raw_bodylink_pos[frame_idx], dtype=np.float32)\n        ref_root_global_pos = ref_body_global_pos[self.root_body_idx]\n        q_root_xyzw = self.ref_raw_bodylink_rot[frame_idx, self.root_body_idx]\n        q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw)\n        q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)\n\n        rel_pos_w = ref_body_global_pos[keybody_idxs] - ref_root_global_pos[None, :]\n        rel_pos_root = self._quat_rotate_inv_wxyz(q_root_wxyz, rel_pos_w)\n        return np.asarray(rel_pos_root, dtype=np.float32).reshape(-1)\n\n    def _get_obs_ref_keybody_rel_pos_fut(self):\n        T = self.n_fut_frames_int\n        if T <= 0:\n            return np.zeros(0, dtype=np.float32)\n        if getattr(self, \"_use_fk_vr\", False) and self._fk_trans_fut is not None:\n            keybody_idxs = self._get_ref_keybody_indices(\"actor_ref_keybody_rel_pos_fut\")\n            n_keybodies = int(keybody_idxs.shape[0])\n            if n_keybodies == 0:\n                return np.zeros((T, 0), dtype=np.float32).reshape(-1)\n            if not self._root_only_fk_has_required_keybodies(keybody_idxs):\n                return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1)\n            ref_body = self._fk_trans_fut[:T]  # (T, num_bodies, 3)\n            ref_root = ref_body[:, self.root_body_idx, :]  # (T, 3)\n            if self._keybody_rel_pos_fut_buffer.shape[1] != n_keybodies:\n                self._keybody_rel_pos_fut_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32)\n                self._keybody_rel_pos_w_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32)\n            elif (\n                self._keybody_rel_pos_w_buffer is None\n                or self._keybody_rel_pos_w_buffer.shape[0] < T\n                or self._keybody_rel_pos_w_buffer.shape[1] != n_keybodies\n            ):\n                self._keybody_rel_pos_w_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32)\n            rel_pos_fut = self._keybody_rel_pos_fut_buffer\n            np.subtract(\n                ref_body[:, keybody_idxs, :],\n                ref_root[:, None, :],\n                out=self._keybody_rel_pos_w_buffer[:T, :n_keybodies, :],\n            )\n            rel_pos_fut[:, :, :] = self._quat_rotate_inv_wxyz(\n                self._fk_quat_fut_wxyz[:T, None, :],\n                self._keybody_rel_pos_w_buffer[:T, :n_keybodies, :],\n            )\n            return rel_pos_fut.reshape(-1).astype(np.float32)\n        keybody_idxs = self._get_ref_keybody_indices(\"actor_ref_keybody_rel_pos_fut\")\n        n_keybodies = int(keybody_idxs.shape[0])\n        if not hasattr(self, \"ref_raw_bodylink_pos\") or self.ref_raw_bodylink_pos is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_fut.\"\n            )\n            if n_keybodies == 0:\n                return np.zeros((T, 0), dtype=np.float32).reshape(-1)\n            return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1)\n        if not hasattr(self, \"ref_raw_bodylink_rot\") or self.ref_raw_bodylink_rot is None:\n            self.get_logger().warn(\n                \"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_fut.\"\n            )\n            if n_keybodies == 0:\n                return np.zeros((T, 0), dtype=np.float32).reshape(-1)\n            return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1)\n\n        if n_keybodies == 0:\n            return np.zeros((T, 0), dtype=np.float32).reshape(-1)\n        fut_idx = self._get_future_frame_indices()\n        q_root_wxyz = self._get_future_root_quat_wxyz()\n        ref_body_global_pos = np.asarray(self.ref_raw_bodylink_pos[fut_idx], dtype=np.float32)\n        ref_root_global_pos = ref_body_global_pos[:, self.root_body_idx, :]\n        rel_pos_w = ref_body_global_pos[:, keybody_idxs, :] - ref_root_global_pos[:, None, :]\n        if self._keybody_rel_pos_fut_buffer.shape[1] != n_keybodies:\n            self._keybody_rel_pos_fut_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32)\n        rel_pos_fut = self._keybody_rel_pos_fut_buffer\n        rel_pos_fut[:, :, :] = self._quat_rotate_inv_wxyz(q_root_wxyz[:, None, :], rel_pos_w)\n        return rel_pos_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_place_holder(self):\n        return np.zeros(self.actor_place_holder_ndim, dtype=np.float32)\n\n    # =========== Policy Obeservation Methods ===========\n\n    def _warmup_fk_for_vr(self):\n        \"\"\"Run one FK warmup step when entering VR motion mode.\"\"\"\n        try:\n            if (\n                getattr(self, \"fk\", None) is None\n                or not getattr(self, \"fk_initialized\", False)\n            ):\n                return\n            if getattr(self, \"external_latest_obs\", None) is None:\n                return\n            if getattr(self, \"external_fut_dof_pos_queue\", None) is None:\n                return\n\n            n_fut = int(getattr(self, \"n_fut_frames\", 0))\n            if (\n                n_fut <= 0\n                or self.external_fut_root_pos_queue is None\n                or self.external_fut_root_rot_queue is None\n            ):\n                return\n\n            latest = self.external_latest_obs[0]\n            cur_root_pos = latest[58:61]\n            cur_root_rot = latest[61:65]\n            cur_dof_pos = latest[0:29]\n            root_pos_tensor, root_rot_tensor, dof_pos_tensor = (\n                self._prepare_vr_fk_tensors(\n                    cur_root_pos=cur_root_pos,\n                    cur_root_rot=cur_root_rot,\n                    cur_dof_pos=cur_dof_pos,\n                    n_fut=n_fut,\n                )\n            )\n\n            fk_out = self.fk(\n                root_pos=root_pos_tensor,\n                root_quat=root_rot_tensor,\n                dof_pos=dof_pos_tensor,\n                fps=float(1.0 / self.dt),\n                quat_format=\"wxyz\",\n                vel_smoothing_sigma=0.0,\n                compute_velocity=False,\n            )\n            self._fk_vr_out = {\n                k: v.detach().cpu().numpy() for k, v in fk_out.items()\n            }\n        except Exception as e:\n            self.get_logger().warn(f\"[VR] FK warmup failed, fallback to zeros: {e}\")\n\n    def _low_state_callback(self, ls_msg: LowState):\n        \"\"\"Process low-level robot state and remote controller input.\n\n        Main callback that handles:\n        - Remote controller input processing\n        - Motion selection based on button presses\n        - Safety state checking\n        - Velocity command extraction\n\n        Motion Button Mapping:\n        - A button: Enable policy (defaults to velocity mode)\n        - B button: Switch from velocity to motion mode\n        - Y button: Switch from motion back to velocity mode\n        - UP/DOWN/LEFT/RIGHT: Motion clip selection (only in velocity tracking mode)\n\n        Args:\n            ls_msg: LowState message containing robot sensor data and remote controller input\n        \"\"\"\n        self._lowstate_msg = ls_msg\n        self.remote_controller.set(ls_msg.wireless_remote)\n\n        # A button: Toggle policy enable state (default to velocity mode)\n        if (\n            self._is_button_pressed(KeyMap.A) and self.robot_state_ready\n        ):\n            self.policy_enabled = True\n            self.current_policy_mode = \"velocity\"  # Default to velocity mode\n            self.latest_obs_flag = False\n            self._reset_motion_action_ema_filter()\n            self._reset_counter()\n            if hasattr(self, \"use_kv_cache\") and self.use_kv_cache:\n                self.motion_kv_cache.fill(0)\n            self.motion_step_idx = 0\n            # Initialize with velocity model's default angles\n            self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)\n            self.target_dof_pos_onnx = self.velocity_default_angles_onnx.copy()\n            # Publish velocity model's control parameters (kps/kds)\n            self._publish_control_params()\n            self.get_logger().info(\n                f\"Policy enabled in {self.current_policy_mode} tracking mode\"\n            )\n\n        # B button: Switch to motion tracking mode (only when policy is enabled)\n        if (\n            self._is_button_pressed(KeyMap.B)\n            and self.robot_state_ready\n            and self.policy_enabled\n            and self.current_policy_mode == \"velocity\"  # Only allow switch from velocity mode\n        ):\n            vr_data_available = bool(\n                getattr(self, \"enable_teleop_reference\", True)\n                and getattr(self, \"external_obs_received\", False)\n                and getattr(self, \"external_latest_obs\", None) is not None\n            )\n            vr_ready = self._is_vr_ready_for_motion()\n            if (\n                self.enable_teleop_reference\n                and self.require_vr_data_for_motion\n                and not vr_ready\n            ):\n                self.get_logger().warn(\n                    \"require_vr_data_for_motion=True but the VR queue is not ready yet; staying in velocity mode.\"\n                )\n            else:\n                # Don't automatically switch to next motion clip - keep current selection\n                if hasattr(self, \"all_motion_data\") and self.all_motion_data:\n                    # Load the current motion clip data (don't change current_motion_clip_index)\n                    self._load_current_motion()\n\n                self.current_policy_mode = \"motion\"\n                self._reset_motion_action_ema_filter()\n                self._reset_counter()\n                if hasattr(self, \"use_kv_cache\") and self.use_kv_cache:\n                    self.motion_kv_cache.fill(0)\n                    self.get_logger().info(\"Motion KV-Cache reset.\")\n\n                self.motion_step_idx = 0\n                self.get_logger().info(\"Motion Step Index reset to 0.\")\n\n                # Clear any pending actions to prevent conflicts between policies\n                # Use motion model's default angles\n                self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)\n                self.target_dof_pos_onnx = self.motion_default_angles_onnx.copy()\n\n                # Publish motion model's control parameters (kps/kds)\n                self._publish_control_params()\n\n                self.latest_obs_flag = bool(vr_data_available)\n                source_mode = \"ZMQ latest_obs\" if self.latest_obs_flag else \"offline motion\"\n\n                self.get_logger().info(\n                    f\"Switched to motion tracking mode ({source_mode}) - motion clip index: {self.current_motion_clip_index}\"\n                )\n                if self.latest_obs_flag:\n                    self.get_logger().info(\"[VR] Reference trajectory source: ZMQ latest_obs\")\n                    self._warmup_fk_for_vr()\n                self.motion_in_progress = True\n\n        if (\n            self._is_button_pressed(KeyMap.Y)\n            and self.robot_state_ready\n            and self.policy_enabled\n            and self.current_policy_mode == \"motion\"  # Only allow switch from motion mode\n        ):\n            self._switch_to_velocity_mode()\n\n        # Get velocity commands only in velocity tracking mode\n        if self.current_policy_mode == \"velocity\":\n            self.vx, self.vy, self.vyaw = self.remote_controller.get_velocity_commands()\n        else:\n            # In motion tracking mode, ignore joystick input\n            self.vx, self.vy, self.vyaw = 0.0, 0.0, 0.0\n\n        # Handle motion clip selection in velocity tracking mode (UP/DOWN/LEFT/RIGHT)\n        if (\n            self.current_policy_mode == \"velocity\"\n            and self.policy_enabled\n            and self.robot_state_ready\n        ):\n            # Handle motion clip selection with UP/DOWN/LEFT/RIGHT buttons\n            if self._is_button_pressed(KeyMap.up):\n                # Switch to previous motion clip\n                if hasattr(self, \"all_motion_data\") and self.all_motion_data:\n                    self.current_motion_clip_index = (\n                        self.current_motion_clip_index - 1\n                    ) % len(self.all_motion_data)\n                    self.get_logger().info(\n                        f\"Selected previous motion clip: \"\n                        f\"{self.motion_file_names[self.current_motion_clip_index]}\"\n                    )\n            elif self._is_button_pressed(KeyMap.down):\n                # Switch to next motion clip\n                if hasattr(self, \"all_motion_data\") and self.all_motion_data:\n                    self.current_motion_clip_index = (\n                        self.current_motion_clip_index + 1\n                    ) % len(self.all_motion_data)\n                    self.get_logger().info(\n                        f\"Selected next motion clip: \"\n                        f\"{self.motion_file_names[self.current_motion_clip_index]}\"\n                    )\n            elif self._is_button_pressed(KeyMap.left):\n                # Select first motion clip\n                if hasattr(self, \"all_motion_data\") and self.all_motion_data:\n                    self.current_motion_clip_index = 0\n                    self.get_logger().info(\n                        f\"Selected first motion clip: \"\n                        f\"{self.motion_file_names[self.current_motion_clip_index]}\"\n                    )\n            elif self._is_button_pressed(KeyMap.right):\n                # Select last motion clip\n                if hasattr(self, \"all_motion_data\") and self.all_motion_data:\n                    self.current_motion_clip_index = len(self.all_motion_data) - 1\n                    self.get_logger().info(\n                        f\"Selected last motion clip: \"\n                        f\"{self.motion_file_names[self.current_motion_clip_index]}\"\n                    )\n\n    def run(self):\n        \"\"\"Main execution loop for policy inference and action publication.\"\"\"\n        # Only run if setup is completed\n        if not hasattr(self, '_setup_completed') or not self._setup_completed:\n            return\n        t_loop_start = time.perf_counter()\n        now = time.time()\n        t_io = time.perf_counter()\n        buf = getattr(self, \"_ros_latest_obs_buffer\", None)\n        if buf is not None:\n            self._ros_latest_obs_buffer = None\n            frame_idx, obs_arr = buf\n            if frame_idx is not None:\n                self._npz_replay_frame_index = frame_idx\n            self._store_external_latest_obs(obs_arr[None, :])\n        self._poll_zmq_latest_obs()\n        if getattr(self, \"current_policy_mode\", None) == \"motion\":\n            if self._last_vr_status_log_time is None:\n                self._last_vr_status_log_time = now\n            elif now - self._last_vr_status_log_time >= 5.0:\n                vr_available = bool(\n                    getattr(self, \"external_obs_received\", False)\n                    and getattr(self, \"external_latest_obs\", None) is not None\n                )\n                queue_stats = self._latest_obs_buffer.get_queue_stats()\n                freq = queue_stats.get(\"expected_freq\")\n                if vr_available:\n                    self.get_logger().info(\n                        \"[VR-STATUS] ZMQ latest_obs streaming | \"\n                        f\"buffer_size={queue_stats['queue_size']} \"\n                        f\"expected_freq={freq:.1f}Hz\" if freq else\n                        f\"buffer_size={queue_stats['queue_size']} expected_freq=unknown\"\n                    )\n                else:\n                    self.get_logger().warn(\n                        \"[VR-STATUS] No new ZMQ latest_obs received in the last 5 seconds; using offline reference or the last buffered VR state.\"\n                    )\n                self._last_vr_status_log_time = now\n        if (\n            getattr(self, \"require_vr_data_for_motion\", False)\n            and getattr(self, \"policy_enabled\", False)\n            and not getattr(self, \"_vr_ready_logged\", False)\n            and self._is_vr_ready_for_motion()\n        ):\n            self.get_logger().info(\n                f\"[VR] VR queue is ready for motion mode (seen_frames={int(getattr(self, '_external_seen_frames', 0))}, \"\n                f\"n_fut={int(getattr(self, 'n_fut_frames', 0) or 0)}, \"\n                f\"delay={int(getattr(self, 'zmq_jitter_delay_frames', 0) or 0)})\"\n            )\n            self._vr_ready_logged = True\n        self._publish_latest_obs()\n        io_ms = self._timing_ms(t_io)\n        policy_timing = self._run_without_profiling()\n        _run_elapsed = 0.0\n        if policy_timing is not None:\n            _run_elapsed = float(policy_timing.get(\"policy_total_ms\", 0.0)) / 1000.0\n        if (\n            getattr(self, \"current_policy_mode\", None) == \"motion\"\n            and getattr(self, \"latest_obs_flag\", False)\n            and _run_elapsed > 0.5\n            and not getattr(self, \"_vr_cold_start_logged\", False)\n        ):\n            self._vr_cold_start_logged = True\n            self.get_logger().info(\n                \"[VR] The first motion step is a cold start (FK/ONNX initialization) and may take about 1 second.\"\n            )\n        if (\n            getattr(self, \"current_policy_mode\", None) == \"motion\"\n            and getattr(self, \"latest_obs_flag\", False)\n            and _run_elapsed > 1.15 * self.dt\n            and _run_elapsed <= 0.5\n        ):\n            self._policy_slow_count = getattr(self, \"_policy_slow_count\", 0) + 1\n            if self._policy_slow_count == 1 or self._policy_slow_count % 50 == 0:\n                self.get_logger().warn(\n                    f\"[VR] Policy step latency {_run_elapsed*1000:.1f} ms exceeds the target {self.dt*1000:.1f} ms. \"\n                    f\"Estimated /humanoid/action rate: {1.0/_run_elapsed:.1f} Hz (target {1.0/self.dt:.0f} Hz). \"\n                    \"The main bottleneck is usually FK or ONNX inference; if the system settles near 30 Hz, consider setting policy_freq to 30.\"\n                )\n        if policy_timing is not None:\n            sample = dict(policy_timing)\n            sample[\"io_ms\"] = io_ms\n            sample[\"loop_total_ms\"] = self._timing_ms(t_loop_start)\n            self._record_timing_sample(sample)\n\n    def _read_onnx_metadata(self, onnx_model_path: str) -> dict:\n        \"\"\"Read model metadata from ONNX file and parse into Python types.\"\"\"\n        model = onnx.load(str(onnx_model_path))\n        meta = {p.key: p.value for p in model.metadata_props}\n\n        def _parse_floats(csv_str: str):\n            return np.array(\n                [float(x) for x in csv_str.split(\",\") if x != \"\"],\n                dtype=np.float32,\n            )\n\n        result = {}\n        result[\"action_scale\"] = _parse_floats(meta[\"action_scale\"])\n        result[\"kps\"] = _parse_floats(meta[\"joint_stiffness\"])\n        result[\"kds\"] = _parse_floats(meta[\"joint_damping\"])\n        result[\"default_joint_pos\"] = _parse_floats(meta[\"default_joint_pos\"])\n        result[\"joint_names\"] = [x for x in meta[\"joint_names\"].split(\",\") if x != \"\"]\n        return result\n\n    def _store_external_latest_obs(self, arr: np.ndarray):\n        \"\"\"Store latest_obs and maintain the current/future frame queues.\"\"\"\n        if arr.ndim == 1:\n            arr = arr[None, :]\n        if arr.shape[1] < self.latest_obs_expected_dim:\n            self.get_logger().warn(\n                f\"Received latest_obs dim={arr.shape[1]}, expected >= {self.latest_obs_expected_dim}\"\n            )\n            return\n        clipped = arr[:, : self.latest_obs_expected_dim].astype(np.float32, copy=False)\n        current_time = time.time()\n        self.external_latest_obs = clipped\n        self.external_obs_received = True\n        self.last_external_obs_time = current_time\n        self._external_seen_frames = int(getattr(self, \"_external_seen_frames\", 0)) + 1\n\n        latest_root_pos = clipped[0, 58:61]\n        latest_root_rot = clipped[0, 61:65]\n        latest_dof_pos = clipped[0, :29]\n        latest_dof_vel = clipped[0, 29:58]\n\n        if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None:\n            raw_idx = getattr(self, \"_npz_replay_frame_index\", None)\n            try:\n                latest_frame_idx = int(raw_idx) if raw_idx is not None else -1\n            except Exception:\n                latest_frame_idx = -1\n\n            if self._prev_external_dof_pos is None:\n                self._prev_external_dof_pos = np.empty_like(self.external_fut_dof_pos_queue[0])\n                self._prev_external_dof_vel = np.empty_like(self.external_fut_dof_vel_queue[0])\n                self._prev_external_root_pos = np.empty_like(self.external_fut_root_pos_queue[0])\n                if self.external_fut_root_rot_queue is not None:\n                    self._prev_external_root_rot = np.empty_like(\n                        self.external_fut_root_rot_queue[0]\n                    )\n            np.copyto(self._prev_external_dof_pos, self.external_fut_dof_pos_queue[0])\n            np.copyto(self._prev_external_dof_vel, self.external_fut_dof_vel_queue[0])\n            np.copyto(self._prev_external_root_pos, self.external_fut_root_pos_queue[0])\n            if self.external_fut_root_rot_queue is not None:\n                np.copyto(self._prev_external_root_rot, self.external_fut_root_rot_queue[0])\n            if self.external_fut_frame_idx_queue is not None:\n                try:\n                    self._prev_external_frame_idx = int(self.external_fut_frame_idx_queue[0])\n                except Exception:\n                    self._prev_external_frame_idx = -1\n\n            self.external_fut_dof_pos_queue[:-1] = self.external_fut_dof_pos_queue[1:]\n            self.external_fut_dof_pos_queue[-1] = latest_dof_pos\n            self.external_fut_dof_vel_queue[:-1] = self.external_fut_dof_vel_queue[1:]\n            self.external_fut_dof_vel_queue[-1] = latest_dof_vel\n            self.external_fut_root_pos_queue[:-1] = self.external_fut_root_pos_queue[1:]\n            self.external_fut_root_pos_queue[-1] = latest_root_pos\n            if self.external_fut_root_rot_queue is not None:\n                self.external_fut_root_rot_queue[:-1] = self.external_fut_root_rot_queue[1:]\n                self.external_fut_root_rot_queue[-1] = latest_root_rot\n            if self.external_fut_frame_idx_queue is not None:\n                self.external_fut_frame_idx_queue[:-1] = self.external_fut_frame_idx_queue[1:]\n                self.external_fut_frame_idx_queue[-1] = latest_frame_idx\n\n    def _poll_zmq_latest_obs(self):\n        \"\"\"Poll the ZMQ latest_obs buffer with stale-data checks and delay.\"\"\"\n        current_time = time.time()\n\n        data, timestamp, is_stale, frame_index, sender_timestamp = self._latest_obs_buffer.get_with_age_and_delay(\n            max_age=self.max_data_age,\n            delay_steps=int(getattr(self, \"zmq_jitter_delay_frames\", 0)),\n        )\n\n        if data is None:\n            return\n\n        if frame_index is not None:\n            self._npz_replay_frame_index = int(frame_index)\n        self._latest_sender_timestamp = sender_timestamp\n\n        if is_stale:\n            self.stale_data_warning_count += 1\n            if self.stale_data_warning_count % 50 == 0:\n                age_ms = (current_time - timestamp) * 1000.0\n                self.get_logger().warn(\n                    f\"ZMQ latest_obs is stale: age={age_ms:.1f}ms \"\n                    f\"(threshold={self.max_data_age*1000:.1f}ms), \"\n                    f\"stale_count={self.stale_data_warning_count}\"\n                )\n                queue_stats = self._latest_obs_buffer.get_queue_stats()\n                if queue_stats.get(\"expected_freq\"):\n                    self.get_logger().warn(\n                        f\"latest_obs buffer: size={queue_stats['queue_size']}, \"\n                        f\"avg_interval={queue_stats['avg_interval']*1000:.1f}ms, \"\n                        f\"expected_freq={queue_stats['expected_freq']:.1f}Hz\"\n                    )\n        else:\n            if self.stale_data_warning_count > 0:\n                self.stale_data_warning_count = 0\n\n        if self.last_poll_time is not None:\n            poll_interval = current_time - self.last_poll_time\n            if poll_interval > 0.03:\n                self.get_logger().debug(\n                    f\"Policy poll interval {poll_interval*1000:.1f}ms (>30ms)\"\n                )\n        self.last_poll_time = current_time\n\n        self._store_external_latest_obs(np.asarray(data, dtype=np.float32))\n\n        if (\n            getattr(self, \"enable_teleop_reference\", True)\n            and getattr(self, \"require_vr_data_for_motion\", False)\n            and not getattr(self, \"latest_obs_flag\", False)\n            and self._is_vr_ready_for_motion()\n        ):\n            self.latest_obs_flag = True\n            if not getattr(self, \"_vr_fk_started_logged\", False):\n                self.get_logger().info(\n                    \"[VR] ZMQ data is ready; the main thread will build the reference trajectory from live ZMQ input.\"\n                )\n                self._vr_fk_started_logged = True\n\n    def _publish_latest_obs(self):\n        \"\"\"Publish the latest_obs topic for debugging or reuse.\"\"\"\n        if self.external_latest_obs is None:\n            return\n        try:\n            msg = Float32MultiArray()\n            msg.data = self.external_latest_obs[0].tolist()\n            self.latest_obs_pub.publish(msg)\n        except Exception as e:\n            self.get_logger().error(f\"Failed to publish latest_obs: {e}\")\n\n    def _apply_onnx_metadata(self):\n        \"\"\"Apply PD/scale/defaults from ONNX metadata as authoritative values.\n        Load separate metadata for velocity and motion models.\"\"\"\n        # Load velocity model metadata\n        velocity_meta = self._read_onnx_metadata(self.velocity_onnx_path)\n        self.velocity_dof_names_onnx = velocity_meta[\"joint_names\"]\n        self.velocity_action_scale_onnx = velocity_meta[\"action_scale\"].astype(np.float32)\n        self.velocity_kps_onnx = velocity_meta[\"kps\"].astype(np.float32)\n        self.velocity_kds_onnx = velocity_meta[\"kds\"].astype(np.float32)\n        self.velocity_default_angles_onnx = velocity_meta[\"default_joint_pos\"].astype(np.float32)\n        \n        # Load motion model metadata\n        motion_meta = self._read_onnx_metadata(self.motion_onnx_path)\n        self.motion_dof_names_onnx = motion_meta[\"joint_names\"]\n        self.motion_action_scale_onnx = motion_meta[\"action_scale\"].astype(np.float32)\n        self.motion_kps_onnx = motion_meta[\"kps\"].astype(np.float32)\n        self.motion_kds_onnx = motion_meta[\"kds\"].astype(np.float32)\n        self.motion_default_angles_onnx = motion_meta[\"default_joint_pos\"].astype(np.float32)\n        \n        # Use velocity model metadata as default (for backward compatibility)\n        self.dof_names_onnx = self.velocity_dof_names_onnx\n        self.action_scale_onnx = self.velocity_action_scale_onnx\n        self.kps_onnx = self.velocity_kps_onnx\n        self.kds_onnx = self.velocity_kds_onnx\n        self.default_angles_onnx = self.velocity_default_angles_onnx\n        self.default_angles_dict = {\n            name: float(self.default_angles_onnx[idx])\n            for idx, name in enumerate(self.dof_names_onnx)\n        }\n\n    def _build_dof_mappings(self):\n        # Map ONNX <-> MJCF for control\n        \n        # Check if all ONNX names exist in real_dof_names (use velocity as reference)\n        missing_names = [name for name in self.velocity_dof_names_onnx if name not in self.real_dof_names]\n        if missing_names:\n            self.get_logger().warn(f\"Missing names in real_dof_names: {missing_names}\")\n        \n        # Build mappings for velocity model\n        self.velocity_onnx_to_real = [\n            self.velocity_dof_names_onnx.index(name) for name in self.real_dof_names\n        ]\n        self.velocity_kps_real = self.velocity_kps_onnx[self.velocity_onnx_to_real].astype(np.float32)\n        self.velocity_kds_real = self.velocity_kds_onnx[self.velocity_onnx_to_real].astype(np.float32)\n        \n        # Build mappings for motion model\n        self.motion_onnx_to_real = [\n            self.motion_dof_names_onnx.index(name) for name in self.real_dof_names\n        ]\n        self.motion_kps_real = self.motion_kps_onnx[self.motion_onnx_to_real].astype(np.float32)\n        self.motion_kds_real = self.motion_kds_onnx[self.motion_onnx_to_real].astype(np.float32)\n        \n        # Use velocity model mappings as default (for backward compatibility)\n        self.onnx_to_real = self.velocity_onnx_to_real\n        self.kps_real = self.velocity_kps_real\n        self.kds_real = self.velocity_kds_real\n        self.default_angles_mu = self.velocity_default_angles_onnx[self.velocity_onnx_to_real].astype(np.float32)\n        self.action_scale_mu = self.velocity_action_scale_onnx[self.velocity_onnx_to_real].astype(np.float32)\n        \n        # Build ref_to_onnx mapping (for motion model)\n        self.ref_to_onnx = [\n            self.dof_names_ref_motion.index(name) for name in self.motion_dof_names_onnx\n        ]\n        \n        # Pre-compute default angles dictionaries for efficient observation building\n        self.velocity_default_angles_dict = {\n            name: float(self.velocity_default_angles_onnx[idx])\n            for idx, name in enumerate(self.velocity_dof_names_onnx)\n        }\n        self.motion_default_angles_dict = {\n            name: float(self.motion_default_angles_onnx[idx])\n            for idx, name in enumerate(self.motion_dof_names_onnx)\n        }\n        \n        # Pre-compute dof_names_onnx arrays for each mode (avoid repeated selection)\n        self.velocity_dof_names_onnx_array = np.array(self.velocity_dof_names_onnx)\n        self.motion_dof_names_onnx_array = np.array(self.motion_dof_names_onnx)\n        self.motion_dof_real_indices = [\n            self.real_dof_names.index(n) for n in self.motion_dof_names_onnx\n        ]\n        self.velocity_dof_real_indices = [\n            self.real_dof_names.index(n) for n in self.velocity_dof_names_onnx\n        ]\n        n_dof = max(len(self.motion_dof_names_onnx), len(self.velocity_dof_names_onnx))\n        self._dof_pos_obs_buffer = np.zeros(n_dof, dtype=np.float32)\n        self._dof_vel_obs_buffer = np.zeros(n_dof, dtype=np.float32)\n        \n        # Pre-allocate arrays for future frame observations\n        if hasattr(self, \"n_fut_frames\") and self.n_fut_frames is not None:\n            self.n_fut_frames_int = int(self.n_fut_frames)\n            if self.n_fut_frames_int > 0:\n                self._pos_fut_buffer = np.zeros(\n                    (len(self.dof_names_ref_motion), self.n_fut_frames_int), dtype=np.float32\n                )\n                self._h_fut_buffer = np.zeros((1, self.n_fut_frames_int), dtype=np.float32)\n                self._root_pos_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32)\n            else:\n                self.n_fut_frames_int = 0\n        else:\n            self.n_fut_frames_int = 0\n\n        self._future_frame_offsets = np.arange(1, self.n_fut_frames_int + 1, dtype=np.int64)\n        self._future_frame_indices_buffer = np.zeros(self.n_fut_frames_int, dtype=np.int64)\n        self._future_root_quat_wxyz_buffer = np.zeros((self.n_fut_frames_int, 4), dtype=np.float32)\n        self._gravity_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32)\n        self._base_linvel_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32)\n        self._base_angvel_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32)\n        self._keybody_rel_pos_fut_buffer = np.zeros((self.n_fut_frames_int, 0, 3), dtype=np.float32)\n        self._keybody_rel_pos_w_buffer = None\n        max_t = max(1, self.n_fut_frames_int)\n        self._vel_fut_T6 = np.zeros((max_t, 6), dtype=np.float32)\n        self._rot_t_buffer = np.zeros((max_t, 3), dtype=np.float32)\n        self._rot_cross_buffer = np.zeros((max_t, 3), dtype=np.float32)\n        self._use_fk_vr = False\n        self._fk_vel_0_root = np.zeros(3, dtype=np.float32)\n        self._fk_angvel_0_root = np.zeros(3, dtype=np.float32)\n        self._fk_quat_0_root = np.zeros(4, dtype=np.float32)\n        self._fk_trans_0 = None\n        max_t = max(1, self.n_fut_frames_int)\n        self._fk_vel_fut = np.zeros((max_t, 3), dtype=np.float32)\n        self._fk_angvel_fut = np.zeros((max_t, 3), dtype=np.float32)\n        self._fk_quat_fut = np.zeros((max_t, 4), dtype=np.float32)\n        self._fk_trans_fut = None\n        self._q_conj_buffer = np.zeros((max_t + 1, 4), dtype=np.float32)\n        self._rotated_3vec_buffer = np.zeros(3, dtype=np.float32)\n        self._rotated_angvel_cur_buffer = np.zeros(3, dtype=np.float32)\n        self._cross_t_buffer = np.zeros(3, dtype=np.float32)\n        self._fk_quat_0_root_wxyz = np.zeros(4, dtype=np.float32)\n        self._fk_quat_fut_wxyz = np.zeros((max_t, 4), dtype=np.float32)\n        \n        # Pre-allocate velocity command observation array\n        self._velocity_cmd_obs = np.zeros(4, dtype=np.float32)\n        \n        # Publish kps and kds parameters (use velocity as default)\n        self._publish_control_params()\n\n    def _publish_control_params(self):\n        \"\"\"Publish kps and kds control parameters based on current policy mode.\n        \n        Called during initialization and mode switching to ensure control node\n        receives the correct parameters for the current policy mode.\n        \"\"\"\n        try:\n            # Use appropriate parameters based on current policy mode\n            if self.current_policy_mode == \"motion\":\n                current_kps = self.motion_kps_real\n                current_kds = self.motion_kds_real\n            else:  # velocity mode\n                current_kps = self.velocity_kps_real\n                current_kds = self.velocity_kds_real\n            \n            # Publish kps\n            kps_msg = Float32MultiArray()\n            kps_msg.data = current_kps.tolist()\n            self.kps_pub.publish(kps_msg)\n            \n            # Publish kds\n            kds_msg = Float32MultiArray()\n            kds_msg.data = current_kds.tolist()\n            self.kds_pub.publish(kds_msg)\n            \n            self.get_logger().info(\n                f\"Published control parameters ({self.current_policy_mode} mode): \"\n                f\"kps={len(current_kps)}, kds={len(current_kds)}\"\n            )\n        except Exception as e:\n            self.get_logger().error(f\"Failed to publish control parameters: {e}\")\n\n    def _publish_policy_mode(self):\n        \"\"\"Publish current policy mode status.\"\"\"\n        try:\n            mode_msg = String()\n            mode_msg.data = f\"{self.current_policy_mode}_{'enabled' if self.policy_enabled else 'disabled'}\"\n            self.policy_mode_pub.publish(mode_msg)\n        except Exception as e:\n            self.get_logger().error(f\"Failed to publish policy mode: {e}\")\n\n    def _timing_ms(self, t0: float) -> float:\n        return (time.perf_counter() - t0) * 1000.0\n\n    def _record_timing_sample(self, sample: dict):\n        if not getattr(self, \"timing_debug_enabled\", False):\n            return\n        self._timing_debug_samples.append(sample)\n        if getattr(self, \"timing_debug_log_per_loop\", False):\n            self.get_logger().info(\n                \"[Timing] \"\n                f\"loop_total={sample['loop_total_ms']:.2f}ms \"\n                f\"io={sample['io_ms']:.2f}ms \"\n                f\"policy_total={sample['policy_total_ms']:.2f}ms \"\n                f\"fk={sample['fk_ms']:.2f}ms \"\n                f\"obs={sample['obs_ms']:.2f}ms \"\n                f\"onnx={sample['onnx_ms']:.2f}ms \"\n                f\"post={sample['post_ms']:.2f}ms\"\n            )\n\n        now = time.time()\n        last = getattr(self, \"_timing_debug_last_log_time\", None)\n        interval = float(getattr(self, \"timing_debug_log_interval_sec\", 5.0))\n        if last is None:\n            self._timing_debug_last_log_time = now\n            return\n        if now - last < interval:\n            return\n        if len(self._timing_debug_samples) == 0:\n            self._timing_debug_last_log_time = now\n            return\n\n        keys = [\n            \"loop_total_ms\",\n            \"io_ms\",\n            \"policy_total_ms\",\n            \"fk_ms\",\n            \"obs_ms\",\n            \"onnx_ms\",\n            \"post_ms\",\n        ]\n        stats = {}\n        for key in keys:\n            vals = np.array(\n                [float(s.get(key, 0.0)) for s in self._timing_debug_samples],\n                dtype=np.float64,\n            )\n            stats[key] = (float(np.mean(vals)), float(np.max(vals)))\n        self.get_logger().info(\n            \"[Timing-Agg] \"\n            + \" \".join(\n                f\"{key}=mean:{stats[key][0]:.2f}ms/max:{stats[key][1]:.2f}ms\"\n                for key in keys\n            )\n            + f\" n={len(self._timing_debug_samples)}\"\n        )\n        self._timing_debug_samples.clear()\n        self._timing_debug_last_log_time = now\n\n    def _root_only_fk_has_required_keybodies(self, keybody_idxs: np.ndarray) -> bool:\n        if keybody_idxs.size == 0:\n            return True\n        available_bodies = 0 if self._fk_trans_0 is None else int(self._fk_trans_0.shape[0])\n        if available_bodies <= int(np.max(keybody_idxs)):\n            if not self._root_only_fk_keybody_warned:\n                self.get_logger().warn(\n                    \"[RootOnlyFK] FK output only contains root body, but obs schema still \"\n                    \"requests non-root keybody positions. Returning zeros for keybody obs.\"\n                )\n                self._root_only_fk_keybody_warned = True\n            return False\n        return True\n\n    def _run_without_profiling(self):\n        \"\"\"Run the main loop without performance profiling.\"\"\"\n        if self._lowstate_msg is None or not self.policy_enabled:\n            return None\n\n        timing_info = {\n            \"policy_total_ms\": 0.0,\n            \"fk_ms\": 0.0,\n            \"obs_ms\": 0.0,\n            \"onnx_ms\": 0.0,\n            \"post_ms\": 0.0,\n        }\n        _t_policy_start = time.perf_counter()\n\n        if self.current_policy_mode == \"motion\":\n            if self.latest_obs_flag:\n                current_time = time.time()\n                if self.last_external_obs_time is None:\n                    data_age = float(\"inf\")\n                else:\n                    data_age = current_time - self.last_external_obs_time\n\n                if data_age > self.max_data_age:\n                    self.get_logger().warn(\n                        f\"ZMQ latest_obs is stale: age={data_age*1000:.1f}ms > {self.max_data_age*1000:.1f}ms; \"\n                        \"switching to velocity tracking mode for safety.\"\n                    )\n                    self._switch_to_velocity_mode(reason=\"VR latest_obs stale\")\n                    return None\n\n            if not self.latest_obs_flag and (\n                not hasattr(self, \"n_motion_frames\") or not hasattr(self, \"ref_dof_pos\")\n            ):\n                self.get_logger().warn(\"Motion data not loaded, skipping policy execution\")\n                return None\n\n            if (\n                self.latest_obs_flag\n                and self.fk is not None\n                and self.external_fut_dof_pos_queue is not None\n            ):\n                try:\n                    n_fut = int(getattr(self, \"n_fut_frames\", 0))\n                    if (\n                        n_fut > 0\n                        and self.external_fut_root_pos_queue is not None\n                        and self.external_fut_root_rot_queue is not None\n                    ):\n                        t_fk = time.perf_counter()\n                        cur_root_pos = self.ref_root_pos_raw.astype(np.float32)\n                        cur_root_rot = (\n                            self._prev_external_root_rot\n                            if self._prev_external_root_rot is not None\n                            else self.external_fut_root_rot_queue[0].astype(np.float32)\n                        )\n                        cur_dof_pos = self.ref_dof_pos_raw.astype(np.float32)\n                        root_pos_tensor, root_rot_tensor, dof_pos_tensor = (\n                            self._prepare_vr_fk_tensors(\n                                cur_root_pos=cur_root_pos,\n                                cur_root_rot=cur_root_rot,\n                                cur_dof_pos=cur_dof_pos,\n                                n_fut=n_fut,\n                            )\n                        )\n                        fk_out = self.fk(\n                            root_pos=root_pos_tensor,\n                            root_quat=root_rot_tensor,\n                            dof_pos=dof_pos_tensor,\n                            fps=float(1.0 / self.dt),\n                            quat_format=\"wxyz\",\n                            vel_smoothing_sigma=0.0,\n                            compute_velocity=False,\n                        )\n                        self._fk_vr_out = {\n                            k: v.detach().cpu().numpy() for k, v in fk_out.items()\n                        }\n                        timing_info[\"fk_ms\"] = self._timing_ms(t_fk)\n                    else:\n                        self._fk_vr_out = None\n                except Exception as e:\n                    self.get_logger().error(\n                        f\"VR FK computation failed; falling back to offline reference: {e}\"\n                    )\n                    self._fk_vr_out = None\n\n            self.obs_builder = self.motion_obs_builder\n            # Use motion model metadata\n            current_action_scale = self.motion_action_scale_onnx\n            current_default_angles = self.motion_default_angles_onnx\n            current_onnx_to_real = self.motion_onnx_to_real\n        else:  # velocity mode\n            self.obs_builder = self.velocity_obs_builder\n            # Use velocity model metadata\n            current_action_scale = self.velocity_action_scale_onnx\n            current_default_angles = self.velocity_default_angles_onnx\n            current_onnx_to_real = self.velocity_onnx_to_real\n\n        t_obs = time.perf_counter()\n        if self.current_policy_mode == \"motion\":\n            self._cache_fk_vr_for_obs()\n        policy_obs_np = self.obs_builder.build_policy_obs()[None, :].astype(\n            np.float32, copy=False\n        )\n        timing_info[\"obs_ms\"] = self._timing_ms(t_obs)\n        # Run ONNX inference with the appropriate policy session and correct input/output names\n        t_onnx = time.perf_counter()\n        if self.current_policy_mode == \"velocity\":\n            input_feed = {self.velocity_input_name: policy_obs_np}\n            onnx_output = self.velocity_policy_session.run([self.velocity_output_name], input_feed)\n        else:  # motion mode\n            if self.use_kv_cache:\n                if self.motion_kv_cache is None:\n                    shape = [\n                        d if isinstance(d, int) else 1\n                        for d in self.motion_kv_shape\n                    ]\n                    self.motion_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype)\n                # if (\n                #     self.motion_effective_context_len > 0\n                #     and self.motion_step_idx > 0\n                #     and self.motion_step_idx % self.motion_effective_context_len == 0\n                # ):\n                #     self.motion_kv_cache.fill(0.0)\n\n                input_feed = {\n                    self.motion_input_name: policy_obs_np,\n                    self.motion_kv_input_name: self.motion_kv_cache,\n                }\n                if self.motion_step_idx_input_name is not None:\n                    step_idx = self.motion_step_idx\n                    # if self.motion_effective_context_len > 0:\n                    #     step_idx = (\n                    #         self.motion_step_idx\n                    #         % self.motion_effective_context_len\n                    #     )\n                    input_feed[self.motion_step_idx_input_name] = np.array(\n                        [step_idx], dtype=np.int64\n                    )\n\n                output_names = [self.motion_output_name]\n                if self.motion_kv_output_name:\n                    output_names.append(self.motion_kv_output_name)\n                onnx_output = self.motion_policy_session.run(\n                    output_names, input_feed\n                )\n                if len(onnx_output) > 1:\n                    self.motion_kv_cache = onnx_output[1]\n                self.motion_step_idx += 1\n            else:\n                input_feed = {self.motion_input_name: policy_obs_np}\n                onnx_output = self.motion_policy_session.run(\n                    [self.motion_output_name], input_feed\n                )\n        timing_info[\"onnx_ms\"] = self._timing_ms(t_onnx)\n\n        t_post = time.perf_counter()\n        raw_actions_onnx = np.asarray(onnx_output[0], dtype=np.float32).reshape(-1)\n        if self.current_policy_mode == \"motion\":\n            self.actions_onnx = self._apply_motion_action_ema_filter(raw_actions_onnx)\n        else:\n            self.actions_onnx = raw_actions_onnx.copy()\n        # Use the appropriate metadata based on current policy mode\n        self.target_dof_pos_onnx = (\n            self.actions_onnx * current_action_scale + current_default_angles\n        )\n        self.target_dof_pos_real = self.target_dof_pos_onnx[current_onnx_to_real]\n        # Action processing and publishing\n        self._process_and_publish_actions()\n        if self.current_policy_mode == \"motion\":\n            if (\n                not getattr(self, \"latest_obs_flag\", False)\n                and self.motion_frame_idx >= self.n_motion_frames\n                and self.motion_in_progress\n            ):\n                self.get_logger().info(\"Motion action completed (offline reference)\")\n                self.motion_in_progress = False\n\n        # Publish policy mode status\n        self._publish_policy_mode()\n        timing_info[\"post_ms\"] = self._timing_ms(t_post)\n        timing_info[\"policy_total_ms\"] = self._timing_ms(_t_policy_start)\n        return timing_info\n    \n    def _process_and_publish_actions(self):\n        \"\"\"Process and publish action commands.\"\"\"\n        if self.target_dof_pos_real is not None:\n            action_msg = Float32MultiArray()\n            action_msg.data = self.target_dof_pos_real.tolist()\n\n            # Check for NaN values\n            if np.isnan(self.target_dof_pos_real).any():\n                self.get_logger().error(\"Action contains NaN values\")\n\n            self.action_pub.publish(action_msg)\n\n        self.motion_frame_idx += 1\n\n    def setup(self):\n        \"\"\"Set up the evaluator by loading all required components.\"\"\"\n        main_affinity = _parse_cpu_affinity_str(\n            getattr(self, \"_cpu_affinity_main_str\", \"\") or \"\"\n        )\n        if main_affinity and set_thread_cpu_affinity(main_affinity):\n            self.get_logger().info(f\"[Policy] main thread pinned to CPUs {main_affinity}\")\n        self.load_model_config()  # Load config first\n        self.update_config_parameters()  # Update parameters from config\n        # Initialize FK for online VR reference reconstruction\n        self._init_fk()\n        self.load_policy()        # Then load policies\n        self._apply_onnx_metadata()\n        self._init_obs_buffers()\n        self._build_dof_mappings()\n        self._warmup_motion_policy()\n        self._init_keybody_indices_cache()\n        # Always load motion data since we support both modes\n        self.load_motion_data()\n        self.get_logger().info(\"Synchronous root-only policy setup completed\")\n\n    def _init_fk(self):\n        \"\"\"Initialize lightweight root-only FK for synchronous VR reference updates.\"\"\"\n        try:\n            self.get_logger().info(\n                \"Initializing root-only FK (no URDF, sync main-thread mode)\"\n            )\n            self.fk = HoloMotionFKRootOnly(\n                dof_names=self.dof_names_ref_motion,\n                device=\"cpu\",\n                timing_logger_enabled=True,\n                timing_log_interval_sec=5.0,\n                timing_log_per_call=False,\n                timing_name=\"FKRootOnlyVR\",\n                timing_log_fn=self.get_logger().info,\n            )\n            try:\n                ndof = len(self.fk.dof_names)\n                root_pos_dummy = torch.zeros((1, 4, 3), dtype=torch.float32)\n                root_quat_dummy = torch.zeros((1, 4, 4), dtype=torch.float32)\n                root_quat_dummy[..., 0] = 1.0\n                dof_pos_dummy = torch.zeros((1, 4, ndof), dtype=torch.float32)\n                _ = self.fk(\n                    root_pos=root_pos_dummy,\n                    root_quat=root_quat_dummy,\n                    dof_pos=dof_pos_dummy,\n                    fps=float(1.0 / self.dt),\n                    quat_format=\"wxyz\",\n                    vel_smoothing_sigma=0.0,\n                    compute_velocity=False,\n                )\n                self.get_logger().info(\"[FK] Root-only warmup completed (B=1,T=4)\")\n            except Exception as e_dummy:\n                self.get_logger().warn(f\"[FK] Root-only warmup failed (ignored): {e_dummy}\")\n\n            self.fk_initialized = True\n            self.get_logger().info(\n                f\"Root-only FK initialized successfully with {len(self.fk.dof_names)} dofs\"\n            )\n        except Exception as e:\n            self.get_logger().error(f\"Failed to initialize root-only FK: {e}\")\n            self.fk = None\n            self.fk_initialized = False\n\n    def destroy_node(self):\n        try:\n            if getattr(self, \"_zmq_subscriber\", None) is not None:\n                self._zmq_subscriber.stop()\n        except Exception:\n            pass\n\n        super().destroy_node()\n\ndef get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray:\n    \"\"\"Calculate gravity orientation from quaternion.\n\n    Args:\n        quaternion: Array-like [w, x, y, z]\n\n    Returns:\n        np.ndarray of shape (3,) representing gravity projection.\n    \"\"\"\n    qw = float(quaternion[0])\n    qx = float(quaternion[1])\n    qy = float(quaternion[2])\n    qz = float(quaternion[3])\n\n    gravity_orientation = np.zeros(3, dtype=np.float32)\n    gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy)\n    gravity_orientation[1] = -2.0 * (qz * qy + qw * qx)\n    gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz)\n    return gravity_orientation\n\n\ndef main():\n    \"\"\"Main entry point for the policy node.\"\"\"\n    rclpy.init()\n    policy_node = HoloMotionPolicyNode()\n    rclpy.spin(policy_node)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/__init__.py",
    "content": ""
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/command_helper.py",
    "content": "from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as LowCmdGo\nfrom unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as LowCmdHG\nfrom typing import Union\n\n\nclass MotorMode:\n    PR = 0  # Series Control for Pitch/Roll Joints\n    AB = 1  # Parallel Control for A/B Joints\n\n\ndef create_damping_cmd(cmd: Union[LowCmdGo, LowCmdHG]):\n    size = len(cmd.motor_cmd)\n    for i in range(size):\n        cmd.motor_cmd[i].q = 0\n        cmd.motor_cmd[i].qd = 0\n        cmd.motor_cmd[i].kp = 0\n        cmd.motor_cmd[i].kd = 8\n        cmd.motor_cmd[i].tau = 0\n\n\ndef create_zero_cmd(cmd: Union[LowCmdGo, LowCmdHG]):\n    size = len(cmd.motor_cmd)\n    for i in range(size):\n        cmd.motor_cmd[i].q = 0\n        cmd.motor_cmd[i].qd = 0\n        cmd.motor_cmd[i].kp = 0\n        cmd.motor_cmd[i].kd = 0\n        cmd.motor_cmd[i].tau = 0\n\n\ndef init_cmd_hg(cmd: LowCmdHG, mode_machine: int, mode_pr: int):\n    cmd.mode_machine = mode_machine\n    cmd.mode_pr = mode_pr\n    size = len(cmd.motor_cmd)\n    for i in range(size):\n        cmd.motor_cmd[i].mode = 1\n        cmd.motor_cmd[i].q = 0\n        cmd.motor_cmd[i].qd = 0\n        cmd.motor_cmd[i].kp = 0\n        cmd.motor_cmd[i].kd = 0\n        cmd.motor_cmd[i].tau = 0\n\n\ndef init_cmd_go(cmd: LowCmdGo, weak_motor: list):\n    cmd.head[0] = 0xFE\n    cmd.head[1] = 0xEF\n    cmd.level_flag = 0xFF\n    cmd.gpio = 0\n    PosStopF = 2.146e9\n    VelStopF = 16000.0\n    size = len(cmd.motor_cmd)\n    for i in range(size):\n        if i in weak_motor:\n            cmd.motor_cmd[i].mode = 1\n        else:\n            cmd.motor_cmd[i].mode = 0x0A\n        cmd.motor_cmd[i].q = PosStopF\n        cmd.motor_cmd[i].qd = VelStopF\n        cmd.motor_cmd[i].kp = 0\n        cmd.motor_cmd[i].kd = 0\n        cmd.motor_cmd[i].tau = 0\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/maths.py",
    "content": "import torch\nimport numpy as np\nimport random\nimport os\n\n\n@torch.jit.script\ndef normalize(x, eps: float = 1e-9):\n    return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)\n\n\n@torch.jit.script\ndef torch_rand_float(lower, upper, shape, device):\n    # type: (float, float, Tuple[int, int], str) -> Tensor\n    return (upper - lower) * torch.rand(*shape, device=device) + lower\n\n\n@torch.jit.script\ndef copysign(a, b):\n    # type: (float, Tensor) -> Tensor\n    a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])\n    return torch.abs(a) * torch.sign(b)\n\n\ndef set_seed(seed, torch_deterministic=False):\n    \"\"\"set seed across modules\"\"\"\n    if seed == -1 and torch_deterministic:\n        seed = 42\n    elif seed == -1:\n        seed = np.random.randint(0, 10000)\n    print(\"Setting seed: {}\".format(seed))\n\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n    if torch_deterministic:\n        # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n        os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n        torch.backends.cudnn.benchmark = False\n        torch.backends.cudnn.deterministic = True\n        torch.use_deterministic_algorithms(True)\n    else:\n        torch.backends.cudnn.benchmark = True\n        torch.backends.cudnn.deterministic = False\n\n    return seed\n\n\ndef to_torch(x, dtype=torch.float, device=\"cuda:0\", requires_grad=False):\n    return torch.tensor(\n        x, dtype=dtype, device=device, requires_grad=requires_grad\n    )\n\n\n@torch.compile\ndef quat_mul_legacy(\n    a: torch.Tensor, b: torch.Tensor, w_last: bool = True\n) -> torch.Tensor:\n    \"\"\"Multiply two quaternions.\n\n    Args:\n        a (torch.Tensor): (..., 4) quaternion.\n        b (torch.Tensor): (..., 4) quaternion.\n        w_last (bool): Whether the scalar part w is the last element.\n                      If True, format is [x, y, z, w]; if False, format is [w, x, y, z].\n\n    Returns:\n        torch.Tensor: (..., 4) quaternion result of a * b.\n    \"\"\"\n    assert a.shape == b.shape\n    shape = a.shape\n    a = a.reshape(-1, 4)\n    b = b.reshape(-1, 4)\n\n    if w_last:\n        # Format: [x, y, z, w]\n        x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]\n        x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]\n    else:\n        # Format: [w, x, y, z]\n        w1, x1, y1, z1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]\n        w2, x2, y2, z2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]\n\n    ww = (z1 + x1) * (x2 + y2)\n    yy = (w1 - y1) * (w2 + z2)\n    zz = (w1 + y1) * (w2 - z2)\n    xx = ww + yy + zz\n    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))\n    w = qq - ww + (z1 - y1) * (y2 - z2)\n    x = qq - xx + (x1 + w1) * (x2 + w2)\n    y = qq - yy + (w1 - x1) * (y2 + z2)\n    z = qq - zz + (z1 + y1) * (w2 - x2)\n\n    if w_last:\n        quat = torch.stack([x, y, z, w], dim=-1).view(shape)\n    else:\n        quat = torch.stack([w, x, y, z], dim=-1).view(shape)\n\n    return quat\n\n\n@torch.jit.script\ndef quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:\n    \"\"\"Multiply two quaternions together.\n\n    Args:\n        q1: The first quaternion in (w, x, y, z). Shape is (..., 4).\n        q2: The second quaternion in (w, x, y, z). Shape is (..., 4).\n\n    Returns:\n        The product of the two quaternions in (w, x, y, z). Shape is (..., 4).\n\n    Raises:\n        ValueError: Input shapes of ``q1`` and ``q2`` are not matching.\n    \"\"\"\n    # check input is correct\n    if q1.shape != q2.shape:\n        msg = f\"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}.\"\n        raise ValueError(msg)\n    # reshape to (N, 4) for multiplication\n    shape = q1.shape\n    q1 = q1.reshape(-1, 4)\n    q2 = q2.reshape(-1, 4)\n    # extract components from quaternions\n    w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]\n    w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]\n    # perform multiplication\n    ww = (z1 + x1) * (x2 + y2)\n    yy = (w1 - y1) * (w2 + z2)\n    zz = (w1 + y1) * (w2 - z2)\n    xx = ww + yy + zz\n    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))\n    w = qq - ww + (z1 - y1) * (y2 - z2)\n    x = qq - xx + (x1 + w1) * (x2 + w2)\n    y = qq - yy + (w1 - x1) * (y2 + z2)\n    z = qq - zz + (z1 + y1) * (w2 - x2)\n\n    return torch.stack([w, x, y, z], dim=-1).view(shape)\n\n\n@torch.jit.script\ndef normalize(x, eps: float = 1e-9):\n    return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)\n\n\n@torch.jit.script\ndef quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:\n    \"\"\"Apply a quaternion rotation to a vector.\n\n    Args:\n        quat: The quaternion in (w, x, y, z). Shape is (..., 4).\n        vec: The vector in (x, y, z). Shape is (..., 3).\n\n    Returns:\n        The rotated vector in (x, y, z). Shape is (..., 3).\n    \"\"\"\n    # store shape\n    shape = vec.shape\n    # reshape to (N, 3) for multiplication\n    quat = quat.reshape(-1, 4)\n    vec = vec.reshape(-1, 3)\n    # extract components from quaternions\n    xyz = quat[:, 1:]\n    t = xyz.cross(vec, dim=-1) * 2\n    return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)\n\n\n@torch.jit.script\ndef quat_apply_inverse(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:\n    \"\"\"Apply an inverse quaternion rotation to a vector.\n\n    Args:\n        quat: The quaternion in (w, x, y, z). Shape is (..., 4).\n        vec: The vector in (x, y, z). Shape is (..., 3).\n\n    Returns:\n        The rotated vector in (x, y, z). Shape is (..., 3).\n    \"\"\"\n    # store shape\n    shape = vec.shape\n    # reshape to (N, 3) for multiplication\n    quat = quat.reshape(-1, 4)\n    vec = vec.reshape(-1, 3)\n    # extract components from quaternions\n    xyz = quat[:, 1:]\n    t = xyz.cross(vec, dim=-1) * 2\n    return (vec - quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)\n\n\n@torch.jit.script\ndef quat_rotate(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a + b + c\n\n\n# @torch.jit.script\ndef quat_rotate_inverse(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a - b + c\n\n\n@torch.jit.script\ndef quat_conjugate(a):\n    shape = a.shape\n    a = a.reshape(-1, 4)\n    # return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)\n    return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape)\n\n\n@torch.jit.script\ndef quat_unit(a):\n    return normalize(a)\n\n\n@torch.jit.script\ndef quat_from_angle_axis(angle, axis):\n    theta = (angle / 2).unsqueeze(-1)\n    xyz = normalize(axis) * theta.sin()\n    w = theta.cos()\n    return quat_unit(torch.cat([xyz, w], dim=-1))\n\n\n@torch.jit.script\ndef normalize_angle(x):\n    return torch.atan2(torch.sin(x), torch.cos(x))\n\n\n@torch.jit.script\ndef get_basis_vector(q, v):\n    return quat_rotate(q, v)\n\n\ndef get_axis_params(value, axis_idx, x_value=0.0, dtype=np.float64, n_dims=3):\n    \"\"\"Construct arguments to `Vec` according to axis index.\"\"\"\n    zs = np.zeros((n_dims,))\n    assert axis_idx < n_dims, (\n        \"the axis dim should be within the vector dimensions\"\n    )\n    zs[axis_idx] = 1.0\n    params = np.where(zs == 1.0, value, zs)\n    params[0] = x_value\n    return list(params.astype(dtype))\n\n\n# @torch.jit.script\n# def copysign(a, b):\n#     a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])\n#     return torch.abs(a) * torch.sign(b)\n\n\n@torch.jit.script\ndef get_euler_xyz(q: torch.Tensor) -> tuple:\n    qx, qy, qz, qw = 0, 1, 2, 3\n    # roll (x-axis rotation)\n    sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])\n    cosr_cosp = (\n        q[:, qw] * q[:, qw]\n        - q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        + q[:, qz] * q[:, qz]\n    )\n    roll = torch.atan2(sinr_cosp, cosr_cosp)\n\n    # pitch (y-axis rotation)\n    sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])\n    pitch = torch.where(\n        torch.abs(sinp) >= 1,\n        copysign(torch.tensor(np.pi / 2.0, device=sinp.device), sinp),\n        torch.asin(sinp),\n    )\n\n    # yaw (z-axis rotation)\n    siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])\n    cosy_cosp = (\n        q[:, qw] * q[:, qw]\n        + q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        - q[:, qz] * q[:, qz]\n    )\n    yaw = torch.atan2(siny_cosp, cosy_cosp)\n\n    return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)\n\n\n@torch.jit.script\ndef quat_from_euler_xyz(roll, pitch, yaw):\n    cy = torch.cos(yaw * 0.5)\n    sy = torch.sin(yaw * 0.5)\n    cr = torch.cos(roll * 0.5)\n    sr = torch.sin(roll * 0.5)\n    cp = torch.cos(pitch * 0.5)\n    sp = torch.sin(pitch * 0.5)\n\n    qw = cy * cr * cp + sy * sr * sp\n    qx = cy * sr * cp - sy * cr * sp\n    qy = cy * cr * sp + sy * sr * cp\n    qz = sy * cr * cp - cy * sr * sp\n\n    return torch.stack([qx, qy, qz, qw], dim=-1)\n\n\ndef torch_rand_float(lower, upper, shape, device):\n    return (upper - lower) * torch.rand(*shape, device=device) + lower\n\n\n# @torch.jit.script\n@torch.compile\ndef torch_random_dir_2(shape, device):\n    angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1)\n    return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)\n\n\n@torch.jit.script\ndef tensor_clamp(t, min_t, max_t):\n    return torch.max(torch.min(t, max_t), min_t)\n\n\n@torch.jit.script\ndef scale(x, lower, upper):\n    return 0.5 * (x + 1.0) * (upper - lower) + lower\n\n\n@torch.jit.script\ndef unscale(x, lower, upper):\n    return (2.0 * x - upper - lower) / (upper - lower)\n\n\ndef unscale_np(x, lower, upper):\n    return (2.0 * x - upper - lower) / (upper - lower)\n\n\n@torch.jit.script\ndef quat_to_angle_axis(q):\n    # computes axis-angle representation from quaternion q\n    # q must be normalized\n    min_theta = 1e-5\n    qx, _, _, qw = 0, 1, 2, 3\n\n    sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])\n    angle = 2 * torch.acos(q[..., qw])\n    angle = normalize_angle(angle)\n    sin_theta_expand = sin_theta.unsqueeze(-1)\n    axis = q[..., qx:qw] / sin_theta_expand\n\n    mask = torch.abs(sin_theta) > min_theta\n    default_axis = torch.zeros_like(axis)\n    default_axis[..., -1] = 1\n\n    angle = torch.where(mask, angle, torch.zeros_like(angle))\n    mask_expand = mask.unsqueeze(-1)\n    axis = torch.where(mask_expand, axis, default_axis)\n    return angle, axis\n\n\n@torch.jit.script\ndef angle_axis_to_exp_map(angle, axis):\n    # compute exponential map from axis-angle\n    angle_expand = angle.unsqueeze(-1)\n    exp_map = angle_expand * axis\n    return exp_map\n\n\n@torch.jit.script\ndef quat_to_exp_map(q):\n    # compute exponential map from quaternion\n    # q must be normalized\n    angle, axis = quat_to_angle_axis(q)\n    exp_map = angle_axis_to_exp_map(angle, axis)\n    return exp_map\n\n\n@torch.jit.script\ndef slerp(q0, q1, t):\n    cos_half_theta = torch.sum(q0 * q1, dim=-1)\n\n    neg_mask = cos_half_theta < 0\n    q1 = q1.clone()\n    q1[neg_mask] = -q1[neg_mask]\n    cos_half_theta = torch.abs(cos_half_theta)\n    cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)\n\n    half_theta = torch.acos(cos_half_theta)\n    sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)\n\n    ratio_a = torch.sin((1 - t) * half_theta) / sin_half_theta\n    ratio_b = torch.sin(t * half_theta) / sin_half_theta\n\n    new_q = ratio_a * q0 + ratio_b * q1\n\n    new_q = torch.where(\n        torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q\n    )\n    new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)\n\n    return new_q\n\n\n@torch.jit.script\ndef my_quat_rotate(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a + b + c\n\n\n@torch.jit.script\ndef calc_heading(q):\n    # calculate heading direction from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    # this is the x axis heading\n    ref_dir = torch.zeros_like(q[..., 0:3])\n    ref_dir[..., 0] = 1\n    rot_dir = my_quat_rotate(q, ref_dir)\n\n    heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])\n    return heading\n\n\n@torch.jit.script\ndef calc_heading_quat(q):\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(heading, axis)\n    return heading_q\n\n\n@torch.jit.script\ndef calc_heading_quat_inv(q):\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(-heading, axis)\n    return heading_q\n\n\n@torch.compile\ndef axis_angle_from_quat(\n    quat: torch.Tensor,\n    w_last: bool = True,\n) -> torch.Tensor:\n    \"\"\"Compute axis-angle (log map) vector from a quaternion.\n\n    Args:\n        quat (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z].\n        w_last (bool): Whether the scalar part w is the last element.\n\n    Returns:\n        torch.Tensor: (..., 3) axis-angle vector (axis * angle), with angle in radians in [0, pi].\n\n    Notes:\n        - The quaternion is sign-adjusted to ensure w >= 0 and normalized to unit length for numerical stability.\n        - Uses a stable small-angle handling to avoid NaNs and gradient issues.\n    \"\"\"\n    # Handle different quaternion formats\n    if w_last:\n        # Quaternion is [q_x, q_y, q_z, q_w]\n        quat_w_orig = quat[..., -1:]\n    else:\n        # Quaternion is [q_w, q_x, q_y, q_z]\n        quat_w_orig = quat[..., 0:1]\n\n    # Normalize quaternion to have w > 0\n    quat = quat * (1.0 - 2.0 * (quat_w_orig < 0.0))\n\n    # Ensure unit quaternion for stability\n    quat = quat / torch.linalg.norm(quat, dim=-1, keepdim=True).clamp_min(\n        1.0e-9\n    )\n\n    # Recompute quat_xyz and quat_w after potential sign flip\n    if w_last:\n        quat_w = quat[..., -1:]\n        quat_xyz = quat[..., :3]\n    else:\n        quat_w = quat[..., 0:1]\n        quat_xyz = quat[..., 1:4]\n\n    mag = torch.linalg.norm(quat_xyz, dim=-1)\n    half_angle = torch.atan2(mag, quat_w.squeeze(-1))\n    angle = 2.0 * half_angle\n    # check whether to apply Taylor approximation\n    use_taylor = angle.abs() <= 1.0e-6\n    # To prevent NaN gradients with torch.where, we compute both branches and blend\n    # based on the condition.\n    # See: https://pytorch.org/docs/1.9.0/generated/torch.where.html#torch-where\n    # \"However, if you need the gradients to flow through the branches, please use torch.lerp\"\n    # Although we are not using lerp, the principle of avoiding sharp branches is the same.\n    sin_half_angles_over_angles_approx = 0.5 - angle * angle / 48\n    # Clamp angle to avoid division by zero in the non-taylor branch when angle is exactly 0.\n    angle_safe = torch.where(use_taylor, torch.ones_like(angle), angle)\n    sin_half_angles_over_angles_exact = torch.sin(half_angle) / angle_safe\n\n    sin_half_angles_over_angles = torch.where(\n        use_taylor,\n        sin_half_angles_over_angles_approx,\n        sin_half_angles_over_angles_exact,\n    )\n    return quat_xyz / sin_half_angles_over_angles[..., None]\n\n\n@torch.jit.script\ndef quat_inv(q: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:\n    \"\"\"Computes the inverse of a quaternion.\n\n    Args:\n        q: The quaternion orientation in (w, x, y, z). Shape is (N, 4).\n        eps: A small value to avoid division by zero. Defaults to 1e-9.\n\n    Returns:\n        The inverse quaternion in (w, x, y, z). Shape is (N, 4).\n    \"\"\"\n    return quat_conjugate(q) / q.pow(2).sum(dim=-1, keepdim=True).clamp(\n        min=eps\n    )\n\n\n# --------------------- WXYZ helpers (torch) ---------------------\ndef xyzw_to_wxyz(q: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert quaternion from XYZW to WXYZ.\n    Args:\n        q (torch.Tensor): (..., 4) quaternion in XYZW.\n    Returns:\n        torch.Tensor: (..., 4) quaternion in WXYZ.\n    \"\"\"\n    return torch.cat([q[..., 3:4], q[..., 0:3]], dim=-1)\n\n\ndef wxyz_to_xyzw(q: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert quaternion from WXYZ to XYZW.\n    Args:\n        q (torch.Tensor): (..., 4) quaternion in WXYZ.\n    Returns:\n        torch.Tensor: (..., 4) quaternion in XYZW.\n    \"\"\"\n    return torch.cat([q[..., 1:4], q[..., 0:1]], dim=-1)\n\n\n@torch.compile\ndef quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Hamilton product in WXYZ layout using fused implementation.\n    Args:\n        q1 (torch.Tensor): (..., 4) WXYZ.\n        q2 (torch.Tensor): (..., 4) WXYZ.\n    Returns:\n        torch.Tensor: (..., 4) WXYZ.\n    \"\"\"\n    return quat_mul(q1, q2, w_last=False)\n\n\ndef subtract_frame_transforms(\n    t01: torch.Tensor,\n    q01: torch.Tensor,\n    t02: torch.Tensor = None,\n    q02: torch.Tensor = None,\n):\n    r\"\"\"Subtract transformations between two reference frames into a stationary frame.\n\n    It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \\times T_{02}`,\n    where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B.\n\n    Args:\n        t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3).\n        q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).\n        t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3).\n            Defaults to None, in which case the position is assumed to be zero.\n        q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).\n            Defaults to None, in which case the orientation is assumed to be identity.\n\n    Returns:\n        A tuple containing the position and orientation of frame 2 w.r.t. frame 1.\n        Shape of the tensors are (N, 3) and (N, 4) respectively.\n    \"\"\"\n    # compute orientation\n    q10 = quat_inv(q01)\n    if q02 is not None:\n        q12 = quat_mul(q10, q02)\n    else:\n        q12 = q10\n    # compute translation\n    if t02 is not None:\n        t12 = quat_apply(q10, t02 - t01)\n    else:\n        t12 = quat_apply(q10, -t01)\n    return t12, q12\n\n\n@torch.compile\ndef quat_normalize_wxyz(q_wxyz: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Normalize quaternion in WXYZ layout.\n    Args:\n        q_wxyz (torch.Tensor): (..., 4) WXYZ.\n    Returns:\n        torch.Tensor: (..., 4) normalized WXYZ.\n    \"\"\"\n    return q_wxyz / torch.linalg.norm(q_wxyz, dim=-1, keepdim=True).clamp_min(\n        1.0e-9\n    )\n\n\n@torch.jit.script\ndef matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4).\n\n    Returns:\n        Rotation matrices. The shape is (..., 3, 3).\n\n    Reference:\n        https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions, -1)\n    # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/motor_crc.py",
    "content": "import struct\nimport numpy as np\nfrom ctypes import Structure, c_uint8, c_float, c_uint32, Array\n\n\ndef crc32_core(data_array, length):\n    CRC32 = 0xFFFFFFFF\n    dwPolynomial = 0x04C11DB7\n\n    for i in range(length):\n        data = data_array[i]\n        for bit in range(32):  # Process all 32 bits\n            if (CRC32 >> 31) & 1:  # Check MSB before shift\n                CRC32 = ((CRC32 << 1) & 0xFFFFFFFF) ^ dwPolynomial\n            else:\n                CRC32 = (CRC32 << 1) & 0xFFFFFFFF\n\n            if (data >> (31 - bit)) & 1:  # Match C++ bit processing order\n                CRC32 ^= dwPolynomial\n\n    return CRC32\n\n\ndef calc_crc(cmd) -> int:\n    \"\"\"Calculate CRC for LowCmd message\"\"\"\n    buffer = bytearray()\n\n    # Pack header (mode_pr, mode_machine + 2 padding)\n    buffer.extend(struct.pack(\"<BBxx\", cmd.mode_pr, cmd.mode_machine))\n\n    # Pack motor commands\n    for motor in cmd.motor_cmd:\n        buffer.extend(\n            struct.pack(\n                \"<B3xfffffI\",\n                motor.mode,\n                motor.q,\n                motor.dq,\n                motor.tau,\n                motor.kp,\n                motor.kd,\n                motor.reserve,\n            )\n        )\n\n    # Pack reserve (4 bytes)\n    buffer.extend(struct.pack(\"<4B\", *cmd.reserve))\n\n    # Convert to uint32 array (little-endian)\n    uint32_array = struct.unpack(f\"<{len(buffer) // 4}I\", buffer)\n\n    # Calculate with fixed length (246 for LowCmd struct size)\n    return crc32_core(uint32_array, 246)\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/remote_controller_filter.py",
    "content": "import struct\n\n\nclass KeyMap:\n    R1 = 0\n    L1 = 1\n    start = 2\n    select = 3\n    R2 = 4\n    L2 = 5\n    F1 = 6\n    F2 = 7\n    A = 8\n    B = 9\n    X = 10\n    Y = 11\n    up = 12\n    right = 13\n    down = 14\n    left = 15\n\n\nclass RemoteController:\n    def __init__(self):\n        self.lx = 0\n        self.ly = 0\n        self.rx = 0\n        self.ry = 0\n        self.button = [0] * 16\n        # 添加滤波器参数\n        self.alpha = 0.3  # 滤波系数 (0-1之间，越小滤波效果越强)\n        self.deadzone = 0.05  # 死区阈值\n\n        # 添加上一次的状态值用于滤波\n        self.lx_prev = 0\n        self.ly_prev = 0\n        self.rx_prev = 0\n        self.ry_prev = 0\n        self.smooth = 0.1  # 平滑系数，值越小越平滑\n        self.dead_zone = 0.01  # 死区阈值\n\n        # 速度映射参数\n        self.max_linear_speed_x = 0.5  # 最大线速度 (m/s)\n        self.max_linear_speed_y = 0.2  # 最大线速度 (m/s)\n        self.max_angular_speed = 0.7  # 最大角速度 (rad/s)\n\n        # 速度阈值参数 - 当速度小于阈值时设为0\n        self.velocity_threshold_x = 0.1  # 前进/后退速度阈值 (m/s)\n        self.velocity_threshold_y = 0.1  # 左右平移速度阈值 (m/s)\n        self.velocity_threshold_yaw = 0.1  # 转向角速度阈值 (rad/s)\n\n    def apply_filter_and_deadzone(self, value, prev_value):\n        # 结合死区判断和平滑处理\n        if abs(value) < self.dead_zone:\n            value = 0.0\n        return prev_value * (1 - self.smooth) + value * self.smooth\n\n    def set(self, data):\n        # wireless_remote\n        keys = struct.unpack(\"H\", data[2:4])[0]\n        for i in range(16):\n            self.button[i] = (keys & (1 << i)) >> i\n\n        # 读取原始值\n        lx_raw = struct.unpack(\"f\", data[4:8])[0]\n        rx_raw = struct.unpack(\"f\", data[8:12])[0]\n        ry_raw = struct.unpack(\"f\", data[12:16])[0]\n        ly_raw = struct.unpack(\"f\", data[20:24])[0]\n\n        # 应用滤波和死区\n        self.lx = self.apply_filter_and_deadzone(lx_raw, self.lx_prev)\n        self.ly = self.apply_filter_and_deadzone(ly_raw, self.ly_prev)\n        self.rx = self.apply_filter_and_deadzone(rx_raw, self.rx_prev)\n        self.ry = self.apply_filter_and_deadzone(ry_raw, self.ry_prev)\n\n        # 更新前一次的值\n        self.lx_prev = self.lx\n        self.ly_prev = self.ly\n        self.rx_prev = self.rx\n        self.ry_prev = self.ry\n\n    def get_velocity_commands(self):\n        \"\"\"\n        将摇杆值转换为速度命令\n        Returns:\n            tuple: (vx, vy, vyaw)\n            - vx: 前进/后退速度 (m/s)，由左摇杆前后(ly)控制\n            - vy: 左右平移速度 (m/s)，由左摇杆左右(lx)控制\n            - vyaw: 转向角速度 (rad/s)，由右摇杆左右(rx)控制\n        \"\"\"\n        # 前进/后退速度，使用左摇杆的y轴\n        vx = (\n            self.ly * self.max_linear_speed_x\n        )  # 注意：通常需要取反，因为向前推摇杆时ly为负\n\n        # 限制x速度最小值为-0.5\n        if vx < -0.5:\n            vx = -0.5\n\n        # 左右平移速度，使用左摇杆的x轴\n        vy = (\n            -self.lx * self.max_linear_speed_y\n        )  # 注意：可能需要取反，取决于坐标系定义\n\n        # 转向角速度，使用右摇杆的x轴\n        vyaw = (\n            -self.rx * self.max_angular_speed\n        )  # 注意：可能需要取反，取决于坐标系定义\n\n        # 应用速度阈值 - 当速度小于阈值时设为0\n        if abs(vx) < self.velocity_threshold_x:\n            vx = 0.0\n        if abs(vy) < self.velocity_threshold_y:\n            vy = 0.0\n        if abs(vyaw) < self.velocity_threshold_yaw:\n            vyaw = 0.0\n\n        return vx, vy, vyaw\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/rotation_helper.py",
    "content": "import numpy as np\nfrom scipy.spatial.transform import Rotation as R\n\n\ndef get_gravity_orientation(quaternion):\n    qw = quaternion[0]\n    qx = quaternion[1]\n    qy = quaternion[2]\n    qz = quaternion[3]\n\n    gravity_orientation = np.zeros(3)\n\n    gravity_orientation[0] = 2 * (-qz * qx + qw * qy)\n    gravity_orientation[1] = -2 * (qz * qy + qw * qx)\n    gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)\n\n    return gravity_orientation\n\n\ndef transform_imu_data(waist_yaw, waist_yaw_omega, imu_quat, imu_omega):\n    RzWaist = R.from_euler(\"z\", waist_yaw).as_matrix()\n    R_torso = R.from_quat(\n        [imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]\n    ).as_matrix()\n    R_pelvis = np.dot(R_torso, RzWaist.T)\n    w = np.dot(RzWaist, imu_omega[0]) - np.array([0, 0, waist_yaw_omega])\n    return R.from_matrix(R_pelvis).as_quat()[[3, 0, 1, 2]], w\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/rotations.py",
    "content": "import torch\nfrom torch import Tensor\nimport torch.nn.functional as F\nfrom humanoid_policy.utils.maths import (\n    normalize,\n    copysign,\n)\nfrom typing import Tuple\nimport numpy as np\nfrom typing import List, Optional\n\n\n@torch.jit.script\ndef quat_unit(a):\n    return normalize(a)\n\n\n@torch.jit.script\ndef quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor:\n    shape = b.shape\n    a = a.reshape(-1, 4)\n    b = b.reshape(-1, 3)\n    if w_last:\n        xyz = a[:, :3]\n        w = a[:, 3:]\n    else:\n        xyz = a[:, 1:]\n        w = a[:, :1]\n    t = xyz.cross(b, dim=-1) * 2\n    return (b + w * t + xyz.cross(t, dim=-1)).view(shape)\n\n\n@torch.jit.script\ndef quat_apply_yaw(quat: Tensor, vec: Tensor, w_last: bool) -> Tensor:\n    quat_yaw = quat.clone().view(-1, 4)\n    quat_yaw[:, :2] = 0.0\n    quat_yaw = normalize(quat_yaw)\n    return quat_apply(quat_yaw, vec, w_last)\n\n\n@torch.jit.script\ndef wrap_to_pi(angles):\n    angles %= 2 * np.pi\n    angles -= 2 * np.pi * (angles > np.pi)\n    return angles\n\n\n@torch.jit.script\ndef quat_conjugate(a: Tensor, w_last: bool) -> Tensor:\n    shape = a.shape\n    a = a.reshape(-1, 4)\n    if w_last:\n        return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)\n    else:\n        return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape)\n\n\n@torch.jit.script\ndef quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor:\n    shape = b.shape\n    a = a.reshape(-1, 4)\n    b = b.reshape(-1, 3)\n    if w_last:\n        xyz = a[:, :3]\n        w = a[:, 3:]\n    else:\n        xyz = a[:, 1:]\n        w = a[:, :1]\n    t = xyz.cross(b, dim=-1) * 2\n    return (b + w * t + xyz.cross(t, dim=-1)).view(shape)\n\n\n@torch.jit.script\ndef quat_rotate(q: Tensor, v: Tensor, w_last: bool) -> Tensor:\n    shape = q.shape\n    if w_last:\n        q_w = q[:, -1]\n        q_vec = q[:, :3]\n    else:\n        q_w = q[:, 0]\n        q_vec = q[:, 1:]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a + b + c\n\n\n@torch.jit.script\ndef quat_rotate_inverse(q: Tensor, v: Tensor, w_last: bool) -> Tensor:\n    shape = q.shape\n    if w_last:\n        q_w = q[:, -1]\n        q_vec = q[:, :3]\n    else:\n        q_w = q[:, 0]\n        q_vec = q[:, 1:]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a - b + c\n\n\n@torch.jit.script\ndef quat_angle_axis(x: Tensor, w_last: bool) -> Tuple[Tensor, Tensor]:\n    \"\"\"\n    The (angle, axis) representation of the rotation. The axis is normalized to unit length.\n    The angle is guaranteed to be between [0, pi].\n    \"\"\"\n    if w_last:\n        w = x[..., -1]\n        axis = x[..., :3]\n    else:\n        w = x[..., 0]\n        axis = x[..., 1:]\n    s = 2 * (w**2) - 1\n    angle = s.clamp(-1, 1).arccos()  # just to be safe\n    axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9)\n    return angle, axis\n\n\n@torch.jit.script\ndef quat_from_angle_axis(angle: Tensor, axis: Tensor, w_last: bool) -> Tensor:\n    theta = (angle / 2).unsqueeze(-1)\n    xyz = normalize(axis) * theta.sin()\n    w = theta.cos()\n    if w_last:\n        return quat_unit(torch.cat([xyz, w], dim=-1))\n    else:\n        return quat_unit(torch.cat([w, xyz], dim=-1))\n\n\n@torch.jit.script\ndef vec_to_heading(h_vec):\n    h_theta = torch.atan2(h_vec[..., 1], h_vec[..., 0])\n    return h_theta\n\n\n@torch.jit.script\ndef heading_to_quat(h_theta, w_last: bool):\n    axis = torch.zeros(\n        h_theta.shape\n        + [\n            3,\n        ],\n        device=h_theta.device,\n    )\n    axis[..., 2] = 1\n    heading_q = quat_from_angle_axis(h_theta, axis, w_last=w_last)\n    return heading_q\n\n\n@torch.jit.script\ndef quat_axis(q: Tensor, axis: int, w_last: bool) -> Tensor:\n    basis_vec = torch.zeros(q.shape[0], 3, device=q.device)\n    basis_vec[:, axis] = 1\n    return quat_rotate(q, basis_vec, w_last)\n\n\n@torch.jit.script\ndef normalize_angle(x):\n    return torch.atan2(torch.sin(x), torch.cos(x))\n\n\n@torch.jit.script\ndef get_basis_vector(q: Tensor, v: Tensor, w_last: bool) -> Tensor:\n    return quat_rotate(q, v, w_last)\n\n\n@torch.jit.script\ndef quat_to_angle_axis(q):\n    # type: (Tensor) -> Tuple[Tensor, Tensor]\n    # computes axis-angle representation from quaternion q\n    # q must be normalized\n    # ZL: could have issues.\n    min_theta = 1e-5\n    qx, qy, qz, qw = 0, 1, 2, 3\n\n    sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])\n    angle = 2 * torch.acos(q[..., qw])\n    angle = normalize_angle(angle)\n    sin_theta_expand = sin_theta.unsqueeze(-1)\n    axis = q[..., qx:qw] / sin_theta_expand\n\n    mask = torch.abs(sin_theta) > min_theta\n    default_axis = torch.zeros_like(axis)\n    default_axis[..., -1] = 1\n\n    angle = torch.where(mask, angle, torch.zeros_like(angle))\n    mask_expand = mask.unsqueeze(-1)\n    axis = torch.where(mask_expand, axis, default_axis)\n    return angle, axis\n\n\n@torch.jit.script\ndef slerp(q0, q1, t):\n    # type: (Tensor, Tensor, Tensor) -> Tensor\n    cos_half_theta = torch.sum(q0 * q1, dim=-1)\n\n    neg_mask = cos_half_theta < 0\n    q1 = q1.clone()\n    q1[neg_mask] = -q1[neg_mask]\n    cos_half_theta = torch.abs(cos_half_theta)\n    cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)\n\n    half_theta = torch.acos(cos_half_theta)\n    sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)\n\n    ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta\n    ratioB = torch.sin(t * half_theta) / sin_half_theta\n\n    new_q = ratioA * q0 + ratioB * q1\n\n    new_q = torch.where(\n        torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q\n    )\n    new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)\n\n    return new_q\n\n\n@torch.jit.script\ndef angle_axis_to_exp_map(angle, axis):\n    # type: (Tensor, Tensor) -> Tensor\n    # compute exponential map from axis-angle\n    angle_expand = angle.unsqueeze(-1)\n    exp_map = angle_expand * axis\n    return exp_map\n\n\n@torch.jit.script\ndef my_quat_rotate(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a + b + c\n\n\n@torch.jit.script\ndef calc_heading(q):\n    # type: (Tensor) -> Tensor\n    # calculate heading direction from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    # this is the x axis heading\n    ref_dir = torch.zeros_like(q[..., 0:3])\n    ref_dir[..., 0] = 1\n    rot_dir = my_quat_rotate(q, ref_dir)\n\n    heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])\n    return heading\n\n\n@torch.jit.script\ndef quat_to_exp_map(q):\n    # type: (Tensor) -> Tensor\n    # compute exponential map from quaternion\n    # q must be normalized\n    angle, axis = quat_to_angle_axis(q)\n    exp_map = angle_axis_to_exp_map(angle, axis)\n    return exp_map\n\n\n@torch.jit.script\ndef calc_heading_quat(q, w_last):\n    # type: (Tensor, bool) -> Tensor\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(heading, axis, w_last=w_last)\n    return heading_q\n\n\n@torch.jit.script\ndef calc_heading_quat_inv(q, w_last):\n    # type: (Tensor, bool) -> Tensor\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(-heading, axis, w_last=w_last)\n    return heading_q\n\n\n@torch.jit.script\ndef quat_inverse(x, w_last):\n    # type: (Tensor, bool) -> Tensor\n    \"\"\"\n    The inverse of the rotation\n    \"\"\"\n    return quat_conjugate(x, w_last=w_last)\n\n\n@torch.jit.script\ndef get_euler_xyz(q: Tensor, w_last: bool) -> Tuple[Tensor, Tensor, Tensor]:\n    if w_last:\n        qx, qy, qz, qw = 0, 1, 2, 3\n    else:\n        qw, qx, qy, qz = 0, 1, 2, 3\n    # roll (x-axis rotation)\n    sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])\n    cosr_cosp = (\n        q[:, qw] * q[:, qw]\n        - q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        + q[:, qz] * q[:, qz]\n    )\n    roll = torch.atan2(sinr_cosp, cosr_cosp)\n\n    # pitch (y-axis rotation)\n    sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])\n    pitch = torch.where(\n        torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)\n    )\n\n    # yaw (z-axis rotation)\n    siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])\n    cosy_cosp = (\n        q[:, qw] * q[:, qw]\n        + q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        - q[:, qz] * q[:, qz]\n    )\n    yaw = torch.atan2(siny_cosp, cosy_cosp)\n\n    return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)\n\n\n# @torch.jit.script\ndef get_euler_xyz_in_tensor(q):\n    qx, qy, qz, qw = 0, 1, 2, 3\n    # roll (x-axis rotation)\n    sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])\n    cosr_cosp = (\n        q[:, qw] * q[:, qw]\n        - q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        + q[:, qz] * q[:, qz]\n    )\n    roll = torch.atan2(sinr_cosp, cosr_cosp)\n\n    # pitch (y-axis rotation)\n    sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])\n    pitch = torch.where(\n        torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)\n    )\n\n    # yaw (z-axis rotation)\n    siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])\n    cosy_cosp = (\n        q[:, qw] * q[:, qw]\n        + q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        - q[:, qz] * q[:, qz]\n    )\n    yaw = torch.atan2(siny_cosp, cosy_cosp)\n\n    return torch.stack((roll, pitch, yaw), dim=-1)\n\n\n@torch.jit.script\ndef quat_pos(x):\n    \"\"\"\n    make all the real part of the quaternion positive\n    \"\"\"\n    q = x\n    z = (q[..., 3:] < 0).float()\n    q = (1 - 2 * z) * q\n    return q\n\n\n@torch.jit.script\ndef is_valid_quat(q):\n    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]\n    return (w * w + x * x + y * y + z * z).allclose(torch.ones_like(w))\n\n\n@torch.jit.script\ndef quat_normalize(q):\n    \"\"\"\n    Construct 3D rotation from quaternion (the quaternion needs not to be normalized).\n    \"\"\"\n    q = quat_unit(quat_pos(q))  # normalized to positive and unit quaternion\n    return q\n\n\n@torch.jit.script\ndef quat_mul(a, b, w_last: bool):\n    assert a.shape == b.shape\n    shape = a.shape\n    a = a.reshape(-1, 4)\n    b = b.reshape(-1, 4)\n\n    if w_last:\n        x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]\n        x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]\n    else:\n        w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]\n        w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]\n    ww = (z1 + x1) * (x2 + y2)\n    yy = (w1 - y1) * (w2 + z2)\n    zz = (w1 + y1) * (w2 - z2)\n    xx = ww + yy + zz\n    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))\n    w = qq - ww + (z1 - y1) * (y2 - z2)\n    x = qq - xx + (x1 + w1) * (x2 + w2)\n    y = qq - yy + (w1 - x1) * (y2 + z2)\n    z = qq - zz + (z1 + y1) * (w2 - x2)\n\n    if w_last:\n        quat = torch.stack([x, y, z, w], dim=-1).view(shape)\n    else:\n        quat = torch.stack([w, x, y, z], dim=-1).view(shape)\n\n    return quat\n\n\n@torch.jit.script\ndef quat_mul_norm(x, y, w_last):\n    # type: (Tensor, Tensor, bool) -> Tensor\n    \"\"\"\n    Combine two set of 3D rotations together using \\**\\* operator. The shape needs to be\n    broadcastable\n    \"\"\"\n    return quat_normalize(quat_mul(x, y, w_last))\n\n\n@torch.jit.script\ndef quat_mul_norm(x, y, w_last):\n    # type: (Tensor, Tensor, bool) -> Tensor\n    \"\"\"\n    Combine two set of 3D rotations together using \\**\\* operator. The shape needs to be\n    broadcastable\n    \"\"\"\n    return quat_unit(quat_mul(x, y, w_last))\n\n\n@torch.jit.script\ndef quat_identity(shape: List[int]):\n    \"\"\"\n    Construct 3D identity rotation given shape\n    \"\"\"\n    w = torch.ones(shape + [1])\n    xyz = torch.zeros(shape + [3])\n    q = torch.cat([xyz, w], dim=-1)\n    return quat_normalize(q)\n\n\n@torch.jit.script\ndef quat_identity_like(x):\n    \"\"\"\n    Construct identity 3D rotation with the same shape\n    \"\"\"\n    return quat_identity(x.shape[:-1])\n\n\n@torch.jit.script\ndef transform_from_rotation_translation(\n    r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None\n):\n    \"\"\"\n    Construct a transform from a quaternion and 3D translation. Only one of them can be None.\n    \"\"\"\n    assert r is not None or t is not None, (\n        \"rotation and translation can't be all None\"\n    )\n    if r is None:\n        assert t is not None\n        r = quat_identity(list(t.shape))\n    if t is None:\n        t = torch.zeros(list(r.shape) + [3])\n    return torch.cat([r, t], dim=-1)\n\n\n@torch.jit.script\ndef transform_rotation(x):\n    \"\"\"Get rotation from transform\"\"\"\n    return x[..., :4]\n\n\n@torch.jit.script\ndef transform_translation(x):\n    \"\"\"Get translation from transform\"\"\"\n    return x[..., 4:]\n\n\n@torch.jit.script\ndef transform_mul(x, y):\n    \"\"\"\n    Combine two transformation together\n    \"\"\"\n    z = transform_from_rotation_translation(\n        r=quat_mul_norm(\n            transform_rotation(x), transform_rotation(y), w_last=True\n        ),\n        t=quat_rotate(\n            transform_rotation(x), transform_translation(y), w_last=True\n        )\n        + transform_translation(x),\n    )\n    return z\n\n\n##################################### FROM PHC rotation_conversions.py #####################################\n@torch.jit.script\ndef quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions, -1)\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\n\n@torch.jit.script\ndef axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as axis/angle to quaternions.\n\n    Args:\n        axis_angle: Rotations given as a vector in axis angle form,\n            as a tensor of shape (..., 3), where the magnitude is\n            the angle turned anticlockwise in radians around the\n            vector's direction.\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)\n    half_angles = angles * 0.5\n    eps = 1e-6\n    small_angles = angles.abs() < eps\n    sin_half_angles_over_angles = torch.empty_like(angles)\n    sin_half_angles_over_angles[~small_angles] = (\n        torch.sin(half_angles[~small_angles]) / angles[~small_angles]\n    )\n    # for x small, sin(x/2) is about x/2 - (x/2)^3/6\n    # so sin(x/2)/x is about 1/2 - (x*x)/48\n    sin_half_angles_over_angles[small_angles] = (\n        0.5 - (angles[small_angles] * angles[small_angles]) / 48\n    )\n    quaternions = torch.cat(\n        [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles],\n        dim=-1,\n    )\n    return quaternions\n\n\n# @torch.jit.script\ndef wxyz_to_xyzw(quat):\n    return quat[..., [1, 2, 3, 0]]\n\n\n# @torch.jit.script\ndef xyzw_to_wxyz(quat):\n    return quat[..., [3, 0, 1, 2]]\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    w x y z\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        matrix.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    # we produce the desired quaternion multiplied by each of r, i, j, k\n    quat_by_rijk = torch.stack(\n        [\n            torch.stack(\n                [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1\n            ),\n            torch.stack(\n                [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1\n            ),\n            torch.stack(\n                [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1\n            ),\n            torch.stack(\n                [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1\n            ),\n        ],\n        dim=-2,\n    )\n\n    # We floor here at 0.1 but the exact level is not important; if q_abs is small,\n    # the candidate won't be picked.\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),\n    # forall i; we pick the best-conditioned one (with the largest denominator)\n\n    return quat_candidates[\n        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,\n        :,  # pyre-ignore[16]\n    ].reshape(batch_dim + (4,))\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Returns torch.sqrt(torch.max(0, x))\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef quat_w_first(rot):\n    rot = torch.cat([rot[..., [-1]], rot[..., :-1]], -1)\n    return rot\n\n\n@torch.jit.script\ndef quat_from_euler_xyz(roll, pitch, yaw):\n    cy = torch.cos(yaw * 0.5)\n    sy = torch.sin(yaw * 0.5)\n    cr = torch.cos(roll * 0.5)\n    sr = torch.sin(roll * 0.5)\n    cp = torch.cos(pitch * 0.5)\n    sp = torch.sin(pitch * 0.5)\n\n    qw = cy * cr * cp + sy * sr * sp\n    qx = cy * sr * cp - sy * cr * sp\n    qy = cy * cr * sp + sy * sr * cp\n    qz = sy * cr * cp - cy * sr * sp\n\n    return torch.stack([qx, qy, qz, qw], dim=-1)\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/include/common/motor_crc.h",
    "content": "/*****************************************************************\n Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved.\n******************************************************************/\n\n#ifndef _MOTOR_CRC_H_\n#define _MOTOR_CRC_H_\n\n#include <stdint.h>\n#include <array>\n#include \"rclcpp/rclcpp.hpp\"\n#include \"unitree_go/msg/low_cmd.hpp\"\n#include \"unitree_go/msg/motor_cmd.hpp\"\n#include \"unitree_go/msg/bms_cmd.hpp\"\n\nconstexpr int HIGHLEVEL = 0xee;\nconstexpr int LOWLEVEL = 0xff;\nconstexpr int TRIGERLEVEL = 0xf0;\nconstexpr double PosStopF = (2.146E+9f);\nconstexpr double VelStopF = (16000.0f);\n\n// joint index\nconstexpr int FR_0 = 0;      \nconstexpr int FR_1 = 1;      \nconstexpr int FR_2 = 2;\n\nconstexpr int FL_0 = 3;\nconstexpr int FL_1 = 4;\nconstexpr int FL_2 = 5;\n\nconstexpr int RR_0 = 6;\nconstexpr int RR_1 = 7;\nconstexpr int RR_2 = 8;\n\nconstexpr int RL_0 = 9;\nconstexpr int RL_1 = 10;\nconstexpr int RL_2 = 11;\n\n\ntypedef struct\n{\n\tuint8_t off; // off 0xA5\n\tstd::array<uint8_t, 3> reserve;\n} BmsCmd;\n\n\n\ntypedef struct\n{\n\tuint8_t mode; // desired working mode\n\tfloat q;\t  // desired angle (unit: radian)\n\tfloat dq;\t  // desired velocity (unit: radian/second)\n\tfloat tau;\t  // desired output torque (unit: N.m)\n\tfloat Kp;\t  // desired position stiffness (unit: N.m/rad )\n\tfloat Kd;\t  // desired velocity stiffness (unit: N.m/(rad/s) )\n\tstd::array<uint32_t, 3> reserve;\n} MotorCmd; // motor control\n\n\n\ntypedef struct\n{\n\tstd::array<uint8_t, 2> head;\n\tuint8_t levelFlag;\n\tuint8_t frameReserve;\n\t\t\n\tstd::array<uint32_t, 2> SN;\n\tstd::array<uint32_t, 2> version;\n\tuint16_t bandWidth;\n\tstd::array<MotorCmd, 20> motorCmd;\n\tBmsCmd bms;\n\tstd::array<uint8_t, 40> wirelessRemote;\n\tstd::array<uint8_t, 12> led;\n\tstd::array<uint8_t, 2> fan;\n\tuint8_t gpio;\n\tuint32_t reserve;\n\t\n\tuint32_t crc;\n} LowCmd;           \n\nuint32_t crc32_core(uint32_t* ptr, uint32_t len);\nvoid get_crc(unitree_go::msg::LowCmd& msg);\n\n\n\n\n\n#endif"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/include/common/motor_crc_hg.h",
    "content": "/*****************************************************************\n Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved.\n******************************************************************/\n\n#ifndef _MOTOR_CRC_H_\n#define _MOTOR_CRC_H_\n\n#include <stdint.h>\n#include <array>\n#include \"rclcpp/rclcpp.hpp\"\n#include \"unitree_hg/msg/low_cmd.hpp\"\n#include \"unitree_hg/msg/motor_cmd.hpp\"\n\ntypedef struct\n{\n\tuint8_t mode; // desired working mode\n\tfloat q;\t  // desired angle (unit: radian)\n\tfloat dq;\t  // desired velocity (unit: radian/second)\n\tfloat tau;\t  // desired output torque (unit: N.m)\n\tfloat Kp;\t  // desired position stiffness (unit: N.m/rad )\n\tfloat Kd;\t  // desired velocity stiffness (unit: N.m/(rad/s) )\n\tuint32_t reserve = 0;\n} MotorCmd; // motor control\n\ntypedef struct\n{\n\tuint8_t modePr;\n\tuint8_t modeMachine;\n\tstd::array<MotorCmd, 35> motorCmd;\n\tstd::array<uint32_t, 4> reserve;\n\tuint32_t crc;\n} LowCmd;\n\nuint32_t crc32_core(uint32_t *ptr, uint32_t len);\nvoid get_crc(unitree_hg::msg::LowCmd &msg);\n\n#endif"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/include/common/ros2_sport_client.h",
    "content": "#ifndef _ROS2_SPORT_CLIENT_\n#define _ROS2_SPORT_CLIENT_\n#include<iostream>\n#include \"nlohmann/json.hpp\"\n#include \"unitree_api/msg/request.hpp\"\n\n\n#pragma pack(1)\nconst int32_t ROBOT_SPORT_API_ID_DAMP = 1001;\nconst int32_t ROBOT_SPORT_API_ID_BALANCESTAND = 1002;\nconst int32_t ROBOT_SPORT_API_ID_STOPMOVE = 1003;\nconst int32_t ROBOT_SPORT_API_ID_STANDUP = 1004;\nconst int32_t ROBOT_SPORT_API_ID_STANDDOWN = 1005;\nconst int32_t ROBOT_SPORT_API_ID_RECOVERYSTAND = 1006;\nconst int32_t ROBOT_SPORT_API_ID_EULER = 1007;\nconst int32_t ROBOT_SPORT_API_ID_MOVE = 1008;\nconst int32_t ROBOT_SPORT_API_ID_SIT = 1009;\nconst int32_t ROBOT_SPORT_API_ID_RISESIT = 1010;\nconst int32_t ROBOT_SPORT_API_ID_SWITCHGAIT = 1011;\nconst int32_t ROBOT_SPORT_API_ID_TRIGGER = 1012;\nconst int32_t ROBOT_SPORT_API_ID_BODYHEIGHT = 1013;\nconst int32_t ROBOT_SPORT_API_ID_FOOTRAISEHEIGHT = 1014;\nconst int32_t ROBOT_SPORT_API_ID_SPEEDLEVEL = 1015;\nconst int32_t ROBOT_SPORT_API_ID_HELLO = 1016;\nconst int32_t ROBOT_SPORT_API_ID_STRETCH = 1017;\nconst int32_t ROBOT_SPORT_API_ID_TRAJECTORYFOLLOW = 1018;\nconst int32_t ROBOT_SPORT_API_ID_CONTINUOUSGAIT = 1019;\nconst int32_t ROBOT_SPORT_API_ID_CONTENT = 1020;\nconst int32_t ROBOT_SPORT_API_ID_WALLOW = 1021;\nconst int32_t ROBOT_SPORT_API_ID_DANCE1 = 1022;\nconst int32_t ROBOT_SPORT_API_ID_DANCE2 = 1023;\nconst int32_t ROBOT_SPORT_API_ID_GETBODYHEIGHT = 1024;\nconst int32_t ROBOT_SPORT_API_ID_GETFOOTRAISEHEIGHT = 1025;\nconst int32_t ROBOT_SPORT_API_ID_GETSPEEDLEVEL = 1026;\nconst int32_t ROBOT_SPORT_API_ID_SWITCHJOYSTICK = 1027;\nconst int32_t ROBOT_SPORT_API_ID_POSE = 1028;\nconst int32_t ROBOT_SPORT_API_ID_SCRAPE = 1029;\nconst int32_t ROBOT_SPORT_API_ID_FRONTFLIP = 1030;\nconst int32_t ROBOT_SPORT_API_ID_FRONTJUMP = 1031;\nconst int32_t ROBOT_SPORT_API_ID_FRONTPOUNCE = 1032;\n\ntypedef struct\n{\n    float timeFromStart;\n    float x;\n    float y;\n    float yaw;\n    float vx;\n    float vy;\n    float vyaw;\n} PathPoint;\n\nclass SportClient\n{\npublic:\n    /*\n     * @brief Damp\n     * @api: 1001\n     */\n    void Damp(unitree_api::msg::Request &req);\n\n    /*\n     * @brief BalanceStand\n     * @api: 1002\n     */\n    void BalanceStand(unitree_api::msg::Request &req);\n\n    /*\n     * @brief StopMove\n     * @api: 1003\n     */\n    void StopMove(unitree_api::msg::Request &req);\n\n    /*\n     * @brief StandUp\n     * @api: 1004\n     */\n    void StandUp(unitree_api::msg::Request &req);\n\n    /*\n     * @brief StandDown\n     * @api: 1005\n     */\n    void StandDown(unitree_api::msg::Request &req);\n\n    /*\n     * @brief RecoveryStand\n     * @api: 1006\n     */\n    void RecoveryStand(unitree_api::msg::Request &req);\n\n    /*\n     * @brief Euler\n     * @api: 1007\n     */\n    void Euler(unitree_api::msg::Request &req, float roll, float pitch, float yaw);\n\n    /*\n     * @brief Move\n     * @api: 1008\n     */\n    void Move(unitree_api::msg::Request &req, float vx, float vy, float vyaw);\n\n    /*\n     * @brief Sit\n     * @api: 1009\n     */\n    void Sit(unitree_api::msg::Request &req);\n\n    /*\n     * @brief RiseSit\n     * @api: 1010\n     */\n    void RiseSit(unitree_api::msg::Request &req);\n\n    /*\n     * @brief SwitchGait\n     * @api: 1011\n     */\n    void SwitchGait(unitree_api::msg::Request &req, int d);\n\n    /*\n     * @brief Trigger\n     * @api: 1012\n     */\n    void Trigger(unitree_api::msg::Request &req);\n\n    /*\n     * @brief BodyHeight\n     * @api: 1013\n     */\n    void BodyHeight(unitree_api::msg::Request &req, float height);\n\n    /*\n     * @brief FootRaiseHeight\n     * @api: 1014\n     */\n    void FootRaiseHeight(unitree_api::msg::Request &req, float height);\n\n    /*\n     * @brief SpeedLevel\n     * @api: 1015\n     */\n    void SpeedLevel(unitree_api::msg::Request &req, int level);\n\n    /*\n     * @brief Hello\n     * @api: 1016\n     */\n    void Hello(unitree_api::msg::Request &req);\n\n    /*\n     * @brief Stretch\n     * @api: 1017\n     */\n    void Stretch(unitree_api::msg::Request &req);\n\n    /*\n     * @brief TrajectoryFollow\n     * @api: 1018\n     */\n    void TrajectoryFollow(unitree_api::msg::Request &req, std::vector<PathPoint> &path);\n\n    /*\n     * @brief SwitchJoystick\n     * @api: 1027\n     */\n    void SwitchJoystick(unitree_api::msg::Request &req, bool flag);\n\n    /*\n     * @brief ContinuousGait\n     * @api: 1019\n     */\n    void ContinuousGait(unitree_api::msg::Request &req, bool flag);\n\n    /*\n     * @brief Wallow\n     * @api: 1021\n     */\n    void Wallow(unitree_api::msg::Request &req);\n\n    /*\n     * @brief Content\n     * @api: 1020\n     */\n    void Content(unitree_api::msg::Request &req);\n\n    /*\n     * @brief Pose\n     * @api: 1028\n     */\n    void Pose(unitree_api::msg::Request &req, bool flag);\n\n    /*\n     * @brief Scrape\n     * @api: 1029\n     */\n    void Scrape(unitree_api::msg::Request &req);\n\n    /*\n     * @brief FrontFlip\n     * @api: 1030\n     */\n    void FrontFlip(unitree_api::msg::Request &req);\n\n    /*\n     * @brief FrontJump\n     * @api: 1031\n     */\n    void FrontJump(unitree_api::msg::Request &req);\n\n    /*\n     * @brief FrontPounce\n     * @api: 1032\n     */\n    void FrontPounce(unitree_api::msg::Request &req);\n\n    /*\n     * @brief Dance1\n     * @api: 1022\n     */\n    void Dance1(unitree_api::msg::Request &req);\n\n    /*\n     * @brief Dance2\n     * @api: 1023\n     */\n    void Dance2(unitree_api::msg::Request &req);\n};\n\n#endif"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/include/common/wireless_controller.h",
    "content": "#pragma once\n\n#include \"unitree_go/msg/wireless_controller.hpp\"\n#include <array>\n#include <cstring>\n\nclass KeyMap {\npublic:\n    static const int R1;\n    static const int L1;\n    static const int start;\n    static const int select;\n    static const int R2;\n    static const int L2;\n    static const int F1;\n    static const int F2;\n    static const int A;\n    static const int B;\n    static const int X;\n    static const int Y;\n    static const int up;\n    static const int right;\n    static const int down;\n    static const int left;\n};\n\nclass RemoteController {\npublic:\n    // Constructor\n    RemoteController();\n\n    // Add overloaded set method for raw data\n    void set(const std::array<unsigned char, 40>& data);\n    \n    // Keep original method for compatibility\n    void set(const unitree_go::msg::WirelessController::SharedPtr msg);\n\n    // Member variables\n    double lx;\n    double ly;\n    double rx;\n    double ry;\n    int button[16];\n};\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py",
    "content": "\"\"\"\nHoloMotion ROS2 Launch Configuration\n\nThis module defines the ROS2 launch configuration for the HoloMotion humanoid robot control system.\nIt sets up a complete robotics pipeline including robot control, motion policy execution, and data recording\nfor the Unitree G1 humanoid robot.\n\nThe launch file coordinates three main components:\n1. Main control node (C++) - Handles low-level robot control and communication\n2. Policy node (Python) - Executes motion policies and high-level decision making\n3. Recording node - Captures sensor data and commands for analysis\n\nKey Features:\n- Configures network interface for robot communication\n- Sets up CycloneDDS middleware with specific network interface\n- Launches coordinated multi-node system with shared configuration\n- Automatically records operational data with timestamped bags\n\nAuthor: HoloMotion Team\nLicense: See project LICENSE file\n\"\"\"\n\nfrom datetime import datetime\nimport os\n\nfrom launch import LaunchDescription\nfrom launch.actions import SetEnvironmentVariable, DeclareLaunchArgument\nfrom launch.substitutions import LaunchConfiguration\nfrom launch_ros.actions import Node\nfrom ament_index_python.packages import get_package_share_directory\nfrom launch.actions import ExecuteProcess\nfrom launch.conditions import IfCondition\n\n\ndef generate_launch_description():\n    \"\"\"\n    Generate the complete launch description for the HoloMotion humanoid control system.\n\n    This function creates a comprehensive ROS2 launch configuration that coordinates\n    multiple nodes required for humanoid robot operation. It sets up the necessary\n    environment, launches control nodes, and optionally initiates data recording.\n\n    Network Configuration:\n        - Uses specific network interface (eth0) for robot communication\n        - Configures CycloneDDS middleware to use designated network interface\n        - Ensures proper isolation and communication with the robot hardware\n\n    Node Architecture:\n        1. Main Control Node (C++):\n           - Handles real-time robot control and sensor data processing\n           - Manages low-level motor commands and feedback loops\n           - Interfaces directly with robot hardware via configured network\n\n        2. Policy Node (Python):\n           - Executes trained motion policies for humanoid locomotion\n           - Processes high-level commands and translates to robot actions\n           - Handles motion planning and behavior coordination\n\n        3. Recording Node (Optional):\n           - Automatically captures all relevant system data when enabled\n           - Records sensor states, commands, and system metrics\n           - Creates timestamped bag files for later analysis\n\n    Configuration:\n        - Robot: Unitree G1 with 29 DOF configuration\n        - Config file: g1_29dof_holomotion.yaml\n        - Recording format: MCAP for efficient data storage\n        - Recording: Disabled by default, can be enabled with --record parameter\n\n    Recorded Topics (when recording enabled):\n        - /lowcmd: Low-level motor commands sent to robot\n        - /lowstate: Robot sensor feedback and joint states\n        - /humanoid/action: High-level action commands from policy\n\n    Parameters:\n        - enable_recording: Boolean flag to enable/disable topic recording (default: false)\n\n    Returns:\n        LaunchDescription: Complete ROS2 launch configuration with all nodes,\n                          environment variables, and optional recording setup\n\n    Raises:\n        FileNotFoundError: If the configuration file cannot be located\n        PermissionError: If unable to create recording directory\n\n    Example:\n        Launch without recording (default):\n        $ ros2 launch humanoid_control holomotion_29dof.launch.py\n\n        Launch with recording enabled:\n        $ ros2 launch humanoid_control holomotion_29dof.launch.py enable_recording:=true\n\n        Or using the shell script:\n        $ ./launch_holomotion_29dof.sh --record\n    \"\"\"\n    # Declare launch arguments\n    enable_recording_arg = DeclareLaunchArgument(\n        \"enable_recording\",\n        default_value=\"false\",\n        description=\"Enable topic recording (true/false)\",\n    )\n\n    network_interface = \"eth0\"\n    config_name = \"g1_29dof_holomotion.yaml\"\n\n    pkg_dir = get_package_share_directory(\"humanoid_control\")\n    config_file = os.path.join(pkg_dir, \"config\", config_name)\n    # Allow overriding python interpreter via env var (set by the shell script)\n\n    python_executable = os.environ[\"Deploy_CONDA_PREFIX\"] + \"/bin/python\"\n    print(f\"Using Python executable: {python_executable}\")\n\n    return LaunchDescription(\n        [\n            # Declare launch arguments\n            enable_recording_arg,\n            # Main control node (C++)\n            SetEnvironmentVariable(\n                name=\"CYCLONEDDS_URI\",\n                value=f\"<CycloneDDS><Domain><General><NetworkInterfaceAddress>{network_interface}</NetworkInterfaceAddress></General></Domain></CycloneDDS>\",\n            ),\n            Node(\n                package=\"humanoid_control\",\n                executable=\"humanoid_control\",\n                name=\"main_node\",\n                parameters=[{\"config_path\": config_file}],\n                output=\"screen\",\n            ),\n            # Policy node (Python)\n            Node(\n                package=\"humanoid_control\",\n                executable=\"policy_node_29dof\",\n                name=\"policy_node\",\n                parameters=[{\"config_path\": config_file}],\n                output=\"screen\",\n                prefix=python_executable,\n            ),\n            # Recording node (conditional)\n            ExecuteProcess(\n                cmd=[\n                    \"ros2\",\n                    \"bag\",\n                    \"record\",\n                    \"--storage\",\n                    \"mcap\",\n                    \"-o\",\n                    (\n                        \"./bag_record/\"\n                        + datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n                        + \"_\"\n                        + config_name.split(\".\")[0]\n                    ),\n                    \"/lowcmd\",\n                    \"/lowstate\",\n                    \"/humanoid/action\",\n                ],\n                output=\"screen\",\n                condition=IfCondition(LaunchConfiguration(\"enable_recording\")),\n            ),\n        ]\n    )\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/models/.gitkeep",
    "content": ""
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/motion_data/.gitkeep",
    "content": ""
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/package.xml",
    "content": "<?xml version=\"1.0\"?>\n<?xml-model href=\"http://download.ros.org/schema/package_format3.xsd\" schematypens=\"http://www.w3.org/2001/XMLSchema\"?>\n<package format=\"3\">\n  <name>humanoid_control</name>\n  <version>0.0.0</version>\n  <description>Humanoid locomotion control package from Horizon Robotics</description>\n  <maintainer email=\"unitree@todo.todo\">unitree</maintainer>\n  <license>TODO: License declaration</license>\n\n  <!-- Build tool dependencies -->\n  <buildtool_depend>ament_cmake</buildtool_depend>\n  <buildtool_depend>ament_cmake_python</buildtool_depend>\n\n  <!-- C++ dependencies -->\n  <depend>rclcpp</depend>\n  <depend>sensor_msgs</depend>\n  <depend>unitree_hg</depend>\n\n  <!-- Python dependencies -->\n  <depend>rclpy</depend>\n  <depend>python3-numpy</depend>\n  <depend>python3-torch</depend>\n  <depend>python3-yaml</depend>\n\n  <exec_depend>ros2launch</exec_depend>\n\n  <!-- Keep Python test dependencies -->\n  <test_depend>ament_copyright</test_depend>\n  <test_depend>ament_flake8</test_depend>\n  <test_depend>ament_pep257</test_depend>\n  <test_depend>python3-pytest</test_depend>\n\n  <export>\n    <build_type>ament_cmake</build_type>\n  </export>\n</package>\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/resource/humanoid_control",
    "content": ""
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/setup.cfg",
    "content": "[develop]\nscript_dir=$base/lib/humanoid_control\n[install]\ninstall_scripts=$base/lib/humanoid_control\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/setup.py",
    "content": "from setuptools import setup, find_packages\nimport os\n\n\npackage_name = \"humanoid_control\"\n\ndata_files = [\n    (\n        \"share/ament_index/resource_index/packages\",\n        [\"resource/\" + package_name],\n    ),\n    (\"share/\" + package_name, [\"package.xml\"]),\n]\n# Add files from config, launch and model directories\nfor dir_name in [\"config\", \"launch\", \"models\"]:\n    if os.path.exists(dir_name):  # Only process if directory exists\n        for root, dirs, files in os.walk(dir_name):\n            install_dir = os.path.join(\"share\", package_name, root)\n            list_entry = (install_dir, [os.path.join(root, f) for f in files])\n            data_files.append(list_entry)\n\nsetup(\n    name=package_name,\n    version=\"0.0.1\",\n    packages=find_packages(),\n    data_files=data_files,\n    install_requires=[\"setuptools\"],\n    zip_safe=True,\n    maintainer=\"Horizon Robotics\",\n    maintainer_email=\"maiyue01.chen@horizon.auto\",\n    description=\"Humanoid locomotion control package from Horizon Robotics\",\n    license=\"Apache License 2.0\",\n    tests_require=[\"pytest\"],\n    entry_points={\n        \"console_scripts\": [\n            \"policy_node_performance = humanoid_policy.policy_node_performance:main\",\n        ],\n    },\n)\n"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/src/common/motor_crc.cpp",
    "content": "#include \"motor_crc.h\"\n\n\nvoid get_crc(unitree_go::msg::LowCmd& msg)\n{\n    LowCmd raw{};\n    memcpy(&raw.head[0], &msg.head[0], 2);\n\n    raw.levelFlag=msg.level_flag;\n    raw.frameReserve=msg.frame_reserve;\n\n    memcpy(&raw.SN[0],&msg.sn[0], 8);\n    memcpy(&raw.version[0], &msg.version[0], 8);\n\n\traw.bandWidth=msg.bandwidth;\n\n\n    for(int i = 0; i<20; i++)\n    {\n        raw.motorCmd[i].mode=msg.motor_cmd[i].mode;\n        raw.motorCmd[i].q=msg.motor_cmd[i].q;\n        raw.motorCmd[i].dq=msg.motor_cmd[i].dq;\n        raw.motorCmd[i].tau=msg.motor_cmd[i].tau;\n        raw.motorCmd[i].Kp=msg.motor_cmd[i].kp;\n        raw.motorCmd[i].Kd=msg.motor_cmd[i].kd;\n\n        memcpy(&raw.motorCmd[i].reserve[0], &msg.motor_cmd[i].reserve[0], 12);\n    }\n\n    raw.bms.off=msg.bms_cmd.off;\n    memcpy(&raw.bms.reserve[0],&msg.bms_cmd.reserve[0],  3);\n\n\n    memcpy(&raw.wirelessRemote[0], &msg.wireless_remote[0], 40);\n\n    memcpy(&raw.led[0], &msg.led[0],  12);  // go2\n    memcpy(&raw.fan[0], &msg.fan[0],  2);\n    raw.gpio=msg.gpio;    // go2\n\n\traw.reserve=msg.reserve;\n\n    raw.crc=crc32_core((uint32_t *)&raw, (sizeof(LowCmd)>>2)-1);\n    msg.crc=raw.crc;\n\n    \n}\n\n\nuint32_t crc32_core(uint32_t* ptr, uint32_t len)\n{\n    uint32_t xbit = 0;\n    uint32_t data = 0;\n    uint32_t CRC32 = 0xFFFFFFFF;\n    const uint32_t dwPolynomial = 0x04c11db7;\n    for (uint32_t i = 0; i < len; i++)\n    {\n        xbit = 1 << 31;\n        data = ptr[i];\n        for (uint32_t bits = 0; bits < 32; bits++)\n        {\n            if (CRC32 & 0x80000000)\n            {\n                CRC32 <<= 1;\n                CRC32 ^= dwPolynomial;\n            }\n            else\n                CRC32 <<= 1;\n            if (data & xbit)\n                CRC32 ^= dwPolynomial;\n\n            xbit >>= 1;\n        }\n    }\n    return CRC32;\n}"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/src/common/motor_crc_hg.cpp",
    "content": "#include \"motor_crc_hg.h\"\n\nvoid get_crc(unitree_hg::msg::LowCmd &msg)\n{\n    LowCmd raw{};\n\n    raw.modePr = msg.mode_pr;\n    raw.modeMachine = msg.mode_machine;\n\n    for (int i = 0; i < 35; i++)\n    {\n        raw.motorCmd[i].mode = msg.motor_cmd[i].mode;\n        raw.motorCmd[i].q = msg.motor_cmd[i].q;\n        raw.motorCmd[i].dq = msg.motor_cmd[i].dq;\n        raw.motorCmd[i].tau = msg.motor_cmd[i].tau;\n        raw.motorCmd[i].Kp = msg.motor_cmd[i].kp;\n        raw.motorCmd[i].Kd = msg.motor_cmd[i].kd;\n\n        raw.motorCmd[i].reserve = msg.motor_cmd[i].reserve;\n    }\n\n    memcpy(&raw.reserve[0], &msg.reserve[0], 4);\n\n    raw.crc = crc32_core((uint32_t *)&raw, (sizeof(LowCmd) >> 2) - 1);\n    msg.crc = raw.crc;\n}\n\nuint32_t crc32_core(uint32_t *ptr, uint32_t len)\n{\n    uint32_t xbit = 0;\n    uint32_t data = 0;\n    uint32_t CRC32 = 0xFFFFFFFF;\n    const uint32_t dwPolynomial = 0x04c11db7;\n    for (uint32_t i = 0; i < len; i++)\n    {\n        xbit = 1 << 31;\n        data = ptr[i];\n        for (uint32_t bits = 0; bits < 32; bits++)\n        {\n            if (CRC32 & 0x80000000)\n            {\n                CRC32 <<= 1;\n                CRC32 ^= dwPolynomial;\n            }\n            else\n                CRC32 <<= 1;\n            if (data & xbit)\n                CRC32 ^= dwPolynomial;\n\n            xbit >>= 1;\n        }\n    }\n    return CRC32;\n}"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/src/common/ros2_sport_client.cpp",
    "content": "#include \"ros2_sport_client.h\"\n\nvoid SportClient::Damp(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_DAMP;\n}\n\nvoid SportClient::BalanceStand(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_BALANCESTAND;\n}\n\nvoid SportClient::StopMove(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_STOPMOVE;\n}\n\nvoid SportClient::StandUp(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_STANDUP;\n}\n\nvoid SportClient::StandDown(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_STANDDOWN;\n}\n\nvoid SportClient::RecoveryStand(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_RECOVERYSTAND;\n}\n\nvoid SportClient::Euler(unitree_api::msg::Request &req, float roll, float pitch, float yaw)\n{\n    nlohmann::json js;\n    js[\"x\"] = roll;\n    js[\"y\"] = pitch;\n    js[\"z\"] = yaw;\n    req.parameter = js.dump();\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_EULER;\n}\n\nvoid SportClient::Move(unitree_api::msg::Request &req, float vx, float vy, float vyaw)\n{\n    nlohmann::json js;\n    js[\"x\"] = vx;\n    js[\"y\"] = vy;\n    js[\"z\"] = vyaw;\n    req.parameter = js.dump();\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_MOVE;\n}\n\nvoid SportClient::Sit(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_SIT;\n}\n\nvoid SportClient::RiseSit(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_RISESIT;\n}\n\nvoid SportClient::SwitchGait(unitree_api::msg::Request &req, int d)\n{\n    nlohmann::json js;\n    js[\"data\"] = d;\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_SWITCHGAIT;\n    req.parameter = js.dump();\n}\n\nvoid SportClient::Trigger(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_TRIGGER;\n}\n\nvoid SportClient::BodyHeight(unitree_api::msg::Request &req, float height)\n{\n    nlohmann::json js;\n    js[\"data\"] = height;\n    req.parameter = js.dump();\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_BODYHEIGHT;\n}\n\nvoid SportClient::FootRaiseHeight(unitree_api::msg::Request &req, float height)\n{\n    nlohmann::json js;\n    js[\"data\"] = height;\n    req.parameter = js.dump();\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_FOOTRAISEHEIGHT;\n}\n\nvoid SportClient::SpeedLevel(unitree_api::msg::Request &req, int level)\n{\n    nlohmann::json js;\n    js[\"data\"] = level;\n    req.parameter = js.dump();\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_SPEEDLEVEL;\n}\n\nvoid SportClient::Hello(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_HELLO;\n}\n\nvoid SportClient::Stretch(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_STRETCH;\n}\n\n\nvoid SportClient::TrajectoryFollow(unitree_api::msg::Request &req, std::vector<PathPoint> &path)\n{\n    nlohmann::json js_path;\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_TRAJECTORYFOLLOW;\n    for (int i = 0; i < 30; i++)\n    {\n        nlohmann::json js_point;\n        js_point[\"t_from_start\"] = path[i].timeFromStart;\n        js_point[\"x\"] = path[i].x;\n        js_point[\"y\"] = path[i].y;\n        js_point[\"yaw\"] = path[i].yaw;\n        js_point[\"vx\"] = path[i].vx;\n        js_point[\"vy\"] = path[i].vy;\n        js_point[\"vyaw\"] = path[i].vyaw;\n        js_path.push_back(js_point);\n    }\n    req.parameter =js_path.dump();\n}\n\nvoid SportClient::SwitchJoystick(unitree_api::msg::Request &req, bool flag)\n{\n    nlohmann::json js;\n    js[\"data\"] = flag;\n    req.parameter = js.dump();\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_SWITCHJOYSTICK;\n}\n\nvoid SportClient::ContinuousGait(unitree_api::msg::Request &req, bool flag)\n{\n    nlohmann::json js;\n    js[\"data\"] = flag;\n    req.parameter = js.dump();\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_CONTINUOUSGAIT;\n}\n\nvoid SportClient::Wallow(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_WALLOW;\n}\n\nvoid SportClient::Content(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_CONTENT;\n}\n\nvoid SportClient::Pose(unitree_api::msg::Request &req, bool flag)\n{\n    nlohmann::json js;\n    js[\"data\"] = flag;\n    req.parameter = js.dump();\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_POSE;\n}\n\nvoid SportClient::Scrape(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_SCRAPE;\n}\n\nvoid SportClient::FrontFlip(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTFLIP;\n}\n\nvoid SportClient::FrontJump(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTJUMP;\n}\n\nvoid SportClient::FrontPounce(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTPOUNCE;\n}\n\nvoid SportClient::Dance1(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_DANCE1;\n}\n\nvoid SportClient::Dance2(unitree_api::msg::Request &req)\n{\n    req.header.identity.api_id = ROBOT_SPORT_API_ID_DANCE2;\n}"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/src/common/wireless_controller.cpp",
    "content": "#include \"common/wireless_controller.h\"\n#include <cstring>\n\n// Define static constants\nconst int KeyMap::R1 = 0;\nconst int KeyMap::L1 = 1;\nconst int KeyMap::start = 2;\nconst int KeyMap::select = 3;\nconst int KeyMap::R2 = 4;\nconst int KeyMap::L2 = 5;\nconst int KeyMap::F1 = 6;\nconst int KeyMap::F2 = 7;\nconst int KeyMap::A = 8;\nconst int KeyMap::B = 9;\nconst int KeyMap::X = 10;\nconst int KeyMap::Y = 11;\nconst int KeyMap::up = 12;\nconst int KeyMap::right = 13;\nconst int KeyMap::down = 14;\nconst int KeyMap::left = 15;\n\n// Implement RemoteController methods\nRemoteController::RemoteController() {\n  lx = 0;\n  ly = 0;\n  rx = 0;\n  ry = 0;\n  std::fill(button, button + 16, 0);\n}\n\nvoid RemoteController::set(const std::array<unsigned char, 40> &data) {\n  // Debug print raw bytes\n  //   printf(\"Raw data bytes: \");\n  //   for (int i = 0; i < 40; i++) {\n  //     printf(\"%02x \", data[i]);\n  //   }\n  //   printf(\"\\n\");\n\n  // Extract keys from bytes 2-3\n  uint16_t keys = (data[3] << 8) | data[2];\n  //   printf(\"Keys value: 0x%04x\\n\", keys);\n\n  for (int i = 0; i < 16; i++) {\n    button[i] = (keys & (1 << i)) >> i;\n  }\n\n  // Extract and print floats before memcpy\n  float lx_temp, rx_temp, ry_temp, ly_temp;\n  std::memcpy(&lx_temp, &data[4], 4);  // bytes 4-7\n  std::memcpy(&rx_temp, &data[8], 4);  // bytes 8-11\n  std::memcpy(&ry_temp, &data[12], 4); // bytes 12-15\n  std::memcpy(&ly_temp, &data[20], 4); // bytes 20-23\n\n  //   printf(\"Values before assignment: lx=%f, ly=%f, rx=%f, ry=%f\\n\", lx_temp,\n  //  ly_temp, rx_temp, ry_temp);\n\n  // Assign to class members\n  lx = lx_temp;\n  rx = rx_temp;\n  ry = ry_temp;\n  ly = ly_temp;\n\n  //   printf(\"Values after assignment: lx=%f, ly=%f, rx=%f, ry=%f\\n\", lx, ly,\n  //   rx,\n  //  ry);\n}\n\nvoid RemoteController::set(\n    const unitree_go::msg::WirelessController::SharedPtr msg) {\n  uint16_t keys = msg->keys;\n  for (int i = 0; i < 16; i++) {\n    button[i] = (keys & (1 << i)) >> i;\n  }\n  lx = msg->lx;\n  rx = msg->rx;\n  ry = msg->ry;\n  ly = msg->ly;\n}"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/src/src/main_node.cpp",
    "content": "/**\n * This example demonstrates how to use ROS2 to send low-level motor commands of\n *unitree g1 robot 29 dof\n **/\n #include \"common/motor_crc_hg.h\"\n #include \"common/wireless_controller.h\"\n #include \"rclcpp/rclcpp.hpp\"\n #include \"unitree_go/msg/wireless_controller.hpp\"\n #include \"unitree_hg/msg/low_cmd.hpp\"\n #include \"unitree_hg/msg/low_state.hpp\"\n #include \"unitree_hg/msg/motor_cmd.hpp\"\n #include <ament_index_cpp/get_package_share_directory.hpp>\n #include <map>\n #include <sstream>\n #include <std_msgs/msg/float32_multi_array.hpp>\n #include <std_msgs/msg/string.hpp>\n #include <string>\n #include <vector>\n #include <yaml-cpp/yaml.h>\n #include <thread>\n \n #define INFO_IMU 0   // Set 1 to info IMU states\n #define INFO_MOTOR 0 // Set 1 to info motor states\n \n enum PRorAB { PR = 0, AB = 1 };\n \n using std::placeholders::_1;\n \n const int G1_NUM_MOTOR = 29;\n \n enum class RobotState { ZERO_TORQUE, MOVE_TO_DEFAULT, EMERGENCY_STOP, POLICY };\n enum class EmergencyStopPhase { DAMPING, DISABLE };  // New enum for emergency stop phases\n \n // Create a humanoid_controller class for low state receive\n class humanoid_controller : public rclcpp::Node {\n public:\n   humanoid_controller() : Node(\"humanoid_controller\") {\n     RCLCPP_INFO(this->get_logger(), \"Using main_node !!!\");\n \n     // Get config path from ROS parameter\n     std::string config_path =\n         this->declare_parameter<std::string>(\"config_path\", \"\");\n \n     RCLCPP_INFO(this->get_logger(), \"Config file path: %s\",\n                 config_path.c_str());\n \n     // Load configuration\n     loadConfig(config_path);\n     RCLCPP_INFO(this->get_logger(),\n                 \"Entered ZERO_TORQUE state, press start to switch to \"\n                 \"MOVE_TO_DEFAULT state, press A to switch to POLICY state, \"\n                 \"press select to emergency stop. Waiting for start signal...\");\n \n     lowstate_subscriber_ = this->create_subscription<unitree_hg::msg::LowState>(\n         \"/lowstate\", 10,\n         std::bind(&humanoid_controller::LowStateHandler, this, _1));\n \n     policy_action_subscriber_ =\n         this->create_subscription<std_msgs::msg::Float32MultiArray>(\n             \"/humanoid/action\", 10,\n             std::bind(&humanoid_controller::PolicyActionHandler, this, _1));\n \n     // Add subscribers for kps and kds parameters from policy node\n     kps_subscriber_ =\n         this->create_subscription<std_msgs::msg::Float32MultiArray>(\n             \"/humanoid/kps\", 10,\n             std::bind(&humanoid_controller::KpsHandler, this, _1));\n     \n     kds_subscriber_ =\n         this->create_subscription<std_msgs::msg::Float32MultiArray>(\n             \"/humanoid/kds\", 10,\n             std::bind(&humanoid_controller::KdsHandler, this, _1));\n \n     lowcmd_publisher_ =\n         this->create_publisher<unitree_hg::msg::LowCmd>(\"/lowcmd\", 10);\n \n     robot_state_publisher_ = \n         this->create_publisher<std_msgs::msg::String>(\"/robot_state\", 10);\n \n     timer_ =\n         this->create_wall_timer(std::chrono::milliseconds(timer_dt),\n                                 std::bind(&humanoid_controller::Control, this));\n \n     time_ = 0;\n     duration_ = 3; // 3 s\n   }\n \n private:\n   std::map<std::string, int> dof2motor_idx;\n   std::map<std::string, double> default_dof_pos;\n   std::map<std::string, double> kps;\n   std::map<std::string, double> kds;\n   std::vector<std::string> complete_dof_order;\n   std::vector<std::string> policy_dof_order;\n   \n   // Policy-subscribed control parameters\n   std::vector<float> policy_kps_data;\n   std::vector<float> policy_kds_data;\n   bool kps_received_ = false;\n   bool kds_received_ = false;\n   RemoteController remote_controller;\n   std::map<std::string, double> target_dof_pos;\n   std::vector<float> policy_action_data;\n \n   // Optional arrays from YAML for start (MOVE_TO_DEFAULT) behavior\n   bool has_joint_arrays_ = false;\n   std::vector<std::string> joint_names_array_;\n   std::vector<double> default_position_array_;\n   std::vector<double> kp_array_;\n   std::vector<double> kd_array_;\n   \n   std::map<std::string, double> move_to_default_kps;\n   std::map<std::string, double> move_to_default_kds;\n \n   RobotState current_state_ = RobotState::ZERO_TORQUE;\n \n   bool should_shutdown_ = false;\n \n   // Add safety limit parameters using existing structure in YAML\n   std::map<std::string, std::pair<double, double>> joint_position_limits; // min, max\n   std::map<std::string, double> joint_velocity_limits;\n   std::map<std::string, double> joint_effort_limits;\n   \n   // Scaling coefficients for limits\n   double position_limit_scale = 1.0;\n   double velocity_limit_scale = 1.0;\n   double effort_limit_scale = 1.0;\n \n   EmergencyStopPhase emergency_stop_phase_ = EmergencyStopPhase::DAMPING;\n   double emergency_stop_time_ = 0.0;\n   double emergency_damping_duration_ = 2.0;  // 1 second of damping before disabling\n \n   // Add a helper function to calculate expected torque\n   double calculateExpectedTorque(const std::string& dof_name, double q_des, double q, double dq) {\n     double kp = kps[dof_name];\n     double kd = kds[dof_name];\n     // dq_des is assumed to be 0 in your control scheme\n     return kp * (q_des - q) + kd * (0.0 - dq);\n   }\n   \n   // Add a helper function to scale kp and kd to limit torque\n   std::pair<double, double> limitTorque(const std::string& dof_name, double q_des, double q, double dq) {\n     double kp = kps[dof_name];\n     double kd = kds[dof_name];\n     \n     // Calculate expected torque\n     double expected_torque = calculateExpectedTorque(dof_name, q_des, q, dq);\n     double abs_expected_torque = std::abs(expected_torque);\n     \n     // Check if torque would exceed limit\n     if (joint_effort_limits.find(dof_name) != joint_effort_limits.end()) {\n       double max_torque = joint_effort_limits[dof_name] * effort_limit_scale;\n       \n       if (abs_expected_torque > max_torque && abs_expected_torque > 1e-6) {\n         // Scale both kp and kd by the same factor to preserve damping characteristics\n         double scale_factor = max_torque / abs_expected_torque;\n         return std::make_pair(kp * scale_factor, kd * scale_factor);\n       }\n     }\n     \n     // If no scaling needed, return original values\n     return std::make_pair(kp, kd);\n   }\n   \n   // Add a helper function to scale custom kp and kd to limit torque\n   std::pair<double, double> limitTorqueWithCustomGains(\n     const std::string& dof_name, \n     double q_des, \n     double q, \n     double dq,\n     double custom_kp,\n     double custom_kd) {\n     \n     // Calculate expected torque\n     double expected_torque = custom_kp * (q_des - q) + custom_kd * (0.0 - dq);\n     double abs_expected_torque = std::abs(expected_torque);\n     \n     // Check if torque would exceed limit\n     if (joint_effort_limits.find(dof_name) != joint_effort_limits.end()) {\n       double max_torque = joint_effort_limits[dof_name] * effort_limit_scale;\n       \n       if (abs_expected_torque > max_torque && abs_expected_torque > 1e-6) {\n         // Scale both kp and kd by the same factor to preserve damping characteristics\n         double scale_factor = max_torque / abs_expected_torque;\n         return std::make_pair(custom_kp * scale_factor, custom_kd * scale_factor);\n       }\n     }\n     \n     // If no scaling needed, return original values\n     return std::make_pair(custom_kp, custom_kd);\n   }\n \n   void loadConfig(const std::string &config_path) {\n     try {\n       YAML::Node config = YAML::LoadFile(config_path);\n \n       // Load motor indices\n       auto indices = config[\"dof2motor_idx_mapping\"];\n       for (const auto &it : indices) {\n         dof2motor_idx[it.first.as<std::string>()] = it.second.as<int>();\n       }\n \n       // Load default angles\n       auto angles = config[\"default_joint_angles\"];\n       for (const auto &it : angles) {\n         default_dof_pos[it.first.as<std::string>()] = it.second.as<double>();\n       }\n       // Set target dof pos to default dof pos\n       for (const auto &it : default_dof_pos) {\n         target_dof_pos[it.first] = it.second;\n       }\n \n       // Note: kps and kds are now received from policy node via ROS topics\n       // No longer loading from config file to avoid conflicts\n \n       // Load dof order\n       for (const auto &it : config[\"complete_dof_order\"]) {\n         complete_dof_order.push_back(it.as<std::string>());\n       }\n       for (const auto &it : config[\"policy_dof_order\"]) {\n         policy_dof_order.push_back(it.as<std::string>());\n       }\n \n       // Load control frequency\n       control_freq_ = config[\"control_freq\"].as<double>();\n       control_dt_ = 1.0 / control_freq_;\n       timer_dt = static_cast<int>(control_dt_ * 1000);\n       RCLCPP_INFO(this->get_logger(), \"Control frequency set to: %f Hz\",\n                   control_freq_);\n \n       // Load joint limits\n       auto pos_limits = config[\"joint_limits\"][\"position\"];\n       for (const auto &it : pos_limits) {\n         std::string dof_name = it.first.as<std::string>();\n         auto limits = it.second.as<std::vector<double>>();\n         joint_position_limits[dof_name] = std::make_pair(limits[0], limits[1]);\n       }\n \n       auto vel_limits = config[\"joint_limits\"][\"velocity\"];\n       for (const auto &it : vel_limits) {\n         joint_velocity_limits[it.first.as<std::string>()] = it.second.as<double>();\n       }\n \n       auto effort_limits = config[\"joint_limits\"][\"effort\"];\n       for (const auto &it : effort_limits) {\n         joint_effort_limits[it.first.as<std::string>()] = it.second.as<double>();\n       }\n \n       // Load joint limits scaling coefficients (optional, default to 1.0)\n       position_limit_scale = config[\"limit_scales\"][\"position\"].as<double>(1.0);\n       velocity_limit_scale = config[\"limit_scales\"][\"velocity\"].as<double>(1.0);\n       effort_limit_scale = config[\"limit_scales\"][\"effort\"].as<double>(1.0);\n       \n       RCLCPP_INFO(this->get_logger(), \"Joint limit scales - Position: %f, Velocity: %f, Effort: %f\",\n                  position_limit_scale, velocity_limit_scale, effort_limit_scale);\n \n       // Optional: arrays for joint configuration on Start\n       // If kp and kd arrays are provided, use them with joint names and positions\n       // Auto-generate joint_names and default_position from complete_dof_order and default_joint_angles if not provided\n       if (config[\"kp\"] && config[\"kd\"]) {\n         joint_names_array_.clear();\n         default_position_array_.clear();\n         kp_array_.clear();\n         kd_array_.clear();\n \n         // Auto-generate joint_names and default_position from existing config if not explicitly provided\n         if (config[\"joint_names\"] && config[\"default_position\"]) {\n           // Use explicitly provided arrays\n           for (const auto &it : config[\"joint_names\"]) {\n             joint_names_array_.push_back(it.as<std::string>());\n           }\n           for (const auto &it : config[\"default_position\"]) {\n             default_position_array_.push_back(it.as<double>());\n           }\n         } else {\n           // Auto-generate from complete_dof_order and default_joint_angles\n           for (const auto &dof_name : complete_dof_order) {\n             joint_names_array_.push_back(dof_name);\n             if (default_dof_pos.find(dof_name) != default_dof_pos.end()) {\n               default_position_array_.push_back(default_dof_pos[dof_name]);\n             } else {\n               RCLCPP_WARN(this->get_logger(), \"Default position not found for joint %s, using 0.0\", dof_name.c_str());\n               default_position_array_.push_back(0.0);\n             }\n           }\n           RCLCPP_INFO(this->get_logger(), \"Auto-generated joint_names and default_position from complete_dof_order and default_joint_angles\");\n         }\n \n         // Load kp and kd arrays\n         for (const auto &it : config[\"kp\"]) {\n           kp_array_.push_back(it.as<double>());\n         }\n         for (const auto &it : config[\"kd\"]) {\n           kd_array_.push_back(it.as<double>());\n         }\n \n         // Basic validation\n         if (joint_names_array_.size() == default_position_array_.size() &&\n             joint_names_array_.size() == kp_array_.size() &&\n             joint_names_array_.size() == kd_array_.size()) {\n           has_joint_arrays_ = true;\n \n           // Store MoveToDefault-specific kps/kds and default positions\n           for (size_t i = 0; i < joint_names_array_.size(); ++i) {\n             const std::string &name = joint_names_array_[i];\n             double pos = default_position_array_[i];\n             double kp_v = kp_array_[i];\n             double kd_v = kd_array_[i];\n             default_dof_pos[name] = pos;\n             \n             // Store MoveToDefault kp/kd\n             move_to_default_kps[name] = kp_v;\n             move_to_default_kds[name] = kd_v;\n           }\n \n           RCLCPP_INFO(this->get_logger(), \"Using joint arrays for Start behavior (size: %zu)\", joint_names_array_.size());\n         } else {\n           RCLCPP_WARN(this->get_logger(), \"joint_names/default_position/kp/kd size mismatch; ignoring arrays\");\n         }\n       }\n     } catch (const YAML::Exception &e) {\n       RCLCPP_ERROR(this->get_logger(), \"Error parsing config file: %s\",\n                    e.what());\n     }\n   }\n \n   void Control() {\n     // First check if we're already in emergency stop\n     if (current_state_ == RobotState::EMERGENCY_STOP) {\n         emergency_stop_time_ += control_dt_;\n         \n         if (emergency_stop_phase_ == EmergencyStopPhase::DAMPING) {\n             SendDampedEmergencyStop();\n             if (emergency_stop_time_ >= emergency_damping_duration_) {\n                 emergency_stop_phase_ = EmergencyStopPhase::DISABLE;\n                 RCLCPP_INFO(this->get_logger(), \"Damping complete, disabling motors\");\n             }\n         } else {\n             SendFinalEmergencyStop();\n             if (timer_) {\n                 timer_->cancel();\n             }\n             rclcpp::shutdown();\n             return;\n         }\n         \n         get_crc(low_command);\n         lowcmd_publisher_->publish(low_command);\n         return;  // Exit early, ignore all other commands\n     }\n \n     // If not in emergency stop, check for emergency stop command first\n     if (remote_controller.button[KeyMap::select] == 1) {\n         current_state_ = RobotState::EMERGENCY_STOP;\n         should_shutdown_ = true;\n         publishRobotState();\n         return;\n     }\n \n     // Process other commands only if not in emergency stop\n     if (remote_controller.button[KeyMap::L1] == 1 &&\n         current_state_ != RobotState::ZERO_TORQUE) {\n         RCLCPP_INFO(this->get_logger(), \"Switching to ZERO_TORQUE state\");\n         current_state_ = RobotState::ZERO_TORQUE;\n         publishRobotState();\n     }\n \n     // Start button only works in ZERO_TORQUE state\n     if (remote_controller.button[KeyMap::start] == 1) {\n         if (current_state_ == RobotState::ZERO_TORQUE) {\n             RCLCPP_INFO(this->get_logger(), \"Switching to MOVE_TO_DEFAULT state\");\n             current_state_ = RobotState::MOVE_TO_DEFAULT;\n             time_ = 0.0;\n             publishRobotState();\n         } else {\n             RCLCPP_INFO(this->get_logger(), \n                 \"Start button only works in ZERO_TORQUE state. Current state: %d\", \n                 static_cast<int>(current_state_));\n         }\n     }\n \n     // A button only works in MOVE_TO_DEFAULT state\n     if (remote_controller.button[KeyMap::A] == 1) {\n         if (current_state_ == RobotState::MOVE_TO_DEFAULT) {\n             // Check if kps and kds parameters have been received from policy node\n             if (!kps_received_ || !kds_received_) {\n                 RCLCPP_ERROR(this->get_logger(), \n                             \"Cannot switch to POLICY state. Control parameters not received from policy node! kps_received: %s, kds_received: %s\", \n                             kps_received_ ? \"true\" : \"false\", \n                             kds_received_ ? \"true\" : \"false\");\n                 return;\n             }\n             \n             // Check lower body joint positions before allowing transition\n             bool positions_ok = true;\n             std::stringstream deviation_msg;\n             const double position_threshold = 0.4;\n \n             // List of lower body joints to check\n             std::vector<std::string> lower_body_joints = {\n                 \"left_hip_yaw\", \"left_hip_roll\", \"left_hip_pitch\", \"left_knee\", \"left_ankle_pitch\", \"left_ankle_roll\",\n                 \"right_hip_yaw\", \"right_hip_roll\", \"right_hip_pitch\", \"right_knee\", \"right_ankle_pitch\", \"right_ankle_roll\"\n             };\n \n             for (int i = 0; i < G1_NUM_MOTOR; ++i) {\n                 std::string dof_name = complete_dof_order[i];\n                 \n                 // Skip if not a lower body joint\n                 if (std::find(lower_body_joints.begin(), lower_body_joints.end(), dof_name) == lower_body_joints.end()) {\n                     continue;\n                 }\n \n                 double current_pos = motor[i].q;\n                 double default_pos = default_dof_pos[dof_name];\n                 double diff = std::abs(current_pos - default_pos);\n \n                 if (diff > position_threshold) {\n                     positions_ok = false;\n                     deviation_msg << dof_name << \"(\" << diff << \"), \";\n                 }\n             }\n \n             if (positions_ok) {\n                 RCLCPP_INFO(this->get_logger(), \"Switching to POLICY state\");\n                 current_state_ = RobotState::POLICY;\n                 time_ = 0.0;\n                 publishRobotState();\n                 \n             } else {\n                 RCLCPP_WARN(this->get_logger(), \n                     \"Cannot switch to POLICY state. Lower body joints with large deviations: %s\", \n                     deviation_msg.str().c_str());\n                 \n             }\n         } else {\n             RCLCPP_INFO(this->get_logger(), \n                 \"A button only works in MOVE_TO_DEFAULT state. Current state: %d\", \n                 static_cast<int>(current_state_));\n         }\n     }\n \n     // Normal state machine logic\n     switch (current_state_) {\n         case RobotState::ZERO_TORQUE:\n             SendZeroTorqueCommand();\n             get_crc(low_command);\n             lowcmd_publisher_->publish(low_command);\n             break;\n \n         case RobotState::MOVE_TO_DEFAULT:\n             SendDefaultPositionCommand();\n             get_crc(low_command);\n             lowcmd_publisher_->publish(low_command);\n             break;\n \n         case RobotState::POLICY:\n             SendPolicyCommand();\n             get_crc(low_command);\n             lowcmd_publisher_->publish(low_command);\n             break;\n \n         case RobotState::EMERGENCY_STOP:\n             // Emergency stop is handled at the beginning of the function\n             // This case should not be reached due to early return\n             break;\n     }\n     \n     // Publish current robot state\n     publishRobotState();\n   }\n \n   void SendZeroTorqueCommand() {\n     low_command.mode_pr = mode_;\n     low_command.mode_machine = mode_machine;\n \n     for (int i = 0; i < G1_NUM_MOTOR; ++i) {\n       low_command.motor_cmd[i].mode = 1; // Enable\n       low_command.motor_cmd[i].q = 0.0;\n       low_command.motor_cmd[i].dq = 0.0;\n       low_command.motor_cmd[i].kp = 0.0;\n       low_command.motor_cmd[i].kd = 0.0;\n       low_command.motor_cmd[i].tau = 0.0;\n     }\n   }\n \n   void SendDefaultPositionCommand() {\n     time_ += control_dt_;\n     low_command.mode_pr = mode_;\n     low_command.mode_machine = mode_machine;\n \n     // Print kp/kd values on first execution\n     static bool first_move_to_default = true;\n     if (first_move_to_default) {\n       RCLCPP_INFO(this->get_logger(), \"=== First MOVE_TO_DEFAULT execution ===\");\n       first_move_to_default = false;\n     }\n \n     if (has_joint_arrays_) {\n       // Use provided arrays and dof2motor mapping to command motors\n       double ratio = clamp(time_ / duration_, 0.0, 1.0);\n       for (size_t j = 0; j < joint_names_array_.size(); ++j) {\n         const std::string &dof_name = joint_names_array_[j];\n         if (dof2motor_idx.find(dof_name) == dof2motor_idx.end()) {\n           continue; // skip unknown names\n         }\n         int motor_idx = dof2motor_idx[dof_name];\n \n         double target_final = default_position_array_[j];\n         double target_pos = (1. - ratio) * motor[motor_idx].q + ratio * target_final;\n \n         // Current state\n         double current_pos = motor[motor_idx].q;\n         double current_vel = motor[motor_idx].dq;\n \n         // Use MoveToDefault specialized kp/kd\n         double kp_to_use = move_to_default_kps[dof_name];\n         double kd_to_use = move_to_default_kds[dof_name];\n         \n         // Print kp/kd values on first execution\n         if (time_ <= control_dt_ * 2) { // Print for first few iterations\n           RCLCPP_INFO(this->get_logger(), \"MoveToDefault - %s: kp=%.2f, kd=%.2f\", \n                      dof_name.c_str(), kp_to_use, kd_to_use);\n         }\n         \n         // Apply torque limiting with MoveToDefault gains\n         auto [limited_kp, limited_kd] = limitTorqueWithCustomGains(\n           dof_name, target_pos, current_pos, current_vel, kp_to_use, kd_to_use);\n \n         low_command.motor_cmd[motor_idx].mode = 1;\n         low_command.motor_cmd[motor_idx].tau = 0.0;\n         low_command.motor_cmd[motor_idx].q = target_pos;\n         low_command.motor_cmd[motor_idx].dq = 0.0;\n         low_command.motor_cmd[motor_idx].kp = limited_kp;\n         low_command.motor_cmd[motor_idx].kd = limited_kd;\n       }\n     } else {\n       // Fall back to map-driven order with default MoveToDefault gains\n       // Use default kp/kd values for MoveToDefault since policy kps/kds are not available yet\n       const double default_move_kp = 50.0;  // Default stiffness for MoveToDefault\n       const double default_move_kd = 5.0;   // Default damping for MoveToDefault\n       \n       // Print default kp/kd values on first execution\n       if (time_ <= control_dt_ * 2) { // Print for first few iterations\n         RCLCPP_INFO(this->get_logger(), \"MoveToDefault (fallback) - Using default kp=%.2f, kd=%.2f\", \n                    default_move_kp, default_move_kd);\n       }\n       \n       for (int i = 0; i < G1_NUM_MOTOR; ++i) {\n         std::string dof_name = complete_dof_order[i];\n         double ratio = clamp(time_ / duration_, 0.0, 1.0);\n         double target_pos = (1. - ratio) * motor[i].q + ratio * default_dof_pos[dof_name];\n \n         // Current state\n         double current_pos = motor[i].q;\n         double current_vel = motor[i].dq;\n \n         // Use default MoveToDefault gains with torque limiting\n         auto [limited_kp, limited_kd] = limitTorqueWithCustomGains(\n           dof_name, target_pos, current_pos, current_vel, default_move_kp, default_move_kd);\n \n         low_command.motor_cmd[i].mode = 1;\n         low_command.motor_cmd[i].tau = 0.0;\n         low_command.motor_cmd[i].q = target_pos;\n         low_command.motor_cmd[i].dq = 0.0;\n         low_command.motor_cmd[i].kp = limited_kp;\n         low_command.motor_cmd[i].kd = limited_kd;\n       }\n     }\n   }\n \n   void SendPolicyCommand() {\n     time_ += control_dt_;\n     low_command.mode_pr = mode_;\n     low_command.mode_machine = mode_machine;\n \n     // Print kp/kd values on first execution\n     static bool first_policy_command = true;\n     if (first_policy_command) {\n       RCLCPP_INFO(this->get_logger(), \"=== First POLICY command execution ===\");\n       first_policy_command = false;\n     }\n \n     // Check if kps and kds parameters have been received from policy node\n     if (!kps_received_ || !kds_received_) {\n       RCLCPP_ERROR(this->get_logger(), \n                   \"Policy control parameters not received! kps_received: %s, kds_received: %s\", \n                   kps_received_ ? \"true\" : \"false\", \n                   kds_received_ ? \"true\" : \"false\");\n       RCLCPP_ERROR(this->get_logger(), \"Cannot execute POLICY commands without control parameters. Triggering emergency stop.\");\n       current_state_ = RobotState::EMERGENCY_STOP;\n       should_shutdown_ = true;\n       publishRobotState();\n       return;\n     }\n \n     for (const auto &pair : target_dof_pos) {\n       const std::string &dof_name = pair.first;\n       const double &target_pos = pair.second;\n       int motor_idx = dof2motor_idx[dof_name];\n       \n       \n       // Get policy kp/kd values\n       double policy_kp = kps[dof_name];\n       double policy_kd = kds[dof_name];\n       \n       // Print kp/kd values on first execution\n       if (time_ <= control_dt_ * 2) { // Print for first few iterations\n         RCLCPP_INFO(this->get_logger(), \"Policy - %s: kp=%.2f, kd=%.2f\", \n                    dof_name.c_str(), policy_kp, policy_kd);\n       }\n       \n       // Use policy kp/kd values directly without torque limiting\n       low_command.motor_cmd[motor_idx].mode = 1;\n       low_command.motor_cmd[motor_idx].tau = 0.0;\n       low_command.motor_cmd[motor_idx].q = target_pos;\n       low_command.motor_cmd[motor_idx].dq = 0.0;\n       low_command.motor_cmd[motor_idx].kp = policy_kp;\n       low_command.motor_cmd[motor_idx].kd = policy_kd;\n     }\n   }\n \n   void SendDampedEmergencyStop() {\n     low_command.mode_pr = mode_;\n     low_command.mode_machine = mode_machine;\n \n     // Use default damping value for emergency stop since kds may not be available\n     const double default_emergency_kd = 10.0; // Higher damping for faster stopping\n \n     for (int i = 0; i < G1_NUM_MOTOR; ++i) {\n       std::string dof_name = complete_dof_order[i];\n       low_command.motor_cmd[i].mode = 1; // Keep enabled\n       low_command.motor_cmd[i].q = motor[i].q; // Current position\n       low_command.motor_cmd[i].dq = 0.0; // Target zero velocity\n       low_command.motor_cmd[i].kp = 0.0; // No position control\n       low_command.motor_cmd[i].kd = default_emergency_kd; // Use default damping\n       low_command.motor_cmd[i].tau = 0.0;\n     }\n   }\n \n   void SendFinalEmergencyStop() {\n     low_command.mode_pr = mode_;\n     low_command.mode_machine = mode_machine;\n \n     for (int i = 0; i < G1_NUM_MOTOR; ++i) {\n       low_command.motor_cmd[i].mode = 0; // Disable\n       low_command.motor_cmd[i].q = 0.0;\n       low_command.motor_cmd[i].dq = 0.0;\n       low_command.motor_cmd[i].kp = 0.0;\n       low_command.motor_cmd[i].kd = 0.0;\n       low_command.motor_cmd[i].tau = 0.0;\n     }\n   }\n \n   void LowStateHandler(unitree_hg::msg::LowState::SharedPtr message) {\n     mode_machine = (int)message->mode_machine;\n     imu = message->imu_state;\n     for (int i = 0; i < G1_NUM_MOTOR; i++) {\n       motor[i] = message->motor_state[i];\n     }\n \n     // Check joint limits for all joints\n     bool limits_exceeded = false;\n     std::string exceeded_msg;\n     \n     // Trigger emergency stop if any limits are exceeded\n     if (limits_exceeded) {\n       RCLCPP_ERROR(this->get_logger(), \"%s\", exceeded_msg.c_str());\n       RCLCPP_ERROR(this->get_logger(), \"Joint limits exceeded! Triggering emergency stop.\");\n       // current_state_ = RobotState::EMERGENCY_STOP;\n       // should_shutdown_ = true;\n       // publishRobotState();\n     }\n \n     remote_controller.set(message->wireless_remote);\n   }\n \n   void PolicyActionHandler(\n       const std_msgs::msg::Float32MultiArray::SharedPtr message) {\n     // RCLCPP_INFO(this->get_logger(), \"PolicyActionHandler called!\");\n     policy_action_data = message->data;\n \n     // Check if message size matches expected size\n     if (policy_action_data.size() != policy_dof_order.size()) {\n       RCLCPP_ERROR(this->get_logger(), \n                   \"Policy action data size mismatch: got %zu, expected %zu\", \n                   policy_action_data.size(), policy_dof_order.size());\n       current_state_ = RobotState::EMERGENCY_STOP;\n       should_shutdown_ = true;\n       publishRobotState();\n       return;\n     }\n \n     // set target dof pos\n     for (size_t i = 0; i < policy_dof_order.size(); i++) {\n       const auto &dof_name = policy_dof_order[i];\n \n       double calculated_pos = policy_action_data[i];\n       \n       // Check if the target position is within joint limits (with scaling)\n       if (joint_position_limits.find(dof_name) != joint_position_limits.end()) {\n         // Calculate the middle point of the range\n         double mid_pos = (joint_position_limits[dof_name].first + joint_position_limits[dof_name].second) / 2.0;\n         // Calculate the half-range and scale it\n         double half_range = (joint_position_limits[dof_name].second - joint_position_limits[dof_name].first) / 2.0;\n         double scaled_half_range = half_range * position_limit_scale;\n         \n         // Calculate scaled min and max by expanding from midpoint\n         double min_pos = mid_pos - scaled_half_range;\n         double max_pos = mid_pos + scaled_half_range;\n         \n         if (calculated_pos < min_pos || calculated_pos > max_pos) {\n           // RCLCPP_WARN(this->get_logger(), \n           //            \"Target position would exceed limit for joint %s: %f (scaled limits: %f, %f)\", \n           //            dof_name.c_str(), calculated_pos, min_pos, max_pos);\n           // Clamp the position to within limits\n           calculated_pos = std::clamp(calculated_pos, min_pos, max_pos);\n         }\n       }\n       \n       // Set the target position (clamped to safe values if needed)\n       target_dof_pos[dof_name] = calculated_pos;\n     }\n   }\n \n   void KpsHandler(const std_msgs::msg::Float32MultiArray::SharedPtr message) {\n     policy_kps_data = message->data;\n     kps_received_ = true;\n     \n     // Check if message size matches expected size\n     if (policy_kps_data.size() != policy_dof_order.size()) {\n       RCLCPP_ERROR(this->get_logger(), \n                   \"Policy kps data size mismatch: got %zu, expected %zu\", \n                   policy_kps_data.size(), policy_dof_order.size());\n       current_state_ = RobotState::EMERGENCY_STOP;\n       should_shutdown_ = true;\n       publishRobotState();\n       return;\n     }\n     \n     // Update kps map with policy data\n     for (size_t i = 0; i < policy_dof_order.size(); i++) {\n       const auto &dof_name = policy_dof_order[i];\n       kps[dof_name] = policy_kps_data[i];\n     }\n     \n     RCLCPP_INFO(this->get_logger(), \"Received kps parameters from policy node (size: %zu)\", policy_kps_data.size());\n   }\n \n   void KdsHandler(const std_msgs::msg::Float32MultiArray::SharedPtr message) {\n     policy_kds_data = message->data;\n     kds_received_ = true;\n     \n     // Check if message size matches expected size\n     if (policy_kds_data.size() != policy_dof_order.size()) {\n       RCLCPP_ERROR(this->get_logger(), \n                   \"Policy kds data size mismatch: got %zu, expected %zu\", \n                   policy_kds_data.size(), policy_dof_order.size());\n       current_state_ = RobotState::EMERGENCY_STOP;\n       should_shutdown_ = true;\n       publishRobotState();\n       return;\n     }\n     \n     // Update kds map with policy data\n     for (size_t i = 0; i < policy_dof_order.size(); i++) {\n       const auto &dof_name = policy_dof_order[i];\n       kds[dof_name] = policy_kds_data[i];\n     }\n     \n     RCLCPP_INFO(this->get_logger(), \"Received kds parameters from policy node (size: %zu)\", policy_kds_data.size());\n   }\n \n   double clamp(double value, double low, double high) {\n     if (value < low)\n       return low;\n     if (value > high)\n       return high;\n     return value;\n   }\n \n   std::string robotStateToString(RobotState state) {\n     switch (state) {\n       case RobotState::ZERO_TORQUE:\n         return \"ZERO_TORQUE\";\n       case RobotState::MOVE_TO_DEFAULT:\n         return \"MOVE_TO_DEFAULT\";\n       case RobotState::EMERGENCY_STOP:\n         return \"EMERGENCY_STOP\";\n       case RobotState::POLICY:\n         return \"POLICY\";\n       default:\n         return \"UNKNOWN\";\n     }\n   }\n \n   void publishRobotState() {\n     std_msgs::msg::String state_msg;\n     state_msg.data = robotStateToString(current_state_);\n     robot_state_publisher_->publish(state_msg);\n   }\n \n   rclcpp::TimerBase::SharedPtr timer_; // ROS2 timer\n   rclcpp::Publisher<unitree_hg::msg::LowCmd>::SharedPtr\n       lowcmd_publisher_; // ROS2 Publisher\n   rclcpp::Subscription<unitree_hg::msg::LowState>::SharedPtr\n       lowstate_subscriber_; // ROS2 Subscriber\n   rclcpp::Subscription<std_msgs::msg::Float32MultiArray>::SharedPtr\n       policy_action_subscriber_;\n   rclcpp::Subscription<std_msgs::msg::Float32MultiArray>::SharedPtr\n       kps_subscriber_;\n   rclcpp::Subscription<std_msgs::msg::Float32MultiArray>::SharedPtr\n       kds_subscriber_;\n   rclcpp::Publisher<std_msgs::msg::String>::SharedPtr robot_state_publisher_;\n   unitree_hg::msg::LowCmd low_command; // Unitree hg lowcmd message\n   unitree_hg::msg::IMUState imu;       // Unitree hg IMU message\n   unitree_hg::msg::MotorState\n       motor[G1_NUM_MOTOR]; // Unitree hg motor state message\n   double control_freq_;\n   double control_dt_;\n   int timer_dt;\n   double time_; // Running time count\n   double duration_;\n   PRorAB mode_ = PRorAB::PR;\n   int mode_machine;\n   RemoteController wireless_remote_;\n }; // End of humanoid_controller class\n \n int main(int argc, char **argv) {\n   rclcpp::init(argc, argv);                            // Initialize rclcpp\n   auto node = std::make_shared<humanoid_controller>(); // Create a ROS2 node\n   rclcpp::spin(node);                                  // Run ROS2 node\n   rclcpp::shutdown();                                  // Exit\n   return 0;\n }"
  },
  {
    "path": "deployment/unitree_g1_ros2_29dof/start_container.sh",
    "content": "#!/bin/bash\n\ndocker kill holomotion_orin_deploy\ndocker rm holomotion_orin_deploy\necho \"Old holomotion_orin_deploy container removed !\"\n\n# Initialize variable as empty\nholomotion_repo_path=\"\"\n\n# Loop until the user provides a non-empty string\nwhile [[ -z \"$holomotion_repo_path\" ]]; do\n  read -p \"Please enter the holomotion local repository path: \" holomotion_repo_path\n  \n  if [[ -z \"$holomotion_repo_path\" ]]; then\n    echo \"Input cannot be empty.\"\n  fi\ndone\n\n# Validate the directory exists before running Docker\nif [ ! -d \"$holomotion_repo_path\" ]; then\n    echo \"Error: Directory '$holomotion_repo_path' does not exist.\"\n    exit 1\nfi\n\necho \"Mounting path: $holomotion_repo_path\"\n\nsudo docker run -it \\\n  --name holomotion_orin_deploy \\\n  --runtime nvidia \\\n  --gpus all \\\n  --privileged \\\n  --network host \\\n  -e \"ACCEPT_EULA=Y\" \\\n  -v \"$holomotion_repo_path:/home/unitree/holomotion\" \\\n  -v \"/usr/local/cuda-11.4/targets/aarch64-linux/lib:/cuda_base:ro\" \\\n  -v \"/usr/lib/aarch64-linux-gnu/libcudnn.so.8.6.0:/host_gpu/libcudnn.so.8.6.0:ro\" \\\n  -v \"/usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.6.0:/host_gpu/libcudnn_ops_infer.so.8.6.0:ro\" \\\n  -v \"/usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.6.0:/host_gpu/libcudnn_cnn_infer.so.8.6.0:ro\" \\\n  horizonrobotics/holomotion:orin_foxy_jp5.1_humble_deploy_zmq_20260319 \\\n  bash -c \"ln -sf /host_gpu/libcudnn.so.8.6.0 /host_gpu/libcudnn.so.8 && \\\n           ln -sf /host_gpu/libcudnn_ops_infer.so.8.6.0 /host_gpu/libcudnn_ops_infer.so.8 && \\\n           ln -sf /host_gpu/libcudnn_cnn_infer.so.8.6.0 /host_gpu/libcudnn_cnn_infer.so.8 && \\\n           source /root/miniconda3/bin/activate && conda activate holomotion_deploy && exec bash\""
  },
  {
    "path": "docs/environment_setup.md",
    "content": "# Environment Setup\n\n## Step 1: Setup Conda\n\nThis project uses conda to manage Python environments. We recommend using [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install#linux-installer).\n\n**For users in China:** Configure the conda mirror following [TUNA](https://mirrors.tuna.tsinghua.edu.cn/help/anaconda/) for faster downloads.\n\n## Step 2: Setup Third-party Dependencies\n\n### 2.1 Download SMPL/SMPLX Models\n\nWe use SMPL/SMPLX models to retarget mocap data into robot motion data. Register your account and download the models from:\n\n- [SMPL](https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.1.0.zip)\n- [SMPLX](https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip)\n\nPlace both zip files (`SMPL_python_v.1.1.0.zip` and `models_smplx_v1_1.zip`) in the `thirdparties/` folder, then extract:\n\n```shell\nmkdir thirdparties/smpl_models\nunzip thirdparties/SMPL_python_v.1.1.0.zip -d thirdparties/smpl_models/\nunzip thirdparties/models_smplx_v1_1.zip -d thirdparties/smpl_models/\n```\n\nThe resulting file structure for smpl models would be:\n\n```shell\nthirdparties/\n├── smpl_models\n   ├── models\n   └── SMPL_python_v.1.1.0\n```\n\n### 2.2 Pull Submodules\n\nAfter cloning this repository, run the following command to get all submodule dependencies:\n\n```shell\ngit submodule update --init --recursive\n```\n\n### 2.3 Create Asset Symlinks\n\nThis project uses symbolic links to connect robot and SMPL assets from submodules to the main `assets` directory. Symlinks are created automatically when you clone the repository.\n\n### 2.4 Verify Third-party File Structure\n\nAfter completing the above steps, your file structure should look like this:\n\n```shell\nthirdparties/\n├── HoloMotion_assets\n├── GMR\n├── smplx\n├── joints2smpl\n├── omomo_release\n├── smpl_models\n├── SMPLSim\n├── unitree_ros\n└── unitree_ros2\n```\n\n## Step 3: Create the Conda Environment\n\nCreate the conda environment named `holomotion_train` and `holomotion_deploy`:\n\n```shell\nconda env create -f environments/environment_train_isaaclab_cu118.yaml\n# for newer GPUs like RTX 5090, use environment_train_isaaclab_cu128.yaml\nconda env create -f environments/environment_deploy.yaml\n```\n\n\nInstall smplx and GMR into the conda environment:\n\n```shell\ncd thirdparties\n\nconda activate holomotion_train\n\npip install -e ./smplx\n\n# use --no-deps to avoid pulling GMR's dependencies\npip install -e ./GMR --no-deps\n```\n\n## Step 4: Configure the Environment Variables\n\nHoloMotion uses `train.env` and `deploy.env` files to export environment variables in the shell entry scripts. Please make sure the `Train_CONDA_PREFIX` and the `Deploy_CONDA_PREFIX` variables in `train.env` and `deploy.env` are correctly setup. You can manually source these files and check the output in the shell.\n\nTake the `train.env` for example:\n\n```shell\nsource train.env\n```\n\nThese `.env` files will be sourced in the shell scripts (in `holomotion/scripts`) to correctly find and utilize your conda environments.\n"
  },
  {
    "path": "docs/evaluate_motion_tracking.md",
    "content": "## Evaluate the Motion Tracking Model\n\nAfter training for a while and saving model checkpoints, it is necessary to run the evaluation pipeline to get to know your model performance both visually and quantitatively. HoloMotion also bakes the model exporting process for later deployment in the evaluation pipeline.\n\n**Overall Workflow:**\n\n```mermaid\nflowchart LR\nA[Trained Checkpoints]\nB[HDF5 Database]\nC[Evaluation Config]\nD[Evaluation Entry]\nE[Offline Evaluation]\nF[Calculate Metrics]\nG[MuJoCo Visualization]\n\nA --> D\nB --> D\nC --> D\nD --> E\nE --> F\nE --> G\nclassDef dashed stroke-dasharray: 5 5, rx:10, ry:10, fill:#c9d9f5\nclassDef normal fill:#c9d9f5, rx:10, ry:10\nclass A,B dashed\nclass C,D,E,F,G normal\n```\n\n### 1 Offline Evaluation\n\n```bash\nbash ./holomotion/scripts/evaluation/eval_motion_tracking.sh\n```\n\nUpdate the evaluation script by setting `checkpoint_path` (e.g., `logs/Holomotion/model_1000.pt`) and `eval_h5_dataset_path`.\n\n### 2 Calculate Metrics\n\nProcess the `.npz` files generated in the previous step and convert them into a final quantitative JSON metrics report:\n\n```bash\nbash ./holomotion/scripts/evaluation/calc_offline_eval_metrics.sh\n```\n\n- `npz_dir`: Path to the folder containing `.npz` result files. \n- `dataset_suffix`: Evaluation dataset name, set to differentiate different datasets.\n\n### 3 MuJoCo Visualization\n\nGenerate video outputs to validate the motion tracking quality from the `.npz` result files by setting the `motion_npz_root` to the evaluation npz folder. Note that in order to properly visualize the recorded robot data, you should set the `+key_prefix=\"robot_\"` .\n\n```bash\nbash ./holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh\n```\n\n- `motion_npz_root`: Path to the folder containing `.npz` result files. \n- `video_rendering/{motion_name}.mp4` files in the corresponding `.npz` result files.\n\n### 4 Export Trained Model to ONNX\n\nTo deploy our policy to real world robots, we need to convert the pytorch module into ONNX format, which is supported by most inference frameworks.\nAfter running the evaluation script, the `.onnx` file will be generated and saved to the checkpoint directory:\n\n```\nlogs/HoloMotion/your_checkpoint_dir/\n├── config.yaml\n├── exported\n│   └── model_10000.onnx\n└── model_10000.pt\n```\n"
  },
  {
    "path": "docs/holomotion_motion_file_spec.md",
    "content": "## HoloMotion NPZ Format — Keys and Values\n\nThis document lists the exact keys saved in a HoloMotion NPZ and their value types/shapes.\n\n- Prefix policy\n\n  - ref\\_\\*: reference motion (source-of-truth produced by preprocessing)\n  - ft*ref*_: filtered reference motion (post-filtering; never overwrites ref\\__)\n  - robot\\_\\*: robot states (only present in offline evaluation exports)\n  - Legacy (no prefix): kept only for backward-compat; new files prefer ref\\_\\*\n\n- metadata\n\n  - type: JSON string\n  - fields:\n    - motion_key: str\n    - raw_motion_key: str\n    - motion_fps: float\n    - num_frames: int\n    - wallclock_len: float (seconds, approx (num_frames - 1) / motion_fps)\n    - num_dofs: int\n    - num_bodies: int\n    - clip_length: int (original clip length in frames)\n    - valid_prefix_len: int (contiguous valid frames from the start)\n\n- ref_dof_pos\n\n  - dtype: float32\n  - shape: [T, num_dofs] (URDF joint order; reference motion)\n\n- ref_dof_vel\n\n  - dtype: float32\n  - shape: [T, num_dofs] (URDF joint order; reference motion)\n\n- ref_global_translation\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (meters, world frame; reference motion)\n\n- ref_global_rotation_quat\n\n  - dtype: float32\n  - shape: [T, num_bodies, 4] (quaternion XYZW, world frame; reference motion)\n\n- ref_global_velocity\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (m/s, world frame; reference motion)\n\n- ref_global_angular_velocity\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (rad/s, world frame; reference motion)\n\n- ft_ref_dof_pos\n\n  - dtype: float32\n  - shape: [T, num_dofs] (filtered reference motion)\n\n- ft_ref_dof_vel\n\n  - dtype: float32\n  - shape: [T, num_dofs] (derived from filtered positions)\n\n- ft_ref_global_translation\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (filtered reference motion)\n\n- ft_ref_global_rotation_quat\n\n  - dtype: float32\n  - shape: [T, num_bodies, 4] (filtered, normalized XYZW)\n\n- ft_ref_global_velocity\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (derived from filtered positions)\n\n- ft_ref_global_angular_velocity\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (derived from filtered quaternions)\n\n- robot_dof_pos\n  - dtype: float32\n  - shape: [T, num_dofs] (URDF joint order; robot)\n\n- robot_dof_vel\n\n  - dtype: float32\n  - shape: [T, num_dofs] (URDF joint order; robot)\n\n- robot_global_translation\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (meters, world frame; robot)\n\n- robot_global_rotation_quat\n\n  - dtype: float32\n  - shape: [T, num_bodies, 4] (quaternion XYZW, world frame; robot)\n\n- robot_global_velocity\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (m/s, world frame; robot)\n\n- robot_global_angular_velocity\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (rad/s, world frame; robot)\n\n- dof_pos (deprecated legacy key)\n\n  - dtype: float32\n  - shape: [T, num_dofs] (URDF joint order; ref or robot)\n  - deprecated\n\n- dof_vels (deprecated legacy key)\n\n  - dtype: float32\n  - shape: [T, num_dofs] (URDF joint order; ref or robot)\n  - deprecated\n\n- global_translation (deprecated legacy key)\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (meters, world frame; ref or robot)\n  - deprecated\n\n- global_rotation_quat (deprecated legacy key)\n\n  - dtype: float32\n  - shape: [T, num_bodies, 4] (quaternion XYZW, world frame; ref or robot)\n  - deprecated\n\n- global_velocity (deprecated legacy key)\n\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (m/s, world frame; ref or robot)\n  - deprecated\n\n- global_angular_velocity (deprecated legacy key)\n  - dtype: float32\n  - shape: [T, num_bodies, 3] (rad/s, world frame; ref or robot)\n  - deprecated\n\nNotes:\n\n- T == num_frames from metadata.\n- All arrays are float32.\n"
  },
  {
    "path": "docs/motion_retargeting.md",
    "content": "# Motion Retargeting\n\nTransform human motion data into robot-compatible joint trajectories for following training. We support GMR for retargeting (https://github.com/YanjieZe/GMR)\n\n## Prerequisites\n\nBefore running the motion retargeting pipeline, ensure you have:\n\n### 1. Environment Setup\n\nPlease make sure the smplx and GMR are properly installed according to [[environment setup doc](./environment_setup.md)].\n\n### 2. Data Preparation\n\nPlace your AMASS motion data in `/assets/test_data/motion_retargeting/{dataset_name}`\nor modify 'amass_dir' in 'script/motion_retargeting/\\*.sh' !Please check all related path in .sh and .yaml are right!\n\n### 3. Model Preparation\n\n    Put SMPLX models under following path\n    thirdparties/\n    └── GMR/\n        ├── assets/\n        │   └── body_models/\n        │       └── smplx/\n        │           ├── SMPLX_FEMALE.npz\n        │           ├── SMPLX_FEMALE.pkl\n        │           ├── SMPLX_MALE.npz\n        │           ├── SMPLX_MALE.pkl\n        │           ├── SMPLX_NETURAL.npz\n        │           └── SMPLX_NETURAL.pkl\n\n### 4. Path Verification\n\nCheck data paths in the configuration scripts:\n\n- `holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh`\n\nBefore using GMR, it is recommended to run `bash ./holomotion/scripts/motion_retargeting/apply_gmr_motion_retarget_patch.sh` first, which can help reduce singular solutions to some extent.\n\n## Quick Start\n### 1. Motion Retargeting\n\n```bash\nbash ./holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh\n```\n\n> Reminder: set device = \"cuda:0\" to \"cpu\" in \"smplx_to_robot_dataset.py\" if facing cuda error\n\nAfter GMR retargeting, we further need to convert the dataset into a HoloMotion-compatible npz format, please run:\n\n```bash\nbash ./holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_to_holomotion.sh\n```\n\n### 2. Motion Visualization\n\nGenerate video outputs to validate retargeting quality of the HoloMotion npz files:\n\n```bash\nbash ./holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh\n```\n\n**Output**: `video_rendering/{motion_name}.mp4` files in the retargeted data directories\n\n### 3. Pack to HDF5 for Training\nAfter retargeting, we need to pack the npz files into a compact HDF5 database:\n```bash\nbash ./holomotion/scripts/motion_retargeting/pack_hdf5_dataset.sh\n```"
  },
  {
    "path": "docs/mujoco_sim2sim.md",
    "content": "# Sim2Sim Verification\n\nAfter generating the ONNX file from the evaluation stage, you can verify the performance of your Isaac-trained policy in another simulator, such as Mujoco to test its performance before deploying to the real robot.\n\nThe entry script is `holomotion/scripts/evaluation/eval_mujoco_sim2sim.sh`, you should set these variables before running:\n\n- `robot_xml_path`: The scene mjcf .xml file for the robot\n- `ONNX_PATH`: The exported ONNX model file\n- `motion_npz_path`: The npz file containing the reference motion\n"
  },
  {
    "path": "docs/realworld_deployment.md",
    "content": "# HoloMotion Deployment Guide\n\nThis guide describes how to set up the deployment environment and run the trained policy on a physical Unitree G1 robot.\n\n## Robot Configuration for 29 DOF\nThe 29 DOF configuration includes:\n\n- 12 leg joints (6 per leg)\n- 3 waist joints (yaw, roll, pitch)\n- 14 arm joints (7 per arm)\n\n---\n\n## Deployment Options\n\nThis guide provides two deployment methods:\n\n| Deployment Method                               | Target Platform   |\n| ----------------------------------------------- | ----------------- |\n| [Laptop Deployment](#laptop-deployment)         | Laptop/Desktop PC |\n| [PC2 Docker Deployment](#pc2-docker-deployment) | G1 Robot's PC2    |\n\nChoose the appropriate method based on your setup:\n\n- **For laptop/desktop deployment**: Follow the [Laptop Deployment](#laptop-deployment) section\n- **For PC2 on robot hardware**: Follow the [PC2 Docker Deployment](#pc2-docker-deployment) section\n\n### ⚠️ Important Safety Notice\n\n> **For safety reasons, it is strongly recommended to remove the dexterous hands before running the policy.**\n\n---\n\n## Laptop Deployment\n\n### Quick Environment Setup\n\n#### Prerequisites\n\nEnsure the following are installed before proceeding:\n\n- Anaconda or Miniconda\n- ROS 2 Humble installed at `/opt/ros/humble`\n- MCAP for efficient ROS 2 data recording\n- Unitree ROS 2 SDK installed at `~/unitree_ros2/`\n\n#### One-Click Deployment\n\n```bash\ncd <your_holomotion_repo_path>/deployment\nchmod +x deploy_environment.sh\n./deploy_environment.sh\n```\n\nThis script will:\n\n- Create a new conda environment (with CUDA support if available)\n- Install Python packages from `requirements/requirements_deploy.txt`\n- Install Unitree SDK Python bindings\n- Build the ROS 2 workspace under `unitree_g1_ros2_29dof/`\n\n---\n\n### Deploy on Physical G1 Robot (Laptop)\n\n### Setup Overview\n\nThe deployment process consists of two types of steps:\n\n| **One-Time Setup** (per computer) | **Every Run** (each time you use the robot) |\n| --------------------------------- | ------------------------------------------- |\n| Step 1: Network Configuration     | Step 3: Power On & Initialize Robot         |\n| Step 2: Launch Script Setup       | Step 4: Launch Policy Controller            |\n\n> **Note**: Once you complete Steps 1-2, you only need to do Steps 3-4 for each robot session!\n\n### Step 1: Connect and Configure Network\n\n#### Prerequisites for Network Setup:\n\n1. **Power on the robot** and wait for it to fully boot\n2. **Use an Ethernet cable** to connect your PC to the robot's LAN port\n3. **Ensure both devices are powered on** during configuration\n\n#### Network Configuration:\n\nConfigure your PC's network interface with the following static IP settings:\n\n- **Static IP**: `192.168.123.222`\n- **Netmask**: `255.255.255.0`\n- **Gateway**: (leave empty)\n\n#### Automatic Configuration Script:\n\nYou can use the following script to configure it automatically (use command `nmcli con show` to check your actual connection name):\n\n<details>\n<summary>Click to view set_static_ip.sh</summary>\n\n```bash\n#!/bin/bash\n\n# Replace with your actual connection name (use `nmcli con show` to check)\nCON_NAME=\"Wired connection 1\"\nIP_ADDRESS=\"192.168.123.222\"\nNETMASK=\"24\"\nGATEWAY=\"\"\n\nnmcli con modify \"$CON_NAME\" ipv4.addresses \"$IP_ADDRESS/$NETMASK\"\nnmcli con modify \"$CON_NAME\" ipv4.method manual\n\nif [ -n \"$GATEWAY\" ]; then\n  nmcli con modify \"$CON_NAME\" ipv4.gateway \"$GATEWAY\"\nfi\n\nnmcli con modify \"$CON_NAME\" ipv4.dns \"\"\nnmcli con down \"$CON_NAME\" && nmcli con up \"$CON_NAME\"\n```\n\n</details>\n\n---\n\n### Step 2: Prepare Launch Script\n\n#### Configure Network Interface:\n\n1. **Check your network interface name** (while connected to the robot):\n\n   ```bash\n   ifconfig\n   ```\n\n   Look for the interface connected to the robot (e.g., `enxf8e43ba00afd`, `eth0`, `enp0s31f6`)\n\n2. **Update the launch configuration**:\n   ```bash\n   # Edit the launch file\n   nano <your_holomotion_repo_path>/deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py\n   ```\n   Find and update the `network_interface` parameter with your actual interface name.\n\n\n---\n\n### Step 3: Power On and Initialize the Robot\n\n> **Do this every time** you want to run the robot.\n\n#### Robot Initialization Sequence for 29 DOF:\n\n1. **Power on the robot** - Start the robot in the **hanging position**\n2. **Wait for zero torque mode** - The robot will automatically enter zero torque mode (joints feel loose)\n3. **Connect your computer** - Use the same Ethernet cable to connect to the robot's LAN port\n4. **Enter debugging mode** - On the remote controller, press `L2 + R2` simultaneously. Note: the new deployment automatically enters this mode on startup, so manual entry is usually not required.\n\n---\n\n### Step 4: Launch the Policy Controller\n\n#### Preflight Checklist\n\nBefore running, ensure the following are ready.\n\n- Model folders configured in `g1_29dof_holomotion.yaml` exist\n  - `motion_tracking_model_folder`: under `src/models/`\n  - `velocity_tracking_model_folder`: under `src/models/`\n- Motion data directory exists and contains .npz files (retargeted results)\n  - `motion_clip_dir`: under `src/motion_data/`\n- Config file path used by launch is correct\n\n#### Motion Reference Source\n\nMotion tracking supports two reference sources:\n\n- **Offline motion mode**: the robot executes the selected `.npz` motion clip from `motion_clip_dir`.\n- **Online teleoperation mode**: the robot uses live `latest_obs` data streamed from the teleoperation workstation / VR pipeline. See [Holomotion teleop setup](../deployment/holomotion_teleop/holomotion_teleop_setup.md) for Pico / XRoboToolkit, ZMQ publishing, and launch order on the workstation.\n\nThe mode is selected by YAML settings in `g1_29dof_holomotion.yaml`:\n\n- **Offline motion mode**\n  - `vr.enable_teleop_reference: false`\n  - `vr.require_vr_data_for_motion: false`\n  - Result: even if ZMQ data is still arriving, the robot ignores it and motion tracking uses offline `.npz` clips only.\n- **Online teleoperation mode**\n  - `vr.enable_teleop_reference: true`\n  - `vr.require_vr_data_for_motion: true`\n  - `vr.latest_obs_zmq_uri: \"tcp://<workstation-ip>:6001\"`\n  - Result: motion mode waits for live teleoperation data before entering and uses the incoming `latest_obs` stream as the motion reference.\n\nIf you want teleoperation to be available but not mandatory, you can also use:\n\n- `vr.enable_teleop_reference: true`\n- `vr.require_vr_data_for_motion: false`\n\nIn that configuration, motion mode can still start without waiting for VR readiness, but live teleoperation data may be used when available at mode entry.\n\n#### One-click start\n\n```bash\ncd <your_holomotion_repo_path>/deployment/unitree_g1_ros2_29dof\nbash launch_holomotion_29dof.sh\n```\n\n> **Success indicator**: On startup, the robot joints should remain in zero torque state and feel free to move.\n\n#### Motion Control Modes\n\nThe 29 DOF robot operates in two main modes:\n\n| Mode    | How to Enter                                                          | Controls                                                                                                                     | Switch             |\n| ------- | --------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------ |\n| Velocity tracking | 1) Press Start to stand up, then press A<br/>2) From motion tracking: press Y | Left stick: move (vx, vy)<br/>Right stick: rotate (yaw)<br/>D-Pad: select motion clip (Left=first, Right=last, Up=prev, Down=next) | B: enter motion tracking   |\n| Motion tracking | Press B                                                               | Executes selected motion clip or online teleoperation automatically                                                                                        | Y: back to velocity tracking |\n\n#### Control Flow\n\nHere is the robot control flowchart for 29 DOF:\n\n```mermaid\nflowchart TD\n    subgraph prepPhase [\"Setup Phase\"]\n        direction TB\n        A[Set Robot<br/>to Hanging Position] --> B[Power On and<br/>Zero Torque Mode]\n        B --> C[L2+R2: Enter Debug Mode]\n        C --> D[Launch Program]\n    end\n\n    %% Main flow\n    prepPhase --> E[Start: Stand Up]\n    E --> F[Lower Robot to Ground]\n    F --> G[A: Enter Velocity tracking Mode]\n\n    %% Velocity tracking mode controls\n    G --> H[Velocity tracking Mode]\n    H --> H1[Left Stick: Move]\n    H --> H2[Right Stick: Rotate]\n    H --> H3[D-Pad: Select Motion Clip]\n    H --> H4[B: Enter Motion tracking Mode]\n\n    %% Motion tracking mode\n    H4 --> I[Motion tracking Mode]\n    I --> I1[Execute Motion Clip or Teleoperation]\n    I --> I2[Y: Back to Velocity tracking]\n\n    %% Mode switching\n    I2 --> H\n\n    %% Emergency stop\n    D --> N[Select: Emergency Stop]\n    E --> N\n    F --> N\n    G --> N\n    H --> N\n    I --> N\n    N --> O[Close Program]\n\n    classDef startEnd fill:#e1f5fe,stroke:#01579b,stroke-width:2px\n    classDef control fill:#f3e5f5,stroke:#4a148c,stroke-width:2px\n    classDef velocityTracking fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px\n    classDef motionTracking fill:#fff3e0,stroke:#e65100,stroke-width:2px\n    classDef emergency fill:#ffebee,stroke:#b71c1c,stroke-width:2px,stroke-dasharray: 5 5\n    classDef preparationFrame fill:#f9f9f9,stroke:#666,stroke-width:2px,stroke-dasharray: 5 5\n\n    class A,O startEnd\n    class B,C,D,E,F,G control\n    class H,H1,H2,H3,H4 velocityTracking\n    class I,I1,I2 motionTracking\n    class N emergency\n    class prepPhase preparationFrame\n```\n\n#### Configuration Files (used by Step 4)\n\n**System Configuration**\n\n- File: `HoloMotion/deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml`\n- Key parameters:\n  - `motion_tracking_model_folder`: motion tracking model folder under `models/`\n  - `velocity_tracking_model_folder`: velocity tracking model folder under `models/`\n  - `motion_clip_dir`: motion clip data folder under `src/`\n  - `vr.enable_teleop_reference`: enable or disable live teleoperation reference\n  - `vr.require_vr_data_for_motion`: whether motion mode must wait for live teleoperation data\n  - `vr.latest_obs_zmq_uri`: teleoperation ZMQ endpoint used in online mode\n\n**Pre-trained Models**\n\nWe provide a pre-trained velocity tracking model that you can download and use:\n\n- **Motion Tracking Model**: Download from [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_motion_tracking_model)\n- **Velocity Tracking Model**: Download from [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_velocity_tracking_model)\n\nTo use this model:\n\n1. Download the `holomotion_v1.2_velocity_tracking_model` folder from the Hugging Face repository\n2. Place the downloaded folder under `models/` (e.g., `models/holomotion_v1.2_velocity_tracking_model/`)\n4. Update `velocity_tracking_model_folder` in the `g1_29dof_holomotion.yaml` to point to this folder\n\n**Adding New Motion Tracking Models**\n\n1. Create a new folder under `models/` based on the following example model folder structure (e.g., `models/your_model_dir_name/`)\n2. Update `motion_tracking_model_folder` in the `g1_29dof_holomotion.yaml`\n3. Ensure the motion clip data files are in the `motion_clip_dir`\n\nExample model folder structure (motion model):\n\n```bash\nHoloMotion/deployment/unitree_g1_ros2_29dof/src/models/your_model_dir_name\n├── config.yaml\n├── exported\n    └── your_model_name.onnx\n```\n\n---\n\n### Safety Notice\n\nThis deployment is intended for demonstration only. It is not a production-grade control system. Do not interfere with the robot during operation. If unexpected behavior occurs, exit control immediately via the controller or keyboard to ensure safety.\n\nTo stop the control process, press `Select` or use `Ctrl+C` in the terminal.\n\n---\n\n## PC2 Docker Deployment\n\n### Setup Overview\n\nThe deployment process consists of two types of steps:\n\n| **One-Time Setup** (per PC2) | **Every Run** (each time you use the robot) |\n| ----------------------------- | -------------------------------------------- |\n| Step 1: Configure Docker      | Step 4: Start Docker Container              |\n| Step 2: Load Docker Image     | Step 5: Power On & Initialize Robot         |\n| Step 3: Configure Launch File Network Interface     | Step 6: Launch Policy Controller            |\n\n> **Note**: Once you complete Steps 1-3, you only need to do Steps 4-6 for each robot session!\n\n### System Requirements\n\n- **Platform**: NVIDIA Jetson Orin\n- **JetPack**: 5.1\n- **Ubuntu**: 20.04\n- **ROS 2**: Foxy\n- **Docker**: Installed with NVIDIA Container Runtime support\n\n### Step 1: Configure Docker for NVIDIA Runtime\n\nModify `/etc/docker/daemon.json`:\n\n```json\n{\n  \"runtimes\": {\n    \"nvidia\": {\n      \"path\": \"nvidia-container-runtime\",\n      \"runtimeArgs\": []\n    }\n  },\n  \"default-runtime\": \"nvidia\"\n}\n```\n\nRestart Docker and verify:\n\n```bash\nsudo systemctl restart docker\nsudo docker info | grep -i runtime\n```\n\n### Step 2: Load Docker Image\n\nPull the image from dockerhub with:\n\n```bash\ndocker pull horizonrobotics/holomotion:orin_foxy_jp5.1_docker_humble_deploy_zmq_20260319\n```\n\nOr if you have the image locally, tag it appropriately:\n\n```bash\ndocker tag <your_image_name> holomotion:orin_foxy_jp5.1_docker_humble_deploy_zmq_20260319\n```\n\n### Step 3: Configure Launch File Network Interface:\n\n\n1. **Check your network interface name on the robot**:\n\n   ```bash\n   ifconfig\n   ```\n\n   Look for the interface with IP `192.168.123.164`. The interface is typically `eth0`.\n\n2. **Update the launch configuration** if your interface is not `eth0`:\n\n   ```bash\n   nano <your_holomotion_repo_path>/deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py\n   ```\n\n   Find line 103 and update the `network_interface` parameter:\n   ```python\n   network_interface = \"eth0\"  # Change to your actual interface name\n   ```\n\n### Step 4: Start Docker Container\n\n> **Important**: Before running Docker commands, ensure your user is added to the docker group. If you encounter permission errors, add your user to the docker group:\n\n```bash\nsudo usermod -aG docker $USER\n```\n\nAfter adding your user to the docker group, you need to log out and log back in (or restart your session) for the changes to take effect. Verify with:\n\n```bash\ngroups\n```\n\nYou should see `docker` in the list of groups.\n\n> **Important**: You need to run this step **every time** you want to use the robot. The script will automatically remove any existing container and start a fresh one.\n\n```bash\ncd <your_holomotion_repo_path>/deployment/unitree_g1_ros2_29dof\nbash start_container.sh\n```\n\n**When prompted, enter the holomotion repository path:**\n\n- The script will ask: `Please enter the holomotion local repository path:`\n- Enter the full path to your holomotion repository, for example:\n  - `/home/unitree/HoloMotion` (if the repository is at this location)\n  - Or the actual path where your holomotion repository is located\n\n\n### Step 5: Power On and Initialize Robot\n\n> **Do this every time** before launching the policy controller.\n\n1. Put the robot in hanging position\n2. Wait for zero torque mode\n3. Press `L2 + R2` on remote controller for debug mode \n\n### Step 6: Launch Policy Controller\n\n> **Do this every time** you want to run the robot (after Steps 4 and 5).\n\n**Preflight Checklist**\n\nBefore running, ensure the following are ready.\n\n- Model folders configured in `g1_29dof_holomotion.yaml` exist\n  - `motion_tracking_model_folder`: under `src/models/`\n  - `velocity_tracking_model_folder`: under `src/models/`\n- Motion data directory exists and contains .npz files (retargeted results)\n  - `motion_clip_dir`: under `src/motion_data/`\n- Config file path used by launch is correct\n- Motion reference source is configured as intended (see [Motion Reference Source](#motion-reference-source))\n\n**Pre-trained Models**\n\nYou can download and use the pre-trained velocity tracking model. Refer to the [Pre-trained Models](#configuration-files-used-by-step-4) section above for general instructions. \n\n> **Note**: The model folder should be placed in your local repository before starting the Docker container, as the repository is mounted into the container.\n\n**One-click start in the docker**\n\n```bash\ncd /home/unitree/holomotion/deployment/unitree_g1_ros2_29dof\nbash launch_holomotion_29dof_docker.sh\n```\n\n> **Note**: The control flow is the same as described in the [Control Flow](#control-flow) section above.\n\n\n---\n\n### Safety Notice\n\nThis deployment is intended for demonstration only. It is not a production-grade control system. Do not interfere with the robot during operation. If unexpected behavior occurs, exit control immediately via the controller or keyboard to ensure safety.\n\nTo stop the control process:\n\n- Press `Select` on the remote controller, or\n- Use `Ctrl+C` in the terminal (inside Docker container)\n"
  },
  {
    "path": "docs/smpl_data_curation.md",
    "content": "# Dataset Preparation Guide\n\nThis guide describes the workflow and setup for preparing datasets to train the motion tracking model.\nWe use **AMASS-compatible SMPL-format motion capture data** as the training input.\n\n---\n\n## Overview\n\nThe dataset preparation pipeline has the following steps:\n\n1. **Download datasets**\n   - To train with diverse and rich motion data, you first need to collect raw motion capture datasets from various sources.\n   - Then place all downloaded datasets under the data/raw_datasets directory in their original structure.\n2. **Convert datasets to AMASS format**\n   - To ensure that all motion data is compatible with the AMASS-style .npz format used by the training pipeline, you need to convert the raw datasets.\n   - Then run the conversion script to generate .npz files under data/amass_compatible_datasets/.\n3. **Filter datasets**\n   - To improve data quality by removing abnormal, noisy, or unwanted motion samples, you can optionally run the filtering step.\n   - Then run the filtering script to generate filtered .yaml files under holomotion/config/data_curation/.\n4. **Visualize Prepared Data**\n   - Use the included visualization utility to preview and inspect the generated AMASS-compatible `.npz` motion files.\n   - Quickly check for anomalies or errors before training.\n5. **Generate Motion from Monocular Video**\n   - You can also generate SMPL-format motion capture files **directly from monocular RGB videos** using GVHMR.\n   - This allows you to create training data or test the model with real-world video footage.\n   - Pipeline are given follow.\n\n### Directory Structure After Full Setup\n\n```\ndata/\n├── raw_datasets/\n│   ├── humanact12/\n│   ├── OMOMO/\n│   ├── MotionX/\n│   └── ZJU_Mocap/\n├── amass_compatible_datasets/\n│   ├── amass/\n│   │   ├── ACCAD/\n│   │   ├── BioMotionLab_NTroje/\n│   │   ├── ...\n│   ├── humanact12/\n│   ├── OMOMO/\n│   ├── MotionX/\n│   └── ZJU_Mocap/\n├── dataset_labels/\n│   ├── humanact12.jsonl\n│   ├── OMOMO.jsonl\n│   ├── MotionX.jsonl\n│   ├── ZJU_Mocap.jsonl\n│   ├── amass.jsonl\n```\n\n---\n\n## Step-by-Step Instructions\n\n### 1. Download Datasets\n\nDownload and extract the datasets into the `data/` folder as follows:\n\n- `data/amass_compatible_datasets/amass/` (required)\n  - [AMASS dataset](https://amass.is.tue.mpg.de/download.php) — choose **SMPL-X G** format.\n- `data/raw_datasets/humanact12/` (optional)\n  - [HumanAct12](https://github.com/EricGuo5513/action-to-motion?tab=readme-ov-file)\n- `data/raw_datasets/OMOMO/` (optional)\n  - [OMOMO dataset](https://github.com/lijiaman/omomo_release?tab=readme-ov-file)\n- `data/raw_datasets/MotionX/` (optional)\n  - [MotionX dataset](https://github.com/IDEA-Research/Motion-X)\n- `data/raw_datasets/ZJU_Mocap/` (optional)\n  - [EasyMocap](https://github.com/zju3dv/EasyMocap)\n\n---\n\n### 2. Convert Optional Datasets to AMASS Format (optional)\n\nSkip this step if you only use amass dataset.\n\nStep:\n\n1. Initialize Submodules:\n    Some datasets require external repositories or models for proper conversion. These are included as **submodules** under:\n    ```\n    thirdparties/\n    ```\n    Initialize them with:\n    ```bash\n    git submodule update --init --recursive\n    ```\n    If you need to modify or update these submodules, refer to their individual `README` files.\n\n\n2. Download SMPL Model:\n    - Download `SMPL_NEUTRAL.npz` from the [SMPL official website](https://smpl.is.tue.mpg.de/download.php).\n    - Place the file into:\n    ```\n    ./assets/smpl/\n    ```\n\n3. Modify `thirdparties/joints2smpl/src/customloss.py`:\n    Before running the pipeline, make sure to modify the `body_fitting_loss_3d` function in `thirdparties/joints2smpl/src/customloss.py` to include the following change:\n    ```python\n    joint3d_loss = (joint_loss_weight ** 2) * joint3d_loss_part.sum(dim=-1)\n    ```\n\n4. Modify `thirdparties/joints2smpl/src/smplify.py`:\n    Next, ensure the following modification in the `__call__` function of `SMPLify3D` inside `thirdparties/joints2smpl/src/smplify.py`:\n    ```python\n    init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach()\n    ```\n\n5. Start Conversion:\n    Run the provided script to convert all available datasets to AMASS `.npz` files:\n\n    ```bash\n    bash holomotion/scripts/data_curation/convert_to_amass.sh\n    ```\n\n    This script reads from `data/{dataset}/` and writes to `data/amass_compatible_datasets/{dataset}/`.\n\n---\n\n### 3. Filter Datasets (optional)\n\nSkip this step if you prefer to use all available data for training.\n\n#### Why filter?\n\nThe raw datasets may contain motions that are irrelevant, undesirable, or of poor quality for training. This step helps improve the overall quality and consistency of your dataset.\n\n#### Filtering criteria\n\nThe filtering process excludes samples based on the following rules:\n\n- **Upstairs/Downstairs motion:**  \n  Paths containing keywords like stairs, staircase, upstairs, downstairs, or motions with large upward/downward Z translation and velocity are excluded.\n\n- **Sitting motion:**  \n  Sitting motion: Paths containing sitting/Sitting keywords or frames that match a reference sitting pose are excluded.\n\n- **Known abnormal datasets:**  \n  Known abnormal datasets: Samples from datasets like aist is excluded.\n\n- **Unrealistic velocity:**  \n  Unrealistic velocity: Motions where the mean velocity exceeds a threshold (default: 100.0) are excluded.\n\n#### How to run\n\nYou can use the `-l` option to specify which datasets to filter (space-separated list).  \nRun the filtering script to identify and exclude abnormal or unwanted samples:\n\n```bash\nbash holomotion/scripts/data_curation/filter_smpl_data.sh -l \"amass humanact12 OMOMO MotionX ZJU_Mocap\"\n```\n\nThe output `.yaml` files will be placed in `holomotion/config/data_curation/`.\n\n---\n\n## Notes\n\n- Paths are relative to the project root.\n- The AMASS dataset must be manually requested from their website.\n- Dataset conversion and filtering may take time depending on your hardware.\n\n---\n\nThis guide assumes that you only need the basic configuration to run the complete pipeline. For further customization, refer to the relevant scripts in the repository and optional steps in the documentation.\n\n## Video2SMPL Instructions\n### 1. Environment setup\nCreate a new conda environment named 'gvhmr' following the official installation guide  \n[[GVHMR setup doc](../thirdparties/GVHMR/docs/INSTALL.md)]\n> Reminder: If you encounter 'hmr4d' missing module errors during runtime, install the following dependency separately in gvhmr env\n```bash\npip install hmr4d\n```\n\nRename the SMPL model files as follows  \nbasicmodel_{GENDER}_lbs_*.pkl  →  SMPL_{GENDER}.pkl  \nPlace the SMPL and SMPL-X model files into the directory structure below\n\n```\nthirdparties/GVHMR/inputs/checkpoints/\n├── body_models/smplx\n│   └── SMPLX_{GENDER}.npz\n└── body_models/smpl\n    └── SMPL_{GENDER}.pkl\n```\n\n### 2. Video to SMPL motion data\nConfirm that all input videos have a frame rate of 30 FPS to avoid motion acceleration or deceleration.\n```bash\nbash ./holomotion/scripts/data_curation/video_to_smpl_gvhmr.sh\n```\n> Reminder: Set the directory in the .sh file to an absolute path.\n\n### 3. Visualize generated SMPL motion data\nVisualize the SMPL motion sequences generated by GVHMR for inspection and debugging.\n```bash\nbash ./holomotion/scripts/data_curation/visualize_smpl_gvhmr.sh\n```\n\n### 4. Convert SMPL data to SMPLX\nMotion data from GVHMR are SMPL format.\nUse ./thirdparties/GMR/scripts/smpl_to_smplx.py converting format to SMPLX for retargeting."
  },
  {
    "path": "docs/train_motion_tracking.md",
    "content": "## Train the Motion Tracking Model\n\nAfter completing motion retargeting, you can train a motion tracking model with HoloMotion using the following process.\n\n**Overall Workflow:**\n\n```mermaid\nflowchart LR\nA[Motion Retargeting] --> B[HDF5 Database]\nB --> C[Training Config]\nC --> D[Training Entry]\nD --> E[Distributed PPO Training]\n\nclassDef dashed stroke-dasharray: 5 5, rx:10, ry:10, fill:#c9d9f5\nclassDef normal fill:#c9d9f5, rx:10, ry:10\nclass A dashed\nclass B,C,D,E normal\n```\n\n### 1. Train the Motion Tracking Model\n\nThe training entry point is `holomotion/src/training/train.py`, which uses the training config to start distributed training across multiple GPUs.\n\n#### 2.1 Prepare the Training Config\n\nUse the demo config at `holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking.yaml` as a template. Key configuration groups to modify (configs are located in the `holomotion/config/` directory):\n\n- **`/algo`**: Algorithm settings (PPO) and network configurations\n- **`/robot`**: Robot-specific config including DOF, body links, and control parameters\n- **`/env`**: Environment settings including motion sampling and curriculum learning\n- **`/env/observations`**: Observation dimensions, noise, and scaling for the policy\n- **`/env/rewards`**: Reward function definitions\n- **`/env/domain_randomization`**: Domain randomization settings (start with `NO_domain_rand`)\n- **`/env/terrain`**: Terrain configuration\n- **`/modules`**: The policy network modules definitions\n\n```yaml\n# @package _global_\n\ndefaults:\n  - /training: train_base\n  - /algo: ppo\n  - /robot: unitree/G1/29dof/29dof_training_isaaclab\n  - /env: motion_tracking\n  - /env/terminations: termination_motion_tracking\n  - /env/observations: motion_tracking/obs_motion_tracking_tf-moe\n  - /env/rewards: motion_tracking/rew_motion_tracking\n  - /env/domain_randomization: domain_rand_medium\n  - /env/terrain: isaaclab_plane\n  - /modules: motion_tracking/motion_tracking_tf-moe\n\nproject_name: HoloMotion\n```\n\n#### 2.2 Train your Policy\n\nReview and modify the training script at `holomotion/scripts/training/train_motion_tracking.sh`. Ensure `config_name` match your training config and LMDB database directory.\n\nStart training by running:\n\n```shell\nbash holomotion/scripts/training/train_motion_tracking.sh\n\n# or\n\nbash holomotion/scripts/training/train_velocity_tracking.sh\n```\n\nNote that IsaacLab relies on internet connections to pull assets from Nvidia's cloud storage. If you encountered stuck at scene creation, it is very likely that you can't access the cloud-hosted assets. Turn on your proxy and try again can solve the issue.\n\n### Training Tips\n\n#### How to use less GPU ?\n\nTraining requires significant GPU memory. Reduce `num_envs` if your GPU has limited GRAM. This will reduce both the rollout burden and the PPO training consumption, at the risk of significantly less stable policy optimization process.\n\n#### How to start multiple training session ?\n\nIn cases where you would like to start multiple training sessions, you should explicitly add the `--main_process_port=port_number` option in the training entry bash script to avoid port conflict of the accelerate backend. And this `port_number` **can not** be `0` .\n\nIf you would like to run training on a specific GPU, just modify the GPU id in the `export CUDA_VISIBLE_DEVICES=\"X\"` statement.\n\n#### How to set the save/log intervals ?\n\nYou may want to have more or less frequent logging and model dumping intervals. You can alter these intervals by adding the following options:\n\n- `algo.config.save_interval=X` : The checkpoint will be saved every `X` learning iterations.\n- `algo.config.log_interval=Y`: The logging information will be displayed every `Y` learning iterations.\n\n#### Where is the checkpoint dumped ?\n\nBy default, the model checkpoint will be dumped into a folder named `logs/HoloMotion`. You can change this path by explictly setting `project_name=X`, which results in dumping the checkpoints into the `logs/X` directory.\n\n#### How to resume training from a checkpoint ?\n\nTo resume training from a pretrained checkpoint, you can find the checkpoint in the log directory, and then add the option like this: `checkpoint=logs/HoloMotion/20250728_214414-train_unitree_g1_21dof_teacher/model_X.pt`\n"
  },
  {
    "path": "environments/environment_deploy.yaml",
    "content": "name: holomotion_deploy\nchannels:\n  - pytorch\n  - nvidia  \n  - conda-forge\n  - defaults\n\ndependencies:\n  # Python runtime\n  - python=3.10\n\n  # PyTorch with CUDA support\n  - pytorch-cuda=12.1\n  - cudnn>=9  \n  - cudatoolkit>=11.7,<12\n  - pytorch==2.3.1\n  - torchvision==0.18.1\n  - torchaudio==2.3.1\n\n  # Scientific computing packages (via conda for better compatibility)\n  - numpy==1.24.3\n  - scipy\n  - matplotlib\n  - pandas\n\n  # System utilities and development tools\n  - sshpass=1.06\n  - git\n  - curl\n  - wget\n  - pyyaml\n  - easydict\n  - joblib\n\n  # Basic Python package management\n  - pip\n  - setuptools\n  - wheel\n\n  # Install additional packages via pip\n  - pip:\n    - -r environments/requirements_deploy.txt"
  },
  {
    "path": "environments/environment_train_isaaclab_cu118.yaml",
    "content": "name: holomotion_train\nchannels:\n  - conda-forge\ndependencies:\n  - python=3.11\n  - pip\n  - mesalib\n  - pip:\n      - -r requirements_torch_cu118.txt\n      - -r requirements_base.txt\n      - -e ../thirdparties/SMPLSim\n      - -e ../\n"
  },
  {
    "path": "environments/environment_train_isaaclab_cu128.yaml",
    "content": "name: holomotion_train\nchannels:\n  - conda-forge\ndependencies:\n  - python=3.11\n  - pip\n  - mesalib\n  - pip:\n      - -r requirements_torch_cu128.txt\n      - -r requirements_base.txt\n      - -e ../thirdparties/SMPLSim\n      - -e ../\n"
  },
  {
    "path": "environments/requirements_base.txt",
    "content": "isaacsim[all,extscache]==5.0.0\nisaaclab[isaacsim,all]==2.2.0\n\n\nsetuptools\nwheel\nnumpy==1.26.0\nsmplx==0.1.28\nhydra-core==1.3.2\neasydict\ntqdm\nopen3d\nlxml\nray\nipdb\njoblib\nscipy\njupyter\nloguru\ntensorboard\nmujoco\nmink\ndm_control\nloop_rate_limiters\nqpsolvers[quadprog,proxqp]\naccelerate\ntabulate\nmatplotlib\npandas\ntermcolor\nrich\npytorch-tcn\neinops\nonnxruntime-gpu\nonnx\npre-commit\nruff\npytest\nimageio>=2.9\nimageio-ffmpeg\nopencv-python\nnatsort\npsutil\nredis[hiredis]\nchumpy\npyvirtualdisplay\npynput\nxxhash\nh5py>=3.8\npygame\ntensordict==0.11.0\npytorch_kinematics\nonnxscript\n\n# human_body_prior 依赖\nhuman-body-prior"
  },
  {
    "path": "environments/requirements_deploy.txt",
    "content": "# Machine Learning Runtime  \nonnxruntime-gpu\n\n# SMPL/SMPLX support\nsmplx==0.1.28\n\n# Configuration management\nhydra-core==1.3.2\nomegaconf\n# Progress and logging\ntqdm\nloguru\ntermcolor\nrich\n\n# Data processing\nlmdb\neinops\n\n# Protobuf (specific version for compatibility)\nprotobuf==3.20.3\nonnx\n\n# Development tools\nipdb\n\nmujoco\npygame\n# Note: The following packages are installed via conda in environment_deploy.yaml:\n# - torch, torchvision, torchaudio (with CUDA support)\n# - numpy, scipy, matplotlib, pandas  \n# - pyyaml, easydict, joblib\n# - system utilities (git, curl, wget, sshpass) "
  },
  {
    "path": "environments/requirements_torch_cu118.txt",
    "content": "--extra-index-url https://download.pytorch.org/whl/cu118\n--extra-index-url https://pypi.nvidia.com\n\ntorch==2.7.0\ntorchvision==0.22.0\ntorchaudio==2.7.0\n"
  },
  {
    "path": "environments/requirements_torch_cu128.txt",
    "content": "--extra-index-url https://download.pytorch.org/whl/cu128\n--extra-index-url https://pypi.nvidia.com\n\ntorch==2.10.0+cu128\ntorchvision==0.25.0+cu128\ntorchaudio==2.10.0+cu128\n"
  },
  {
    "path": "environments/requirements_torch_cu130.txt",
    "content": "--extra-index-url https://download.pytorch.org/whl/cu130\n--extra-index-url https://pypi.nvidia.com\n\ntorch==2.10.0+cu130\ntorchvision==0.25.0+cu130\ntorchaudio==2.10.0+cu130\n"
  },
  {
    "path": "holomotion/config/algo/ppo.yaml",
    "content": "# @package _global_\n\nalgo:\n  _target_: holomotion.src.algo.ppo.PPO\n  _recursive_: false\n  config:\n    # --- General Settings ---\n    enable_online_eval: false\n    num_learning_iterations: 10001\n    log_interval: 5\n    save_interval: 500\n    export_policy: true\n    onnx_name_suffix: null\n    use_kv_cache: true\n    eval_interval: null\n    load_optimizer: true\n    headless: ${headless}\n    # ---\n\n    # --- Accelerate Settings ---\n    mixed_precision: null # \"fp16\", \"bf16\", or null. Use \"bf16\" for A100/H100, \"fp16\" for older GPUs\n    dynamo_backend: \"inductor\" # \"inductor\", \"aot_eager\", \"cudagraphs\", or null. Enables automatic model compilation during prepare()\n    # ---\n\n    # --- PPO Related Settings ---\n    init_at_random_ep_len: true\n    num_steps_per_env: 32\n    num_learning_epochs: 3\n    num_mini_batches: 4\n    clip_param: 0.2\n    gamma: 0.99\n    lam: 0.95\n    value_loss_coef: 1.0\n    entropy_coef: 5.0e-3\n    anneal_entropy: false\n    zero_entropy_point: 1.0\n    max_grad_norm: 1.0\n    use_clipped_value_loss: true\n    desired_kl: 0.01\n    init_noise_std: 1.0\n\n    # --- Optimizer Settings ---\n    optimizer_type: AdamW # Options: \"AdamW\", \"Adam\"\n    schedule: adaptive\n    actor_learning_rate: 3.0e-4\n    critic_learning_rate: 5.0e-4\n    adaptive_lr:\n      adapt_critic: false\n      lr_scaler: 1.2\n      kl_high_factor: 2.0\n      kl_low_factor: 0.5\n      min_learning_rate: 1.0e-6\n      max_learning_rate: 1.0\n\n    distributed_update:\n      mode: scalable # Options: \"legacy\", \"scalable\"\n      lr_scale:\n        mode: sqrt_world_size # Options: \"none\", \"sqrt_world_size\", \"linear_world_size\"\n        reference_world_size: 1\n        max_scale: null\n      kl_early_stop:\n        enabled: true\n        signal: window_mean # Shared windowed KL control signal\n        window_size: 3\n        factor: 1.8\n        min_updates: 1\n\n    # Distributed training settings\n    normalize_advantage_per_mini_batch: false # Use global advantage norm for DDP\n    global_advantage_norm: true # Sync advantages across all ranks\n\n    # --- Sampling Strategy ---\n    sampling_strategy: uniform\n    curriculum:\n      p_a_ratio: 0.5\n      ema_alpha_signal: 0.2\n      ema_alpha_rel_improve: 0.2\n      relative_eps: 1.0e-6\n      dump_whole_window_scores_json: false\n      dump_whole_window_scores_every_swaps: 10\n\n    weighted_bin:\n      bin_regex_patterns: []\n      dump_sampled_keys: false\n      dump_sampled_keys_interval: 1000\n\n    # --- Module Settings ---\n    module_dict:\n      actor: ${modules.actor}\n      critic: ${modules.critic}\n\n    symmetry_loss:\n      enabled: false\n      coef: 0.1\n      dof_sign_by_name: ${robot.dof_sign_by_name}\n"
  },
  {
    "path": "holomotion/config/algo/ppo_tf.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - ppo\n\nalgo:\n  _target_: holomotion.src.algo.ppo_tf.PPOTF\n  config:\n    num_steps_per_env: 32\n    kl_coef: 0.0\n    schedule: adaptive\n    actor_learning_rate: 3.0e-5\n    critic_learning_rate: 5.0e-5\n\n    num_learning_epochs: 3\n    num_mini_batches: 24\n\n    clip_param: 0.2\n    entropy_coef: 5.0e-3\n    desired_kl: 0.01\n    noise_std_type: log\n    fix_sigma: false\n    init_noise_std: 1.0\n    min_sigma: 0.01\n    max_sigma: 1.2\n    aux_state_pred:\n      enabled: true\n      w_keybody_contact: 1.0e-2\n      w_base_lin_vel: 1.0e-2\n      w_ref_keybody_rel_pos: 1.0e-1\n      w_robot_keybody_rel_pos: 1.0e-1\n\n      min_std: 0.01\n      max_std: 2.0\n\n      keybody_contact_names:\n        - left_hip_pitch_link\n        - right_hip_pitch_link\n        - left_knee_link\n        - right_knee_link\n        - left_ankle_roll_link\n        - right_ankle_roll_link\n        - left_elbow_link\n        - right_elbow_link\n        - left_wrist_yaw_link\n        - right_wrist_yaw_link\n\n      keybody_rel_pos_names:\n        - left_knee_link\n        - right_knee_link\n        - left_ankle_roll_link\n        - right_ankle_roll_link\n        - left_elbow_link\n        - right_elbow_link\n        - left_wrist_yaw_link\n        - right_wrist_yaw_link\n\n    dead_expert_margin_to_topk:\n      enabled: true\n      weight: 10.0\n\n    aux_router_command_recon:\n      enabled: false\n      weight: 0.0\n      hidden_dim: 0\n      term_prefix: actor_ref_\n\n    aux_router_switch_penalty:\n      enabled: false\n      weight: 0.0\n\n    router_expert_orthogonal:\n      enabled: false\n      weight: 0.0\n      min_active_usage: 1.0e-3\n      eps: 1.0e-8\n\n    selected_expert_margin_to_unselected:\n      enabled: false\n      weight: 0.0\n      target: 0.0\n\n    moe_router:\n      routing_score_fn: softmax\n      routing_scale: 1.0\n      use_dynamic_bias: false\n      bias_update_rate: 0.001\n      expert_bias_clip: 0.0\n"
  },
  {
    "path": "holomotion/config/data_curation/joints2smpl.yaml",
    "content": ""
  },
  {
    "path": "holomotion/config/data_curation/smplify_base.yaml",
    "content": ""
  },
  {
    "path": "holomotion/config/env/domain_randomization/NO_domain_rand.yaml",
    "content": "# @package _global_\n\ndomain_rand:\n  action_delay:\n    enabled: false\n\n  erfi:\n    enabled: false\n\n  motion_init_perturb:\n    root_pose_perturb_range:\n      x: [0.0, 0.0]\n      y: [0.0, 0.0]\n      z: [0.0, 0.0]\n      roll: [0.0, 0.0]\n      pitch: [0.0, 0.0]\n      yaw: [0.0, 0.0]\n    root_vel_perturb_range:\n      x: [0.0, 0.0]\n      y: [0.0, 0.0]\n      z: [0.0, 0.0]\n      roll: [0.0, 0.0]\n      pitch: [0.0, 0.0]\n      yaw: [0.0, 0.0]\n    dof_pos_perturb_range: [0.0, 0.0]\n    dof_vel_perturb_range: [0.0, 0.0]\n\n  obs_noise:\n    actor_ref_gravity_projection_cur:\n      n_min: 0\n      n_max: 0\n    actor_ref_gravity_projection_fut:\n      n_min: 0\n      n_max: 0\n    actor_ref_base_linvel_cur:\n      n_min: 0\n      n_max: 0\n      n_min_z: 0\n      n_max_z: 0\n    actor_ref_base_linvel_fut:\n      n_min: 0\n      n_max: 0\n      n_min_z: 0\n      n_max_z: 0\n    actor_ref_base_angvel_cur:\n      n_min: 0\n      n_max: 0\n      n_min_z: 0\n      n_max_z: 0\n    actor_ref_base_angvel_fut:\n      n_min: 0\n      n_max: 0\n      n_min_z: 0\n      n_max_z: 0\n    actor_ref_dof_pos_cur:\n      n_min: 0\n      n_max: 0\n    actor_ref_dof_pos_fut:\n      n_min: 0\n      n_max: 0\n    actor_ref_root_height_cur:\n      n_min: 0\n      n_max: 0\n    actor_ref_root_height_fut:\n      n_min: 0\n      n_max: 0\n    actor_ref_keybody_rel_pos_cur:\n      n_min: 0\n      n_max: 0\n    actor_ref_keybody_rel_pos_fut:\n      n_min: 0\n      n_max: 0\n\n    actor_projected_gravity:\n      n_min: 0.0\n      n_max: 0.0\n    actor_rel_robot_root_ang_vel:\n      n_min: 0.0\n      n_max: 0.0\n    actor_dof_pos:\n      n_min: 0.0\n      n_max: 0.0\n    actor_dof_vel:\n      n_min: 0.0\n      n_max: 0.0"
  },
  {
    "path": "holomotion/config/env/domain_randomization/domain_rand_medium.yaml",
    "content": "# @package _global_\n\ndomain_rand:\n  action_delay:\n    enabled: true\n    min_delay: 0\n    max_delay: 2\n\n  erfi:\n    enabled: false\n    rfi_probability: 0.5\n    rfi_lim: 0.1\n    randomize_rfi_lim: true\n    rfi_lim_range: [0.5, 1.5]\n    rao_lim: 0.1\n\n  obs_noise:\n    actor_ref_gravity_projection_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_gravity_projection_fut:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_base_linvel_cur:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.05\n      n_max_z: 0.05\n    actor_ref_base_linvel_fut:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.05\n      n_max_z: 0.05\n    actor_ref_base_angvel_cur:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.1\n      n_max_z: 0.1\n    actor_ref_base_angvel_fut:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.1\n      n_max_z: 0.1\n    actor_ref_dof_pos_cur:\n      n_min: -0.05\n      n_max: 0.05\n    actor_ref_dof_pos_fut:\n      n_min: -0.05\n      n_max: 0.05\n    actor_ref_root_height_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_root_height_fut:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_keybody_rel_pos_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_keybody_rel_pos_fut:\n      n_min: -0.1\n      n_max: 0.1\n\n    actor_projected_gravity:\n      n_min: -0.1\n      n_max: 0.1\n    actor_rel_robot_root_ang_vel:\n      n_min: -0.2\n      n_max: 0.2\n    actor_dof_pos:\n      n_min: -0.01\n      n_max: 0.01\n    actor_dof_vel:\n      n_min: -0.5\n      n_max: 0.5\n\n  motion_init_perturb:\n    root_pose_perturb_range:\n      x: [-0.05, 0.05]\n      y: [-0.05, 0.05]\n      z: [-0.01, 0.01]\n      roll: [-0.1, 0.1]\n      pitch: [-0.1, 0.1]\n      yaw: [-0.2, 0.2]\n    root_vel_perturb_range:\n      x: [-0.5, 0.5]\n      y: [-0.5, 0.5]\n      z: [-0.2, 0.2]\n      roll: [-0.5, 0.5]\n      pitch: [-0.5, 0.5]\n      yaw: [-0.2, 0.2]\n    dof_pos_perturb_range: [-0.1, 0.1]\n    dof_vel_perturb_range: [0.0, 0.0]\n\n  default_dof_pos_bias:\n    mode: startup\n    params:\n      joint_names: [\".*\"]\n      pos_distribution_params: [-0.01, 0.01]\n      operation: add\n      distribution: uniform\n\n  rigid_body_com:\n    mode: startup\n    params:\n      body_names: torso_link\n      com_range:\n        x: [-0.075, 0.075]\n        y: [-0.1, 0.1]\n        z: [-0.1, 0.1]\n\n  randomize_mass:\n    mode: startup\n    params:\n      body_names:\n        - \"pelvis\"\n        - \"torso_link\"\n      mass_range: [-1.0, 2.0]\n\n  rigid_body_material:\n    mode: startup\n    params:\n      body_names: \".*\"\n      static_friction_range: [0.3, 1.6]\n      dynamic_friction_range: [0.3, 1.2]\n      restitution_range: [0.0, 0.5]\n      num_buckets: 64\n\n  push_by_setting_velocity:\n    mode: interval\n    interval_range_s: [1.0, 3.0]\n    params:\n      velocity_range:\n        x: [-0.5, 0.5]\n        y: [-0.5, 0.5]\n        z: [-0.2, 0.2]\n        roll: [-0.52, 0.52]\n        pitch: [-0.52, 0.52]\n        yaw: [-0.78, 0.78]\n\n  randomize_actuator_gains:\n    mode: startup\n    params:\n      asset_name: robot\n      body_names: \".*\"\n      stiffness_distribution_params: [0.9, 1.1]\n      damping_distribution_params: [0.9, 1.1]\n      operation: scale\n      distribution: uniform\n"
  },
  {
    "path": "holomotion/config/env/domain_randomization/domain_rand_small.yaml",
    "content": "# @package _global_\n\ndomain_rand:\n  action_delay:\n    enabled: false\n    min_delay: 0\n    max_delay: 0\n\n  erfi:\n    enabled: false\n    rfi_probability: 0.5\n    rfi_lim: 0.1\n    randomize_rfi_lim: true\n    rfi_lim_range: [0.5, 1.5]\n    rao_lim: 0.1\n\n  obs_noise:\n    actor_ref_gravity_projection_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_gravity_projection_fut:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_base_linvel_cur:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.05\n      n_max_z: 0.05\n    actor_ref_base_linvel_fut:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.05\n      n_max_z: 0.05\n    actor_ref_base_angvel_cur:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.1\n      n_max_z: 0.1\n    actor_ref_base_angvel_fut:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.1\n      n_max_z: 0.1\n    actor_ref_dof_pos_cur:\n      n_min: -0.05\n      n_max: 0.05\n    actor_ref_dof_pos_fut:\n      n_min: -0.05\n      n_max: 0.05\n    actor_ref_root_height_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_root_height_fut:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_keybody_rel_pos_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_keybody_rel_pos_fut:\n      n_min: -0.1\n      n_max: 0.1\n\n    actor_projected_gravity:\n      n_min: -0.1\n      n_max: 0.1\n    actor_rel_robot_root_ang_vel:\n      n_min: -0.2\n      n_max: 0.2\n    actor_dof_pos:\n      n_min: -0.01\n      n_max: 0.01\n    actor_dof_vel:\n      n_min: -0.5\n      n_max: 0.5\n\n  motion_init_perturb:\n    root_pose_perturb_range:\n      x: [-0.05, 0.05]\n      y: [-0.05, 0.05]\n      z: [-0.01, 0.01]\n      roll: [-0.1, 0.1]\n      pitch: [-0.1, 0.1]\n      yaw: [-0.2, 0.2]\n    root_vel_perturb_range:\n      x: [-0.3, 0.3]\n      y: [-0.3, 0.3]\n      z: [-0.1, 0.1]\n      roll: [-0.3, 0.3]\n      pitch: [-0.3, 0.3]\n      yaw: [-0.4, 0.4]\n    dof_pos_perturb_range: [-0.1, 0.1]\n    dof_vel_perturb_range: [0.0, 0.0]\n\n  default_dof_pos_bias:\n    mode: startup\n    params:\n      joint_names: [\".*\"]\n      pos_distribution_params: [-0.01, 0.01]\n      operation: add\n      distribution: uniform\n\n  rigid_body_com:\n    mode: startup\n    params:\n      body_names: torso_link\n      com_range:\n        x: [-0.025, 0.025]\n        y: [-0.05, 0.05]\n        z: [-0.05, 0.05]\n\n  randomize_mass:\n    mode: startup\n    params:\n      body_names:\n        - \"pelvis\"\n        - \"torso_link\"\n      mass_range: [-1.0, 2.0]\n\n  rigid_body_material:\n    mode: startup\n    params:\n      body_names: \".*\"\n      static_friction_range: [0.3, 1.6]\n      dynamic_friction_range: [0.3, 1.2]\n      restitution_range: [0.0, 0.5]\n      num_buckets: 64\n\n  push_by_setting_velocity:\n    mode: interval\n    interval_range_s: [1.0, 3.0]\n    params:\n      velocity_range:\n        x: [-0.5, 0.5]\n        y: [-0.5, 0.5]\n        z: [-0.2, 0.2]\n        roll: [-0.52, 0.52]\n        pitch: [-0.52, 0.52]\n        yaw: [-0.78, 0.78]\n\n  randomize_actuator_gains:\n    mode: startup\n    params:\n      asset_name: robot\n      body_names: \".*\"\n      stiffness_distribution_params: [0.9, 1.1]\n      damping_distribution_params: [0.9, 1.1]\n      operation: scale\n      distribution: uniform\n"
  },
  {
    "path": "holomotion/config/env/domain_randomization/domain_rand_strong.yaml",
    "content": "# @package _global_\n\ndomain_rand:\n  action_delay:\n    enabled: true\n    min_delay: 0\n    max_delay: 4\n\n  erfi:\n    enabled: true\n    rfi_probability: 0.5\n    rfi_lim: 0.1\n    randomize_rfi_lim: true\n    rfi_lim_range: [0.5, 1.5]\n    rao_lim: 0.1\n\n  obs_noise:\n    actor_ref_gravity_projection_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_gravity_projection_fut:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_base_linvel_cur:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.05\n      n_max_z: 0.05\n    actor_ref_base_linvel_fut:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.05\n      n_max_z: 0.05\n    actor_ref_base_angvel_cur:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.1\n      n_max_z: 0.1\n    actor_ref_base_angvel_fut:\n      n_min: -0.1\n      n_max: 0.1\n      n_min_z: -0.1\n      n_max_z: 0.1\n    actor_ref_dof_pos_cur:\n      n_min: -0.05\n      n_max: 0.05\n    actor_ref_dof_pos_fut:\n      n_min: -0.05\n      n_max: 0.05\n    actor_ref_root_height_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_root_height_fut:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_keybody_rel_pos_cur:\n      n_min: -0.1\n      n_max: 0.1\n    actor_ref_keybody_rel_pos_fut:\n      n_min: -0.1\n      n_max: 0.1\n\n    actor_projected_gravity:\n      n_min: -0.1\n      n_max: 0.1\n    actor_rel_robot_root_ang_vel:\n      n_min: -0.2\n      n_max: 0.2\n    actor_dof_pos:\n      n_min: -0.01\n      n_max: 0.01\n    actor_dof_vel:\n      n_min: -0.5\n      n_max: 0.5\n\n  motion_init_perturb:\n    root_pose_perturb_range:\n      x: [-0.05, 0.05]\n      y: [-0.05, 0.05]\n      z: [-0.01, 0.01]\n      roll: [-0.1, 0.1]\n      pitch: [-0.1, 0.1]\n      yaw: [-0.2, 0.2]\n    root_vel_perturb_range:\n      x: [-0.5, 0.5]\n      y: [-0.5, 0.5]\n      z: [-0.2, 0.2]\n      roll: [-0.5, 0.5]\n      pitch: [-0.5, 0.5]\n      yaw: [-0.2, 0.2]\n    dof_pos_perturb_range: [-0.1, 0.1]\n    dof_vel_perturb_range: [0.0, 0.0]\n\n  default_dof_pos_bias:\n    mode: startup\n    params:\n      joint_names: [\".*\"]\n      pos_distribution_params: [-0.01, 0.01]\n      operation: add\n      distribution: uniform\n\n  rigid_body_com:\n    mode: startup\n    params:\n      body_names: torso_link\n      com_range:\n        x: [-0.075, 0.075]\n        y: [-0.1, 0.1]\n        z: [-0.1, 0.1]\n\n  randomize_mass:\n    mode: startup\n    params:\n      body_names:\n        - \"pelvis\"\n        - \"torso_link\"\n      mass_range: [-1.0, 2.0]\n\n  rigid_body_material:\n    mode: startup\n    params:\n      body_names: \".*\"\n      static_friction_range: [0.3, 1.6]\n      dynamic_friction_range: [0.3, 1.2]\n      restitution_range: [0.0, 0.5]\n      num_buckets: 64\n\n  push_by_setting_velocity:\n    mode: interval\n    interval_range_s: [1.0, 3.0]\n    params:\n      velocity_range:\n        x: [-0.5, 0.5]\n        y: [-0.5, 0.5]\n        z: [-0.2, 0.2]\n        roll: [-0.52, 0.52]\n        pitch: [-0.52, 0.52]\n        yaw: [-0.78, 0.78]\n\n  randomize_actuator_gains:\n    mode: startup\n    params:\n      asset_name: robot\n      body_names: \".*\"\n      stiffness_distribution_params: [0.9, 1.1]\n      damping_distribution_params: [0.9, 1.1]\n      operation: scale\n      distribution: uniform\n"
  },
  {
    "path": "holomotion/config/env/motion_tracking.yaml",
    "content": "# @package _global_\n\nenv:\n  _target_: holomotion.src.env.motion_tracking.MotionTrackingEnv\n  _recursive_: False\n  config:\n    experiment_name: ${experiment_name}\n    num_envs: ${num_envs}\n    env_spacing: 2.5 # meters\n    replicate_physics: true\n    headless: ${headless}\n    num_processes: ${num_processes}\n    main_process: ${main_process}\n    process_id: ${process_id}\n    ckpt_dir: null\n    disable_ref_viz: false\n    eval_log_dir: null\n    save_rendering_dir: null\n\n    robot: ${robot}\n    domain_rand: ${domain_rand}\n    rewards: ${rewards}\n    terrain: ${terrain}\n    obs: ${obs}\n    terminations: ${terminations}\n\n    simulation:\n      episode_length_s: 10 # Long episodes for fluid motion-based termination\n      sim_freq: 200\n      control_decimation: 4\n      physx:\n        bounce_threshold_velocity: 0.5\n        gpu_max_rigid_patch_count: 327680 # 10 * 2**15\n\n    scene:\n      terrain: ${terrain}\n      lighting:\n        distant_light_intensity: 3000.0\n        dome_light_intensity: 1000.0\n      contact_sensor:\n        history_length: 3\n        force_threshold: 10.0\n        track_air_time: true\n        debug_vis: false\n\n    actions:\n      dof_pos:\n        type: joint_position\n        params:\n          asset_name: robot\n          joint_names:\n            - \".*\"\n          use_default_offset: true\n          scale: ${robot.actuators.action_scale}\n\n    commands:\n      ref_motion:\n        type: MotionCommandCfg\n        params:\n          command_obs_name: bydmmc_ref_motion\n          motion_lib_cfg: ${robot.motion}\n          urdf_dof_names: ${robot.dof_names}\n          urdf_body_names: ${robot.body_names}\n          arm_dof_names: ${robot.arm_dof_names}\n          waist_dof_names: ${robot.waist_dof_names}\n          leg_dof_names: ${robot.leg_dof_names}\n          arm_body_names: ${robot.arm_body_names}\n          torso_body_names: ${robot.torso_body_names}\n          leg_body_names: ${robot.leg_body_names}\n          anchor_bodylink_name: ${robot.anchor_body}\n          asset_name: robot\n          debug_vis: true\n          root_pose_perturb_range: ${domain_rand.motion_init_perturb.root_pose_perturb_range}\n          root_vel_perturb_range: ${domain_rand.motion_init_perturb.root_vel_perturb_range}\n          dof_pos_perturb_range: ${domain_rand.motion_init_perturb.dof_pos_perturb_range}\n          dof_vel_perturb_range: ${domain_rand.motion_init_perturb.dof_vel_perturb_range}\n          resample_time_interval_s: 100\n          n_fut_frames: ${obs.n_fut_frames}\n          target_fps: 50\n\n    normalization:\n      clip_actions: 100.0\n      clip_observations: 100.0\n\n    resample_motion_when_training: True\n\n    curriculum:\n      enabled: false\n\n      robot_friction_completion_rate:\n        enabled: True\n        func: robot_friction_range_by_completion_rate\n        params:\n          num_updates: 5\n          cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]\n          static_friction_target: [0.3, 1.6]\n          dynamic_friction_target: [0.3, 1.2]\n          body_names: \".*\"\n          restitution_range: [0.0, 0.5]\n          num_buckets: 64\n\n      rigid_body_com_completion_rate:\n        enabled: True\n        func: rigid_body_com_by_completion_rate\n        params:\n          num_updates: 5\n          cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]\n          state_prefix: \"_cr_curr\"\n          asset_name: \"robot\"\n          body_names: \"torso_link\"\n          com_range_target:\n            x: [-0.025, 0.025]\n            y: [-0.05, 0.05]\n            z: [-0.05, 0.05]\n\n      default_dof_pos_bias_completion_rate:\n        enabled: True\n        func: default_dof_pos_bias_by_completion_rate\n        params:\n          num_updates: 5\n          cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]\n          state_prefix: \"_cr_curr\"\n          joint_names:\n            - \".*\"\n          pos_distribution_params_target: [-0.01, 0.01]\n          operation: add\n          distribution: uniform\n\n      push_by_setting_velocity_completion_rate:\n        enabled: True\n        func: isaaclab_mdp.modify_term_cfg\n        params:\n          address: \"events.push_by_setting_velocity.params\"\n          modify_fn: push_by_setting_velocity_range_by_completion_rate\n          modify_params:\n            num_updates: 3\n            cr_thresholds: [0.20, 0.30, 0.40]\n            velocity_range_target:\n              x: [-0.5, 0.5]\n              y: [-0.5, 0.5]\n              z: [-0.2, 0.2]\n              roll: [-0.52, 0.52]\n              pitch: [-0.52, 0.52]\n              yaw: [-0.78, 0.78]\n\n      randomize_actuator_gains_completion_rate:\n        enabled: True\n        func: randomize_actuator_gains_by_completion_rate\n        params:\n          num_updates: 3\n          cr_thresholds: [0.20, 0.30, 0.40]\n          asset_name: \"robot\"\n          body_names: \".*\"\n          stiffness_distribution_params_target: [0.9, 1.1]\n          damping_distribution_params_target: [0.9, 1.1]\n          operation: scale\n          distribution: uniform\n\n      action_rate_l2_completion_rate:\n        enabled: true\n        func: reward_term_weight_by_completion_rate\n        params:\n          reward_term_name: \"action_rate_l2\"\n          final_weight: -0.1\n          start_scale: 0.1\n          num_updates: 5\n          cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]\n\n      joint_pos_limits_completion_rate:\n        enabled: true\n        func: reward_term_weight_by_completion_rate\n        params:\n          reward_term_name: \"joint_pos_limits\"\n          final_weight: -10.0\n          start_scale: 0.1\n          num_updates: 5\n          cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]\n\n      undesired_contacts_completion_rate:\n        enabled: true\n        func: reward_term_weight_by_completion_rate\n        params:\n          reward_term_name: \"undesired_contacts\"\n          final_weight: -0.1\n          start_scale: 0.1\n          num_updates: 5\n          cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]\n\n"
  },
  {
    "path": "holomotion/config/env/observations/motion_tracking/obs_motion_tracking_mlp.yaml",
    "content": "# @package _global_\n\nobs:\n  context_length: 32\n  n_fut_frames: 10\n  target_fps: 50\n  actor_obs_prefix: \"ref_\"\n  critic_obs_prefix: \"ref_\"\n\n  obs_groups:\n    unified:\n      atomic_obs_list:\n        - actor_ref_gravity_projection_cur:\n            func: ref_gravity_projection_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_max}\n\n        - actor_ref_gravity_projection_fut:\n            func: ref_gravity_projection_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_max}\n\n        # Reference base linear velocity\n        - actor_ref_base_linvel_cur:\n            func: ref_base_linvel_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max}\n                n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min_z}\n                n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max_z}\n\n        - actor_ref_base_linvel_fut:\n            func: ref_base_linvel_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max}\n                n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min_z}\n                n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max_z}\n\n        - actor_ref_base_angvel_cur:\n            func: ref_base_angvel_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max}\n                n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min_z}\n                n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max_z}\n\n        - actor_ref_base_angvel_fut:\n            func: ref_base_angvel_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max}\n                n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min_z}\n                n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max_z}\n\n        - actor_ref_dof_pos_cur:\n            func: ref_dof_pos_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_max}\n\n        - actor_ref_dof_pos_fut:\n            func: ref_dof_pos_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_max}\n\n        - actor_ref_motion_filter_cutoff_hz:\n            func: ref_motion_filter_cutoff_hz\n\n        - actor_ref_root_height_cur:\n            func: ref_root_height_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_max}\n\n        - actor_ref_root_height_fut:\n            func: ref_root_height_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_max}\n\n        - actor_ref_keybody_rel_pos_cur:\n            func: ref_keybody_rel_pos_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n              keybody_names:\n                - \"left_knee_link\"\n                - \"right_knee_link\"\n                - \"left_ankle_roll_link\"\n                - \"right_ankle_roll_link\"\n                - \"left_elbow_link\"\n                - \"right_elbow_link\"\n                - \"left_wrist_yaw_link\"\n                - \"right_wrist_yaw_link\"\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_max}\n\n        - actor_ref_keybody_rel_pos_fut:\n            func: ref_keybody_rel_pos_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n              keybody_names:\n                - \"left_knee_link\"\n                - \"right_knee_link\"\n                - \"left_ankle_roll_link\"\n                - \"right_ankle_roll_link\"\n                - \"left_elbow_link\"\n                - \"right_elbow_link\"\n                - \"left_wrist_yaw_link\"\n                - \"right_wrist_yaw_link\"\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_max}\n\n        - actor_projected_gravity:\n            func: projected_gravity\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_projected_gravity.n_min}\n                n_max: ${domain_rand.obs_noise.actor_projected_gravity.n_max}\n\n        - actor_rel_robot_root_ang_vel:\n            func: rel_robot_root_ang_vel\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_min}\n                n_max: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_max}\n\n        - actor_dof_pos:\n            func: dof_pos\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_dof_pos.n_min}\n                n_max: ${domain_rand.obs_noise.actor_dof_pos.n_max}\n\n        - actor_dof_vel:\n            func: dof_vel\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_dof_vel.n_min}\n                n_max: ${domain_rand.obs_noise.actor_dof_vel.n_max}\n\n        - actor_last_action:\n            func: last_action\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n\n        - critic_ref_dof_pos_cur:\n            func: ref_dof_pos_cur\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_dof_pos_fut:\n            func: ref_dof_pos_fut\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_root_height_fut:\n            func: ref_root_height_fut\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_root_height_cur:\n            func: ref_root_height_cur\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_global_anchor_diff:\n            func: global_anchor_diff\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_cur_heading_aligned_root_pos:\n            func: ref_motion_cur_heading_aligned_root_pos\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_fut_heading_aligned_root_pos:\n            func: ref_motion_fut_heading_aligned_root_pos\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_cur_heading_aligned_root_rot6d:\n            func: ref_motion_cur_heading_aligned_root_rot6d\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_fut_heading_aligned_root_rot6d:\n            func: ref_motion_fut_heading_aligned_root_rot6d\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_cur_heading_aligned_root_lin_vel:\n            func: ref_motion_cur_heading_aligned_root_lin_vel\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_fut_heading_aligned_root_lin_vel:\n            func: ref_motion_fut_heading_aligned_root_lin_vel\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_cur_heading_aligned_root_ang_vel:\n            func: ref_motion_cur_heading_aligned_root_ang_vel\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_fut_heading_aligned_root_ang_vel:\n            func: ref_motion_fut_heading_aligned_root_ang_vel\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_rel_robot_root_lin_vel:\n            func: rel_robot_root_lin_vel\n\n        - critic_rel_robot_root_ang_vel:\n            func: rel_robot_root_ang_vel\n\n        - critic_global_robot_bodylink_lin_vel_flat:\n            func: global_robot_bodylink_lin_vel_flat\n\n        - critic_global_robot_bodylink_ang_vel_flat:\n            func: global_robot_bodylink_ang_vel_flat\n\n        - critic_root_rel_robot_bodylink_pos_flat:\n            func: root_rel_robot_bodylink_pos_flat\n\n        - critic_root_rel_robot_bodylink_rot_mat_flat:\n            func: root_rel_robot_bodylink_rot_mat_flat\n\n        - critic_dof_pos:\n            func: dof_pos\n\n        - critic_dof_vel:\n            func: dof_vel\n\n        - critic_last_action:\n            func: last_action\n\n      enable_corruption: true\n      concatenate_terms: false\n"
  },
  {
    "path": "holomotion/config/env/observations/motion_tracking/obs_motion_tracking_tf-moe.yaml",
    "content": "# @package _global_\n\nobs:\n  context_length: 1\n  n_fut_frames: 10\n  target_fps: 50\n  actor_obs_prefix: \"ref_\"\n  critic_obs_prefix: \"ref_\"\n\n  obs_groups:\n    unified:\n      atomic_obs_list:\n        - actor_ref_gravity_projection_cur:\n            func: ref_gravity_projection_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_max}\n\n        - actor_ref_gravity_projection_fut:\n            func: ref_gravity_projection_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_max}\n\n        # Reference base linear velocity\n        - actor_ref_base_linvel_cur:\n            func: ref_base_linvel_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max}\n                n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min_z}\n                n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max_z}\n\n        - actor_ref_base_linvel_fut:\n            func: ref_base_linvel_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max}\n                n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min_z}\n                n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max_z}\n\n        - actor_ref_base_angvel_cur:\n            func: ref_base_angvel_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max}\n                n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min_z}\n                n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max_z}\n\n        - actor_ref_base_angvel_fut:\n            func: ref_base_angvel_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max}\n                n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min_z}\n                n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max_z}\n\n        - actor_ref_dof_pos_cur:\n            func: ref_dof_pos_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_max}\n\n        - actor_ref_dof_pos_fut:\n            func: ref_dof_pos_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_max}\n\n        - actor_ref_motion_filter_cutoff_hz:\n            func: ref_motion_filter_cutoff_hz\n\n        - actor_ref_root_height_cur:\n            func: ref_root_height_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_max}\n\n        - actor_ref_root_height_fut:\n            func: ref_root_height_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_max}\n\n        - actor_ref_keybody_rel_pos_cur:\n            func: ref_keybody_rel_pos_cur\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n              keybody_names:\n                - \"left_knee_link\"\n                - \"right_knee_link\"\n                - \"left_ankle_roll_link\"\n                - \"right_ankle_roll_link\"\n                - \"left_elbow_link\"\n                - \"right_elbow_link\"\n                - \"left_wrist_yaw_link\"\n                - \"right_wrist_yaw_link\"\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_max}\n\n        - actor_ref_keybody_rel_pos_fut:\n            func: ref_keybody_rel_pos_fut\n            params:\n              ref_prefix: ${obs.actor_obs_prefix}\n              keybody_names:\n                - \"left_knee_link\"\n                - \"right_knee_link\"\n                - \"left_ankle_roll_link\"\n                - \"right_ankle_roll_link\"\n                - \"left_elbow_link\"\n                - \"right_elbow_link\"\n                - \"left_wrist_yaw_link\"\n                - \"right_wrist_yaw_link\"\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_min}\n                n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_max}\n\n        - actor_projected_gravity:\n            func: projected_gravity\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_projected_gravity.n_min}\n                n_max: ${domain_rand.obs_noise.actor_projected_gravity.n_max}\n\n        - actor_rel_robot_root_ang_vel:\n            func: rel_robot_root_ang_vel\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_min}\n                n_max: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_max}\n\n        - actor_dof_pos:\n            func: dof_pos\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_dof_pos.n_min}\n                n_max: ${domain_rand.obs_noise.actor_dof_pos.n_max}\n\n        - actor_dof_vel:\n            func: dof_vel\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: ${domain_rand.obs_noise.actor_dof_vel.n_min}\n                n_max: ${domain_rand.obs_noise.actor_dof_vel.n_max}\n\n        - actor_last_action:\n            func: last_action\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n\n        - critic_ref_dof_pos_cur:\n            func: ref_dof_pos_cur\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_dof_pos_fut:\n            func: ref_dof_pos_fut\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_root_height_fut:\n            func: ref_root_height_fut\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_root_height_cur:\n            func: ref_root_height_cur\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_global_anchor_diff:\n            func: global_anchor_diff\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_cur_heading_aligned_root_pos:\n            func: ref_motion_cur_heading_aligned_root_pos\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_fut_heading_aligned_root_pos:\n            func: ref_motion_fut_heading_aligned_root_pos\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_cur_heading_aligned_root_rot6d:\n            func: ref_motion_cur_heading_aligned_root_rot6d\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_fut_heading_aligned_root_rot6d:\n            func: ref_motion_fut_heading_aligned_root_rot6d\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_cur_heading_aligned_root_lin_vel:\n            func: ref_motion_cur_heading_aligned_root_lin_vel\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_fut_heading_aligned_root_lin_vel:\n            func: ref_motion_fut_heading_aligned_root_lin_vel\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_cur_heading_aligned_root_ang_vel:\n            func: ref_motion_cur_heading_aligned_root_ang_vel\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_ref_motion_fut_heading_aligned_root_ang_vel:\n            func: ref_motion_fut_heading_aligned_root_ang_vel\n            params:\n              ref_prefix: ${obs.critic_obs_prefix}\n\n        - critic_rel_robot_root_lin_vel:\n            func: rel_robot_root_lin_vel\n\n        - critic_rel_robot_root_ang_vel:\n            func: rel_robot_root_ang_vel\n\n        - critic_global_robot_bodylink_lin_vel_flat:\n            func: global_robot_bodylink_lin_vel_flat\n\n        - critic_global_robot_bodylink_ang_vel_flat:\n            func: global_robot_bodylink_ang_vel_flat\n\n        - critic_root_rel_robot_bodylink_pos_flat:\n            func: root_rel_robot_bodylink_pos_flat\n\n        - critic_root_rel_robot_bodylink_rot_mat_flat:\n            func: root_rel_robot_bodylink_rot_mat_flat\n\n        - critic_dof_pos:\n            func: dof_pos\n\n        - critic_dof_vel:\n            func: dof_vel\n\n        - critic_last_action:\n            func: last_action\n\n      enable_corruption: true\n      concatenate_terms: false\n"
  },
  {
    "path": "holomotion/config/env/observations/velocity_tracking/obs_velocity_tracking.yaml",
    "content": "# @package _global_\n\nobs:\n  context_length: 8\n  n_fut_frames: 0\n  target_fps: 50\n\n  obs_groups:\n    unified:\n      atomic_obs_list:\n        # Actor terms (sequence-style; flatten at serializer)\n        - actor_velocity_command:\n            func: velocity_command\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            mirror_func: mirror_velocity_command\n            mirror_config: {}\n        - actor_projected_gravity:\n            func: projected_gravity\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: -0.1\n                n_max: 0.1\n            mirror_func: mirror_vec3\n            mirror_config: {}\n        - actor_rel_robot_root_ang_vel:\n            func: rel_robot_root_ang_vel\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: -0.2\n                n_max: 0.2\n            mirror_func: mirror_axial_vec3\n            mirror_config: {}\n        - actor_dof_pos:\n            func: dof_pos\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: -0.01\n                n_max: 0.01\n            mirror_func: mirror_dof\n            mirror_config: {}\n        - actor_dof_vel:\n            func: dof_vel\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            noise:\n              type: AdditiveUniformNoiseCfg\n              params:\n                n_min: -0.5\n                n_max: 0.5\n            mirror_func: mirror_dof\n            mirror_config: {}\n        - actor_last_action:\n            func: last_action\n            history_length: ${obs.context_length}\n            flatten_history_dim: false\n            mirror_func: mirror_dof\n            mirror_config: {}\n\n        # Critic terms\n        - critic_velocity_command:\n            func: velocity_command\n        - critic_rel_robot_root_lin_vel:\n            func: rel_robot_root_lin_vel\n        - critic_rel_robot_root_ang_vel:\n            func: rel_robot_root_ang_vel\n        - critic_root_rel_robot_bodylink_pos_flat:\n            func: root_rel_robot_bodylink_pos_flat\n            params:\n              keybody_names: ${robot.key_bodies}\n        - critic_root_rel_robot_bodylink_rot_mat_flat:\n            func: root_rel_robot_bodylink_rot_mat_flat\n            params:\n              keybody_names: ${robot.key_bodies}\n        - critic_dof_pos:\n            func: dof_pos\n        - critic_dof_vel:\n            func: dof_vel\n        - critic_last_action:\n            func: last_action\n      enable_corruption: true\n      concatenate_terms: false\n"
  },
  {
    "path": "holomotion/config/env/rewards/motion_tracking/rew_motion_tracking.yaml",
    "content": "# @package _global_\n\nrewards:\n  _config:\n    reward_prefix: \"ref_\"\n\n  is_alive:\n    weight: 0.5\n    params: {}\n\n  root_pos_xy_tracking_exp:\n    weight: 1.0\n    params:\n      std: 0.2\n      ref_prefix: ${rewards._config.reward_prefix}\n\n  root_rot_tracking_exp:\n    weight: 0.5\n    params:\n      std: 0.4\n      ref_prefix: ${rewards._config.reward_prefix}\n\n  root_rel_keybodylink_pos_tracking_l2_exp:\n    weight: 1.0\n    params:\n      keybody_names: ${robot.key_bodies}\n      std: 0.3\n      ref_prefix: ${rewards._config.reward_prefix}\n\n  root_rel_keybodylink_rot_tracking_l2_exp:\n    weight: 2.0\n    params:\n      keybody_names: ${robot.key_bodies}\n      std: 0.4\n      ref_prefix: ${rewards._config.reward_prefix}\n\n  global_keybodylink_lin_vel_tracking_l2_exp:\n    weight: 1.0\n    params:\n      keybody_names: ${robot.key_bodies}\n      std: 1.0\n      ref_prefix: ${rewards._config.reward_prefix}\n\n  global_keybodylink_ang_vel_tracking_l2_exp:\n    weight: 1.0\n    params:\n      keybody_names: ${robot.key_bodies}\n      std: 3.14\n      ref_prefix: ${rewards._config.reward_prefix}\n\n  action_rate_l2:\n    weight: -0.1\n    params: {}\n\n  # joint_acc_l2:\n  #   weight: -1.0e-6\n  #   params: {}\n\n  joint_pos_limits:\n    weight: -10.0\n    params:\n      asset_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: robot\n        joint_names:\n          - \".*\"\n\n  undesired_contacts:\n    weight: -0.1\n    params:\n      threshold: 1.0\n      sensor_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: contact_forces\n        body_names:\n          - ${robot.undesired_contacts_regrex}\n"
  },
  {
    "path": "holomotion/config/env/rewards/velocity_tracking/rew_velocity_tracking.yaml",
    "content": "# @package _global_\n\nrewards:\n  stand_still_action_rate:\n    weight: -1.0\n    params:\n      command_name: base_velocity\n\n  feet_contact_without_cmd:\n    weight: 1.0\n    params:\n      command_name: base_velocity\n      sensor_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: contact_forces\n        body_names:\n          - \".*ankle_roll.*\"\n\n  track_stand_still_exp:\n    weight: 5.0\n    params:\n      command_name: base_velocity\n      std: 0.2\n\n  track_lin_vel_xy_heading_aligned_frame_exp:\n    weight: 3.0\n    params:\n      std: 0.5\n      command_name: base_velocity\n\n  track_ang_vel_z_heading_aligned_frame_exp:\n    weight: 3.0\n    params:\n      std: 0.5\n      command_name: base_velocity\n\n  is_alive:\n    weight: 0.15\n    params: {}\n\n  lin_vel_z_l2:\n    weight: -1.0\n    params: {}\n\n  ang_vel_xy_l2:\n    weight: -5.0e-2\n    params: {}\n\n  joint_acc_l2:\n    weight: -1.0e-6\n    params: {}\n\n  action_rate_l2:\n    weight: -1.0e-1\n    params: {}\n\n  joint_pos_limits:\n    weight: -5.0\n    params:\n      asset_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: robot\n        joint_names: [\".*\"]\n\n  feet_air_time_v4:\n    weight: 1.0\n    params:\n      threshold: 0.5\n      sensor_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: contact_forces\n        body_names: [\".*ankle_roll.*\"]\n      command_name: base_velocity\n\n  fly:\n    weight: -1.0\n    params:\n      threshold: 1.0\n      sensor_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: contact_forces\n        body_names: [\".*ankle_roll.*\"]\n\n  feet_too_near:\n    weight: -10.0\n    threshold: 0.2\n    params:\n      asset_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: robot\n        joint_names:\n          - .*ankle_roll.*\n\n  joint_deviation_l1_arms:\n    weight: -0.3\n    params:\n      asset_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: robot\n        joint_names:\n          - .*_hip_roll.*\n          - .*waist_roll.*\n          - .*waist_pitch.*\n          - .*_shoulder_roll.*\n          - .*_shoulder_yaw.*\n          - .*_wrist.*\n\n  joint_deviation_l1_legs_yaw:\n    weight: -0.15\n    params:\n      asset_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: robot\n        joint_names:\n          - .*waist_yaw.*\n          - .*_hip_yaw.*\n          - .*_elbow.*\n          - .*_ankle.*\n\n  joint_deviation_l1_legs:\n    weight: -0.02\n    params:\n      asset_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: robot\n        joint_names:\n          - .*_shoulder_pitch.*\n          - .*_hip_pitch.*\n          - .*_knee.*\n\n  flat_orientation_l2:\n    weight: -5.0\n    params: {}\n\n  base_height_l2:\n    weight: -10.0\n    params:\n      target_height: 0.78\n\n  feet_slide:\n    weight: -1.0\n    params:\n      asset_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: robot\n        body_names: [\".*ankle_roll.*\"]\n      sensor_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: contact_forces\n        body_names: [\".*ankle_roll.*\"]\n\n  undesired_contacts:\n    weight: -1.0\n    params:\n      threshold: 1.0\n      sensor_cfg:\n        _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg\n        name: contact_forces\n        body_names:\n          - ${robot.undesired_contacts_regrex}\n\n  torso_xy_ang_vel_l2_penalty:\n    weight: -1.0\n    params: {}\n\n  torso_upright_l2_penalty:\n    weight: -1.0\n    params: {}\n"
  },
  {
    "path": "holomotion/config/env/terminations/NO_termination.yaml",
    "content": "# @package _global_\n\nterminations: {}\n"
  },
  {
    "path": "holomotion/config/env/terminations/termination_motion_tracking.yaml",
    "content": "# @package _global_\n\nterminations:\n\n  time_out:\n    time_out: true\n\n  ref_gravity_projection_far:\n    params:\n      threshold: 0.8\n      ref_prefix: ${rewards._config.reward_prefix}\n\n  keybody_ref_z_far:\n    params:\n      threshold: 0.25\n      ref_prefix: ${rewards._config.reward_prefix}\n      keybody_names:\n        - pelvis\n        - left_ankle_roll_link\n        - right_ankle_roll_link\n        - left_wrist_yaw_link\n        - right_wrist_yaw_link\n\n  # keybody_ref_pos_far:\n  #   params:\n  #     threshold: 0.5\n  #     ref_prefix: ${rewards._config.reward_prefix}\n  #     keybody_names:\n  #       - pelvis"
  },
  {
    "path": "holomotion/config/env/terminations/termination_velocity_tracking.yaml",
    "content": "# @package _global_\n\nterminations:\n  time_out:\n    time_out: true\n\n  root_height_below_minimum:\n    params:\n      minimum_height: 0.2\n\n  bad_orientation:\n    params:\n      limit_angle: 0.8\n"
  },
  {
    "path": "holomotion/config/env/terrain/isaaclab_plane.yaml",
    "content": "# @package _global_\n\n# Simple flat terrain generated via IsaacLab height-field TerrainGenerator.\n# Uses random-uniform height field with zero noise as a flat patch.\nterrain:\n  terrain_type: generator\n  prim_path: /World/ground\n  static_friction: 1.0\n  dynamic_friction: 1.0\n  restitution: 0.0\n  friction_combine_mode: multiply\n  restitution_combine_mode: multiply\n  debug_vis: false\n  max_init_terrain_level: 0\n\n  # Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch.\n  # When false, env origins are placed on a regular grid as in the default importer.\n  random_spawn: true\n  # Keep random spawn points away from terrain edges to avoid spawning onto the outer border.\n  random_spawn_margin: 2.0\n\n  # TerrainGeneratorCfg parameters.\n  generator:\n    num_rows: 1\n    num_cols: 1\n    size: [10.0, 10.0]\n    border_width: 1000.0\n    horizontal_scale: 0.1\n    vertical_scale: 0.005\n    slope_threshold: null\n    difficulty_range: [0.0, 0.0]\n    color_scheme: height\n    sub_terrains:\n      plane:\n        type: plane\n        proportion: 1.0\n\n  # Offline visual material configuration (PreviewSurface, no MDL/Nucleus).\n  visual_material:\n    type: color\n    diffuse_color: [0.25, 0.25, 0.25]\n    metallic: 0.0\n    roughness: 0.5\n"
  },
  {
    "path": "holomotion/config/env/terrain/isaaclab_rough.yaml",
    "content": "# @package _global_\n\n# Rough height-field terrain for locomotion training.\n# Uses random-uniform height field to create continuous noise-like terrain.\nterrain:\n  terrain_type: generator\n  prim_path: /World/ground\n  static_friction: 1.0\n  dynamic_friction: 1.0\n  restitution: 0.0\n  friction_combine_mode: multiply\n  restitution_combine_mode: multiply\n  debug_vis: false\n  max_init_terrain_level: 4\n\n  # Randomize spawn position within each sub-terrain (recommended for locomotion).\n  random_spawn: true\n\n  random_spawn_margin: 4.0\n\n  # TerrainGeneratorCfg parameters.\n  generator:\n    num_rows: 4 # Number of terrain rows (difficulty levels)\n    num_cols: 4 # Number of terrain columns (types)\n    size: [20.0, 20.0] # Size of each sub-terrain in meters [length, width]\n    border_width: 1000.0 # Border around terrain in meters\n    horizontal_scale: 0.1 # Resolution in x-y plane\n    vertical_scale: 0.005 # Height resolution\n    slope_threshold: null # Slopes above this become vertical\n    difficulty_range: [0.0, 1.0] # Min and max difficulty\n    color_scheme: height # Use material shading instead of vertex colors\n    sub_terrains:\n      rough:\n        type: random_uniform\n        proportion: 1.0\n        noise_range: [0.0, 0.04]\n        noise_step: 0.05\n        downsampled_scale: 1.0\n\n  # Offline visual material configuration (PreviewSurface, no MDL/Nucleus).\n  visual_material:\n    type: color\n    diffuse_color: [0.25, 0.25, 0.25]\n    metallic: 0.0\n    roughness: 0.5\n"
  },
  {
    "path": "holomotion/config/env/velocity_tracking.yaml",
    "content": "# @package _global_\n\nenv:\n  _target_: holomotion.src.env.velocity_tracking.VelocityTrackingEnv\n  _recursive_: False\n  config:\n    experiment_name: ${experiment_name}\n    num_envs: ${num_envs}\n    env_spacing: 2.5\n    replicate_physics: true\n    headless: ${headless}\n    num_processes: ${num_processes}\n    main_process: ${main_process}\n    process_id: ${process_id}\n    ckpt_dir: null\n    disable_ref_viz: false\n    eval_log_dir: null\n    save_rendering_dir: null\n\n    robot: ${robot}\n    domain_rand: ${domain_rand}\n    rewards: ${rewards}\n    terrain: ${terrain}\n    obs: ${obs}\n    terminations: ${terminations}\n\n    simulation:\n      episode_length_s: 20\n      sim_freq: 200\n      control_decimation: 4\n      physx:\n        bounce_threshold_velocity: 0.5\n        gpu_max_rigid_patch_count: 327680\n\n    scene:\n      terrain: ${terrain}\n      lighting:\n        distant_light_intensity: 3000.0\n        dome_light_intensity: 1000.0\n      contact_sensor:\n        history_length: 3\n        force_threshold: 10.0\n        track_air_time: true\n        debug_vis: false\n\n    actions:\n      dof_pos:\n        type: joint_position\n        params:\n          asset_name: robot\n          joint_names:\n            - \".*\"\n          use_default_offset: true\n          scale: ${robot.actuators.action_scale}\n\n    commands:\n      base_velocity:\n        type: HoloMotionUniformVelocityCommandCfg\n        params:\n          asset_name: robot\n          resampling_time_range: [3, 10.0]\n          rel_standing_envs: 0.20\n          rel_yaw_envs: 0.30  # actual prob for sampled yaw-only is 0.3 * (1-0.2) = 0.24\n          rel_heading_envs: 1.0\n          heading_command: false\n          heading_control_stiffness: 0.5\n          debug_vis: true\n          ranges:\n            lin_vel_x: [-0.6, 1.0]\n            lin_vel_y: [-0.5, 0.5]\n            ang_vel_z: [-1.0, 1.0]\n            heading: [-3.14, 3.14]\n          # limit_ranges:\n          #   lin_vel_x: [-0.5, 1.0]\n          #   lin_vel_y: [-0.3, 0.3]\n          #   ang_vel_z: [-0.2, 0.2]\n          #   heading: [-3.14, 3.14]\n"
  },
  {
    "path": "holomotion/config/evaluation/eval_isaaclab.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - /robot: unitree/G1/29dof/29dof_training_isaaclab\n  - /env: motion_tracking\n  - /env/terrain: isaaclab_plane\n  - /env/terminations: NO_termination\n  - /env/domain_randomization: NO_domain_rand\n\nproject_name: ???\nexperiment_name: ???\n\nnum_envs: ???\nheadless: ???\n\nmotion_h5_path: null\ncheckpoint: null\nlog_dir: null\nckpt_pt_names: null\n\nnum_processes: ???\nmain_process: ???\nprocess_id: ???\n\ntimestamp: ${now:%Y%m%d_%H%M%S}\nbase_dir: logs\nexperiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name}\nsave_dir: ${experiment_dir}/.hydra\noutput_dir: ${experiment_dir}/output\nexperiment_save_dir: ???\n\nexport_policy: false\nexport_only: false\n\ndump_npzs: false\ncalc_per_clip_metrics: false\ngenerate_report: false\ndof_mode: \"23\"\n\nobs:\n  critic_obs_prefix: \"ref_\"\n\nrewards:\n  _config:\n    reward_prefix: \"ref_\"\n\nalgo:\n  config:\n    dynamo_backend: null\n    sampling_strategy: uniform\n    seed: 114514\n\nenv:\n  config:\n    seed: 42\n    simulation:\n      episode_length_s: 36000\n\nrobot:\n  motion:\n    backend: \"hdf5_v2\"\n    cache_max_num_clips: ${num_envs}\n    train_hdf5_roots: ${robot.motion.val_hdf5_roots}\n    val_hdf5_roots: ${motion_h5_path}\n    max_frame_length: 10000  # 20s\n    min_frame_length: 1\n    # handpicked_motion_names: ${handpicked_motion_names}\n    world_frame_normalization: false\n    dataloader:\n      num_workers: 2\n      pin_memory: true\n      persistent_workers: false\n      prefetch_factor: 1\n      timeout: 600\n      batch_progress_bar: true\n\nterrain:\n  terrain_type: generator\n  prim_path: /World/ground\n  static_friction: 1.0\n  dynamic_friction: 1.0\n  restitution: 0.0\n  friction_combine_mode: multiply\n  restitution_combine_mode: multiply\n  debug_vis: false\n  max_init_terrain_level: 0\n\n  # Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch.\n  # When false, env origins are placed on a regular grid as in the default importer.\n  random_spawn: true\n  # Keep random spawn points away from terrain edges to avoid spawning onto the outer border.\n  random_spawn_margin: 2.0\n\n  # TerrainGeneratorCfg parameters.\n  generator:\n    num_rows: 1\n    num_cols: 1\n    size: [10.0, 10.0]\n    border_width: 1000.0\n    horizontal_scale: 0.1\n    vertical_scale: 0.005\n    slope_threshold: null\n    difficulty_range: [0.0, 0.0]\n    color_scheme: height\n    sub_terrains:\n      plane:\n        type: plane\n        proportion: 1.0\n\n  # Offline visual material configuration (PreviewSurface, no MDL/Nucleus).\n  visual_material:\n    type: color\n    diffuse_color: [0.25, 0.25, 0.25]\n    metallic: 0.0\n    roughness: 0.5\n"
  },
  {
    "path": "holomotion/config/evaluation/eval_mujoco_sim2sim.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - _self_\n\n# Evaluation toggles\nheadless: false # true to run without GUI\nrecord_video: false # true to save MP4 recordings\nvideo_width: 1280\nvideo_height: 720\nvideo_fps: 30\ncamera_tracking: true # true to make camera follow robot root body\ncamera_height_offset:\n  0.3 # small offset (meters) above robot root for camera lookat point\n  # NOTE: This offsets where camera LOOKS AT, not camera position\n  # Use small values (0.2-0.5m) for proper framing, not large values\ncamera_distance:\n  4.0 # camera distance from lookat point (meters)\n  # Larger values = camera further away, smaller values = closer\ncamera_azimuth: 150.0 # default viewer/offscreen azimuth (deg), side-ish view\ncamera_elevation: -20.0 # default viewer/offscreen elevation (deg), slight downward angle\n\n# Offline evaluation pipeline (dataset mode)\nmotion_npz_dir: null\ndump_npzs: true\ndump_onnx_io_npy: false\ncalc_per_clip_metrics: false\ngenerate_report: false\nmetric_calculation: \"per_clip\" # \"per_clip\" (Macro) or \"per_frame\" (Micro)\ndof_mode: \"23\" # \"29\" for full DoF, \"23\" for reduced DoF\nfailure_pos_err_thresh_m: 0.25\nray_actors_per_gpu: 16 # persistent Ray actors per GPU for batch eval\nray_multi_ckpt_mode: \"split\" # \"split\" or \"per_checkpoint\" for multi-ONNX eval\nckpt_onnx_root_dir: null # optional ONNX root directory for multi-checkpoint eval\nckpt_onnx_names: null # optional list of ONNX file names to evaluate\nray_parallel_metrics_postprocess: true # parallelize per-checkpoint metrics/report/export with Ray when evaluating multiple ckpts\nray_metrics_postprocess_num_cpus: 24 # Ray resource accounting per checkpoint-postprocess task (0 = don't reserve CPUs)\nmetrics_threadpool_max_workers: 24 # per-checkpoint ThreadPoolExecutor workers inside metrics.py (null => auto=min(num_files, 24))\n\n# Termination / scheduling\nmax_policy_steps: 0 # 0 = unlimited; used in headless if no motion\npolicy_action_delay_step: 0 # max random action delay in 50 Hz policy steps; 0 disables delay\naction_delay_type: \"episode\" # \"episode\" samples once per reset, \"step\" re-samples every policy step\nunitree_viewer_dt: 0.0167 # ~60 Hz viewer sync\nunitree_domain_id: 0\nunitree_interface: \"lo\"\nunitree_use_joystick: false\nunitree_joystick_type: \"xbox\"\nunitree_print_scene_information: false\n\n# Debug options\ndebug_anchor_obs: false # when true, dump anchor pose/obs debug CSV for sim2sim\ndebug_anchor_obs_interval: 50 # log every N policy steps (>=1)\nuse_isaac_root_alignment: true # align free root state to IsaacLab reference root at frame 0\nisaac_action_playback: false # when true, use per-frame actions recorded from IsaacLab instead of ONNX policy\n\nrobot_xml_path: ???\n\nuse_gpu: true\n"
  },
  {
    "path": "holomotion/config/evaluation/eval_velocity_tracking.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - /robot: unitree/G1/29dof/29dof_training_isaaclab\n  - /env: velocity_tracking\n  - /env/terrain: isaaclab_plane\n  - /env/terminations: NO_termination\n  - /env/domain_randomization: NO_domain_rand\n\nproject_name: ???\nexperiment_name: ???\n\nnum_envs: ???\nheadless: ???\n\nmotion_h5_path: null\ncheckpoint: null\n\nnum_processes: ???\nmain_process: ???\nprocess_id: ???\n\ntimestamp: ${now:%Y%m%d_%H%M%S}\nbase_dir: logs\nexperiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name}\nsave_dir: ${experiment_dir}/.hydra\noutput_dir: ${experiment_dir}/output\nexperiment_save_dir: ???\n\ndump_npzs: false\nexport_policy: true\n\nalgo:\n  config:\n    dynamo_backend: null\n    # Video recording for offline evaluation (env.render() -> MP4 at target_fps)\n    record_video: false # enable MP4 recording during offline evaluation\n\nenv:\n  config:\n    simulation:\n      episode_length_s: 3600\n\nrobot:\n  motion:\n    cache_max_num_clips: ${num_envs}\n    max_frame_length: 10000\n    min_frame_length: 0\n    hdf5_root: ${motion_h5_path}\n    val_hdf5_root: ${motion_h5_path}\n    dataloader:\n      num_workers: 0\n\n# terrain:\n#   terrain_type: usd\n#   usd_path: assets/isaac/4.1/Isaac/Environments/Grid/gridroom_black.usd\n# usd_path: assets/isaac/4.1/Isaac/Environments/Terrains/rough_plane.usd\n\nterrain:\n  terrain_type: generator\n  prim_path: /World/ground\n  static_friction: 1.0\n  dynamic_friction: 1.0\n  restitution: 0.0\n  friction_combine_mode: multiply\n  restitution_combine_mode: multiply\n  debug_vis: false\n  max_init_terrain_level: 0\n\n  # Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch.\n  # When false, env origins are placed on a regular grid as in the default importer.\n  random_spawn: true\n\n  # TerrainGeneratorCfg parameters.\n  generator:\n    num_rows: 1\n    num_cols: 1\n    size: [8.0, 8.0]\n    border_width: 10.0\n    horizontal_scale: 0.1\n    vertical_scale: 0.005\n    slope_threshold: null\n    difficulty_range: [0.0, 0.0]\n    color_scheme: height\n    sub_terrains:\n      plane:\n        type: random_uniform\n        proportion: 1.0\n        noise_range: [0.0, 0.0]\n        noise_step: 0.25\n        downsampled_scale: 0.5\n\n  # Offline visual material configuration (PreviewSurface, no MDL/Nucleus).\n  visual_material:\n    type: color\n    diffuse_color: [0.25, 0.25, 0.25]\n    metallic: 0.0\n    roughness: 0.5\n"
  },
  {
    "path": "holomotion/config/modules/motion_tracking/motion_tracking_mlp.yaml",
    "content": "# @package _global_\n\nmodules:\n  actor:\n    type: MLP\n\n    hidden_norm: none\n    layer_config:\n      hidden_dims:\n        - 2048\n        - 1024\n        - 512\n        - 256\n      activation: SiLU\n\n    obs_norm:\n      enabled: true\n      epsilon: 1.0e-8 # Reduced for better stability in DDP\n      update_method: ema # ema or cumulative\n      ema_momentum: 1.0e-4\n      update_at_train: true\n      update_at_eval: false\n      enable_clipping: true # Enable clipping for DDP stability\n      clip_range: 10.0 # Reduced clip range for better stability\n      sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout\n\n    # Observation schema for motion tracking, from the actor's perspective.\n    obs_schema:\n      flattened_obs:\n        seq_len: ${obs.context_length}\n        terms:\n          - unified/actor_ref_gravity_projection_cur\n          - unified/actor_ref_base_linvel_cur\n          - unified/actor_ref_base_angvel_cur\n          - unified/actor_ref_dof_pos_cur\n          - unified/actor_ref_root_height_cur\n          - unified/actor_projected_gravity\n          - unified/actor_rel_robot_root_ang_vel\n          - unified/actor_dof_pos\n          - unified/actor_dof_vel\n          - unified/actor_last_action\n      flattened_obs_fut:\n        seq_len: ${obs.n_fut_frames}\n        terms:\n          - unified/actor_ref_dof_pos_fut\n          - unified/actor_ref_root_height_fut\n          - unified/actor_ref_gravity_projection_fut\n          - unified/actor_ref_base_linvel_fut\n          - unified/actor_ref_base_angvel_fut\n\n    output_dim: robot_action_dim\n\n  critic:\n    type: MLP\n\n    obs_norm:\n      enabled: true\n      epsilon: 1.0e-8 # Reduced for better stability in DDP\n      update_method: ema # ema or cumulative\n      ema_momentum: 1.0e-4\n      update_at_train: true\n      update_at_eval: false\n      enable_clipping: true # Enable clipping for DDP stability\n      clip_range: 10.0 # Reduced clip range for better stability\n      sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout\n\n    hidden_norm: rmsnorm\n\n    layer_config:\n      hidden_dims:\n        - 2048\n        - 2048\n        - 2048\n        - 2048\n      activation: SiLU\n\n    obs_schema:\n      flattened_obs:\n        seq_len: 1\n        terms:\n          - unified/critic_ref_dof_pos_cur\n          - unified/critic_global_anchor_diff\n          - unified/critic_ref_motion_cur_heading_aligned_root_pos\n          - unified/critic_ref_motion_cur_heading_aligned_root_rot6d\n          - unified/critic_ref_motion_cur_heading_aligned_root_lin_vel\n          - unified/critic_ref_motion_cur_heading_aligned_root_ang_vel\n          - unified/critic_rel_robot_root_lin_vel\n          - unified/critic_rel_robot_root_ang_vel\n          - unified/critic_global_robot_bodylink_lin_vel_flat\n          - unified/critic_global_robot_bodylink_ang_vel_flat\n          - unified/critic_root_rel_robot_bodylink_pos_flat\n          - unified/critic_root_rel_robot_bodylink_rot_mat_flat\n          - unified/critic_dof_pos\n          - unified/critic_dof_vel\n          - unified/critic_last_action\n      flattened_obs_fut:\n        seq_len: ${obs.n_fut_frames}\n        terms:\n          - unified/critic_ref_dof_pos_fut\n          - unified/critic_ref_root_height_fut\n          - unified/critic_ref_motion_fut_heading_aligned_root_pos\n          - unified/critic_ref_motion_fut_heading_aligned_root_rot6d\n          - unified/critic_ref_motion_fut_heading_aligned_root_lin_vel\n          - unified/critic_ref_motion_fut_heading_aligned_root_ang_vel\n\n    output_dim: 1\n"
  },
  {
    "path": "holomotion/config/modules/motion_tracking/motion_tracking_tf-moe.yaml",
    "content": "# @package _global_\n\nmodules:\n  actor:\n    type: ReferenceRoutedGroupedMoETransformerPolicy\n\n    use_checkpointing: false # use gradient checkpointing to save GRAM significantly\n\n    # MoE-specific hyperparameters\n    num_fine_experts: 16\n    num_shared_experts: 1\n    top_k: 2\n    moe_loss_coef: 0.0\n    routing_score_fn: ${algo.config.moe_router.routing_score_fn}\n    routing_scale: ${algo.config.moe_router.routing_scale}\n    use_dynamic_bias: ${algo.config.moe_router.use_dynamic_bias}\n    bias_update_rate: ${algo.config.moe_router.bias_update_rate}\n    expert_bias_clip: ${algo.config.moe_router.expert_bias_clip}\n\n    # Transformer hyperparameters - smaller model for stability\n    obs_embed_mlp_hidden: 2048\n    router_embed_mlp_hidden: 2048\n    d_model: 512\n    n_heads: 8\n    n_kv_heads: 4\n    use_gated_attn: true\n\n    n_layers: 3\n    ff_mult: 2.0\n    ff_mult_dense: 4\n    attn_dropout: 0.0\n    mlp_dropout: 0.0\n    max_ctx_len: 32\n\n    # Auxiliary dynamics prediction weights (0.0 = disabled)\n    aux_sys_id_weight: 0.0\n    aux_dynamics_weight: 0.0\n\n    obs_norm:\n      enabled: true\n      epsilon: 1.0e-8 # Reduced for better stability in DDP\n      update_method: ema # ema or cumulative\n      ema_momentum: 1.0e-4\n      update_at_train: true\n      update_at_eval: false\n      enable_clipping: true # Enable clipping for DDP stability\n      clip_range: 10.0 # Reduced clip range for better stability\n      sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout\n\n    # Observation schema for motion tracking, from the actor's perspective.\n    obs_schema:\n      flattened_obs:\n        seq_len: ${obs.context_length}\n        terms:\n          - unified/actor_ref_gravity_projection_cur\n          - unified/actor_ref_base_linvel_cur\n          - unified/actor_ref_base_angvel_cur\n          - unified/actor_ref_dof_pos_cur\n          - unified/actor_ref_root_height_cur\n          - unified/actor_projected_gravity\n          - unified/actor_rel_robot_root_ang_vel\n          - unified/actor_dof_pos\n          - unified/actor_dof_vel\n          - unified/actor_last_action\n      flattened_obs_fut:\n        seq_len: ${obs.n_fut_frames}\n        terms:\n          - unified/actor_ref_dof_pos_fut\n          - unified/actor_ref_root_height_fut\n          - unified/actor_ref_gravity_projection_fut\n          - unified/actor_ref_base_linvel_fut\n          - unified/actor_ref_base_angvel_fut\n\n    output_dim: robot_action_dim\n\n  critic:\n    type: MLP\n\n    obs_norm:\n      enabled: true\n      epsilon: 1.0e-8 # Reduced for better stability in DDP\n      update_method: ema # ema or cumulative\n      ema_momentum: 1.0e-4\n      update_at_train: true\n      update_at_eval: false\n      enable_clipping: true # Enable clipping for DDP stability\n      clip_range: 10.0 # Reduced clip range for better stability\n      sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout\n\n    hidden_norm: rmsnorm\n\n    layer_config:\n      hidden_dims:\n        - 2048\n        - 2048\n        - 2048\n        - 2048\n      activation: SiLU\n\n    obs_schema:\n      flattened_obs:\n        seq_len: 1\n        terms:\n          - unified/critic_ref_dof_pos_cur\n          - unified/critic_global_anchor_diff\n          - unified/critic_ref_motion_cur_heading_aligned_root_pos\n          - unified/critic_ref_motion_cur_heading_aligned_root_rot6d\n          - unified/critic_ref_motion_cur_heading_aligned_root_lin_vel\n          - unified/critic_ref_motion_cur_heading_aligned_root_ang_vel\n          - unified/critic_rel_robot_root_lin_vel\n          - unified/critic_rel_robot_root_ang_vel\n          - unified/critic_global_robot_bodylink_lin_vel_flat\n          - unified/critic_global_robot_bodylink_ang_vel_flat\n          - unified/critic_root_rel_robot_bodylink_pos_flat\n          - unified/critic_root_rel_robot_bodylink_rot_mat_flat\n          - unified/critic_dof_pos\n          - unified/critic_dof_vel\n          - unified/critic_last_action\n      flattened_obs_fut:\n        seq_len: ${obs.n_fut_frames}\n        terms:\n          - unified/critic_ref_dof_pos_fut\n          - unified/critic_ref_root_height_fut\n          - unified/critic_ref_motion_fut_heading_aligned_root_pos\n          - unified/critic_ref_motion_fut_heading_aligned_root_rot6d\n          - unified/critic_ref_motion_fut_heading_aligned_root_lin_vel\n          - unified/critic_ref_motion_fut_heading_aligned_root_ang_vel\n\n    output_dim: 1\n"
  },
  {
    "path": "holomotion/config/modules/velocity_tracking/velocity_tracking_mlp.yaml",
    "content": "# @package _global_\n\nmodules:\n  actor:\n    type: MLP\n    fix_sigma: false\n    noise_std_type: scalar\n    obs_norm:\n      enabled: true\n      epsilon: 1.0e-8 # Reduced for better stability in DDP\n      update_method: ema # ema or cumulative\n      ema_momentum: 1.0e-4\n      update_at_train: true\n      update_at_eval: false\n      enable_clipping: true # Enable clipping for DDP stability\n      clip_range: 10.0 # Reduced clip range for better stability\n      sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout\n\n    hidden_norm: rmsnorm\n    layer_config:\n      hidden_dims:\n        - 512\n        - 512\n        - 512\n      activation: SiLU\n\n    obs_schema:\n      flattened_obs:\n        seq_len: ${obs.context_length}\n        terms:\n          - unified/actor_velocity_command\n          - unified/actor_projected_gravity\n          - unified/actor_rel_robot_root_ang_vel\n          - unified/actor_dof_pos\n          - unified/actor_dof_vel\n          - unified/actor_last_action\n\n    output_dim: robot_action_dim\n\n  critic:\n    type: MLP\n\n    obs_norm:\n      enabled: true\n      epsilon: 1.0e-8 # Reduced for better stability in DDP\n      update_method: ema # ema or cumulative\n      ema_momentum: 1.0e-4\n      update_at_train: true\n      update_at_eval: false\n      enable_clipping: true # Enable clipping for DDP stability\n      clip_range: 10.0 # Reduced clip range for better stability\n      sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout\n\n    hidden_norm: rmsnorm\n    layer_config:\n      hidden_dims:\n        - 512\n        - 512\n        - 512\n      activation: SiLU\n\n    obs_schema:\n      flattened_obs:\n        seq_len: 1\n        terms:\n          - unified/critic_velocity_command\n          - unified/critic_rel_robot_root_lin_vel\n          - unified/critic_rel_robot_root_ang_vel\n          - unified/critic_root_rel_robot_bodylink_pos_flat\n          - unified/critic_root_rel_robot_bodylink_rot_mat_flat\n          - unified/critic_dof_pos\n          - unified/critic_dof_vel\n          - unified/critic_last_action\n\n    output_dim: 1\n"
  },
  {
    "path": "holomotion/config/motion_retargeting/gmr_to_holomotion.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - _self_\n\nhydra:\n  job:\n    chdir: false\n\nio:\n  src_dir: ???\n  robot_config: ???\n  out_root: ???\n  ref_dir: holomotion/src/motion_retargeting/utils\n\nprocessing:\n  target_fps: 50\n  fast_interpolate: true\n  skip_existing: true\n  debug_mode: false\n\nray:\n  num_workers: 0\n  ray_address: \"\"\n\nnaming:\n  emit_prefixed: true\n  emit_legacy: false\n\npreprocess:\n  # Available stages:\n  # ['filename_as_motionkey','legacy_to_ref_keys','add_legacy_keys',\n  #  'slicing','apply_butterworth_filter','add_padding','tagging']\n  # Empty list [] means no preprocessing stages applied\n  pipeline: []\n\nslicing:\n  window_size: 500\n  overlap: 50\n\nfiltering:\n  type: butterworth\n  butter_cutoff_hz: 3.0\n  butter_order: 4\n\npadding:\n  # Robot config path for FK and default joint angles\n  # If empty, uses io.robot_config\n  robot_config_path: ${io.robot_config}\n  # Duration of stand-still padding before/after motion (seconds)\n  stand_still_time: 1.0\n  # Duration of transition between default pose and motion (seconds)\n  transition_time: 1.5\n\ntagging:\n  # when empty, write tags to: <out_root>/kinematic_tags.json\n  output_json_path: \"\"\n"
  },
  {
    "path": "holomotion/config/motion_retargeting/holomotion_preprocess.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - _self_\n\nhydra:\n  job:\n    chdir: false\n\nio:\n  src_root: ???\n  out_root: ???\n\npreprocess:\n  # Available stages:\n  # ['filename_as_motionkey','legacy_to_ref_keys','add_legacy_keys',\n  #  'slicing','apply_butterworth_filter','add_padding','tagging']\n  pipeline: []  \n\nslicing:\n  window_size: 500\n  overlap: 50\n\nfiltering:\n  type: butterworth\n  butter_cutoff_hz: 3.0\n  butter_order: 4\n\npadding:\n  # Robot config path for FK and default joint angles\n  robot_config_path: \"\"\n  # Duration of stand-still padding before/after motion (seconds)\n  stand_still_time: 1.0\n  # Duration of transition between default pose and motion (seconds)\n  transition_time: 1.0\n\ntagging:\n  # when empty, write tags to: <out_root>/kinematic_tags.json\n  output_json_path: \"\"\n\nray:\n  enabled: true\n  num_workers: 2  # 0 = use all available CPUs\n  ray_address: \"\"  # empty = local; otherwise connect to existing Ray cluster\n\n"
  },
  {
    "path": "holomotion/config/motion_retargeting/kinematic_filter.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - _self_\n\nhydra:\n  job:\n    chdir: false\n\nio:\n  dataset_root: \"\" # absolute path to dataset root containing kinematic_tags.json\n\nfiltering:\n  output_yaml: \"\" # optional; defaults to <dataset_root>/excluded_kinematic_motion_names.yaml\n\nschema:\n  across: union\n  thresholds:\n    kinematic_features.root_linear_speed.max: { op: \">\", value: 10.0 }\n    kinematic_features.root_angular_speed.max: { op: \">\", value: 20.0 }\n    kinematic_features.root_delta_z.max: { op: \">\", value: 2.0 }\n    kinematic_features.jerk.max: { op: \">\", value: 2000.0 }\n"
  },
  {
    "path": "holomotion/config/motion_retargeting/pack_hdf5_database.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - _self_\n  - /robot: unitree/G1/29dof/29dof_training_isaaclab\n\nhydra:\n  job:\n    chdir: false\n\n# IO\nprecomputed_npz_root: ???\nhdf5_root: ???\n\n# Runtime\n# Optimal parameters for distributed JuiceFS training with millions of clips:\n# - chunks_t: Larger chunks (2048-4096) reduce metadata overhead and improve sequential read performance on JuiceFS\n#   Balance: larger chunks = better sequential I/O, but too large wastes memory\nchunks_t: 1024\n# - compression: lzf provides fast decompression suitable for training workloads\n#   Alternatives: gzip (better compression, slower), none (fastest but largest files)\ncompression: lzf  # lzf|gzip|none\n# - shard_target_gb: Larger shards (5-10 GB) reduce shard count and metadata overhead for millions of clips\n#   Balance: fewer shards = less metadata overhead, better for distributed access; more shards = better parallelism\n#   For millions of clips, 5-10 GB reduces shard count while maintaining good parallelism\nshard_target_gb: 1.0\n# - num_jobs: Parallel workers for packing process (should match available CPU cores)\nnum_jobs: 16\ndebug_local_mode: false\n\n\n"
  },
  {
    "path": "holomotion/config/motion_retargeting/pack_hdf5_v2.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - _self_\n  - /robot: unitree/G1/29dof/29dof_training_isaaclab\n\nhydra:\n  job:\n    chdir: false\n\n# IO\nholomotion_npz_root: ???\nhdf5_root: ???\n\n# Runtime\nchunks_t: 1024\ncompression: lzf  # lzf|gzip|none\nshard_target_gb: 1.0\nshard_target_mode: h5_filesize  # h5_filesize|npz_filesize|uncompressed_nbytes\nnum_jobs: 16\ndebug_local_mode: false\n"
  },
  {
    "path": "holomotion/config/motion_retargeting/unitree_G1_29dof_retargeting.yaml",
    "content": "robot:\n  humanoid_type: unitree/G1/29dof\n\n  asset:\n    smpl_dir: \"assets/smpl\"\n    assetRoot: \"./\"\n    assetFileName: \"assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml\"\n    training_mjcfName: \"assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml\"\n\nvideo_dir: ${motion_npz_root}/video_rendering\nskip_frames: 3 # when skip frames=1, it means 30hz, when skip frames=2, it means 15hz, etc.\nshow_markers: False\nmax_workers: 12\n"
  },
  {
    "path": "holomotion/config/mujoco_eval/sim2sim.yaml",
    "content": "# @package _group_\n\ndefaults:\n  - _self_\n\nenabled: false\nmodel_type: \"holomotion\"\n\n# Evaluation toggles\nheadless: false\nrecord_video: false\nvideo_width: 1280\nvideo_height: 720\nvideo_fps: 30\ncamera_tracking: true\ncamera_height_offset: 0.3\ncamera_distance: 4.0\ncamera_azimuth: 150.0\ncamera_elevation: -20.0\n\n# Input/output\nrobot_xml_path: null\nmotion_npz_dir: null\nmotion_npz_path: null\nckpt_onnx_path: null\nckpt_onnx_root_dir: null\nckpt_onnx_names: null\n\n# Offline evaluation pipeline\ndump_npzs: true\ncalc_per_clip_metrics: false\ngenerate_report: false\nmetric_calculation: \"per_clip\"\ndof_mode: \"23\"\nfailure_pos_err_thresh_m: 0.25\nray_actors_per_gpu: 16\nray_multi_ckpt_mode: \"split\"\n\n# Runtime\nuse_gpu: true\n"
  },
  {
    "path": "holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml",
    "content": "# @package _global_\n\nrobot:\n  humanoid_type: unitree/G1/29dof\n\n  dof_obs_size: 29\n  actions_dim: 29\n  num_bodies: 30\n  num_extend_bodies: 0\n  undesired_contacts_regrex: \"^(?!left_ankle_roll_link$)(?!right_ankle_roll_link$)(?!left_wrist_yaw_link$)(?!right_wrist_yaw_link$).+$\"\n  torso_name: \"torso_link\"\n  anchor_body: \"torso_link\"\n\n  key_bodies:\n    - \"pelvis\"\n    - \"left_hip_roll_link\"\n    - \"left_knee_link\"\n    - \"left_ankle_pitch_link\"\n    - \"right_hip_roll_link\"\n    - \"right_knee_link\"\n    - \"right_ankle_pitch_link\"\n    - \"torso_link\"\n    - \"left_shoulder_roll_link\"\n    - \"left_elbow_link\"\n    - \"left_wrist_yaw_link\"\n    - \"right_shoulder_roll_link\"\n    - \"right_elbow_link\"\n    - \"right_wrist_yaw_link\"\n\n  key_dofs:\n    - \"left_knee_joint\"\n    - \"right_knee_joint\"\n    - \"left_elbow_joint\"\n    - \"right_elbow_joint\"\n\n  dof_names:\n    - \"left_hip_pitch_joint\"\n    - \"left_hip_roll_joint\"\n    - \"left_hip_yaw_joint\"\n    - \"left_knee_joint\"\n    - \"left_ankle_pitch_joint\"\n    - \"left_ankle_roll_joint\"\n    - \"right_hip_pitch_joint\"\n    - \"right_hip_roll_joint\"\n    - \"right_hip_yaw_joint\"\n    - \"right_knee_joint\"\n    - \"right_ankle_pitch_joint\"\n    - \"right_ankle_roll_joint\"\n    - \"waist_yaw_joint\"\n    - \"waist_roll_joint\"\n    - \"waist_pitch_joint\"\n    - \"left_shoulder_pitch_joint\"\n    - \"left_shoulder_roll_joint\"\n    - \"left_shoulder_yaw_joint\"\n    - \"left_elbow_joint\"\n    - \"left_wrist_roll_joint\"\n    - \"left_wrist_pitch_joint\"\n    - \"left_wrist_yaw_joint\"\n    - \"right_shoulder_pitch_joint\"\n    - \"right_shoulder_roll_joint\"\n    - \"right_shoulder_yaw_joint\"\n    - \"right_elbow_joint\"\n    - \"right_wrist_roll_joint\"\n    - \"right_wrist_pitch_joint\"\n    - \"right_wrist_yaw_joint\"\n\n  # ========== Unified DOF Groupings ==========\n  # Main anatomical groupings for DOF\n  arm_dof_names:\n    - \"left_shoulder_pitch_joint\"\n    - \"left_shoulder_roll_joint\"\n    - \"left_shoulder_yaw_joint\"\n    - \"left_elbow_joint\"\n    - \"left_wrist_roll_joint\"\n    - \"left_wrist_pitch_joint\"\n    - \"left_wrist_yaw_joint\"\n    - \"right_shoulder_pitch_joint\"\n    - \"right_shoulder_roll_joint\"\n    - \"right_shoulder_yaw_joint\"\n    - \"right_elbow_joint\"\n    - \"right_wrist_roll_joint\"\n    - \"right_wrist_pitch_joint\"\n    - \"right_wrist_yaw_joint\"\n\n  waist_dof_names:\n    - \"waist_yaw_joint\"\n    - \"waist_roll_joint\"\n    - \"waist_pitch_joint\"\n\n  leg_dof_names:\n    - \"left_hip_pitch_joint\"\n    - \"left_hip_roll_joint\"\n    - \"left_hip_yaw_joint\"\n    - \"left_knee_joint\"\n    - \"left_ankle_pitch_joint\"\n    - \"left_ankle_roll_joint\"\n    - \"right_hip_pitch_joint\"\n    - \"right_hip_roll_joint\"\n    - \"right_hip_yaw_joint\"\n    - \"right_knee_joint\"\n    - \"right_ankle_pitch_joint\"\n    - \"right_ankle_roll_joint\"\n\n  # Side-specific groupings for DOF\n  left_arm_dof_names:\n    - \"left_shoulder_pitch_joint\"\n    - \"left_shoulder_roll_joint\"\n    - \"left_shoulder_yaw_joint\"\n    - \"left_elbow_joint\"\n    - \"left_wrist_roll_joint\"\n    - \"left_wrist_pitch_joint\"\n    - \"left_wrist_yaw_joint\"\n\n  right_arm_dof_names:\n    - \"right_shoulder_pitch_joint\"\n    - \"right_shoulder_roll_joint\"\n    - \"right_shoulder_yaw_joint\"\n    - \"right_elbow_joint\"\n    - \"right_wrist_roll_joint\"\n    - \"right_wrist_pitch_joint\"\n    - \"right_wrist_yaw_joint\"\n\n  left_leg_dof_names:\n    - \"left_hip_pitch_joint\"\n    - \"left_hip_roll_joint\"\n    - \"left_hip_yaw_joint\"\n    - \"left_knee_joint\"\n    - \"left_ankle_pitch_joint\"\n    - \"left_ankle_roll_joint\"\n\n  right_leg_dof_names:\n    - \"right_hip_pitch_joint\"\n    - \"right_hip_roll_joint\"\n    - \"right_hip_yaw_joint\"\n    - \"right_knee_joint\"\n    - \"right_ankle_pitch_joint\"\n    - \"right_ankle_roll_joint\"\n\n  # Combined groupings for DOF (for backward compatibility and common usage)\n  upper_body_dof_names: ${robot.arm_dof_names} # Alias for arm_dof_names\n\n  lower_body_dof_names:\n    - \"left_hip_pitch_joint\"\n    - \"left_hip_roll_joint\"\n    - \"left_hip_yaw_joint\"\n    - \"left_knee_joint\"\n    - \"left_ankle_pitch_joint\"\n    - \"left_ankle_roll_joint\"\n    - \"right_hip_pitch_joint\"\n    - \"right_hip_roll_joint\"\n    - \"right_hip_yaw_joint\"\n    - \"right_knee_joint\"\n    - \"right_ankle_pitch_joint\"\n    - \"right_ankle_roll_joint\"\n    - \"waist_yaw_joint\"\n    - \"waist_roll_joint\"\n    - \"waist_pitch_joint\"\n\n  # ========== Unified Body Groupings ==========\n  # Main anatomical groupings for bodies\n  arm_body_names:\n    - \"left_shoulder_pitch_link\"\n    - \"left_shoulder_roll_link\"\n    - \"left_shoulder_yaw_link\"\n    - \"left_elbow_link\"\n    - \"left_wrist_roll_link\"\n    - \"left_wrist_pitch_link\"\n    - \"left_wrist_yaw_link\"\n    - \"right_shoulder_pitch_link\"\n    - \"right_shoulder_roll_link\"\n    - \"right_shoulder_yaw_link\"\n    - \"right_elbow_link\"\n    - \"right_wrist_roll_link\"\n    - \"right_wrist_pitch_link\"\n    - \"right_wrist_yaw_link\"\n\n  head_hand_bodies:\n    - \"torso_link\"\n    - \"left_wrist_yaw_link\"\n    - \"right_wrist_yaw_link\"\n\n  torso_body_names:\n    - \"waist_yaw_link\"\n    - \"waist_roll_link\"\n    - \"torso_link\"\n\n  leg_body_names:\n    - \"left_hip_pitch_link\"\n    - \"left_hip_roll_link\"\n    - \"left_hip_yaw_link\"\n    - \"left_knee_link\"\n    - \"left_ankle_pitch_link\"\n    - \"left_ankle_roll_link\"\n    - \"right_hip_pitch_link\"\n    - \"right_hip_roll_link\"\n    - \"right_hip_yaw_link\"\n    - \"right_knee_link\"\n    - \"right_ankle_pitch_link\"\n    - \"right_ankle_roll_link\"\n\n  # Side-specific groupings for bodies\n  left_arm_body_names:\n    - \"left_shoulder_pitch_link\"\n    - \"left_shoulder_roll_link\"\n    - \"left_shoulder_yaw_link\"\n    - \"left_elbow_link\"\n    - \"left_wrist_roll_link\"\n    - \"left_wrist_pitch_link\"\n    - \"left_wrist_yaw_link\"\n\n  right_arm_body_names:\n    - \"right_shoulder_pitch_link\"\n    - \"right_shoulder_roll_link\"\n    - \"right_shoulder_yaw_link\"\n    - \"right_elbow_link\"\n    - \"right_wrist_roll_link\"\n    - \"right_wrist_pitch_link\"\n    - \"right_wrist_yaw_link\"\n\n  left_leg_body_names:\n    - \"left_hip_pitch_link\"\n    - \"left_hip_roll_link\"\n    - \"left_hip_yaw_link\"\n    - \"left_knee_link\"\n    - \"left_ankle_pitch_link\"\n    - \"left_ankle_roll_link\"\n\n  right_leg_body_names:\n    - \"right_hip_pitch_link\"\n    - \"right_hip_roll_link\"\n    - \"right_hip_yaw_link\"\n    - \"right_knee_link\"\n    - \"right_ankle_pitch_link\"\n    - \"right_ankle_roll_link\"\n\n  body_names:\n    - \"pelvis\"\n    - \"left_hip_pitch_link\"\n    - \"left_hip_roll_link\"\n    - \"left_hip_yaw_link\"\n    - \"left_knee_link\"\n    - \"left_ankle_pitch_link\"\n    - \"left_ankle_roll_link\"\n    - \"right_hip_pitch_link\"\n    - \"right_hip_roll_link\"\n    - \"right_hip_yaw_link\"\n    - \"right_knee_link\"\n    - \"right_ankle_pitch_link\"\n    - \"right_ankle_roll_link\"\n    - \"waist_yaw_link\"\n    - \"waist_roll_link\"\n    - \"torso_link\"\n    - \"left_shoulder_pitch_link\"\n    - \"left_shoulder_roll_link\"\n    - \"left_shoulder_yaw_link\"\n    - \"left_elbow_link\"\n    - \"left_wrist_roll_link\"\n    - \"left_wrist_pitch_link\"\n    - \"left_wrist_yaw_link\"\n    - \"right_shoulder_pitch_link\"\n    - \"right_shoulder_roll_link\"\n    - \"right_shoulder_yaw_link\"\n    - \"right_elbow_link\"\n    - \"right_wrist_roll_link\"\n    - \"right_wrist_pitch_link\"\n    - \"right_wrist_yaw_link\"\n\n  init_state:\n    pos: [0.0, 0.0, 0.8] # x,y,z [m]\n    rot: [0.0, 0.929, 0.341, 0.298] # x,y,z,w [quat]\n    lin_vel: [0.0, 0.0, 0.0] # x,y,z [m/s]\n    ang_vel: [0.0, 0.0, 0.0] # x,y,z [rad/s]\n    default_joint_angles: # = target angles [rad] when action = 0.0\n      left_hip_pitch_joint: -0.312\n      left_hip_roll_joint: 0.0\n      left_hip_yaw_joint: 0.0\n      left_knee_joint: 0.669\n      left_ankle_pitch_joint: -0.363\n      left_ankle_roll_joint: 0.0\n\n      right_hip_pitch_joint: -0.312\n      right_hip_roll_joint: 0.0\n      right_hip_yaw_joint: 0.0\n      right_knee_joint: 0.669\n      right_ankle_pitch_joint: -0.363\n      right_ankle_roll_joint: 0.0\n\n      waist_yaw_joint: 0.\n      waist_roll_joint: 0.\n      waist_pitch_joint: 0.1\n\n      left_shoulder_pitch_joint: 0.2\n      left_shoulder_roll_joint: 0.2\n      left_shoulder_yaw_joint: 0.0\n      left_elbow_joint: 0.6\n      left_wrist_roll_joint: 0.0\n      left_wrist_pitch_joint: 0.0\n      left_wrist_yaw_joint: 0.0\n\n      right_shoulder_pitch_joint: 0.2\n      right_shoulder_roll_joint: -0.2\n      right_shoulder_yaw_joint: 0.0\n      right_elbow_joint: 0.6\n      right_wrist_roll_joint: 0.0\n      right_wrist_pitch_joint: 0.0\n      right_wrist_yaw_joint: 0.0\n\n  actuators:\n    actuator_type: unitree_erfi # implicit, unitree, or unitree_erfi\n    ema_filter_enabled: false\n    ema_filter_alpha: 1.0\n    all_joints:\n      joint_names_expr:\n        - \".*_hip_yaw_joint\"\n        - \".*_hip_roll_joint\"\n        - \".*_hip_pitch_joint\"\n        - \".*_knee_joint\"\n        - \".*_ankle_pitch_joint\"\n        - \".*_ankle_roll_joint\"\n        - \"waist_yaw_joint\"\n        - \"waist_roll_joint\"\n        - \"waist_pitch_joint\"\n        - \".*_shoulder_pitch_joint\"\n        - \".*_shoulder_roll_joint\"\n        - \".*_shoulder_yaw_joint\"\n        - \".*_elbow_joint\"\n        - \".*_wrist_roll_joint\"\n        - \".*_wrist_pitch_joint\"\n        - \".*_wrist_yaw_joint\"\n\n      effort_limit_sim:\n        \".*_hip_yaw_joint\": 88.0\n        \".*_hip_roll_joint\": 139.0\n        \".*_hip_pitch_joint\": 88.0\n        \".*_knee_joint\": 139.0\n        \".*_ankle_pitch_joint\": 50.0\n        \".*_ankle_roll_joint\": 50.0\n        \"waist_yaw_joint\": 88.0\n        \"waist_roll_joint\": 50.0\n        \"waist_pitch_joint\": 50.0\n        \".*_shoulder_pitch_joint\": 25.0\n        \".*_shoulder_roll_joint\": 25.0\n        \".*_shoulder_yaw_joint\": 25.0\n        \".*_elbow_joint\": 25.0\n        \".*_wrist_roll_joint\": 25.0\n        \".*_wrist_pitch_joint\": 5.0\n        \".*_wrist_yaw_joint\": 5.0\n\n      velocity_limit_sim:\n        \".*_hip_yaw_joint\": 32.0\n        \".*_hip_roll_joint\": 20.0\n        \".*_hip_pitch_joint\": 32.0\n        \".*_knee_joint\": 20.0\n        \".*_ankle_pitch_joint\": 37.0\n        \".*_ankle_roll_joint\": 37.0\n        \"waist_yaw_joint\": 32.0\n        \"waist_roll_joint\": 37.0\n        \"waist_pitch_joint\": 37.0\n        \".*_shoulder_pitch_joint\": 37.0\n        \".*_shoulder_roll_joint\": 37.0\n        \".*_shoulder_yaw_joint\": 37.0\n        \".*_elbow_joint\": 37.0\n        \".*_wrist_roll_joint\": 37.0\n        \".*_wrist_pitch_joint\": 22.0\n        \".*_wrist_yaw_joint\": 22.0\n\n      stiffness:\n        \".*_hip_pitch_joint\": 40.17923847\n        \".*_hip_roll_joint\": 99.09842778\n        \".*_hip_yaw_joint\": 40.17923847\n        \".*_knee_joint\": 99.09842778\n        \".*_ankle_pitch_joint\": 28.50124620\n        \".*_ankle_roll_joint\": 28.50124620\n        \"waist_yaw_joint\": 40.17923847\n        \"waist_roll_joint\": 28.50124620\n        \"waist_pitch_joint\": 28.50124620\n        \".*_shoulder_pitch_joint\": 14.25062310\n        \".*_shoulder_roll_joint\": 14.25062310\n        \".*_shoulder_yaw_joint\": 14.25062310\n        \".*_elbow_joint\": 14.25062310\n        \".*_wrist_roll_joint\": 14.25062309787429\n        \".*_wrist_pitch_joint\": 16.77832748089279\n        \".*_wrist_yaw_joint\": 16.77832748089279\n\n      damping:\n        \".*_hip_pitch_joint\": 2.55788977\n        \".*_hip_roll_joint\": 6.30880185\n        \".*_hip_yaw_joint\": 2.55788977\n        \".*_knee_joint\": 6.30880185\n        \".*_ankle_pitch_joint\": 1.81444569\n        \".*_ankle_roll_joint\": 1.81444569\n        \"waist_yaw_joint\": 2.55788977\n        \"waist_roll_joint\": 1.81444569\n        \"waist_pitch_joint\": 1.81444569\n        \".*_shoulder_pitch_joint\": 0.90722284\n        \".*_shoulder_roll_joint\": 0.90722284\n        \".*_shoulder_yaw_joint\": 0.90722284\n        \".*_elbow_joint\": 0.90722284\n        \".*_wrist_roll_joint\": 0.907222843292423\n        \".*_wrist_pitch_joint\": 1.06814150219\n        \".*_wrist_yaw_joint\": 1.06814150219\n\n      armature:\n        \".*_hip_pitch_joint\": 0.010177520\n        \".*_hip_roll_joint\": 0.025101925\n        \".*_hip_yaw_joint\": 0.010177520\n        \".*_knee_joint\": 0.025101925\n        \".*_ankle_pitch_joint\": 0.007219450\n        \".*_ankle_roll_joint\": 0.007219450\n        \"waist_yaw_joint\": 0.010177520\n        \"waist_roll_joint\": 0.007219450\n        \"waist_pitch_joint\": 0.007219450\n        \".*_shoulder_pitch_joint\": 0.003609725\n        \".*_shoulder_roll_joint\": 0.003609725\n        \".*_shoulder_yaw_joint\": 0.003609725\n        \".*_elbow_joint\": 0.003609725\n        \".*_wrist_roll_joint\": 0.003609725\n        \".*_wrist_pitch_joint\": 0.00425\n        \".*_wrist_yaw_joint\": 0.00425\n\n    action_scale:\n      \".*_hip_pitch_joint\": 0.548\n      \".*_hip_roll_joint\": 0.351\n      \".*_hip_yaw_joint\": 0.548\n      \".*_knee_joint\": 0.351\n      \".*_ankle_pitch_joint\": 0.439\n      \".*_ankle_roll_joint\": 0.439\n      \"waist_yaw_joint\": 0.548\n      \"waist_roll_joint\": 0.439\n      \"waist_pitch_joint\": 0.439\n      \".*_shoulder_pitch_joint\": 0.439\n      \".*_shoulder_roll_joint\": 0.439\n      \".*_shoulder_yaw_joint\": 0.439\n      \".*_elbow_joint\": 0.439\n      \".*_wrist_roll_joint\": 0.439\n      \".*_wrist_pitch_joint\": 0.075\n      \".*_wrist_yaw_joint\": 0.075\n\n  dof_sign_by_name:\n    left_hip_pitch_joint: 1.0\n    left_hip_roll_joint: -1.0\n    left_hip_yaw_joint: -1.0\n    left_knee_joint: 1.0\n    left_ankle_pitch_joint: 1.0\n    left_ankle_roll_joint: -1.0\n    right_hip_pitch_joint: 1.0\n    right_hip_roll_joint: -1.0\n    right_hip_yaw_joint: -1.0\n    right_knee_joint: 1.0\n    right_ankle_pitch_joint: 1.0\n    right_ankle_roll_joint: -1.0\n    waist_yaw_joint: -1.0\n    waist_roll_joint: -1.0\n    waist_pitch_joint: 1.0\n    left_shoulder_pitch_joint: 1.0\n    left_shoulder_roll_joint: -1.0\n    left_shoulder_yaw_joint: -1.0\n    left_elbow_joint: 1.0\n    left_wrist_roll_joint: -1.0\n    left_wrist_pitch_joint: 1.0\n    left_wrist_yaw_joint: -1.0\n    right_shoulder_pitch_joint: 1.0\n    right_shoulder_roll_joint: -1.0\n    right_shoulder_yaw_joint: -1.0\n    right_elbow_joint: 1.0\n    right_wrist_roll_joint: -1.0\n    right_wrist_pitch_joint: 1.0\n    right_wrist_yaw_joint: -1.0\n\n  asset:\n    collapse_fixed_joints: True\n    replace_cylinder_with_capsule: True\n    flip_visual_attachments: False\n    max_angular_velocity: 1000.\n    max_linear_velocity: 1000.\n    density: 0.001\n    angular_damping: 0.\n    linear_damping: 0.\n\n    asset_root: \"./\"\n    urdf_file: \"assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.urdf\"\n    assetFileName: \"assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml\"\n    fix_base_link: false\n    force_usd_conversion: true\n\n  extend_config: []\n\n  motion:\n    asset:\n      assetRoot: \"./\"\n      assetFileName: \"assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml\"\n\n    sampling_strategy: ${algo.config.sampling_strategy}\n    weighted_bin: ${algo.config.weighted_bin}\n    curriculum: ${algo.config.curriculum}\n    dump_sampled_motion_keys: false\n    dump_sampled_motion_keys_interval: 1\n    dump_sampled_motion_keys_dir: \"sampled_motion_cache_keys\"\n\n    max_frame_length: 300 # 6s\n    min_frame_length: 50 # 1s\n    handpicked_motion_names: null\n    excluded_motion_names: null\n\n    world_frame_normalization: true\n\n    backend: \"hdf5_v2\" # hdf5, hdf5_v2\n    train_hdf5_roots: ${train_hdf5_roots}\n    val_hdf5_roots: ${train_hdf5_roots}\n\n    dof_names: ${robot.dof_names}\n    body_names: ${robot.body_names}\n    key_bodies: ${robot.key_bodies}\n\n    extend_config: ${robot.extend_config}\n\n    dataloader:\n      num_workers: 2\n      prefetch_factor: 1\n      pin_memory: true\n      persistent_workers: false\n      timeout: 600\n\n    fk_robot_file_path: ${robot.asset.urdf_file}\n    fk_vel_smoothing_sigma: 2.0\n\n    online_filter:\n      enabled: false\n      butter_order: 4\n      butter_cutoff_hz_pool: []\n\n    cache:\n      batch_progress_bar: false\n      max_num_clips: ${num_envs} # Batch size for motion clips\n      device: \"cuda\" # \"cuda\" or \"cpu\"; cuda stages on GPU\n      swap_interval_steps: ${robot.motion.max_frame_length} # Swap cache every N steps\n      allowed_prefixes:\n        - \"ref_\"\n        - \"ft_ref_\"\n"
  },
  {
    "path": "holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab_s100.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - unitree/G1/29dof/29dof_training_isaaclab\n\nrobot:\n  asset:\n    urdf_file: \"assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0_s100.urdf\"\n\n\n\n"
  },
  {
    "path": "holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking_mlp.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - /training: train_base\n  - /algo: ppo\n  - /robot: unitree/G1/29dof/29dof_training_isaaclab\n  - /env: motion_tracking\n  - /env/terminations: termination_motion_tracking\n  - /env/observations: motion_tracking/obs_motion_tracking_mlp\n  - /env/rewards: motion_tracking/rew_motion_tracking\n  - /env/domain_randomization: domain_rand_medium\n  - /env/terrain: isaaclab_rough\n  - /modules: motion_tracking/motion_tracking_mlp\n\nproject_name: HoloMotionMotrackV1.2\n\n# checkpoint: ???\n\ntrain_hdf5_roots:\n  - data/h5v2_datasets/AMASS_test\n\n"
  },
  {
    "path": "holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking_tf-moe.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - /training: train_base\n  - /algo: ppo_tf\n  - /robot: unitree/G1/29dof/29dof_training_isaaclab\n  - /env: motion_tracking\n  - /env/terminations: termination_motion_tracking\n  - /env/observations: motion_tracking/obs_motion_tracking_tf-moe\n  - /env/rewards: motion_tracking/rew_motion_tracking\n  - /env/domain_randomization: domain_rand_medium\n  - /env/terrain: isaaclab_rough\n  - /modules: motion_tracking/motion_tracking_tf-moe\n\nproject_name: HoloMotionMotrackV1.2\n\n# checkpoint: ???\n\ntrain_hdf5_roots:\n  - data/h5v2_datasets/AMASS_test\n"
  },
  {
    "path": "holomotion/config/training/train_base.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - _self_\n  - /mujoco_eval: sim2sim\n\nproject_name: ???\nexperiment_name: ???\n\nnum_envs: ???\nheadless: ???\n\nmotion_h5_path: ???\ncheckpoint: null\n\nnum_processes: ???\nmain_process: ???\nprocess_id: ???\n\ntimestamp: ${now:%Y%m%d_%H%M%S}\nbase_dir: logs\nexperiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name}\nsave_dir: ${experiment_dir}/.hydra\noutput_dir: ${experiment_dir}/output\nexperiment_save_dir: ???\n"
  },
  {
    "path": "holomotion/config/training/velocity_tracking/train_g1_29dof_velocity_tracking_mlp.yaml",
    "content": "# @package _global_\n\ndefaults:\n  - /training: train_base\n  - /algo: ppo\n  - /robot: unitree/G1/29dof/29dof_training_isaaclab\n  - /env: velocity_tracking\n  - /env/terminations: termination_velocity_tracking\n  - /env/observations: velocity_tracking/obs_velocity_tracking\n  - /env/rewards: velocity_tracking/rew_velocity_tracking\n  - /env/domain_randomization: domain_rand_medium\n  - /env/terrain: isaaclab_rough\n  - /modules: velocity_tracking/velocity_tracking_mlp\n\nproject_name: HoloMotionVelocityTrackingG1\n\n# checkpoint: ???\n\ntrain_hdf5_roots:\n  - /horizon-bucket/robot_lab/users/maiyue01.chen/h5_datasets/h5_unitree_walk_20160119\n\nenv:\n  config:\n    commands:\n      base_velocity:\n        params:\n          resampling_time_range: [3.0, 6.0]\n          rel_standing_envs: 0.2\n          rel_yaw_envs: 0.2\n          heading_command: false\n          heading_control_stiffness: 0.5\n          ranges:\n            lin_vel_x: [-1.0, 1.0]\n            lin_vel_y: [-0.5, 0.5]\n            ang_vel_z: [-1.0, 1.0]\n            heading: [-3.14, 3.14]\n\nalgo:\n  config:\n    symmetry_loss:\n      enabled: true\n      coef: 0.1\n\nrobot:\n  init_state:\n    default_joint_angles:\n      waist_pitch_joint: 0.0\n"
  },
  {
    "path": "holomotion/scripts/data_curation/convert_to_amass.sh",
    "content": "source train.env\n\n# 默认原始数据路径\nDATA_ROOT=\"./data/raw_datasets\"\n\n# 如果传入参数就覆盖默认\nif [ ! -z \"$1\" ]; then\n    DATA_ROOT=\"$1\"\nfi\n\n${Train_CONDA_PREFIX}/bin/python \\\n    holomotion/src/data_curation/data_smplify.py \\\n    --data_root \"$DATA_ROOT\"\n"
  },
  {
    "path": "holomotion/scripts/data_curation/filter_smpl_data.sh",
    "content": "source train.env\n\n# default json lisy\ndefault_jsonl_list=(\"humanact12\" \"MotionX\" \"OMOMO\" \"ZJU_Mocap\" \"amass\")\njsonl_list=(\"${default_jsonl_list[@]}\")\n\n# extract command line params\nwhile getopts \"l:\" opt; do\n    case $opt in\n    l)\n        # 用户输入的 jsonl_list\n        IFS=' ' read -r -a jsonl_list <<<\"$OPTARG\"\n        ;;\n    *)\n        echo \"Usage: $0 [-l \\\"file1 file2 ...\\\"]\"\n        exit 1\n        ;;\n    esac\ndone\n\necho \"Running label_data.py first...\"\n${Train_CONDA_PREFIX}/bin/python \\\n    ./holomotion/src/data_curation/filter/label_data.py \\\n    --jsonl_list \"${jsonl_list[@]}\"\n\necho \"label_data.py finished.\"\necho \"==============================\"\n\nfor json in \"${jsonl_list[@]}\"; do\n    echo \"Processing $json\"\n\n    #\n    if [[ \"$json\" == \"amass\" ]]; then\n        parent_folder=\"./data/amass_compatible_datasets/amass\"\n    else\n        parent_folder=\"./data/amass_compatible_datasets\"\n    fi\n\n    # 生成路径\n    json_path=\"./data/dataset_labels/${json}.jsonl\"\n    yaml_path=\"./holomotion/config/data_curation/${json}_excluded.yaml\"\n\n    # 调用 python 脚本\n    ${Train_CONDA_PREFIX}/bin/python \\\n        ./holomotion/src/data_curation/filter/filter.py \\\n        --parent_folder \"$parent_folder\" \\\n        --json_path \"$json_path\" \\\n        --yaml_path \"$yaml_path\"\n\n    echo \"Finished $json\"\n    echo \"-----------------------\"\ndone\n\necho \"All done\"\n"
  },
  {
    "path": "holomotion/scripts/data_curation/video_to_smpl_gvhmr.sh",
    "content": "export CONDA_BASE=$(conda info --base)\nexport Train_CONDA_PREFIX=\"$CONDA_BASE/envs/gvhmr\"\n\nvideo_folder_root=\"holomotion_abs_path/data/video_data\"\nnpz_data_root=\"holomotion_abs_path/data/gvhmr_converted/gvhmr_result\"\nout_dir=\"holomotion_abs_path/data/gvhmr_converted/collected_smpl\"\n\ncd thirdparties/GVHMR/\n\n$Train_CONDA_PREFIX/bin/python ../../holomotion/src/data_curation/video_to_smpl_gvhmr.py \\\n    --folder=${video_folder_root} \\\n    --output_root=${npz_data_root} \\\n    -s\n\nmkdir -p \"${out_dir}\"\nfor subdir in \"${npz_data_root}\"/*; do\n    if [[ ! -d \"${subdir}\" ]]; then\n        continue\n    fi\n\n    sub_name=$(basename \"${subdir}\")\n    src_npz=\"${subdir}/smpl.npz\"\n\n    if [[ ! -f \"${src_npz}\" ]]; then\n        echo \"[SKIP] ${sub_name}: smpl.npz not found\"\n        continue\n    fi\n\n    dst_npz=\"${out_dir}/${sub_name}_smpl.npz\"\n\n    cp -f \"${src_npz}\" \"${dst_npz}\"\n    echo \"[COPY] ${src_npz} -> ${dst_npz}\"\ndone"
  },
  {
    "path": "holomotion/scripts/data_curation/visualize_smpl_npz.sh",
    "content": "export CONDA_BASE=$(conda info --base)\nexport Train_CONDA_PREFIX=\"$CONDA_BASE/envs/gvhmr\"\n\n$Train_CONDA_PREFIX/bin/python ../../holomotion/src/data_curation/visualize_smpl_npz.py"
  },
  {
    "path": "holomotion/scripts/evaluation/calc_offline_eval_metrics.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\n\nnpz_dir=\"your_npz_dir\"\ndataset_suffix=\"HoloMotion_eval\"\nmetric_calculation=\"per_clip\"   # Options: \"per_clip\" or \"per_frame\"\ndof_mode=\"23\"  # Options: \"29\" for full DoF, \"23\" for upper body only\n\n${Train_CONDA_PREFIX}/bin/python \\\n    holomotion/src/evaluation/metrics.py \\\n    --npz_dir=${npz_dir} \\\n    --dataset_suffix=${dataset_suffix} \\\n    --failure_pos_err_thresh_m=0.25 \\\n    --metric_calculation=${metric_calculation} \\\n    --dof_mode=${dof_mode}\n"
  },
  {
    "path": "holomotion/scripts/evaluation/eval_motion_tracking.sh",
    "content": "#!/bin/bash\n\n# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nsource train.env\nexport CUDA_VISIBLE_DEVICES=\"0\"\n\nHEADLESS=true\nCONFIG_NAME=\"eval_isaaclab\"\n\nCKPT_PATH=\"logs/HoloMotionMotrackV1.2/your_log_dir/model_xxx.pt\"\n\neval_h5_dataset_path=\"['data/h5v2_datasets/lafan1']\"\n\nnum_envs=4\n\n\n${Train_CONDA_PREFIX}/bin/accelerate launch \\\n    holomotion/src/evaluation/eval_motion_tracking_single.py \\\n    --config-name=evaluation/${CONFIG_NAME} \\\n    headless=${HEADLESS} \\\n    num_envs=${num_envs} \\\n    export_policy=true \\\n    dump_npzs=true \\\n    calc_per_clip_metrics=true \\\n    generate_report=true \\\n    motion_h5_path=${eval_h5_dataset_path} \\\n    +use_kv_cache=true \\\n    export_only=false \\\n    checkpoint=$CKPT_PATH \\\n    project_name=\"HoloMotionMoTrack\"\n"
  },
  {
    "path": "holomotion/scripts/evaluation/eval_mujoco_sim2sim.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\n\nexport CUDA_VISIBLE_DEVICES=\"0\"\n\nexport HEADLESS=false\nif $HEADLESS; then\n    export MUJOCO_GL=\"osmesa\"\n    export RECORD_VIDEO=true\nelse\n    export MUJOCO_GL=\"egl\"\n    export RECORD_VIDEO=false\nfi\n\nmodel_type=\"${model_type:-holomotion}\"\n\nrobot_xml_path=\"assets/robots/unitree/G1/29dof/scene_29dof.xml\"\n\nONNX_PATH=\"your_onnx_model.onnx\"\n\nexport motion_npz_path=\"your_npz.npz\"\n\n${Train_CONDA_PREFIX}/bin/python holomotion/src/evaluation/eval_mujoco_sim2sim.py \\\n    record_video=$RECORD_VIDEO \\\n    headless=$HEADLESS \\\n    camera_tracking=true \\\n    camera_distance=7.0 \\\n    +model_type=${model_type} \\\n    use_gpu=true \\\n    dump_npzs=true \\\n    dump_onnx_io_npy=false \\\n    calc_per_clip_metrics=true \\\n    generate_report=true \\\n    ray_actors_per_gpu=12 \\\n    policy_action_delay_step=0 \\\n    action_delay_type=step \\\n    +ckpt_onnx_path=\"$ONNX_PATH\" \\\n    +motion_npz_path='${oc.env:motion_npz_path}' \\\n    robot_xml_path=${robot_xml_path}\n"
  },
  {
    "path": "holomotion/scripts/evaluation/eval_velocity_tracking.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nsource train.env\nexport CUDA_VISIBLE_DEVICES=\"0\"\n\nconfig_name=\"eval_velocity_tracking\"\n\nnum_envs=1\n\nckpt_path=\"logs/HoloMotionVelocityTracking/xxxxx-train_g1_29dof_velocity_tracking/model_xxx.pt\"\n\n${Train_CONDA_PREFIX}/bin/python \\\n    holomotion/src/evaluation/eval_velocity_tracking.py \\\n    --config-name=evaluation/${config_name} \\\n    project_name=\"HoloMotionVelocityTracking\" \\\n    num_envs=${num_envs} \\\n    headless=false \\\n    experiment_name=${config_name} \\\n    checkpoint=${ckpt_path} \\\n    +env.config.commands.base_velocity.params.resampling_time_range=[3,5]\n"
  },
  {
    "path": "holomotion/scripts/evaluation/mean_process_5metrics.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport argparse\nimport csv\nimport glob\nimport json\nimport os\nimport re\n\nimport numpy as np\nimport pandas as pd\nfrom tabulate import tabulate\n\n\n# 需要统计的指标 Key\nMETRICS = [\n    \"mpjpe_g\",\n    \"mpjpe_l\",\n    \"whole_body_joints_dist\",\n    \"root_vel_error\",\n    \"root_r_error\",\n    \"root_p_error\",\n    \"root_y_error\",\n    \"root_height_error\",\n    \"mean_dof_vel\",\n    \"mean_dof_acc\",\n    \"mean_dof_torque\",\n    \"mean_action_rate\",\n    \"success\",\n    \"mean_torque_jump_norm\",\n    \"p95_torque_jump_norm\",\n    \"mean_torque_jump_ratio\",\n    \"p95_torque_jump_ratio\",\n]\n\n# 表头映射 (Json Key -> 表格显示名称)\nCOLUMN_MAPPING = {\n    \"mpjpe_g\": \"Global Bodylink Pos Err\",\n    \"mpjpe_l\": \"Local Bodylink Pos Err\",\n    \"whole_body_joints_dist\": \"Dof Position Err\",\n    \"root_vel_error\": \"Root Vel Err\",\n    \"root_r_error\": \"Root Roll Err\",\n    \"root_p_error\": \"Root Pitch Err\",\n    \"root_y_error\": \"Root Yaw Err\",\n    \"root_height_error\": \"Root Height Err\",\n    \"mean_dof_vel\": \"Mean Dof Vel\",\n    \"mean_dof_acc\": \"Mean Dof Acc\",\n    \"mean_dof_torque\": \"Mean Dof Torque\",\n    \"mean_action_rate\": \"Mean Action Rate\",\n    \"success\": \"Success Rate\",\n    \"mean_torque_jump_norm\": \"Mean Torque Jump Norm\",\n    \"p95_torque_jump_norm\": \"P95 Torque Jump Norm\",\n    \"mean_torque_jump_ratio\": \"Mean Torque Jump Ratio\",\n    \"p95_torque_jump_ratio\": \"P95 Torque Jump Ratio\",\n}\n\n\ndef get_dataset_name(motion_key):\n    if not isinstance(motion_key, str):\n        return \"Unknown\"\n\n    match_old = re.search(r\"clips_([a-zA-Z0-9]+)_\", motion_key)\n    if match_old:\n        return match_old.group(1)\n\n    match_new = re.search(r\"v1.1_eval_([a-zA-Z0-9]+)_\", motion_key)\n    if match_new:\n        return match_new.group(1)\n\n    return motion_key.split(\"_\")[0]\n\n\ndef process_data(folder_path):\n    folder_path = os.path.expanduser(folder_path)\n    search_pattern = os.path.join(folder_path, \"*.json\")\n    json_files = glob.glob(search_pattern)\n    json_files = [\n        file for file in json_files if \"batch_\" not in os.path.basename(file)\n    ]\n\n    if not json_files:\n        raise FileNotFoundError(\n            f\"No .json files found in directory: {folder_path}\"\n        )\n\n    all_records = []\n\n    for file_path in json_files:\n        model_name = os.path.splitext(os.path.basename(file_path))[0]\n        with open(file_path, \"r\", encoding=\"utf-8\") as f:\n            data = json.load(f)\n\n        # 结构兼容性处理\n        if isinstance(data, dict) and \"per_clip\" in data:\n            clips_data = data[\"per_clip\"]\n        elif isinstance(data, list):\n            clips_data = data\n        elif isinstance(data, dict) and \"motion_key\" in data:\n            clips_data = [data]\n        else:\n            continue\n\n        for entry in clips_data:\n            if \"motion_key\" not in entry:\n                continue\n\n            dataset_name = get_dataset_name(entry[\"motion_key\"])\n\n            record = {\"Method\": model_name, \"Dataset\": dataset_name}\n\n            for metric in METRICS:\n                val = entry.get(metric, None)\n                if val is not None:\n                    record[metric] = val\n\n            all_records.append(record)\n\n    if not all_records:\n        raise ValueError(\n            f\"No valid per-clip metric records extracted from: {folder_path}\"\n        )\n\n    df = pd.DataFrame(all_records)\n    df = df.reindex(columns=[\"Method\", \"Dataset\", *METRICS])\n    grouped_ds = df.groupby([\"Method\", \"Dataset\"])[METRICS]\n    df_mean_ds = grouped_ds.mean().reset_index()\n    df_median_ds = grouped_ds.median().reset_index()\n\n    # Macro-Mean calculation\n    df_mean_total = (\n        df_mean_ds.groupby([\"Method\"])[METRICS].mean().reset_index()\n    )\n\n    # Macro-Median calculation\n    df_median_total = (\n        df_median_ds.groupby([\"Method\"])[METRICS].mean().reset_index()\n    )\n\n    df_mean_total[\"Dataset\"] = \"Total (Macro)\"\n    df_median_total[\"Dataset\"] = \"Total (Macro)\"\n\n    final_mean = pd.concat([df_mean_ds, df_mean_total], ignore_index=True)\n    final_median = pd.concat(\n        [df_median_ds, df_median_total], ignore_index=True\n    )\n\n    return final_mean, final_median\n\n\ndef highlight_best(val, best_val):\n    \"\"\"Return a highlighted HTML string when value is best.\"\"\"\n    if val is None or pd.isna(val):\n        return str(val)\n\n    val_float = float(val)\n    best_val_float = float(best_val)\n    formatted_val = f\"{val_float:.4f}\"\n    if np.isclose(val_float, best_val_float, atol=1e-6):\n        return f\"<b><span style='color: green'>{formatted_val}</span></b>\"\n    return formatted_val\n\n\ndef generate_report(\n    df,\n    folder_path,\n    file_name=\"result_table_mean.md\",\n    title=\"Evaluation Results (Mean)\",\n):\n    out_md = os.path.join(folder_path, file_name)\n\n    all_datasets = df[\"Dataset\"].unique().tolist()\n\n    # 排序：将 Total 放到最后\n    total_key = \"Total (Macro)\"\n    if total_key in all_datasets:\n        all_datasets.remove(total_key)\n        all_datasets.sort()\n        all_datasets.append(total_key)\n    else:\n        all_datasets.sort()\n\n    md_content_accumulator = f\"# {title}\\n\\n\"\n    md_content_accumulator += (\n        \"> **Note:** 'Total (Macro)' represents the **Macro-Average**, \"\n        \"calculated as the arithmetic mean of the scores across all datasets, \"\n        \"treating each dataset equally regardless of sample size.\\n\\n\"\n    )\n\n    for ds_name in all_datasets:\n        sub_df = df[df[\"Dataset\"] == ds_name].copy()\n\n        for metric in METRICS:\n            if metric in sub_df.columns:\n                if metric == \"success\":\n                    best_val = sub_df[metric].max()\n                else:\n                    best_val = sub_df[metric].min()\n                sub_df[metric] = sub_df[metric].apply(\n                    lambda x, best_val=best_val: highlight_best(x, best_val)\n                )\n\n        sub_df = sub_df.drop(columns=[\"Dataset\"])\n        sub_df.rename(columns=COLUMN_MAPPING, inplace=True)\n\n        cols = list(sub_df.columns)\n        if \"Method\" in cols:\n            cols.insert(0, cols.pop(cols.index(\"Method\")))\n            sub_df = sub_df[cols]\n\n        md_content_accumulator += f\"### Dataset: {ds_name}\\n\"\n        # 使用 to_markdown 生成表格\n        table_str = sub_df.to_markdown(index=False)\n        md_content_accumulator += table_str + \"\\n\\n\"\n\n    with open(out_md, \"w\", encoding=\"utf-8\") as f:\n        f.write(md_content_accumulator)\n\n    return os.path.abspath(out_md)\n\n\ndef _format_metric_values_for_cli(sub_df: pd.DataFrame) -> pd.DataFrame:\n    cli_df = sub_df.copy()\n    for metric in METRICS:\n        if metric in cli_df.columns:\n            cli_df[metric] = cli_df[metric].apply(\n                lambda x: f\"{float(x):.4f}\" if pd.notna(x) else \"nan\"\n            )\n    return cli_df\n\n\ndef _print_cli_tables(df: pd.DataFrame, title: str, folder_path: str) -> None:\n    total_key = \"Total (Macro)\"\n    all_datasets = df[\"Dataset\"].unique().tolist()\n    dataset_order = sorted([d for d in all_datasets if d != total_key])\n    if total_key in all_datasets:\n        dataset_order.append(total_key)\n\n    merged_df = df.copy()\n    merged_df[\"Dataset\"] = pd.Categorical(\n        merged_df[\"Dataset\"], categories=dataset_order, ordered=True\n    )\n    merged_df = merged_df.sort_values(\n        by=[\"Dataset\", \"Method\"], kind=\"stable\"\n    ).reset_index(drop=True)\n    merged_df[\"Dataset\"] = merged_df[\"Dataset\"].astype(str)\n\n    merged_df = _format_metric_values_for_cli(merged_df)\n    merged_df.rename(columns=COLUMN_MAPPING, inplace=True)\n\n    metric_display_cols = [\n        COLUMN_MAPPING[m] for m in METRICS if COLUMN_MAPPING[m] in merged_df\n    ]\n    # table_cols = [\"Dataset\", \"Method\"] + metric_display_cols\n    table_cols = [\"Dataset\"] + metric_display_cols\n    merged_df = merged_df[table_cols]\n\n    output_tsv_path = os.path.join(\n        folder_path, \"sub_dataset_macro_mean_metrics.tsv\"\n    )\n    with open(output_tsv_path, \"w\", encoding=\"utf-8\", newline=\"\") as f:\n        writer = csv.writer(f, delimiter=\"\\t\", lineterminator=\"\\n\")\n        writer.writerow(merged_df.columns.tolist())\n        writer.writerows(merged_df.values.tolist())\n\n    table_str = tabulate(\n        merged_df.values.tolist(),\n        headers=merged_df.columns.tolist(),\n        tablefmt=\"simple_outline\",\n        colalign=(\"left\",) * len(merged_df.columns),\n    )\n\n    block = (\n        \"\\n\"\n        + \"=\" * 80\n        + f\"\\n{title}\\n\"\n        + \"=\" * 80\n        + f\"\\n\\n{table_str}\\n\"\n        + \"=\" * 80\n        + \"\\n\"\n    )\n    print(block)\n    metric_log_path = os.path.join(folder_path, \"metric.log\")\n    with open(metric_log_path, \"a\", encoding=\"utf-8\") as file:\n        file.write(block)\n\n\ndef generate_macro_mean_report_from_json_dir(folder_path: str) -> str:\n    mean_df, _ = process_data(folder_path)\n    report_path = generate_report(\n        df=mean_df,\n        folder_path=folder_path,\n        file_name=\"result_table_macro_mean.md\",\n        title=\"Evaluation Results (Macro-Averaging Mean)\",\n    )\n    _print_cli_tables(\n        df=mean_df,\n        title=\"DATASET-WISE METRICS (MACRO-AVERAGING MEAN)\",\n        folder_path=folder_path,\n    )\n    return report_path\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dir\", type=str, help=\"json文件夹路径\")\n    args = parser.parse_args()\n\n    out_md = generate_macro_mean_report_from_json_dir(args.dir)\n    print(f\"报告已生成: {out_md}\")\n"
  },
  {
    "path": "holomotion/scripts/evaluation/multi_model_metrics_analysis.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nsource train.env\nmetrics_json_dir=\"logs/Holomotion/metrics_output_dataset\"\n\n${Train_CONDA_PREFIX}/bin/python \\\n  holomotion/src/evaluation/multi_model_metrics_report.py \\\n  --json_dir=\"$metrics_json_dir\"\n"
  },
  {
    "path": "holomotion/scripts/motion_retargeting/apply_gmr_motion_retarget_patch.sh",
    "content": "#!/usr/bin/env bash\n# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n#\n# This file was originally adapted from the [GMR] repository:\n# https://github.com/YanjieZe/GMR/blob/master/general_motion_retargeting/motion_retarget.py\n#\n\nset -euo pipefail\n\nREPO_ROOT=\"$(pwd)\"\nTARGET_FILE=\"${1:-$REPO_ROOT/thirdparties/GMR/general_motion_retargeting/motion_retarget.py}\"\n\nif [[ ! -f \"$TARGET_FILE\" ]]; then\n    echo \"Target file not found: $TARGET_FILE\" >&2\n    exit 1\nfi\n\npython - \"$TARGET_FILE\" <<'PY'\nfrom pathlib import Path\nimport ast\nimport sys\nimport textwrap\n\n\nPATCH_MARKERS = (\n    \"self.first_frame_damping = max(float(damping), 2.0)\",\n    \"self.prev_posture_task = mink.PostureTask(self.model, cost=1e-3)\",\n    \"def _solve_task_group(\",\n)\n\n\nPATCHED_INIT = \"\"\"\ndef __init__(\n    self,\n    src_human: str,\n    tgt_robot: str,\n    actual_human_height: float = None,\n    solver: str=\"daqp\", # change from \"quadprog\" to \"daqp\".\n    damping: float=5e-1, # change from 1e-1 to 1e-2.\n    verbose: bool=True,\n    use_velocity_limit: bool=False,\n) -> None:\n\n    # load the robot model\n    self.xml_file = str(ROBOT_XML_DICT[tgt_robot])\n    if verbose:\n        print(\"Use robot model: \", self.xml_file)\n    self.model = mj.MjModel.from_xml_path(self.xml_file)\n\n    # Print DoF names in order\n    print(\"[GMR] Robot Degrees of Freedom (DoF) names and their order:\")\n    self.robot_dof_names = {}\n    for i in range(self.model.nv):  # 'nv' is the number of DoFs\n        dof_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, self.model.dof_jntid[i])\n        self.robot_dof_names[dof_name] = i\n        if verbose:\n            print(f\"DoF {i}: {dof_name}\")\n\n    print(\"[GMR] Robot Body names and their IDs:\")\n    self.robot_body_names = {}\n    for i in range(self.model.nbody):  # 'nbody' is the number of bodies\n        body_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, i)\n        self.robot_body_names[body_name] = i\n        if verbose:\n            print(f\"Body ID {i}: {body_name}\")\n\n    print(\"[GMR] Robot Motor (Actuator) names and their IDs:\")\n    self.robot_motor_names = {}\n    for i in range(self.model.nu):  # 'nu' is the number of actuators (motors)\n        motor_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_ACTUATOR, i)\n        self.robot_motor_names[motor_name] = i\n        if verbose:\n            print(f\"Motor ID {i}: {motor_name}\")\n\n    # Load the IK config\n    with open(IK_CONFIG_DICT[src_human][tgt_robot]) as f:\n        ik_config = json.load(f)\n    if verbose:\n        print(\"Use IK config: \", IK_CONFIG_DICT[src_human][tgt_robot])\n\n    # compute the scale ratio based on given human height and the assumption in the IK config\n    if actual_human_height is not None:\n        ratio = actual_human_height / ik_config[\"human_height_assumption\"]\n    else:\n        ratio = 1.0\n\n    # adjust the human scale table\n    for key in ik_config[\"human_scale_table\"].keys():\n        ik_config[\"human_scale_table\"][key] = ik_config[\"human_scale_table\"][key] * ratio\n\n    # used for retargeting\n    self.ik_match_table1 = ik_config[\"ik_match_table1\"]\n    self.ik_match_table2 = ik_config[\"ik_match_table2\"]\n    self.human_root_name = ik_config[\"human_root_name\"]\n    self.robot_root_name = ik_config[\"robot_root_name\"]\n    self.use_ik_match_table1 = ik_config[\"use_ik_match_table1\"]\n    self.use_ik_match_table2 = ik_config[\"use_ik_match_table2\"]\n    self.human_scale_table = ik_config[\"human_scale_table\"]\n    self.ground = ik_config[\"ground_height\"] * np.array([0, 0, 1])\n\n    self.max_iter = 10\n\n    self.solver = solver\n    self.damping = damping\n    self.first_frame_damping = max(float(damping), 2.0)\n    self.first_frame_max_iter = max(int(self.max_iter), 10)\n    self._is_first_frame = True\n\n    self.human_body_to_task1 = {}\n    self.human_body_to_task2 = {}\n    self.pos_offsets1 = {}\n    self.rot_offsets1 = {}\n    self.pos_offsets2 = {}\n    self.rot_offsets2 = {}\n    self._arm_task_original_orientation_costs = {}\n    self._first_frame_arm_orientation_cost = 1.0\n\n    self.task_errors1 = {}\n    self.task_errors2 = {}\n\n    self.ik_limits = [mink.ConfigurationLimit(self.model)]\n    if use_velocity_limit:\n        VELOCITY_LIMITS = {k: 3*np.pi for k in self.robot_motor_names.keys()}\n        self.ik_limits.append(mink.VelocityLimit(self.model, VELOCITY_LIMITS))\n\n    self.setup_retarget_configuration()\n\n    self.ground_offset = 0.0\n\"\"\"\n\n\nPATCHED_SETUP = \"\"\"\ndef setup_retarget_configuration(self):\n    self.configuration = mink.Configuration(self.model)\n    self._default_qpos = self.configuration.data.qpos.copy()\n    self.posture_task = mink.PostureTask(self.model, cost=1e-2)\n    self.posture_task.set_target(self._default_qpos)\n    self.prev_posture_task = mink.PostureTask(self.model, cost=1e-3)\n    self.prev_posture_task.set_target(self._default_qpos)\n\n    self.tasks1 = []\n    self.tasks2 = []\n\n    for frame_name, entry in self.ik_match_table1.items():\n        body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry\n        if pos_weight != 0 or rot_weight != 0:\n            task = mink.FrameTask(\n                frame_name=frame_name,\n                frame_type=\"body\",\n                position_cost=pos_weight,\n                orientation_cost=rot_weight,\n                lm_damping=1,\n            )\n            self.human_body_to_task1[body_name] = task\n            self.pos_offsets1[body_name] = np.array(pos_offset) - self.ground\n            self.rot_offsets1[body_name] = R.from_quat(\n                rot_offset, scalar_first=True\n            )\n            self.tasks1.append(task)\n            self.task_errors1[task] = []\n            if self._is_arm_body(body_name):\n                self._arm_task_original_orientation_costs[task] = float(\n                    rot_weight\n                )\n\n    for frame_name, entry in self.ik_match_table2.items():\n        body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry\n        if pos_weight != 0 or rot_weight != 0:\n            task = mink.FrameTask(\n                frame_name=frame_name,\n                frame_type=\"body\",\n                position_cost=pos_weight,\n                orientation_cost=rot_weight,\n                lm_damping=1,\n            )\n            self.human_body_to_task2[body_name] = task\n            self.pos_offsets2[body_name] = np.array(pos_offset) - self.ground\n            self.rot_offsets2[body_name] = R.from_quat(\n                rot_offset, scalar_first=True\n            )\n            self.tasks2.append(task)\n            self.task_errors2[task] = []\n            if self._is_arm_body(body_name):\n                self._arm_task_original_orientation_costs[task] = float(\n                    rot_weight\n                )\n\"\"\"\n\n\nPATCHED_RETARGET_BLOCK = \"\"\"\n@staticmethod\ndef _is_arm_body(body_name):\n    return any(\n        token in body_name\n        for token in (\n            \"left_shoulder\",\n            \"right_shoulder\",\n            \"left_elbow\",\n            \"right_elbow\",\n            \"left_wrist\",\n            \"right_wrist\",\n        )\n    )\n\ndef _set_first_frame_arm_task_costs(self, enabled):\n    for task, original_orientation_cost in (\n        self._arm_task_original_orientation_costs.items()\n    ):\n        orientation_cost = (\n            self._first_frame_arm_orientation_cost\n            if enabled\n            else original_orientation_cost\n        )\n        task.set_orientation_cost(orientation_cost)\n\ndef _solve_task_group(\n    self,\n    tasks,\n    error_fn,\n    *,\n    damping,\n    max_iter,\n    include_posture,\n    include_prev_posture,\n):\n    solve_tasks = list(tasks)\n    if include_posture:\n        solve_tasks.append(self.posture_task)\n    if include_prev_posture:\n        solve_tasks.append(self.prev_posture_task)\n\n    curr_error = error_fn()\n    dt = self.configuration.model.opt.timestep\n    vel = mink.solve_ik(\n        self.configuration,\n        solve_tasks,\n        dt,\n        self.solver,\n        damping,\n        limits=self.ik_limits,\n    )\n    self.configuration.integrate_inplace(vel, dt)\n    next_error = error_fn()\n    num_iter = 0\n    while curr_error - next_error > 0.001 and num_iter < max_iter:\n        curr_error = next_error\n        dt = self.configuration.model.opt.timestep\n        vel = mink.solve_ik(\n            self.configuration,\n            solve_tasks,\n            dt,\n            self.solver,\n            damping,\n            limits=self.ik_limits,\n        )\n        self.configuration.integrate_inplace(vel, dt)\n        next_error = error_fn()\n        num_iter += 1\n\ndef retarget(self, human_data, offset_to_ground=False):\n    prev_q = self.configuration.data.qpos.copy()\n    # Update the task targets\n    self.update_targets(human_data, offset_to_ground)\n    include_posture = self._is_first_frame\n    include_prev_posture = True\n    solve_damping = (\n        self.first_frame_damping if self._is_first_frame else self.damping\n    )\n    solve_max_iter = (\n        self.first_frame_max_iter if self._is_first_frame else self.max_iter\n    )\n    self.prev_posture_task.set_target(prev_q)\n    if self._is_first_frame:\n        self._set_first_frame_arm_task_costs(True)\n\n    if self.use_ik_match_table1:\n        self._solve_task_group(\n            self.tasks1,\n            self.error1,\n            damping=solve_damping,\n            max_iter=solve_max_iter,\n            include_posture=include_posture,\n            include_prev_posture=include_prev_posture,\n        )\n\n    if self.use_ik_match_table2:\n        self._solve_task_group(\n            self.tasks2,\n            self.error2,\n            damping=solve_damping,\n            max_iter=solve_max_iter,\n            include_posture=include_posture,\n            include_prev_posture=include_prev_posture,\n        )\n\n    if self._is_first_frame:\n        self._set_first_frame_arm_task_costs(False)\n    self._is_first_frame = False\n    return self.configuration.data.qpos.copy()\n\"\"\"\n\n\ndef indent_block(src: str, indent: str = \"    \") -> str:\n    body = textwrap.dedent(src).strip(\"\\n\")\n    return \"\\n\".join(indent + line if line else \"\" for line in body.splitlines()) + \"\\n\"\n\n\ndef find_class(module: ast.Module, class_name: str) -> ast.ClassDef:\n    for node in module.body:\n        if isinstance(node, ast.ClassDef) and node.name == class_name:\n            return node\n    raise SystemExit(f\"Class {class_name!r} not found in target file.\")\n\n\ndef find_method(class_node: ast.ClassDef, method_name: str) -> ast.FunctionDef:\n    for node in class_node.body:\n        if isinstance(node, ast.FunctionDef) and node.name == method_name:\n            return node\n    raise SystemExit(f\"Method {method_name!r} not found in class {class_node.name}.\")\n\n\ndef apply_replacement(lines, node: ast.AST, replacement: str):\n    start = node.lineno - 1\n    end = node.end_lineno\n    return lines[:start] + [indent_block(replacement)] + lines[end:]\n\n\npath = Path(sys.argv[1])\ntext = path.read_text(encoding=\"utf-8\")\n\nif all(marker in text for marker in PATCH_MARKERS):\n    print(f\"Patch already present: {path}\")\n    raise SystemExit(0)\n\nmodule = ast.parse(text)\nclass_node = find_class(module, \"GeneralMotionRetargeting\")\ninit_node = find_method(class_node, \"__init__\")\nsetup_node = find_method(class_node, \"setup_retarget_configuration\")\nretarget_node = find_method(class_node, \"retarget\")\n\nlines = text.splitlines(keepends=True)\nfor node, replacement in sorted(\n    [\n        (retarget_node, PATCHED_RETARGET_BLOCK),\n        (setup_node, PATCHED_SETUP),\n        (init_node, PATCHED_INIT),\n    ],\n    key=lambda item: item[0].lineno,\n    reverse=True,\n):\n    lines = apply_replacement(lines, node, replacement)\n\nnew_text = \"\".join(lines)\nast.parse(new_text)\npath.write_text(new_text, encoding=\"utf-8\")\nprint(f\"Patch applied to: {path}\")\nPY"
  },
  {
    "path": "holomotion/scripts/motion_retargeting/pack_hdf5_v2.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\nexport CUDA_VISIBLE_DEVICES=\"\"\n\nholomotion_npz_root='[\"data/holomotion_retargeted/AMASS_test\"]'\nhdf5_root=\"data/h5v2_datasets/AMASS_test\"\n\nrobot_config=\"unitree/G1/29dof/29dof_training_isaaclab\"\n${Train_CONDA_PREFIX}/bin/python \\\n    holomotion/src/motion_retargeting/pack_hdf5_v2.py \\\n    robot=$robot_config \\\n    holomotion_npz_root=${holomotion_npz_root} \\\n    hdf5_root=$hdf5_root\n"
  },
  {
    "path": "holomotion/scripts/motion_retargeting/run_holomotion_preprocessing.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\n\nholo_src_dir=\"src_holomotion_npz_dir\"\nholo_tgt_dir=\"output_holomotion_npz_dir\"\n\npipeline=\"['filename_as_motionkey','legacy_to_ref_keys','tagging']\"\n\nrobot_config=\"holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml\"\n\n${Train_CONDA_PREFIX}/bin/python \\\n    holomotion/src/motion_retargeting/holomotion_preprocess.py \\\n    padding.robot_config_path=${robot_config} \\\n    io.src_root=${holo_src_dir} \\\n    io.out_root=${holo_tgt_dir} \\\n    preprocess.pipeline=${pipeline} \\\n    ray.enabled=true \\\n    padding.stand_still_time=20.0 \\\n    ray.num_workers=2\n"
  },
  {
    "path": "holomotion/scripts/motion_retargeting/run_kinematic_filter.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\n\ndataset_root=\"data/holomotion_retargeted/processed_datasets/AMASS_test\"\n\n${Train_CONDA_PREFIX}/bin/python \\\n    holomotion/src/motion_retargeting/kinematic_filter.py \\\n    io.dataset_root=${dataset_root}"
  },
  {
    "path": "holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_bvh.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\n\nbvh_src_dir=\"data/lafan1_bvh\"\ngmr_tgt_dir=\"data/gmr_retargeted/lafan1/\"\n\n# Step 1: retargeting to robot dataset from smplx format\n# create gmr_tgt_dir if not exists\nif [ ! -d \"$gmr_tgt_dir\" ]; then\n    mkdir -p $gmr_tgt_dir\nfi\n\n$Train_CONDA_PREFIX/bin/python \\\n    thirdparties/GMR/scripts/bvh_to_robot_dataset.py \\\n    --src_folder ${bvh_src_dir}/ \\\n    --tgt_folder ${gmr_tgt_dir}/ \\\n    --robot unitree_g1\n"
  },
  {
    "path": "holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\n\nsmplx_src_dir=\"assets/test_data/motion_retargeting/\"\ngmr_tgt_dir=\"data/gmr_retargeted/AMASS_test/\"\n\n\n# create gmr_tgt_dir if not exists\nif [ ! -d \"$gmr_tgt_dir\" ]; then\n    mkdir -p $gmr_tgt_dir\nfi\n\n$Train_CONDA_PREFIX/bin/python \\\n    thirdparties/GMR/scripts/smplx_to_robot_dataset.py \\\n    --src_folder=${smplx_src_dir}/ \\\n    --tgt_folder=${gmr_tgt_dir}/ \\\n    --num_cpus=16 \\\n    --robot=unitree_g1\n\n\n\n"
  },
  {
    "path": "holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_to_holomotion.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\n\ndir_name=\"AMASS_test\"\ngmr_tgt_dir=\"data/gmr_retargeted/${dir_name}\"\nholo_retargeted_dir=\"data/holomotion_retargeted/processed_datasets/${dir_name}\"\n\nrobot_cfg=\"holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml\"\n\npreprocess_pipeline=\"['filename_as_motionkey','legacy_to_ref_keys','slicing','add_padding','tagging']\"\n\n${Train_CONDA_PREFIX}/bin/python \\\n    holomotion/src/motion_retargeting/gmr_to_holomotion.py \\\n    io.robot_config=${robot_cfg} \\\n    io.src_dir=${gmr_tgt_dir} \\\n    io.out_root=${holo_retargeted_dir} \\\n    processing.target_fps=50 \\\n    preprocess.pipeline=${preprocess_pipeline} \\\n    ray.num_workers=16\n"
  },
  {
    "path": "holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nsource train.env\n\nexport MUJOCO_GL=\"osmesa\"\n\nmotion_npz_root=\"path_to_your_npz_dir\"\n\nexport motion_name=\"all\"\n\n\n$Train_CONDA_PREFIX/bin/python holomotion/src/motion_retargeting/utils/visualize_with_mujoco.py \\\n    +key_prefix=\"robot_\" \\\n    +draw_ref_body_spheres=true \\\n    +ref_key_prefix=\"ref_\" \\\n    +motion_npz_root=${motion_npz_root} \\\n    skip_frames=6 \\\n    max_workers=11 \\\n    +motion_name='${oc.env:motion_name}'\n"
  },
  {
    "path": "holomotion/scripts/training/train_motion_tracking.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nsource train.env\n\nexport CUDA_VISIBLE_DEVICES=0\n\nif [[ $(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\\n' | wc -l) -eq 1 ]]; then\n    USE_MULTI_GPU=false\nelse\n    USE_MULTI_GPU=true\nfi\n\n\nconfig_name=\"train_g1_29dof_motion_tracking_mlp\"\n# config_name=\"train_g1_29dof_motion_tracking_tf-moe\"\n\nnum_envs=4096\n\nCOMMON_ARGS=(\n    \"holomotion/src/training/train.py\"\n    \"--config-name=training/motion_tracking/${config_name}\"\n    \"num_envs=${num_envs}\"\n    \"headless=true\"\n    \"experiment_name=${config_name}\"\n)\n\ntrap cleanup SIGINT SIGTERM\nif [[ \"${USE_MULTI_GPU}\" == \"true\" ]]; then\n    ${Train_CONDA_PREFIX}/bin/accelerate launch \\\n        --multi_gpu \\\n        \"${COMMON_ARGS[@]}\"\nelse\n    ${Train_CONDA_PREFIX}/bin/accelerate launch \\\n        \"${COMMON_ARGS[@]}\"\nfi\nwait ${TRAIN_PID}\ntrap - SIGINT SIGTERM\n"
  },
  {
    "path": "holomotion/scripts/training/train_velocity_tracking.sh",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nsource train.env\n\nexport CUDA_VISIBLE_DEVICES=0\n\nconfig_name=\"train_g1_29dof_velocity_tracking_mlp\"\n\nnum_envs=4096\n\nCOMMON_ARGS=(\n    \"holomotion/src/training/train.py\"\n    \"--config-name=training/velocity_tracking/${config_name}\"\n    \"experiment_name=${config_name}\"\n    \"num_envs=${num_envs}\"\n    \"headless=true\"\n)\n\ntrap cleanup SIGINT SIGTERM\nif [[ $(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\\n' | wc -l) -eq 1 ]]; then\n    ${Train_CONDA_PREFIX}/bin/accelerate launch \\\n        --multi_gpu \\\n        \"${COMMON_ARGS[@]}\"\nelse\n    ${Train_CONDA_PREFIX}/bin/accelerate launch \\\n        \"${COMMON_ARGS[@]}\"\nfi\nwait ${TRAIN_PID}\ntrap - SIGINT SIGTERM\n"
  },
  {
    "path": "holomotion/src/algo/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/algo/algo_base.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport os\nimport random\nimport statistics\nimport sys\nimport time\nfrom collections import deque\nfrom typing import Any, Dict\n\nimport numpy as np\nimport torch\nfrom accelerate import Accelerator\nfrom accelerate.utils import (\n    ProjectConfiguration,\n    TorchDynamoPlugin,\n    load_checkpoint_in_model,\n    load_state_dict,\n)\nfrom hydra.utils import get_class\nfrom loguru import logger\nfrom tensordict import TensorDict\n\nfrom holomotion.src.algo.algo_utils import AlgoLogger\n\n\nclass BaseOnpolicyRL:\n    \"\"\"Base class for on-policy RL algorithms in HoloMotion.\"\"\"\n\n    def __init__(\n        self,\n        env_config,\n        config,\n        log_dir=None,\n        headless: bool = True,\n        is_offline_eval: bool = False,\n    ) -> None:\n        self.config = config\n        self.env_config = env_config\n        self.log_dir = log_dir\n        self.headless = headless\n        self.is_offline_eval = is_offline_eval\n\n        self._setup_accelerator()\n        self.algo_logger = AlgoLogger(\n            self.accelerator,\n            self.log_dir,\n            is_main_process=self.is_main_process,\n        )\n        self._setup_environment()\n        self._setup_configs()\n        self._setup_seeding()\n        self._setup_data_buffers()\n        self._setup_algo_components()\n        self._setup_models_and_optimizer()\n\n    def _setup_accelerator(self) -> None:\n        if not self.is_offline_eval:\n            os.makedirs(self.log_dir, exist_ok=True)\n\n        accelerator_kwargs = {}\n        mixed_precision = self.config.get(\"mixed_precision\", None)\n        if mixed_precision in (\"fp16\", \"bf16\"):\n            accelerator_kwargs[\"mixed_precision\"] = mixed_precision\n        dynamo_backend = self.config.get(\"dynamo_backend\", None)\n        if os.environ.get(\"TORCH_COMPILE_DISABLE\", \"0\") == \"1\":\n            dynamo_backend = None\n        if dynamo_backend in (\"inductor\", \"aot_eager\", \"cudagraphs\"):\n            dynamo_dynamic = bool(self.config.get(\"dynamo_dynamic\", True))\n            dynamo_fullgraph = bool(self.config.get(\"dynamo_fullgraph\", False))\n            dynamo_mode = self.config.get(\"dynamo_mode\", \"default\")\n            accelerator_kwargs[\"dynamo_plugin\"] = TorchDynamoPlugin(\n                backend=str(dynamo_backend),\n                mode=str(dynamo_mode),\n                fullgraph=bool(dynamo_fullgraph),\n                dynamic=bool(dynamo_dynamic),\n            )\n\n        accelerator_kwargs[\"log_with\"] = \"tensorboard\"\n        project_config = ProjectConfiguration(\n            project_dir=self.log_dir,\n            logging_dir=self.log_dir,\n        )\n        accelerator_kwargs[\"project_config\"] = project_config\n        self.accelerator = Accelerator(**accelerator_kwargs)\n        self.local_rank = getattr(\n            self.accelerator, \"local_process_index\", None\n        )\n        if self.local_rank is None:\n            self.local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n\n        self.device = self.accelerator.device\n        if torch.cuda.is_available() and self.device.type == \"cuda\":\n            dev_index = self.device.index\n            if dev_index is None:\n                dev_index = int(self.local_rank)\n                self.device = torch.device(\"cuda\", dev_index)\n            else:\n                dev_index = int(dev_index)\n            torch.cuda.set_device(dev_index)\n        self.is_main_process = self.accelerator.is_main_process\n\n        self.accelerator.init_trackers(\n            project_name=\"holomotion\",\n            config={\n                \"precision\": mixed_precision if mixed_precision else \"fp32\",\n                \"dynamo_backend\": dynamo_backend if dynamo_backend else \"none\",\n                \"dynamo_dynamic\": bool(self.config.get(\"dynamo_dynamic\", True))\n                if dynamo_backend\n                else False,\n            },\n        )\n        self._release_cuda_cache()\n\n        logger.remove()\n        log_level = os.environ.get(\"LOGURU_LEVEL\", \"INFO\").upper()\n        if self.log_dir:\n            rank_log_file_name = (\n                \"offline_eval_rank\" if self.is_offline_eval else \"run_rank\"\n            )\n            logger.add(\n                os.path.join(\n                    self.log_dir,\n                    f\"{rank_log_file_name}_{int(self.accelerator.process_index):04d}.log\",\n                ),\n                level=log_level,\n                colorize=False,\n            )\n        if self.is_main_process:\n            logger.add(\n                sys.stdout,\n                level=log_level,\n                colorize=True,\n            )\n            log_file_name = (\n                \"offline_eval.log\" if self.is_offline_eval else \"run.log\"\n            )\n            logger.add(\n                os.path.join(self.log_dir, log_file_name),\n                level=log_level,\n                colorize=False,\n            )\n\n            used_precision = mixed_precision if mixed_precision else \"fp32\"\n            logger.info(\n                f\"Accelerator initialized with precision: {used_precision}\"\n            )\n            if dynamo_backend:\n                logger.info(f\"Accelerator dynamo_backend: {dynamo_backend}\")\n            logger.info(f\"TensorBoard logging enabled at: {self.log_dir}\")\n\n        self.process_rank = self.accelerator.process_index\n        self.gpu_world_size = self.accelerator.num_processes\n        self.gpu_global_rank = self.accelerator.process_index\n        self.is_distributed = self.gpu_world_size > 1\n        env_rank = os.environ.get(\"RANK\", \"unset\")\n        env_local_rank = os.environ.get(\"LOCAL_RANK\", \"unset\")\n        env_world_size = os.environ.get(\"WORLD_SIZE\", \"unset\")\n        env_local_world_size = os.environ.get(\"LOCAL_WORLD_SIZE\", \"unset\")\n        env_node_rank = os.environ.get(\n            \"NODE_RANK\", os.environ.get(\"MACHINE_RANK\", \"unset\")\n        )\n        env_master_addr = os.environ.get(\"MASTER_ADDR\", \"unset\")\n        env_master_port = os.environ.get(\"MASTER_PORT\", \"unset\")\n        env_cuda_visible_devices = os.environ.get(\n            \"CUDA_VISIBLE_DEVICES\", \"unset\"\n        )\n        cuda_device_count = (\n            int(torch.cuda.device_count()) if torch.cuda.is_available() else 0\n        )\n        logger.info(\n            \"[Accelerate setup] \"\n            f\"distributed_type={self.accelerator.distributed_type}, \"\n            f\"num_processes={int(self.accelerator.num_processes)}, \"\n            f\"process_index={int(self.accelerator.process_index)}, \"\n            f\"local_process_index={int(self.local_rank)}, \"\n            f\"is_main_process={bool(self.accelerator.is_main_process)}\"\n        )\n        logger.info(\n            \"[Accelerate env] \"\n            f\"RANK={env_rank}, LOCAL_RANK={env_local_rank}, \"\n            f\"WORLD_SIZE={env_world_size}, \"\n            f\"LOCAL_WORLD_SIZE={env_local_world_size}, \"\n            f\"NODE_RANK={env_node_rank}, MASTER_ADDR={env_master_addr}, \"\n            f\"MASTER_PORT={env_master_port}\"\n        )\n        logger.info(\n            \"[Accelerate cuda] \"\n            f\"CUDA_VISIBLE_DEVICES={env_cuda_visible_devices}, \"\n            f\"torch_cuda_device_count={cuda_device_count}, \"\n            f\"selected_device={self.device}\"\n        )\n\n    def _setup_environment(self) -> None:\n        \"\"\"Setup IsaacLab AppLauncher and environment instance.\"\"\"\n        # Device string from accelerator (handles distributed training)\n        device_str = str(self.device)\n\n        # Delayed import to ensure Accelerate is fully initialized before IsaacLab\n        from isaaclab.app import AppLauncher\n\n        # Stagger IsaacSim AppLauncher initialization across distributed ranks\n        # Use local rank per node to stagger independently on each node\n        if self.is_distributed:\n            self.accelerator.wait_for_everyone()\n            base_delay_s = float(\n                os.environ.get(\"HOLOMOTION_ISAAC_STAGGER_SEC\", \"5.0\")\n            )\n            local_rank = int(self.local_rank)\n            delay_s = base_delay_s * float(local_rank)\n            if delay_s > 0.0:\n                logger.info(\n                    f\"[Global Rank {self.gpu_global_rank}, Local Rank {local_rank}] \"\n                    f\"Sleeping {delay_s:.1f}s before IsaacSim AppLauncher init\"\n                )\n            time.sleep(delay_s)\n\n        # Create AppLauncher with accelerator device\n        # Enable cameras only when needed:\n        # - headless & recording: True (offscreen rendering)\n        # - headless & not recording: False (maximize performance)\n        # - with GUI: True\n        _record_video = bool(self.config.get(\"record_video\", False))\n        enable_cameras = _record_video or (not self.headless)\n\n        # Explicitly disable Omniverse multi-GPU rendering to avoid per-process\n        # MGPU context creation across all visible GPUs.\n        kit_args_str = (\n            \"--/renderer/multiGpu/enabled=false \"\n            \"--/renderer/multiGpu/autoEnable=false \"\n            \"--/renderer/multiGpu/maxGpuCount=1\"\n        )\n\n        app_launcher_flags = {\n            \"headless\": self.headless,\n            \"enable_cameras\": enable_cameras,\n            \"video\": _record_video,\n            \"device\": device_str,\n            \"kit_args\": kit_args_str,\n        }\n\n        self._sim_app_launcher = AppLauncher(**app_launcher_flags)\n        self._sim_app = self._sim_app_launcher.app\n\n        logger.info(\n            f\"AppLauncher initialized with flags: {app_launcher_flags}\"\n        )\n\n        env_class = get_class(self.env_config._target_)\n\n        render_mode = (\n            \"rgb_array\"\n            if bool(self.config.get(\"record_video\", False))\n            else None\n        )\n        self.env = env_class(\n            config=self.env_config.config,\n            device=device_str,\n            headless=self.headless,\n            log_dir=self.log_dir,\n            accelerator=self.accelerator,\n            render_mode=render_mode,\n        )\n\n        _ = self.env.reset_all()\n\n        logger.info(f\"Environment initialized with render_mode: {render_mode}\")\n\n    def _setup_configs(self) -> None:\n        self.num_envs: int = self.env.config.num_envs\n        self.num_privileged_obs = 0\n        self.num_actions = self.env.config.robot.actions_dim\n\n        self.command_name = list(self.env.config.commands.keys())[0]\n        self.command_term = self.env._env.command_manager.get_term(\n            self.command_name\n        )\n        if self.command_name == \"ref_motion\":\n            self.command_term.set_runtime_distributed_context(\n                process_id=int(self.accelerator.process_index),\n                num_processes=int(self.accelerator.num_processes),\n            )\n            self.command_term.setup_dumping_dir(self.log_dir)\n\n        self.save_interval = self.config.save_interval\n        self.log_interval = self.config.log_interval\n        self.num_steps_per_env = self.config.num_steps_per_env\n        self.num_learning_iterations = self.config.num_learning_iterations\n        self.total_learning_iterations = int(self.num_learning_iterations)\n\n    def _setup_seeding(self) -> None:\n        if self.command_name == \"ref_motion\":\n            self.seed = int(self.command_term.cfg.seed)\n            self.base_seed = int(self.seed - int(self.process_rank))\n        else:\n            self.base_seed = int(self.config.get(\"seed\", int(time.time())))\n            self.seed = int(self.base_seed + int(self.process_rank))\n        random.seed(self.seed)\n        np.random.seed(self.seed)\n        torch.manual_seed(self.seed)\n        if torch.cuda.is_available():\n            torch.cuda.manual_seed(self.seed)\n        self.env.seed(self.seed)\n        if self.command_name == \"ref_motion\":\n            self.command_term.set_motion_cache_seed(\n                self.seed, reinitialize=False\n            )\n\n    def _setup_data_buffers(self) -> None:\n        self.tot_timesteps = 0\n        self.tot_time = 0\n        self.current_learning_iteration = 0\n\n        self.start_time = 0\n        self.stop_time = 0\n        self.collection_time = 0\n        self.learn_time = 0\n\n        self.ep_infos = []\n        self.rewbuffer = deque(maxlen=100)\n        self.lenbuffer = deque(maxlen=100)\n\n        self.cur_reward_sum = torch.zeros(\n            self.env.num_envs,\n            dtype=torch.float,\n            device=self.device,\n        )\n        self.cur_episode_length = torch.zeros(\n            self.env.num_envs,\n            dtype=torch.float,\n            device=self.device,\n        )\n\n        self.storage = None\n        self.transition_td = None\n        self._last_rollout_dones = None\n        self._last_rollout_actions = None\n\n    def _setup_algo_components(self) -> None:\n        \"\"\"Hook for algorithm-specific components (AMP, DAgger, PULSE).\"\"\"\n        return\n\n    def _setup_models_and_optimizer(self) -> None:\n        raise NotImplementedError(\n            \"Subclasses must implement _setup_models_and_optimizer.\"\n        )\n\n    def _build_storage(self, obs_td: TensorDict) -> Any:\n        \"\"\"Hook for custom RolloutStorage. Override for specialized storage; default no-op.\"\"\"\n        return None\n\n    def _post_env_step_hook(\n        self,\n        rewards: torch.Tensor,\n        dones: torch.Tensor,\n        time_outs: torch.Tensor,\n        infos: Dict[str, Any],\n    ) -> None:\n        \"\"\"Hook after each env step for auxiliary data collection.\"\"\"\n        if self.command_name != \"ref_motion\":\n            return\n        motion_term = self.env._env.command_manager.get_term(\"ref_motion\")\n        if motion_term is None:\n            return\n        motion_term.update_curriculum_reward_accumulators(rewards)\n\n    def _post_update_hook(self, loss_dict: Dict[str, Any]) -> None:\n        \"\"\"Hook after each PPO update for auxiliary losses or logging.\"\"\"\n        return\n\n    def _extra_checkpoint_state(self) -> Dict[str, Any]:\n        \"\"\"Additional state to save in checkpoints.\"\"\"\n        return {}\n\n    def _load_extra_checkpoint_state(\n        self, loaded_dict: Dict[str, Any]\n    ) -> None:\n        \"\"\"Load additional checkpoint state if present.\"\"\"\n        return\n\n    def _build_transition(\n        self,\n        obs_td: TensorDict,\n        actor_out: TensorDict,\n        critic_out: TensorDict,\n    ):\n        raise NotImplementedError(\n            \"Subclasses must implement _build_transition.\"\n        )\n\n    def _post_iteration_hook(self, it: int) -> None:\n        return\n\n    def _post_training_hook(self) -> None:\n        return\n\n    def _release_cuda_cache(self) -> None:\n        if torch.cuda.is_available() and self.device.type == \"cuda\":\n            torch.cuda.empty_cache()\n\n    def _get_additional_log_metrics(self) -> Dict[str, Any]:\n        return {}\n\n    def train_mode(self) -> None:\n        self.actor.train()\n        self.critic.train()\n\n    def _ensure_storage(self, obs_td: TensorDict) -> None:\n        if self.storage is not None:\n            return\n        self.storage = self._build_storage(obs_td)\n        if self.storage is None:\n            raise RuntimeError(\n                \"Storage is not initialized. Override _build_storage() or initialize self.storage in subclass.\"\n            )\n\n    def _reset_rollout_forward_state(self) -> None:\n        \"\"\"Hook for algorithm-specific rollout state reset.\"\"\"\n        return\n\n    def _rollout_forward(\n        self,\n        obs_td: TensorDict,\n        *,\n        actor_mode: str = \"sampling\",\n        collect_transition: bool = True,\n        track_episode_stats: bool = True,\n    ) -> TensorDict:\n        update_obs_norm = not self.is_offline_eval\n        with self.accelerator.autocast():\n            actor_out: TensorDict = self.actor(\n                obs_td,\n                actions=None,\n                mode=actor_mode,\n                update_obs_norm=update_obs_norm,\n            )\n            critic_out: TensorDict | None = None\n            if collect_transition:\n                critic_out = self.critic(\n                    obs_td, update_obs_norm=update_obs_norm\n                )\n\n        if collect_transition:\n            self.transition_td = self._build_transition(\n                obs_td,\n                actor_out,\n                critic_out,\n            )\n\n        actions = actor_out.get(\"actions\")\n        self._last_rollout_actions = actions\n        obs_dict, rewards, dones, time_outs, infos = self.env.step(actions)\n\n        next_obs_td = self._wrap_obs_dict(obs_dict)\n        dones = dones.to(self.device)\n        self._last_rollout_dones = dones\n\n        if collect_transition:\n            rewards = rewards.to(self.device)\n            time_outs = time_outs.to(self.device)\n            self.process_env_step(rewards, dones, time_outs, infos)\n\n        if track_episode_stats:\n            rewards_for_stats = rewards.to(self.device)\n            self._track_episode_stats(rewards_for_stats, dones, infos)\n        return next_obs_td\n\n    def _track_episode_stats(\n        self,\n        rewards: torch.Tensor,\n        dones: torch.Tensor,\n        infos: Dict[str, Any],\n    ) -> None:\n        log_info = infos.get(\"log\")\n        if self.is_main_process and isinstance(log_info, dict):\n            cpu_log_info: Dict[str, torch.Tensor] = {}\n            for key, value in log_info.items():\n                cpu_value = self._log_value_to_cpu_tensor(value)\n                if cpu_value is not None and cpu_value.numel() > 0:\n                    cpu_log_info[key] = cpu_value\n            if len(cpu_log_info) > 0:\n                self.ep_infos.append(cpu_log_info)\n        self.cur_reward_sum += rewards\n        self.cur_episode_length += 1\n\n        done_ids = (dones > 0).nonzero(as_tuple=False)\n        self.rewbuffer.extend(\n            self.cur_reward_sum[done_ids][:, 0].cpu().numpy().tolist()\n        )\n        self.lenbuffer.extend(\n            self.cur_episode_length[done_ids][:, 0].cpu().numpy().tolist()\n        )\n        self.cur_reward_sum[done_ids] = 0\n        self.cur_episode_length[done_ids] = 0\n\n    def _compute_returns(self, obs_td: TensorDict) -> None:\n        update_obs_norm = not self.is_offline_eval\n        with self.accelerator.autocast():\n            last_values = (\n                self.critic(obs_td, update_obs_norm=update_obs_norm)\n                .get(\"values\")\n                .detach()\n            )\n            self.storage.compute_returns(\n                last_values,\n                self.gamma,\n                self.lam,\n                normalize_advantage=False,\n            )\n\n        if getattr(self, \"global_advantage_norm\", False):\n            accelerator = self.accelerator if self.is_distributed else None\n            self.storage.normalize_advantages_global_by_command(\n                command_name=self.command_name,\n                accelerator=accelerator,\n                eps=1.0e-8,\n            )\n\n    def rollout_policy(self, obs_td: TensorDict) -> TensorDict:\n        \"\"\"Collect one rollout with current policy and compute returns.\"\"\"\n        actor_was_training = self.actor.training\n        critic_was_training = self.critic.training\n        self.actor.eval()\n        self.critic.eval()\n        with torch.no_grad():\n            self._reset_rollout_forward_state()\n            for _ in range(self.num_steps_per_env):\n                obs_td = self._rollout_forward(obs_td)\n            self._compute_returns(obs_td)\n        if actor_was_training:\n            self.actor.train()\n        if critic_was_training:\n            self.critic.train()\n        return obs_td\n\n    def learn(self):\n        \"\"\"Main learning loop with runner logic shared across on-policy algorithms.\"\"\"\n        obs_dict = self.env.reset_all()[0]\n        obs_td = self._wrap_obs_dict(obs_dict)\n        self._ensure_storage(obs_td)\n        self.train_mode()\n\n        start_it = self.current_learning_iteration\n        total_it = start_it + int(self.num_learning_iterations)\n        self.total_learning_iterations = total_it\n\n        self.accelerator.wait_for_everyone()\n        if self.is_main_process:\n            logger.info(\n                f\"Starting training for {self.num_learning_iterations} iterations \"\n                f\"from iteration {self.current_learning_iteration}\"\n            )\n\n        for it in range(start_it, total_it):\n            self.current_learning_iteration = it\n            start = time.time()\n            obs_td = self.rollout_policy(obs_td)\n\n            stop = time.time()\n            collection_time = stop - start\n            start = stop\n\n            loss_dict = self.update()\n\n            stop = time.time()\n            learn_time = stop - start\n\n            if self.is_main_process and it % self.log_interval == 0:\n                self._log_iteration(\n                    it=it,\n                    loss_dict=loss_dict,\n                    collection_time=collection_time,\n                    learn_time=learn_time,\n                )\n\n            if self.is_main_process and it % self.save_interval == 0:\n                self.save(\n                    os.path.join(\n                        self.log_dir,\n                        f\"model_{self.current_learning_iteration}.pt\",\n                    )\n                )\n                self._release_cuda_cache()\n\n            self._post_iteration_hook(it)\n            self.ep_infos.clear()\n            self.accelerator.wait_for_everyone()\n\n        final_checkpoint_path = os.path.join(\n            self.log_dir, f\"model_{self.current_learning_iteration}.pt\"\n        )\n        if self.is_main_process:\n            self.save(final_checkpoint_path)\n            self._release_cuda_cache()\n\n        self._post_training_hook()\n\n        if self.log_dir:\n            self.accelerator.wait_for_everyone()\n            self.accelerator.end_training()\n            if self.is_main_process:\n                logger.info(\n                    f\"Training completed. Model saved to {self.log_dir}\"\n                )\n\n    def process_env_step(\n        self,\n        rewards: torch.Tensor,\n        dones: torch.Tensor,\n        time_outs: torch.Tensor,\n        infos: Dict[str, Any],\n    ) -> None:\n        \"\"\"Process env step results and append to storage.\n\n        Args:\n            rewards: [N, 1] rewards (env step output).\n            dones: [N, 1] done flags (env step output).\n            time_outs: [N] time out flags (env step output).\n            infos: Environment info dictionary.\n        \"\"\"\n        raw_rewards = rewards.clone().view(-1, 1)\n        rewards = raw_rewards.clone()\n        dones = dones.view(-1, 1)\n\n        # Bootstrapping on time outs\n        rewards += self.gamma * (\n            self.transition_td.values * time_outs[:, None]\n        )\n        self.transition_td.rewards = rewards\n        self.transition_td.dones = dones.to(dtype=torch.bool)\n\n        self.storage.add(self.transition_td)\n        self._post_env_step_hook(raw_rewards, dones, time_outs, infos)\n\n        self.transition_td = None\n\n    def _wrap_obs_dict(self, obs_dict: dict) -> TensorDict:\n        \"\"\"Wrap env obs dict into a native nested TensorDict on device.\"\"\"\n        return TensorDict.from_dict(\n            obs_dict,\n            batch_size=[self.env.num_envs],\n            device=self.device,\n        )\n\n    @staticmethod\n    def _clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:\n        \"\"\"Remove '_orig_mod.' prefix from torch.compile wrapped models.\n\n        Args:\n            state_dict: State dict that may contain '_orig_mod.' prefixed keys\n\n        Returns:\n            Cleaned state dict with prefixes removed\n        \"\"\"\n        cleaned_dict = {}\n        prefix = \"_orig_mod.\"\n        prefix_len = len(prefix)\n        for k, v in state_dict.items():\n            new_k = k[prefix_len:] if k.startswith(prefix) else k\n            cleaned_dict[new_k] = v\n        return cleaned_dict\n\n    def _load_model_state(self, model, state_dict, *, strict: bool = True):\n        \"\"\"Load a state dict into a (possibly compiled) model safely.\n\n        - Always unwrap Accelerate wrappers first.\n        - If the model is a compiled OptimizedModule (has ``_orig_mod``),\n          load into the original module and strip any ``_orig_mod.`` prefixes\n          from the incoming state dict for robustness.\n        \"\"\"\n        target = self.accelerator.unwrap_model(model)\n        cleaned = self._clean_state_dict(state_dict)\n        if hasattr(target, \"_orig_mod\"):\n            target._orig_mod.load_state_dict(cleaned, strict=strict)\n        else:\n            target.load_state_dict(cleaned, strict=strict)\n\n    def _resolve_model_file_path(self, ckpt_path: str, model_name: str) -> str:\n        \"\"\"Resolve per-model Accelerate checkpoint directory from *.pt path.\"\"\"\n        base_path = ckpt_path.replace(\".pt\", \"\")\n        model_path = os.path.join(base_path, model_name)\n        if not os.path.isdir(model_path):\n            raise FileNotFoundError(\n                f\"Missing accelerate checkpoint directory for {model_name}: \"\n                f\"{model_path}\"\n            )\n        return model_path\n\n    def _load_accelerate_model(\n        self, model, model_path: str, *, strict: bool = True\n    ) -> None:\n        \"\"\"Load model params from Accelerate checkpoint directory/file.\"\"\"\n        checkpoint_path = model_path\n        if os.path.isdir(model_path):\n            safetensors_path = os.path.join(model_path, \"model.safetensors\")\n            pytorch_bin_path = os.path.join(model_path, \"pytorch_model.bin\")\n            if os.path.isfile(safetensors_path):\n                checkpoint_path = safetensors_path\n            elif os.path.isfile(pytorch_bin_path):\n                checkpoint_path = pytorch_bin_path\n            else:\n                target = self.accelerator.unwrap_model(model)\n                load_checkpoint_in_model(target, model_path, strict=strict)\n                return\n        state_dict = load_state_dict(checkpoint_path)\n        self._load_model_state(model, state_dict, strict=strict)\n\n    def _aggregate_episode_log_metrics(\n        self,\n    ) -> Dict[str, float]:\n        metrics: Dict[str, float] = {}\n        if len(self.ep_infos) == 0:\n            return metrics\n\n        metric_sums: Dict[str, float] = {}\n        metric_counts: Dict[str, int] = {}\n        for ep_info in self.ep_infos:\n            for key, value in ep_info.items():\n                cpu_value = self._log_value_to_cpu_tensor(value)\n                if cpu_value is None or cpu_value.numel() == 0:\n                    continue\n                metric_sums[key] = metric_sums.get(key, 0.0) + float(\n                    cpu_value.sum().item()\n                )\n                metric_counts[key] = metric_counts.get(key, 0) + int(\n                    cpu_value.numel()\n                )\n\n        for key, total in metric_sums.items():\n            count = metric_counts.get(key, 0)\n            if count <= 0:\n                continue\n            mean_value = total / float(count)\n            metric_key = key if \"/\" in key else f\"Episode/{key}\"\n            metrics[metric_key] = mean_value\n\n        return metrics\n\n    @staticmethod\n    def _log_value_to_cpu_tensor(value: Any) -> torch.Tensor | None:\n        if isinstance(value, torch.Tensor):\n            tensor = value.detach()\n            if tensor.ndim == 0:\n                tensor = tensor.unsqueeze(0)\n            return tensor.to(device=\"cpu\", dtype=torch.float32).reshape(-1)\n        if isinstance(value, np.ndarray):\n            return torch.as_tensor(value, dtype=torch.float32).reshape(-1)\n        if isinstance(value, (int, float)):\n            return torch.tensor([float(value)], dtype=torch.float32)\n        return None\n\n    def _log_iteration(\n        self,\n        *,\n        it: int,\n        loss_dict: Dict[str, Any],\n        collection_time: float,\n        learn_time: float,\n        synced_mean_reward: float | None = None,\n        synced_mean_episode_length: float | None = None,\n    ) -> None:\n        if not self.log_dir:\n            return\n\n        world_size = max(1, int(self.gpu_world_size))\n        fps = int(\n            self.num_steps_per_env\n            * self.num_envs\n            * world_size\n            / max(collection_time + learn_time, 1.0e-8)\n        )\n        total_learning_iterations = int(\n            getattr(\n                self,\n                \"total_learning_iterations\",\n                self.current_learning_iteration\n                + int(self.num_learning_iterations),\n            )\n        )\n\n        iteration_metrics: Dict[str, Any] = {\n            \"0-Train/iteration\": int(it),\n            \"0-Train/iterations_total\": total_learning_iterations,\n        }\n\n        for key, value in loss_dict.items():\n            if value is None:\n                continue\n            scalar = float(value)\n            iteration_metrics[f\"Loss/{key}\"] = scalar\n\n        iteration_metrics.update(\n            {\n                \"1-Perf/total_fps\": float(fps),\n                \"1-Perf/collection_time\": float(collection_time),\n                \"1-Perf/learning_time\": float(learn_time),\n            }\n        )\n\n        if (\n            synced_mean_reward is not None\n            and synced_mean_episode_length is not None\n        ):\n            iteration_metrics[\"0-Train/mean_reward\"] = float(\n                synced_mean_reward\n            )\n            iteration_metrics[\"0-Train/mean_episode_length\"] = float(\n                synced_mean_episode_length\n            )\n        elif len(self.rewbuffer) > 0:\n            mean_reward = float(statistics.mean(self.rewbuffer))\n            mean_episode_length = float(statistics.mean(self.lenbuffer))\n            iteration_metrics[\"0-Train/mean_reward\"] = mean_reward\n            iteration_metrics[\"0-Train/mean_episode_length\"] = (\n                mean_episode_length\n            )\n\n        iteration_metrics.update(self._aggregate_episode_log_metrics())\n        iteration_metrics.update(self._get_additional_log_metrics())\n\n        self.algo_logger.log_iteration(\n            step=it,\n            total_learning_iterations=total_learning_iterations,\n            metrics=iteration_metrics,\n        )\n\n    def load(self, ckpt_path):\n        raise NotImplementedError(\"Subclasses must implement load().\")\n\n    def save(self, path, infos=None):\n        raise NotImplementedError(\"Subclasses must implement save().\")\n"
  },
  {
    "path": "holomotion/src/algo/algo_utils.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport os\nfrom collections.abc import Mapping\nfrom typing import Any, Generator\n\nimport torch\nimport torch.nn as nn\nfrom loguru import logger\nfrom tabulate import tabulate\nfrom tensordict import TensorDict, tensorclass\n\n\nclass AlgoLogger:\n    def __init__(\n        self,\n        accelerator,\n        log_dir: str | None,\n        *,\n        is_main_process: bool,\n    ) -> None:\n        self.accelerator = accelerator\n        self.log_dir = log_dir\n        self.is_main_process = bool(is_main_process)\n\n    @staticmethod\n    def _is_scalar_metric(value: Any) -> bool:\n        if isinstance(value, (int, float)):\n            return True\n        if isinstance(value, torch.Tensor):\n            return value.numel() == 1\n        return False\n\n    @staticmethod\n    def _to_scalar(value: Any) -> float:\n        if isinstance(value, torch.Tensor):\n            return float(value.item())\n        return float(value)\n\n    @staticmethod\n    def _format_console_value(value: Any) -> str:\n        if isinstance(value, (int, float)):\n            value_f = float(value)\n            abs_value = abs(value_f)\n            if abs_value > 0.0 and (abs_value < 1.0e-4 or abs_value >= 1.0e4):\n                return f\"{value_f:.4e}\"\n            return f\"{value_f:.4f}\"\n        if isinstance(value, torch.Tensor) and value.numel() == 1:\n            value_f = float(value.item())\n            abs_value = abs(value_f)\n            if abs_value > 0.0 and (abs_value < 1.0e-4 or abs_value >= 1.0e4):\n                return f\"{value_f:.4e}\"\n            return f\"{value_f:.4f}\"\n        return str(value)\n\n    def _build_console_log(\n        self,\n        *,\n        step: int,\n        total_learning_iterations: int | None,\n        console_metrics: Mapping[str, Any],\n    ) -> str:\n        if total_learning_iterations is None:\n            title = f\"TRAINING LOG - Iteration {step}\"\n        else:\n            title = (\n                f\"TRAINING LOG - Iteration {step}/{total_learning_iterations}\"\n            )\n        table_data = [\n            [key, str(console_metrics[key])]\n            for key in sorted(console_metrics.keys())\n        ]\n        log_lines = [\n            \"\\n\" + \"=\" * 80,\n            title,\n            \"=\" * 80,\n            tabulate(\n                table_data,\n                headers=[\"Metric\", \"Value\"],\n                tablefmt=\"simple_outline\",\n                colalign=(\"left\", \"left\"),\n                disable_numparse=True,\n            ),\n            \"=\" * 80,\n            f\"Logging Directory: {os.path.abspath(self.log_dir)}\",\n            \"=\" * 80 + \"\\n\",\n        ]\n        return \"\\n\".join(log_lines)\n\n    def log_iteration(\n        self,\n        *,\n        step: int,\n        metrics: Mapping[str, Any],\n        total_learning_iterations: int | None = None,\n    ) -> None:\n        if not self.log_dir or not self.is_main_process:\n            return\n\n        tensorboard_metrics: dict[str, float] = {}\n        for key in sorted(metrics.keys()):\n            value = metrics[key]\n            if value is None or not self._is_scalar_metric(value):\n                continue\n            tensorboard_metrics[key] = self._to_scalar(value)\n\n        if len(tensorboard_metrics) > 0:\n            self.accelerator.log(tensorboard_metrics, step=int(step))\n\n        console_metrics = {\n            key: self._format_console_value(value)\n            for key, value in metrics.items()\n            if value is not None\n        }\n        console_log = self._build_console_log(\n            step=step,\n            total_learning_iterations=total_learning_iterations,\n            console_metrics=console_metrics,\n        )\n        logger.info(console_log)\n\n\n@tensorclass(shadow=True)\nclass PpoTransition:\n    \"\"\"PPO rollout transition tensorclass.\n\n    Batch axes:\n    - N: num_envs (per-step)\n    - B: minibatch_size (for minibatches)\n\n    Shapes (batch dims = [N] or [B]):\n    - obs: TensorDict with leaf tensors [*, ...]\n    - actions, teacher_actions, mu, sigma: [*, A]\n    - actions_log_prob, values, rewards, returns, advantages, dones: [*, 1]\n\n    All float tensors are float32. `dones` is bool.\n    \"\"\"\n\n    FIELD_SPECS = {\n        \"obs\": {\"kind\": \"obs\"},\n        \"actions\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"teacher_actions\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"mu\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"sigma\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"actions_log_prob\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"values\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"rewards\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"dones\": {\"shape\": (1,), \"dtype\": torch.bool},\n        \"returns\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"advantages\": {\"shape\": (1,), \"dtype\": torch.float32},\n    }\n\n    obs: TensorDict\n    actions: torch.Tensor\n    teacher_actions: torch.Tensor\n    mu: torch.Tensor\n    sigma: torch.Tensor\n    actions_log_prob: torch.Tensor\n    values: torch.Tensor\n    rewards: torch.Tensor\n    dones: torch.Tensor\n    returns: torch.Tensor\n    advantages: torch.Tensor\n\n\n@tensorclass(shadow=True)\nclass PpoVelocityTransition:\n    \"\"\"PPO rollout transition tensorclass.\n\n    Batch axes:\n    - N: num_envs (per-step)\n    - B: minibatch_size (for minibatches)\n\n    Shapes (batch dims = [N] or [B]):\n    - obs: TensorDict with leaf tensors [*, ...]\n    - actions, teacher_actions, mu, sigma: [*, A]\n    - actions_log_prob, values, rewards, returns, advantages, dones: [*, 1]\n    - velocity_commands: [*, 4]\n\n    All float tensors are float32. `dones` is bool.\n    \"\"\"\n\n    FIELD_SPECS = {\n        \"obs\": {\"kind\": \"obs\"},\n        \"actions\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"teacher_actions\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"mu\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"sigma\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"actions_log_prob\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"values\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"rewards\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"dones\": {\"shape\": (1,), \"dtype\": torch.bool},\n        \"returns\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"advantages\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"velocity_commands\": {\"shape\": (4,), \"dtype\": torch.float32},\n    }\n\n    obs: TensorDict\n    actions: torch.Tensor\n    teacher_actions: torch.Tensor\n    mu: torch.Tensor\n    sigma: torch.Tensor\n    actions_log_prob: torch.Tensor\n    values: torch.Tensor\n    rewards: torch.Tensor\n    dones: torch.Tensor\n    returns: torch.Tensor\n    advantages: torch.Tensor\n    velocity_commands: torch.Tensor\n\n\n@tensorclass(shadow=True)\nclass PpoAuxTransition:\n    \"\"\"PPO transition with auxiliary state-prediction supervision targets.\"\"\"\n\n    SHAPE_TOKENS = {\"C\": 0, \"K\": 0}\n\n    FIELD_SPECS = {\n        \"obs\": {\"kind\": \"obs\"},\n        \"actions\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"teacher_actions\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"mu\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"sigma\": {\"shape\": (\"A\",), \"dtype\": torch.float32},\n        \"actions_log_prob\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"values\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"rewards\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"dones\": {\"shape\": (1,), \"dtype\": torch.bool},\n        \"returns\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"advantages\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"gt_base_lin_vel_b\": {\"shape\": (3,), \"dtype\": torch.float32},\n        \"gt_root_height_rel_terrain\": {\"shape\": (1,), \"dtype\": torch.float32},\n        \"gt_keybody_contacts\": {\"shape\": (\"C\",), \"dtype\": torch.float32},\n        \"gt_ref_keybody_rel_pos\": {\n            \"shape\": (\"K\", 3),\n            \"dtype\": torch.float32,\n        },\n        \"gt_robot_keybody_rel_pos\": {\n            \"shape\": (\"K\", 3),\n            \"dtype\": torch.float32,\n        },\n        \"gt_denoise_ref_root_lin_vel\": {\n            \"shape\": (3,),\n            \"dtype\": torch.float32,\n        },\n        \"gt_denoise_ref_root_ang_vel\": {\n            \"shape\": (3,),\n            \"dtype\": torch.float32,\n        },\n        \"gt_denoise_ref_dof_pos\": {\n            \"shape\": (\"A\",),\n            \"dtype\": torch.float32,\n        },\n    }\n\n    obs: TensorDict\n    actions: torch.Tensor\n    teacher_actions: torch.Tensor\n    mu: torch.Tensor\n    sigma: torch.Tensor\n    actions_log_prob: torch.Tensor\n    values: torch.Tensor\n    rewards: torch.Tensor\n    dones: torch.Tensor\n    returns: torch.Tensor\n    advantages: torch.Tensor\n    gt_base_lin_vel_b: torch.Tensor\n    gt_root_height_rel_terrain: torch.Tensor\n    gt_keybody_contacts: torch.Tensor\n    gt_ref_keybody_rel_pos: torch.Tensor\n    gt_robot_keybody_rel_pos: torch.Tensor\n    gt_denoise_ref_root_lin_vel: torch.Tensor\n    gt_denoise_ref_root_ang_vel: torch.Tensor\n    gt_denoise_ref_dof_pos: torch.Tensor\n\n\nclass RolloutStorage(nn.Module):\n    \"\"\"Rollout storage as a single TensorDict buffer with batch size [T, N].\"\"\"\n\n    def __init__(\n        self,\n        num_envs,\n        num_transitions_per_env,\n        obs_template: TensorDict,\n        actions_shape,\n        device=\"cpu\",\n        command_name: str | None = None,\n        transition_cls: type[PpoTransition] = PpoTransition,\n    ):\n        super().__init__()\n        self.device = device\n        self.num_transitions_per_env = num_transitions_per_env\n        self.num_envs = num_envs\n        self.command_name = command_name\n        self._float_dtype = torch.float32\n        self._dones_dtype = torch.bool\n        self._transition_cls = transition_cls\n\n        obs_template = obs_template.to(self.device)\n        self.data = TensorDict(\n            {},\n            batch_size=[num_transitions_per_env, num_envs],\n            device=self.device,\n        )\n        self._allocate_from_transition(\n            obs_template=obs_template,\n            actions_shape=actions_shape,\n        )\n\n        self.step = 0\n\n    def _resolve_shape(self, spec_shape, actions_shape) -> tuple:\n        if spec_shape is None:\n            return tuple()\n        resolved = []\n        shape_tokens = getattr(self._transition_cls, \"SHAPE_TOKENS\", {})\n        for dim in spec_shape:\n            if dim == \"A\":\n                resolved.extend(tuple(actions_shape))\n            elif isinstance(dim, str) and dim in shape_tokens:\n                resolved.append(int(shape_tokens[dim]))\n            else:\n                resolved.append(int(dim))\n        return tuple(resolved)\n\n    def _allocate_from_transition(\n        self,\n        *,\n        obs_template: TensorDict,\n        actions_shape,\n    ) -> None:\n        specs = getattr(self._transition_cls, \"FIELD_SPECS\", None)\n        if not isinstance(specs, dict):\n            raise ValueError(\n                \"Transition class must define FIELD_SPECS for allocation.\"\n            )\n        for name, spec in specs.items():\n            if spec.get(\"kind\") == \"obs\":\n                leaf_keys = obs_template.keys(\n                    include_nested=True, leaves_only=True\n                )\n                for key in leaf_keys:\n                    value = obs_template.get(key)\n                    if not torch.is_tensor(value):\n                        continue\n                    dtype = (\n                        self._float_dtype\n                        if torch.is_floating_point(value)\n                        else value.dtype\n                    )\n                    key_tuple = key if isinstance(key, tuple) else (key,)\n                    self.data.set(\n                        (\"obs\",) + key_tuple,\n                        torch.empty(\n                            (self.num_transitions_per_env, self.num_envs)\n                            + tuple(value.shape[1:]),\n                            device=self.device,\n                            dtype=dtype,\n                        ),\n                    )\n                continue\n            shape_spec = spec.get(\"shape\")\n            dtype = spec.get(\"dtype\", self._float_dtype)\n            resolved = self._resolve_shape(shape_spec, actions_shape)\n            self.data.set(\n                name,\n                torch.empty(\n                    (self.num_transitions_per_env, self.num_envs) + resolved,\n                    device=self.device,\n                    dtype=dtype,\n                ),\n            )\n\n    def _to_storage_tensor(self, tensor: torch.Tensor) -> torch.Tensor:\n        if not torch.is_tensor(tensor):\n            raise TypeError(\"Expected a tensor for RolloutStorage update.\")\n        if tensor.device != self.device:\n            tensor = tensor.to(self.device)\n        if (\n            torch.is_floating_point(tensor)\n            and tensor.dtype != self._float_dtype\n        ):\n            tensor = tensor.to(dtype=self._float_dtype)\n        return tensor\n\n    def add(self, transition: PpoTransition) -> None:\n        if self.step >= self.num_transitions_per_env:\n            raise OverflowError(\"Rollout buffer overflow!\")\n        if not isinstance(transition, self._transition_cls):\n            raise TypeError(\n                \"Transition must match the RolloutStorage transition class.\"\n            )\n        if transition.batch_size is None or len(transition.batch_size) < 1:\n            raise ValueError(\"Transition must have batch size [N].\")\n        if int(transition.batch_size[0]) != int(self.num_envs):\n            raise ValueError(\n                f\"Transition batch size {transition.batch_size} \"\n                f\"does not match num_envs={self.num_envs}.\"\n            )\n\n        td = transition.to_tensordict()\n        td = td.apply(self._to_storage_tensor, inplace=False)\n        if \"dones\" in td.keys():\n            dones = td.get(\"dones\")\n            if torch.is_tensor(dones) and dones.dtype != self._dones_dtype:\n                td.set(\"dones\", dones.to(dtype=self._dones_dtype))\n        self.data[self.step].update_(td)\n\n        self.step += 1\n\n    def clear(self) -> None:\n        self.step = 0\n\n    def compute_returns(\n        self,\n        last_values,\n        gamma,\n        lam,\n        normalize_advantage: bool = False,\n    ) -> None:\n        advantage = 0\n        rewards = self.data[\"rewards\"]\n        values = self.data[\"values\"]\n        dones = self.data[\"dones\"]\n        returns = self.data[\"returns\"]\n        advantages = self.data[\"advantages\"]\n        for step in reversed(range(self.num_transitions_per_env)):\n            if step == self.num_transitions_per_env - 1:\n                next_values = last_values\n            else:\n                next_values = values[step + 1]\n            next_is_not_terminal = 1.0 - dones[step].float()\n            delta = (\n                rewards[step]\n                + next_is_not_terminal * gamma * next_values\n                - values[step]\n            )\n            advantage = delta + next_is_not_terminal * gamma * lam * advantage\n            returns[step] = advantage + values[step]\n\n        advantages.copy_(returns - values)\n        if normalize_advantage:\n            flat = advantages.view(-1)\n            mean = flat.mean()\n            std = flat.std().clamp_min(1.0e-8)\n            advantages.copy_((advantages - mean) / std)\n\n    @torch.no_grad()\n    def normalize_advantages_global(\n        self,\n        *,\n        accelerator=None,\n        eps: float = 1.0e-8,\n    ) -> None:\n        \"\"\"Global advantage normalization over the full rollout buffer.\n\n        This normalizes `self.data[\"advantages\"]` in-place using mean/std over\n        all `[T * N]` samples. If `accelerator` is provided, the moments are\n        aggregated across processes via `accelerator.reduce(..., reduction=\"sum\")`.\n        \"\"\"\n        advantages = self.data[\"advantages\"]\n        advantages_flat = advantages.view(-1).float()\n\n        count = torch.tensor(\n            [advantages_flat.numel()], device=self.device, dtype=torch.float32\n        )\n        sum_local = advantages_flat.sum()\n        sqsum_local = (advantages_flat * advantages_flat).sum()\n        if accelerator is not None and int(accelerator.num_processes) > 1:\n            count_g = accelerator.reduce(count, reduction=\"sum\")\n            sum_g = accelerator.reduce(sum_local, reduction=\"sum\")\n            sqsum_g = accelerator.reduce(sqsum_local, reduction=\"sum\")\n        else:\n            count_g = count\n            sum_g = sum_local\n            sqsum_g = sqsum_local\n\n        mean = sum_g / count_g\n        var = (sqsum_g / count_g) - mean * mean\n        std = torch.sqrt(var.clamp_min(eps))\n        advantages.copy_((advantages - mean) / std)\n\n    @torch.no_grad()\n    def normalize_advantages_global_by_move_mask(\n        self,\n        *,\n        accelerator=None,\n        eps: float = 1.0e-8,\n        move_threshold: float = 0.5,\n    ) -> None:\n        \"\"\"Global advantage normalization split by move vs static (velocity commands).\n\n        Assumes:\n        - `advantages`: [T, N, 1]\n        - `velocity_commands`: [T, N, 4], where channel 0 is move_mask in {0,1}.\n        \"\"\"\n        velocity_commands = self.data.get(\"velocity_commands\", None)\n        if velocity_commands is None:\n            raise ValueError(\n                \"velocity_commands is required for global advantage normalization by move mask.\"\n            )\n\n        advantages = self.data[\"advantages\"]\n        advantages_flat = advantages.view(-1).float()\n\n        vel_flat = velocity_commands.view(-1, int(velocity_commands.shape[-1]))\n        move_mask = vel_flat[:, 0] > float(move_threshold)\n        static_mask = ~move_mask\n\n        count_all = torch.tensor(\n            [advantages_flat.numel()], device=self.device, dtype=torch.float32\n        )\n        sum_all_local = advantages_flat.sum()\n        sqsum_all_local = (advantages_flat * advantages_flat).sum()\n        if accelerator is not None and int(accelerator.num_processes) > 1:\n            count_all_g = accelerator.reduce(count_all, reduction=\"sum\")\n            sum_all_g = accelerator.reduce(sum_all_local, reduction=\"sum\")\n            sqsum_all_g = accelerator.reduce(sqsum_all_local, reduction=\"sum\")\n        else:\n            count_all_g = count_all\n            sum_all_g = sum_all_local\n            sqsum_all_g = sqsum_all_local\n\n        mean_all = sum_all_g / count_all_g\n        var_all = (sqsum_all_g / count_all_g) - mean_all * mean_all\n        std_all = torch.sqrt(var_all.clamp_min(eps))\n\n        def _group_stats(\n            mask: torch.Tensor,\n        ) -> tuple[torch.Tensor, torch.Tensor]:\n            if not bool(mask.any().item()):\n                return mean_all, std_all\n            mask_f = mask.to(dtype=torch.float32)\n            count_local = mask_f.sum()\n            sum_local = (advantages_flat * mask_f).sum()\n            sqsum_local = (advantages_flat * advantages_flat * mask_f).sum()\n            if accelerator is not None and int(accelerator.num_processes) > 1:\n                count_g = accelerator.reduce(count_local, reduction=\"sum\")\n                sum_g = accelerator.reduce(sum_local, reduction=\"sum\")\n                sqsum_g = accelerator.reduce(sqsum_local, reduction=\"sum\")\n            else:\n                count_g = count_local\n                sum_g = sum_local\n                sqsum_g = sqsum_local\n            if float(count_g.item()) <= 0.0:\n                return mean_all, std_all\n            mean = sum_g / count_g\n            var = (sqsum_g / count_g) - mean * mean\n            std = torch.sqrt(var.clamp_min(eps))\n            return mean, std\n\n        move_mean, move_std = _group_stats(move_mask)\n        static_mean, static_std = _group_stats(static_mask)\n\n        advantages_norm = advantages_flat.clone()\n        if bool(move_mask.any().item()):\n            advantages_norm[move_mask] = (\n                advantages_flat[move_mask] - move_mean\n            ) / move_std\n        if bool(static_mask.any().item()):\n            advantages_norm[static_mask] = (\n                advantages_flat[static_mask] - static_mean\n            ) / static_std\n\n        self.data[\"advantages\"].copy_(advantages_norm.view_as(advantages))\n\n    @torch.no_grad()\n    def normalize_advantages_global_by_command(\n        self,\n        *,\n        command_name: str | None,\n        accelerator=None,\n        eps: float = 1.0e-8,\n    ) -> None:\n        \"\"\"Dispatch global advantage normalization based on command type/name.\"\"\"\n        if (\n            command_name == \"base_velocity\"\n            and self.data.get(\"velocity_commands\", None) is not None\n        ):\n            self.normalize_advantages_global_by_move_mask(\n                accelerator=accelerator, eps=eps\n            )\n            return\n        self.normalize_advantages_global(accelerator=accelerator, eps=eps)\n\n    def iter_minibatches(\n        self,\n        num_mini_batches: int,\n        num_epochs: int,\n    ) -> Generator[PpoTransition, None, None]:\n        if self.step != self.num_transitions_per_env:\n            raise RuntimeError(\n                f\"RolloutStorage buffer not full: step={self.step}, \"\n                f\"expected={self.num_transitions_per_env}. \"\n                \"This would mix stale entries from a previous rollout.\"\n            )\n        batch_size = self.num_envs * self.num_transitions_per_env\n        (\n            effective_num_mini_batches,\n            mini_batch_size,\n        ) = self.resolve_mini_batch_partition(batch_size, num_mini_batches)\n\n        indices = torch.randperm(\n            batch_size,\n            requires_grad=False,\n            device=self.device,\n        )[: effective_num_mini_batches * mini_batch_size]\n\n        flat = self.data.flatten(0, 1)  # [T * N, ...]\n\n        for _ in range(num_epochs):\n            for i in range(effective_num_mini_batches):\n                start = i * mini_batch_size\n                end = (i + 1) * mini_batch_size\n                batch_idx = indices[start:end]\n                batch = flat[batch_idx]\n                yield self._transition_cls.from_tensordict(batch)\n\n    @staticmethod\n    def resolve_mini_batch_partition(\n        batch_size: int,\n        num_mini_batches: int,\n    ) -> tuple[int, int]:\n        if batch_size <= 0:\n            raise RuntimeError(\n                \"RolloutStorage minibatch partition requires batch_size > 0.\"\n            )\n        effective_num_mini_batches = max(\n            1, min(int(num_mini_batches), int(batch_size))\n        )\n        mini_batch_size = max(1, batch_size // effective_num_mini_batches)\n        return effective_num_mini_batches, mini_batch_size\n"
  },
  {
    "path": "holomotion/src/algo/ppo.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport inspect\nimport json\nimport math\nimport os\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport numpy as np\nfrom loguru import logger\nfrom tabulate import tabulate\nfrom tqdm import tqdm\nimport imageio\nfrom omegaconf import OmegaConf\n\nfrom holomotion.src.algo.algo_base import BaseOnpolicyRL\nfrom holomotion.src.algo.algo_utils import (\n    PpoTransition,\n    PpoVelocityTransition,\n    RolloutStorage,\n)\nfrom holomotion.src.utils.onnx_export import (\n    export_policy_to_onnx as export_policy_to_onnx_common,\n)\nfrom tensordict import TensorDict\n\n\ndef _checkpoint_state_to_cpu(value):\n    if isinstance(value, torch.Tensor):\n        return value.detach().cpu()\n    if isinstance(value, dict):\n        return {k: _checkpoint_state_to_cpu(v) for k, v in value.items()}\n    if isinstance(value, list):\n        return [_checkpoint_state_to_cpu(v) for v in value]\n    if isinstance(value, tuple):\n        return tuple(_checkpoint_state_to_cpu(v) for v in value)\n    return value\n\n\nclass PPO(BaseOnpolicyRL):\n    def _setup_configs(self):\n        super()._setup_configs()\n        self.desired_kl = self.config.desired_kl\n        self.schedule = self.config.schedule\n        self.actor_learning_rate = self.config.get(\n            \"actor_learning_rate\", self.config.get(\"learning_rate\", 3e-4)\n        )\n        self.critic_learning_rate = self.config.get(\n            \"critic_learning_rate\", self.config.get(\"learning_rate\", 3e-4)\n        )\n        self.base_actor_learning_rate = float(self.actor_learning_rate)\n        self.base_critic_learning_rate = float(self.critic_learning_rate)\n        self.actor_beta1 = self.config.get(\"actor_beta1\", 0.9)\n        self.actor_beta2 = self.config.get(\"actor_beta2\", 0.999)\n        self.critic_beta1 = self.config.get(\"critic_beta1\", 0.9)\n        self.critic_beta2 = self.config.get(\"critic_beta2\", 0.999)\n        self.optimizer_type = self.config.optimizer_type\n        self.clip_param = self.config.clip_param\n        self.num_learning_epochs = int(self.config.num_learning_epochs)\n        self.configured_num_mini_batches = int(self.config.num_mini_batches)\n        if self.configured_num_mini_batches < 1:\n            raise ValueError(\"num_mini_batches must be >= 1.\")\n        distributed_update_cfg = self.config.get(\"distributed_update\", {})\n        self.distributed_update_mode = str(\n            distributed_update_cfg.get(\"mode\", \"legacy\")\n        ).lower()\n        if self.distributed_update_mode not in {\"legacy\", \"scalable\"}:\n            raise ValueError(\n                \"distributed_update.mode must be one of \"\n                \"{'legacy', 'scalable'}.\"\n            )\n        self.requested_num_mini_batches = self._resolve_num_mini_batches(\n            self.configured_num_mini_batches\n        )\n        self.num_mini_batches = self.requested_num_mini_batches\n        self.gamma = self.config.gamma\n        self.lam = self.config.lam\n        self.value_loss_coef = self.config.value_loss_coef\n        self.initial_entropy_coef = float(self.config.entropy_coef)\n        self.anneal_entropy = bool(self.config.get(\"anneal_entropy\", False))\n        self.zero_entropy_point = float(\n            self.config.get(\"zero_entropy_point\", 1.0)\n        )\n        self._validate_entropy_schedule_config(\n            initial_entropy_coef=self.initial_entropy_coef,\n            anneal_entropy=self.anneal_entropy,\n            zero_entropy_point=self.zero_entropy_point,\n        )\n        self.entropy_coef = self.initial_entropy_coef\n        self.max_grad_norm = self.config.max_grad_norm\n        self.use_clipped_value_loss = self.config.use_clipped_value_loss\n        adaptive_lr_cfg = self.config.get(\"adaptive_lr\", {})\n        self.adaptive_lr_adapt_critic = bool(\n            adaptive_lr_cfg.get(\"adapt_critic\", False)\n        )\n        self.adaptive_lr_factor = float(adaptive_lr_cfg.get(\"lr_scaler\", 1.2))\n        self.adaptive_lr_kl_high_factor = float(\n            adaptive_lr_cfg.get(\"kl_high_factor\", 2.0)\n        )\n        self.adaptive_lr_kl_low_factor = float(\n            adaptive_lr_cfg.get(\"kl_low_factor\", 0.5)\n        )\n        self.adaptive_lr_min = float(\n            adaptive_lr_cfg.get(\"min_learning_rate\", 1.0e-6)\n        )\n        self.adaptive_lr_max = float(\n            adaptive_lr_cfg.get(\"max_learning_rate\", 1.0)\n        )\n        if self.adaptive_lr_factor <= 1.0:\n            raise ValueError(\"adaptive_lr.lr_scaler must be > 1.\")\n        if self.adaptive_lr_kl_high_factor <= 0.0:\n            raise ValueError(\"adaptive_lr.kl_high_factor must be > 0.\")\n        if self.adaptive_lr_kl_low_factor <= 0.0:\n            raise ValueError(\"adaptive_lr.kl_low_factor must be > 0.\")\n        if self.adaptive_lr_min <= 0.0:\n            raise ValueError(\"adaptive_lr.min_learning_rate must be > 0.\")\n        if self.adaptive_lr_max < self.adaptive_lr_min:\n            raise ValueError(\n                \"adaptive_lr.max_learning_rate must be >= \"\n                \"adaptive_lr.min_learning_rate.\"\n            )\n        kl_early_stop_cfg = distributed_update_cfg.get(\"kl_early_stop\", {})\n        self.kl_early_stop_enabled = bool(\n            kl_early_stop_cfg.get(\"enabled\", False)\n        )\n        kl_signal_mode = str(\n            kl_early_stop_cfg.get(\"signal\", \"window_mean\")\n        ).lower()\n        if kl_signal_mode != \"window_mean\":\n            raise ValueError(\n                \"Only distributed_update.kl_early_stop.signal='window_mean' \"\n                \"is supported.\"\n            )\n        self.kl_early_stop_window_size = int(\n            kl_early_stop_cfg.get(\"window_size\", 3)\n        )\n        self.kl_early_stop_factor = float(kl_early_stop_cfg.get(\"factor\", 2.0))\n        self.kl_early_stop_min_updates = int(\n            kl_early_stop_cfg.get(\"min_updates\", 1)\n        )\n        if self.kl_early_stop_window_size < 1:\n            raise ValueError(\n                \"distributed_update.kl_early_stop.window_size must be >= 1.\"\n            )\n        if self.kl_early_stop_factor <= 0.0:\n            raise ValueError(\n                \"distributed_update.kl_early_stop.factor must be > 0.\"\n            )\n        if self.kl_early_stop_min_updates < 1:\n            raise ValueError(\n                \"distributed_update.kl_early_stop.min_updates must be >= 1.\"\n            )\n        if self.kl_early_stop_enabled and self.desired_kl is None:\n            raise ValueError(\n                \"distributed_update.kl_early_stop requires desired_kl to be set.\"\n            )\n        self.global_advantage_norm = bool(\n            self.config.get(\"global_advantage_norm\", True)\n        )\n        self.normalize_advantage_per_mini_batch = bool(\n            self.config.get(\"normalize_advantage_per_mini_batch\", False)\n        )\n        self.distributed_lr_scale_factor = self._compute_lr_scale_factor(\n            distributed_update_cfg.get(\"lr_scale\", {})\n        )\n        self.actor_learning_rate = (\n            self.base_actor_learning_rate * self.distributed_lr_scale_factor\n        )\n        self.critic_learning_rate = (\n            self.base_critic_learning_rate * self.distributed_lr_scale_factor\n        )\n        self._last_update_metrics = {\n            \"0-Train/configured_num_mini_batches\": float(\n                self.configured_num_mini_batches\n            ),\n            \"0-Train/requested_num_mini_batches\": float(\n                self.requested_num_mini_batches\n            ),\n            \"0-Train/effective_num_mini_batches\": float(\n                self.requested_num_mini_batches\n            ),\n            \"0-Train/mini_batch_size_per_rank\": 0.0,\n            \"0-Train/num_updates_executed\": 0.0,\n            \"0-Train/lr_scale_factor\": float(self.distributed_lr_scale_factor),\n            \"0-Train/scalable_distributed_update\": float(\n                self.distributed_update_mode == \"scalable\"\n            ),\n            \"0-Train/kl_windowed\": 0.0,\n            \"0-Train/kl_stop_triggered\": 0.0,\n            \"0-Train/kl_stop_analytic\": 0.0,\n        }\n        self._offline_evaluating: bool = False\n\n        motion_cfg = self.env_config.config.robot.motion\n        sampling_strategy_cfg = motion_cfg.get(\"sampling_strategy\", None)\n        if sampling_strategy_cfg is None:\n            sampling_strategy = \"uniform\"\n        else:\n            sampling_strategy = str(sampling_strategy_cfg).lower()\n        valid_strategies = {\"uniform\", \"weighted_bin\", \"curriculum\"}\n        if sampling_strategy not in valid_strategies:\n            raise ValueError(\n                f\"Invalid sampling_strategy '{sampling_strategy}'. \"\n                f\"Expected one of {sorted(valid_strategies)}.\"\n            )\n        self.sampling_strategy: str = sampling_strategy\n        self.weighted_bin_cfg = dict(motion_cfg.get(\"weighted_bin\", {}))\n\n        sym_cfg = self.config.get(\"symmetry_loss\", {})\n        self.symmetry_loss_enabled = bool(sym_cfg.get(\"enabled\", False))\n        self.symmetry_loss_coef = float(sym_cfg.get(\"coef\", 0.0))\n        self._sym_dof_perm: torch.Tensor | None = None\n        self._sym_dof_sign: torch.Tensor | None = None\n        self._obs_mirror_map: dict[str, callable] = {}\n        if self._symmetry_loss_active():\n            self._setup_symmetry()\n\n    def _resolve_num_mini_batches(self, base_num_mini_batches: int) -> int:\n        if self.distributed_update_mode == \"legacy\" and self.is_distributed:\n            return max(1, base_num_mini_batches * int(self.gpu_world_size))\n        return max(1, base_num_mini_batches)\n\n    def _compute_lr_scale_factor(self, lr_scale_cfg) -> float:\n        scale_mode = str(lr_scale_cfg.get(\"mode\", \"none\")).lower()\n        if scale_mode not in {\n            \"none\",\n            \"sqrt_world_size\",\n            \"linear_world_size\",\n        }:\n            raise ValueError(\n                \"distributed_update.lr_scale.mode must be one of \"\n                \"{'none', 'sqrt_world_size', 'linear_world_size'}.\"\n            )\n        reference_world_size = float(\n            lr_scale_cfg.get(\"reference_world_size\", 1)\n        )\n        if reference_world_size <= 0.0:\n            raise ValueError(\n                \"distributed_update.lr_scale.reference_world_size must be > 0.\"\n            )\n        runtime_world_size = float(\n            self.gpu_world_size if self.is_distributed else 1\n        )\n        world_ratio = runtime_world_size / reference_world_size\n        if scale_mode == \"none\":\n            scale = 1.0\n        elif scale_mode == \"sqrt_world_size\":\n            scale = math.sqrt(world_ratio)\n        else:\n            scale = world_ratio\n        max_scale = lr_scale_cfg.get(\"max_scale\", None)\n        if max_scale is not None:\n            max_scale = float(max_scale)\n            if max_scale <= 0.0:\n                raise ValueError(\n                    \"distributed_update.lr_scale.max_scale must be > 0 when set.\"\n                )\n            scale = min(scale, max_scale)\n        return float(scale)\n\n    def _symmetry_loss_active(self) -> bool:\n        return bool(\n            getattr(self, \"command_name\", None) == \"base_velocity\"\n            and getattr(self, \"symmetry_loss_enabled\", False)\n            and float(getattr(self, \"symmetry_loss_coef\", 0.0)) > 0.0\n        )\n\n    @staticmethod\n    def _omega_or_obj_to_dict(value):\n        if value is None:\n            return {}\n        if OmegaConf.is_config(value):\n            return OmegaConf.to_container(value, resolve=True)\n        if isinstance(value, dict):\n            return value\n        if hasattr(value, \"__dict__\"):\n            return vars(value)\n        return {}\n\n    def _setup_symmetry(self) -> None:\n        robot_asset = self.env._env.scene[\"robot\"]\n        joint_names = list(getattr(robot_asset, \"joint_names\", []))\n        if len(joint_names) != int(self.num_actions):\n            raise ValueError(\n                \"symmetry_loss requires simulator joint_names to match \"\n                f\"num_actions, got {len(joint_names)} vs {self.num_actions}.\"\n            )\n\n        name_to_idx = {name: idx for idx, name in enumerate(joint_names)}\n        perm: list[int] = []\n        for name in joint_names:\n            if name.startswith(\"left_\"):\n                mirror_name = \"right_\" + name[len(\"left_\") :]\n            elif name.startswith(\"right_\"):\n                mirror_name = \"left_\" + name[len(\"right_\") :]\n            else:\n                mirror_name = name\n            perm.append(int(name_to_idx.get(mirror_name, name_to_idx[name])))\n\n        sym_cfg = self._omega_or_obj_to_dict(\n            self.config.get(\"symmetry_loss\", {})\n        )\n        sign_by_name = sym_cfg.get(\"dof_sign_by_name\", None)\n        if not sign_by_name:\n            robot_cfg = self._omega_or_obj_to_dict(\n                getattr(\n                    getattr(self.env_config, \"config\", None), \"robot\", None\n                )\n            )\n            sign_by_name = robot_cfg.get(\"dof_sign_by_name\", None)\n        sign_by_name = self._omega_or_obj_to_dict(sign_by_name)\n        if len(sign_by_name) == 0:\n            raise ValueError(\n                \"symmetry_loss requires dof_sign_by_name in algo or robot config.\"\n            )\n\n        sign = [float(sign_by_name.get(name, 1.0)) for name in joint_names]\n        self._sym_dof_perm = torch.tensor(\n            perm, device=self.device, dtype=torch.long\n        )\n        self._sym_dof_sign = torch.tensor(\n            sign, device=self.device, dtype=torch.float32\n        )\n        self._build_obs_mirror_map()\n\n    def _extract_obs_mirror_metadata(self) -> dict[str, dict]:\n        obs_cfg = getattr(\n            getattr(self.env_config, \"config\", None), \"obs\", None\n        )\n        obs_root = self._omega_or_obj_to_dict(obs_cfg)\n        obs_groups = obs_root.get(\"obs_groups\", {})\n        metadata: dict[str, dict] = {}\n        for group_name, group_cfg in obs_groups.items():\n            if not isinstance(group_cfg, dict):\n                continue\n            for term_entry in group_cfg.get(\"atomic_obs_list\", []):\n                if not isinstance(term_entry, dict):\n                    continue\n                for term_name, term_cfg in term_entry.items():\n                    term_cfg = self._omega_or_obj_to_dict(term_cfg)\n                    mirror_func = term_cfg.get(\"mirror_func\", None)\n                    if not mirror_func:\n                        continue\n                    metadata[f\"{group_name}/{term_name}\"] = {\n                        \"mirror_func\": str(mirror_func),\n                        \"mirror_config\": self._omega_or_obj_to_dict(\n                            term_cfg.get(\"mirror_config\", {})\n                        ),\n                    }\n        return metadata\n\n    def _get_actor_schema_terms(self) -> set[str]:\n        module_dict = self._omega_or_obj_to_dict(\n            self.config.get(\"module_dict\", {})\n        )\n        actor_cfg = self._omega_or_obj_to_dict(module_dict.get(\"actor\", {}))\n        actor_schema = self._omega_or_obj_to_dict(\n            actor_cfg.get(\"obs_schema\", {})\n        )\n        actor_terms: set[str] = set()\n        for seq_cfg in actor_schema.values():\n            if not isinstance(seq_cfg, dict):\n                continue\n            for term in seq_cfg.get(\"terms\", []):\n                actor_terms.add(str(term))\n        return actor_terms\n\n    def _build_obs_mirror_map(self) -> None:\n        from holomotion.src.env.isaaclab_components.isaaclab_observation import (\n            MirrorFunctions,\n        )\n\n        self._obs_mirror_map = {}\n        if self._sym_dof_perm is None or self._sym_dof_sign is None:\n            return\n\n        term_meta = self._extract_obs_mirror_metadata()\n        actor_terms = self._get_actor_schema_terms()\n        for term in actor_terms:\n            meta = term_meta.get(term)\n            if meta is None:\n                continue\n            mirror_func = str(meta.get(\"mirror_func\", \"\"))\n            if mirror_func == \"mirror_dof\":\n                perm = self._sym_dof_perm\n                sign = self._sym_dof_sign\n\n                def _fn(x, perm=perm, sign=sign):\n                    perm_local = perm.to(device=x.device, dtype=torch.long)\n                    sign_local = sign.to(device=x.device, dtype=x.dtype)\n                    return MirrorFunctions.mirror_dof(\n                        x, perm=perm_local, sign=sign_local\n                    )\n\n            elif mirror_func == \"mirror_vec3\":\n\n                def _fn(x):\n                    return MirrorFunctions.mirror_vec3(x)\n\n            elif mirror_func == \"mirror_axial_vec3\":\n\n                def _fn(x):\n                    return MirrorFunctions.mirror_axial_vec3(x)\n\n            elif mirror_func == \"mirror_velocity_command\":\n\n                def _fn(x):\n                    return MirrorFunctions.mirror_velocity_command(x)\n\n            else:\n                continue\n\n            self._obs_mirror_map[term] = _fn\n\n    @staticmethod\n    def _td_key_to_path(key) -> str:\n        if isinstance(key, tuple):\n            return \"/\".join(str(part) for part in key)\n        return str(key)\n\n    def _mirror_actor_obs(self, obs_td: TensorDict) -> TensorDict:\n        if (\n            not self._symmetry_loss_active()\n            or not isinstance(obs_td, TensorDict)\n            or len(getattr(self, \"_obs_mirror_map\", {})) == 0\n        ):\n            return obs_td\n\n        mirrored = TensorDict(\n            {},\n            batch_size=list(obs_td.batch_size),\n            device=obs_td.device,\n        )\n        for key in obs_td.keys(include_nested=True, leaves_only=True):\n            key_tuple = key if isinstance(key, tuple) else (key,)\n            value = obs_td.get(key_tuple)\n            mirror_fn = self._obs_mirror_map.get(\n                self._td_key_to_path(key_tuple)\n            )\n            mirrored.set(\n                key_tuple,\n                mirror_fn(value) if mirror_fn is not None else value,\n            )\n        return mirrored\n\n    def _mirror_env_action(self, actions: torch.Tensor) -> torch.Tensor:\n        from holomotion.src.env.isaaclab_components.isaaclab_observation import (\n            MirrorFunctions,\n        )\n\n        if not self._symmetry_loss_active():\n            return actions\n        if self._sym_dof_perm is None or self._sym_dof_sign is None:\n            raise RuntimeError(\n                \"Symmetry DOF permutation/signs are not initialized.\"\n            )\n        return MirrorFunctions.mirror_action(\n            actions,\n            perm=self._sym_dof_perm.to(\n                device=actions.device, dtype=torch.long\n            ),\n            sign=self._sym_dof_sign.to(\n                device=actions.device, dtype=actions.dtype\n            ),\n        )\n\n    def _compute_analytic_kl(\n        self,\n        old_mu: torch.Tensor,\n        old_sigma: torch.Tensor,\n        new_mu: torch.Tensor,\n        new_sigma: torch.Tensor,\n        weight: torch.Tensor | None = None,\n    ) -> float:\n        with torch.no_grad():\n            kl_vec = torch.sum(\n                torch.log((new_sigma + 1.0e-8) / (old_sigma + 1.0e-8))\n                + (torch.square(old_sigma) + torch.square(old_mu - new_mu))\n                / (2.0 * torch.square(new_sigma) + 1.0e-8)\n                - 0.5,\n                dim=-1,\n            )\n            if weight is None:\n                kl_sum = kl_vec.sum()\n                kl_count = torch.tensor(\n                    float(kl_vec.numel()),\n                    device=self.device,\n                    dtype=torch.float32,\n                )\n            else:\n                kl_weight = weight.to(dtype=torch.float32)\n                kl_sum = (kl_vec * kl_weight).sum()\n                kl_count = kl_weight.sum()\n            if self.is_distributed:\n                kl_sum = self.accelerator.reduce(kl_sum, reduction=\"sum\")\n                kl_count = self.accelerator.reduce(kl_count, reduction=\"sum\")\n            kl_mean = kl_sum / kl_count.clamp_min(1.0)\n        return float(kl_mean.item())\n\n    def _compute_clip_fraction(\n        self,\n        ratio: torch.Tensor,\n        weight: torch.Tensor | None = None,\n    ) -> float:\n        with torch.no_grad():\n            clipped = (\n                (ratio < (1.0 - self.clip_param))\n                | (ratio > (1.0 + self.clip_param))\n            ).to(torch.float32)\n            if weight is None:\n                clip_sum = clipped.sum()\n                clip_count = torch.tensor(\n                    float(clipped.numel()),\n                    device=self.device,\n                    dtype=torch.float32,\n                )\n            else:\n                clip_weight = weight.to(dtype=torch.float32)\n                clip_sum = (clipped * clip_weight).sum()\n                clip_count = clip_weight.sum()\n            if self.is_distributed:\n                clip_sum = self.accelerator.reduce(clip_sum, reduction=\"sum\")\n                clip_count = self.accelerator.reduce(\n                    clip_count, reduction=\"sum\"\n                )\n            clip_fraction = clip_sum / clip_count.clamp_min(1.0)\n        return float(clip_fraction.item())\n\n    def _compute_explained_variance(\n        self,\n        target: torch.Tensor,\n        prediction: torch.Tensor,\n        weight: torch.Tensor | None = None,\n    ) -> float:\n        with torch.no_grad():\n            target_f = target.float()\n            prediction_f = prediction.float()\n            residual = target_f - prediction_f\n            if weight is None:\n                weight_f = torch.ones_like(target_f, dtype=torch.float32)\n            else:\n                weight_f = weight.to(dtype=torch.float32)\n\n            count = weight_f.sum()\n            target_sum = (target_f * weight_f).sum()\n            target_sq_sum = (target_f.square() * weight_f).sum()\n            residual_sum = (residual * weight_f).sum()\n            residual_sq_sum = (residual.square() * weight_f).sum()\n\n            if self.is_distributed:\n                count = self.accelerator.reduce(count, reduction=\"sum\")\n                target_sum = self.accelerator.reduce(\n                    target_sum, reduction=\"sum\"\n                )\n                target_sq_sum = self.accelerator.reduce(\n                    target_sq_sum, reduction=\"sum\"\n                )\n                residual_sum = self.accelerator.reduce(\n                    residual_sum, reduction=\"sum\"\n                )\n                residual_sq_sum = self.accelerator.reduce(\n                    residual_sq_sum, reduction=\"sum\"\n                )\n\n            denom = count.clamp_min(1.0)\n            target_mean = target_sum / denom\n            residual_mean = residual_sum / denom\n            target_var = target_sq_sum / denom - target_mean.square()\n            residual_var = residual_sq_sum / denom - residual_mean.square()\n            if float(target_var.item()) <= 1.0e-8:\n                return 0.0\n            explained_variance = 1.0 - residual_var / target_var\n        return float(explained_variance.item())\n\n    def _set_optimizer_learning_rates(self) -> None:\n        for param_group in self.actor_optimizer.param_groups:\n            param_group[\"lr\"] = self.actor_learning_rate\n        for param_group in self.critic_optimizer.param_groups:\n            param_group[\"lr\"] = self.critic_learning_rate\n\n    @staticmethod\n    def _validate_entropy_schedule_config(\n        *,\n        initial_entropy_coef: float,\n        anneal_entropy: bool,\n        zero_entropy_point: float,\n    ) -> None:\n        if float(initial_entropy_coef) < 0.0:\n            raise ValueError(\"entropy_coef must be >= 0.\")\n        if anneal_entropy and not (0.0 < float(zero_entropy_point) <= 1.0):\n            raise ValueError(\n                \"zero_entropy_point must be in (0.0, 1.0] when \"\n                \"anneal_entropy is enabled.\"\n            )\n\n    def _get_effective_entropy_coef(self) -> float:\n        if self.initial_entropy_coef <= 0.0 or not self.anneal_entropy:\n            return float(self.initial_entropy_coef)\n        total_learning_iterations = int(\n            getattr(\n                self,\n                \"total_learning_iterations\",\n                self.current_learning_iteration\n                + int(self.num_learning_iterations),\n            )\n        )\n        total_learning_iterations = max(1, total_learning_iterations)\n        zero_entropy_iteration = float(self.zero_entropy_point) * float(\n            total_learning_iterations\n        )\n        anneal_scale = max(\n            0.0,\n            1.0\n            - float(self.current_learning_iteration) / zero_entropy_iteration,\n        )\n        return float(self.initial_entropy_coef * anneal_scale)\n\n    def _apply_adaptive_lr(self, kl_signal: float | None) -> None:\n        if (\n            self.desired_kl is None\n            or self.schedule != \"adaptive\"\n            or kl_signal is None\n        ):\n            return\n        if kl_signal > self.desired_kl * self.adaptive_lr_kl_high_factor:\n            self.actor_learning_rate = max(\n                self.adaptive_lr_min,\n                self.actor_learning_rate / self.adaptive_lr_factor,\n            )\n            if self.adaptive_lr_adapt_critic:\n                self.critic_learning_rate = max(\n                    self.adaptive_lr_min,\n                    self.critic_learning_rate / self.adaptive_lr_factor,\n                )\n        elif (\n            kl_signal > 0.0\n            and kl_signal < self.desired_kl * self.adaptive_lr_kl_low_factor\n        ):\n            self.actor_learning_rate = min(\n                self.adaptive_lr_max,\n                self.actor_learning_rate * self.adaptive_lr_factor,\n            )\n            if self.adaptive_lr_adapt_critic:\n                self.critic_learning_rate = min(\n                    self.adaptive_lr_max,\n                    self.critic_learning_rate * self.adaptive_lr_factor,\n                )\n        self._set_optimizer_learning_rates()\n\n    def _compute_windowed_kl_signal(\n        self, recent_analytic_kls: list[float]\n    ) -> float | None:\n        if len(recent_analytic_kls) < self.kl_early_stop_window_size:\n            return None\n        window = recent_analytic_kls[-self.kl_early_stop_window_size :]\n        return float(sum(window) / len(window))\n\n    def _should_early_stop_for_kl(\n        self,\n        kl_signal: float | None,\n        num_kl_measurements: int,\n    ) -> bool:\n        if not self.kl_early_stop_enabled or self.desired_kl is None:\n            return False\n        if kl_signal is None:\n            return False\n        required_measurements = max(\n            self.kl_early_stop_min_updates, self.kl_early_stop_window_size\n        )\n        if num_kl_measurements < required_measurements:\n            return False\n        return kl_signal > self.desired_kl * self.kl_early_stop_factor\n\n    def _setup_data_buffers(self):\n        super()._setup_data_buffers()\n        self.use_velocity_transition: bool = (\n            self.command_name == \"base_velocity\"\n        )\n        self.transition_cls = (\n            PpoVelocityTransition\n            if self.use_velocity_transition\n            else PpoTransition\n        )\n        self.transition_td: PpoTransition | PpoVelocityTransition | None = None\n\n    def _build_optimizer_kwargs(self, optimizer_class: type) -> dict:\n        if self.optimizer_type != \"AdamW\":\n            return {}\n        signature = inspect.signature(optimizer_class.__init__)\n        parameters = signature.parameters\n        use_fused = bool(\n            self.config.get(\n                \"adamw_use_fused\", bool(self.device.type == \"cuda\")\n            )\n        )\n        use_foreach = bool(self.config.get(\"adamw_use_foreach\", True))\n        if (\n            use_fused\n            and (\"fused\" in parameters)\n            and (self.device.type == \"cuda\")\n        ):\n            return {\"fused\": True}\n        if use_foreach and (\"foreach\" in parameters):\n            return {\"foreach\": True}\n        return {}\n\n    def _setup_models_and_optimizer(self):\n        from holomotion.src.modules.agent_modules import PPOActor, PPOCritic\n\n        # Build sample TensorDict for schema-based assembly\n        sample_obs_dict = self.env.reset_all()[0]\n        sample_td = self._wrap_obs_dict(sample_obs_dict)\n        actor_cfg = OmegaConf.to_container(\n            self.config.module_dict.actor, resolve=True\n        )\n        critic_cfg = OmegaConf.to_container(\n            self.config.module_dict.critic, resolve=True\n        )\n\n        self.actor_type = actor_cfg.get(\"type\", \"MLP\")\n        self.critic_type = critic_cfg.get(\"type\", \"MLP\")\n\n        actor_schema = actor_cfg.get(\"obs_schema\", None)\n        critic_schema = critic_cfg.get(\"obs_schema\", None)\n\n        self.actor = PPOActor(\n            obs_schema=actor_schema,\n            module_config_dict=actor_cfg,\n            num_actions=self.num_actions,\n            init_noise_std=self.config.init_noise_std,\n            obs_example=sample_td,\n        ).to(self.device)\n\n        self.critic = PPOCritic(\n            obs_schema=critic_schema,\n            module_config_dict=critic_cfg,\n            obs_example=sample_td,\n        ).to(self.device)\n\n        if self.is_main_process:\n            actor = self.accelerator.unwrap_model(self.actor)\n            critic = self.accelerator.unwrap_model(self.critic)\n\n            logger.info(\"Actor (TensorDict module):\\n{!r}\", actor)\n            logger.info(\n                \"Actor keys: in_keys={} out_keys={}\",\n                list(actor.in_keys),\n                list(actor.out_keys),\n            )\n            logger.info(\"Actor core nn module:\\n{!r}\", actor.actor_module)\n\n            logger.info(\"Critic (TensorDict module):\\n{!r}\", critic)\n            logger.info(\n                \"Critic keys: in_keys={} out_keys={}\",\n                list(critic.in_keys),\n                list(critic.out_keys),\n            )\n            logger.info(\"Critic core nn module:\\n{!r}\", critic.critic_module)\n\n            # Log actor and critic parameter counts (in millions)\n            actor_params = sum(p.numel() for p in self.actor.parameters())\n            critic_params = sum(p.numel() for p in self.critic.parameters())\n            params_table = [\n                [\"Actor\", f\"{actor_params / 1.0e6:.3f}\"],\n                [\"Critic\", f\"{critic_params / 1.0e6:.3f}\"],\n                [\"Total\", f\"{(actor_params + critic_params) / 1.0e6:.3f}\"],\n            ]\n            logger.info(\n                \"Model Summary:\\n\"\n                + tabulate(\n                    params_table,\n                    headers=[\"Model\", \"Params (M)\"],\n                    tablefmt=\"simple_outline\",\n                )\n            )\n\n        optimizer_class = getattr(optim, self.optimizer_type)\n        optimizer_kwargs = self._build_optimizer_kwargs(optimizer_class)\n        self.actor_optimizer = optimizer_class(\n            self.actor.parameters(),\n            lr=self.actor_learning_rate,\n            betas=(self.actor_beta1, self.actor_beta2),\n            **optimizer_kwargs,\n        )\n        self.critic_optimizer = optimizer_class(\n            self.critic.parameters(),\n            lr=self.critic_learning_rate,\n            betas=(self.critic_beta1, self.critic_beta2),\n            **optimizer_kwargs,\n        )\n\n        dynamo_backend = self.config.get(\"dynamo_backend\", None)\n        if dynamo_backend and self.is_main_process:\n            logger.info(\n                f\"Models will be compiled with dynamo_backend='{dynamo_backend}' \"\n                \"during accelerator.prepare()\"\n            )\n        (\n            self.actor,\n            self.critic,\n            self.actor_optimizer,\n            self.critic_optimizer,\n        ) = self.accelerator.prepare(\n            self.actor,\n            self.critic,\n            self.actor_optimizer,\n            self.critic_optimizer,\n        )\n\n    def _build_storage(self, obs_td: TensorDict):\n        return RolloutStorage(\n            self.num_envs,\n            self.num_steps_per_env,\n            obs_template=obs_td,\n            actions_shape=[self.num_actions],\n            device=self.device,\n            command_name=self.command_name,\n            transition_cls=self.transition_cls,\n        )\n\n    def _build_transition(\n        self,\n        obs_td: TensorDict,\n        actor_out: TensorDict,\n        critic_out: TensorDict,\n    ):\n        actions = actor_out.get(\"actions\")\n        actions_log_prob = actor_out.get(\"actions_log_prob\")\n        mu = actor_out.get(\"mu\")\n        sigma = actor_out.get(\"sigma\")\n        values = critic_out.get(\"values\")\n\n        zero_scalar = torch.zeros(\n            self.num_envs,\n            1,\n            device=self.device,\n            dtype=torch.float32,\n        )\n        zero_scalar_bool = torch.zeros(\n            self.num_envs,\n            1,\n            device=self.device,\n            dtype=torch.bool,\n        )\n\n        transition_kwargs = {\n            \"obs\": obs_td,\n            \"actions\": actions.detach(),\n            \"teacher_actions\": torch.zeros_like(actions),\n            \"mu\": mu.detach(),\n            \"sigma\": sigma.detach(),\n            \"actions_log_prob\": actions_log_prob[..., None].detach(),\n            \"values\": values.detach(),\n            \"rewards\": zero_scalar.clone(),\n            \"dones\": zero_scalar_bool,\n            \"returns\": zero_scalar.clone(),\n            \"advantages\": zero_scalar.clone(),\n            \"batch_size\": [self.num_envs],\n            \"device\": self.device,\n        }\n\n        if self.use_velocity_transition:\n            import isaaclab.envs.mdp as isaaclab_mdp\n\n            velocity_cmd = isaaclab_mdp.generated_commands(\n                self.env._env, command_name=\"base_velocity\"\n            )\n            if velocity_cmd.shape[-1] > 3:\n                velocity_cmd = velocity_cmd[..., :3]\n            move_mask = (velocity_cmd.norm(dim=-1) > 0.1).to(\n                dtype=velocity_cmd.dtype\n            )\n            transition_kwargs[\"velocity_commands\"] = torch.cat(\n                [move_mask[..., None], velocity_cmd],\n                dim=-1,\n            ).detach()\n\n        return self.transition_cls(**transition_kwargs)\n\n    def _post_iteration_hook(self, it: int) -> None:\n        if self.command_name == \"ref_motion\":\n            motion_cmd = self.env._env.command_manager.get_term(\"ref_motion\")\n            motion_cmd.apply_cache_swap_if_pending_barrier(\n                accelerator=self.accelerator\n            )\n\n    def _post_training_hook(self) -> None:\n        if self.command_name == \"ref_motion\":\n            motion_cmd = self.env._env.command_manager.get_term(\"ref_motion\")\n            if motion_cmd is not None:\n                motion_cmd.close()\n\n    def _get_mean_policy_std(self) -> torch.Tensor:\n        base_actor = self.accelerator.unwrap_model(self.actor)\n        if hasattr(base_actor, \"std\"):\n            return base_actor.std.mean()\n        if hasattr(base_actor, \"log_std\"):\n            return torch.exp(base_actor.log_std).mean()\n        return torch.tensor(0.0, device=self.device)\n\n    def _maybe_override_loaded_actor_sigma(self) -> None:\n        if not bool(self.config.get(\"override_sigma\", False)):\n            return\n\n        sigma_override = self.config.get(\"sigma_override\", None)\n        if sigma_override is None:\n            raise ValueError(\n                \"config.override_sigma is enabled but config.sigma_override is not set.\"\n            )\n\n        actor_unwrapped = self.accelerator.unwrap_model(self.actor)\n        orig_mod = getattr(actor_unwrapped, \"_orig_mod\", None)\n        if orig_mod is not None:\n            actor_unwrapped = orig_mod\n\n        override_sigma = getattr(actor_unwrapped, \"override_sigma\", None)\n        if override_sigma is None:\n            raise AttributeError(\n                f\"{type(actor_unwrapped).__name__} does not implement override_sigma().\"\n            )\n\n        override_sigma(sigma_override)\n        if self.is_main_process:\n            logger.info(\n                \"Reapplied sigma override after checkpoint load: {}\",\n                sigma_override,\n            )\n\n    def _get_additional_log_metrics(self) -> dict[str, float]:\n        \"\"\"Build auxiliary training/cache metrics.\"\"\"\n        iteration_metrics = {}\n\n        if \"actor_learning_rate\" in self.__dict__:\n            iteration_metrics[\"0-Train/actor_learning_rate\"] = float(\n                self.actor_learning_rate\n            )\n\n        if \"critic_learning_rate\" in self.__dict__:\n            iteration_metrics[\"0-Train/critic_learning_rate\"] = float(\n                self.critic_learning_rate\n            )\n\n        if \"initial_entropy_coef\" in self.__dict__:\n            iteration_metrics[\"0-Train/entropy_coef_effective\"] = float(\n                self._get_effective_entropy_coef()\n            )\n\n        if \"_last_update_metrics\" in self.__dict__:\n            iteration_metrics.update(self._last_update_metrics)\n\n        mean_std = self._get_mean_policy_std()\n        iteration_metrics[\"0-Train/mean_noise_std\"] = float(mean_std.item())\n\n        if self.command_name != \"ref_motion\":\n            return iteration_metrics\n\n        motion_cmd = self.env._env.command_manager.get_term(\"ref_motion\")\n        cache = motion_cmd._motion_cache\n        iteration_metrics[\"1-Perf/Cache/swap_index\"] = float(cache.swap_index)\n        pool_stats = cache.cache_curriculum_pool_statistics()\n        if pool_stats is not None:\n            core_cache_metric_keys = {\n                \"prioritized_pool_size\": \"1-Perf/Cache/prioritized_pool_size\",\n                \"prioritized_pool_mean_score\": \"1-Perf/Cache/prioritized_pool_mean_score\",\n                \"uniform_pool_mean_score\": \"1-Perf/Cache/uniform_pool_mean_score\",\n                \"entered_prioritized_pool_count\": \"1-Perf/Cache/entered_prioritized_pool_count\",\n                \"exited_prioritized_pool_count\": \"1-Perf/Cache/exited_prioritized_pool_count\",\n            }\n            for src_key, dst_key in core_cache_metric_keys.items():\n                if src_key in pool_stats:\n                    iteration_metrics[dst_key] = float(pool_stats[src_key])\n        return iteration_metrics\n\n    def update(self):\n        mean_value_loss = 0.0\n        mean_surrogate_loss = 0.0\n        mean_entropy = 0.0\n        mean_kl_analytic = 0.0\n        mean_symmetry_loss = 0.0\n        critic_explained_variance = self._compute_explained_variance(\n            target=self.storage.data[\"returns\"],\n            prediction=self.storage.data[\"values\"],\n        )\n\n        batch_size = int(\n            self.storage.num_envs * self.storage.num_transitions_per_env\n        )\n        (\n            effective_num_mini_batches,\n            mini_batch_size,\n        ) = RolloutStorage.resolve_mini_batch_partition(\n            batch_size, self.num_mini_batches\n        )\n        self._last_update_metrics = {\n            \"0-Train/configured_num_mini_batches\": float(\n                self.configured_num_mini_batches\n            ),\n            \"0-Train/requested_num_mini_batches\": float(\n                self.requested_num_mini_batches\n            ),\n            \"0-Train/effective_num_mini_batches\": float(\n                effective_num_mini_batches\n            ),\n            \"0-Train/mini_batch_size_per_rank\": float(mini_batch_size),\n            \"0-Train/num_updates_executed\": 0.0,\n            \"0-Train/lr_scale_factor\": float(self.distributed_lr_scale_factor),\n            \"0-Train/scalable_distributed_update\": float(\n                self.distributed_update_mode == \"scalable\"\n            ),\n            \"0-Train/kl_windowed\": 0.0,\n            \"0-Train/kl_stop_triggered\": 0.0,\n            \"0-Train/kl_stop_analytic\": 0.0,\n            \"0-Train/kl_analytic_batch_last\": 0.0,\n            \"0-Train/kl_analytic_batch_max\": 0.0,\n            \"0-Train/clip_fraction_batch_mean\": 0.0,\n            \"0-Train/clip_fraction_batch_last\": 0.0,\n        }\n        entropy_coef = self._get_effective_entropy_coef()\n\n        generator = self.storage.iter_minibatches(\n            effective_num_mini_batches,\n            self.num_learning_epochs,\n        )\n        measure_analytic_kl = self.desired_kl is not None\n        num_updates = 0\n        num_kl_measurements = 0\n        kl_stop_triggered = False\n        kl_stop_analytic = 0.0\n        kl_windowed = None\n        recent_analytic_kls: list[float] = []\n        kl_analytic_batch_last = 0.0\n        kl_analytic_batch_max = 0.0\n        clip_fraction_batch_mean = 0.0\n        clip_fraction_batch_last = 0.0\n\n        for batch in generator:\n            obs_batch = batch.obs\n            actions_batch = batch.actions\n            target_values_batch = batch.values\n            advantages_batch = batch.advantages\n            returns_batch = batch.returns\n            old_actions_log_prob_batch = batch.actions_log_prob\n            old_mu_batch = batch.mu\n            old_sigma_batch = batch.sigma\n            with self.accelerator.autocast():\n                actor_out = self.actor(\n                    obs_batch,\n                    actions=actions_batch,\n                    mode=\"logp\",\n                    update_obs_norm=False,\n                )\n                critic_out = self.critic(obs_batch, update_obs_norm=False)\n                actions_log_prob_batch = actor_out.get(\"actions_log_prob\")\n                mu_batch = actor_out.get(\"mu\")\n                sigma_batch = actor_out.get(\"sigma\")\n                entropy_batch = actor_out.get(\"entropy\")\n                value_pred = critic_out.get(\"values\")\n                symmetry_loss = None\n                if self._symmetry_loss_active():\n                    mirrored_obs_batch = self._mirror_actor_obs(obs_batch)\n                    mirrored_actor_out = self.actor(\n                        mirrored_obs_batch,\n                        actions=None,\n                        mode=\"inference\",\n                        update_obs_norm=False,\n                    )\n                    mirrored_actions = mirrored_actor_out.get(\"actions\")\n                    mirrored_actions_back = self._mirror_env_action(\n                        mirrored_actions\n                    )\n                    symmetry_loss = F.mse_loss(\n                        mu_batch.float(),\n                        mirrored_actions_back.float(),\n                    )\n\n            value_batch = value_pred\n            returns_batch_norm = returns_batch\n            target_values_batch_norm = target_values_batch\n\n            analytic_kl = None\n            if measure_analytic_kl:\n                analytic_kl = self._compute_analytic_kl(\n                    old_mu=old_mu_batch.float(),\n                    old_sigma=old_sigma_batch.float(),\n                    new_mu=mu_batch.float(),\n                    new_sigma=sigma_batch.float(),\n                )\n                mean_kl_analytic += analytic_kl\n                num_kl_measurements += 1\n                kl_analytic_batch_last = analytic_kl\n                kl_analytic_batch_max = max(kl_analytic_batch_max, analytic_kl)\n                recent_analytic_kls.append(analytic_kl)\n                if len(recent_analytic_kls) > self.kl_early_stop_window_size:\n                    recent_analytic_kls.pop(0)\n                kl_windowed = self._compute_windowed_kl_signal(\n                    recent_analytic_kls\n                )\n                if self._should_early_stop_for_kl(\n                    kl_windowed, num_kl_measurements\n                ):\n                    kl_stop_triggered = True\n                    kl_stop_analytic = analytic_kl\n                    break\n\n            ratio = torch.exp(\n                actions_log_prob_batch\n                - torch.squeeze(old_actions_log_prob_batch).float()\n            )\n            clip_fraction = self._compute_clip_fraction(ratio)\n            clip_fraction_batch_mean += clip_fraction\n            clip_fraction_batch_last = clip_fraction\n            surrogate = -torch.squeeze(advantages_batch) * ratio\n            surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(\n                ratio, 1.0 - self.clip_param, 1.0 + self.clip_param\n            )\n            surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()\n\n            if self.use_clipped_value_loss:\n                value_clipped = target_values_batch_norm + (\n                    value_batch - target_values_batch_norm\n                ).clamp(-self.clip_param, self.clip_param)\n                value_losses = (value_batch - returns_batch_norm).pow(2)\n                value_losses_clipped = (\n                    value_clipped - returns_batch_norm\n                ).pow(2)\n                value_loss = torch.max(\n                    value_losses, value_losses_clipped\n                ).mean()\n            else:\n                value_loss = (returns_batch_norm - value_batch).pow(2).mean()\n\n            actor_loss = surrogate_loss\n            critic_loss = self.value_loss_coef * value_loss\n\n            if entropy_coef > 0.0:\n                entropy_loss = entropy_batch.mean()\n                actor_loss = actor_loss - entropy_coef * entropy_loss\n            if symmetry_loss is not None:\n                actor_loss = (\n                    actor_loss + self.symmetry_loss_coef * symmetry_loss\n                )\n\n            self.actor_optimizer.zero_grad()\n            self.critic_optimizer.zero_grad()\n            self.accelerator.backward(actor_loss)\n            self.accelerator.backward(critic_loss)\n\n            if self.max_grad_norm is not None:\n                self.accelerator.clip_grad_norm_(\n                    self.actor.parameters(),\n                    self.max_grad_norm,\n                )\n                self.accelerator.clip_grad_norm_(\n                    self.critic.parameters(),\n                    self.max_grad_norm,\n                )\n\n            self.actor_optimizer.step()\n            self.critic_optimizer.step()\n\n            num_updates += 1\n            mean_value_loss += float(value_loss.item())\n            mean_surrogate_loss += float(surrogate_loss.item())\n            mean_entropy += float(entropy_batch.mean().item())\n            if symmetry_loss is not None:\n                mean_symmetry_loss += float(symmetry_loss.item())\n\n        denom = max(1, num_updates)\n        mean_value_loss /= denom\n        mean_surrogate_loss /= denom\n        mean_entropy /= denom\n        mean_kl_analytic /= max(1, num_kl_measurements)\n        mean_symmetry_loss /= denom\n        clip_fraction_batch_mean /= denom\n        if self.schedule == \"adaptive\":\n            self._apply_adaptive_lr(kl_windowed)\n        self._last_update_metrics[\"0-Train/num_updates_executed\"] = float(\n            num_updates\n        )\n        self._last_update_metrics[\"0-Train/kl_windowed\"] = float(\n            kl_windowed or 0.0\n        )\n        self._last_update_metrics[\"0-Train/kl_stop_triggered\"] = float(\n            kl_stop_triggered\n        )\n        self._last_update_metrics[\"0-Train/kl_stop_analytic\"] = float(\n            kl_stop_analytic\n        )\n        self._last_update_metrics[\"0-Train/kl_analytic_batch_last\"] = float(\n            kl_analytic_batch_last\n        )\n        self._last_update_metrics[\"0-Train/kl_analytic_batch_max\"] = float(\n            kl_analytic_batch_max\n        )\n        self._last_update_metrics[\"0-Train/clip_fraction_batch_mean\"] = float(\n            clip_fraction_batch_mean\n        )\n        self._last_update_metrics[\"0-Train/clip_fraction_batch_last\"] = float(\n            clip_fraction_batch_last\n        )\n\n        self.storage.clear()\n\n        loss_out = {\n            \"value_function\": mean_value_loss,\n            \"critic_explained_variance\": critic_explained_variance,\n            \"surrogate\": mean_surrogate_loss,\n            \"entropy\": mean_entropy,\n            \"kl_analytic\": mean_kl_analytic,\n        }\n        if self._symmetry_loss_active():\n            loss_out[\"symmetry_loss\"] = mean_symmetry_loss\n        # Reduce losses across processes for consistent logging on rank 0\n        if self.is_distributed:\n            reduced_out = {}\n            for k, v in loss_out.items():\n                if v is None:\n                    reduced_out[k] = None\n                    continue\n                t = torch.tensor(v, device=self.device, dtype=torch.float32)\n                reduced_t = self.accelerator.reduce(t, reduction=\"mean\")\n                reduced_out[k] = float(reduced_t.item())\n            loss_out = reduced_out\n\n        self._post_update_hook(loss_out)\n        return loss_out\n\n    def load(self, ckpt_path):\n        if ckpt_path is None:\n            return None\n        if self.is_main_process:\n            logger.info(f\"Loading checkpoint from {ckpt_path}\")\n\n        actor_model_path = self._resolve_model_file_path(ckpt_path, \"actor\")\n        critic_model_path = self._resolve_model_file_path(ckpt_path, \"critic\")\n        self._load_accelerate_model(self.actor, actor_model_path, strict=True)\n        self._load_accelerate_model(\n            self.critic, critic_model_path, strict=True\n        )\n\n        loaded_dict = torch.load(ckpt_path, map_location=self.device)\n        if not getattr(self, \"is_offline_eval\", False):\n            self._restore_optimizer_state(\n                self.actor_optimizer,\n                loaded_dict[\"actor_optimizer_state_dict\"],\n                optimizer_name=\"actor\",\n            )\n            self._restore_optimizer_state(\n                self.critic_optimizer,\n                loaded_dict[\"critic_optimizer_state_dict\"],\n                optimizer_name=\"critic\",\n            )\n        elif self.is_main_process:\n            logger.info(\n                \"Skipping optimizer state restore during offline evaluation.\"\n            )\n        self.current_learning_iteration = loaded_dict.get(\"iter\", 0)\n        self._maybe_override_loaded_actor_sigma()\n        self._load_extra_checkpoint_state(loaded_dict)\n        return loaded_dict.get(\"infos\", None)\n\n    def _restore_optimizer_state(\n        self,\n        optimizer,\n        loaded_state_dict,\n        *,\n        optimizer_name: str,\n    ) -> bool:\n        compatible, reason = self._optimizer_state_is_compatible(\n            optimizer, loaded_state_dict\n        )\n        if not compatible:\n            if self.is_main_process:\n                logger.warning(\n                    \"Skipping {} optimizer state restore from checkpoint: {}\",\n                    optimizer_name,\n                    reason,\n                )\n            return False\n\n        try:\n            optimizer.load_state_dict(loaded_state_dict)\n        except ValueError as exc:\n            if self.is_main_process:\n                logger.warning(\n                    \"Skipping {} optimizer state restore from checkpoint: {}\",\n                    optimizer_name,\n                    exc,\n                )\n            return False\n        return True\n\n    def _optimizer_state_is_compatible(\n        self, optimizer, loaded_state_dict\n    ) -> tuple[bool, str | None]:\n        current_state_dict = optimizer.state_dict()\n        current_groups = current_state_dict.get(\"param_groups\")\n        loaded_groups = loaded_state_dict.get(\"param_groups\")\n        if not isinstance(current_groups, list) or not isinstance(\n            loaded_groups, list\n        ):\n            return True, None\n        if len(current_groups) != len(loaded_groups):\n            return (\n                False,\n                \"param group count mismatch \"\n                f\"(current={len(current_groups)}, loaded={len(loaded_groups)})\",\n            )\n\n        for group_idx, (current_group, loaded_group) in enumerate(\n            zip(current_groups, loaded_groups)\n        ):\n            current_param_count = len(current_group.get(\"params\", []))\n            loaded_param_count = len(loaded_group.get(\"params\", []))\n            if current_param_count != loaded_param_count:\n                return (\n                    False,\n                    \"param group size mismatch for group \"\n                    f\"{group_idx} (current={current_param_count}, \"\n                    f\"loaded={loaded_param_count})\",\n                )\n        return True, None\n\n    def save(self, path, infos=None):\n        if not self.is_main_process:\n            return\n\n        logger.info(f\"Saving checkpoint to {path}\")\n        base_path = path.replace(\".pt\", \"\")\n        os.makedirs(\n            os.path.dirname(base_path) if os.path.dirname(base_path) else \".\",\n            exist_ok=True,\n        )\n\n        self.accelerator.save_model(\n            self.actor, os.path.join(base_path, \"actor\")\n        )\n        self.accelerator.save_model(\n            self.critic, os.path.join(base_path, \"critic\")\n        )\n\n        custom_state = {\n            \"actor_optimizer_state_dict\": self.actor_optimizer.state_dict(),\n            \"critic_optimizer_state_dict\": self.critic_optimizer.state_dict(),\n            \"iter\": self.current_learning_iteration,\n            \"infos\": infos,\n        }\n        custom_state.update(self._extra_checkpoint_state())\n        torch.save(_checkpoint_state_to_cpu(custom_state), path)\n\n        if bool(self.config.get(\"export_policy\", False)):\n            export_policy_to_onnx_common(\n                self,\n                path,\n                onnx_name_suffix=self.config.get(\"onnx_name_suffix\", None),\n                use_kv_cache=bool(self.config.get(\"use_kv_cache\", True)),\n            )\n\n    def offline_evaluate_policy(self, dump_npzs: bool = False):\n        \"\"\"Dump NPZs (no metrics) from validation cache using ref_motion command.\n\n        - Iterates validation batches; env i -> clip i (deterministic) starting at frame 0.\n        - Collect robot and reference sequences each step and save one NPZ per clip.\n        - NPZ conforms to holomotion_retargeted format keys.\n        - Optionally records viewport MP4(s) aligned with target_fps and rollout length.\n        \"\"\"\n\n        ckpt_path = self.config.checkpoint\n        n_fut_frames = self.env.config.commands.ref_motion.params.get(\n            \"n_fut_frames\", 8\n        )\n        # log_dir is already set to checkpoint directory in eval script\n        model_name = os.path.basename(ckpt_path).replace(\".pt\", \"\")\n\n        # Eval modes (freeze normalizers if enabled)\n        self.actor.eval()\n        self.critic.eval()\n\n        # Require ref_motion command and simple cache backend\n        command_name = list(self.env.config.commands.keys())[0]\n        if command_name != \"ref_motion\":\n            logger.warning(\n                \"Offline evaluation only supported for ref_motion command\"\n            )\n            return {}\n        motion_cmd = self.env._env.command_manager.get_term(\"ref_motion\")\n        cache = getattr(motion_cmd, \"_motion_cache\", None)\n        if cache is None:\n            logger.error(\n                \"Offline evaluation requires hdf5_simple cache backend (no LMDB support)\"\n            )\n            return {}\n\n        self._offline_evaluating = True\n\n        # Evaluation flag and cache batch-size adjustment (ensure batch_size == num_envs)\n        motion_cmd._is_evaluating = True\n        num_envs = self.env.num_envs\n        try:\n            if getattr(cache, \"_batch_size\", None) != num_envs:\n                from holomotion.src.training.h5_dataloader import (\n                    MotionClipBatchCache,\n                )\n\n                cache = MotionClipBatchCache(\n                    train_dataset=cache._datasets[\"train\"],\n                    val_dataset=cache._datasets[\"val\"],\n                    batch_size=num_envs,\n                    stage_device=getattr(cache, \"_stage_device\", None),\n                    num_workers=getattr(cache, \"_num_workers\", 0),\n                    prefetch_factor=getattr(cache, \"_prefetch_factor\", None),\n                    pin_memory=getattr(cache, \"_pin_memory\", True),\n                    persistent_workers=getattr(\n                        cache, \"_persistent_workers\", False\n                    ),\n                    sampler_rank=getattr(cache, \"_sampler_rank\", 0),\n                    sampler_world_size=getattr(\n                        cache, \"_sampler_world_size\", 1\n                    ),\n                    allowed_prefixes=getattr(cache, \"_allowed_prefixes\", None),\n                    swap_interval_steps=getattr(\n                        cache, \"swap_interval_steps\", None\n                    ),\n                    force_timeout_on_swap=getattr(\n                        cache, \"force_timeout_on_swap\", True\n                    ),\n                    seed=getattr(cache, \"_seed\", None),\n                    loader_timeout=getattr(cache, \"_loader_timeout\", 0.0),\n                )\n                motion_cmd._motion_cache = cache\n        except Exception as e:\n            logger.warning(\n                f\"Offline eval: failed to rebuild cache to batch_size={num_envs}: {e}\"\n            )\n\n        # Derive HDF5 dataset base name (from validation dataset root) for output naming\n        dataset_suffix = None\n        val_dataset = cache._datasets[\"val\"]\n        dataset_root = None\n        if hasattr(val_dataset, \"hdf5_root\"):\n            dataset_root = str(val_dataset.hdf5_root).rstrip(os.sep)\n        elif hasattr(val_dataset, \"ts_roots\"):\n            ts_roots = getattr(val_dataset, \"ts_roots\")\n            if ts_roots:\n                dataset_root = str(ts_roots[0]).rstrip(os.sep)\n        if dataset_root:\n            dataset_suffix = os.path.basename(dataset_root)\n\n        # Output directory (respect existing log_dir derived from checkpoint)\n        suffix = f\"isaaclab_eval_output_{model_name}\"\n        if dataset_suffix is not None:\n            suffix = f\"{suffix}_{dataset_suffix}\"\n        output_dir = os.path.join(self.log_dir, suffix)\n        os.makedirs(output_dir, exist_ok=True)\n        logger.info(f\"Saving evaluation outputs to: {output_dir}\")\n\n        # Switch to validation cache and iterate all batches\n        if hasattr(cache, \"set_mode\"):\n            cache.set_mode(\"val\")\n        # Determine policy/video FPS from command config (align wallclock time)\n        motion_fps = int(getattr(motion_cmd.cfg, \"target_fps\", 50))\n        total_batches = int(getattr(cache, \"num_batches\", 1))\n        with torch.no_grad():\n            for batch_idx in tqdm(\n                range(total_batches), desc=\"Evaluating batches\"\n            ):\n                if batch_idx > 0:\n                    cache.advance()\n                # Reset envs first, then apply deterministic mapping on the active cache batch\n                _ = self.env.reset_all()\n                if hasattr(motion_cmd, \"setup_offline_eval_deterministic\"):\n                    motion_cmd.setup_offline_eval_deterministic(\n                        apply_pending_swap=False\n                    )\n                self._reset_rollout_forward_state()\n\n                # Read current batch metadata AFTER reset + setup\n                current = getattr(cache, \"current_batch\", None)\n                if current is None or not hasattr(current, \"motion_keys\"):\n                    logger.warning(\n                        \"Current cache batch missing motion_keys; skipping batch\"\n                    )\n                    continue\n                motion_keys = list(current.motion_keys)\n                raw_motion_keys = list(\n                    getattr(current, \"raw_motion_keys\", current.motion_keys)\n                )\n\n                # Determine active env count for this batch\n                clip_count = int(cache.clip_count)\n                active_count = min(num_envs, clip_count)\n\n                if active_count > 0:\n                    active_ids = torch.arange(\n                        active_count,\n                        dtype=torch.long,\n                        device=self.device,\n                    )\n                    motion_cmd.force_realign_offline_eval_no_perturb(\n                        active_ids\n                    )\n\n                # Recompute observations after deterministic setup\n                obs_mgr = self.env._env.observation_manager\n                if active_count > 0:\n                    obs_mgr.reset(active_ids)\n                    obs_dict = obs_mgr.compute(update_history=True)\n                else:\n                    obs_dict = obs_mgr.compute(update_history=True)\n                obs = self._wrap_obs_dict(obs_dict)\n\n                # Map env -> motion_key for active envs\n                env_motion_keys = {\n                    int(i): motion_keys[int(i)] for i in range(active_count)\n                }\n                env_raw_motion_keys = {\n                    int(i): raw_motion_keys[int(i)]\n                    for i in range(active_count)\n                }\n\n                # Prepare per-env collectors\n                env_has_done = torch.zeros(\n                    num_envs, dtype=torch.bool, device=self.device\n                )\n                episode_lengths = torch.zeros(\n                    num_envs, dtype=torch.long, device=self.device\n                )\n\n                active_mask = torch.zeros(\n                    num_envs, dtype=torch.bool, device=self.device\n                )\n                if active_count > 0:\n                    active_mask[:active_count] = True\n\n                # Reference collectors (URDF order)\n                ref_dof_pos = [[] for _ in range(active_count)]\n                ref_dof_vel = [[] for _ in range(active_count)]\n                ref_body_pos = [[] for _ in range(active_count)]\n                ref_body_rot_xyzw = [[] for _ in range(active_count)]\n                ref_body_vel = [[] for _ in range(active_count)]\n                ref_body_ang_vel = [[] for _ in range(active_count)]\n\n                # Robot collectors (URDF order)\n                robot_dof_pos = [[] for _ in range(active_count)]\n                robot_dof_vel = [[] for _ in range(active_count)]\n                robot_body_pos = [[] for _ in range(active_count)]\n                robot_body_rot_xyzw = [[] for _ in range(active_count)]\n                robot_body_vel = [[] for _ in range(active_count)]\n                robot_body_ang_vel = [[] for _ in range(active_count)]\n                robot_dof_acc = [[] for _ in range(active_count)]\n                robot_dof_torque = [[] for _ in range(active_count)]\n                robot_action_rate = [[] for _ in range(active_count)]\n                prev_robot_dof_vel = [None for _ in range(active_count)]\n                prev_robot_actions = [None for _ in range(active_count)]\n                step_dt = float(self.env._env.step_dt)\n\n                # Per-env bookkeeping\n                clip_lengths_np = (\n                    current.lengths.detach().cpu().numpy()\n                    if hasattr(current, \"lengths\")\n                    else np.array(\n                        [getattr(cache, \"max_frame_length\", 1000)]\n                        * active_count\n                    )\n                )\n                # Persist an explicit mapping file for verification\n                try:\n                    mapping_records = []\n                    for i in range(active_count):\n                        mapping_records.append(\n                            {\n                                \"env_id\": int(i),\n                                \"motion_key\": env_motion_keys[int(i)],\n                                \"raw_motion_key\": env_raw_motion_keys[int(i)],\n                                \"clip_length\": int(clip_lengths_np[int(i)]),\n                            }\n                        )\n                    mapping_path = os.path.join(\n                        output_dir, f\"batch_{batch_idx:04d}_mapping.json\"\n                    )\n                    with open(mapping_path, \"w\") as f:\n                        json.dump(mapping_records, f, indent=2)\n                except Exception:\n                    pass\n\n                env_frame_counts = [0 for _ in range(active_count)]\n                encountered_done = [False for _ in range(active_count)]\n                valid_masks = [[] for _ in range(active_count)]\n\n                def _sanitize_key(key: str) -> str:\n                    return (\n                        key.replace(\"/\", \"+\")\n                        .replace(\" \", \"_\")\n                        .replace(\"\\\\\", \"+\")\n                    )\n\n                def _get_out_path(idx: int) -> str:\n                    out_name = f\"{_sanitize_key(env_motion_keys[idx])}.npz\"\n                    return os.path.join(output_dir, out_name)\n\n                def _save_env_npz(idx: int):\n                    if idx >= active_count:\n                        return\n                    # Total collected frames\n                    total_len = int(min(env_frame_counts[idx], max_steps))\n                    if total_len <= 0:\n                        return\n\n                    # Compute contiguous valid prefix length and slice_len\n                    vm = valid_masks[idx][:total_len]\n                    valid_prefix_len = 0\n                    for b in vm:\n                        if b:\n                            valid_prefix_len += 1\n                        else:\n                            break\n                    clip_len = int(clip_lengths_np[idx])\n                    slice_len = int(min(valid_prefix_len, clip_len, total_len))\n                    if slice_len <= 0:\n                        return\n\n                    # Reference arrays (sliced)\n                    ref_dof_pos_arr = np.stack(\n                        ref_dof_pos[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    ref_dof_vel_arr = np.stack(\n                        ref_dof_vel[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    ref_body_pos_arr = np.stack(\n                        ref_body_pos[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    ref_body_rot_xyzw_arr = np.stack(\n                        ref_body_rot_xyzw[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    ref_body_vel_arr = np.stack(\n                        ref_body_vel[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    ref_body_ang_vel_arr = np.stack(\n                        ref_body_ang_vel[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n\n                    # Robot arrays (sliced)\n                    robot_dof_pos_arr = np.stack(\n                        robot_dof_pos[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    robot_dof_vel_arr = np.stack(\n                        robot_dof_vel[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    robot_dof_acc_arr = np.stack(\n                        robot_dof_acc[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    robot_dof_torque_arr = np.stack(\n                        robot_dof_torque[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    robot_action_rate_arr = np.asarray(\n                        robot_action_rate[idx][:slice_len], dtype=np.float32\n                    )\n                    robot_body_pos_arr = np.stack(\n                        robot_body_pos[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    robot_body_rot_xyzw_arr = np.stack(\n                        robot_body_rot_xyzw[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    robot_body_vel_arr = np.stack(\n                        robot_body_vel[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n                    robot_body_ang_vel_arr = np.stack(\n                        robot_body_ang_vel[idx][:slice_len], axis=0\n                    ).astype(np.float32)\n\n                    # Metadata\n                    motion_fps = int(getattr(motion_cmd.cfg, \"target_fps\", 50))\n                    num_dofs = int(ref_dof_pos_arr.shape[1])\n                    num_bodies = int(ref_body_pos_arr.shape[1])\n                    wallclock_len = (\n                        float(slice_len - 1) / float(motion_fps)\n                        if motion_fps > 0 and slice_len > 0\n                        else 0.0\n                    )\n                    meta = {\n                        \"motion_key\": env_motion_keys[idx],\n                        \"raw_motion_key\": env_raw_motion_keys[idx],\n                        \"motion_fps\": float(motion_fps),\n                        \"num_frames\": int(slice_len),\n                        \"wallclock_len\": float(wallclock_len),\n                        \"num_dofs\": int(num_dofs),\n                        \"num_bodies\": int(num_bodies),\n                        \"clip_length\": int(clip_lengths_np[idx]),\n                        \"valid_prefix_len\": int(valid_prefix_len),\n                    }\n\n                    # Output filename: flattened motion_key\n                    out_path = _get_out_path(idx)\n\n                    np.savez_compressed(\n                        out_path,\n                        metadata=json.dumps(meta),\n                        robot_dof_pos=robot_dof_pos_arr,\n                        robot_dof_vel=robot_dof_vel_arr,\n                        robot_dof_acc=robot_dof_acc_arr,\n                        robot_dof_torque=robot_dof_torque_arr,\n                        robot_action_rate=robot_action_rate_arr,\n                        robot_global_translation=robot_body_pos_arr,\n                        robot_global_rotation_quat=robot_body_rot_xyzw_arr,\n                        robot_global_velocity=robot_body_vel_arr,\n                        robot_global_angular_velocity=robot_body_ang_vel_arr,\n                        ref_dof_pos=ref_dof_pos_arr,\n                        ref_dof_vel=ref_dof_vel_arr,\n                        ref_global_translation=ref_body_pos_arr,\n                        ref_global_rotation_quat=ref_body_rot_xyzw_arr,\n                        ref_global_velocity=ref_body_vel_arr,\n                        ref_global_angular_velocity=ref_body_ang_vel_arr,\n                    )\n\n                max_steps = int(\n                    getattr(cache, \"max_frame_length\", 1000)\n                )  # decide the max_length to evaluate\n                for rollout_step in tqdm(\n                    range(max_steps), desc=\"Rollout steps\"\n                ):\n                    # PRE-STEP: collect states for all active envs\n                    active = [i for i in range(active_count)]\n                    if len(active) > 0:\n                        # Reference step tensors (URDF order)\n                        ref_dp = (\n                            motion_cmd.get_ref_motion_dof_pos_cur_urdf_order()\n                            .detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        ref_dv = (\n                            motion_cmd.get_ref_motion_dof_vel_cur_urdf_order()\n                            .detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        ref_bp = (\n                            motion_cmd.get_ref_motion_bodylink_global_pos_cur_urdf_order()\n                            .detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        ref_br = (\n                            motion_cmd.get_ref_motion_bodylink_global_rot_xyzw_cur_urdf_order()\n                            .detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        ref_bv = (\n                            motion_cmd.get_ref_motion_bodylink_global_lin_vel_cur_urdf_order()\n                            .detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        ref_bav = (\n                            motion_cmd.get_ref_motion_bodylink_global_ang_vel_cur_urdf_order()\n                            .detach()\n                            .cpu()\n                            .numpy()\n                        )\n\n                        # Robot step tensors (URDF order)\n                        rob_dp = (\n                            motion_cmd.robot_dof_pos_cur_urdf_order.detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        rob_dv = (\n                            motion_cmd.robot_dof_vel_cur_urdf_order.detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        rob_bp = (\n                            motion_cmd.robot_bodylink_global_pos_cur_urdf_order.detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        rob_br = (\n                            motion_cmd.robot_bodylink_global_rot_xyzw_cur_urdf_order.detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        rob_bv = (\n                            motion_cmd.robot_bodylink_global_lin_vel_cur_urdf_order.detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        rob_bav = (\n                            motion_cmd.robot_bodylink_global_ang_vel_cur_urdf_order.detach()\n                            .cpu()\n                            .numpy()\n                        )\n                        for idx in active:\n                            if prev_robot_dof_vel[idx] is None:\n                                dof_acc_cur = np.zeros_like(\n                                    rob_dv[idx], dtype=np.float32\n                                )\n                            else:\n                                dof_acc_cur = (\n                                    rob_dv[idx] - prev_robot_dof_vel[idx]\n                                ) / step_dt\n                            prev_robot_dof_vel[idx] = rob_dv[idx].copy()\n\n                            ref_dof_pos[idx].append(ref_dp[idx])\n                            ref_dof_vel[idx].append(ref_dv[idx])\n                            ref_body_pos[idx].append(ref_bp[idx])\n                            ref_body_rot_xyzw[idx].append(ref_br[idx])\n                            ref_body_vel[idx].append(ref_bv[idx])\n                            ref_body_ang_vel[idx].append(ref_bav[idx])\n\n                            robot_dof_pos[idx].append(rob_dp[idx])\n                            robot_dof_vel[idx].append(rob_dv[idx])\n                            robot_dof_acc[idx].append(\n                                dof_acc_cur.astype(np.float32)\n                            )\n                            robot_body_pos[idx].append(rob_bp[idx])\n                            robot_body_rot_xyzw[idx].append(rob_br[idx])\n                            robot_body_vel[idx].append(rob_bv[idx])\n                            robot_body_ang_vel[idx].append(rob_bav[idx])\n\n                            # Record valid mask for current frame (before step)\n                            clip_limit = int(clip_lengths_np[idx])\n                            valid_now = (\n                                (idx < active_count)\n                                and (not encountered_done[idx])\n                                and (\n                                    env_frame_counts[idx]\n                                    < clip_limit - n_fut_frames\n                                )\n                            )\n                            valid_masks[idx].append(bool(valid_now))\n\n                            # Increment local frame counter\n                            env_frame_counts[idx] += 1\n\n                    # No mid-rollout finalize; we defer to end using valid masks\n                    # Inference and step (advance sim)\n                    obs = self._rollout_forward(\n                        obs,\n                        actor_mode=\"inference\",\n                        collect_transition=False,\n                        track_episode_stats=False,\n                    )\n                    dones = self._last_rollout_dones\n                    if dones is None:\n                        raise RuntimeError(\n                            \"Rollout forward did not return dones during offline evaluation.\"\n                        )\n                    actions_step = self._last_rollout_actions\n                    if actions_step is None:\n                        raise RuntimeError(\n                            \"Rollout forward did not return actions during offline evaluation.\"\n                        )\n\n                    actions_np = actions_step.detach().cpu().numpy()\n                    torque_urdf = (\n                        motion_cmd.robot.data.applied_torque[\n                            ..., motion_cmd.sim2urdf_dof_idx\n                        ]\n                        .detach()\n                        .cpu()\n                        .numpy()\n                    )\n                    for idx in range(active_count):\n                        if prev_robot_actions[idx] is None:\n                            action_rate_cur = 0.0\n                        else:\n                            action_rate_cur = float(\n                                np.linalg.norm(\n                                    actions_np[idx] - prev_robot_actions[idx]\n                                )\n                                / step_dt\n                            )\n                        prev_robot_actions[idx] = actions_np[idx].copy()\n                        robot_action_rate[idx].append(\n                            np.float32(action_rate_cur)\n                        )\n                        robot_dof_torque[idx].append(\n                            torque_urdf[idx].astype(np.float32)\n                        )\n\n                    # Handle RL dones (first-done policy): mark done for future frames\n                    step_dones = (\n                        dones.bool().reshape(-1).detach().cpu().numpy()\n                    )\n                    for idx in range(min(active_count, len(step_dones))):\n                        if step_dones[idx] and not encountered_done[idx]:\n                            encountered_done[idx] = True\n\n                    if rollout_step == max_steps - 1:\n                        # End of rollout: save once per env with full rollout arrays + valid_mask\n                        if dump_npzs and active_count > 0:\n                            out_path_to_last_idx = {}\n                            for idx in range(active_count):\n                                out_path_to_last_idx[_get_out_path(idx)] = idx\n                            save_indices = list(out_path_to_last_idx.values())\n                            max_npz_save_workers = max(\n                                1, min(16, len(save_indices))\n                            )\n                            with ThreadPoolExecutor(\n                                max_workers=max_npz_save_workers\n                            ) as executor:\n                                futures = [\n                                    executor.submit(_save_env_npz, idx)\n                                    for idx in save_indices\n                                ]\n                                for future in tqdm(\n                                    as_completed(futures),\n                                    total=len(futures),\n                                    desc=\"Saving NPZs\",\n                                ):\n                                    future.result()\n                        break\n\n        logger.info(\n            f\"Offline evaluation complete: saved clips to {output_dir}\"\n        )\n        return {\"output_dir\": output_dir}\n"
  },
  {
    "path": "holomotion/src/algo/ppo_tf.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nfrom typing import Generator\n\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom holomotion.src.algo.algo_utils import PpoAuxTransition\nfrom holomotion.src.algo.ppo import PPO\nfrom holomotion.src.modules.agent_modules import (\n    PPOCondTFActor,\n    PPOCritic,\n    PPOTFActor,\n    PPOTFRefRouterActor,\n    PPOTFRefRouterSeqActor,\n    PPOTFRefRouterV3Actor,\n    TensorDictAssembler,\n)\nfrom holomotion.src.modules.network_modules import GroupedMoEBlock\nfrom loguru import logger\nfrom omegaconf import OmegaConf\nfrom tabulate import tabulate\nfrom tensordict import TensorDict\n\n\nclass PPOTF(PPO):\n    \"\"\"Transformer-policy PPO with TensorDict rollout and sequence update.\"\"\"\n\n    @staticmethod\n    def _select_actor_wrapper_cls(actor_cfg: dict):\n        actor_type = str(actor_cfg.get(\"type\", \"\"))\n        use_future_cross_attn = bool(\n            actor_cfg.get(\"use_future_cross_attn\", False)\n        )\n        if actor_type == \"ReferenceRoutedGroupedMoETransformerPolicy\":\n            if use_future_cross_attn:\n                raise ValueError(\n                    \"ReferenceRoutedGroupedMoETransformerPolicy does not \"\n                    \"support use_future_cross_attn=True.\"\n                )\n            return PPOTFRefRouterActor\n        if actor_type == \"ReferenceRoutedGroupedMoETransformerPolicyV2\":\n            if use_future_cross_attn:\n                raise ValueError(\n                    \"ReferenceRoutedGroupedMoETransformerPolicyV2 does not \"\n                    \"support use_future_cross_attn=True.\"\n                )\n            return PPOTFRefRouterSeqActor\n        if actor_type == \"ReferenceRoutedGroupedMoETransformerPolicyV3\":\n            if use_future_cross_attn:\n                raise ValueError(\n                    \"ReferenceRoutedGroupedMoETransformerPolicyV3 does not \"\n                    \"support use_future_cross_attn=True.\"\n                )\n            return PPOTFRefRouterV3Actor\n        if use_future_cross_attn:\n            return PPOCondTFActor\n        return PPOTFActor\n\n    @staticmethod\n    def _summarize_moe_layer_stats(moe_layers) -> dict[str, float | None]:\n        if len(moe_layers) == 0:\n            return {\n                \"moe_active_expert_ratio\": None,\n                \"moe_max_expert_frac\": None,\n                \"moe_least_expert_frac\": None,\n                \"moe_dead_expert_ratio\": None,\n                \"moe_expert_count_cv\": None,\n                \"moe_selected_expert_margin_to_unselected\": None,\n            }\n\n        def _mean_attr(attr_name: str) -> float:\n            values = torch.stack(\n                [\n                    getattr(layer, attr_name).to(torch.float32)\n                    for layer in moe_layers\n                ]\n            )\n            return float(values.mean().item())\n\n        return {\n            \"moe_active_expert_ratio\": _mean_attr(\"last_active_expert_ratio\"),\n            \"moe_max_expert_frac\": _mean_attr(\"last_max_expert_frac\"),\n            \"moe_least_expert_frac\": _mean_attr(\"last_min_expert_frac\"),\n            \"moe_dead_expert_ratio\": _mean_attr(\"last_dead_expert_ratio\"),\n            \"moe_expert_count_cv\": _mean_attr(\"last_expert_count_cv\"),\n            \"moe_selected_expert_margin_to_unselected\": _mean_attr(\n                \"last_selected_expert_margin_to_unselected\"\n            ),\n        }\n\n    def _setup_configs(self):\n        super()._setup_configs()\n        aux_cfg = self.config.get(\"aux_state_pred\", {})\n        self.use_aux_state_pred: bool = bool(aux_cfg.get(\"enabled\", False))\n        self.aux_state_pred_w_base_lin_vel = float(\n            aux_cfg.get(\"w_base_lin_vel\", 0.0)\n        )\n        self.aux_state_pred_w_root_height = float(\n            aux_cfg.get(\"w_root_height\", 0.0)\n        )\n        self.aux_state_pred_w_keybody_contact = float(\n            aux_cfg.get(\"w_keybody_contact\", 0.0)\n        )\n        self.aux_state_pred_w_ref_keybody_rel_pos = float(\n            aux_cfg.get(\"w_ref_keybody_rel_pos\", 0.0)\n        )\n        self.aux_state_pred_w_robot_keybody_rel_pos = float(\n            aux_cfg.get(\"w_robot_keybody_rel_pos\", 0.0)\n        )\n        self.aux_state_pred_w_denoise_ref_root_lin_vel = float(\n            aux_cfg.get(\"w_denoise_ref_root_lin_vel\", 0.0)\n        )\n        self.aux_state_pred_w_denoise_ref_root_ang_vel = float(\n            aux_cfg.get(\"w_denoise_ref_root_ang_vel\", 0.0)\n        )\n        self.aux_state_pred_w_denoise_ref_dof_pos = float(\n            aux_cfg.get(\"w_denoise_ref_dof_pos\", 0.0)\n        )\n        self.aux_state_pred_keybody_contact_names = [\n            str(name) for name in aux_cfg.get(\"keybody_contact_names\", [])\n        ]\n        self.aux_state_pred_keybody_rel_pos_names = [\n            str(name) for name in aux_cfg.get(\"keybody_rel_pos_names\", [])\n        ]\n        self.aux_state_pred_num_contact_bodies = int(\n            len(self.aux_state_pred_keybody_contact_names)\n        )\n        self.aux_state_pred_num_keybody_bodies = int(\n            len(self.aux_state_pred_keybody_rel_pos_names)\n        )\n        self.use_aux_root_height = bool(\n            self.use_aux_state_pred and self.aux_state_pred_w_root_height > 0.0\n        )\n        self.use_aux_denoise_ref_root_lin_vel = bool(\n            self.use_aux_state_pred\n            and self.aux_state_pred_w_denoise_ref_root_lin_vel > 0.0\n        )\n        self.use_aux_denoise_ref_root_ang_vel = bool(\n            self.use_aux_state_pred\n            and self.aux_state_pred_w_denoise_ref_root_ang_vel > 0.0\n        )\n        self.use_aux_denoise_ref_dof_pos = bool(\n            self.use_aux_state_pred\n            and self.aux_state_pred_w_denoise_ref_dof_pos > 0.0\n        )\n        self.aux_state_pred_min_std = float(aux_cfg.get(\"min_std\", 1.0e-3))\n        self.aux_state_pred_max_std = float(aux_cfg.get(\"max_std\", 5.0))\n        self.aux_denoise_residual_huber_beta = float(\n            aux_cfg.get(\"denoise_residual_huber_beta\", 0.1)\n        )\n        self.aux_state_pred_raycast_z_offset = float(\n            aux_cfg.get(\"raycast_z_offset\", 1.0)\n        )\n        self.aux_state_pred_raycast_max_dist = float(\n            aux_cfg.get(\"raycast_max_dist\", 20.0)\n        )\n        if self.aux_state_pred_min_std <= 0.0:\n            raise ValueError(\"aux_state_pred.min_std must be > 0.\")\n        if self.aux_state_pred_max_std <= self.aux_state_pred_min_std:\n            raise ValueError(\n                \"aux_state_pred.max_std must be > aux_state_pred.min_std.\"\n            )\n        if self.aux_denoise_residual_huber_beta <= 0.0:\n            raise ValueError(\n                \"aux_state_pred.denoise_residual_huber_beta must be > 0.\"\n            )\n        if self.aux_state_pred_w_base_lin_vel < 0.0:\n            raise ValueError(\"aux_state_pred.w_base_lin_vel must be >= 0.\")\n        if self.aux_state_pred_w_root_height < 0.0:\n            raise ValueError(\"aux_state_pred.w_root_height must be >= 0.\")\n        if self.aux_state_pred_w_keybody_contact < 0.0:\n            raise ValueError(\"aux_state_pred.w_keybody_contact must be >= 0.\")\n        if self.aux_state_pred_w_ref_keybody_rel_pos < 0.0:\n            raise ValueError(\n                \"aux_state_pred.w_ref_keybody_rel_pos must be >= 0.\"\n            )\n        if self.aux_state_pred_w_robot_keybody_rel_pos < 0.0:\n            raise ValueError(\n                \"aux_state_pred.w_robot_keybody_rel_pos must be >= 0.\"\n            )\n        if self.aux_state_pred_w_denoise_ref_root_lin_vel < 0.0:\n            raise ValueError(\n                \"aux_state_pred.w_denoise_ref_root_lin_vel must be >= 0.\"\n            )\n        if self.aux_state_pred_w_denoise_ref_root_ang_vel < 0.0:\n            raise ValueError(\n                \"aux_state_pred.w_denoise_ref_root_ang_vel must be >= 0.\"\n            )\n        if self.aux_state_pred_w_denoise_ref_dof_pos < 0.0:\n            raise ValueError(\n                \"aux_state_pred.w_denoise_ref_dof_pos must be >= 0.\"\n            )\n        if self.use_aux_root_height:\n            if self.aux_state_pred_raycast_max_dist <= 0.0:\n                raise ValueError(\n                    \"aux_state_pred.raycast_max_dist must be > 0.\"\n                )\n            if self.aux_state_pred_raycast_z_offset < 0.0:\n                raise ValueError(\n                    \"aux_state_pred.raycast_z_offset must be >= 0.\"\n                )\n        if (\n            self.aux_state_pred_w_keybody_contact > 0.0\n            and self.aux_state_pred_num_contact_bodies == 0\n        ):\n            raise ValueError(\n                \"aux_state_pred.w_keybody_contact > 0 requires \"\n                \"aux_state_pred.keybody_contact_names to be non-empty.\"\n            )\n        if (\n            self.aux_state_pred_w_ref_keybody_rel_pos > 0.0\n            or self.aux_state_pred_w_robot_keybody_rel_pos > 0.0\n        ) and self.aux_state_pred_num_keybody_bodies == 0:\n            raise ValueError(\n                \"aux_state_pred keybody position weights > 0 require \"\n                \"aux_state_pred.keybody_rel_pos_names to be non-empty.\"\n            )\n        if self.use_aux_state_pred and self.command_name != \"ref_motion\":\n            raise ValueError(\n                \"aux_state_pred is only supported for PPOTF motion tracking \"\n                \"(command_name='ref_motion').\"\n            )\n        PpoAuxTransition.SHAPE_TOKENS[\"C\"] = (\n            self.aux_state_pred_num_contact_bodies\n        )\n        PpoAuxTransition.SHAPE_TOKENS[\"K\"] = (\n            self.aux_state_pred_num_keybody_bodies\n        )\n        aux_cmd_cfg = self.config.get(\"aux_router_command_recon\", {})\n        self.use_aux_router_command_recon: bool = bool(\n            aux_cmd_cfg.get(\"enabled\", False)\n        )\n        self.aux_router_command_recon_weight = float(\n            aux_cmd_cfg.get(\"weight\", 0.0)\n        )\n        self.aux_router_command_recon_hidden_dim = int(\n            aux_cmd_cfg.get(\"hidden_dim\", 0)\n        )\n        self.aux_router_command_recon_term_prefix = str(\n            aux_cmd_cfg.get(\"term_prefix\", \"actor_ref_\")\n        )\n        aux_switch_cfg = self.config.get(\"aux_router_switch_penalty\", {})\n        self.use_aux_router_switch_penalty = bool(\n            aux_switch_cfg.get(\"enabled\", False)\n        )\n        self.aux_router_switch_penalty_weight = float(\n            aux_switch_cfg.get(\"weight\", 0.0)\n        )\n        self.aux_router_switch_penalty_metric = str(\n            aux_switch_cfg.get(\"metric\", \"js\")\n        ).lower()\n        self.aux_router_switch_penalty_beta = float(\n            aux_switch_cfg.get(\"beta\", 1.0)\n        )\n        aux_router_future_cfg = self.config.get(\"aux_router_future_recon\", {})\n        self.use_aux_router_future_recon = bool(\n            aux_router_future_cfg.get(\"enabled\", False)\n        )\n        self.aux_router_future_recon_weight = float(\n            aux_router_future_cfg.get(\"weight\", 0.0)\n        )\n        self.aux_router_future_recon_hidden_dim = int(\n            aux_router_future_cfg.get(\"hidden_dim\", 0)\n        )\n        self.aux_router_future_recon_huber_beta = float(\n            aux_router_future_cfg.get(\"huber_beta\", 1.0)\n        )\n        dead_margin_cfg = self.config.get(\"dead_expert_margin_to_topk\", {})\n        self.use_dead_expert_margin_to_topk = bool(\n            dead_margin_cfg.get(\"enabled\", False)\n        )\n        self.dead_expert_margin_to_topk_weight = float(\n            dead_margin_cfg.get(\"weight\", 0.0)\n        )\n        orth_cfg = self.config.get(\"router_expert_orthogonal\", {})\n        self.use_router_expert_orthogonal = bool(\n            orth_cfg.get(\"enabled\", False)\n        )\n        self.router_expert_orthogonal_weight = float(\n            orth_cfg.get(\"weight\", 0.0)\n        )\n        self.router_expert_orthogonal_min_active_usage = float(\n            orth_cfg.get(\"min_active_usage\", 1.0e-3)\n        )\n        self.router_expert_orthogonal_eps = float(orth_cfg.get(\"eps\", 1.0e-8))\n        selected_margin_cfg = self.config.get(\n            \"selected_expert_margin_to_unselected\", {}\n        )\n        self.use_selected_expert_margin_to_unselected = bool(\n            selected_margin_cfg.get(\"enabled\", False)\n        )\n        self.selected_expert_margin_to_unselected_weight = float(\n            selected_margin_cfg.get(\"weight\", 0.0)\n        )\n        self.selected_expert_margin_to_unselected_target = float(\n            selected_margin_cfg.get(\"target\", 0.0)\n        )\n        if self.aux_router_switch_penalty_metric not in {\n            \"js\",\n            \"normed_smooth_l1\",\n        }:\n            raise ValueError(\n                \"aux_router_switch_penalty.metric must be one of \"\n                \"{'js', 'normed_smooth_l1'}, got \"\n                f\"{self.aux_router_switch_penalty_metric!r}.\"\n            )\n        if self.aux_router_command_recon_weight < 0.0:\n            raise ValueError(\"aux_router_command_recon.weight must be >= 0.\")\n        if self.aux_router_future_recon_weight < 0.0:\n            raise ValueError(\"aux_router_future_recon.weight must be >= 0.\")\n        if self.aux_router_switch_penalty_weight < 0.0:\n            raise ValueError(\"aux_router_switch_penalty.weight must be >= 0.\")\n        if self.dead_expert_margin_to_topk_weight < 0.0:\n            raise ValueError(\"dead_expert_margin_to_topk.weight must be >= 0.\")\n        if self.router_expert_orthogonal_weight < 0.0:\n            raise ValueError(\"router_expert_orthogonal.weight must be >= 0.\")\n        if self.router_expert_orthogonal_min_active_usage < 0.0:\n            raise ValueError(\n                \"router_expert_orthogonal.min_active_usage must be >= 0.\"\n            )\n        if self.selected_expert_margin_to_unselected_weight < 0.0:\n            raise ValueError(\n                \"selected_expert_margin_to_unselected.weight must be >= 0.\"\n            )\n        if self.router_expert_orthogonal_eps <= 0.0:\n            raise ValueError(\"router_expert_orthogonal.eps must be > 0.\")\n        if self.aux_router_switch_penalty_beta <= 0.0:\n            raise ValueError(\"aux_router_switch_penalty.beta must be > 0.\")\n        if self.aux_router_future_recon_huber_beta <= 0.0:\n            raise ValueError(\"aux_router_future_recon.huber_beta must be > 0.\")\n        if self.selected_expert_margin_to_unselected_target < 0.0:\n            raise ValueError(\n                \"selected_expert_margin_to_unselected.target must be >= 0.\"\n            )\n        if (\n            self.use_dead_expert_margin_to_topk\n            and self.dead_expert_margin_to_topk_weight == 0.0\n        ):\n            logger.warning(\n                \"dead_expert_margin_to_topk.enabled=True but weight=0.0; \"\n                \"dead-expert margin loss will have no effect.\"\n            )\n        if (\n            self.use_router_expert_orthogonal\n            and not self.use_dead_expert_margin_to_topk\n        ):\n            raise ValueError(\n                \"router_expert_orthogonal.enabled=True requires \"\n                \"dead_expert_margin_to_topk.enabled=True in sparse top-k MoE.\"\n            )\n        if (\n            self.use_router_expert_orthogonal\n            and self.router_expert_orthogonal_weight == 0.0\n        ):\n            logger.warning(\n                \"router_expert_orthogonal.enabled=True but weight=0.0; \"\n                \"orthogonal regularization will have no effect.\"\n            )\n        if (\n            self.use_selected_expert_margin_to_unselected\n            and self.selected_expert_margin_to_unselected_weight == 0.0\n        ):\n            logger.warning(\n                \"selected_expert_margin_to_unselected.enabled=True but \"\n                \"weight=0.0; selected-expert margin loss will have no effect.\"\n            )\n        if (\n            self.use_aux_router_switch_penalty\n            and self.aux_router_switch_penalty_weight == 0.0\n        ):\n            logger.warning(\n                \"aux_router_switch_penalty.enabled=True but weight=0.0; \"\n                \"router switch penalty will have no effect.\"\n            )\n        if (\n            self.use_aux_router_future_recon\n            and self.aux_router_future_recon_weight == 0.0\n        ):\n            logger.warning(\n                \"aux_router_future_recon.enabled=True but weight=0.0; \"\n                \"future reconstruction loss will have no effect.\"\n            )\n        if (\n            self.use_aux_router_command_recon\n            or self.use_aux_router_switch_penalty\n            or self.use_aux_router_future_recon\n        ) and self.command_name != \"ref_motion\":\n            raise ValueError(\n                \"aux_router_command_recon, aux_router_future_recon, and \"\n                \"aux_router_switch_penalty are \"\n                \"only supported for PPOTF motion tracking \"\n                \"(command_name='ref_motion').\"\n            )\n        self.aux_command_router_num_moe_layers = 0\n        self.aux_command_router_num_fine_experts = 0\n        self.aux_router_command_recon_assembler: TensorDictAssembler | None = (\n            None\n        )\n        actor_cfg = self.config.get(\"module_dict\", {}).get(\"actor\", {})\n        actor_type = str(actor_cfg.get(\"type\", \"\"))\n        if actor_type in {\n            \"ReferenceRoutedGroupedMoETransformerPolicyV2\",\n            \"ReferenceRoutedGroupedMoETransformerPolicyV3\",\n        }:\n            if self.use_aux_router_command_recon:\n                raise ValueError(\n                    f\"{actor_type} does not support aux_router_command_recon.\"\n                )\n            unsupported_aux_weights = {\n                \"w_root_height\": self.aux_state_pred_w_root_height,\n                \"w_denoise_ref_root_lin_vel\": self.aux_state_pred_w_denoise_ref_root_lin_vel,\n                \"w_denoise_ref_root_ang_vel\": self.aux_state_pred_w_denoise_ref_root_ang_vel,\n                \"w_denoise_ref_dof_pos\": self.aux_state_pred_w_denoise_ref_dof_pos,\n            }\n            enabled_unsupported = [\n                name\n                for name, value in unsupported_aux_weights.items()\n                if float(value) > 0.0\n            ]\n            if enabled_unsupported:\n                raise ValueError(\n                    f\"{actor_type} only supports \"\n                    \"aux_state_pred weights for base_lin_vel, keybody_contact, \"\n                    \"ref_keybody_rel_pos, and robot_keybody_rel_pos. Unsupported \"\n                    \"weights: \" + \", \".join(enabled_unsupported)\n                )\n        elif self.use_aux_router_future_recon:\n            raise ValueError(\n                \"aux_router_future_recon requires \"\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 or V3.\"\n            )\n\n    @staticmethod\n    def _unwrap_obs_schema(schema: dict | None) -> dict | None:\n        if schema is None:\n            return None\n        has_terms = any(\n            isinstance(v, dict) and (\"terms\" in v) for v in schema.values()\n        )\n        if has_terms:\n            return schema\n        if len(schema) == 1:\n            only_value = next(iter(schema.values()))\n            if isinstance(only_value, dict):\n                return only_value\n        return schema\n\n    @staticmethod\n    def _schema_term_leaf_name(term: str) -> str:\n        return str(term).split(\"/\")[-1]\n\n    @classmethod\n    def _is_aux_command_term(cls, term: str, term_prefix: str) -> bool:\n        return cls._schema_term_leaf_name(term).startswith(term_prefix)\n\n    @classmethod\n    def _build_aux_router_command_recon_schema(\n        cls, actor_schema: dict, term_prefix: str\n    ) -> dict:\n        command_schema = {}\n        for group_name, seq_cfg in actor_schema.items():\n            terms = [\n                str(term)\n                for term in seq_cfg.get(\"terms\", [])\n                if cls._is_aux_command_term(str(term), term_prefix)\n            ]\n            if len(terms) == 0:\n                continue\n            next_seq_cfg = dict(seq_cfg)\n            next_seq_cfg[\"terms\"] = terms\n            command_schema[group_name] = next_seq_cfg\n        if len(command_schema) == 0:\n            raise ValueError(\n                \"aux_router_command_recon could not find any actor command terms in \"\n                f\"obs_schema with prefix '{term_prefix}'.\"\n            )\n        return command_schema\n\n    @staticmethod\n    def _masked_aux_keybody_mse(\n        pred: torch.Tensor,\n        target: torch.Tensor,\n        valid_tok: torch.Tensor,\n    ) -> torch.Tensor:\n        if pred.shape != target.shape:\n            raise ValueError(\n                \"pred and target must have the same shape for keybody MSE, \"\n                f\"got {tuple(pred.shape)} and {tuple(target.shape)}.\"\n            )\n        if pred.ndim != 4:\n            raise ValueError(\n                \"Keybody MSE expects [B, T, K, 3] tensors, \"\n                f\"got pred with shape {tuple(pred.shape)}.\"\n            )\n        per_token_mse = torch.square(pred - target).mean(dim=(-1, -2))\n        valid_tok = valid_tok.to(per_token_mse.dtype)\n        if valid_tok.shape != per_token_mse.shape:\n            raise ValueError(\n                \"valid_tok must match per-token keybody MSE shape, \"\n                f\"got {tuple(valid_tok.shape)} and \"\n                f\"{tuple(per_token_mse.shape)}.\"\n            )\n        valid_count = valid_tok.sum().clamp_min(1.0)\n        return (per_token_mse * valid_tok).sum() / valid_count\n\n    @staticmethod\n    def _masked_aux_mse(\n        pred: torch.Tensor,\n        target: torch.Tensor,\n        valid_tok: torch.Tensor,\n    ) -> torch.Tensor:\n        if pred.shape != target.shape:\n            raise ValueError(\n                \"pred and target must share the same shape for auxiliary MSE, \"\n                f\"got {tuple(pred.shape)} and {tuple(target.shape)}.\"\n            )\n        if pred.ndim < 3:\n            raise ValueError(\n                \"Auxiliary MSE expects tensors with shape [B, T, ...], \"\n                f\"got {tuple(pred.shape)}.\"\n            )\n        reduce_dims = tuple(range(2, pred.ndim))\n        per_token_mse = torch.square(pred - target).mean(dim=reduce_dims)\n        valid_tok = valid_tok.to(per_token_mse.dtype)\n        if valid_tok.shape != per_token_mse.shape:\n            raise ValueError(\n                \"valid_tok must match per-token auxiliary MSE shape, got \"\n                f\"{tuple(valid_tok.shape)} and {tuple(per_token_mse.shape)}.\"\n            )\n        valid_count = valid_tok.sum().clamp_min(1.0)\n        return (per_token_mse * valid_tok).sum() / valid_count\n\n    @staticmethod\n    def _masked_adjacent_router_js(\n        *,\n        router_features: torch.Tensor,\n        valid_tok: torch.Tensor,\n        num_moe_layers: int,\n        num_fine_experts: int,\n    ) -> torch.Tensor:\n        if router_features.ndim != 3:\n            raise ValueError(\n                \"router_features must have shape [B, T, L*E], got \"\n                f\"{tuple(router_features.shape)}.\"\n            )\n        if valid_tok.ndim != 2:\n            raise ValueError(\n                \"valid_tok must have shape [B, T], got \"\n                f\"{tuple(valid_tok.shape)}.\"\n            )\n        if num_moe_layers <= 0 or num_fine_experts <= 0:\n            raise ValueError(\n                \"num_moe_layers and num_fine_experts must be positive, got \"\n                f\"{num_moe_layers} and {num_fine_experts}.\"\n            )\n        bsz, seq_len, feat_dim = router_features.shape\n        expected_dim = num_moe_layers * num_fine_experts\n        if feat_dim != expected_dim:\n            raise ValueError(\n                \"router_features last dim must equal num_moe_layers * \"\n                \"num_fine_experts, got \"\n                f\"{feat_dim} vs {expected_dim}.\"\n            )\n        if valid_tok.shape != (bsz, seq_len):\n            raise ValueError(\n                \"valid_tok shape mismatch for router temporal loss: expected \"\n                f\"{(bsz, seq_len)}, got {tuple(valid_tok.shape)}.\"\n            )\n        if seq_len <= 1:\n            return router_features.new_zeros(())\n\n        router_p = router_features.reshape(\n            bsz, seq_len, num_moe_layers, num_fine_experts\n        ).to(torch.float32)\n        prev_p = router_p[:, :-1]\n        curr_p = router_p[:, 1:]\n        mix_p = 0.5 * (prev_p + curr_p)\n        eps = 1.0e-20\n        prev_safe = prev_p.clamp_min(eps)\n        curr_safe = curr_p.clamp_min(eps)\n        mix_safe = mix_p.clamp_min(eps)\n        kl_prev = (prev_p * (torch.log(prev_safe) - torch.log(mix_safe))).sum(\n            dim=-1\n        )\n        kl_curr = (curr_p * (torch.log(curr_safe) - torch.log(mix_safe))).sum(\n            dim=-1\n        )\n        js = 0.5 * (kl_prev + kl_curr)\n        adjacent_valid = (valid_tok[:, :-1] * valid_tok[:, 1:]).to(js.dtype)\n        valid_count = adjacent_valid.sum().clamp_min(1.0) * float(\n            num_moe_layers\n        )\n        return (js * adjacent_valid.unsqueeze(-1)).sum() / valid_count\n\n    @staticmethod\n    def _masked_adjacent_router_normed_smooth_l1(\n        *,\n        router_temporal_features: torch.Tensor,\n        valid_tok: torch.Tensor,\n        num_moe_layers: int,\n        num_fine_experts: int,\n        beta: float = 1.0,\n    ) -> torch.Tensor:\n        if router_temporal_features.ndim != 3:\n            raise ValueError(\n                \"router_temporal_features must have shape [B, T, L*E], got \"\n                f\"{tuple(router_temporal_features.shape)}.\"\n            )\n        if valid_tok.ndim != 2:\n            raise ValueError(\n                \"valid_tok must have shape [B, T], got \"\n                f\"{tuple(valid_tok.shape)}.\"\n            )\n        if num_moe_layers <= 0 or num_fine_experts <= 0:\n            raise ValueError(\n                \"num_moe_layers and num_fine_experts must be positive, got \"\n                f\"{num_moe_layers} and {num_fine_experts}.\"\n            )\n        if beta <= 0.0:\n            raise ValueError(\n                f\"beta must be positive for SmoothL1, got {beta}.\"\n            )\n        bsz, seq_len, feat_dim = router_temporal_features.shape\n        expected_dim = num_moe_layers * num_fine_experts\n        if feat_dim != expected_dim:\n            raise ValueError(\n                \"router_temporal_features last dim must equal \"\n                \"num_moe_layers * num_fine_experts, got \"\n                f\"{feat_dim} vs {expected_dim}.\"\n            )\n        if valid_tok.shape != (bsz, seq_len):\n            raise ValueError(\n                \"valid_tok shape mismatch for router temporal loss: expected \"\n                f\"{(bsz, seq_len)}, got {tuple(valid_tok.shape)}.\"\n            )\n        if seq_len <= 1:\n            return router_temporal_features.new_zeros(())\n\n        router_logits = router_temporal_features.reshape(\n            bsz, seq_len, num_moe_layers, num_fine_experts\n        ).to(torch.float32)\n        router_logits = router_logits - router_logits.mean(\n            dim=-1, keepdim=True\n        )\n        router_logits = F.normalize(router_logits, p=2.0, dim=-1, eps=1.0e-5)\n        prev_logits = router_logits[:, :-1]\n        curr_logits = router_logits[:, 1:]\n        smooth_l1 = F.smooth_l1_loss(\n            curr_logits,\n            prev_logits,\n            reduction=\"none\",\n            beta=beta,\n        ).mean(dim=(-1, -2))\n        adjacent_valid = (valid_tok[:, :-1] * valid_tok[:, 1:]).to(\n            smooth_l1.dtype\n        )\n        valid_count = adjacent_valid.sum().clamp_min(1.0)\n        return (smooth_l1 * adjacent_valid).sum() / valid_count\n\n    @staticmethod\n    def _masked_aux_gaussian_nll(\n        *,\n        loc: torch.Tensor,\n        log_std: torch.Tensor,\n        target: torch.Tensor,\n        valid_tok: torch.Tensor,\n        min_std: float,\n        max_std: float,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        if loc.shape != log_std.shape or loc.shape != target.shape:\n            raise ValueError(\n                \"loc, log_std, and target must share the same shape for \"\n                \"Gaussian aux loss, got \"\n                f\"{tuple(loc.shape)}, {tuple(log_std.shape)}, \"\n                f\"{tuple(target.shape)}.\"\n            )\n        if loc.ndim < 3:\n            raise ValueError(\n                \"Gaussian aux loss expects tensors with shape [B, T, ...], \"\n                f\"got {tuple(loc.shape)}.\"\n            )\n        per_elem_std = torch.clamp(\n            torch.exp(log_std),\n            min=float(min_std),\n            max=float(max_std),\n        )\n        reduce_dims = tuple(range(2, loc.ndim))\n        per_token_nll = 0.5 * (\n            torch.square((target - loc) / per_elem_std)\n            + 2.0 * torch.log(per_elem_std + 1.0e-8)\n        ).sum(dim=reduce_dims)\n        valid_tok = valid_tok.to(per_token_nll.dtype)\n        if valid_tok.shape != per_token_nll.shape:\n            raise ValueError(\n                \"valid_tok must match per-token Gaussian loss shape, got \"\n                f\"{tuple(valid_tok.shape)} and {tuple(per_token_nll.shape)}.\"\n            )\n        valid_count = valid_tok.sum().clamp_min(1.0)\n        loss = (per_token_nll * valid_tok).sum() / valid_count\n        per_token_std = per_elem_std.reshape(\n            per_elem_std.shape[0], per_elem_std.shape[1], -1\n        ).mean(dim=-1)\n        mean_std = (per_token_std * valid_tok).sum() / valid_count\n        return loss, mean_std\n\n    @staticmethod\n    def _masked_aux_huber(\n        *,\n        pred: torch.Tensor,\n        target: torch.Tensor,\n        valid_tok: torch.Tensor,\n        beta: float,\n    ) -> torch.Tensor:\n        if pred.shape != target.shape:\n            raise ValueError(\n                \"pred and target must share the same shape for Huber aux loss, \"\n                f\"got {tuple(pred.shape)} and {tuple(target.shape)}.\"\n            )\n        if pred.ndim < 3:\n            raise ValueError(\n                \"Huber aux loss expects tensors with shape [B, T, ...], \"\n                f\"got {tuple(pred.shape)}.\"\n            )\n        per_elem = F.smooth_l1_loss(pred, target, reduction=\"none\", beta=beta)\n        reduce_dims = tuple(range(2, pred.ndim))\n        per_token = per_elem.mean(dim=reduce_dims)\n        valid_tok = valid_tok.to(per_token.dtype)\n        if valid_tok.shape != per_token.shape:\n            raise ValueError(\n                \"valid_tok must match per-token Huber loss shape, got \"\n                f\"{tuple(valid_tok.shape)} and {tuple(per_token.shape)}.\"\n            )\n        valid_count = valid_tok.sum().clamp_min(1.0)\n        return (per_token * valid_tok).sum() / valid_count\n\n    def _compute_aux_router_future_recon_loss(\n        self,\n        *,\n        actor_wrapper: PPOTFActor,\n        actor_out: TensorDict,\n        obs_b: TensorDict,\n        valid_tok: torch.Tensor,\n    ) -> torch.Tensor:\n        future_assembler = actor_wrapper.aux_router_future_recon_assembler\n        if future_assembler is None:\n            raise ValueError(\n                \"aux_router_future_recon is enabled but future assembler was \"\n                \"not initialized on the actor wrapper.\"\n            )\n        aux_router_future_recon_pred = actor_out.get(\"aux_router_future_recon\")\n        bsz, seq_len = int(obs_b.batch_size[0]), int(obs_b.batch_size[1])\n        future_target = future_assembler(obs_b.flatten(0, 1)).reshape(\n            bsz, seq_len, -1\n        )\n        normalized_future_target = actor_wrapper.actor_module.normalize_aux_router_future_recon_target(\n            future_target\n        ).to(aux_router_future_recon_pred.dtype)\n        return self._masked_aux_huber(\n            pred=aux_router_future_recon_pred,\n            target=normalized_future_target,\n            valid_tok=valid_tok,\n            beta=self.aux_router_future_recon_huber_beta,\n        )\n\n    def _compute_routed_expert_orthogonal_loss(\n        self,\n        moe_layer: GroupedMoEBlock,\n        *,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        usage = moe_layer.last_routed_expert_usage.to(\n            device=device, dtype=torch.float32\n        )\n        active_mask = usage > float(\n            self.router_expert_orthogonal_min_active_usage\n        )\n        active_idx = torch.nonzero(active_mask, as_tuple=False).squeeze(-1)\n        active_count = torch.tensor(\n            float(active_idx.numel()), device=device, dtype=torch.float32\n        )\n        if active_idx.numel() < 2:\n            zero = torch.zeros((), device=device, dtype=dtype)\n            zero_f = torch.zeros((), device=device, dtype=torch.float32)\n            return zero, active_count, zero_f\n\n        expert_vecs = moe_layer.down_proj.index_select(0, active_idx)\n        expert_vecs = expert_vecs.reshape(active_idx.numel(), -1).to(\n            device=device, dtype=torch.float32\n        )\n        expert_vecs = F.normalize(\n            expert_vecs,\n            p=2.0,\n            dim=-1,\n            eps=float(self.router_expert_orthogonal_eps),\n        )\n        gram = expert_vecs @ expert_vecs.transpose(0, 1)\n        offdiag_mask = ~torch.eye(\n            gram.shape[0], dtype=torch.bool, device=gram.device\n        )\n        offdiag = gram.masked_select(offdiag_mask)\n        if offdiag.numel() == 0:\n            zero = torch.zeros((), device=device, dtype=dtype)\n            zero_f = torch.zeros((), device=device, dtype=torch.float32)\n            return zero, active_count, zero_f\n\n        orth_loss = offdiag.square().sum().to(dtype)\n        mean_offdiag_similarity = offdiag.abs().mean()\n        return orth_loss, active_count, mean_offdiag_similarity\n\n    @staticmethod\n    def _root_relative_body_pos_from_mixed_position_frames(\n        *,\n        body_pos_w: torch.Tensor,\n        root_pos_env: torch.Tensor,\n        root_quat_w: torch.Tensor,\n        env_origins: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Convert world-frame body positions using an env-frame root pose.\n\n        In IsaacLab, `isaaclab_mdp.root_pos_w(env)` is already in the\n        environment frame (simulator world minus `env.scene.env_origins`),\n        while `robot.data.body_pos_w` stays in simulator-world coordinates.\n        \"\"\"\n        if body_pos_w.ndim != 3 or body_pos_w.shape[-1] != 3:\n            raise ValueError(\n                \"body_pos_w must have shape [B, N, 3], \"\n                f\"got {tuple(body_pos_w.shape)}.\"\n            )\n        if root_pos_env.ndim != 2 or root_pos_env.shape[-1] != 3:\n            raise ValueError(\n                \"root_pos_env must have shape [B, 3], \"\n                f\"got {tuple(root_pos_env.shape)}.\"\n            )\n        if root_quat_w.ndim != 2 or root_quat_w.shape[-1] != 4:\n            raise ValueError(\n                \"root_quat_w must have shape [B, 4], \"\n                f\"got {tuple(root_quat_w.shape)}.\"\n            )\n        if env_origins.ndim != 2 or env_origins.shape[-1] != 3:\n            raise ValueError(\n                \"env_origins must have shape [B, 3], \"\n                f\"got {tuple(env_origins.shape)}.\"\n            )\n        if body_pos_w.shape[0] != root_pos_env.shape[0]:\n            raise ValueError(\n                \"Batch size mismatch between body_pos_w and root_pos_env: \"\n                f\"{body_pos_w.shape[0]} vs {root_pos_env.shape[0]}.\"\n            )\n        if body_pos_w.shape[0] != root_quat_w.shape[0]:\n            raise ValueError(\n                \"Batch size mismatch between body_pos_w and root_quat_w: \"\n                f\"{body_pos_w.shape[0]} vs {root_quat_w.shape[0]}.\"\n            )\n        if body_pos_w.shape[0] != env_origins.shape[0]:\n            raise ValueError(\n                \"Batch size mismatch between body_pos_w and env_origins: \"\n                f\"{body_pos_w.shape[0]} vs {env_origins.shape[0]}.\"\n            )\n        body_pos_env = body_pos_w - env_origins[:, None, :]\n        rel_pos_env = body_pos_env - root_pos_env[:, None, :]\n        quat_vec = root_quat_w[:, None, 1:].expand_as(rel_pos_env)\n        quat_real = root_quat_w[:, None, :1].expand(\n            -1, rel_pos_env.shape[1], -1\n        )\n        t = 2.0 * torch.cross(quat_vec, rel_pos_env, dim=-1)\n        return rel_pos_env - quat_real * t + torch.cross(quat_vec, t, dim=-1)\n\n    def _setup_models_and_optimizer(self):\n        sample_obs_dict = self.env.reset_all()[0]\n        sample_td = self._wrap_obs_dict(sample_obs_dict)\n\n        actor_cfg = OmegaConf.to_container(\n            self.config.module_dict.actor, resolve=True\n        )\n        critic_cfg = OmegaConf.to_container(\n            self.config.module_dict.critic, resolve=True\n        )\n        actor_cfg[\"noise_std_type\"] = getattr(\n            self.config, \"noise_std_type\", \"log\"\n        )\n        actor_cfg[\"min_sigma\"] = getattr(self.config, \"min_sigma\", 0.1)\n        actor_cfg[\"max_sigma\"] = getattr(self.config, \"max_sigma\", 1.5)\n        actor_cfg[\"fix_sigma\"] = getattr(self.config, \"fix_sigma\", False)\n        self._future_mask_prob = float(actor_cfg.get(\"future_mask_prob\", 0.0))\n        self._future_mask_mode = str(\n            actor_cfg.get(\"future_mask_mode\", \"random_suffix\")\n        ).lower()\n        aux_cfg = self.config.get(\"aux_state_pred\", {})\n        if isinstance(aux_cfg, dict):\n            actor_cfg[\"aux_state_pred\"] = dict(aux_cfg)\n        else:\n            actor_cfg[\"aux_state_pred\"] = OmegaConf.to_container(\n                aux_cfg, resolve=True\n            )\n        aux_cmd_cfg = self.config.get(\"aux_router_command_recon\", {})\n        if isinstance(aux_cmd_cfg, dict):\n            actor_aux_cmd_cfg = dict(aux_cmd_cfg)\n        else:\n            actor_aux_cmd_cfg = OmegaConf.to_container(\n                aux_cmd_cfg, resolve=True\n            )\n        aux_switch_cfg = self.config.get(\"aux_router_switch_penalty\", {})\n        if isinstance(aux_switch_cfg, dict):\n            actor_aux_switch_cfg = dict(aux_switch_cfg)\n        else:\n            actor_aux_switch_cfg = OmegaConf.to_container(\n                aux_switch_cfg, resolve=True\n            )\n        aux_router_future_cfg = self.config.get(\"aux_router_future_recon\", {})\n        if isinstance(aux_router_future_cfg, dict):\n            actor_aux_router_future_cfg = dict(aux_router_future_cfg)\n        else:\n            actor_aux_router_future_cfg = OmegaConf.to_container(\n                aux_router_future_cfg, resolve=True\n            )\n        dead_margin_cfg = self.config.get(\"dead_expert_margin_to_topk\", {})\n        if isinstance(dead_margin_cfg, dict):\n            actor_dead_margin_cfg = dict(dead_margin_cfg)\n        else:\n            actor_dead_margin_cfg = OmegaConf.to_container(\n                dead_margin_cfg, resolve=True\n            )\n        selected_margin_cfg = self.config.get(\n            \"selected_expert_margin_to_unselected\", {}\n        )\n        if isinstance(selected_margin_cfg, dict):\n            actor_selected_margin_cfg = dict(selected_margin_cfg)\n        else:\n            actor_selected_margin_cfg = OmegaConf.to_container(\n                selected_margin_cfg, resolve=True\n            )\n\n        actor_schema = self._unwrap_obs_schema(\n            actor_cfg.get(\"obs_schema\", None)\n        )\n        critic_schema = self._unwrap_obs_schema(\n            critic_cfg.get(\"obs_schema\", None)\n        )\n        if actor_schema is None:\n            raise ValueError(\n                \"PPOTF requires actor obs_schema to infer flattened obs dim.\"\n            )\n        if self.use_aux_router_command_recon:\n            aux_command_schema = self._build_aux_router_command_recon_schema(\n                actor_schema, self.aux_router_command_recon_term_prefix\n            )\n            self.aux_router_command_recon_assembler = TensorDictAssembler(\n                aux_command_schema, output_mode=\"flat\"\n            )\n            actor_aux_cmd_cfg[\"output_dim\"] = int(\n                self.aux_router_command_recon_assembler.infer_output_dim(\n                    sample_td\n                )\n            )\n            if self.aux_router_command_recon_hidden_dim > 0:\n                actor_aux_cmd_cfg[\"hidden_dim\"] = (\n                    self.aux_router_command_recon_hidden_dim\n                )\n        actor_cfg[\"aux_router_command_recon\"] = actor_aux_cmd_cfg\n        actor_cfg[\"aux_router_future_recon\"] = actor_aux_router_future_cfg\n        actor_cfg[\"aux_router_switch_penalty\"] = actor_aux_switch_cfg\n        actor_cfg[\"dead_expert_margin_to_topk\"] = actor_dead_margin_cfg\n        actor_cfg[\"selected_expert_margin_to_unselected\"] = (\n            actor_selected_margin_cfg\n        )\n        actor_obs_dim = int(\n            TensorDictAssembler(\n                actor_schema, output_mode=\"flat\"\n            ).infer_output_dim(sample_td)\n        )\n        use_future_cross_attn = bool(\n            actor_cfg.get(\"use_future_cross_attn\", False)\n        )\n        actor_cls = self._select_actor_wrapper_cls(actor_cfg)\n        if use_future_cross_attn:\n            if \"flattened_obs\" not in actor_schema:\n                raise ValueError(\n                    \"use_future_cross_attn=True requires \"\n                    \"actor obs_schema.flattened_obs.\"\n                )\n            if \"flattened_obs_fut\" not in actor_schema:\n                raise ValueError(\n                    \"use_future_cross_attn=True requires \"\n                    \"actor obs_schema.flattened_obs_fut.\"\n                )\n            state_schema = {\"flattened_obs\": actor_schema[\"flattened_obs\"]}\n            future_schema = {\n                \"flattened_obs_fut\": actor_schema[\"flattened_obs_fut\"]\n            }\n            state_obs_dim = int(\n                TensorDictAssembler(\n                    state_schema, output_mode=\"flat\"\n                ).infer_output_dim(sample_td)\n            )\n            future_asm = TensorDictAssembler(future_schema, output_mode=\"seq\")\n            future_token_dim = int(future_asm.infer_output_dim(sample_td))\n            future_seq_len = int(future_asm.seq_len)\n            actor_cfg[\"state_obs_dim\"] = state_obs_dim\n            actor_cfg[\"future_token_dim\"] = future_token_dim\n            actor_cfg[\"future_seq_len\"] = future_seq_len\n            actor_cfg[\"input_dim_override\"] = state_obs_dim\n        else:\n            actor_cfg[\"input_dim_override\"] = actor_obs_dim\n\n        self.actor = actor_cls(\n            obs_schema=actor_schema,\n            module_config_dict=actor_cfg,\n            num_actions=self.num_actions,\n            init_noise_std=self.config.init_noise_std,\n            obs_example=sample_td,\n        ).to(self.device)\n        actor_module_unwrapped = self.actor.actor_module\n        self.aux_command_router_num_moe_layers = int(\n            getattr(actor_module_unwrapped, \"_num_moe_layers\", 0)\n        )\n        self.aux_command_router_num_fine_experts = int(\n            getattr(actor_module_unwrapped, \"num_fine_experts\", 0)\n        )\n        if (\n            self.use_aux_router_switch_penalty\n            and self.aux_command_router_num_moe_layers <= 0\n        ):\n            raise ValueError(\n                \"aux_router_switch_penalty requires at least one \"\n                \"GroupedMoEBlock.\"\n            )\n        self.critic = PPOCritic(\n            obs_schema=critic_schema,\n            module_config_dict=critic_cfg,\n            obs_example=sample_td,\n        ).to(self.device)\n\n        if self.is_main_process:\n            actor = self.accelerator.unwrap_model(self.actor)\n            critic = self.accelerator.unwrap_model(self.critic)\n\n            logger.info(\"Actor (TensorDict module):\\n{!r}\", actor)\n            logger.info(\n                \"Actor keys: in_keys={} out_keys={}\",\n                list(actor.in_keys),\n                list(actor.out_keys),\n            )\n            logger.info(\"Actor core nn module:\\n{!r}\", actor.actor_module)\n\n            logger.info(\"Critic (TensorDict module):\\n{!r}\", critic)\n            logger.info(\n                \"Critic keys: in_keys={} out_keys={}\",\n                list(critic.in_keys),\n                list(critic.out_keys),\n            )\n            logger.info(\"Critic core nn module:\\n{!r}\", critic.critic_module)\n\n            actor_params = sum(p.numel() for p in self.actor.parameters())\n            critic_params = sum(p.numel() for p in self.critic.parameters())\n            params_table = [\n                [\"Actor(Transformer)\", f\"{actor_params / 1.0e6:.3f}\"],\n                [\"Critic\", f\"{critic_params / 1.0e6:.3f}\"],\n                [\"Total\", f\"{(actor_params + critic_params) / 1.0e6:.3f}\"],\n            ]\n            logger.info(\n                \"Model Summary:\\n\"\n                + tabulate(\n                    params_table,\n                    headers=[\"Model\", \"Params (M)\"],\n                    tablefmt=\"simple_outline\",\n                )\n            )\n\n        optimizer_class = getattr(optim, self.optimizer_type)\n        optimizer_kwargs = self._build_optimizer_kwargs(optimizer_class)\n        if self.optimizer_type == \"AdamW\":\n            decay_params = []\n            non_decay_params = []\n            for name, p in self.actor.named_parameters():\n                if not p.requires_grad:\n                    continue\n                if (\n                    p.ndim < 2\n                    or (\"log_std\" in name)\n                    or (\"bias\" in name)\n                    or (\"norm\" in name)\n                ):\n                    non_decay_params.append(p)\n                else:\n                    decay_params.append(p)\n            self.actor_optimizer = optimizer_class(\n                [\n                    {\"params\": decay_params, \"weight_decay\": 0.01},\n                    {\"params\": non_decay_params, \"weight_decay\": 0.0},\n                ],\n                lr=self.actor_learning_rate,\n                betas=(self.actor_beta1, self.actor_beta2),\n                **optimizer_kwargs,\n            )\n        else:\n            self.actor_optimizer = optimizer_class(\n                self.actor.parameters(),\n                lr=self.actor_learning_rate,\n                betas=(self.actor_beta1, self.actor_beta2),\n                **optimizer_kwargs,\n            )\n        self.critic_optimizer = optimizer_class(\n            self.critic.parameters(),\n            lr=self.critic_learning_rate,\n            betas=(self.critic_beta1, self.critic_beta2),\n            **optimizer_kwargs,\n        )\n\n        (\n            self.actor,\n            self.critic,\n            self.actor_optimizer,\n            self.critic_optimizer,\n        ) = self.accelerator.prepare(\n            self.actor,\n            self.critic,\n            self.actor_optimizer,\n            self.critic_optimizer,\n        )\n\n        actor_for_kv = self.accelerator.unwrap_model(self.actor)\n        if hasattr(actor_for_kv, \"reset_kv_cache\"):\n            actor_for_kv.reset_kv_cache(self.env.num_envs, self.device)\n        self._kv_reset_pending = torch.zeros(\n            self.env.num_envs, dtype=torch.bool, device=self.device\n        )\n        self._rollout_future_masks = None\n        self._rollout_step_idx = 0\n\n    def _setup_data_buffers(self):\n        super()._setup_data_buffers()\n        self._aux_height_scanner = None\n        self._aux_contact_sensor = None\n        self._aux_contact_body_ids = None\n        self._aux_keybody_body_ids = None\n        if not self.use_aux_state_pred:\n            return\n        if self.use_velocity_transition:\n            raise ValueError(\n                \"aux_state_pred is not supported with velocity \"\n                \"tracking in PPOTF.\"\n            )\n        self.transition_cls = PpoAuxTransition\n        if self.use_aux_root_height:\n            if \"height_scanner\" not in self.env._env.scene.sensors:\n                raise ValueError(\n                    \"aux_state_pred requires a RayCaster sensor \"\n                    \"named 'height_scanner' \"\n                    \"in env.scene.sensors.\"\n                )\n            height_scanner = self.env._env.scene.sensors[\"height_scanner\"]\n            height_scanner.cfg.max_distance = (\n                self.aux_state_pred_raycast_max_dist\n            )\n            height_scanner.cfg.ray_alignment = \"world\"\n            height_scanner.cfg.offset.pos = (\n                0.0,\n                0.0,\n                self.aux_state_pred_raycast_z_offset,\n            )\n            if height_scanner.is_initialized:\n                height_scanner.ray_starts[..., 2] = (\n                    self.aux_state_pred_raycast_z_offset\n                )\n            self._aux_height_scanner = height_scanner\n        if self.aux_state_pred_num_contact_bodies > 0:\n            if \"contact_forces\" not in self.env._env.scene.sensors:\n                raise ValueError(\n                    \"aux_state_pred.keybody_contact_names requires \"\n                    \"a ContactSensor \"\n                    \"named 'contact_forces' in env.scene.sensors.\"\n                )\n            contact_sensor = self.env._env.scene.sensors[\"contact_forces\"]\n            sensor_body_names = list(contact_sensor.body_names)\n            body_ids = []\n            for body_name in self.aux_state_pred_keybody_contact_names:\n                if body_name not in sensor_body_names:\n                    raise ValueError(\n                        f\"Body '{body_name}' not found in contact \"\n                        \"sensor body_names.\"\n                    )\n                body_ids.append(sensor_body_names.index(body_name))\n            self._aux_contact_sensor = contact_sensor\n            self._aux_contact_body_ids = torch.tensor(\n                body_ids, dtype=torch.long, device=self.device\n            )\n        if self.aux_state_pred_num_keybody_bodies > 0:\n            robot_body_names = list(self.env._env.scene[\"robot\"].body_names)\n            body_ids = []\n            for body_name in self.aux_state_pred_keybody_rel_pos_names:\n                if body_name not in robot_body_names:\n                    raise ValueError(\n                        f\"Body '{body_name}' not found in robot body_names.\"\n                    )\n                body_ids.append(robot_body_names.index(body_name))\n            self._aux_keybody_body_ids = torch.tensor(\n                body_ids, dtype=torch.long, device=self.device\n            )\n\n    def _build_transition(\n        self,\n        obs_td: TensorDict,\n        actor_out: TensorDict,\n        critic_out: TensorDict,\n    ):\n        if not self.use_aux_state_pred:\n            return super()._build_transition(obs_td, actor_out, critic_out)\n\n        import isaaclab.envs.mdp as isaaclab_mdp\n\n        actions = actor_out.get(\"actions\")\n        actions_log_prob = actor_out.get(\"actions_log_prob\")\n        mu = actor_out.get(\"mu\")\n        sigma = actor_out.get(\"sigma\")\n        values = critic_out.get(\"values\")\n        zero_scalar = torch.zeros(\n            self.num_envs,\n            1,\n            device=self.device,\n            dtype=torch.float32,\n        )\n        zero_scalar_bool = torch.zeros(\n            self.num_envs,\n            1,\n            device=self.device,\n            dtype=torch.bool,\n        )\n        gt_base_lin_vel_b = isaaclab_mdp.base_lin_vel(self.env._env)\n        if self.use_aux_root_height:\n            root_pos_w = isaaclab_mdp.root_pos_w(self.env._env)\n            if self._aux_height_scanner is None:\n                raise RuntimeError(\n                    \"Aux state prediction expected \"\n                    \"_aux_height_scanner to be initialized.\"\n                )\n            terrain_z = self._aux_height_scanner.data.ray_hits_w[:, 0, 2:3]\n            env_origin_z = self.env._env.scene.env_origins[:, 2:3]\n            terrain_z = torch.where(\n                torch.isfinite(terrain_z), terrain_z, env_origin_z\n            )\n            gt_root_height_rel_terrain = root_pos_w[:, 2:3] - terrain_z\n        else:\n            gt_root_height_rel_terrain = torch.zeros(\n                self.num_envs, 1, device=self.device, dtype=torch.float32\n            )\n        if self.aux_state_pred_num_contact_bodies > 0:\n            if (\n                self._aux_contact_sensor is None\n                or self._aux_contact_body_ids is None\n            ):\n                raise RuntimeError(\n                    \"Aux keybody contact prediction expects contact sensor \"\n                    \"and body ids to be initialized.\"\n                )\n            contact_time = self._aux_contact_sensor.data.current_contact_time[\n                :, self._aux_contact_body_ids\n            ]\n            gt_keybody_contacts = (contact_time > 0.0).to(torch.float32)\n        else:\n            gt_keybody_contacts = torch.zeros(\n                self.num_envs, 0, device=self.device, dtype=torch.float32\n            )\n        command = self.env._env.command_manager.get_term(self.command_name)\n        if self.aux_state_pred_num_keybody_bodies > 0:\n            if self._aux_keybody_body_ids is None:\n                raise RuntimeError(\n                    \"Aux keybody position prediction expects body \"\n                    \"ids to be initialized.\"\n                )\n            # Both the ref-motion command and robot asset expose bodies in\n            # simulator order, so the cached robot body indices align here.\n            gt_ref_keybody_rel_pos = (\n                command.get_ref_motion_bodylink_rel_pos_cur()[\n                    :, self._aux_keybody_body_ids, :\n                ]\n            )\n            robot = self.env._env.scene[\"robot\"]\n            robot_keybody_global_pos = robot.data.body_pos_w[\n                :, self._aux_keybody_body_ids, :\n            ]\n            env_origins = self.env._env.scene.env_origins\n            root_pos_w = isaaclab_mdp.root_pos_w(self.env._env)\n            root_quat_w = isaaclab_mdp.root_quat_w(self.env._env)\n            gt_robot_keybody_rel_pos = (\n                self._root_relative_body_pos_from_mixed_position_frames(\n                    body_pos_w=robot_keybody_global_pos,\n                    root_pos_env=root_pos_w,\n                    root_quat_w=root_quat_w,\n                    env_origins=env_origins,\n                )\n            )\n        else:\n            gt_ref_keybody_rel_pos = torch.zeros(\n                self.num_envs, 0, 3, device=self.device, dtype=torch.float32\n            )\n            gt_robot_keybody_rel_pos = torch.zeros(\n                self.num_envs, 0, 3, device=self.device, dtype=torch.float32\n            )\n        gt_denoise_ref_root_lin_vel = torch.zeros(\n            self.num_envs, 3, device=self.device, dtype=torch.float32\n        )\n        gt_denoise_ref_root_ang_vel = torch.zeros(\n            self.num_envs, 3, device=self.device, dtype=torch.float32\n        )\n        gt_denoise_ref_dof_pos = torch.zeros(\n            self.num_envs,\n            actions.shape[-1],\n            device=self.device,\n            dtype=torch.float32,\n        )\n        if (\n            self.use_aux_denoise_ref_root_lin_vel\n            or self.use_aux_denoise_ref_root_ang_vel\n            or self.use_aux_denoise_ref_dof_pos\n        ):\n            try:\n                if self.use_aux_denoise_ref_root_lin_vel:\n                    gt_denoise_ref_root_lin_vel = (\n                        command.get_ref_motion_base_linvel_cur(\n                            prefix=\"ft_ref_\"\n                        )\n                        - command.get_ref_motion_base_linvel_cur(prefix=\"ref_\")\n                    )\n                if self.use_aux_denoise_ref_root_ang_vel:\n                    gt_denoise_ref_root_ang_vel = (\n                        command.get_ref_motion_base_angvel_cur(\n                            prefix=\"ft_ref_\"\n                        )\n                        - command.get_ref_motion_base_angvel_cur(prefix=\"ref_\")\n                    )\n                if self.use_aux_denoise_ref_dof_pos:\n                    gt_denoise_ref_dof_pos = (\n                        command.get_ref_motion_dof_pos_cur(prefix=\"ft_ref_\")\n                        - command.get_ref_motion_dof_pos_cur(prefix=\"ref_\")\n                    )\n                    expected_shape = (self.num_envs, actions.shape[-1])\n                    if tuple(gt_denoise_ref_dof_pos.shape) != expected_shape:\n                        raise ValueError(\n                            \"gt_denoise_ref_dof_pos must match the action-aligned \"\n                            \"DoF shape \"\n                            f\"{expected_shape}, got \"\n                            f\"{tuple(gt_denoise_ref_dof_pos.shape)}.\"\n                        )\n            except KeyError as exc:\n                raise RuntimeError(\n                    \"Filtered reference tensors are unavailable for \"\n                    \"aux_denoise_* targets. Enable online filtering or \"\n                    \"materialize ft_ref_* tensors in the motion cache.\"\n                ) from exc\n\n        return self.transition_cls(\n            obs=obs_td,\n            actions=actions.detach(),\n            teacher_actions=torch.zeros_like(actions),\n            mu=mu.detach(),\n            sigma=sigma.detach(),\n            actions_log_prob=actions_log_prob[..., None].detach(),\n            values=values.detach(),\n            rewards=zero_scalar.clone(),\n            dones=zero_scalar_bool,\n            returns=zero_scalar.clone(),\n            advantages=zero_scalar.clone(),\n            gt_base_lin_vel_b=gt_base_lin_vel_b.detach(),\n            gt_root_height_rel_terrain=gt_root_height_rel_terrain.detach(),\n            gt_keybody_contacts=gt_keybody_contacts.detach(),\n            gt_ref_keybody_rel_pos=gt_ref_keybody_rel_pos.detach(),\n            gt_robot_keybody_rel_pos=gt_robot_keybody_rel_pos.detach(),\n            gt_denoise_ref_root_lin_vel=gt_denoise_ref_root_lin_vel.detach(),\n            gt_denoise_ref_root_ang_vel=gt_denoise_ref_root_ang_vel.detach(),\n            gt_denoise_ref_dof_pos=gt_denoise_ref_dof_pos.detach(),\n            batch_size=[self.num_envs],\n            device=self.device,\n        )\n\n    def _build_storage(self, obs_td: TensorDict):\n        actor_for_kv = self.accelerator.unwrap_model(self.actor)\n        actor_policy = actor_for_kv.actor_module\n        if bool(getattr(actor_policy, \"use_future_cross_attn\", False)):\n            n_fut = int(getattr(actor_policy, \"future_seq_len\", 0))\n            if n_fut <= 0:\n                raise ValueError(\n                    \"future_seq_len must be positive when \"\n                    \"use_future_cross_attn=True\"\n                )\n            obs_td = obs_td.clone(recurse=False)\n            obs_td.set(\n                \"future_mask\",\n                torch.ones(\n                    self.env.num_envs,\n                    n_fut,\n                    dtype=torch.bool,\n                    device=self.device,\n                ),\n            )\n        return super()._build_storage(obs_td)\n\n    def _sample_iteration_future_masks(self) -> torch.Tensor | None:\n        actor_for_kv = self.accelerator.unwrap_model(self.actor)\n        actor_policy = actor_for_kv.actor_module\n        if not bool(getattr(actor_policy, \"use_future_cross_attn\", False)):\n            return None\n\n        n_fut = int(getattr(actor_policy, \"future_seq_len\", 0))\n        if n_fut <= 0:\n            raise ValueError(\n                \"future_seq_len must be positive when \"\n                \"use_future_cross_attn=True\"\n            )\n        if self._future_mask_mode != \"random_suffix\":\n            raise ValueError(\n                \"Unsupported future_mask_mode: \"\n                f\"{self._future_mask_mode}. \"\n                \"Expected 'random_suffix'.\"\n            )\n        num_steps = int(self.num_steps_per_env)\n        num_envs = int(self.env.num_envs)\n\n        keep = torch.ones(\n            num_steps,\n            num_envs,\n            n_fut,\n            dtype=torch.bool,\n            device=self.device,\n        )\n        if bool(getattr(self, \"_offline_evaluating\", False)):\n            return keep\n        if self._future_mask_prob <= 0.0:\n            return keep\n        apply_mask = (\n            torch.rand(num_steps, num_envs, device=self.device)\n            < self._future_mask_prob\n        )\n        keep_len = torch.randint(\n            1,\n            n_fut + 1,\n            (num_steps, num_envs),\n            device=self.device,\n        )\n        full_len = torch.full(\n            (num_steps, num_envs),\n            n_fut,\n            dtype=torch.long,\n            device=self.device,\n        )\n        keep_len = torch.where(apply_mask, keep_len, full_len)\n        token_idx = torch.arange(n_fut, device=self.device, dtype=torch.long)[\n            None, None, :\n        ]\n        return token_idx < keep_len[:, :, None]\n\n    def _reset_rollout_forward_state(self) -> None:\n        actor_for_kv = self.accelerator.unwrap_model(self.actor)\n        actor_for_kv.clear_env_cache(None)\n        actor_policy = actor_for_kv.actor_module\n        actor_policy.reset_routing_stats()\n        actor_policy.set_collect_routing_stats(True)\n        self._kv_reset_pending.zero_()\n        self._rollout_future_masks = self._sample_iteration_future_masks()\n        self._rollout_step_idx = 0\n\n    def _rollout_forward(\n        self,\n        obs_td: TensorDict,\n        *,\n        actor_mode: str = \"sampling\",\n        collect_transition: bool = True,\n        track_episode_stats: bool = True,\n    ) -> TensorDict:\n        if collect_transition and self._rollout_future_masks is not None:\n            if self._rollout_step_idx >= int(\n                self._rollout_future_masks.shape[0]\n            ):\n                raise RuntimeError(\n                    \"Rollout future-mask step index exceeded \"\n                    \"pre-sampled mask length.\"\n                )\n            obs_td = obs_td.clone(recurse=False)\n            obs_td.set(\n                \"future_mask\",\n                self._rollout_future_masks[self._rollout_step_idx],\n            )\n\n        actor_for_kv = self.accelerator.unwrap_model(self.actor)\n        if torch.any(self._kv_reset_pending):\n            env_ids = torch.nonzero(self._kv_reset_pending).squeeze(-1)\n            if env_ids.numel() > 0:\n                actor_for_kv.clear_env_cache(env_ids)\n                self._kv_reset_pending[env_ids] = False\n        next_obs_td = super()._rollout_forward(\n            obs_td,\n            actor_mode=actor_mode,\n            collect_transition=collect_transition,\n            track_episode_stats=track_episode_stats,\n        )\n        if collect_transition and self._rollout_future_masks is not None:\n            self._rollout_step_idx += 1\n        if not collect_transition:\n            dones = self._last_rollout_dones\n            if dones is not None:\n                self._kv_reset_pending |= (\n                    dones.view(-1).to(torch.bool).to(self.device)\n                )\n        return next_obs_td\n\n    def process_env_step(\n        self,\n        rewards: torch.Tensor,\n        dones: torch.Tensor,\n        time_outs: torch.Tensor,\n        infos: dict,\n    ) -> None:\n        super().process_env_step(rewards, dones, time_outs, infos)\n        if getattr(self, \"_kv_reset_pending\", None) is not None:\n            self._kv_reset_pending |= (\n                dones.view(-1).to(torch.bool).to(self.device)\n            )\n\n    @staticmethod\n    def _build_episode_causal_mask(dones_seq: torch.Tensor) -> torch.Tensor:\n        \"\"\"Build [N, T, T] mask: causal and within the same episode segment.\"\"\"\n        n, t, _ = dones_seq.shape\n        device = dones_seq.device\n        dones = dones_seq.squeeze(-1).to(torch.long)\n        seg = torch.cumsum(dones, dim=1) - dones\n        same = seg[:, :, None] == seg[:, None, :]\n        causal = torch.tril(torch.ones(t, t, dtype=torch.bool, device=device))\n        return same & causal\n\n    @staticmethod\n    def _resolve_sequence_batch_partition(\n        num_envs: int,\n        num_mini_batches: int,\n    ) -> tuple[int, int]:\n        if num_envs <= 0:\n            raise RuntimeError(\n                \"PPOTF sequence batching requires at least one \"\n                \"environment on each rank.\"\n            )\n        effective_num_mini_batches = max(\n            1, min(int(num_mini_batches), int(num_envs))\n        )\n        mini_batch_envs = max(\n            1,\n            (num_envs + effective_num_mini_batches - 1)\n            // effective_num_mini_batches,\n        )\n        return effective_num_mini_batches, mini_batch_envs\n\n    def _sequence_batches(\n        self, num_mini_batches: int, num_epochs: int\n    ) -> Generator[tuple, None, None]:\n        data = self.storage.data\n        obs_seq = data[\"obs\"].transpose(0, 1)\n        actions_seq = data[\"actions\"].transpose(0, 1)\n        values_seq = data[\"values\"].transpose(0, 1)\n        rewards_seq = data[\"rewards\"].transpose(0, 1)\n        returns_seq = data[\"returns\"].transpose(0, 1)\n        adv_seq = data[\"advantages\"].transpose(0, 1)\n        old_logp_seq = data[\"actions_log_prob\"].transpose(0, 1)\n        old_mu_seq = data[\"mu\"].transpose(0, 1)\n        old_sigma_seq = data[\"sigma\"].transpose(0, 1)\n        dones_seq = data[\"dones\"].transpose(0, 1)\n        gt_base_lin_vel_seq = None\n        gt_root_height_seq = None\n        gt_keybody_contact_seq = None\n        gt_ref_keybody_rel_pos_seq = None\n        gt_robot_keybody_rel_pos_seq = None\n        gt_denoise_ref_root_lin_vel_seq = None\n        gt_denoise_ref_root_ang_vel_seq = None\n        gt_denoise_ref_dof_pos_seq = None\n        if self.use_aux_state_pred:\n            gt_base_lin_vel_seq = data[\"gt_base_lin_vel_b\"].transpose(0, 1)\n            gt_root_height_seq = data[\"gt_root_height_rel_terrain\"].transpose(\n                0, 1\n            )\n            gt_keybody_contact_seq = data[\"gt_keybody_contacts\"].transpose(\n                0, 1\n            )\n            gt_ref_keybody_rel_pos_seq = data[\n                \"gt_ref_keybody_rel_pos\"\n            ].transpose(0, 1)\n            gt_robot_keybody_rel_pos_seq = data[\n                \"gt_robot_keybody_rel_pos\"\n            ].transpose(0, 1)\n            gt_denoise_ref_root_lin_vel_seq = data[\n                \"gt_denoise_ref_root_lin_vel\"\n            ].transpose(0, 1)\n            gt_denoise_ref_root_ang_vel_seq = data[\n                \"gt_denoise_ref_root_ang_vel\"\n            ].transpose(0, 1)\n            gt_denoise_ref_dof_pos_seq = data[\n                \"gt_denoise_ref_dof_pos\"\n            ].transpose(0, 1)\n\n        num_envs = int(actions_seq.shape[0])\n        if num_envs <= 0:\n            raise RuntimeError(\n                \"PPOTF sequence batching requires at least one \"\n                \"environment on each rank, \"\n                f\"got num_envs={num_envs}.\"\n            )\n        num_mini_batches, mb_env = self._resolve_sequence_batch_partition(\n            num_envs, num_mini_batches\n        )\n        env_indices = torch.randperm(num_envs, device=self.device)\n\n        for _ in range(num_epochs):\n            for i in range(num_mini_batches):\n                start = i * mb_env\n                if start >= num_envs:\n                    break\n                end = min(num_envs, (i + 1) * mb_env)\n                idx = env_indices[start:end]\n                obs_b = obs_seq[idx]\n                actions_b = actions_seq[idx]\n                values_b = values_seq[idx]\n                rewards_b = rewards_seq[idx]\n                returns_b = returns_seq[idx]\n                adv_b = adv_seq[idx]\n                old_logp_b = old_logp_seq[idx]\n                old_mu_b = old_mu_seq[idx]\n                old_sigma_b = old_sigma_seq[idx]\n                dones_b = dones_seq[idx]\n                gt_base_lin_vel_b = (\n                    gt_base_lin_vel_seq[idx]\n                    if gt_base_lin_vel_seq is not None\n                    else None\n                )\n                gt_root_height_b = (\n                    gt_root_height_seq[idx]\n                    if gt_root_height_seq is not None\n                    else None\n                )\n                gt_keybody_contact_b = (\n                    gt_keybody_contact_seq[idx]\n                    if gt_keybody_contact_seq is not None\n                    else None\n                )\n                gt_ref_keybody_rel_pos_b = (\n                    gt_ref_keybody_rel_pos_seq[idx]\n                    if gt_ref_keybody_rel_pos_seq is not None\n                    else None\n                )\n                gt_robot_keybody_rel_pos_b = (\n                    gt_robot_keybody_rel_pos_seq[idx]\n                    if gt_robot_keybody_rel_pos_seq is not None\n                    else None\n                )\n                gt_denoise_ref_root_lin_vel_b = (\n                    gt_denoise_ref_root_lin_vel_seq[idx]\n                    if gt_denoise_ref_root_lin_vel_seq is not None\n                    else None\n                )\n                gt_denoise_ref_root_ang_vel_b = (\n                    gt_denoise_ref_root_ang_vel_seq[idx]\n                    if gt_denoise_ref_root_ang_vel_seq is not None\n                    else None\n                )\n                gt_denoise_ref_dof_pos_b = (\n                    gt_denoise_ref_dof_pos_seq[idx]\n                    if gt_denoise_ref_dof_pos_seq is not None\n                    else None\n                )\n                attn_mask = self._build_episode_causal_mask(dones_b)\n                yield (\n                    obs_b,\n                    actions_b,\n                    values_b,\n                    adv_b,\n                    returns_b,\n                    rewards_b,\n                    old_logp_b,\n                    old_mu_b,\n                    old_sigma_b,\n                    attn_mask,\n                    gt_base_lin_vel_b,\n                    gt_root_height_b,\n                    gt_keybody_contact_b,\n                    gt_ref_keybody_rel_pos_b,\n                    gt_robot_keybody_rel_pos_b,\n                    gt_denoise_ref_root_lin_vel_b,\n                    gt_denoise_ref_root_ang_vel_b,\n                    gt_denoise_ref_dof_pos_b,\n                )\n\n    def update(self):\n        actor_unwrapped = self.accelerator.unwrap_model(self.actor)\n        actor_policy = actor_unwrapped.actor_module\n        actor_policy.set_collect_routing_stats(False)\n        mean_value_loss = 0.0\n        mean_surrogate_loss = 0.0\n        mean_entropy = 0.0\n        mean_kl_token = 0.0\n        mean_kl_loss = 0.0\n        mean_kl_analytic = 0.0\n        critic_explained_variance = self._compute_explained_variance(\n            target=self.storage.data[\"returns\"],\n            prediction=self.storage.data[\"values\"],\n        )\n        mean_aux_base_lin_vel_nll = 0.0\n        mean_aux_root_height_nll = 0.0\n        mean_aux_base_lin_vel_std = 0.0\n        mean_aux_root_height_std = 0.0\n        mean_aux_keybody_contact_bce = 0.0\n        mean_aux_keybody_contact_acc = 0.0\n        mean_aux_ref_keybody_rel_pos_mse = 0.0\n        mean_aux_robot_keybody_rel_pos_mse = 0.0\n        mean_aux_denoise_ref_root_lin_vel_huber = 0.0\n        mean_aux_denoise_ref_root_ang_vel_huber = 0.0\n        mean_aux_denoise_ref_dof_pos_huber = 0.0\n        mean_aux_router_command_recon_mse = 0.0\n        mean_aux_router_future_recon_huber = 0.0\n        mean_aux_router_switch_penalty_js = 0.0\n        mean_dead_expert_margin_to_topk_loss = 0.0\n        mean_router_expert_orthogonal_loss = 0.0\n        mean_selected_expert_margin_to_unselected_loss = 0.0\n        moe_layers = [\n            layer\n            for layer in actor_policy.layers\n            if isinstance(layer, GroupedMoEBlock)\n        ]\n\n        (\n            effective_num_mini_batches,\n            mini_batch_envs,\n        ) = self._resolve_sequence_batch_partition(\n            self.storage.num_envs, self.num_mini_batches\n        )\n        self._last_update_metrics = {\n            \"0-Train/configured_num_mini_batches\": float(\n                self.configured_num_mini_batches\n            ),\n            \"0-Train/requested_num_mini_batches\": float(\n                self.requested_num_mini_batches\n            ),\n            \"0-Train/effective_num_mini_batches\": float(\n                effective_num_mini_batches\n            ),\n            \"0-Train/mini_batch_size_per_rank\": float(\n                mini_batch_envs * self.num_steps_per_env\n            ),\n            \"0-Train/mini_batch_num_envs_per_rank\": float(mini_batch_envs),\n            \"0-Train/num_updates_executed\": 0.0,\n            \"0-Train/lr_scale_factor\": float(self.distributed_lr_scale_factor),\n            \"0-Train/scalable_distributed_update\": float(\n                self.distributed_update_mode == \"scalable\"\n            ),\n            \"0-Train/kl_windowed\": 0.0,\n            \"0-Train/kl_stop_triggered\": 0.0,\n            \"0-Train/kl_stop_analytic\": 0.0,\n            \"0-Train/kl_analytic_batch_last\": 0.0,\n            \"0-Train/kl_analytic_batch_max\": 0.0,\n            \"0-Train/clip_fraction_batch_mean\": 0.0,\n            \"0-Train/clip_fraction_batch_last\": 0.0,\n        }\n        entropy_coef = self._get_effective_entropy_coef()\n        generator = self._sequence_batches(\n            effective_num_mini_batches,\n            self.num_learning_epochs,\n        )\n        measure_analytic_kl = self.desired_kl is not None\n        normalize_per_mb = bool(self.normalize_advantage_per_mini_batch)\n        num_updates = 0\n        num_kl_measurements = 0\n        kl_stop_triggered = False\n        kl_stop_analytic = 0.0\n        kl_windowed = None\n        recent_analytic_kls: list[float] = []\n        kl_analytic_batch_last = 0.0\n        kl_analytic_batch_max = 0.0\n        clip_fraction_batch_mean = 0.0\n        clip_fraction_batch_last = 0.0\n\n        for (\n            obs_b,\n            actions_b,\n            target_values_b,\n            advantages_b,\n            returns_b,\n            _rewards_b,\n            old_logp_b,\n            old_mu_b,\n            old_sigma_b,\n            attn_mask_b,\n            gt_base_lin_vel_b,\n            gt_root_height_b,\n            gt_keybody_contact_b,\n            gt_ref_keybody_rel_pos_b,\n            gt_robot_keybody_rel_pos_b,\n            gt_denoise_ref_root_lin_vel_b,\n            gt_denoise_ref_root_ang_vel_b,\n            gt_denoise_ref_dof_pos_b,\n        ) in generator:\n            valid_tok = attn_mask_b.diagonal(dim1=1, dim2=2).to(torch.float32)\n            valid_count = valid_tok.sum().clamp_min(1.0)\n\n            if normalize_per_mb:\n                with torch.no_grad():\n                    flat = advantages_b.view(-1).float()\n                    if self.global_advantage_norm and self.is_distributed:\n                        count = torch.tensor(\n                            [flat.numel()],\n                            device=self.device,\n                            dtype=torch.float32,\n                        )\n                        sum_g = self.accelerator.reduce(\n                            flat.sum(), reduction=\"sum\"\n                        )\n                        sqsum_g = self.accelerator.reduce(\n                            (flat * flat).sum(), reduction=\"sum\"\n                        )\n                        count_g = self.accelerator.reduce(\n                            count, reduction=\"sum\"\n                        )\n                        mean = sum_g / count_g\n                        var = (sqsum_g / count_g) - mean * mean\n                        std = torch.sqrt(var.clamp_min(1.0e-8))\n                    else:\n                        mean = flat.mean()\n                        std = flat.std().clamp_min(1.0e-8)\n                    advantages_b = (advantages_b - mean) / std\n\n            b, t = int(obs_b.batch_size[0]), int(obs_b.batch_size[1])\n            critic_obs_flat = obs_b.flatten(0, 1)\n            with self.accelerator.autocast():\n                actor_out = self.actor(\n                    obs_b,\n                    actions=actions_b,\n                    mode=\"sequence_logp\",\n                    attn_mask=attn_mask_b,\n                    update_obs_norm=False,\n                )\n                critic_out = self.critic(\n                    critic_obs_flat, update_obs_norm=False\n                )\n            logp_new_b = actor_out.get(\"actions_log_prob\")\n            mu_b = actor_out.get(\"mu\")\n            sigma_b = actor_out.get(\"sigma\")\n            entropy_b = actor_out.get(\"entropy\")\n            v_pred_flat = critic_out.get(\"values\")\n            value_batch = v_pred_flat.reshape(b, t, -1)\n            returns_batch_norm = returns_b\n            target_values_batch_norm = target_values_b\n\n            analytic_kl = None\n            if measure_analytic_kl:\n                analytic_kl = self._compute_analytic_kl(\n                    old_mu=old_mu_b.float(),\n                    old_sigma=old_sigma_b.float(),\n                    new_mu=mu_b.float(),\n                    new_sigma=sigma_b.float(),\n                    weight=valid_tok,\n                )\n                mean_kl_analytic += analytic_kl\n                num_kl_measurements += 1\n                kl_analytic_batch_last = analytic_kl\n                kl_analytic_batch_max = max(kl_analytic_batch_max, analytic_kl)\n                recent_analytic_kls.append(analytic_kl)\n                if len(recent_analytic_kls) > self.kl_early_stop_window_size:\n                    recent_analytic_kls.pop(0)\n                kl_windowed = self._compute_windowed_kl_signal(\n                    recent_analytic_kls\n                )\n                if self._should_early_stop_for_kl(\n                    kl_windowed, num_kl_measurements\n                ):\n                    kl_stop_triggered = True\n                    kl_stop_analytic = analytic_kl\n                    break\n\n            logp_new = logp_new_b.squeeze(-1).float()\n            logp_old = old_logp_b.squeeze(-1).float()\n            ratio = torch.exp(logp_new - logp_old)\n            clip_fraction = self._compute_clip_fraction(\n                ratio, weight=valid_tok\n            )\n            clip_fraction_batch_mean += clip_fraction\n            clip_fraction_batch_last = clip_fraction\n            adv = advantages_b.squeeze(-1)\n            s1 = ratio * adv\n            s2 = (\n                torch.clamp(\n                    ratio, 1.0 - self.clip_param, 1.0 + self.clip_param\n                )\n                * adv\n            )\n            surrogate_loss = (\n                -torch.min(s1, s2) * valid_tok\n            ).sum() / valid_count\n\n            if self.use_clipped_value_loss:\n                value_clipped = target_values_batch_norm + (\n                    value_batch - target_values_batch_norm\n                ).clamp(-self.clip_param, self.clip_param)\n                value_losses = (value_batch - returns_batch_norm).pow(2)\n                value_losses_clipped = (\n                    value_clipped - returns_batch_norm\n                ).pow(2)\n                v_max = torch.max(value_losses, value_losses_clipped).squeeze(\n                    -1\n                )\n                value_loss = (v_max * valid_tok).sum() / valid_count\n            else:\n                v_err = (returns_batch_norm - value_batch).pow(2).squeeze(-1)\n                value_loss = (v_err * valid_tok).sum() / valid_count\n\n            actor_loss = surrogate_loss\n            critic_loss = self.value_loss_coef * value_loss\n            aux_base_lin_vel_loss = None\n            aux_root_height_loss = None\n            aux_base_lin_vel_std = None\n            aux_root_height_std = None\n            aux_keybody_contact_loss = None\n            aux_keybody_contact_acc = None\n            aux_ref_keybody_rel_pos_loss = None\n            aux_robot_keybody_rel_pos_loss = None\n            aux_denoise_ref_root_lin_vel_loss = None\n            aux_denoise_ref_root_ang_vel_loss = None\n            aux_denoise_ref_dof_pos_loss = None\n            aux_router_command_recon_loss = None\n            aux_router_future_recon_loss = None\n            aux_router_switch_penalty_loss = None\n            dead_expert_margin_to_topk_loss = None\n            router_expert_orthogonal_loss = None\n            selected_expert_margin_to_unselected_loss = None\n            if self.use_aux_state_pred:\n                aux_base_lin_vel_loc = actor_out.get(\"aux_base_lin_vel_loc\")\n                aux_base_lin_vel_log_std = actor_out.get(\n                    \"aux_base_lin_vel_log_std\"\n                )\n                aux_base_lin_vel_std = torch.clamp(\n                    torch.exp(aux_base_lin_vel_log_std),\n                    min=self.aux_state_pred_min_std,\n                    max=self.aux_state_pred_max_std,\n                )\n                aux_base_lin_vel_nll = 0.5 * (\n                    torch.square(\n                        (gt_base_lin_vel_b - aux_base_lin_vel_loc)\n                        / aux_base_lin_vel_std\n                    )\n                    + 2.0 * torch.log(aux_base_lin_vel_std + 1.0e-8)\n                ).sum(dim=-1)\n                aux_base_lin_vel_loss = (\n                    aux_base_lin_vel_nll * valid_tok\n                ).sum() / valid_count\n                actor_loss = (\n                    actor_loss\n                    + self.aux_state_pred_w_base_lin_vel\n                    * aux_base_lin_vel_loss\n                )\n                aux_root_height_loc = actor_out.get(\"aux_root_height_loc\")\n                aux_root_height_log_std = actor_out.get(\n                    \"aux_root_height_log_std\"\n                )\n                if self.use_aux_root_height and gt_root_height_b is not None:\n                    aux_root_height_std = torch.clamp(\n                        torch.exp(aux_root_height_log_std),\n                        min=self.aux_state_pred_min_std,\n                        max=self.aux_state_pred_max_std,\n                    )\n                    aux_root_height_nll = 0.5 * (\n                        torch.square(\n                            (gt_root_height_b - aux_root_height_loc)\n                            / aux_root_height_std\n                        )\n                        + 2.0 * torch.log(aux_root_height_std + 1.0e-8)\n                    ).sum(dim=-1)\n                    aux_root_height_loss = (\n                        aux_root_height_nll * valid_tok\n                    ).sum() / valid_count\n                    actor_loss = (\n                        actor_loss\n                        + self.aux_state_pred_w_root_height\n                        * aux_root_height_loss\n                    )\n                else:\n                    actor_loss = actor_loss + 0.0 * (\n                        aux_root_height_loc.sum()\n                        + aux_root_height_log_std.sum()\n                    )\n                if (\n                    self.aux_state_pred_num_contact_bodies > 0\n                    and gt_keybody_contact_b is not None\n                ):\n                    aux_keybody_contact_logits = actor_out.get(\n                        \"aux_keybody_contact_logits\"\n                    )\n                    contact_bce = F.binary_cross_entropy_with_logits(\n                        aux_keybody_contact_logits,\n                        gt_keybody_contact_b,\n                        reduction=\"none\",\n                    ).mean(dim=-1)\n                    aux_keybody_contact_loss = (\n                        contact_bce * valid_tok\n                    ).sum() / valid_count\n                    actor_loss = (\n                        actor_loss\n                        + self.aux_state_pred_w_keybody_contact\n                        * aux_keybody_contact_loss\n                    )\n                    contact_pred = (aux_keybody_contact_logits > 0.0).to(\n                        gt_keybody_contact_b.dtype\n                    )\n                    contact_acc_tok = (\n                        (contact_pred == gt_keybody_contact_b)\n                        .to(torch.float32)\n                        .mean(dim=-1)\n                    )\n                    aux_keybody_contact_acc = (\n                        contact_acc_tok * valid_tok\n                    ).sum() / valid_count\n                aux_ref_keybody_rel_pos = actor_out.get(\n                    \"aux_ref_keybody_rel_pos\"\n                )\n                aux_robot_keybody_rel_pos = actor_out.get(\n                    \"aux_robot_keybody_rel_pos\"\n                )\n                if (\n                    self.aux_state_pred_num_keybody_bodies > 0\n                    and gt_ref_keybody_rel_pos_b is not None\n                ):\n                    aux_ref_keybody_rel_pos_loss = (\n                        self._masked_aux_keybody_mse(\n                            aux_ref_keybody_rel_pos,\n                            gt_ref_keybody_rel_pos_b,\n                            valid_tok,\n                        )\n                    )\n                    actor_loss = (\n                        actor_loss\n                        + self.aux_state_pred_w_ref_keybody_rel_pos\n                        * aux_ref_keybody_rel_pos_loss\n                    )\n                elif aux_ref_keybody_rel_pos.numel() > 0:\n                    actor_loss = (\n                        actor_loss + 0.0 * aux_ref_keybody_rel_pos.sum()\n                    )\n                if (\n                    self.aux_state_pred_num_keybody_bodies > 0\n                    and gt_robot_keybody_rel_pos_b is not None\n                ):\n                    aux_robot_keybody_rel_pos_loss = (\n                        self._masked_aux_keybody_mse(\n                            aux_robot_keybody_rel_pos,\n                            gt_robot_keybody_rel_pos_b,\n                            valid_tok,\n                        )\n                    )\n                    actor_loss = (\n                        actor_loss\n                        + self.aux_state_pred_w_robot_keybody_rel_pos\n                        * aux_robot_keybody_rel_pos_loss\n                    )\n                elif aux_robot_keybody_rel_pos.numel() > 0:\n                    actor_loss = (\n                        actor_loss + 0.0 * aux_robot_keybody_rel_pos.sum()\n                    )\n                if self.use_aux_denoise_ref_root_lin_vel:\n                    aux_denoise_ref_root_lin_vel_residual = actor_out.get(\n                        \"aux_denoise_ref_root_lin_vel_residual\"\n                    )\n                    aux_denoise_ref_root_lin_vel_loss = self._masked_aux_huber(\n                        pred=aux_denoise_ref_root_lin_vel_residual,\n                        target=gt_denoise_ref_root_lin_vel_b,\n                        valid_tok=valid_tok,\n                        beta=self.aux_denoise_residual_huber_beta,\n                    )\n                    actor_loss = (\n                        actor_loss\n                        + self.aux_state_pred_w_denoise_ref_root_lin_vel\n                        * aux_denoise_ref_root_lin_vel_loss\n                    )\n                if self.use_aux_denoise_ref_root_ang_vel:\n                    aux_denoise_ref_root_ang_vel_residual = actor_out.get(\n                        \"aux_denoise_ref_root_ang_vel_residual\"\n                    )\n                    aux_denoise_ref_root_ang_vel_loss = self._masked_aux_huber(\n                        pred=aux_denoise_ref_root_ang_vel_residual,\n                        target=gt_denoise_ref_root_ang_vel_b,\n                        valid_tok=valid_tok,\n                        beta=self.aux_denoise_residual_huber_beta,\n                    )\n                    actor_loss = (\n                        actor_loss\n                        + self.aux_state_pred_w_denoise_ref_root_ang_vel\n                        * aux_denoise_ref_root_ang_vel_loss\n                    )\n                if self.use_aux_denoise_ref_dof_pos:\n                    aux_denoise_ref_dof_pos_residual = actor_out.get(\n                        \"aux_denoise_ref_dof_pos_residual\"\n                    )\n                    aux_denoise_ref_dof_pos_loss = self._masked_aux_huber(\n                        pred=aux_denoise_ref_dof_pos_residual,\n                        target=gt_denoise_ref_dof_pos_b,\n                        valid_tok=valid_tok,\n                        beta=self.aux_denoise_residual_huber_beta,\n                    )\n                    actor_loss = (\n                        actor_loss\n                        + self.aux_state_pred_w_denoise_ref_dof_pos\n                        * aux_denoise_ref_dof_pos_loss\n                    )\n            if self.use_aux_router_command_recon:\n                if self.aux_router_command_recon_assembler is None:\n                    raise ValueError(\n                        \"aux_router_command_recon is enabled but command \"\n                        \"assembler was not initialized.\"\n                    )\n                aux_router_command_recon_pred = actor_out.get(\n                    \"aux_router_command_recon\"\n                )\n                gt_aux_router_command_recon_b = (\n                    self.aux_router_command_recon_assembler(\n                        obs_b.flatten(0, 1)\n                    ).reshape(b, t, -1)\n                )\n                aux_router_command_recon_loss = self._masked_aux_mse(\n                    aux_router_command_recon_pred,\n                    gt_aux_router_command_recon_b,\n                    valid_tok,\n                )\n                actor_loss = (\n                    actor_loss\n                    + self.aux_router_command_recon_weight\n                    * aux_router_command_recon_loss\n                )\n            if self.use_aux_router_future_recon:\n                aux_router_future_recon_loss = (\n                    self._compute_aux_router_future_recon_loss(\n                        actor_wrapper=actor_unwrapped,\n                        actor_out=actor_out,\n                        obs_b=obs_b,\n                        valid_tok=valid_tok,\n                    )\n                )\n                actor_loss = (\n                    actor_loss\n                    + self.aux_router_future_recon_weight\n                    * aux_router_future_recon_loss\n                )\n            if self.use_aux_router_switch_penalty:\n                if self.aux_router_switch_penalty_metric == \"js\":\n                    aux_router_features = actor_out.get(\"router_features\")\n                    aux_router_switch_penalty_loss = self._masked_adjacent_router_js(\n                        router_features=aux_router_features,\n                        valid_tok=valid_tok,\n                        num_moe_layers=self.aux_command_router_num_moe_layers,\n                        num_fine_experts=self.aux_command_router_num_fine_experts,\n                    )\n                else:\n                    aux_router_temporal_features = actor_out.get(\n                        \"router_temporal_features\"\n                    )\n                    aux_router_switch_penalty_loss = self._masked_adjacent_router_normed_smooth_l1(\n                        router_temporal_features=aux_router_temporal_features,\n                        valid_tok=valid_tok,\n                        num_moe_layers=self.aux_command_router_num_moe_layers,\n                        num_fine_experts=self.aux_command_router_num_fine_experts,\n                        beta=self.aux_router_switch_penalty_beta,\n                    )\n                aux_router_switch_penalty_loss = (\n                    aux_router_switch_penalty_loss.to(actor_loss.dtype)\n                )\n                actor_loss = (\n                    actor_loss\n                    + self.aux_router_switch_penalty_weight\n                    * aux_router_switch_penalty_loss\n                )\n            if self.use_dead_expert_margin_to_topk and len(moe_layers) > 0:\n                margin_losses = [\n                    layer.last_dead_expert_margin_to_topk_loss\n                    for layer in moe_layers\n                    if layer.last_dead_expert_margin_to_topk_loss is not None\n                ]\n                if len(margin_losses) > 0:\n                    dead_expert_margin_to_topk_loss = torch.stack(\n                        [\n                            loss.to(actor_loss.device, dtype=actor_loss.dtype)\n                            for loss in margin_losses\n                        ]\n                    ).mean()\n                    actor_loss = (\n                        actor_loss\n                        + self.dead_expert_margin_to_topk_weight\n                        * dead_expert_margin_to_topk_loss\n                    )\n            if self.use_router_expert_orthogonal and len(moe_layers) > 0:\n                orth_losses = []\n                for layer in moe_layers:\n                    layer_orth_loss, _, _ = (\n                        self._compute_routed_expert_orthogonal_loss(\n                            layer,\n                            dtype=actor_loss.dtype,\n                            device=actor_loss.device,\n                        )\n                    )\n                    orth_losses.append(layer_orth_loss)\n                if len(orth_losses) > 0:\n                    router_expert_orthogonal_loss = torch.stack(\n                        orth_losses\n                    ).mean()\n                    actor_loss = (\n                        actor_loss\n                        + self.router_expert_orthogonal_weight\n                        * router_expert_orthogonal_loss\n                    )\n            if (\n                self.use_selected_expert_margin_to_unselected\n                and len(moe_layers) > 0\n            ):\n                selected_margin_losses = [\n                    layer.last_selected_expert_margin_to_unselected_loss\n                    for layer in moe_layers\n                    if layer.last_selected_expert_margin_to_unselected_loss\n                    is not None\n                ]\n                if len(selected_margin_losses) > 0:\n                    selected_expert_margin_to_unselected_loss = torch.stack(\n                        [\n                            loss.to(actor_loss.device, dtype=actor_loss.dtype)\n                            for loss in selected_margin_losses\n                        ]\n                    ).mean()\n                    actor_loss = (\n                        actor_loss\n                        + self.selected_expert_margin_to_unselected_weight\n                        * selected_expert_margin_to_unselected_loss\n                    )\n\n            kl_coef = float(\n                getattr(self.config, \"kl_coef\", self.desired_kl or 0.0) or 0.0\n            )\n            if kl_coef > 0.0:\n                delta_logp = logp_new - logp_old\n                kl_token = (\n                    ratio.detach() * delta_logp * valid_tok\n                ).sum() / valid_count\n                kl_loss = kl_coef * kl_token\n                actor_loss = actor_loss + kl_loss\n                mean_kl_token += float(kl_token.item())\n                mean_kl_loss += float(kl_loss.item())\n\n            if entropy_coef > 0.0:\n                ent_tok = entropy_b.squeeze(-1)\n                entropy_loss = (ent_tok * valid_tok).sum() / valid_count\n                actor_loss = actor_loss - entropy_coef * entropy_loss\n\n            self.actor_optimizer.zero_grad()\n            self.critic_optimizer.zero_grad()\n            self.accelerator.backward(actor_loss)\n            self.accelerator.backward(critic_loss)\n\n            if self.max_grad_norm is not None:\n                self.accelerator.clip_grad_norm_(\n                    self.actor.parameters(), self.max_grad_norm\n                )\n                self.accelerator.clip_grad_norm_(\n                    self.critic.parameters(), self.max_grad_norm\n                )\n            self.actor_optimizer.step()\n            self.critic_optimizer.step()\n\n            num_updates += 1\n            mean_value_loss += float(value_loss.item())\n            mean_surrogate_loss += float(surrogate_loss.item())\n            mean_entropy += float(entropy_b.mean().item())\n            if aux_base_lin_vel_loss is not None:\n                mean_aux_base_lin_vel_nll += float(\n                    aux_base_lin_vel_loss.item()\n                )\n            if aux_root_height_loss is not None:\n                mean_aux_root_height_nll += float(aux_root_height_loss.item())\n            if aux_base_lin_vel_std is not None:\n                mean_aux_base_lin_vel_std += float(\n                    aux_base_lin_vel_std.mean().item()\n                )\n            if aux_root_height_std is not None:\n                mean_aux_root_height_std += float(\n                    aux_root_height_std.mean().item()\n                )\n            if aux_keybody_contact_loss is not None:\n                mean_aux_keybody_contact_bce += float(\n                    aux_keybody_contact_loss.item()\n                )\n            if aux_keybody_contact_acc is not None:\n                mean_aux_keybody_contact_acc += float(\n                    aux_keybody_contact_acc.item()\n                )\n            if aux_ref_keybody_rel_pos_loss is not None:\n                mean_aux_ref_keybody_rel_pos_mse += float(\n                    aux_ref_keybody_rel_pos_loss.item()\n                )\n            if aux_robot_keybody_rel_pos_loss is not None:\n                mean_aux_robot_keybody_rel_pos_mse += float(\n                    aux_robot_keybody_rel_pos_loss.item()\n                )\n            if aux_denoise_ref_root_lin_vel_loss is not None:\n                mean_aux_denoise_ref_root_lin_vel_huber += float(\n                    aux_denoise_ref_root_lin_vel_loss.item()\n                )\n            if aux_denoise_ref_root_ang_vel_loss is not None:\n                mean_aux_denoise_ref_root_ang_vel_huber += float(\n                    aux_denoise_ref_root_ang_vel_loss.item()\n                )\n            if aux_denoise_ref_dof_pos_loss is not None:\n                mean_aux_denoise_ref_dof_pos_huber += float(\n                    aux_denoise_ref_dof_pos_loss.item()\n                )\n            if aux_router_command_recon_loss is not None:\n                mean_aux_router_command_recon_mse += float(\n                    aux_router_command_recon_loss.item()\n                )\n            if aux_router_future_recon_loss is not None:\n                mean_aux_router_future_recon_huber += float(\n                    aux_router_future_recon_loss.item()\n                )\n            if aux_router_switch_penalty_loss is not None:\n                mean_aux_router_switch_penalty_js += float(\n                    aux_router_switch_penalty_loss.item()\n                )\n            if dead_expert_margin_to_topk_loss is not None:\n                mean_dead_expert_margin_to_topk_loss += float(\n                    dead_expert_margin_to_topk_loss.item()\n                )\n            if router_expert_orthogonal_loss is not None:\n                mean_router_expert_orthogonal_loss += float(\n                    router_expert_orthogonal_loss.item()\n                )\n            if selected_expert_margin_to_unselected_loss is not None:\n                mean_selected_expert_margin_to_unselected_loss += float(\n                    selected_expert_margin_to_unselected_loss.item()\n                )\n\n        actor_policy.apply_dynamic_bias_update_from_stats()\n        denom = max(1, num_updates)\n        mean_value_loss /= denom\n        mean_surrogate_loss /= denom\n        mean_entropy /= denom\n        mean_kl_token /= denom\n        mean_kl_loss /= denom\n        mean_kl_analytic /= max(1, num_kl_measurements)\n        clip_fraction_batch_mean /= denom\n        if self.schedule == \"adaptive\":\n            self._apply_adaptive_lr(kl_windowed)\n        mean_aux_base_lin_vel_nll /= denom\n        mean_aux_root_height_nll /= denom\n        mean_aux_base_lin_vel_std /= denom\n        mean_aux_root_height_std /= denom\n        mean_aux_keybody_contact_bce /= denom\n        mean_aux_keybody_contact_acc /= denom\n        mean_aux_ref_keybody_rel_pos_mse /= denom\n        mean_aux_robot_keybody_rel_pos_mse /= denom\n        mean_aux_denoise_ref_root_lin_vel_huber /= denom\n        mean_aux_denoise_ref_root_ang_vel_huber /= denom\n        mean_aux_denoise_ref_dof_pos_huber /= denom\n        mean_aux_router_command_recon_mse /= denom\n        mean_aux_router_future_recon_huber /= denom\n        mean_aux_router_switch_penalty_js /= denom\n        mean_dead_expert_margin_to_topk_loss /= denom\n        mean_router_expert_orthogonal_loss /= denom\n        mean_selected_expert_margin_to_unselected_loss /= denom\n        self._last_update_metrics[\"0-Train/num_updates_executed\"] = float(\n            num_updates\n        )\n        self._last_update_metrics[\"0-Train/kl_windowed\"] = float(\n            kl_windowed or 0.0\n        )\n        self._last_update_metrics[\"0-Train/kl_stop_triggered\"] = float(\n            kl_stop_triggered\n        )\n        self._last_update_metrics[\"0-Train/kl_stop_analytic\"] = float(\n            kl_stop_analytic\n        )\n        self._last_update_metrics[\"0-Train/kl_analytic_batch_last\"] = float(\n            kl_analytic_batch_last\n        )\n        self._last_update_metrics[\"0-Train/kl_analytic_batch_max\"] = float(\n            kl_analytic_batch_max\n        )\n        self._last_update_metrics[\"0-Train/clip_fraction_batch_mean\"] = float(\n            clip_fraction_batch_mean\n        )\n        self._last_update_metrics[\"0-Train/clip_fraction_batch_last\"] = float(\n            clip_fraction_batch_last\n        )\n        moe_layers = [\n            layer\n            for layer in actor_unwrapped.actor_module.layers\n            if isinstance(layer, GroupedMoEBlock)\n        ]\n        moe_active_expert_ratio = None\n        moe_max_expert_frac = None\n        moe_least_expert_frac = None\n        moe_dead_expert_ratio = None\n        moe_expert_count_cv = None\n        moe_selected_expert_margin_to_unselected = None\n        moe_last_router_js_step = None\n        moe_last_router_top1_switch_rate = None\n        if len(moe_layers) > 0:\n            moe_metrics = self._summarize_moe_layer_stats(moe_layers)\n            moe_active_expert_ratio = moe_metrics[\"moe_active_expert_ratio\"]\n            moe_max_expert_frac = moe_metrics[\"moe_max_expert_frac\"]\n            moe_least_expert_frac = moe_metrics[\"moe_least_expert_frac\"]\n            moe_dead_expert_ratio = moe_metrics[\"moe_dead_expert_ratio\"]\n            moe_expert_count_cv = moe_metrics[\"moe_expert_count_cv\"]\n            moe_selected_expert_margin_to_unselected = moe_metrics[\n                \"moe_selected_expert_margin_to_unselected\"\n            ]\n            router_shift_stats = actor_policy.get_last_moe_router_shift_stats()\n            js_sum = router_shift_stats[\"js_sum\"]\n            js_count = router_shift_stats[\"js_count\"]\n            top1_switch_sum = router_shift_stats[\"top1_switch_sum\"]\n            top1_switch_count = router_shift_stats[\"top1_switch_count\"]\n            if (\n                js_sum is not None\n                and js_count is not None\n                and top1_switch_sum is not None\n                and top1_switch_count is not None\n            ):\n                js_sum = js_sum.detach().to(self.device, dtype=torch.float32)\n                js_count = js_count.detach().to(\n                    self.device, dtype=torch.float32\n                )\n                top1_switch_sum = top1_switch_sum.detach().to(\n                    self.device, dtype=torch.float32\n                )\n                top1_switch_count = top1_switch_count.detach().to(\n                    self.device, dtype=torch.float32\n                )\n                if self.is_distributed:\n                    js_sum = self.accelerator.reduce(js_sum, reduction=\"sum\")\n                    js_count = self.accelerator.reduce(\n                        js_count, reduction=\"sum\"\n                    )\n                    top1_switch_sum = self.accelerator.reduce(\n                        top1_switch_sum, reduction=\"sum\"\n                    )\n                    top1_switch_count = self.accelerator.reduce(\n                        top1_switch_count, reduction=\"sum\"\n                    )\n                if float(js_count.item()) > 0.0:\n                    moe_last_router_js_step = float((js_sum / js_count).item())\n                if float(top1_switch_count.item()) > 0.0:\n                    moe_last_router_top1_switch_rate = float(\n                        (top1_switch_sum / top1_switch_count).item()\n                    )\n\n        self.storage.clear()\n        loss_out = {\n            \"value_function\": mean_value_loss,\n            \"critic_explained_variance\": critic_explained_variance,\n            \"surrogate\": mean_surrogate_loss,\n            \"entropy\": mean_entropy,\n            \"kl_token\": mean_kl_token,\n            \"kl_loss\": mean_kl_loss,\n            \"kl_analytic\": mean_kl_analytic,\n            \"aux_base_lin_vel_nll\": mean_aux_base_lin_vel_nll,\n            \"aux_root_height_nll\": mean_aux_root_height_nll,\n            \"aux_base_lin_vel_std\": mean_aux_base_lin_vel_std,\n            \"aux_root_height_std\": mean_aux_root_height_std,\n            \"aux_keybody_contact_bce\": mean_aux_keybody_contact_bce,\n            \"aux_keybody_contact_acc\": mean_aux_keybody_contact_acc,\n            \"aux_ref_keybody_rel_pos_mse\": mean_aux_ref_keybody_rel_pos_mse,\n            \"aux_robot_keybody_rel_pos_mse\": (\n                mean_aux_robot_keybody_rel_pos_mse\n            ),\n            \"aux_denoise_ref_root_lin_vel_huber\": (\n                mean_aux_denoise_ref_root_lin_vel_huber\n            ),\n            \"aux_denoise_ref_root_ang_vel_huber\": (\n                mean_aux_denoise_ref_root_ang_vel_huber\n            ),\n            \"aux_denoise_ref_dof_pos_huber\": (\n                mean_aux_denoise_ref_dof_pos_huber\n            ),\n            \"aux_router_command_recon_mse\": mean_aux_router_command_recon_mse,\n            \"aux_router_future_recon_huber\": (\n                mean_aux_router_future_recon_huber\n            ),\n            \"aux_router_switch_penalty_js\": (\n                mean_aux_router_switch_penalty_js\n            ),\n            \"dead_expert_margin_to_topk\": (\n                mean_dead_expert_margin_to_topk_loss\n            ),\n            \"router_expert_orthogonal\": mean_router_expert_orthogonal_loss,\n            \"selected_expert_margin_to_unselected\": (\n                mean_selected_expert_margin_to_unselected_loss\n            ),\n            \"moe_active_expert_ratio\": moe_active_expert_ratio,\n            \"moe_max_expert_frac\": moe_max_expert_frac,\n            \"moe_least_expert_frac\": moe_least_expert_frac,\n            \"moe_dead_expert_ratio\": moe_dead_expert_ratio,\n            \"moe_expert_count_cv\": moe_expert_count_cv,\n            \"moe_selected_expert_margin_to_unselected\": (\n                moe_selected_expert_margin_to_unselected\n            ),\n            \"moe_last_router_js_step\": moe_last_router_js_step,\n            \"moe_last_router_top1_switch_rate\": (\n                moe_last_router_top1_switch_rate\n            ),\n        }\n        if self.is_distributed:\n            reduced_out = {}\n            for k, v in loss_out.items():\n                if v is None:\n                    reduced_out[k] = None\n                    continue\n                t = torch.tensor(v, device=self.device, dtype=torch.float32)\n                reduced_t = self.accelerator.reduce(t, reduction=\"mean\")\n                reduced_out[k] = float(reduced_t.item())\n            loss_out = reduced_out\n\n        self._post_update_hook(loss_out)\n        return loss_out\n"
  },
  {
    "path": "holomotion/src/data_curation/.gitignore",
    "content": "_generated/\n"
  },
  {
    "path": "holomotion/src/data_curation/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/data_curation/data_smplify.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n#\nimport argparse\nimport os\n\nfrom smplify.smplify_humanact12 import humanact12_to_amass\nfrom smplify.smplify_motionx import motionx_to_amass\nfrom smplify.smplify_omomo import omomo_to_amass\n\nfrom holomotion.holomotion.src.data_curation.smplify.smplify_zjumocap import (\n    zju_to_amass,\n)\n\n\ndef ensure_dir(path):\n    \"\"\"Make sure the dir exist.\n\n    Args:\n        path: The path of the dir.\n\n    \"\"\"\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef main():\n    \"\"\"Convert multiple motion capture datasets to AMASS format.\n\n    This function parses command-line arguments to specify the root directory\n    of raw datasets and an optional save directory. It iterates over the\n    supported datasets (MotionX, ZJU_Mocap, HumanAct12, OMOMO), and if the\n    corresponding data directory exists, converts it to AMASS format and saves\n    it in a unified directory structure.\n\n    Raises:\n        SystemExit: If required command-line arguments are missing.\n\n    Side Effects:\n        Creates output directories and writes converted data files.\n        Prints progress and warning messages to stdout.\n\n    \"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Convert all datasets to AMASS format\"\n    )\n    parser.add_argument(\n        \"--data_root\",\n        type=str,\n        required=True,\n        help=\"Path to the root directory of raw datasets\",\n    )\n    parser.add_argument(\n        \"--save_root\",\n        type=str,\n        default=None,\n        help=\"Path to save the unified data (default: data_root/smplx_data)\",\n    )\n    args = parser.parse_args()\n\n    data_root = os.path.abspath(args.data_root)\n    save_root = args.save_root or \"./data/amass_compatible_datasets\"\n\n    print(f\"Raw data root: {data_root}\")\n    print(f\"Unified data will be saved to: {save_root}\")\n    ensure_dir(save_root)\n\n    datasets = [\n        (\"MotionX\", motionx_to_amass),\n        (\"ZJU_Mocap\", zju_to_amass),\n        (\"humanact12\", humanact12_to_amass),\n        (\"OMOMO\", omomo_to_amass),\n    ]\n\n    for name, func in datasets:\n        data_dir = os.path.join(data_root, name)\n        save_dir = os.path.join(save_root, name)\n        ensure_dir(save_dir)\n        if not os.path.exists(data_dir):\n            print(f\"Warning: {data_dir} does not exist. Skipping {name}.\")\n            continue\n\n        print(f\"Processing {name}...\")\n        func(data_dir, save_dir)\n        print(f\"{name} done. Saved to {save_dir}.\\n\")\n\n    print(\"All datasets processed.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/data_curation/filter/filter.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport argparse\nimport json\nimport os\n\nimport numpy as np\n\n\ndef checksitpose(\n    npz_path, ref_pose_path, threshold=0.75, frame_thresh=1\n) -> bool:\n    \"\"\"Check if the given motion sequence is close to the pose.\n\n    Args:\n        npz_path (str): Path to the .npz file of the motion sequence.\n        ref_pose_path (str): the reference sitting pose.\n        threshold (float, optional): Euclidean distance threshold.\n        frame_thresh (int, optional): Minimum number of frames.\n\n    Returns:\n        bool: True if the sequence contains sitting-like frames.\n\n    \"\"\"\n    count = 0\n    try:\n        sitdata = np.load(ref_pose_path)\n        sitpose = sitdata[\"poses\"][535][:66]  # reference sitting pose\n    except Exception:\n        return False\n\n    sitpose_down = sitpose[3:36]  # lower-body joints only\n\n    bdata = np.load(npz_path)\n    curposes = bdata[\"poses\"]  # shape: (N, 165)\n\n    for pose in curposes:\n        pose_down = pose[3:36]\n        dist = np.linalg.norm(pose_down - sitpose_down)\n        if dist < threshold:\n            count += 1\n        if count >= frame_thresh:\n            return True\n\n    return False\n\n\ndef process_dataset(\n    parent_folder,\n    json_path,\n    output_path,\n    abnormal_path,\n    sit_pose_reference,\n    stair_keywords=None,\n    sit_keywords=None,\n    sit_threshold=0.75,\n    frame_threshold=20,\n    velocity_threshold=100.0,\n):\n    \"\"\"Label the dataset under parent folder.\"\"\"\n    stair_keywords = stair_keywords or [\n        \"stairs\",\n        \"staircase\",\n        \"upstairs\",\n        \"downstairs\",\n    ]\n    sit_keywords = sit_keywords or [\"sitting\", \"Sitting\"]\n    abnormal_dataset = [\"aist\"]\n\n    stairs = sit = untrack = 0\n    filtered_paths = set()\n\n    with (\n        open(json_path) as f_in,\n        open(output_path, \"w\") as f_out_normal,\n        open(abnormal_path, \"w\") as f_out_abnormal,\n    ):\n        for line in f_in:\n            line = line.strip()\n            if not line:\n                continue\n\n            try:\n                content = json.loads(line)\n                path = content.get(\"path\", \"\")\n                npz_path = os.path.join(parent_folder, path)\n\n                # skip the path if it is abnormal\n                if path in filtered_paths:\n                    f_out_abnormal.write(line + \"\\n\")\n                    continue\n\n                up_z = content.get(\"max_up_z_velocity\", 0)\n                down_z = content.get(\"max_down_z_velocity\", 0)\n                max_z = content.get(\"max_z_translation\", 0)\n                min_z = content.get(\"min_z_translation\", 0)\n                mean_v = content.get(\"mean_velocity\", 0)\n\n                # filter by keywords\n                if any(kw in path for kw in stair_keywords):\n                    f_out_abnormal.write(line + \"\\n\")\n                    filtered_paths.clear()\n                    filtered_paths.add(path)\n                    stairs += 1\n                    continue\n                elif any(kw in path for kw in sit_keywords):\n                    f_out_abnormal.write(line + \"\\n\")\n                    filtered_paths.clear()\n                    filtered_paths.add(path)\n                    sit += 1\n                    continue\n\n                elif any(kw in path for kw in abnormal_dataset):\n                    f_out_abnormal.write(line + \"\\n\")\n                    filtered_paths.clear()\n                    filtered_paths.add(path)\n                    continue\n\n                if mean_v > velocity_threshold:\n                    f_out_abnormal.write(line + \"\\n\")\n                    filtered_paths.clear()\n                    filtered_paths.add(path)\n                    untrack += 1\n                    continue\n\n                if up_z >= 0.6 and max_z > 0.7:\n                    f_out_abnormal.write(line + \"\\n\")\n                    filtered_paths.clear()\n                    filtered_paths.add(path)\n                    stairs += 1\n                    continue\n\n                elif down_z <= -0.7 and min_z < -0.7:\n                    f_out_abnormal.write(line + \"\\n\")\n                    filtered_paths.clear()\n                    filtered_paths.add(path)\n                    stairs += 1\n                    continue\n\n                if checksitpose(\n                    npz_path,\n                    sit_pose_reference,\n                    sit_threshold,\n                    frame_threshold,\n                ):\n                    f_out_abnormal.write(line + \"\\n\")\n                    filtered_paths.add(path)\n                    sit += 1\n                    continue\n\n                # normal motion\n                f_out_normal.write(line + \"\\n\")\n\n            except Exception as e:\n                print(f\"Error processing line: {line}\\nException: {e}\")\n\n    print(\n        f\"total abnormal data:upstairs {stairs}, sitting {sit}, \\\n            velocity {untrack}\"\n    )\n\n\ndef jsonl_to_yaml(jsonl_path, yaml_output_path):\n    \"\"\"Convert jsonl file into yaml file.\"\"\"\n    output_set = set()\n\n    with open(jsonl_path) as f:\n        for line in f:\n            if not line.strip():\n                continue\n            try:\n                data = json.loads(line)\n                path = data.get(\"path\", \"\")\n                if path:\n                    clean_path = os.path.splitext(path.strip().lstrip(\"/\"))[0]\n                    new_name = \"0-\" + clean_path.replace(\"/\", \"_\").replace(\n                        \"\\\\\", \"_\"\n                    )\n                    output_set.add(f\"{new_name}\")\n            except json.JSONDecodeError:\n                print(f\"skip json line: {line.strip()}\")\n                continue\n\n    with open(yaml_output_path, \"w\") as out:\n        out.write(\"[\\n\")\n        for item in sorted(output_set):\n            out.write(f\"  {item},\\n\")\n        out.write(\"]\\n\")\n\n    print(f\"done, total {len(output_set)} items -> {yaml_output_path}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Filter AMASS dataset and save results.\"\n    )\n\n    parser.add_argument(\n        \"--parent_folder\",\n        type=str,\n        default=\"./data/amass_compatible_datasets\",\n        help=\"Path to the parent folder of AMASS data\",\n    )\n    parser.add_argument(\n        \"--json_path\",\n        type=str,\n        default=\"./data/dataset_labels/OMOMO.jsonl\",\n        help=\"Path to the input JSONL file\",\n    )\n    parser.add_argument(\n        \"--output_path\",\n        type=str,\n        default=\"./data/dataset_labels/temp.jsonl\",\n        help=\"Path to save the filtered output JSONL\",\n    )\n    parser.add_argument(\n        \"--abnormal_path\",\n        type=str,\n        default=\"./data/dataset_labels/temp2.jsonl\",\n        help=\"Path to save abnormal data JSONL\",\n    )\n    parser.add_argument(\n        \"--sit_pose_reference\",\n        type=str,\n        default=\"./data/amass_compatible_datasets/amass/BioMotionLab_NTroje/rub062/0016_sitting2_poses.npz\",\n        help=\"Path to the reference sitting pose npz\",\n    )\n    parser.add_argument(\n        \"--yaml_path\",\n        type=str,\n        default=\"./holomotion/config/data_curation/base.yaml\",\n        help=\"Path to the excluded yaml file\",\n    )\n\n    args = parser.parse_args()\n\n    process_dataset(\n        parent_folder=args.parent_folder,\n        json_path=args.json_path,\n        output_path=args.output_path,\n        abnormal_path=args.abnormal_path,\n        sit_pose_reference=args.sit_pose_reference,\n    )\n    os.makedirs(os.path.dirname(args.yaml_path), exist_ok=True)\n    jsonl_to_yaml(args.abnormal_path, args.yaml_path)\n"
  },
  {
    "path": "holomotion/src/data_curation/filter/label_data.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n#\nimport argparse\nimport json\nimport os\nimport sys\n\nimport numpy as np\n\nsys.path.append(\n    \"./holomotion/src/data_curation/omomo_release/human_body_prior/src\"\n)\n\n\ndef calc_max_xy_translation(motion_data: dict):\n    \"\"\"Calculate max xy translation.\"\"\"\n    trans = motion_data[\"trans\"]\n    root_trans_offset = trans\n    max_xy_translation = np.max(\n        np.linalg.norm(\n            root_trans_offset[:, :2] - root_trans_offset[0:1, :2],\n            axis=1,\n        )\n    )\n    return max_xy_translation\n\n\ndef calc_max_z_translation(motion_data: dict):\n    \"\"\"Calculate max and min z translation.\"\"\"\n    trans = motion_data[\"trans\"]\n    root_trans_offset = trans\n    max_z_translation = np.max(\n        root_trans_offset[:, 2] - root_trans_offset[0:1, 2]\n    )\n    min_z_translation = np.min(\n        root_trans_offset[:, 2] - root_trans_offset[0:1, 2]\n    )\n    return max_z_translation, min_z_translation\n\n\ndef calc_max_velocity_scale(motion_data: dict, fps: float = 30):\n    \"\"\"Calculate max velocity scale.\"\"\"\n    root_trans_offset = motion_data[\"trans\"]\n    est_root_vel = np.diff(root_trans_offset * fps, axis=0)\n    root_vel_norm = np.linalg.norm(est_root_vel, axis=-1)\n    max_velocity_scale = np.max(root_vel_norm)\n    return max_velocity_scale\n\n\ndef calc_mean_velocity_scale(motion_data: dict, fps: float = 30):\n    \"\"\"Calculate mean velocity scale.\"\"\"\n    root_trans_offset = motion_data[\"trans\"]\n    est_root_vel = np.diff(root_trans_offset * fps, axis=0)\n    root_vel_norm = np.linalg.norm(est_root_vel, axis=-1)\n    mean_velocity_scale = np.mean(root_vel_norm)\n    return mean_velocity_scale\n\n\ndef calc_std_velocity_scale(motion_data: dict, fps: float = 30):\n    \"\"\"Calculate std velocity scale.\"\"\"\n    root_trans_offset = motion_data[\"trans\"]\n    est_root_vel = np.diff(root_trans_offset * fps, axis=0)\n    root_vel_norm = np.linalg.norm(est_root_vel, axis=-1)\n    std_velocity_scale = np.std(root_vel_norm)\n    return std_velocity_scale\n\n\ndef calc_max_vxy_scale(motion_data: dict, fps: float = 30):\n    \"\"\"Calculate smax vx, vy scale.\"\"\"\n    root_trans_offset = motion_data[\"trans\"]\n    est_root_vel = np.diff(root_trans_offset * fps, axis=0)\n    root_vel_norm = np.linalg.norm(est_root_vel[:, :2], axis=-1)\n    max_vxy_scale = np.max(root_vel_norm)\n    mean_vxy_scale = np.mean(root_vel_norm)\n    std_vxy_scale = np.std(root_vel_norm)\n    return max_vxy_scale, mean_vxy_scale, std_vxy_scale\n\n\ndef calc_std_accel(motion_data: dict, fps: float = 30.0) -> float:\n    \"\"\"Calculate the standard deviation of root joint acceleration.\n\n    This function computes the per-frame acceleration of the root joint in the\n    XY plane from its translation data and returns the standard deviation\n    of those values.\n\n    Args:\n        motion_data (dict): A dictionary that must contain a 'trans' key\n        representing global translation of the root joint.\n        Shape should be (T, 3), where T is the number of frames.\n        fps (float): Frames per second of the motion sequence.\n\n    Returns:\n        float: Standard deviation of the acceleration magnitudes\n        on the XY plane. Returns 0.0 if there are fewer than 3 frames.\n\n    \"\"\"\n    trans = motion_data[\"trans\"]  # shape: (T, 3)\n    if trans.shape[0] < 3:\n        return 0.0  # At least 3 frames are needed to compute two differences\n\n    # Compute velocity (frame-to-frame displacement * fps)\n    velocities = np.diff(trans, axis=0) * fps  # shape: (T-1, 3)\n\n    # Compute acceleration (frame-to-frame velocity difference * fps)\n    accelerations = np.diff(velocities, axis=0) * fps  # shape: (T-2, 3)\n\n    # Compute acceleration magnitude in XY plane\n    accel_xy_norm = np.linalg.norm(\n        accelerations[:, :2], axis=1\n    )  # shape: (T-2,)\n\n    # Return standard deviation\n    return np.std(accel_xy_norm)\n\n\ndef calc_max_vz_scale(motion_data: dict, fps: float = 30):\n    \"\"\"Calculate max vz scale.\"\"\"\n    root_trans_offset = motion_data[\"trans\"]\n    est_root_vel = np.diff(root_trans_offset * fps, axis=0)\n    root_vel_norm = np.abs(est_root_vel[:, 2])\n    max_vz_scale = np.max(root_vel_norm)\n    mean_vz_scale = np.mean(root_vel_norm)\n    std_vz_scale = np.std(root_vel_norm)\n    return max_vz_scale, mean_vz_scale, std_vz_scale\n\n\ndef calc_vz_scale_with_direction(motion_data: dict, fps: float = 30):\n    \"\"\"Calculate vz scale with direction.\"\"\"\n    root_trans_offset = motion_data[\"trans\"]\n    est_root_vel = np.diff(root_trans_offset * fps, axis=0)\n    vz = est_root_vel[:, 2]\n\n    max_up_vz = np.max(vz[vz > 0]) if np.any(vz > 0) else 0.0\n    max_down_vz = np.min(vz[vz < 0]) if np.any(vz < 0) else 0.0\n    mean_vz = np.mean(vz)\n    std_vz = np.std(vz)\n\n    return max_up_vz, max_down_vz, mean_vz, std_vz\n\n\ndef beyond_upper_dof_limits(\n    motion_data: dict,\n    upper_dof_mapping: dict,\n    upper_dof_max_limits: dict,\n):\n    \"\"\"Check whether or not the motion data is beyond upper dof limits.\"\"\"\n    for dof_name, dof_idx in upper_dof_mapping.items():\n        dof_data = motion_data[\"dof\"][:, dof_idx]\n        max_dof_scale = np.max(dof_data)\n        min_dof_scale = np.min(dof_data)\n        if (\n            max_dof_scale < upper_dof_max_limits[dof_name][0]\n            or max_dof_scale > upper_dof_max_limits[dof_name][1]\n            or min_dof_scale < upper_dof_max_limits[dof_name][0]\n            or min_dof_scale > upper_dof_max_limits[dof_name][1]\n        ):\n            return True\n    return False\n\n\nclass HyperParams:\n    max_xy_translation: float = 2.0\n    max_z_translation: float = 0.3\n    max_velocity_scale: float = 1.0\n    max_vxy_scale: float = 1.2\n    max_vz_scale: float = 0.3\n    upper_dof_mapping: dict = {\n        \"left_shoulder_pitch_joint\": 13,\n        \"left_shoulder_roll_joint\": 14,\n        \"left_shoulder_yaw_joint\": 15,\n        \"left_elbow_joint\": 16,\n        \"right_shoulder_pitch_joint\": 17,\n        \"right_shoulder_roll_joint\": 18,\n        \"right_shoulder_yaw_joint\": 19,\n        \"right_elbow_joint\": 20,\n    }\n    upper_dof_max_limits: dict = {\n        \"left_shoulder_pitch_joint\": [-1.0, 1.0],\n        \"left_shoulder_roll_joint\": [0.0, 0.5],\n        \"left_shoulder_yaw_joint\": [-0.5, 0.5],\n        \"left_elbow_joint\": [0.5, 1.3],\n        \"right_shoulder_pitch_joint\": [-1.0, 1.0],\n        \"right_shoulder_roll_joint\": [-0.5, 0.0],\n        \"right_shoulder_yaw_joint\": [-0.5, 0.3],\n        \"right_elbow_joint\": [0.5, 1.5],\n    }\n\n\ndef label_data_with_metrics(data_folder, jsonl_path: str, parent_folder: str):\n    \"\"\"Calculate the metics and load them into a jsonl file.\"\"\"\n    assert jsonl_path.endswith(\".jsonl\")\n    with open(jsonl_path, \"w\") as f_out:\n        for root, _, files in os.walk(data_folder):\n            for file in files:\n                if file.endswith(\".npz\"):\n                    npz_path = os.path.join(root, file)\n                    content = {}\n                    content[\"path\"] = os.path.relpath(npz_path, parent_folder)\n                    data = np.load(npz_path)\n                    fps = data.get(\"mocap_frame_rate\")\n                    if fps is None:\n                        fps = data.get(\"mocap_framerate\")\n                    if fps is None:\n                        continue\n\n                    try:\n                        content[\"max_xy_translation\"] = round(\n                            calc_max_xy_translation(data), 2\n                        )\n                        max_z_translation, min_z_translation = (\n                            calc_max_z_translation(data)\n                        )\n                        content[\"max_z_translation\"] = round(\n                            max_z_translation, 2\n                        )\n                        content[\"min_z_translation\"] = round(\n                            min_z_translation, 2\n                        )\n                        content[\"max_velocity\"] = round(\n                            calc_max_velocity_scale(data, fps), 2\n                        )\n                        content[\"mean_velocity\"] = round(\n                            calc_mean_velocity_scale(data, fps), 2\n                        )\n                        content[\"std_velocity\"] = round(\n                            calc_std_velocity_scale(data, fps), 2\n                        )\n                        content[\"std_accel\"] = round(\n                            calc_std_accel(data, fps), 2\n                        )\n                        max_xy_v, mean_xy_v, std_xy_v = calc_max_vxy_scale(\n                            data, fps\n                        )\n                        content[\"max_xy_velocity\"] = round(max_xy_v, 2)\n                        content[\"mean_xy_velocity\"] = round(mean_xy_v, 2)\n                        content[\"std_xy_velocity\"] = round(std_xy_v, 2)\n                        max_up_z_v, max_down_z_v, mean_z_v, std_z_v = (\n                            calc_vz_scale_with_direction(data, fps)\n                        )\n                        content[\"max_up_z_velocity\"] = round(max_up_z_v, 2)\n                        content[\"max_down_z_velocity\"] = round(max_down_z_v, 2)\n                        content[\"mean_z_velocity\"] = round(mean_z_v, 2)\n                        content[\"std_z_velocity\"] = round(std_z_v, 2)\n                    except Exception as e:\n                        print(f\"Error: {e}\")\n\n                    def convert_to_builtin_type(obj):\n                        if isinstance(obj, dict):\n                            return {\n                                k: convert_to_builtin_type(v)\n                                for k, v in obj.items()\n                            }\n                        elif isinstance(obj, list):\n                            return [convert_to_builtin_type(i) for i in obj]\n                        elif isinstance(obj, np.ndarray):\n                            return obj.tolist()\n                        elif isinstance(obj, (np.float32, np.float64)):\n                            return float(obj)\n                        elif isinstance(obj, (np.int32, np.int64)):\n                            return int(obj)\n                        else:\n                            return obj\n\n                    f_out.write(\n                        json.dumps(convert_to_builtin_type(content)) + \"\\n\"\n                    )\n\n    print(f\"Annotated file saved to: {jsonl_path}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--jsonl_list\",\n        nargs=\"+\",\n        default=[\"humanact12\", \"MotionX\", \"OMOMO\", \"ZJU_Mocap\", \"amass\"],\n        help=\"List of jsonl files to process.\",\n    )\n    args = parser.parse_args()\n\n    amass_folder = \"./data/amass_compatible_datasets/amass\"\n    other_folder = \"./data/amass_compatible_datasets\"\n    caption_folder = \"./data/dataset_labels\"\n    os.makedirs(caption_folder, exist_ok=True)\n\n    for name in args.jsonl_list:\n        file = name + \".jsonl\"\n        if name != \"amass\":\n            label_data_with_metrics(\n                os.path.join(other_folder, name),\n                os.path.join(caption_folder, file),\n                other_folder,\n            )\n        else:\n            label_data_with_metrics(\n                amass_folder, os.path.join(caption_folder, file), amass_folder\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/data_curation/smpl_npz_to_html.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport argparse\nimport json\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Tuple\n\nimport numpy as np\nfrom scipy.spatial.transform import Rotation as R\n\n\n# -----------------------------\n# Defaults\n# -----------------------------\nDEFAULT_TEMPLATE_PATH = Path(\"index_wooden_static.html\")\nDEFAULT_OUT_HTML = Path(\"vis.html\")\n\nPOSE_JOINTS = 22\nEULER_FIX_DEG = (-90.0, 180.0, 0.0)\nEULER_ORDER = \"xyz\"\n\n# Empirical vertical offset (in meters) to align wooden_static visualization mesh\n# with canonical SMPL coordinates (e.g., GVHMR pipelines).\nWOODEN_SMPL_HEIGHT_OFFSET = 0.2\n\n\n@dataclass(frozen=True)\nclass SmplSequence:\n    \"\"\"A minimal SMPL motion sequence loaded from npz.\"\"\"\n\n    poses: np.ndarray  # (T, 66) = root(3) + body(63), axis-angle\n    trans: np.ndarray  # (T, 3)\n    betas: np.ndarray  # (B,)\n    fps: float\n    gender: str\n\n\ndef parse_args() -> argparse.Namespace:\n    ap = argparse.ArgumentParser(\n        description=\"Generate vis.html from a SMPL npz using a HTML template.\"\n    )\n    ap.add_argument(\"--npz\", type=Path, help=\"Path to input .npz\")\n    ap.add_argument(\n        \"--template\",\n        type=Path,\n        default=DEFAULT_TEMPLATE_PATH,\n        help=\"Path to HTML template\",\n    )\n    ap.add_argument(\n        \"--out\",\n        type=Path,\n        default=DEFAULT_OUT_HTML,\n        help=\"Path to output HTML\",\n    )\n\n    ap.add_argument(\n        \"--pose_joints\",\n        type=int,\n        default=POSE_JOINTS,\n        help=f\"Number of pose joints in poses (default: {POSE_JOINTS}).\",\n    )\n    ap.add_argument(\n        \"--height_axis\",\n        type=int,\n        default=1,\n        choices=[0, 1, 2],\n        help=\"Axis index for height in Th (default: 1 for Y-up).\",\n    )\n    ap.add_argument(\n        \"--height_offset\",\n        type=float,\n        default=WOODEN_SMPL_HEIGHT_OFFSET,\n        help=(\n            \"Subtract from Th height axis (Y-up), in meters. \"\n            \"Default is an empirical offset to align wooden_static mesh \"\n            \"with canonical SMPL coordinates (e.g., GVHMR).\"\n        ),\n    )\n    return ap.parse_args()\n\n\ndef euler_fix_rot(euler_deg=EULER_FIX_DEG, order=EULER_ORDER) -> R:\n    \"\"\"Rotation for world-frame correction: R_new = R_fix * R_old.\"\"\"\n    return R.from_euler(order.lower(), euler_deg, degrees=True)\n\n\ndef _require_key(data: np.lib.npyio.NpzFile, key: str) -> np.ndarray:\n    if key not in data:\n        raise KeyError(\n            f\"Missing key '{key}' in npz. Available: {list(data.keys())}\"\n        )\n    return data[key]\n\n\ndef load_npz(path: Path) -> SmplSequence:\n    if not path.exists():\n        raise FileNotFoundError(f\"Missing {path}\")\n\n    data = np.load(path, allow_pickle=False)\n\n    poses = _require_key(data, \"poses\").astype(np.float32)\n    trans = _require_key(data, \"trans\").astype(np.float32)\n    betas = _require_key(data, \"betas\").astype(np.float32)\n\n    fps = float(np.asarray(_require_key(data, \"mocap_framerate\")))\n    gender = str(np.asarray(_require_key(data, \"gender\")))\n\n    return SmplSequence(\n        poses=poses, trans=trans, betas=betas, fps=fps, gender=gender\n    )\n\n\ndef validate_sequence(seq: SmplSequence, pose_joints: int) -> int:\n    \"\"\"Validate shapes and return T.\"\"\"\n    if seq.poses.ndim != 2:\n        raise ValueError(f\"poses must be 2D, got shape={seq.poses.shape}\")\n    if seq.trans.ndim != 2 or seq.trans.shape[1] != 3:\n        raise ValueError(f\"trans must be (T,3), got shape={seq.trans.shape}\")\n\n    T = int(seq.poses.shape[0])\n    exp_dim = int(pose_joints) * 3\n\n    if seq.poses.shape[1] != exp_dim:\n        raise ValueError(\n            f\"unexpected poses shape: {seq.poses.shape}, expected (T,{exp_dim})\"\n        )\n    if seq.trans.shape[0] != T:\n        raise ValueError(\n            f\"poses frames ({T}) != trans frames ({seq.trans.shape[0]})\"\n        )\n\n    return T\n\n\ndef build_smpl_frames(\n    seq: SmplSequence,\n    *,\n    pose_joints: int,\n    height_axis: int,\n    height_offset: float,\n) -> Tuple[list, int]:\n    \"\"\"\n    Build frames in the format expected by index_wooden_static.html template.\n\n    Notes:\n        height_offset is a visualization-only correction to compensate for the\n        vertical origin mismatch between wooden_static mesh and canonical SMPL\n        coordinates (e.g., GVHMR). Override via --height_offset if needed.\n    \"\"\"\n    T = validate_sequence(seq, pose_joints)\n\n    rot_fix = euler_fix_rot()\n    root_aa = seq.poses[:, :3]\n    body_aa = seq.poses[:, 3:]\n\n    # root: left-multiply world rotation\n    Rh = (rot_fix * R.from_rotvec(root_aa)).as_rotvec().astype(np.float32)\n\n    # trans: rotate in world frame, then apply visualization height offset\n    Th = rot_fix.apply(seq.trans).astype(np.float32)\n    if height_offset != 0.0:\n        Th[:, int(height_axis)] -= float(height_offset)\n\n    # pad hands (6) -> body(63) + hand(6) = 69\n    poses_js = np.concatenate([body_aa, np.zeros((T, 6), np.float32)], axis=1)\n\n    shapes = seq.betas.reshape(-1).tolist()\n    frames = [\n        [\n            {\n                \"id\": 0,\n                \"gender\": seq.gender,\n                \"Rh\": [Rh[f].tolist()],\n                \"Th\": [Th[f].tolist()],\n                \"poses\": [poses_js[f].tolist()],\n                \"shapes\": shapes,\n            }\n        ]\n        for f in range(T)\n    ]\n\n    return frames, T\n\n\ndef render_html(template_path: Path, frames: list, T: int, fps: float) -> str:\n    template = template_path.read_text(encoding=\"utf-8\")\n\n    smpl_data_json = json.dumps(frames, ensure_ascii=False)\n    caption_html = (\n        \"<div class='caption-overlay'><div class='motion-info'>\"\n        f\"Frames: {T} &nbsp;&nbsp; Framerate: {fps:.1f} fps\"\n        \"</div></div>\"\n    )\n\n    return template.replace(\"{{ smpl_data_json }}\", smpl_data_json).replace(\n        \"{{ caption_html }}\", caption_html\n    )\n\n\ndef main(\n    npz_path: Path,\n    template_path: Path,\n    out_html: Path,\n    *,\n    pose_joints: int,\n    height_axis: int,\n    height_offset: float,\n) -> None:\n    if not template_path.exists():\n        raise FileNotFoundError(f\"Missing {template_path}\")\n\n    seq = load_npz(npz_path)\n    frames, T = build_smpl_frames(\n        seq,\n        pose_joints=pose_joints,\n        height_axis=height_axis,\n        height_offset=height_offset,\n    )\n\n    html = render_html(template_path, frames, T, seq.fps)\n    out_html.write_text(html, encoding=\"utf-8\")\n    print(f\"[OK] wrote {out_html.resolve()}\")\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(\n        args.npz,\n        args.template,\n        args.out,\n        pose_joints=args.pose_joints,\n        height_axis=args.height_axis,\n        height_offset=args.height_offset,\n    )\n"
  },
  {
    "path": "holomotion/src/data_curation/smplify/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/data_curation/smplify/smplify_humanact12.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport os\nimport random\n\nimport h5py\nimport numpy as np\nimport smplx\nimport torch\nfrom scipy.spatial.transform import Rotation\nfrom tqdm import tqdm\n\nfrom thirdparties.joints2smpl.src import config\nfrom thirdparties.joints2smpl.src.smplify import SMPLify3D\n\nSMPL_MODEL_DIR = \"./assets/smpl/\"\n\ndevice = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n# device = torch.device(\"cpu\")\n\nnum_joints = 22\njoint_category = \"AMASS\"\nnum_smplify_iters = 150\nfix_foot = False\n\n\ndef joints2smpl(file_name, data_dir, save_dir):\n    \"\"\"Convert 3D joint positions to SMPL-X parameters.\n\n    Args:\n        file_name (str): Name of the input .npy joint file\n        data_dir (str): Directory containing input joint files\n        save_dir (str): Directory to save processed output files\n\n    \"\"\"\n    # print(file_name)\n    input_joints = np.load(os.path.join(data_dir, file_name))\n\n    input_joints = input_joints[:, :, [0, 1, 2]]  # amass stands on x, y\n\n    \"\"\"XY at origin\"\"\"\n    input_joints[..., [0, 1]] -= input_joints[0, 0, [0, 1]]\n\n    \"\"\"Put on Floor\"\"\"\n    floor_height = input_joints[:, :, 2].min()\n    input_joints[:, :, 2] -= floor_height\n\n    batch_size = input_joints.shape[0]\n\n    smplmodel = smplx.create(\n        SMPL_MODEL_DIR,\n        model_type=\"smpl\",\n        gender=\"neutral\",\n        ext=\"npz\",\n        batch_size=batch_size,\n    ).to(device)\n\n    # ## --- load the mean pose as original ----\n    smpl_mean_file = config.SMPL_MEAN_FILE\n\n    file = h5py.File(smpl_mean_file, \"r\")\n    init_mean_pose = (\n        torch.from_numpy(file[\"pose\"][:])\n        .unsqueeze(0)\n        .repeat(batch_size, 1)\n        .float()\n        .to(device)\n    )\n    init_mean_shape = (\n        torch.from_numpy(file[\"shape\"][:])\n        .unsqueeze(0)\n        .repeat(batch_size, 1)\n        .float()\n        .to(device)\n    )\n    cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)\n\n    # # #-------------initialize SMPLify\n    smplify = SMPLify3D(\n        smplxmodel=smplmodel,\n        batch_size=batch_size,\n        joints_category=joint_category,\n        num_iters=num_smplify_iters,\n        device=device,\n    )\n\n    keypoints_3d = torch.Tensor(input_joints).to(device).float()\n\n    pred_betas = init_mean_shape\n    pred_pose = init_mean_pose\n    pred_cam_t = cam_trans_zero\n\n    if joint_category == \"AMASS\":\n        confidence_input = torch.ones(num_joints)\n        # make sure the foot and ankle\n        if fix_foot:\n            confidence_input[7] = 1.5\n            confidence_input[8] = 1.5\n            confidence_input[10] = 1.5\n            confidence_input[11] = 1.5\n    else:\n        print(\"Such category not settle down!\")\n\n    (\n        new_opt_vertices,\n        new_opt_joints,\n        new_opt_pose,\n        new_opt_betas,\n        new_opt_cam_t,\n        new_opt_joint_loss,\n    ) = smplify(\n        pred_pose.detach(),\n        pred_betas.detach(),\n        pred_cam_t.detach(),\n        keypoints_3d,\n        conf_3d=confidence_input.to(device),\n        # seq_ind=idx\n    )\n\n    poses = new_opt_pose.detach().cpu().numpy()\n    betas = new_opt_betas.mean(axis=0).detach().cpu().numpy()\n    trans = keypoints_3d[:, 0].detach().cpu().numpy()\n\n    target_dim = 165\n    current_dim = poses.shape[-1]\n    pad_dim = target_dim - current_dim\n\n    if pad_dim > 0:\n        pad_array = np.zeros((*poses.shape[:-1], pad_dim), dtype=poses.dtype)\n        poses = np.concatenate([poses, pad_array], axis=-1)\n\n    root_orient = poses[:, :3]\n    root_mat = Rotation.from_rotvec(root_orient).as_matrix()\n    rx_minus_100 = Rotation.from_euler(\"x\", -100, degrees=True).as_matrix()\n    align_r = rx_minus_100 @ root_mat\n    align_axis_angle = Rotation.from_matrix(align_r).as_rotvec()\n    poses[:, :3] = align_axis_angle\n    trans_rotated = rx_minus_100 @ (trans.T)\n    trans_rotated = trans_rotated.T\n    input_joints = input_joints[:, :, [0, 2, 1]]  # jts stands on x, z\n    input_joints[..., 0] *= -1\n    param = {\n        \"poses\": poses,\n        \"trans\": trans_rotated,\n        \"betas\": betas,\n        \"gender\": \"neutral\",\n        \"jtr\": input_joints,\n        \"mocap_frame_rate\": 30,\n    }\n    file_name = file_name.split(\".\")[0] + \".npz\"\n    print(file_name)\n    np.savez_compressed(os.path.join(save_dir, file_name), **param)\n\n\ndef humanact12_to_amass(data_dir, save_dir):\n    \"\"\"Convert HumanAct12 dataset to AMASS-compatible format.\n\n    Args:\n        data_dir (str): Directory containing HumanAct12 .npy joint files\n        save_dir (str): Directory to save processed AMASS .npz files\n\n    \"\"\"\n    os.makedirs(save_dir, exist_ok=True)\n\n    file_list = os.listdir(data_dir)\n    random.shuffle(file_list)\n    for file_name in tqdm(file_list):\n        if os.path.exists(os.path.join(save_dir, file_name)):\n            print(f\"{os.path.join(save_dir, file_name)} already exists\")\n            continue\n        joints2smpl(file_name, data_dir, save_dir)\n\n\nif __name__ == \"__main__\":\n    data_dir = \"\"\n    save_dir = \"\"\n"
  },
  {
    "path": "holomotion/src/data_curation/smplify/smplify_motionx.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport os\n\nimport numpy as np\nfrom scipy.spatial.transform import Rotation\n\n\ndef motionx_to_amass(src_root, dst_root):\n    \"\"\"Convert MotionX format motion data to AMASS format.\n\n    Args:\n        src_root (str): Source directory containing MotionX .npy files\n        dst_root (str): Destination directory for processed AMASS .npz files\n\n    Side effects:\n        Creates directory structure mirroring src_root under dst_root\n        Generates compressed .npz files in destination directory\n        Prints file paths of processed files\n\n    Processed data contains:\n        poses: [T, 165] float array of joint rotations (root first)\n        trans: [T, 3] float array of root translations\n        betas: [10] float array of shape parameters\n        gender: str (always \"neutral\")\n        mocap_frame_rate: int (always 30)\n\n    \"\"\"\n    os.makedirs(dst_root, exist_ok=True)\n    for root, _, files in os.walk(src_root):\n        # print(files)\n        for file in files:\n            src_file_path = os.path.join(root, file)\n            motion = np.load(src_file_path)\n            poses = motion[:, :156]  # 最终 shape: (T, 156)\n            num_frames = poses.shape[0]\n            sl = poses.shape[1]\n\n            pad = np.zeros((num_frames, 165 - sl), dtype=poses.dtype)  # (T, 9)\n            poses = np.concatenate([poses, pad], axis=1)  # (T, 165)\n            align_axis_angle = poses[:, :3]\n            root_orient = poses[:, :3]\n            root_mat = Rotation.from_rotvec(root_orient).as_matrix()\n            rotate_matrix = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])\n            align_r = rotate_matrix @ root_mat\n            align_axis_angle = Rotation.from_matrix(align_r).as_rotvec()\n            poses[:, :3] = align_axis_angle\n\n            trans = motion[:, 309:312]  # (T, 3)\n            trans[:, 2] = trans[:, 2] * (-1)\n            trans_matrix = np.array(\n                [[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]\n            )\n            trans = np.dot(trans, trans_matrix)\n            trans = rotate_matrix @ (trans.T)\n            trans = trans.T\n            betas = motion[0, 312:]\n            amass_data = {\n                \"poses\": poses,\n                \"trans\": trans,\n                \"betas\": betas,\n                \"gender\": \"neutral\",\n                \"mocap_frame_rate\": 30,\n            }\n            relative_path = src_file_path.replace(src_root, \"\")\n            file_name = dst_root + relative_path\n            save_dir = file_name.split(\"/\")[-1]\n            save_dir = file_name.replace(save_dir, \"\")\n            os.makedirs(save_dir, exist_ok=True)\n            file_name = file_name.replace(\".npy\", \".npz\")\n            print(file_name)\n            np.savez_compressed(file_name, **amass_data)\n\n\nif __name__ == \"__main__\":\n    src_root = \"./data/smplx_322\"\n    dst_root = \"./data/smplx_data/MotionX\"\n    motionx_to_amass(src_root, dst_root)\n"
  },
  {
    "path": "holomotion/src/data_curation/smplify/smplify_omomo.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n#\n# -----------------------------------------------------------------------------\n# Portions of this file are derived from omomo_release (https://github.com/lijiaman/omomo_release).\n# The original omomo_release code is licensed under the MIT license.\n# -----------------------------------------------------------------------------\n\nimport os\n\nimport numpy as np\nimport pytorch3d.transforms as transforms\nimport torch\nfrom torch.utils import data\n\nfrom thirdparties.omomo_release.manip.data.hand_foot_dataset import (\n    HandFootManipDataset,\n    quat_ik_torch,\n)\n\n\nclass MyHandFootManipDataset(HandFootManipDataset):\n    \"\"\"Modified dataset class for hand-foot manipulation tasks.\n\n    This class overrides the __getitem__ method.\n\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        \"\"\"Initialize the dataset instance by forwarding all arguments.\n\n        This constructor ensures proper initialization\n        of the HandFootManipDataset parent class.\n        All parameters and keyword arguments are passed through unchanged.\n\n        Args:\n            *args: Variable length argument list for parent class\n            **kwargs: Arbitrary keyword arguments for parent class\n\n        \"\"\"\n        super().__init__(*args, **kwargs)\n\n    def __getitem__(self, index):\n        \"\"\"Retrieve and process a data sample by index.\n\n        Try not to padding when retrieve motion data.\n\n        Args:\n            index (int): Index of the sample to retrieve\n\n        Reference:\n            https://github.com/lijiaman/omomo_release/blob/main/manip/data/hand_foot_dataset.py\n\n        \"\"\"\n        # index = 0 # For debug\n        data_input = self.window_data_dict[index][\"motion\"]\n        data_input = torch.from_numpy(data_input).float()\n\n        seq_name = self.window_data_dict[index][\"seq_name\"]\n        object_name = seq_name.split(\"_\")[1]\n\n        trans2joint = self.window_data_dict[index][\"trans2joint\"]\n\n        if self.use_object_splits:\n            ori_w_idx = self.window_data_dict[index][\"ori_w_idx\"]\n            obj_bps_npy_path = os.path.join(\n                self.dest_obj_bps_npy_folder,\n                seq_name + \"_\" + str(ori_w_idx) + \".npy\",\n            )\n        else:\n            obj_bps_npy_path = os.path.join(\n                self.dest_obj_bps_npy_folder,\n                seq_name + \"_\" + str(index) + \".npy\",\n            )\n        obj_bps_data = np.load(obj_bps_npy_path)  # T X N X 3\n        obj_bps_data = torch.from_numpy(obj_bps_data)\n\n        num_joints = 24\n\n        normalized_jpos = self.normalize_jpos_min_max(\n            data_input[:, : num_joints * 3].reshape(-1, num_joints, 3)\n        )  # T X 22 X 3\n\n        global_joint_rot = data_input[:, 2 * num_joints * 3 :]  # T X (22*6)\n\n        new_data_input = torch.cat(\n            (normalized_jpos.reshape(-1, num_joints * 3), global_joint_rot),\n            dim=1,\n        )\n        ori_data_input = torch.cat(\n            (data_input[:, : num_joints * 3], global_joint_rot), dim=1\n        )\n\n        # Add padding.\n        actual_steps = new_data_input.shape[0]\n        # pass\n        paded_new_data_input = new_data_input\n        paded_ori_data_input = ori_data_input\n\n        paded_obj_bps = obj_bps_data.reshape(actual_steps, -1)\n        paded_obj_com_pos = torch.from_numpy(\n            self.window_data_dict[index][\"window_obj_com_pos\"]\n        ).float()\n\n        paded_obj_rot_mat = torch.from_numpy(\n            self.window_data_dict[index][\"obj_rot_mat\"]\n        ).float()\n        paded_obj_scale = torch.from_numpy(\n            self.window_data_dict[index][\"obj_scale\"]\n        ).float()\n        paded_obj_trans = torch.from_numpy(\n            self.window_data_dict[index][\"obj_trans\"]\n        ).float()\n\n        if object_name in [\"mop\", \"vacuum\"]:\n            paded_obj_bottom_rot_mat = torch.from_numpy(\n                self.window_data_dict[index][\"obj_bottom_rot_mat\"]\n            ).float()\n            paded_obj_bottom_scale = torch.from_numpy(\n                self.window_data_dict[index][\"obj_bottom_scale\"]\n            ).float()\n            paded_obj_bottom_trans = (\n                torch.from_numpy(\n                    self.window_data_dict[index][\"obj_bottom_trans\"]\n                )\n                .float()\n                .squeeze(-1)\n            )\n        data_input_dict = {}\n        data_input_dict[\"motion\"] = paded_new_data_input\n        data_input_dict[\"ori_motion\"] = paded_ori_data_input\n\n        data_input_dict[\"obj_bps\"] = paded_obj_bps\n        data_input_dict[\"obj_com_pos\"] = paded_obj_com_pos\n\n        data_input_dict[\"obj_rot_mat\"] = paded_obj_rot_mat\n        data_input_dict[\"obj_scale\"] = paded_obj_scale\n        data_input_dict[\"obj_trans\"] = paded_obj_trans\n\n        if object_name in [\"mop\", \"vacuum\"]:\n            data_input_dict[\"obj_bottom_rot_mat\"] = paded_obj_bottom_rot_mat\n            data_input_dict[\"obj_bottom_scale\"] = paded_obj_bottom_scale\n            data_input_dict[\"obj_bottom_trans\"] = paded_obj_bottom_trans\n        else:\n            data_input_dict[\"obj_bottom_rot_mat\"] = paded_obj_rot_mat\n            data_input_dict[\"obj_bottom_scale\"] = paded_obj_scale\n            data_input_dict[\"obj_bottom_trans\"] = paded_obj_trans\n\n        data_input_dict[\"betas\"] = self.window_data_dict[index][\"betas\"]\n        data_input_dict[\"gender\"] = str(self.window_data_dict[index][\"gender\"])\n\n        data_input_dict[\"seq_name\"] = seq_name\n        data_input_dict[\"obj_name\"] = seq_name.split(\"_\")[1]\n\n        data_input_dict[\"seq_len\"] = actual_steps\n\n        data_input_dict[\"trans2joint\"] = trans2joint\n\n        return data_input_dict\n\n\ndef run_smplx_model(root_trans, aa_rot_rep, betas, gender, fname):\n    \"\"\"Prepare and save SMPL-X motion data in AMASS npz format.\n\n    Processes input motion parameters into SMPL-X compatible format and saves\n    as a compressed npz file.\n\n    Args:\n        root_trans (torch.Tensor): Root translations [BS, T, 3]\n        aa_rot_rep (torch.Tensor): Axis-angle joint rotations\n        [BS, T, num_joints, 3]\n            where num_joints can be either 22 (body only) or 52 (body+hands)\n        betas (torch.Tensor): Shape parameters [BS, 16]\n        gender (list): Gender strings for each sample in batch [BS]\n        fname (str): Output filename/path for saving .npz file\n\n    Output npz file contains:\n        poses: [BS*T, 165] float array of pose parameters\n        trans: [BS*T, 3] float array of translations\n        betas: [16] float array of shape parameters (from first sample)\n        gender: str (always \"neutral\")\n        mocap_frame_rate: int (always 30)\n\n    \"\"\"\n    # root_trans: BS X T X 3\n    # aa_rot_rep: BS X T X 22 X 3\n    # betas: BS X 16\n    # gender: BS\n    bs, num_steps, num_joints, _ = aa_rot_rep.shape\n    if num_joints != 52:\n        padding_zeros_hand = torch.zeros(bs, num_steps, 30, 3).to(\n            aa_rot_rep.device\n        )  # BS X T X 30 X 3\n        aa_rot_rep = torch.cat(\n            (aa_rot_rep, padding_zeros_hand), dim=2\n        )  # BS X T X 52 X 3\n\n    aa_rot_rep = aa_rot_rep.reshape(\n        bs * num_steps, -1, 3\n    )  # (BS*T) X n_joints X 3\n    betas = (\n        betas[:, None, :].repeat(1, num_steps, 1).reshape(bs * num_steps, -1)\n    )  # (BS*T) X 16\n    gender = np.asarray(gender)[:, np.newaxis].repeat(num_steps, axis=1)\n    gender = gender.reshape(-1).tolist()  # (BS*T)\n\n    smpl_trans = root_trans.reshape(-1, 3)  # (BS*T) X 3\n    smpl_root_orient = aa_rot_rep[:, 0, :]  # (BS*T) X 3\n    # print(smpl_root_orient.shape)\n    smpl_pose_body = aa_rot_rep[:, 1:22, :].reshape(-1, 63)  # (BS*T) X 63\n    smpl_pose_hand = aa_rot_rep[:, 22:, :].reshape(-1, 90)  # (BS*T) X 90\n    poses = torch.cat(\n        [smpl_root_orient, smpl_pose_body, smpl_pose_hand], dim=-1\n    )\n    target_dim = 165\n    current_dim = poses.shape[-1]\n    pad_dim = target_dim - current_dim\n\n    if pad_dim > 0:\n        pad_tensor = torch.zeros(\n            *poses.shape[:-1], pad_dim, device=poses.device, dtype=poses.dtype\n        )\n        poses_padded = torch.cat([poses, pad_tensor], dim=-1)\n    else:\n        poses_padded = poses  # already 165 or more\n\n    amass_data = {\n        \"poses\": poses_padded.detach().cpu().numpy(),\n        \"trans\": smpl_trans.detach().cpu().numpy(),\n        \"betas\": betas[0].detach().cpu().numpy(),\n        \"gender\": \"neutral\",\n        \"mocap_frame_rate\": 30,\n    }\n    np.savez_compressed(fname, **amass_data)\n\n\ndef process_dataset(dl, dataset, target_folder, split_name: str):\n    \"\"\"Process a motion dataset batch and convert sequences to SMPL-X format.\n\n    Args:\n        dl (DataLoader): PyTorch DataLoader providing batched data\n        dataset (Dataset): Source dataset object (for denormalization)\n        target_folder (str): target folder for data saving\n        split_name (str): Name of data split being processed\n\n    Output files:\n        Saved as: {target_folder}/{split_name}_{object_name}_{index}.npz\n        Where:\n            target_folder: (implied from external context)\n            object_name: Extracted from sequence name\n            index: Incremental sequence counter\n\n    \"\"\"\n    index = 0\n    for data_dict in dl:\n        val_data = data_dict[\"motion\"].cuda()\n        for_vis_gt_data = val_data[:]\n        all_res_list = for_vis_gt_data\n\n        num_seq = all_res_list.shape[0]\n        print(f\"Processing {split_name}, num_seq: {num_seq}\")\n        num_joints = 24\n\n        normalized_global_jpos = all_res_list[:, :, : num_joints * 3].reshape(\n            num_seq, -1, num_joints, 3\n        )\n        global_jpos = dataset.de_normalize_jpos_min_max(\n            normalized_global_jpos.reshape(-1, num_joints, 3)\n        )\n        global_jpos = global_jpos.reshape(num_seq, -1, num_joints, 3)\n        global_root_jpos = global_jpos[:, :, 0, :].clone()\n        global_rot_6d = all_res_list[:, :, -22 * 6 :].reshape(\n            num_seq, -1, 22, 6\n        )\n        global_rot_mat = transforms.rotation_6d_to_matrix(global_rot_6d)\n\n        trans2joint = data_dict[\"trans2joint\"].to(all_res_list.device)\n        for idx in range(num_seq):\n            curr_global_rot_mat = global_rot_mat[idx]\n            curr_local_rot_mat = quat_ik_torch(curr_global_rot_mat)\n            curr_local_rot_aa_rep = transforms.matrix_to_axis_angle(\n                curr_local_rot_mat\n            )\n\n            curr_global_root_jpos = global_root_jpos[idx]\n            curr_trans2joint = trans2joint[idx : idx + 1].clone()\n            root_trans = curr_global_root_jpos + curr_trans2joint\n\n            betas = data_dict[\"betas\"][idx]\n            gender = data_dict[\"gender\"][idx]\n            curr_seq_name = data_dict[\"seq_name\"][idx]\n            object_name = curr_seq_name.split(\"_\")[1]\n\n            fname = os.path.join(\n                target_folder, f\"{split_name}_{object_name}_{index}.npz\"\n            )\n            print(fname)\n\n            run_smplx_model(\n                root_trans[None].cuda(),\n                curr_local_rot_aa_rep[None].cuda(),\n                betas.cuda(),\n                [gender],\n                fname,\n            )\n            index += 1\n\n\ndef omomo_to_amass(data_root_folder, target_folder):\n    \"\"\"Convert Omomo dataset to AMASS-compatible SMPL-X format.\n\n    Args:\n        data_root_folder (str): Path to the root directory of Omomo dataset\n        target_folder (str): Output directory for processed AMASS files\n\n    \"\"\"\n    use_object_split = True\n    window_size = 120\n\n    train_dataset = MyHandFootManipDataset(\n        train=True,\n        data_root_folder=data_root_folder,\n        window=window_size,\n        use_object_splits=use_object_split,\n    )\n    val_dataset = MyHandFootManipDataset(\n        train=False,\n        data_root_folder=data_root_folder,\n        window=window_size,\n        use_object_splits=use_object_split,\n    )\n\n    val_ds = val_dataset\n    train_ds = train_dataset\n    val_dl = data.DataLoader(\n        val_ds, batch_size=1, shuffle=False, pin_memory=True, num_workers=0\n    )\n    train_dl = data.DataLoader(\n        train_ds, batch_size=1, shuffle=False, pin_memory=True, num_workers=0\n    )\n\n    process_dataset(train_dl, train_dataset, target_folder, \"train\")\n    process_dataset(val_dl, val_dataset, target_folder, \"val\")\n\n\nif __name__ == \"__main__\":\n    data_root_folder = \"\"\n    target_folder = \"\"\n    omomo_to_amass(data_root_folder, target_folder)\n"
  },
  {
    "path": "holomotion/src/data_curation/smplify/smplify_zjumocap.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport os\n\nimport numpy as np\nfrom scipy.spatial.transform import Rotation\nfrom tqdm import tqdm\n\n\ndef zju_single_to_amass(\n    param_dir, out_path, gender=\"neutral\", fps=30, rotate=False\n):\n    \"\"\"Convert .npy files into a single AMASS-style .npz file.\n\n    Args:\n        param_dir: Folder containing 0.npy, 1.npy, ....\n        out_path: Output .npz path.\n        gender: Gender to assign ('neutral', 'male', 'female').\n        fps: Mocap frame rate.\n        rotate: whether or not rotate the body\n\n    \"\"\"\n    pose_list = []\n    trans_list = []\n    shape_list = []\n\n    # Get sorted list of npy files\n    files = sorted(\n        [f for f in os.listdir(param_dir) if f.endswith(\".npy\")],\n        key=lambda x: int(os.path.splitext(x)[0]),\n    )\n\n    if rotate:\n        ry_minus_180 = Rotation.from_euler(\"y\", -180, degrees=True).as_matrix()\n    else:\n        ry_minus_180 = Rotation.from_euler(\"y\", 0, degrees=True).as_matrix()\n    for fname in tqdm(files, desc=\"Processing frames\"):\n        fpath = os.path.join(param_dir, fname)\n        data = np.load(fpath, allow_pickle=True).item()\n\n        poses = data[\"poses\"]  # (1, 72)\n        global_orient = data[\"Rh\"]\n\n        root_orient = global_orient\n        root_mat = Rotation.from_rotvec(root_orient).as_matrix()\n        align_r = ry_minus_180 @ root_mat\n        align_axis_angle = Rotation.from_matrix(align_r).as_rotvec()\n        global_orient = align_axis_angle\n        body_pose = poses[:, 3:66]\n        hand_pose = poses[:, 66:72]\n\n        full_pose = np.concatenate(\n            [global_orient, body_pose, hand_pose], axis=1\n        )  # (1, 165)\n\n        pose_list.append(full_pose[0])  # shape: (165,)\n        trans_list.append(data[\"Th\"][0])  # shape: (3,)\n        shape_list.append(data[\"shapes\"][0])  # shape: (10,)\n\n    poses = np.stack(pose_list, axis=0).astype(np.float32)  # (N, 165)\n    trans = np.stack(trans_list, axis=0).astype(np.float32)  # (N, 3)\n    trans_rotated = ry_minus_180 @ (trans.T)\n    trans_rotated = trans_rotated.T\n    betas = shape_list[0].astype(np.float32)  # (10,) same for all frames\n\n    # Save as AMASS-style npz\n    np.savez_compressed(\n        out_path,\n        poses=poses,\n        trans=trans_rotated,\n        betas=betas,\n        gender=gender,\n        mocap_frame_rate=fps,\n    )\n\n    print(f\"Saved AMASS-style file to: {out_path}\")\n    print(f\"Total frames: {poses.shape[0]}\")\n\n\ndef zju_to_amass(input_dir, output_dir):\n    \"\"\"Convert multiple ZJU-formatted folders to AMASS-style .npz files.\n\n    Args:\n        input_dir: Path to ZJU dataset root folder.\n        output_dir: Path to save AMASS-format .npz files.\n\n    \"\"\"\n    os.makedirs(output_dir, exist_ok=True)\n\n    subjects = sorted(\n        [\n            d\n            for d in os.listdir(input_dir)\n            if os.path.isdir(os.path.join(input_dir, d))\n        ]\n    )\n\n    for subject in subjects:\n        subject_dir = os.path.join(input_dir, subject)\n\n        new_params_dir = os.path.join(subject_dir, \"new_params\")\n        params_dir = os.path.join(subject_dir, \"params\")\n\n        if os.path.isdir(new_params_dir):\n            param_dir = new_params_dir\n            print(f\"[{subject}] Using new_params\")\n        elif os.path.isdir(params_dir):\n            param_dir = params_dir\n            print(f\"[{subject}] Using params\")\n        else:\n            print(f\"[{subject}] No params found, skipping\")\n            continue\n\n        out_path = os.path.join(output_dir, f\"{subject}.npz\")\n        zju_single_to_amass(param_dir, out_path)\n\n    print(f\"All subjects processed. Output saved to {output_dir}\")\n\n\n# Example usage\nif __name__ == \"__main__\":\n    zju_to_amass(\n        param_dir=\"\",\n        out_path=\"\",\n    )\n"
  },
  {
    "path": "holomotion/src/data_curation/templates/index_wooden_static.html",
    "content": "<!--\n# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# This file was originally copied from the [HY-Motion-1.0] repository:\n# https://huggingface.co/spaces/tencent/HY-Motion-1.0/tree/main\n# Modifications have been made to fit the needs of this project.\n-->\n<!doctype html>\n<html lang=\"en\">\n\n<head>\n  <meta charset=\"utf-8\" />\n  <title>Motion Visualization</title>\n  <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\" />\n\n  <!-- Font Awesome (icons only). If you want zero external deps, remove and use text icons. -->\n  <link rel=\"stylesheet\" href=\"https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css\">\n\n  <!-- Three.js importmap -->\n  <script type=\"importmap\">\n    {\n      \"imports\": {\n        \"three\": \"https://cdn.jsdelivr.net/npm/three@0.160.0/build/three.module.js\",\n        \"three/addons/\": \"https://cdn.jsdelivr.net/npm/three@0.160.0/examples/jsm/\"\n      }\n    }\n  </script>\n\n  <style>\n    /* =========================\n     * Base / Layout\n     * ========================= */\n    * { box-sizing: border-box; margin: 0; padding: 0; }\n    html, body {\n      width: 100%;\n      height: 100%;\n      overflow: hidden;\n      background: #424242;\n      color: #e2e8f0;\n      font-family: -apple-system, system-ui, \"Segoe UI\", Roboto, Helvetica, Arial, sans-serif;\n    }\n\n    .fullscreen-container {\n      position: fixed;\n      inset: 0;\n      background: #424242;\n      overflow: hidden;\n    }\n\n    #vis3d {\n      position: absolute;\n      inset: 0;\n      background: #424242;\n    }\n\n    #vis3d canvas {\n      display: block;\n      width: 100% !important;\n      height: 100% !important;\n    }\n\n    /* =========================\n     * Caption overlay (top-center)\n     * ========================= */\n    .caption-overlay {\n      position: absolute;\n      top: 20px;\n      left: 50%;\n      transform: translateX(-50%);\n      max-width: 90%;\n      z-index: 100;\n      pointer-events: auto;\n    }\n\n    .motion-info {\n      background-color: rgba(45, 55, 72, 0.85);\n      backdrop-filter: blur(10px);\n      -webkit-backdrop-filter: blur(10px);\n      border-radius: 20px;\n      box-shadow: 0 4px 20px rgba(0, 0, 0, 0.4);\n      overflow: hidden;\n      max-height: 40vh;\n      overflow-y: auto;\n      display: inline-block;\n    }\n\n    .captions-section { padding: 12px 20px; white-space: nowrap; }\n    .caption-item {\n      background: transparent;\n      border: none;\n      margin-bottom: 6px;\n      color: #f0f4f8;\n      font-size: 1em;\n      font-weight: 500;\n      line-height: 1.5;\n      text-align: center;\n    }\n    .caption-item:last-child { margin-bottom: 0; }\n\n    /* =========================\n     * Controls (bottom-center)\n     * ========================= */\n    .control-overlay {\n      position: absolute;\n      left: 50%;\n      bottom: 30px;\n      transform: translateX(-50%);\n      width: min(600px, 80%);\n      z-index: 120;\n\n      background: rgba(0, 0, 0, 0.4);\n      backdrop-filter: blur(8px);\n      -webkit-backdrop-filter: blur(8px);\n      padding: 14px 18px;\n      border-radius: 12px;\n    }\n\n    .control-row {\n      display: flex;\n      align-items: center;\n      gap: 14px;\n    }\n\n    .progress-container { flex: 1; }\n\n    input[type=\"range\"].progress-slider {\n      width: 100%;\n      height: 8px;\n      border-radius: 4px;\n      background: rgba(255, 255, 255, 0.3);\n      outline: none;\n      cursor: pointer;\n      -webkit-appearance: none;\n      appearance: none;\n    }\n\n    input[type=\"range\"].progress-slider::-webkit-slider-runnable-track {\n      width: 100%;\n      height: 8px;\n      border-radius: 4px;\n      background: rgba(255, 255, 255, 0.3);\n    }\n    input[type=\"range\"].progress-slider::-webkit-slider-thumb {\n      -webkit-appearance: none;\n      appearance: none;\n      width: 18px;\n      height: 18px;\n      border-radius: 50%;\n      background: #4a9eff;\n      border: 2px solid #fff;\n      box-shadow: 0 2px 8px rgba(0, 0, 0, 0.4);\n      margin-top: -5px;\n    }\n\n    input[type=\"range\"].progress-slider::-moz-range-track {\n      width: 100%;\n      height: 8px;\n      border-radius: 4px;\n      background: rgba(255, 255, 255, 0.3);\n    }\n    input[type=\"range\"].progress-slider::-moz-range-thumb {\n      width: 18px;\n      height: 18px;\n      border-radius: 50%;\n      background: #4a9eff;\n      border: 2px solid #fff;\n      box-shadow: 0 2px 8px rgba(0, 0, 0, 0.4);\n    }\n\n    .frame-counter {\n      font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, \"Liberation Mono\", \"Courier New\", monospace;\n      font-size: 13px;\n      font-weight: 600;\n      color: #fff;\n      text-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);\n      white-space: nowrap;\n      min-width: 92px;\n      text-align: right;\n    }\n\n    /* =========================\n     * Loading overlay\n     * ========================= */\n    .loading-overlay {\n      position: absolute;\n      left: 50%;\n      top: 50%;\n      transform: translate(-50%, -50%);\n      z-index: 200;\n\n      display: flex;\n      align-items: center;\n      gap: 10px;\n\n      background: rgba(0, 0, 0, 0.7);\n      backdrop-filter: blur(8px);\n      -webkit-backdrop-filter: blur(8px);\n      color: #fff;\n      padding: 14px 18px;\n      border-radius: 10px;\n      font-size: 14px;\n    }\n\n    .loading-overlay.hidden { display: none; }\n    .loading-overlay.complete { background: rgba(76, 175, 80, 0.85); }\n    .loading-overlay.failed { background: rgba(244, 67, 54, 0.85); }\n  </style>\n</head>\n\n<body>\n  <div class=\"fullscreen-container\">\n    <div id=\"vis3d\"></div>\n\n    <!-- Caption overlay (generated by Python) -->\n    {{ caption_html }}\n\n    <!-- Minimal progress control -->\n    <div class=\"control-overlay\">\n      <div class=\"control-row\">\n        <div class=\"progress-container\">\n          <input type=\"range\" id=\"progressSlider\" class=\"progress-slider\" min=\"0\" max=\"100\" value=\"0\">\n        </div>\n        <div class=\"frame-counter\">\n          <span id=\"currentFrame\">0</span> / <span id=\"totalFrames\">0</span>\n        </div>\n      </div>\n    </div>\n\n    <!-- Loading status -->\n    <div class=\"loading-overlay\" id=\"loadingStatus\">\n      <i class=\"fas fa-spinner fa-spin\"></i> Loading...\n    </div>\n\n    <!-- Hidden controls (keep for keyboard / internal state) -->\n    <div style=\"display:none;\">\n      <button id=\"playPauseBtn\"></button>\n      <button id=\"resetBtn\"></button>\n      <input type=\"range\" id=\"speedSlider\" min=\"0.1\" max=\"3\" step=\"0.1\" value=\"1\">\n      <span id=\"speedValue\">1.0x</span>\n    </div>\n  </div>\n\n  <!-- Embedded SMPL data -->\n  <script type=\"application/json\" id=\"smpl-data-json\">\n{{ smpl_data_json }}\n  </script>\n\n  <script type=\"module\">\n    import * as THREE from 'three';\n    import { OrbitControls } from 'three/addons/controls/OrbitControls.js';\n\n    /**\n     * ============================================================\n     * Ground utilities\n     * ============================================================\n     */\n    function createBaseChessboard(gridSize=50, divisions=50, white=\"#ffffff\", black=\"#3a3a3a\", textureSize=1024) {\n      const adjusted = Math.floor(textureSize / divisions) * divisions;\n      const canvas = document.createElement(\"canvas\");\n      canvas.width = canvas.height = adjusted;\n      const ctx = canvas.getContext(\"2d\");\n      ctx.imageSmoothingEnabled = false;\n\n      const step = adjusted / divisions;\n      for (let i = 0; i < divisions; i++) {\n        for (let j = 0; j < divisions; j++) {\n          ctx.fillStyle = (i + j) % 2 === 0 ? white : black;\n          ctx.fillRect(i * step, j * step, step, step);\n        }\n      }\n\n      const tex = new THREE.CanvasTexture(canvas);\n      tex.wrapS = THREE.RepeatWrapping;\n      tex.wrapT = THREE.RepeatWrapping;\n      tex.magFilter = THREE.NearestFilter;\n      tex.minFilter = THREE.NearestFilter;\n      tex.generateMipmaps = false;\n\n      const geom = new THREE.PlaneGeometry(gridSize, gridSize);\n      const mat = new THREE.MeshStandardMaterial({\n        map: tex,\n        side: THREE.DoubleSide,\n        transparent: true,\n        opacity: 0.85,\n        roughness: 0.9,\n        metalness: 0.1\n      });\n\n      const plane = new THREE.Mesh(geom, mat);\n      plane.receiveShadow = true;\n      return plane;\n    }\n\n    function makeGroundYUp() {\n      const plane = createBaseChessboard(50, 50);\n      plane.rotation.x = -Math.PI / 2; // XZ plane, Y up\n      plane.name = \"ground\";\n      return plane;\n    }\n\n    /**\n     * ============================================================\n     * Scene setup\n     * ============================================================\n     */\n    function setupScene({ scene, camera, renderer, useGround=true }) {\n      scene.background = new THREE.Color(0x424242);\n      scene.fog = new THREE.FogExp2(0x424242, 0.06);\n\n      renderer.shadowMap.enabled = true;\n      renderer.shadowMap.type = THREE.PCFSoftShadowMap;\n\n      // Lighting (keep your style)\n      const hemi = new THREE.HemisphereLight(0xffffff, 0x444444, 1.2);\n      hemi.position.set(0, 2, 0);\n      scene.add(hemi);\n\n      const dir = new THREE.DirectionalLight(0xffffff, 1.5);\n      dir.position.set(3, 5, 4);\n      dir.castShadow = true;\n      dir.shadow.mapSize.width = 2048;\n      dir.shadow.mapSize.height = 2048;\n      dir.shadow.camera.near = 0.5;\n      dir.shadow.camera.far = 50;\n      dir.shadow.camera.left = -10;\n      dir.shadow.camera.right = 10;\n      dir.shadow.camera.top = 10;\n      dir.shadow.camera.bottom = -10;\n      dir.shadow.bias = -0.0001;\n      scene.add(dir);\n\n      const fill = new THREE.DirectionalLight(0xaaccff, 0.5);\n      fill.position.set(-3, 3, -2);\n      scene.add(fill);\n\n      const rim = new THREE.DirectionalLight(0xffeedd, 0.4);\n      rim.position.set(0, 4, -5);\n      scene.add(rim);\n\n      // Camera conventions (Y-up, forward Z)\n      camera.up.set(0, 1, 0);\n      camera.position.set(0, 2.5, 5);\n      camera.lookAt(new THREE.Vector3(0, 1, 0));\n\n      if (useGround) {\n        scene.add(makeGroundYUp());\n      }\n    }\n\n    function fitCameraToScene(scene, camera, controls=null, opts={}) {\n      const { margin=1.05, excludeNames=[\"ground\"] } = opts;\n\n      const box = new THREE.Box3();\n      const tmp = new THREE.Box3();\n      let has = false;\n\n      scene.traverse((obj) => {\n        if (!obj || !obj.visible) return;\n        if (obj.isLight) return;\n        if ((obj.type || \"\").endsWith(\"Helper\")) return;\n        if (excludeNames.includes(obj.name)) return;\n\n        if (obj.isMesh) {\n          if (obj.geometry && obj.geometry.type === \"PlaneGeometry\") return;\n          try {\n            tmp.setFromObject(obj);\n            if (!tmp.isEmpty()) {\n              if (!has) { box.copy(tmp); has = true; }\n              else { box.union(tmp); }\n            }\n          } catch (_) {}\n        }\n      });\n\n      if (!has || box.isEmpty()) return;\n\n      const sphere = new THREE.Sphere();\n      box.getBoundingSphere(sphere);\n      const center = sphere.center.clone();\n      const radius = Math.max(sphere.radius, 1e-3);\n\n      const vFov = THREE.MathUtils.degToRad(camera.fov);\n      const hFov = 2 * Math.atan(Math.tan(vFov / 2) * camera.aspect);\n      const distV = radius / Math.sin(vFov / 2);\n      const distH = radius / Math.sin(hFov / 2);\n      const dist = Math.max(distV, distH) * margin;\n\n      const elev = THREE.MathUtils.degToRad(25);\n      const azim = Math.PI / 4;\n      const horiz = Math.cos(elev);\n      const dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.sin(elev), Math.cos(azim) * horiz);\n\n      camera.up.set(0, 1, 0);\n      camera.position.copy(center).add(dir.multiplyScalar(dist));\n      camera.updateProjectionMatrix();\n      camera.lookAt(center);\n\n      if (controls) {\n        controls.target.copy(center);\n        controls.minDistance = Math.max(radius * 0.2, 0.1);\n        controls.maxDistance = Math.max(dist * 3, controls.minDistance + 0.1);\n        controls.update();\n      }\n    }\n\n    /**\n     * ============================================================\n     * Wooden loader\n     * ============================================================\n     */\n    const NUM_SKIN_WEIGHTS = 4;\n\n    const DEFAULT_EDGES = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50];\n\n    async function loadWoodenModel({ basePath=\"https://raw.githubusercontent.com/chingswy/WoodenModel/refs/heads/main/dump_wooden\" } = {}) {\n      const urls = [\n        `${basePath}/v_template.bin`,\n        `${basePath}/faces.bin`,\n        `${basePath}/skinWeights.bin`,\n        `${basePath}/skinIndice.bin`,\n        `${basePath}/j_template.bin`,\n        `${basePath}/uvs.bin`,\n      ];\n\n      // Optional: kintree & joint names\n      let edges = [...DEFAULT_EDGES];\n      try {\n        const resp = await fetch(`${basePath}/kintree.bin`);\n        if (resp.ok) edges = Array.from(new Int32Array(await resp.arrayBuffer()));\n      } catch (_) {}\n\n      let jointNames = null;\n      try {\n        const resp = await fetch(`${basePath}/joint_names.json`);\n        if (resp.ok) jointNames = await resp.json();\n      } catch (_) {}\n\n      const bufs = await Promise.all(urls.map(u => fetch(u).then(r => {\n        if (!r.ok) throw new Error(`Failed to fetch: ${u}`);\n        return r.arrayBuffer();\n      })));\n\n      const vTemplate = new Float32Array(bufs[0]);\n      const faces = new Uint16Array(bufs[1]);\n      const skinWeights = new Float32Array(bufs[2]);\n      const skinIndices = new Uint16Array(bufs[3]);\n      const keypoints = new Float32Array(bufs[4]);\n      const uvs = new Float32Array(bufs[5]);\n\n      const geometry = new THREE.BufferGeometry();\n      geometry.setAttribute('position', new THREE.BufferAttribute(vTemplate, 3));\n      geometry.setIndex(new THREE.BufferAttribute(faces, 1));\n      geometry.setAttribute('skinIndex', new THREE.BufferAttribute(skinIndices, NUM_SKIN_WEIGHTS));\n      geometry.setAttribute('skinWeight', new THREE.BufferAttribute(skinWeights, NUM_SKIN_WEIGHTS));\n      geometry.setAttribute('uv', new THREE.BufferAttribute(uvs, 2));\n      geometry.computeVertexNormals();\n\n      const numJoints = keypoints.length / 3;\n      while (edges.length < numJoints) edges.push(0);\n\n      // Build bones\n      const bones = [];\n      const root = new THREE.Bone();\n      root.position.set(keypoints[0], keypoints[1], keypoints[2]);\n      root.name = (jointNames && jointNames[0]) ? jointNames[0] : \"Pelvis\";\n      bones.push(root);\n\n      for (let i = 1; i < numJoints; i++) {\n        const bone = new THREE.Bone();\n        const p = edges[i];\n        bone.name = (jointNames && jointNames[i]) ? jointNames[i] : `Joint_${i}`;\n\n        if (p >= 0 && p < i) {\n          bone.position.set(\n            keypoints[3*i]   - keypoints[3*p],\n            keypoints[3*i+1] - keypoints[3*p+1],\n            keypoints[3*i+2] - keypoints[3*p+2],\n          );\n          bones.push(bone);\n          bones[p].add(bone);\n        } else {\n          bone.position.set(0, 0, 0);\n          bones.push(bone);\n          bones[0].add(bone);\n        }\n      }\n\n      const skeleton = new THREE.Skeleton(bones);\n\n      // Texture\n      const texLoader = new THREE.TextureLoader();\n      const baseColor = await texLoader.loadAsync(`${basePath}/Boy_lambert4_BaseColor.webp`);\n      baseColor.flipY = false;\n      baseColor.colorSpace = THREE.SRGBColorSpace;\n\n      const material = new THREE.MeshStandardMaterial({\n        map: baseColor,\n        roughness: 0.6,\n        metalness: 0.2,\n        envMapIntensity: 1.5,\n      });\n\n      const mesh = new THREE.SkinnedMesh(geometry, material);\n      mesh.castShadow = true;\n      mesh.receiveShadow = true;\n      mesh.add(bones[0]);\n      mesh.bind(skeleton);\n\n      return { bones, skeleton, mesh };\n    }\n\n    /**\n     * ============================================================\n     * Playback + UI\n     * ============================================================\n     */\n    let scene, camera, renderer, controls;\n    let infos = null;\n    let currentFrame = 0;\n    let totalFrame = 0;\n\n    const baseIntervalMs = 30;\n    let isPlaying = false;\n    let lastFrameTime = 0;\n    let playbackSpeed = 1.0;\n    let animationId = null;\n\n    let modelsLoaded = false;\n    let expectedModelCount = 0;\n    let loadedModelCount = 0;\n\n    const modelBonesById = {};\n\n    function setLoading(text, state=\"loading\") {\n      const el = document.getElementById(\"loadingStatus\");\n      if (!el) return;\n\n      if (state === \"hidden\") {\n        el.className = \"loading-overlay hidden\";\n        return;\n      }\n\n      el.className = \"loading-overlay\";\n      if (state === \"complete\") el.className += \" complete\";\n      if (state === \"failed\") el.className += \" failed\";\n\n      el.innerHTML = text;\n    }\n\n    function updateUI() {\n      document.getElementById(\"currentFrame\").textContent = String(currentFrame);\n      document.getElementById(\"totalFrames\").textContent = String(totalFrame);\n\n      if (totalFrame > 0) {\n        const p = (currentFrame / totalFrame) * 100;\n        document.getElementById(\"progressSlider\").value = String(p);\n      } else {\n        document.getElementById(\"progressSlider\").value = \"0\";\n      }\n    }\n\n    function computeOffsets(batchSize) {\n      const spacing = 2.0;\n      const totalWidth = (batchSize - 1) * spacing;\n      const startX = -totalWidth / 2;\n      const offsets = [];\n      for (let i = 0; i < batchSize; i++) offsets.push(startX + i * spacing);\n      return offsets;\n    }\n\n    function updateFrame() {\n      if (!infos || !modelsLoaded || currentFrame < 0 || currentFrame >= totalFrame) return;\n\n      const info = infos[currentFrame];\n      for (const smpl of info) {\n        if (!(smpl.id in modelBonesById)) return;\n      }\n\n      const offsets = computeOffsets(info.length);\n\n      info.forEach((smpl, b) => {\n        const bones = modelBonesById[smpl.id];\n        const meshContainer = bones[0].parent;\n\n        // global translation\n        meshContainer.position.set(\n          smpl.Th[0][0] - offsets[b],\n          smpl.Th[0][1],\n          smpl.Th[0][2]\n        );\n\n        // root rotation: axis-angle -> quaternion\n        const axis = new THREE.Vector3(smpl.Rh[0][0], smpl.Rh[0][1], smpl.Rh[0][2]);\n        const angle = axis.length();\n        if (angle > 1e-8) axis.normalize();\n        bones[0].quaternion.copy(new THREE.Quaternion().setFromAxisAngle(axis, angle));\n\n        // poses: handle 69 (padded hands) vs others\n        let posesOffset = 0;\n        if (smpl.poses[0].length === 69) posesOffset = -3;\n\n        for (let i = 1; i < bones.length; i++) {\n          const start = posesOffset + 3 * i;\n          if (start + 2 >= smpl.poses[0].length) continue;\n\n          const a = new THREE.Vector3(\n            smpl.poses[0][start],\n            smpl.poses[0][start + 1],\n            smpl.poses[0][start + 2]\n          );\n          const ang = a.length();\n          if (ang > 1e-6) {\n            a.normalize();\n            bones[i].quaternion.copy(new THREE.Quaternion().setFromAxisAngle(a, ang));\n          } else {\n            bones[i].quaternion.set(0, 0, 0, 1);\n          }\n        }\n      });\n\n      updateUI();\n    }\n\n    function playLoop(t) {\n      if (!isPlaying) return;\n      if (t - lastFrameTime >= (baseIntervalMs / playbackSpeed)) {\n        currentFrame += 1;\n        if (currentFrame >= totalFrame) currentFrame = 0;\n        updateFrame();\n        lastFrameTime = t;\n      }\n      animationId = requestAnimationFrame(playLoop);\n    }\n\n    function play() {\n      if (!modelsLoaded || totalFrame <= 0) return;\n      if (isPlaying) return;\n      isPlaying = true;\n      lastFrameTime = performance.now();\n      animationId = requestAnimationFrame(playLoop);\n    }\n\n    function pause() {\n      isPlaying = false;\n      if (animationId) cancelAnimationFrame(animationId);\n      animationId = null;\n    }\n\n    function reset() {\n      pause();\n      currentFrame = 0;\n      updateFrame();\n    }\n\n    function initControls() {\n      const slider = document.getElementById(\"progressSlider\");\n\n      let wasPlaying = false;\n      slider.addEventListener(\"mousedown\", () => {\n        if (!modelsLoaded) return;\n        wasPlaying = isPlaying;\n        if (isPlaying) pause();\n      });\n\n      slider.addEventListener(\"input\", (e) => {\n        if (!modelsLoaded) return;\n        const progress = parseFloat(e.target.value);\n        currentFrame = Math.floor((progress / 100) * totalFrame);\n        if (currentFrame >= totalFrame) currentFrame = totalFrame - 1;\n        if (currentFrame < 0) currentFrame = 0;\n        updateFrame();\n      });\n\n      slider.addEventListener(\"mouseup\", () => {\n        if (!modelsLoaded) return;\n        if (wasPlaying) play();\n      });\n\n      document.addEventListener(\"keydown\", (e) => {\n        if (!modelsLoaded) return;\n        switch (e.code) {\n          case \"Space\":\n            e.preventDefault();\n            isPlaying ? pause() : play();\n            break;\n          case \"ArrowLeft\":\n            e.preventDefault();\n            currentFrame = Math.max(0, currentFrame - 1);\n            updateFrame();\n            break;\n          case \"ArrowRight\":\n            e.preventDefault();\n            currentFrame = Math.min(totalFrame - 1, currentFrame + 1);\n            updateFrame();\n            break;\n          case \"Home\":\n            e.preventDefault();\n            reset();\n            break;\n        }\n      });\n    }\n\n    /**\n     * ============================================================\n     * Data load (embedded JSON) + app bootstrap\n     * ============================================================\n     */\n    function loadEmbeddedData() {\n      const el = document.getElementById(\"smpl-data-json\");\n      if (!el) throw new Error(\"SMPL data element not found\");\n\n      const datas = JSON.parse(el.textContent || \"[]\");\n      if (!datas || datas.length === 0) throw new Error(\"No SMPL data available\");\n\n      infos = datas;\n      totalFrame = datas.length;\n      updateUI();\n\n      expectedModelCount = infos[0].length;\n      loadedModelCount = 0;\n      modelsLoaded = false;\n\n      setLoading(`<i class=\"fas fa-spinner fa-spin\"></i> Loading... (0/${expectedModelCount})`, \"loading\");\n\n      // Load one wooden model per id in frame0\n      infos[0].forEach((d) => {\n        loadWoodenModel({ basePath: \"https://raw.githubusercontent.com/chingswy/WoodenModel/refs/heads/main/dump_wooden\" })\n          .then((result) => {\n            scene.add(result.mesh);\n            modelBonesById[d.id] = result.bones;\n\n            loadedModelCount += 1;\n            if (loadedModelCount >= expectedModelCount) {\n              modelsLoaded = true;\n              setLoading(`<i class=\"fas fa-check\"></i> Ready`, \"complete\");\n              setTimeout(() => setLoading(\"\", \"hidden\"), 1200);\n\n              updateFrame();\n              fitCameraToScene(scene, camera, controls, { excludeNames: [\"ground\"] });\n\n              play();\n            } else {\n              setLoading(\n                `<i class=\"fas fa-spinner fa-spin\"></i> Loading... (${loadedModelCount}/${expectedModelCount})`,\n                \"loading\"\n              );\n            }\n          })\n          .catch((err) => {\n            console.error(\"Failed to load wooden model:\", err);\n            setLoading(`<i class=\"fas fa-triangle-exclamation\"></i> Failed: ${String(err)}`, \"failed\");\n          });\n      });\n    }\n\n    function init() {\n      const width = window.innerWidth;\n      const height = window.innerHeight;\n\n      scene = new THREE.Scene();\n      camera = new THREE.PerspectiveCamera(45, width / height, 0.1, 50);\n\n      renderer = new THREE.WebGLRenderer({ antialias: true, logarithmicDepthBuffer: true });\n      renderer.toneMapping = THREE.ACESFilmicToneMapping;\n      renderer.toneMappingExposure = 1.0;\n      renderer.outputColorSpace = THREE.SRGBColorSpace;\n      renderer.setPixelRatio(window.devicePixelRatio);\n      renderer.setSize(width, height);\n\n      setupScene({ scene, camera, renderer, useGround: true });\n\n      const container = document.getElementById(\"vis3d\");\n      container.appendChild(renderer.domElement);\n\n      controls = new OrbitControls(camera, renderer.domElement);\n      controls.minDistance = 1;\n      controls.maxDistance = 15;\n      controls.enableDamping = true;\n      controls.dampingFactor = 0.05;\n      controls.target.set(0, 1, 0);\n      controls.update();\n\n      window.addEventListener(\"resize\", () => {\n        const w = window.innerWidth;\n        const h = window.innerHeight;\n        camera.aspect = w / h;\n        camera.updateProjectionMatrix();\n        renderer.setSize(w, h);\n      });\n\n      // Click: toggle play/pause (only after load)\n      let isDragging = false;\n      let mouseDownTime = 0;\n      renderer.domElement.addEventListener(\"mousedown\", () => {\n        isDragging = false;\n        mouseDownTime = Date.now();\n      });\n      renderer.domElement.addEventListener(\"mousemove\", () => {\n        if (Date.now() - mouseDownTime > 150) isDragging = true;\n      });\n      renderer.domElement.addEventListener(\"mouseup\", () => {\n        if (!modelsLoaded) return;\n        if (!isDragging && Date.now() - mouseDownTime < 300) {\n          isPlaying ? pause() : play();\n        }\n      });\n      renderer.domElement.addEventListener(\"dblclick\", () => {\n        if (!modelsLoaded) return;\n        reset();\n      });\n    }\n\n    function animate() {\n      requestAnimationFrame(animate);\n      if (controls && controls.enableDamping) controls.update();\n      renderer.render(scene, camera);\n    }\n\n    // Bootstrap\n    init();\n    initControls();\n    animate();\n\n    try {\n      loadEmbeddedData();\n    } catch (e) {\n      console.error(e);\n      setLoading(`<i class=\"fas fa-triangle-exclamation\"></i> Failed: ${String(e)}`, \"failed\");\n    }\n  </script>\n</body>\n</html>\n"
  },
  {
    "path": "holomotion/src/data_curation/video_to_smpl_gvhmr.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n#\n# This file was originally copied from the [PBHC] repository:\n# https://github.com/TeleHuman/PBHC\n# Modifications have been made to fit the needs of this project.\n\nimport cv2\nimport torch\nimport pytorch_lightning as pl\nimport numpy as np\nimport argparse\nfrom hmr4d.utils.pylogger import Log\nimport hydra\nfrom hydra import initialize_config_module, compose\nfrom pathlib import Path\nfrom pytorch3d.transforms import quaternion_to_matrix\n\nfrom hmr4d.configs import register_store_gvhmr\nfrom hmr4d.utils.video_io_utils import (\n    get_video_lwh,\n    read_video_np,\n    save_video,\n    merge_videos_horizontal,\n    get_writer,\n    get_video_reader,\n)\nfrom hmr4d.utils.vis.cv2_utils import (\n    draw_bbx_xyxy_on_image_batch,\n    draw_coco17_skeleton_batch,\n)\n\nfrom hmr4d.utils.preproc import Tracker, Extractor, VitPoseExtractor, SimpleVO\n\nfrom hmr4d.utils.geo.hmr_cam import (\n    get_bbx_xys_from_xyxy,\n    estimate_K,\n    convert_K_to_K4,\n    create_camera_sensor,\n)\nfrom hmr4d.utils.geo_transform import compute_cam_angvel\nfrom hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL\nfrom hmr4d.utils.net_utils import detach_to_cpu, to_cuda\nfrom hmr4d.utils.smplx_utils import make_smplx\nfrom hmr4d.utils.vis.renderer import (\n    Renderer,\n    get_global_cameras_static,\n    get_ground_params_from_points,\n)\nfrom tqdm import tqdm\nfrom hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay\nfrom einops import einsum, rearrange\n\nimport shutil\n\nimport subprocess\n\nfrom scipy.spatial.transform import Rotation as sRot\n\nCRF = 23  # 17 is lossless, every +6 halves the mp4 size\n\n\ndef get_video_fps(video_path: Path) -> float:\n    cap = cv2.VideoCapture(str(video_path))\n    fps = cap.get(cv2.CAP_PROP_FPS)\n    cap.release()\n    if fps is None or fps <= 1e-6:\n        raise RuntimeError(f\"Failed to read FPS from video: {video_path}\")\n    return float(fps)\n\n\ndef is_close_fps(a: float, b: float, tol: float = 0.02) -> bool:\n    return abs(a - b) <= tol\n\n\ndef transcode_to_30fps_cfr(src: Path, dst: Path, crf: int) -> None:\n    dst.parent.mkdir(parents=True, exist_ok=True)\n    cmd = [\n        \"ffmpeg\",\n        \"-y\",\n        \"-i\",\n        str(src),\n        \"-vf\",\n        \"fps=30\",\n        \"-vsync\",\n        \"cfr\",\n        \"-c:v\",\n        \"libx264\",\n        \"-crf\",\n        str(crf),\n        \"-preset\",\n        \"medium\",\n        \"-c:a\",\n        \"copy\",\n        str(dst),\n    ]\n    subprocess.run(cmd, check=True)\n\n\ndef parse_args_to_cfg(args=None):\n    # Put all args to cfg\n    if args is None:\n        parser = argparse.ArgumentParser()\n        parser.add_argument(\n            \"--video\", type=str, default=\"inputs/demo/dance_3.mp4\"\n        )\n        parser.add_argument(\n            \"--output_root\",\n            type=str,\n            default=None,\n            help=\"by default to outputs/demo\",\n        )\n        parser.add_argument(\n            \"-s\",\n            \"--static_cam\",\n            action=\"store_true\",\n            help=\"If true, skip DPVO\",\n        )\n        parser.add_argument(\n            \"--use_dpvo\",\n            action=\"store_true\",\n            help=\"If true, use DPVO. By default not using DPVO.\",\n        )\n        parser.add_argument(\n            \"--f_mm\",\n            type=int,\n            default=None,\n            help=\"Focal length of fullframe camera in mm. Leave it as None to use default values.\"\n            \"For iPhone 15p, the [0.5x, 1x, 2x, 3x] lens have typical values [13, 24, 48, 77].\"\n            \"If the camera zoom in a lot, you can try 135, 200 or even larger values.\",\n        )\n        parser.add_argument(\n            \"--verbose\",\n            action=\"store_true\",\n            help=\"If true, draw intermediate results\",\n        )\n        args = parser.parse_args()\n\n    # Input\n    video_path = Path(args.video)\n    assert video_path.exists(), f\"Video not found at {video_path}\"\n    length, width, height = get_video_lwh(video_path)\n    Log.info(f\"[Input]: {video_path}\")\n    Log.info(f\"(L, W, H) = ({length}, {width}, {height})\")\n    # Cfg\n    with initialize_config_module(\n        version_base=\"1.3\", config_module=f\"hmr4d.configs\"\n    ):\n        overrides = [\n            f\"video_name='{video_path.stem}'\",\n            f\"static_cam={args.static_cam}\",\n            f\"verbose={args.verbose}\",\n            f\"use_dpvo={args.use_dpvo}\",\n        ]\n        if args.f_mm is not None:\n            overrides.append(f\"f_mm={args.f_mm}\")\n\n        # Allow to change output root\n        if args.output_root is not None:\n            overrides.append(f\"output_root='{args.output_root}'\")\n        register_store_gvhmr()\n        cfg = compose(config_name=\"demo\", overrides=overrides)\n\n    # Output\n    Log.info(f\"[Output Dir]: {cfg.output_dir}\")\n    Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)\n    Path(cfg.preprocess_dir).mkdir(parents=True, exist_ok=True)\n\n    # Copy raw-input-video to video_path\n    Log.info(f\"[Prepare Video] {video_path} -> {cfg.video_path}\")\n\n    src_len = get_video_lwh(video_path)[0]\n    dst_path = Path(cfg.video_path)\n\n    need_regen = (not dst_path.exists()) or (\n        get_video_lwh(dst_path)[0] != src_len\n    )\n\n    src_fps = get_video_fps(video_path)\n    Log.info(f\"[Input FPS]: {src_fps:.4f}\")\n\n    if need_regen:\n        if is_close_fps(src_fps, 30.0):\n            Log.info(\"[FPS OK] ~30fps, copy without re-encoding.\")\n            dst_path.parent.mkdir(parents=True, exist_ok=True)\n            shutil.copy2(video_path, dst_path)\n        else:\n            Log.info(\"[FPS CONVERT] transcoding to 30fps with constant speed.\")\n            transcode_to_30fps_cfr(video_path, Path(cfg.video_path), CRF)\n\n    return cfg\n\n\n@torch.no_grad()\ndef run_preprocess(cfg):\n    Log.info(f\"[Preprocess] Start!\")\n    tic = Log.time()\n    video_path = cfg.video_path\n    paths = cfg.paths\n    static_cam = cfg.static_cam\n    verbose = cfg.verbose\n\n    # Get bbx tracking result\n    if not Path(paths.bbx).exists():\n        tracker = Tracker()\n        bbx_xyxy = tracker.get_one_track(video_path).float()  # (L, 4)\n        bbx_xys = get_bbx_xys_from_xyxy(\n            bbx_xyxy, base_enlarge=1.2\n        ).float()  # (L, 3) apply aspect ratio and enlarge\n        torch.save({\"bbx_xyxy\": bbx_xyxy, \"bbx_xys\": bbx_xys}, paths.bbx)\n        del tracker\n    else:\n        bbx_xys = torch.load(paths.bbx)[\"bbx_xys\"]\n        Log.info(f\"[Preprocess] bbx (xyxy, xys) from {paths.bbx}\")\n    if verbose:\n        video = read_video_np(video_path)\n        bbx_xyxy = torch.load(paths.bbx)[\"bbx_xyxy\"]\n        video_overlay = draw_bbx_xyxy_on_image_batch(bbx_xyxy, video)\n        save_video(video_overlay, cfg.paths.bbx_xyxy_video_overlay)\n\n    # Get VitPose\n    if not Path(paths.vitpose).exists():\n        vitpose_extractor = VitPoseExtractor()\n        vitpose = vitpose_extractor.extract(video_path, bbx_xys)\n        torch.save(vitpose, paths.vitpose)\n        del vitpose_extractor\n    else:\n        vitpose = torch.load(paths.vitpose)\n        Log.info(f\"[Preprocess] vitpose from {paths.vitpose}\")\n    if verbose:\n        video = read_video_np(video_path)\n        video_overlay = draw_coco17_skeleton_batch(video, vitpose, 0.5)\n        save_video(video_overlay, paths.vitpose_video_overlay)\n\n    # Get vit features\n    if not Path(paths.vit_features).exists():\n        extractor = Extractor()\n        vit_features = extractor.extract_video_features(video_path, bbx_xys)\n        torch.save(vit_features, paths.vit_features)\n        del extractor\n    else:\n        Log.info(f\"[Preprocess] vit_features from {paths.vit_features}\")\n\n    # Get visual odometry results\n    if not static_cam:  # use slam to get cam rotation\n        if not Path(paths.slam).exists():\n            if not cfg.use_dpvo:\n                simple_vo = SimpleVO(\n                    cfg.video_path,\n                    scale=0.5,\n                    step=8,\n                    method=\"sift\",\n                    f_mm=cfg.f_mm,\n                )\n                vo_results = simple_vo.compute()  # (L, 4, 4), numpy\n                torch.save(vo_results, paths.slam)\n            else:  # DPVO\n                from hmr4d.utils.preproc.slam import SLAMModel\n\n                length, width, height = get_video_lwh(cfg.video_path)\n                K_fullimg = estimate_K(width, height)\n                intrinsics = convert_K_to_K4(K_fullimg)\n                slam = SLAMModel(\n                    video_path,\n                    width,\n                    height,\n                    intrinsics,\n                    buffer=4000,\n                    resize=0.5,\n                )\n                bar = tqdm(total=length, desc=\"DPVO\")\n                while True:\n                    ret = slam.track()\n                    if ret:\n                        bar.update()\n                    else:\n                        break\n                slam_results = slam.process()  # (L, 7), numpy\n                torch.save(slam_results, paths.slam)\n        else:\n            Log.info(f\"[Preprocess] slam results from {paths.slam}\")\n\n    Log.info(f\"[Preprocess] End. Time elapsed: {Log.time() - tic:.2f}s\")\n\n\ndef load_data_dict(cfg):\n    paths = cfg.paths\n    length, width, height = get_video_lwh(cfg.video_path)\n    if cfg.static_cam:\n        R_w2c = torch.eye(3).repeat(length, 1, 1)\n    else:\n        traj = torch.load(cfg.paths.slam)\n        if cfg.use_dpvo:  # DPVO\n            traj_quat = torch.from_numpy(traj[:, [6, 3, 4, 5]])\n            R_w2c = quaternion_to_matrix(traj_quat).mT\n        else:  # SimpleVO\n            R_w2c = torch.from_numpy(traj[:, :3, :3])\n    if cfg.f_mm is not None:\n        K_fullimg = create_camera_sensor(width, height, cfg.f_mm)[2].repeat(\n            length, 1, 1\n        )\n    else:\n        K_fullimg = estimate_K(width, height).repeat(length, 1, 1)\n\n    data = {\n        \"length\": torch.tensor(length),\n        \"bbx_xys\": torch.load(paths.bbx)[\"bbx_xys\"],\n        \"kp2d\": torch.load(paths.vitpose),\n        \"K_fullimg\": K_fullimg,\n        \"cam_angvel\": compute_cam_angvel(R_w2c),\n        \"f_imgseq\": torch.load(paths.vit_features),\n    }\n    return data\n\n\ndef save_npz(pred, save_path):\n    out_dir = Path(save_path).parent\n    out_dir.mkdir(parents=True, exist_ok=True)\n    trans = pred[\"transl\"].detach().cpu()\n    body_pose = torch.cat(\n        (\n            pred[\"global_orient\"].detach().cpu(),\n            pred[\"body_pose\"].detach().cpu(),\n        ),\n        dim=1,\n    )\n\n    transform1 = sRot.from_euler(\n        \"xyz\", np.array([np.pi / 2, 0, np.pi]), degrees=False\n    )\n    new_root = (\n        transform1 * sRot.from_rotvec(body_pose[:, :3].numpy())\n    ).as_rotvec()\n    body_pose[:, :3] = torch.from_numpy(new_root)\n    trans = trans @ torch.tensor(transform1.as_matrix().T, dtype=torch.float32)\n\n    out_path = out_dir / \"smpl.npz\"\n    Log.info(f\"npz_path {out_path}\")\n    np.savez(\n        str(out_path),\n        betas=pred[\"betas\"][0].detach().cpu().numpy(),\n        gender=\"neutral\",\n        poses=body_pose.numpy(),\n        trans=trans.numpy(),\n        mocap_framerate=30.0,\n    )\n\n\ndef render_incam(cfg):\n    incam_video_path = Path(cfg.paths.incam_video)\n    if incam_video_path.exists():\n        Log.info(f\"[Render Incam] Video already exists at {incam_video_path}\")\n        return\n\n    pred = torch.load(cfg.paths.hmr4d_results)\n    smplx = make_smplx(\"supermotion\").cuda()\n    smplx2smpl = torch.load(\n        \"hmr4d/utils/body_model/smplx2smpl_sparse.pt\"\n    ).cuda()\n    faces_smpl = make_smplx(\"smpl\").faces\n\n    # smpl\n    smplx_out = smplx(**to_cuda(pred[\"smpl_params_incam\"]))\n    pred_c_verts = torch.stack(\n        [torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices]\n    )\n    # -- rendering code -- #\n    video_path = cfg.video_path\n    length, width, height = get_video_lwh(video_path)\n    K = pred[\"K_fullimg\"][0]\n\n    # renderer\n    renderer = Renderer(width, height, device=\"cuda\", faces=faces_smpl, K=K)\n    reader = get_video_reader(video_path)  # (F, H, W, 3), uint8, numpy\n    bbx_xys_render = torch.load(cfg.paths.bbx)[\"bbx_xys\"]\n\n    # -- render mesh -- #\n    verts_incam = pred_c_verts\n    writer = get_writer(incam_video_path, fps=30, crf=CRF)\n    for i, img_raw in tqdm(\n        enumerate(reader),\n        total=get_video_lwh(video_path)[0],\n        desc=f\"Rendering Incam\",\n    ):\n        img = renderer.render_mesh(\n            verts_incam[i].cuda(), img_raw, [0.8, 0.8, 0.8]\n        )\n\n        # # bbx\n        # bbx_xys_ = bbx_xys_render[i].cpu().numpy()\n        # lu_point = (bbx_xys_[:2] - bbx_xys_[2:] / 2).astype(int)\n        # rd_point = (bbx_xys_[:2] + bbx_xys_[2:] / 2).astype(int)\n        # img = cv2.rectangle(img, lu_point, rd_point, (255, 178, 102), 2)\n\n        writer.write_frame(img)\n    writer.close()\n    reader.close()\n\n\ndef render_global(cfg):\n    global_video_path = Path(cfg.paths.global_video)\n    # Always save NPZ regardless of whether the video already exists\n    pred = torch.load(cfg.paths.hmr4d_results)\n    save_npz(pred[\"smpl_params_global\"], save_path=global_video_path)\n    if global_video_path.exists():\n        Log.info(\n            f\"[Render Global] Video already exists at {global_video_path}\"\n        )\n        return\n\n    debug_cam = False\n    smplx = make_smplx(\"supermotion\").cuda()\n    smplx2smpl = torch.load(\n        \"hmr4d/utils/body_model/smplx2smpl_sparse.pt\"\n    ).cuda()\n    faces_smpl = make_smplx(\"smpl\").faces\n    J_regressor = torch.load(\n        \"hmr4d/utils/body_model/smpl_neutral_J_regressor.pt\"\n    ).cuda()\n\n    # smpl\n    smplx_out = smplx(**to_cuda(pred[\"smpl_params_global\"]))\n\n    # npz already saved above\n\n    pred_ay_verts = torch.stack(\n        [torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices]\n    )\n\n    def move_to_start_point_face_z(verts):\n        \"XZ to origin, Start from the ground, Face-Z\"\n        # position\n        verts = verts.clone()  # (L, V, 3)\n        offset = einsum(J_regressor, verts[0], \"j v, v i -> j i\")[0]  # (3)\n        offset[1] = verts[:, :, [1]].min()\n        verts = verts - offset\n        # face direction\n        T_ay2ayfz = compute_T_ayfz2ay(\n            einsum(J_regressor, verts[[0]], \"j v, l v i -> l j i\"),\n            inverse=True,\n        )\n        verts = apply_T_on_points(verts, T_ay2ayfz)\n        return verts\n\n    verts_glob = move_to_start_point_face_z(pred_ay_verts)\n    joints_glob = einsum(\n        J_regressor, verts_glob, \"j v, l v i -> l j i\"\n    )  # (L, J, 3)\n    global_R, global_T, global_lights = get_global_cameras_static(\n        verts_glob.cpu(),\n        beta=2.0,\n        cam_height_degree=20,\n        target_center_height=1.0,\n    )\n\n    # -- rendering code -- #\n    video_path = cfg.video_path\n    length, width, height = get_video_lwh(video_path)\n    _, _, K = create_camera_sensor(width, height, 24)  # render as 24mm lens\n\n    # renderer\n    renderer = Renderer(width, height, device=\"cuda\", faces=faces_smpl, K=K)\n    # renderer = Renderer(width, height, device=\"cuda\", faces=faces_smpl, K=K, bin_size=0)\n\n    # -- render mesh -- #\n    scale, cx, cz = get_ground_params_from_points(\n        joints_glob[:, 0], verts_glob\n    )\n    renderer.set_ground(scale * 1.5, cx, cz)\n    color = torch.ones(3).float().cuda() * 0.8\n\n    render_length = length if not debug_cam else 8\n    writer = get_writer(global_video_path, fps=30, crf=CRF)\n    for i in tqdm(range(render_length), desc=f\"Rendering Global\"):\n        cameras = renderer.create_camera(global_R[i], global_T[i])\n        img = renderer.render_with_ground(\n            verts_glob[[i]], color[None], cameras, global_lights\n        )\n        writer.write_frame(img)\n    writer.close()\n\n\nif __name__ == \"__main__\":\n    # Top-level parser to support folder batch mode\n    top_parser = argparse.ArgumentParser()\n    top_parser.add_argument(\"--video\", type=str, default=None)\n    top_parser.add_argument(\"--folder\", \"-f\", type=str, default=None)\n    top_parser.add_argument(\"--output_root\", \"-d\", type=str, default=None)\n    top_parser.add_argument(\"-s\", \"--static_cam\", action=\"store_true\")\n    top_parser.add_argument(\"--use_dpvo\", action=\"store_true\")\n    top_parser.add_argument(\"--f_mm\", type=int, default=None)\n    top_parser.add_argument(\"--verbose\", action=\"store_true\")\n    top_args = top_parser.parse_args()\n\n    # Batch mode\n    if top_args.folder is not None:\n        folder = Path(top_args.folder)\n        mp4_paths = sorted(\n            list(folder.glob(\"*.mp4\")) + list(folder.glob(\"*.MP4\"))\n        )\n        Log.info(f\"Found {len(mp4_paths)} .mp4 files in {folder}\")\n        for mp4_path in tqdm(mp4_paths):\n            per_args = argparse.Namespace(\n                video=str(mp4_path),\n                output_root=top_args.output_root,\n                static_cam=top_args.static_cam,\n                use_dpvo=top_args.use_dpvo,\n                f_mm=top_args.f_mm,\n                verbose=top_args.verbose,\n            )\n            try:\n                cfg = parse_args_to_cfg(per_args)\n                paths = cfg.paths\n                Log.info(f\"[GPU]: {torch.cuda.get_device_name()}\")\n                Log.info(f\"[GPU]: {torch.cuda.get_device_properties('cuda')}\")\n                run_preprocess(cfg)\n                data = load_data_dict(cfg)\n                if not Path(paths.hmr4d_results).exists():\n                    Log.info(\"[HMR4D] Predicting\")\n                    model: DemoPL = hydra.utils.instantiate(\n                        cfg.model, _recursive_=False\n                    )\n                    model.load_pretrained_model(cfg.ckpt_path)\n                    model = model.eval().cuda()\n                    tic = Log.sync_time()\n                    pred = model.predict(data, static_cam=cfg.static_cam)\n                    pred = detach_to_cpu(pred)\n                    data_time = data[\"length\"] / 30\n                    Log.info(\n                        f\"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s\"\n                    )\n                    torch.save(pred, paths.hmr4d_results)\n                render_incam(cfg)\n                render_global(cfg)\n                if not Path(paths.incam_global_horiz_video).exists():\n                    Log.info(\"[Merge Videos]\")\n                    merge_videos_horizontal(\n                        [paths.incam_video, paths.global_video],\n                        paths.incam_global_horiz_video,\n                    )\n            except Exception as e:\n                Log.error(f\"Failed on {mp4_path}: {e}\")\n        raise SystemExit(0)\n\n    # Single video mode\n    if top_args.video is None:\n        top_parser.error(\"Must provide --video or --folder\")\n\n    single_args = argparse.Namespace(\n        video=top_args.video,\n        output_root=top_args.output_root,\n        static_cam=top_args.static_cam,\n        use_dpvo=top_args.use_dpvo,\n        f_mm=top_args.f_mm,\n        verbose=top_args.verbose,\n    )\n\n    cfg = parse_args_to_cfg(single_args)\n    paths = cfg.paths\n    Log.info(f\"[GPU]: {torch.cuda.get_device_name()}\")\n    Log.info(f\"[GPU]: {torch.cuda.get_device_properties('cuda')}\")\n\n    # ===== Preprocess and save to disk ===== #\n    run_preprocess(cfg)\n    data = load_data_dict(cfg)\n\n    # ===== HMR4D ===== #\n    if not Path(paths.hmr4d_results).exists():\n        Log.info(\"[HMR4D] Predicting\")\n        model: DemoPL = hydra.utils.instantiate(cfg.model, _recursive_=False)\n        model.load_pretrained_model(cfg.ckpt_path)\n        model = model.eval().cuda()\n        tic = Log.sync_time()\n        pred = model.predict(data, static_cam=cfg.static_cam)\n        pred = detach_to_cpu(pred)\n        data_time = data[\"length\"] / 30\n        Log.info(\n            f\"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s\"\n        )\n        torch.save(pred, paths.hmr4d_results)\n\n    # ===== Render ===== #\n    render_incam(cfg)\n    render_global(cfg)\n    if not Path(paths.incam_global_horiz_video).exists():\n        Log.info(\"[Merge Videos]\")\n        merge_videos_horizontal(\n            [paths.incam_video, paths.global_video],\n            paths.incam_global_horiz_video,\n        )\n"
  },
  {
    "path": "holomotion/src/data_curation/vison_mocap/joints2smpl.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport sys\n\nsys.path.append(\"../../\")\nsys.path.append(\"../../thirdparties/joints2smpl/src\")\n\nimport h5py\nimport numpy as np\nimport smplx\nimport torch\nfrom scipy.spatial.transform import Rotation\n\nfrom thirdparties.joints2smpl.src import config\nfrom thirdparties.joints2smpl.src.smplify import SMPLify3D\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n# device = torch.device(\"cpu\")\n\nnum_joints = 22\njoint_category = \"AMASS\"\nnum_smplify_iters = 300\nfix_foot = False\n\n\ndef joints2smpl(input_joints, save_name):\n    \"\"\"Save the joints as amass-compatible npz file.\n\n    Note:\n        This function depends on the `joints2smpl` repository.\n        To use this function properly,\n        you need to manually modify parts of the internal\n        `joints2smpl` repository code to ensure compatibility.\n\n    \"\"\"\n    # print(file_name)\n    input_joints = input_joints[:, :, [0, 1, 2]]  # amass stands on x, y\n\n    \"\"\"XY at origin\"\"\"\n    input_joints[..., [0, 1]] -= input_joints[0, 0, [0, 1]]\n\n    \"\"\"Put on Floor\"\"\"\n    floor_height = input_joints[:, :, 2].min()\n    input_joints[:, :, 2] -= floor_height\n\n    batch_size = input_joints.shape[0]\n\n    smplmodel = smplx.create(\n        config.SMPL_MODEL_DIR,\n        model_type=\"smpl\",\n        gender=\"neutral\",\n        ext=\"npz\",\n        batch_size=batch_size,\n    ).to(device)\n\n    # ## --- load the mean pose as original ----\n    smpl_mean_file = config.SMPL_MEAN_FILE\n\n    file = h5py.File(smpl_mean_file, \"r\")\n    init_mean_pose = (\n        torch.from_numpy(file[\"pose\"][:])\n        .unsqueeze(0)\n        .repeat(batch_size, 1)\n        .float()\n        .to(device)\n    )\n    init_mean_shape = (\n        torch.from_numpy(file[\"shape\"][:])\n        .unsqueeze(0)\n        .repeat(batch_size, 1)\n        .float()\n        .to(device)\n    )\n    cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)\n\n    # # #-------------initialize SMPLify\n    smplify = SMPLify3D(\n        smplxmodel=smplmodel,\n        batch_size=batch_size,\n        joints_category=joint_category,\n        num_iters=num_smplify_iters,\n        device=device,\n    )\n\n    keypoints_3d = torch.Tensor(input_joints).to(device).float()\n\n    pred_betas = init_mean_shape\n    pred_pose = init_mean_pose\n    pred_cam_t = cam_trans_zero\n\n    if joint_category == \"AMASS\":\n        confidence_input = torch.ones(num_joints)\n        # make sure the foot and ankle\n        if fix_foot:\n            confidence_input[7] = 1.5\n            confidence_input[8] = 1.5\n            confidence_input[10] = 1.5\n            confidence_input[11] = 1.5\n    else:\n        print(\"Such category not settle down!\")\n\n    (\n        new_opt_vertices,\n        new_opt_joints,\n        new_opt_pose,\n        new_opt_betas,\n        new_opt_cam_t,\n        new_opt_joint_loss,\n    ) = smplify(\n        pred_pose.detach(),\n        pred_betas.detach(),\n        pred_cam_t.detach(),\n        keypoints_3d,\n        conf_3d=confidence_input.to(device),\n        # seq_ind=idx\n    )\n\n    poses = new_opt_pose.detach().cpu().numpy()\n    betas = new_opt_betas.mean(axis=0).detach().cpu().numpy()\n    trans = keypoints_3d[:, 0].detach().cpu().numpy()\n    root_orient = poses[:, :3]\n    root_mat = Rotation.from_rotvec(root_orient).as_matrix()\n    rx_minus_90 = Rotation.from_euler(\"xz\", [90, 0], degrees=True).as_matrix()\n    # rotate_matrix = np.array([[1,0,0],[0,0,-1],[0,1,0]])\n    # Ry_10 = Rotation.from_euler('z',20,degrees=True).as_matrix()\n    align_r = rx_minus_90 @ root_mat\n    # align_r = rotate_matrix@root_mat\n    align_axis_angle = Rotation.from_matrix(align_r).as_rotvec()\n    poses[:, :3] = align_axis_angle\n    input_joints = input_joints[:, :, [0, 2, 1]]  # jts stands on x, z\n    input_joints[..., 0] *= -1\n    trans_rotated = rx_minus_90 @ (trans.T)\n    trans_rotated = trans_rotated.T\n    target_dim = 165\n    poses_padding = np.zeros((poses.shape[0], target_dim))\n    if poses.shape[1] < target_dim:\n        poses_padding[:, : poses.shape[1]] = poses\n    else:\n        poses_padding = poses\n    param = {\n        \"poses\": poses_padding,\n        \"trans\": trans_rotated,\n        \"betas\": betas,\n        \"gender\": \"neutral\",\n        \"jtr\": input_joints,\n        \"mocap_frame_rate\": 30,\n    }\n    np.savez_compressed(save_name, **param)\n    print(f\"successfully save file:{save_name}\")\n"
  },
  {
    "path": "holomotion/src/data_curation/visualize_smpl_npz.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport argparse\nimport http.server\nimport os\nimport socket\nimport socketserver\nimport subprocess\nimport sys\nimport threading\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Optional\n\nimport webview\n\n\n# -----------------------------\n# UI Shell\n# -----------------------------\nSHELL_HTML = r\"\"\"\n<!doctype html>\n<html>\n<head>\n  <meta charset=\"utf-8\" />\n  <title>SMPL NPZ Viewer</title>\n  <style>\n    html,body{height:100%;margin:0;background:#1f1f1f;color:#fff;font-family:-apple-system,system-ui}\n    .topbar{\n      height:44px; display:flex; align-items:center; gap:12px;\n      padding:0 12px; background:#2a2a2a; border-bottom:1px solid rgba(255,255,255,.08);\n    }\n    .btn{\n      padding:7px 10px;\n      border:1px solid rgba(255,255,255,.25);\n      border-radius:10px;\n      background:rgba(255,255,255,.08);\n      color:#fff;\n      cursor:pointer\n    }\n    .btn:hover{background:rgba(255,255,255,.12)}\n    .status{\n      font-size:13px;opacity:.85;\n      white-space:nowrap;overflow:hidden;text-overflow:ellipsis;\n      flex: 1;\n    }\n    .bar{\n      height:8px;width:240px;\n      background:rgba(255,255,255,.14);\n      border-radius:999px;overflow:hidden;\n      display:none\n    }\n    .bar>div{\n      height:100%;width:30%;\n      background:rgba(255,255,255,.65);\n      border-radius:999px;\n      animation:move 1.1s infinite\n    }\n    @keyframes move{0%{transform:translateX(-120%)}100%{transform:translateX(320%)}}\n    .main{height:calc(100% - 44px)}\n    iframe{width:100%;height:100%;border:0;background:#000}\n  </style>\n</head>\n<body>\n  <div class=\"topbar\">\n    <button class=\"btn\" onclick=\"window.pywebview.api.pick_and_generate()\">Load NPZ</button>\n    <div class=\"bar\" id=\"bar\"><div></div></div>\n    <div class=\"status\" id=\"status\">Select an NPZ file…</div>\n  </div>\n\n  <div class=\"main\">\n    <iframe id=\"viewer\" src=\"about:blank\"></iframe>\n  </div>\n\n  <script>\n    function setBusy(b){\n      document.getElementById('bar').style.display = b ? 'block' : 'none';\n    }\n    function setStatus(t){\n      document.getElementById('status').textContent = t || '';\n    }\n    function showViewer(url){\n      const u = url + (url.includes('?') ? '&' : '?') + 't=' + Date.now();\n      document.getElementById('viewer').src = u;\n    }\n  </script>\n</body>\n</html>\n\"\"\"\n\n\n# -----------------------------\n# Config\n# -----------------------------\n@dataclass(frozen=True)\nclass AppConfig:\n    root: Path\n    port: int\n    smpl_npz_to_html: Path\n    template: Path\n    out_html: Path\n    window_title: str\n    width: int\n    height: int\n    auto_pick: bool\n    debug: bool\n\n\ndef parse_args() -> argparse.Namespace:\n    ap = argparse.ArgumentParser(description=\"SMPL NPZ viewer UI.\")\n    ap.add_argument(\n        \"--port\",\n        type=int,\n        default=8000,\n        help=\"Local HTTP port for serving assets.\",\n    )\n    ap.add_argument(\n        \"--smpl_npz_to_html\",\n        type=Path,\n        default=Path(\"smpl_npz_to_html.py\"),\n        help=\"Path to smpl_npz_to_html.py\",\n    )\n    ap.add_argument(\n        \"--template\",\n        type=Path,\n        default=Path(\"templates/index_wooden_static.html\"),\n        help=\"HTML template path\",\n    )\n    ap.add_argument(\n        \"--out\",\n        type=Path,\n        default=Path(\"_generated/vis.html\"),\n        help=\"Output vis.html path\",\n    )\n    ap.add_argument(\n        \"--title\", type=str, default=\"NPZ Viewer\", help=\"Window title\"\n    )\n    ap.add_argument(\"--width\", type=int, default=800, help=\"Window width\")\n    ap.add_argument(\"--height\", type=int, default=600, help=\"Window height\")\n    ap.add_argument(\n        \"--no-auto-pick\",\n        action=\"store_false\",\n        help=\"Do not auto-open file picker at startup\",\n    )\n    ap.add_argument(\n        \"--debug\", action=\"store_true\", help=\"Enable pywebview debug/devtools\"\n    )\n    return ap.parse_args()\n\n\n# -----------------------------\n# Utilities\n# -----------------------------\ndef js_escape(s: str) -> str:\n    return s.replace(\"\\\\\", \"\\\\\\\\\").replace(\"'\", \"\\\\'\")\n\n\ndef ensure_exists(path: Path, what: str) -> None:\n    if not path.exists():\n        raise FileNotFoundError(f\"Missing {what}: {path}\")\n\n\ndef is_port_available(port: int) -> bool:\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:\n        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        try:\n            sock.bind((\"127.0.0.1\", port))\n            return True\n        except OSError:\n            return False\n\n\n# -----------------------------\n# Core: server + generator + UI API\n# -----------------------------\nclass StaticServer:\n    def __init__(self, root: Path, port: int):\n        self.root = root\n        self.port = port\n        self._thread: Optional[threading.Thread] = None\n\n    def start(self) -> None:\n        def _serve():\n            os.chdir(self.root)  # serve assets from project root\n            handler = http.server.SimpleHTTPRequestHandler\n            with socketserver.TCPServer(\n                (\"127.0.0.1\", self.port), handler\n            ) as httpd:\n                httpd.serve_forever()\n\n        self._thread = threading.Thread(target=_serve, daemon=True)\n        self._thread.start()\n\n\nclass MakeVisRunner:\n    def __init__(\n        self,\n        root: Path,\n        smpl_npz_to_html: Path,\n        template: Path,\n        out_html: Path,\n    ):\n        self.root = root\n        self.smpl_npz_to_html = smpl_npz_to_html\n        self.template = template\n        self.out_html = out_html\n\n    def run(self, npz_path: Path) -> None:\n        ensure_exists(self.smpl_npz_to_html, \"smpl_npz_to_html.py\")\n        ensure_exists(self.template, \"template html\")\n        ensure_exists(npz_path, \"npz file\")\n\n        self.out_html.parent.mkdir(parents=True, exist_ok=True)\n\n        cmd = [\n            sys.executable,\n            str(self.smpl_npz_to_html),\n            \"--npz\",\n            str(npz_path),\n            \"--template\",\n            str(self.template),\n            \"--out\",\n            str(self.out_html),\n        ]\n        subprocess.check_call(cmd, cwd=str(self.root))\n\n\ndef pick_npz_dialog(window) -> Optional[Path]:\n    file_types = (\"NPZ files (*.npz)\", \"All files (*.*)\")\n\n    # Prefer new enum if available; fallback to deprecated constant.\n    try:\n        dialog_open = webview.FileDialog.OPEN  # type: ignore[attr-defined]\n        paths = window.create_file_dialog(\n            dialog_open, allow_multiple=False, file_types=file_types\n        )\n    except Exception:\n        paths = window.create_file_dialog(\n            webview.OPEN_DIALOG, allow_multiple=False, file_types=file_types\n        )\n\n    return Path(paths[0]) if paths else None\n\n\nclass UIAPI:\n    def __init__(self, window, cfg: AppConfig, runner: MakeVisRunner):\n        self.window = window\n        self.cfg = cfg\n        self.runner = runner\n        self._busy = False\n\n    def pick_and_generate(self) -> None:\n        if self._busy:\n            return\n\n        npz = pick_npz_dialog(self.window)\n        if npz is None:\n            return\n\n        safe_name = js_escape(npz.name)\n        self.window.evaluate_js(\n            f\"setBusy(true); setStatus('Generating: {safe_name}');\"\n        )\n\n        def worker():\n            self._busy = True\n            try:\n                self.runner.run(npz)\n                rel = self.cfg.out_html.relative_to(self.cfg.root).as_posix()\n                self.window.evaluate_js(\n                    f\"setBusy(false); setStatus('Loaded: {safe_name}'); \"\n                    f\"showViewer('http://127.0.0.1:{self.cfg.port}/{rel}');\"\n                )\n            except Exception as e:\n                msg = js_escape(str(e))\n                self.window.evaluate_js(\n                    f\"setBusy(false); setStatus('Failed: {msg}');\"\n                )\n            finally:\n                self._busy = False\n\n        threading.Thread(target=worker, daemon=True).start()\n\n    def auto_pick_once(self) -> None:\n        # Called from window.events.loaded; ensure it runs once.\n        if getattr(self, \"_auto_done\", False):\n            return\n        setattr(self, \"_auto_done\", True)\n        if self.cfg.auto_pick:\n            self.pick_and_generate()\n\n\n# -----------------------------\n# Entrypoint\n# -----------------------------\ndef build_config(args: argparse.Namespace) -> AppConfig:\n    root = Path(__file__).resolve().parent\n    smpl_npz_to_html = (\n        (root / args.smpl_npz_to_html).resolve()\n        if not args.smpl_npz_to_html.is_absolute()\n        else args.smpl_npz_to_html\n    )\n    template = (\n        (root / args.template).resolve()\n        if not args.template.is_absolute()\n        else args.template\n    )\n    out_html = (\n        (root / args.out).resolve() if not args.out.is_absolute() else args.out\n    )\n\n    return AppConfig(\n        root=root,\n        port=int(args.port),\n        smpl_npz_to_html=smpl_npz_to_html,\n        template=template,\n        out_html=out_html,\n        window_title=str(args.title),\n        width=int(args.width),\n        height=int(args.height),\n        auto_pick=not bool(args.no_auto_pick),\n        debug=bool(args.debug),\n    )\n\n\ndef main() -> None:\n    args = parse_args()\n    cfg = build_config(args)\n\n    if not is_port_available(cfg.port):\n        raise RuntimeError(\n            f\"Port {cfg.port} is already in use. Try --port 8001\"\n        )\n\n    server = StaticServer(cfg.root, cfg.port)\n    server.start()\n\n    runner = MakeVisRunner(\n        cfg.root, cfg.smpl_npz_to_html, cfg.template, cfg.out_html\n    )\n\n    window = webview.create_window(\n        cfg.window_title, html=SHELL_HTML, width=cfg.width, height=cfg.height\n    )\n    api = UIAPI(window, cfg, runner)\n    window.expose(api.pick_and_generate)\n\n    # Auto pick once on initial load (optional)\n    window.events.loaded += lambda: threading.Thread(\n        target=api.auto_pick_once, daemon=True\n    ).start()\n\n    webview.start(debug=cfg.debug)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/env/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom holomotion.src.env.isaaclab_components.isaaclab_actions import (\n    build_actions_config,\n    ActionsCfg,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_scene import (\n    build_scene_config,\n    MotionTrackingSceneCfg,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_simulator import (\n    build_simulator_config,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import (\n    build_motion_tracking_commands_config,\n    MoTrack_CommandsCfg,\n)\n\nfrom holomotion.src.env.isaaclab_components.isaaclab_rewards import (\n    build_rewards_config,\n    RewardsCfg,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_observation import (\n    build_observations_config,\n    ObservationsCfg,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_termination import (\n    build_terminations_config,\n    TerminationsCfg,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_domain_rand import (\n    build_domain_rand_config,\n    EventsCfg,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_curriculum import (\n    build_curriculum_config,\n    CurriculumCfg,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_velocity_tracking_command import (\n    build_velocity_commands_config,\n    VelTrack_CommandsCfg,\n)\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_actions.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom isaaclab.utils import configclass\nimport isaaclab.envs.mdp as mdp\n\n\nclass ActionFunctions:\n    \"\"\"Collection of action function implementations.\"\"\"\n\n    @staticmethod\n    def joint_position_action(\n        asset_name: str = \"robot\",\n        joint_names: list[str] | None = None,\n        use_default_offset: bool = True,\n        scale: float = 1.0,\n    ) -> mdp.JointPositionActionCfg:\n        \"\"\"Joint position control action.\"\"\"\n        if joint_names is None:\n            joint_names = [\".*\"]\n        return mdp.JointPositionActionCfg(\n            asset_name=asset_name,\n            joint_names=joint_names,\n            use_default_offset=use_default_offset,\n            scale=scale,\n        )\n\n    @staticmethod\n    def joint_velocity_action(\n        asset_name: str = \"robot\",\n        joint_names: list[str] | None = None,\n        scale: float = 1.0,\n    ) -> mdp.JointVelocityActionCfg:\n        \"\"\"Joint velocity control action.\"\"\"\n        if joint_names is None:\n            joint_names = [\".*\"]\n        return mdp.JointVelocityActionCfg(\n            asset_name=asset_name,\n            joint_names=joint_names,\n            scale=scale,\n        )\n\n    @staticmethod\n    def joint_effort_action(\n        asset_name: str = \"robot\",\n        joint_names: list[str] | None = None,\n        scale: float = 1.0,\n    ) -> mdp.JointEffortActionCfg:\n        \"\"\"Joint effort control action.\"\"\"\n        if joint_names is None:\n            joint_names = [\".*\"]\n        return mdp.JointEffortActionCfg(\n            asset_name=asset_name,\n            joint_names=joint_names,\n            scale=scale,\n        )\n\n\n@configclass\nclass ActionsCfg:\n    \"\"\"Container for action terms.\"\"\"\n\n    pass\n\n\ndef build_actions_config(actions_config_dict: dict) -> ActionsCfg:\n    \"\"\"Build IsaacLab-compatible ActionsCfg from a config dictionary.\"\"\"\n    actions_cfg = ActionsCfg()\n\n    for action_name, action_config in actions_config_dict.items():\n        action_type = action_config[\"type\"]\n        params = action_config.get(\"params\", {})\n\n        if action_type == \"joint_position\":\n            action_term = ActionFunctions.joint_position_action(**params)\n        elif action_type == \"joint_velocity\":\n            action_term = ActionFunctions.joint_velocity_action(**params)\n        elif action_type == \"joint_effort\":\n            action_term = ActionFunctions.joint_effort_action(**params)\n        else:\n            raise ValueError(f\"Unknown action type: {action_type}\")\n\n        setattr(actions_cfg, action_name, action_term)\n\n    return actions_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_curriculum.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nfrom isaaclab.envs import ManagerBasedRLEnv\nimport torch\nfrom typing import Sequence\nfrom isaaclab.managers import CurriculumTermCfg\nfrom isaaclab.utils import configclass\nimport isaaclab.envs.mdp as isaaclab_mdp\nfrom omegaconf import DictConfig, ListConfig, OmegaConf\nfrom typing import Any, Callable, Dict\nfrom loguru import logger\nfrom .isaaclab_domain_rand import DomainRandFunctions\n\n\ndef _completion_rate_curriculum_get_level(\n    env,\n    *,\n    term_tag: str = \"default\",\n    metric_key: str = \"Metrics/ref_motion/Task/Completion_Rate\",\n    num_updates: int = 5,\n    cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),\n    min_steps_per_level: int = 300,\n    cooldown_steps: int = 0,\n    apply_on_startup: bool = True,\n    startup_level: int = 0,\n    state_prefix: str = \"_cr_curr\",\n):\n    base_env = getattr(env, \"unwrapped\", env)\n\n    level_key = f\"{state_prefix}_level\"\n    startup_key = f\"{state_prefix}_startup_applied\"\n    last_up_key = f\"{state_prefix}_last_upgrade_step\"\n    level_start_step_key = f\"{state_prefix}_level_start_step\"\n\n    if not hasattr(base_env, level_key):\n        setattr(base_env, level_key, -1)\n    if not hasattr(base_env, startup_key):\n        setattr(base_env, startup_key, False)\n    if not hasattr(base_env, last_up_key):\n        setattr(base_env, last_up_key, -(10**18))\n    if not hasattr(base_env, level_start_step_key):\n        setattr(base_env, level_start_step_key, 0)\n\n    step = int(\n        getattr(\n            base_env,\n            \"common_step_counter\",\n            getattr(env, \"common_step_counter\", 0),\n        )\n    )\n\n    def _get_completion_stats():\n        metrics = getattr(base_env, \"metrics\", None)\n        if isinstance(metrics, dict) and metric_key in metrics:\n            val = metrics[metric_key]\n            val = float(val.item()) if hasattr(val, \"item\") else float(val)\n            return val, step\n        return None\n\n    def _thr_for_next(next_level: int) -> float:\n        if not cr_thresholds:\n            return 1.0\n        idx = max(0, min(next_level - 1, len(cr_thresholds) - 1))\n        return float(cr_thresholds[idx])\n\n    stats = _get_completion_stats()\n    cur_level = int(getattr(base_env, level_key))\n    changed = False\n\n    # -------- startup init --------\n    if apply_on_startup and not bool(getattr(base_env, startup_key)):\n        init_level = int(max(0, min(int(startup_level), int(num_updates))))\n        setattr(base_env, level_key, max(cur_level, init_level))\n        setattr(base_env, startup_key, True)\n        setattr(base_env, last_up_key, step)\n        setattr(base_env, level_start_step_key, step)\n        cur_level = int(getattr(base_env, level_key))\n        changed = True\n\n    # -------- level upgrade --------\n    if cur_level < int(num_updates):\n        if stats is not None:\n            cr_val, _ = stats\n            level_start_step = int(getattr(base_env, level_start_step_key))\n            stayed_steps = int(step - level_start_step)\n\n            cooldown_ok = True\n            if int(cooldown_steps) > 0:\n                last_up = int(getattr(base_env, last_up_key))\n                cooldown_ok = (step - last_up) >= int(cooldown_steps)\n\n            if cooldown_ok and stayed_steps >= int(min_steps_per_level):\n                next_level = min(cur_level + 1, int(num_updates))\n                thr = _thr_for_next(next_level)\n                if float(cr_val) >= float(thr):\n                    setattr(base_env, level_key, next_level)\n                    setattr(base_env, last_up_key, step)\n                    setattr(base_env, level_start_step_key, step)\n                    cur_level = next_level\n                    changed = True\n\n    applied_key = (\n        f\"{state_prefix}_applied_{str(term_tag)}_level_{int(cur_level)}\"\n    )\n    if not hasattr(base_env, applied_key):\n        setattr(base_env, applied_key, False)\n\n    already_applied = bool(getattr(base_env, applied_key))\n    need_apply = bool(changed) or (not already_applied)\n\n    return int(cur_level), stats, bool(changed), bool(need_apply)\n\n\ndef lin_vel_cmd_levels(\n    env: ManagerBasedRLEnv,\n    env_ids: Sequence[int],\n    reward_term_name: str = \"track_lin_vel_xy\",\n) -> torch.Tensor:\n    command_term = env.command_manager.get_term(\"base_velocity\")\n    ranges = command_term.cfg.ranges\n    limit_ranges = command_term.cfg.limit_ranges\n\n    reward_term = env.reward_manager.get_term_cfg(reward_term_name)\n    reward = (\n        torch.mean(env.reward_manager._episode_sums[reward_term_name][env_ids])\n        / env.max_episode_length_s\n    )\n\n    if env.common_step_counter % env.max_episode_length == 0:\n        if reward > reward_term.weight * 0.8:\n            delta_command = torch.tensor([-0.1, 0.1], device=env.device)\n            ranges.lin_vel_x = torch.clamp(\n                torch.tensor(ranges.lin_vel_x, device=env.device)\n                + delta_command,\n                limit_ranges.lin_vel_x[0],\n                limit_ranges.lin_vel_x[1],\n            ).tolist()\n            ranges.lin_vel_y = torch.clamp(\n                torch.tensor(ranges.lin_vel_y, device=env.device)\n                + delta_command,\n                limit_ranges.lin_vel_y[0],\n                limit_ranges.lin_vel_y[1],\n            ).tolist()\n\n    return torch.tensor(ranges.lin_vel_x[1], device=env.device)\n\n\ndef ang_vel_cmd_levels(\n    env: ManagerBasedRLEnv,\n    env_ids: Sequence[int],\n    reward_term_name: str = \"track_ang_vel_z\",\n) -> torch.Tensor:\n    command_term = env.command_manager.get_term(\"base_velocity\")\n    ranges = command_term.cfg.ranges\n    limit_ranges = command_term.cfg.limit_ranges\n\n    reward_term = env.reward_manager.get_term_cfg(reward_term_name)\n    reward = (\n        torch.mean(env.reward_manager._episode_sums[reward_term_name][env_ids])\n        / env.max_episode_length_s\n    )\n\n    if env.common_step_counter % env.max_episode_length == 0:\n        if reward > reward_term.weight * 0.8:\n            delta_command = torch.tensor([-0.1, 0.1], device=env.device)\n            ranges.ang_vel_z = torch.clamp(\n                torch.tensor(ranges.ang_vel_z, device=env.device)\n                + delta_command,\n                limit_ranges.ang_vel_z[0],\n                limit_ranges.ang_vel_z[1],\n            ).tolist()\n\n    return torch.tensor(ranges.ang_vel_z[1], device=env.device)\n\n\ndef robot_friction_range_by_completion_rate(\n    env: ManagerBasedRLEnv,\n    env_ids: Sequence[int],\n    *,\n    num_updates: int = 5,\n    cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),\n    min_steps_per_level: int = 300,\n    cooldown_steps: int = 0,\n    state_prefix: str = \"_cr_curr\",\n    static_friction_target=(0.3, 1.6),\n    dynamic_friction_target=(0.3, 1.2),\n    enforce_dynamic_le_static: bool = True,\n    asset_name: str = \"robot\",\n    body_names: str = \".*\",\n    restitution_range=(0.0, 0.5),\n    num_buckets: int = 64,\n    anchor_quantile: float = 0.5,\n    min_expand_frac: float = 0.0,\n):\n    base_env = getattr(env, \"unwrapped\", env)\n\n    def _quantile(lo: float, hi: float, q: float) -> float:\n        lo, hi = float(min(lo, hi)), float(max(lo, hi))\n        q = float(max(0.0, min(1.0, q)))\n        return lo + (hi - lo) * q\n\n    def _compute_ranges(level: int):\n        level_i = int(max(0, min(level, int(num_updates))))\n        frac = (\n            1.0\n            if int(num_updates) <= 0\n            else (level_i / float(int(num_updates)))\n        )\n\n        s_lo_t, s_hi_t = map(float, static_friction_target)\n        d_lo_t, d_hi_t = map(float, dynamic_friction_target)\n        s_lo_t, s_hi_t = min(s_lo_t, s_hi_t), max(s_lo_t, s_hi_t)\n        d_lo_t, d_hi_t = min(d_lo_t, d_hi_t), max(d_lo_t, d_hi_t)\n\n        s_anchor = _quantile(s_lo_t, s_hi_t, anchor_quantile)\n        d_anchor = _quantile(d_lo_t, d_hi_t, anchor_quantile)\n\n        eps = float(min_expand_frac)\n        band = eps + (1.0 - eps) * float(max(frac, 0.0))\n\n        s_lo = s_anchor - (s_anchor - s_lo_t) * band\n        s_hi = s_anchor + (s_hi_t - s_anchor) * band\n        d_lo = d_anchor - (d_anchor - d_lo_t) * band\n        d_hi = d_anchor + (d_hi_t - d_anchor) * band\n\n        s_lo, s_hi = min(s_lo, s_hi), max(s_lo, s_hi)\n        d_lo, d_hi = min(d_lo, d_hi), max(d_lo, d_hi)\n\n        if enforce_dynamic_le_static:\n            d_hi = min(d_hi, s_hi)\n            d_lo = min(d_lo, d_hi)\n\n        return (\n            float(s_lo),\n            float(s_hi),\n            float(d_lo),\n            float(d_hi),\n            float(frac),\n            int(level_i),\n        )\n\n    level, stats, changed, need_apply = _completion_rate_curriculum_get_level(\n        env,\n        term_tag=\"fric\",\n        num_updates=num_updates,\n        cr_thresholds=cr_thresholds,\n        min_steps_per_level=min_steps_per_level,\n        cooldown_steps=cooldown_steps,\n        state_prefix=state_prefix,\n    )\n\n    if not need_apply:\n        return float(level)\n\n    s_lo, s_hi, d_lo, d_hi, frac, level_i = _compute_ranges(int(level))\n\n    DomainRandFunctions._get_dr_rigid_body_material(\n        env=env,\n        env_ids=None,\n        asset_name=asset_name,\n        body_names=body_names,\n        static_friction_range=(s_lo, s_hi),\n        dynamic_friction_range=(d_lo, d_hi),\n        restitution_range=tuple(restitution_range),\n        num_buckets=int(num_buckets),\n    )\n\n    setattr(base_env, f\"{state_prefix}_applied_fric_level_{int(level)}\", True)\n    return float(level)\n\n\ndef rigid_body_com_by_completion_rate(\n    env: ManagerBasedRLEnv,\n    env_ids: Sequence[int],\n    *,\n    num_updates: int = 5,\n    cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),\n    min_steps_per_level: int = 300,\n    cooldown_steps: int = 0,\n    state_prefix: str = \"_cr_curr\",\n    asset_name: str = \"robot\",\n    body_names: str = \"torso_link\",\n    com_range_target: dict = {\n        \"x\": (-0.025, 0.025),\n        \"y\": (-0.05, 0.05),\n        \"z\": (-0.05, 0.05),\n    },\n    anchor_quantile: float = 0.5,\n    min_expand_frac: float = 0.0,\n):\n    base_env = getattr(env, \"unwrapped\", env)\n\n    def _quantile(lo: float, hi: float, q: float) -> float:\n        lo, hi = float(min(lo, hi)), float(max(lo, hi))\n        q = float(max(0.0, min(1.0, q)))\n        return lo + (hi - lo) * q\n\n    level, stats, changed, need_apply = _completion_rate_curriculum_get_level(\n        env,\n        term_tag=\"com\",\n        num_updates=num_updates,\n        cr_thresholds=cr_thresholds,\n        min_steps_per_level=min_steps_per_level,\n        cooldown_steps=cooldown_steps,\n        state_prefix=state_prefix,\n    )\n\n    if not need_apply:\n        return float(level)\n\n    level_i = int(max(0, min(int(level), int(num_updates))))\n    frac = (\n        1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates)))\n    )\n    band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float(\n        max(frac, 0.0)\n    )\n\n    com_range = {}\n    for axis, (lo_t, hi_t) in com_range_target.items():\n        lo_t, hi_t = float(lo_t), float(hi_t)\n        lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t)\n\n        anchor = _quantile(lo_t, hi_t, anchor_quantile)\n        lo = anchor - (anchor - lo_t) * band\n        hi = anchor + (hi_t - anchor) * band\n        com_range[axis] = (float(min(lo, hi)), float(max(lo, hi)))\n\n    DomainRandFunctions._get_dr_rigid_body_com(\n        env=env,\n        env_ids=None,\n        com_range=com_range,\n        asset_name=asset_name,\n        body_names=body_names,\n    )\n\n    setattr(base_env, f\"{state_prefix}_applied_com_level_{int(level)}\", True)\n    return float(level)\n\n\ndef default_dof_pos_bias_by_completion_rate(\n    env: ManagerBasedRLEnv,\n    env_ids: Sequence[int],\n    *,\n    num_updates: int = 5,\n    cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),\n    min_steps_per_level: int = 300,\n    cooldown_steps: int = 0,\n    state_prefix: str = \"_cr_curr\",\n    asset_name: str = \"robot\",\n    joint_names: list[str] = (\".*\"),\n    pos_distribution_params_target: tuple[float, float] = (-0.01, 0.01),\n    operation: str = \"add\",\n    distribution: str = \"uniform\",\n    anchor_quantile: float = 0.5,\n    min_expand_frac: float = 0.0,\n):\n    base_env = getattr(env, \"unwrapped\", env)\n\n    level, stats, changed, need_apply = _completion_rate_curriculum_get_level(\n        env,\n        term_tag=\"dof\",\n        num_updates=num_updates,\n        cr_thresholds=cr_thresholds,\n        min_steps_per_level=min_steps_per_level,\n        cooldown_steps=cooldown_steps,\n        state_prefix=state_prefix,\n    )\n\n    if not need_apply:\n        return float(level)\n\n    def _quantile(lo: float, hi: float, q: float) -> float:\n        lo, hi = float(min(lo, hi)), float(max(lo, hi))\n        q = float(max(0.0, min(1.0, q)))\n        return lo + (hi - lo) * q\n\n    lo_t, hi_t = map(float, pos_distribution_params_target)\n    lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t)\n\n    level_i = int(max(0, min(int(level), int(num_updates))))\n    frac = (\n        1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates)))\n    )\n    band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float(\n        max(frac, 0.0)\n    )\n\n    anchor = _quantile(lo_t, hi_t, anchor_quantile)\n    lo = anchor - (anchor - lo_t) * band\n    hi = anchor + (hi_t - anchor) * band\n    lo, hi = float(min(lo, hi)), float(max(lo, hi))\n\n    DomainRandFunctions._get_dr_default_dof_pos_bias(\n        env=env,\n        env_ids=None,\n        asset_name=asset_name,\n        joint_names=joint_names,\n        pos_distribution_params=(lo, hi),\n        operation=operation,\n        distribution=distribution,\n    )\n\n    setattr(base_env, f\"{state_prefix}_applied_dof_level_{int(level)}\", True)\n    return float(level)\n\n\ndef push_by_setting_velocity_range_by_completion_rate(\n    env: ManagerBasedRLEnv,\n    env_ids: Sequence[int],\n    old_value,\n    *,\n    num_updates: int = 5,\n    cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),\n    min_steps_per_level: int = 300,\n    cooldown_steps: int = 0,\n    state_prefix: str = \"_cr_curr\",\n    velocity_range_target: dict = {\n        \"x\": (-0.5, 0.5),\n        \"y\": (-0.5, 0.5),\n        \"z\": (-0.2, 0.2),\n        \"roll\": (-0.52, 0.52),\n        \"pitch\": (-0.52, 0.52),\n        \"yaw\": (-0.78, 0.78),\n    },\n    anchor_quantile: float = 0.5,\n    min_expand_frac: float = 0.0,\n):\n    base_env = getattr(env, \"unwrapped\", env)\n\n    def _quantile(lo: float, hi: float, q: float) -> float:\n        lo, hi = float(min(lo, hi)), float(max(lo, hi))\n        q = float(max(0.0, min(1.0, q)))\n        return lo + (hi - lo) * q\n\n    level, stats, changed, need_apply = _completion_rate_curriculum_get_level(\n        env,\n        term_tag=\"push\",\n        num_updates=num_updates,\n        cr_thresholds=cr_thresholds,\n        min_steps_per_level=min_steps_per_level,\n        cooldown_steps=cooldown_steps,\n        state_prefix=state_prefix,\n    )\n\n    if not need_apply:\n        return isaaclab_mdp.modify_term_cfg.NO_CHANGE\n\n    level_i = int(max(0, min(int(level), int(num_updates))))\n    frac = (\n        1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates)))\n    )\n    band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float(\n        max(frac, 0.0)\n    )\n\n    new_params = dict(old_value) if isinstance(old_value, dict) else old_value\n\n    current_velocity_range = {}\n    for axis, (lo_t, hi_t) in velocity_range_target.items():\n        lo_t, hi_t = float(lo_t), float(hi_t)\n        lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t)\n\n        anchor = _quantile(lo_t, hi_t, anchor_quantile)\n        lo = anchor - (anchor - lo_t) * band\n        hi = anchor + (hi_t - anchor) * band\n        current_velocity_range[axis] = [float(min(lo, hi)), float(max(lo, hi))]\n\n    if isinstance(new_params, dict) or hasattr(new_params, \"__setitem__\"):\n        if isinstance(new_params, dict):\n            new_params = dict(new_params)\n            new_params[\"velocity_range\"] = current_velocity_range\n        else:\n            new_params[\"velocity_range\"] = current_velocity_range\n    else:\n        setattr(new_params, \"velocity_range\", current_velocity_range)\n\n    setattr(base_env, f\"{state_prefix}_applied_push_level_{int(level)}\", True)\n    return new_params\n\n\ndef randomize_actuator_gains_by_completion_rate(\n    env: ManagerBasedRLEnv,\n    env_ids: Sequence[int],\n    *,\n    num_updates: int = 5,\n    cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),\n    min_steps_per_level: int = 300,\n    cooldown_steps: int = 0,\n    state_prefix: str = \"_cr_curr\",\n    asset_name: str = \"robot\",\n    body_names: str = \".*\",\n    stiffness_distribution_params_target: tuple[float, float] = (0.9, 1.1),\n    damping_distribution_params_target: tuple[float, float] = (0.9, 1.1),\n    operation: str = \"scale\",\n    distribution: str = \"uniform\",\n    anchor_quantile: float = 0.5,\n    min_expand_frac: float = 0.0,\n):\n    base_env = getattr(env, \"unwrapped\", env)\n\n    level, stats, changed, need_apply = _completion_rate_curriculum_get_level(\n        env,\n        term_tag=\"gains\",\n        num_updates=num_updates,\n        cr_thresholds=cr_thresholds,\n        min_steps_per_level=min_steps_per_level,\n        cooldown_steps=cooldown_steps,\n        state_prefix=state_prefix,\n    )\n\n    if not need_apply:\n        return float(level)\n\n    def _quantile(lo: float, hi: float, q: float) -> float:\n        lo, hi = float(min(lo, hi)), float(max(lo, hi))\n        q = float(max(0.0, min(1.0, q)))\n        return lo + (hi - lo) * q\n\n    level_i = int(max(0, min(int(level), int(num_updates))))\n    frac = (\n        1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates)))\n    )\n    band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float(\n        max(frac, 0.0)\n    )\n\n    # stiffness\n    ks_lo_t, ks_hi_t = map(float, stiffness_distribution_params_target)\n    ks_lo_t, ks_hi_t = min(ks_lo_t, ks_hi_t), max(ks_lo_t, ks_hi_t)\n    ks_anchor = _quantile(ks_lo_t, ks_hi_t, anchor_quantile)\n    ks_lo = ks_anchor - (ks_anchor - ks_lo_t) * band\n    ks_hi = ks_anchor + (ks_hi_t - ks_anchor) * band\n    ks_lo, ks_hi = float(min(ks_lo, ks_hi)), float(max(ks_lo, ks_hi))\n\n    # damping\n    kd_lo_t, kd_hi_t = map(float, damping_distribution_params_target)\n    kd_lo_t, kd_hi_t = min(kd_lo_t, kd_hi_t), max(kd_lo_t, kd_hi_t)\n    kd_anchor = _quantile(kd_lo_t, kd_hi_t, anchor_quantile)\n    kd_lo = kd_anchor - (kd_anchor - kd_lo_t) * band\n    kd_hi = kd_anchor + (kd_hi_t - kd_anchor) * band\n    kd_lo, kd_hi = float(min(kd_lo, kd_hi)), float(max(kd_lo, kd_hi))\n\n    DomainRandFunctions._get_dr_randomize_actuator_gains(\n        env=env,\n        env_ids=None,\n        asset_name=asset_name,\n        body_names=body_names,\n        stiffness_distribution_params=(ks_lo, ks_hi),\n        damping_distribution_params=(kd_lo, kd_hi),\n        operation=operation,\n        distribution=distribution,\n    )\n\n    setattr(base_env, f\"{state_prefix}_applied_gains_level_{int(level)}\", True)\n    return float(level)\n\n\ndef reward_term_weight_by_completion_rate(\n    env,\n    env_ids,\n    *,\n    reward_term_name: str,\n    final_weight: float,\n    start_scale: float = 0.1,\n    num_updates: int = 5,\n    cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),\n    min_steps_per_level: int = 300,\n    cooldown_steps: int = 0,\n    state_prefix: str = \"_cr_curr\",\n):\n    base_env = getattr(env, \"unwrapped\", env)\n\n    level, stats, changed, need_apply = _completion_rate_curriculum_get_level(\n        env,\n        term_tag=f\"reward_{reward_term_name}\",\n        num_updates=num_updates,\n        cr_thresholds=cr_thresholds,\n        min_steps_per_level=min_steps_per_level,\n        cooldown_steps=cooldown_steps,\n        state_prefix=state_prefix,\n    )\n\n    progress = 1.0 if num_updates <= 0 else float(level) / float(num_updates)\n    start_weight = float(final_weight) * float(start_scale)\n    new_weight = start_weight + progress * (float(final_weight) - start_weight)\n\n    reward_cfg = env.reward_manager.get_term_cfg(reward_term_name)\n    old_weight = float(reward_cfg.weight)\n\n    if not need_apply:\n        return float(level)\n\n    reward_cfg.weight = float(new_weight)\n    env.reward_manager.set_term_cfg(reward_term_name, reward_cfg)\n\n    setattr(\n        base_env,\n        f\"{state_prefix}_reward_weight_{reward_term_name}\",\n        float(new_weight),\n    )\n    setattr(\n        base_env,\n        f\"{state_prefix}_applied_reward_{reward_term_name}_level_{int(level)}\",\n        True,\n    )\n    return float(level)\n\n\n@configclass\nclass CurriculumCfg:\n    pass\n\n\ndef build_curriculum_config(curriculum_config_dict: dict) -> CurriculumCfg:\n    \"\"\"\n    Build IsaacLab-compatible CurriculumCfg from a config dictionary.\n    \"\"\"\n    if isinstance(curriculum_config_dict, (DictConfig, ListConfig)):\n        curriculum_config_dict = OmegaConf.to_container(\n            curriculum_config_dict, resolve=True\n        )\n\n    curriculum_cfg = CurriculumCfg()\n    cfg_dict: Dict[str, Any] = dict(curriculum_config_dict or {})\n\n    def _resolve_callable(name: Any) -> Callable:\n        if callable(name):\n            return name\n\n        if isinstance(name, str) and name.startswith(\"isaaclab_mdp.\"):\n            name = name.split(\".\", 1)[1]\n\n        fn = globals().get(name)\n        if callable(fn):\n            return fn\n\n        fn = getattr(isaaclab_mdp, name, None)\n        if callable(fn):\n            return fn\n\n        if hasattr(isaaclab_mdp, \"curriculums\"):\n            fn = getattr(isaaclab_mdp.curriculums, name, None)\n            if callable(fn):\n                return fn\n\n        raise ValueError(f\"Unknown curriculum function: {name}\")\n\n    def _normalize_modify_params(x: Any) -> Any:\n        if isinstance(x, list):\n            # many configs express tuples as YAML lists\n            return tuple(_normalize_modify_params(v) for v in x)\n        if isinstance(x, dict):\n            return {k: _normalize_modify_params(v) for k, v in x.items()}\n        return x\n\n    def _fix_params(params: Dict[str, Any]) -> Dict[str, Any]:\n        params = dict(params or {})\n\n        if \"modify_fn\" in params and isinstance(\n            params[\"modify_fn\"], (str, Callable)\n        ):\n            params[\"modify_fn\"] = _resolve_callable(params[\"modify_fn\"])\n\n        if \"modify_params\" in params and isinstance(\n            params[\"modify_params\"], dict\n        ):\n            params[\"modify_params\"] = _normalize_modify_params(\n                params[\"modify_params\"]\n            )\n\n        return params\n\n    global_enabled = cfg_dict.pop(\"enabled\", True)\n    if not global_enabled:\n        return curriculum_cfg\n\n    for term_name, term_cfg in cfg_dict.items():\n        if term_cfg is None:\n            term_cfg = {}\n\n        if isinstance(term_cfg, bool):\n            if not term_cfg:\n                continue\n            term_cfg = {}\n\n        if not isinstance(term_cfg, dict):\n            raise TypeError(\n                f\"[build_curriculum_config] term '{term_name}' must be a dict/bool/None, got {type(term_cfg)}\"\n            )\n\n        if not term_cfg.get(\"enabled\", True):\n            continue\n\n        func_field = term_cfg.get(\"func\", None)\n        if func_field is None:\n            func = _resolve_callable(term_name)\n        else:\n            func = _resolve_callable(func_field)\n\n        params = _fix_params(term_cfg.get(\"params\", {}) or {})\n\n        setattr(\n            curriculum_cfg,\n            term_name,\n            CurriculumTermCfg(func=func, params=params),\n        )\n\n    return curriculum_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_domain_rand.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport torch\nfrom typing import Literal\n\nimport isaaclab.utils.math as math_utils\nfrom isaaclab.assets import Articulation\n\nimport isaaclab.envs.mdp as isaaclab_mdp\nfrom isaaclab.envs.mdp.events import _randomize_prop_by_op\nfrom isaaclab.managers import SceneEntityCfg, EventTermCfg\nfrom isaaclab.utils import configclass\n\n\nfrom isaaclab.envs import ManagerBasedEnv\nfrom isaaclab.managers import EventTermCfg\n\n\nclass DomainRandFunctions:\n    @staticmethod\n    def _get_dr_default_dof_pos_bias(\n        env: ManagerBasedEnv,\n        env_ids: torch.Tensor | None,\n        asset_name: str = \"robot\",\n        joint_names: list[str] = (\".*\"),\n        pos_distribution_params: tuple[float, float] | None = None,\n        operation: Literal[\"add\", \"scale\", \"abs\"] = \"abs\",\n        distribution: Literal[\n            \"uniform\", \"log_uniform\", \"gaussian\"\n        ] = \"uniform\",\n    ):\n        asset_cfg = SceneEntityCfg(asset_name, joint_names=joint_names)\n        asset_cfg.resolve(env.scene)\n        asset: Articulation = env.scene[asset_name]\n        asset.data.default_joint_pos_nominal = torch.clone(\n            asset.data.default_joint_pos[0]\n        )\n\n        if env_ids is None:\n            env_ids = torch.arange(env.scene.num_envs, device=asset.device)\n\n        if asset_cfg.joint_ids == slice(None):\n            joint_ids = slice(None)\n        else:\n            joint_ids = torch.tensor(\n                asset_cfg.joint_ids,\n                dtype=torch.int,\n                device=asset.device,\n            )\n\n        if pos_distribution_params is not None:\n            pos = asset.data.default_joint_pos.to(asset.device).clone()\n            pos = _randomize_prop_by_op(\n                pos,\n                pos_distribution_params,\n                env_ids,\n                joint_ids,\n                operation=operation,\n                distribution=distribution,\n            )[env_ids][:, joint_ids]\n\n            if env_ids != slice(None) and joint_ids != slice(None):\n                env_ids = env_ids[:, None]\n            asset.data.default_joint_pos[env_ids, joint_ids] = pos\n            env.action_manager.get_term(\"dof_pos\")._offset[\n                env_ids, joint_ids\n            ] = pos\n\n    @staticmethod\n    def _get_dr_rigid_body_com(\n        env: ManagerBasedEnv,\n        env_ids: torch.Tensor | None,\n        com_range: dict[str, tuple[float, float]],\n        asset_name: str = \"robot\",\n        body_names: str = \"torso_link\",\n    ):\n        asset_cfg = SceneEntityCfg(asset_name, body_names=body_names)\n        asset_cfg.resolve(env.scene)\n        return isaaclab_mdp.events.randomize_rigid_body_com(\n            env,\n            env_ids,\n            com_range,\n            asset_cfg,\n        )\n\n    @staticmethod\n    def _get_dr_rigid_body_material(\n        env: ManagerBasedEnv,\n        env_ids: torch.Tensor | None,\n        asset_name: str = \"robot\",\n        body_names: str = \".*\",\n        static_friction_range: tuple[float, float] | None = None,\n        dynamic_friction_range: tuple[float, float] | None = None,\n        restitution_range: tuple[float, float] | None = None,\n        num_buckets: int = 64,\n    ):\n        asset_cfg = SceneEntityCfg(asset_name, body_names=body_names)\n        asset_cfg.resolve(env.scene)\n        eveent_cfg = EventTermCfg(\n            func=isaaclab_mdp.events.randomize_rigid_body_material,\n            params={\n                \"asset_cfg\": asset_cfg,\n                \"static_friction_range\": static_friction_range,\n                \"dynamic_friction_range\": dynamic_friction_range,\n                \"restitution_range\": restitution_range,\n                \"num_buckets\": num_buckets,\n            },\n        )\n        material_randomizer = (\n            isaaclab_mdp.events.randomize_rigid_body_material(eveent_cfg, env)\n        )\n        return material_randomizer(env, env_ids, **eveent_cfg.params)\n\n    @staticmethod\n    def _get_dr_push_by_setting_velocity(\n        env: ManagerBasedEnv,\n        env_ids: torch.Tensor,\n        velocity_range: dict[str, tuple[float, float]],\n    ):\n        return isaaclab_mdp.events.push_by_setting_velocity(\n            env,\n            env_ids,\n            velocity_range,\n        )\n\n    @staticmethod\n    def _get_dr_randomize_actuator_gains(\n        env: ManagerBasedEnv,\n        env_ids: torch.Tensor,\n        asset_name: str = \"robot\",\n        body_names: str = \".*\",\n        stiffness_distribution_params: tuple[float, float] | None = None,\n        damping_distribution_params: tuple[float, float] | None = None,\n        operation: Literal[\"add\", \"scale\", \"abs\"] = \"abs\",\n        distribution: Literal[\n            \"uniform\", \"log_uniform\", \"gaussian\"\n        ] = \"uniform\",\n    ):\n        asset_cfg = SceneEntityCfg(asset_name, body_names=body_names)\n        asset_cfg.resolve(env.scene)\n        return isaaclab_mdp.events.randomize_actuator_gains(\n            env,\n            env_ids,\n            asset_cfg,\n            stiffness_distribution_params,\n            damping_distribution_params,\n            operation=operation,\n            distribution=distribution,\n        )\n\n    @staticmethod\n    def _get_dr_randomize_mass(\n        env: ManagerBasedEnv,\n        env_ids: torch.Tensor,\n        asset_name: str = \"robot\",\n        body_names: str = \".*\",\n        mass_range: tuple[float, float] | None = None,\n    ):\n        asset_cfg = SceneEntityCfg(asset_name, body_names=body_names)\n        asset_cfg.resolve(env.scene)\n        return isaaclab_mdp.events.randomize_rigid_body_mass(\n            env,\n            env_ids,\n            mass_distribution_params=mass_range,\n            asset_cfg=asset_cfg,\n            operation=\"add\",\n        )\n\n\n@configclass\nclass EventsCfg:\n    pass\n\n\ndef build_domain_rand_config(domain_rand_config_dict: dict) -> EventsCfg:\n    \"\"\"Build IsaacLab-compatible EventsCfg from a config dictionary.\"\"\"\n    events_cfg = EventsCfg()\n\n    for event_name, cfg in domain_rand_config_dict.items():\n        # Keep non-event config under `domain_rand` available for Hydra\n        # references without forcing it through the Isaac Lab event builder.\n        if not (isinstance(cfg, dict) and \"mode\" in cfg):\n            continue\n\n        try:\n            func = getattr(DomainRandFunctions, f\"_get_dr_{event_name}\")\n        except AttributeError as exc:\n            raise AttributeError(\n                f\"Unknown domain randomization event '{event_name}'\"\n            ) from exc\n        term = EventTermCfg(\n            func=func,\n            **cfg,\n        )\n        setattr(events_cfg, event_name, term)\n\n    return events_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_motion_tracking_command.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nfrom dataclasses import MISSING\nfrom typing import Sequence\nimport time\nimport json\n\nfrom collections import defaultdict\nfrom typing import Dict, List, Optional\nimport numpy as np\nfrom tqdm import tqdm\nfrom scipy.spatial.transform import Rotation as sRot\n\nimport isaaclab.envs.mdp as mdp\nimport isaaclab.sim as sim_utils\n\n\nimport isaaclab.utils.math as isaaclab_math\nimport torch\nfrom isaaclab.actuators import ImplicitActuatorCfg\nfrom isaaclab.assets import Articulation, ArticulationCfg, AssetBaseCfg\nfrom isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg\nfrom isaaclab.envs.mdp.actions import JointEffortActionCfg\nfrom isaaclab.managers import (\n    ActionTermCfg,\n    CommandTerm,\n    CommandTermCfg,\n    EventTermCfg as EventTerm,\n    ObservationGroupCfg,\n    ObservationGroupCfg as ObsGroup,\n    ObservationTermCfg,\n    ObservationTermCfg as ObsTerm,\n    RewardTermCfg,\n    TerminationTermCfg,\n)\nfrom isaaclab.markers import (\n    VisualizationMarkers,\n    VisualizationMarkersCfg,\n)\nfrom holomotion.src.training.h5_dataloader import (\n    Hdf5MotionDataset,\n    Hdf5RootDofDataset,\n    MotionClipBatchCache,\n    build_motion_datasets_from_cfg,\n)\nimport os\nfrom isaaclab.markers.config import SPHERE_MARKER_CFG\nfrom isaaclab.sim import PreviewSurfaceCfg\nfrom isaaclab.scene import InteractiveSceneCfg\nfrom isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns\nfrom isaaclab.sim import PhysxCfg, SimulationCfg\nfrom isaaclab.terrains import TerrainImporterCfg\nfrom isaaclab.utils import configclass\nfrom isaaclab.utils.noise import AdditiveUniformNoiseCfg as Unoise\nfrom omegaconf import OmegaConf\n\nfrom holomotion.src.utils.isaac_utils.rotations import (\n    calc_heading_quat_inv,\n    get_euler_xyz,\n    my_quat_rotate,\n    quat_inverse,\n    quat_mul,\n    quat_rotate,\n    quat_rotate_inverse,\n    quaternion_to_matrix,\n    wrap_to_pi,\n    wxyz_to_xyzw,\n    xyzw_to_wxyz,\n)\nfrom holomotion.src.utils.reference_prefix import (\n    resolve_reference_tensor_key,\n)\nfrom loguru import logger\n\n\nclass RefMotionCommand(CommandTerm):\n    cfg: CommandTermCfg\n\n    def __init__(\n        self,\n        cfg,\n        env: ManagerBasedRLEnv,\n    ):\n        # print(cfg)\n        super().__init__(cfg, env)\n        self._env = env\n        self._is_evaluating = self.cfg.is_evaluating\n        self._runtime_process_id = int(self.cfg.process_id)\n        self._runtime_num_processes = max(1, int(self.cfg.num_processes))\n\n        self._init_robot_handle()\n        self._init_buffers()\n        self._init_motion_lib()\n\n    #     # self._init_tracking_config()\n\n    def _init_tracking_config(self, config):\n        self.log_dict_holomotion = {}\n        self.log_dict_nonreduced_holomotion = {}\n        self.log_dict_nonreduced = {}\n        self.log_dict = {}\n        if \"head_hand_bodies\" in config:\n            self.motion_tracking_id = [\n                self.robot.body_names.index(link)\n                for link in config.head_hand_bodies\n            ]\n        if \"leg_body_names\" in config:\n            self.lower_body_id = [\n                self.robot.body_names.index(link)\n                for link in config.leg_body_names\n            ]\n        if \"arm_body_names\" in config:\n            self.upper_body_id = [\n                self.robot.body_names.index(link)\n                for link in config.arm_body_names\n            ]\n        if \"leg_dof_names\" in config:\n            self.lower_body_joint_ids = [\n                config.dof_names.index(link) for link in config.leg_dof_names\n            ]\n        if \"arm_dof_names\" in config:\n            self.upper_body_joint_ids = [\n                config.dof_names.index(link) for link in config.arm_dof_names\n            ]\n\n        if \"waist_dof_names\" in config:\n            self.waist_dof_indices = [\n                config.dof_names.index(link) for link in config.waist_dof_names\n            ]\n\n    @staticmethod\n    def _amp_filter_names_by_prefix(\n        names: Sequence[str], prefix: str, keywords: Sequence[str]\n    ) -> list[str]:\n        return [\n            name\n            for name in names\n            if name.startswith(prefix) and any(key in name for key in keywords)\n        ]\n\n    @staticmethod\n    def _amp_pick_first_name(\n        names: Sequence[str], patterns: Sequence[str]\n    ) -> str | None:\n        for pattern in patterns:\n            for name in names:\n                if pattern in name:\n                    return name\n        return None\n\n    def _resolve_motion_cache_stage_device(\n        self, cache_cfg: Dict[str, object]\n    ) -> Optional[torch.device]:\n        raw_stage_device = cache_cfg.get(\"device\", \"cuda\")\n        if isinstance(raw_stage_device, torch.device):\n            if raw_stage_device.type == \"cpu\":\n                return None\n            if raw_stage_device.type != \"cuda\":\n                raise ValueError(\n                    f\"Unsupported motion cache device: {raw_stage_device}\"\n                )\n            if raw_stage_device.index is not None:\n                return raw_stage_device\n            if not torch.cuda.is_available():\n                return None\n            local_rank_env = os.environ.get(\"LOCAL_RANK\")\n            if local_rank_env is not None:\n                local_rank = int(local_rank_env)\n                device_count = int(torch.cuda.device_count())\n                if 0 <= local_rank < device_count:\n                    return torch.device(\"cuda\", local_rank)\n            return torch.device(\"cuda\", int(torch.cuda.current_device()))\n\n        stage_device = str(raw_stage_device).strip().lower()\n        if stage_device in (\"none\", \"cpu\"):\n            return None\n        if stage_device == \"cuda\":\n            if isinstance(self.device, torch.device):\n                if self.device.type == \"cuda\":\n                    return self.device\n                return None\n            device_str = str(self.device).strip().lower()\n            if device_str.startswith(\"cuda\"):\n                return torch.device(device_str)\n            if not torch.cuda.is_available():\n                return None\n            local_rank_env = os.environ.get(\"LOCAL_RANK\")\n            if local_rank_env is not None:\n                local_rank = int(local_rank_env)\n                device_count = int(torch.cuda.device_count())\n                if 0 <= local_rank < device_count:\n                    return torch.device(\"cuda\", local_rank)\n            return torch.device(\"cuda\", int(torch.cuda.current_device()))\n        if stage_device.startswith(\"cuda:\"):\n            return torch.device(stage_device)\n        raise ValueError(\n            f\"Unsupported motion cache device config: {raw_stage_device}\"\n        )\n\n    def _init_motion_lib(self):\n        mcfg = OmegaConf.create(self.cfg.motion_lib_cfg)\n        self.mcfg = mcfg\n        backend = str(mcfg.get(\"backend\", \"hdf5\")).lower()\n        self._motion_cache = None\n        if backend in (\"hdf5\", \"hdf5_simple\"):\n            # Support multi-root configuration while keeping single-root\n            # behavior fully backward compatible.\n            train_hdf5_roots = mcfg.get(\"train_hdf5_roots\", None)\n            val_hdf5_roots = mcfg.get(\"val_hdf5_roots\", None)\n\n            if train_hdf5_roots:\n                train_roots = [str(r) for r in train_hdf5_roots]\n            else:\n                hdf5_root = mcfg.get(\"hdf5_root\")\n                if hdf5_root is None:\n                    raise ValueError(\"hdf5_root is required\")\n                train_roots = [str(hdf5_root)]\n\n            val_hdf5_root = mcfg.get(\"val_hdf5_root\", None)\n            if val_hdf5_roots:\n                val_roots = [str(r) for r in val_hdf5_roots]\n            elif val_hdf5_root is not None and str(val_hdf5_root) != str(\n                train_roots[0]\n            ):\n                val_roots = [str(val_hdf5_root)]\n            else:\n                val_roots = None\n\n            train_manifest_paths = [\n                os.path.join(root, \"manifest.json\") for root in train_roots\n            ]\n            for mp in train_manifest_paths:\n                if not os.path.exists(mp):\n                    raise FileNotFoundError(\n                        f\"HDF5 manifest not found at {mp}. \"\n                        \"Please set robot.motion.hdf5_root/train_hdf5_roots to \"\n                        \"the correct path!\"\n                    )\n\n            max_frame_length = int(mcfg.get(\"max_frame_length\", 500))\n            min_frame_length = int(mcfg.get(\"min_frame_length\", 1))\n            world_frame_norm = bool(\n                mcfg.get(\"world_frame_normalization\", True)\n            )\n\n            cache_cfg = mcfg.get(\"cache\", {})\n            allowed_prefixes = cache_cfg.get(\n                \"allowed_prefixes\",\n                [\"ref_\", \"ft_ref_\"],\n            )\n\n            if len(train_manifest_paths) == 1:\n                logger.info(\n                    f\"Loading HDF5 training dataset from {train_manifest_paths[0]}\"\n                )\n            else:\n                logger.info(\n                    f\"Loading HDF5 training dataset from manifests: \"\n                    f\"{train_manifest_paths}\"\n                )\n            train_dataset = Hdf5MotionDataset(\n                manifest_path=train_manifest_paths\n                if len(train_manifest_paths) > 1\n                else train_manifest_paths[0],\n                max_frame_length=max_frame_length,\n                min_window_length=min_frame_length,\n                handpicked_motion_names=mcfg.get(\n                    \"handpicked_motion_names\", None\n                ),\n                excluded_motion_names=mcfg.get(\"excluded_motion_names\", None),\n                world_frame_normalization=world_frame_norm,\n                allowed_prefixes=allowed_prefixes,\n            )\n            if len(train_dataset) == 0:\n                raise ValueError(\n                    \"Training dataset is empty. Check that all manifests \"\n                    \"contain valid clips with length \"\n                    f\">= {min_frame_length}\"\n                )\n            logger.info(f\"Loaded {len(train_dataset)} training motion windows\")\n            train_num_clips = len(train_dataset.clips)\n            train_total_frames = sum(\n                int(meta.get(\"length\", 0))\n                for meta in train_dataset.clips.values()\n            )\n            fps_used = int(self.cfg.target_fps)\n            train_duration_s = (\n                float(train_total_frames) / float(fps_used)\n                if fps_used > 0\n                else 0.0\n            )\n            if len(train_roots) == 1:\n                logger.info(\n                    f\"Train dataset: root={train_roots[0]}, \"\n                    f\"manifest={train_manifest_paths[0]}\"\n                )\n            else:\n                logger.info(\n                    f\"Train dataset: roots={train_roots}, \"\n                    f\"manifests={train_manifest_paths}\"\n                )\n            logger.info(\n                f\"Train clips={train_num_clips}, frames={train_total_frames}, \"\n                f\"duration={train_duration_s / 3600:.2f}h @ {fps_used} fps\"\n            )\n            excluded_names = mcfg.get(\"excluded_motion_names\", None)\n            if excluded_names:\n                excluded_set = set(excluded_names)\n                excluded_clip_keys = [\n                    k for k in train_dataset.clips.keys() if k in excluded_set\n                ]\n                excluded_num_clips = len(excluded_clip_keys)\n                excluded_total_frames = sum(\n                    int(train_dataset.clips[k].get(\"length\", 0))\n                    for k in excluded_clip_keys\n                )\n                excluded_duration_s = (\n                    float(excluded_total_frames) / float(fps_used)\n                    if fps_used > 0\n                    else 0.0\n                )\n                left_num_clips = max(0, train_num_clips - excluded_num_clips)\n                left_total_frames = max(\n                    0, train_total_frames - excluded_total_frames\n                )\n                left_duration_s = (\n                    float(left_total_frames) / float(fps_used)\n                    if fps_used > 0\n                    else 0.0\n                )\n                logger.info(\n                    f\"Excluded (by name): clips={excluded_num_clips}, \"\n                    f\"frames={excluded_total_frames}, \"\n                    f\"duration={excluded_duration_s / 3600:.2f}h\"\n                )\n                logger.info(\n                    f\"Remaining after exclusion: clips={left_num_clips}, \"\n                    f\"frames={left_total_frames}, \"\n                    f\"duration={left_duration_s / 3600:.2f}h\"\n                )\n\n            val_dataset = None\n            if val_roots is not None:\n                val_manifest_paths = [\n                    os.path.join(root, \"manifest.json\") for root in val_roots\n                ]\n                for mp in val_manifest_paths:\n                    if not os.path.exists(mp):\n                        raise FileNotFoundError(\n                            f\"HDF5 validation manifest not found at {mp}. \"\n                            \"Please set robot.motion.val_hdf5_root/\"\n                            \"val_hdf5_roots to the correct path!\"\n                        )\n                if len(val_manifest_paths) == 1:\n                    logger.info(\n                        f\"Loading HDF5 validation dataset from {val_manifest_paths[0]}\"\n                    )\n                else:\n                    logger.info(\n                        \"Loading HDF5 validation dataset from manifests: \"\n                        f\"{val_manifest_paths}\"\n                    )\n                val_dataset = Hdf5MotionDataset(\n                    manifest_path=val_manifest_paths\n                    if len(val_manifest_paths) > 1\n                    else val_manifest_paths[0],\n                    max_frame_length=max_frame_length,\n                    min_window_length=min_frame_length,\n                    handpicked_motion_names=mcfg.get(\n                        \"handpicked_motion_names\", None\n                    ),\n                    excluded_motion_names=mcfg.get(\n                        \"excluded_motion_names\", None\n                    ),\n                    world_frame_normalization=world_frame_norm,\n                    allowed_prefixes=allowed_prefixes,\n                )\n                logger.info(\n                    f\"Loaded {len(val_dataset)} validation motion windows\"\n                )\n                val_num_clips = len(val_dataset.clips)\n                val_total_frames = sum(\n                    int(meta.get(\"length\", 0))\n                    for meta in val_dataset.clips.values()\n                )\n                val_duration_s = (\n                    float(val_total_frames) / float(fps_used)\n                    if fps_used > 0\n                    else 0.0\n                )\n                if len(val_roots) == 1:\n                    logger.info(\n                        f\"Val dataset: root={val_roots[0]}, \"\n                        f\"manifest={val_manifest_paths[0]}\"\n                    )\n                else:\n                    logger.info(\n                        f\"Val dataset: roots={val_roots}, \"\n                        f\"manifests={val_manifest_paths}\"\n                    )\n                logger.info(\n                    f\"Val clips={val_num_clips}, frames={val_total_frames}, \"\n                    f\"duration={val_duration_s / 3600:.1f}h @ {fps_used} fps\"\n                )\n            else:\n                logger.info(\n                    \"Validation dataset: using training dataset \"\n                    \"(no separate val manifest found)\"\n                )\n\n            dataloader_cfg = mcfg.get(\"dataloader\", {})\n            stage_device = self._resolve_motion_cache_stage_device(cache_cfg)\n\n            self._motion_cache = MotionClipBatchCache(\n                train_dataset=train_dataset,\n                val_dataset=val_dataset,\n                batch_size=int(cache_cfg.get(\"max_num_clips\", 1024)),\n                stage_device=stage_device,\n                num_workers=int(dataloader_cfg.get(\"num_workers\", 4)),\n                prefetch_factor=dataloader_cfg.get(\"prefetch_factor\", None),\n                pin_memory=bool(dataloader_cfg.get(\"pin_memory\", True)),\n                persistent_workers=bool(\n                    dataloader_cfg.get(\"persistent_workers\", True)\n                ),\n                batch_progress_bar=bool(\n                    cache_cfg.get(\"batch_progress_bar\", False)\n                ),\n                sampler_rank=int(self.cfg.process_id),\n                sampler_world_size=int(self.cfg.num_processes),\n                allowed_prefixes=allowed_prefixes,\n                swap_interval_steps=int(\n                    cache_cfg.get(\"swap_interval_steps\", max_frame_length)\n                ),\n                seed=int(self.cfg.seed),\n                loader_timeout=float(dataloader_cfg.get(\"timeout\", 0.0)),\n            )\n            cache = self._motion_cache\n            logger.info(\n                \"DataLoader params: \"\n                f\"batch_size={cache._batch_size}, \"\n                f\"num_workers={cache._num_workers}, \"\n                f\"prefetch_factor={cache._prefetch_factor}, \"\n                f\"pin_memory={cache._pin_memory}, \"\n                f\"persistent_workers={cache._persistent_workers}\"\n            )\n            logger.info(\n                \"Sampler/Cache params: \"\n                f\"rank={cache._sampler_rank}/{cache._sampler_world_size}, \"\n                f\"device={cache._stage_device}, \"\n                f\"swap_interval_steps={cache.swap_interval_steps}\"\n            )\n            self._motion_lib = None\n\n        elif backend == \"hdf5_v2\":\n            max_frame_length = int(mcfg.get(\"max_frame_length\", 500))\n            min_frame_length = int(mcfg.get(\"min_frame_length\", 1))\n            world_frame_norm = bool(\n                mcfg.get(\"world_frame_normalization\", True)\n            )\n            cache_cfg = mcfg.get(\"cache\", {})\n            allowed_prefixes = cache_cfg.get(\n                \"allowed_prefixes\",\n                [\"ref_\", \"ft_ref_\"],\n            )\n\n            train_hdf5_roots = mcfg.get(\"train_hdf5_roots\", None)\n            if train_hdf5_roots:\n                train_roots = [str(r) for r in train_hdf5_roots]\n            else:\n                hdf5_root = mcfg.get(\"hdf5_root\", None)\n                train_roots = [str(hdf5_root)] if hdf5_root is not None else []\n            train_manifest_paths = [\n                os.path.join(root, \"manifest.json\") for root in train_roots\n            ]\n\n            (\n                train_dataset,\n                val_dataset,\n                cache_kwargs,\n            ) = build_motion_datasets_from_cfg(\n                motion_cfg=mcfg,\n                max_frame_length=max_frame_length,\n                min_window_length=min_frame_length,\n                world_frame_normalization=world_frame_norm,\n                handpicked_motion_names=mcfg.get(\n                    \"handpicked_motion_names\", None\n                ),\n                excluded_motion_names=mcfg.get(\"excluded_motion_names\", None),\n                allowed_prefixes=allowed_prefixes,\n            )\n            if len(train_dataset) == 0:\n                raise ValueError(\n                    \"Training dataset is empty. Check that all HDF5 v2 \"\n                    \"roots contain valid clips with length \"\n                    f\">= {min_frame_length}\"\n                )\n\n            if len(train_manifest_paths) == 1:\n                logger.info(\n                    f\"Loading HDF5 v2 training dataset from {train_manifest_paths[0]}\"\n                )\n            else:\n                logger.info(\n                    \"Loading HDF5 v2 training dataset from manifests: \"\n                    f\"{train_manifest_paths}\"\n                )\n            fps_used = int(self.cfg.target_fps)\n            logger.info(f\"Loaded {len(train_dataset)} training motion windows\")\n            train_num_clips = len(train_dataset.clips)\n            train_total_frames = sum(\n                int(meta.get(\"length\", 0))\n                for meta in train_dataset.clips.values()\n            )\n            train_duration_s = (\n                float(train_total_frames) / float(fps_used)\n                if fps_used > 0\n                else 0.0\n            )\n            logger.info(\n                f\"Train clips={train_num_clips}, frames={train_total_frames}, \"\n                f\"duration={train_duration_s / 3600:.2f}h @ {fps_used} fps\"\n            )\n            if len(train_roots) == 1:\n                logger.info(\n                    f\"Train dataset: root={train_roots[0]}, \"\n                    f\"manifest={train_manifest_paths[0]}\"\n                )\n            elif len(train_roots) > 1:\n                logger.info(\n                    f\"Train dataset: roots={train_roots}, \"\n                    f\"manifests={train_manifest_paths}\"\n                )\n            excluded_names = mcfg.get(\"excluded_motion_names\", None)\n            if excluded_names:\n                excluded_set = set(excluded_names)\n                excluded_clip_keys: List[str] = []\n                if isinstance(train_dataset, Hdf5RootDofDataset):\n                    for key, meta in train_dataset.clips.items():\n                        aliases = train_dataset._build_motion_key_aliases(\n                            key, meta\n                        )\n                        if any(alias in excluded_set for alias in aliases):\n                            excluded_clip_keys.append(key)\n                else:\n                    excluded_clip_keys = [\n                        k\n                        for k in train_dataset.clips.keys()\n                        if k in excluded_set\n                    ]\n                excluded_num_clips = len(excluded_clip_keys)\n                excluded_total_frames = sum(\n                    int(train_dataset.clips[k].get(\"length\", 0))\n                    for k in excluded_clip_keys\n                )\n                excluded_duration_s = (\n                    float(excluded_total_frames) / float(fps_used)\n                    if fps_used > 0\n                    else 0.0\n                )\n                remaining_num_clips = train_num_clips - excluded_num_clips\n                remaining_total_frames = (\n                    train_total_frames - excluded_total_frames\n                )\n                remaining_duration_s = train_duration_s - excluded_duration_s\n                logger.info(\n                    \"Excluded (by name): \"\n                    f\"clips={excluded_num_clips}, frames={excluded_total_frames}, \"\n                    f\"duration={excluded_duration_s / 3600:.2f}h\"\n                )\n                logger.info(\n                    \"Remaining after exclusion: \"\n                    f\"clips={remaining_num_clips}, frames={remaining_total_frames}, \"\n                    f\"duration={remaining_duration_s / 3600:.2f}h\"\n                )\n            if val_dataset is None:\n                logger.info(\n                    \"Validation dataset: using training dataset \"\n                    \"(no separate val HDF5 v2 roots found)\"\n                )\n\n            dataloader_cfg = mcfg.get(\"dataloader\", {})\n            stage_device = self._resolve_motion_cache_stage_device(cache_cfg)\n\n            self._motion_cache = MotionClipBatchCache(\n                train_dataset=train_dataset,\n                val_dataset=val_dataset,\n                batch_size=int(cache_cfg.get(\"max_num_clips\", 1024)),\n                stage_device=stage_device,\n                num_workers=int(dataloader_cfg.get(\"num_workers\", 4)),\n                prefetch_factor=dataloader_cfg.get(\"prefetch_factor\", None),\n                pin_memory=bool(dataloader_cfg.get(\"pin_memory\", True)),\n                persistent_workers=bool(\n                    dataloader_cfg.get(\"persistent_workers\", True)\n                ),\n                batch_progress_bar=bool(\n                    cache_cfg.get(\"batch_progress_bar\", False)\n                ),\n                sampler_rank=int(self.cfg.process_id),\n                sampler_world_size=int(self.cfg.num_processes),\n                allowed_prefixes=allowed_prefixes,\n                swap_interval_steps=int(\n                    cache_cfg.get(\"swap_interval_steps\", max_frame_length)\n                ),\n                seed=int(self.cfg.seed),\n                loader_timeout=float(dataloader_cfg.get(\"timeout\", 0.0)),\n                **cache_kwargs,\n            )\n            cache = self._motion_cache\n            logger.info(\n                \"DataLoader params: \"\n                f\"batch_size={cache._batch_size}, \"\n                f\"num_workers={cache._num_workers}, \"\n                f\"prefetch_factor={cache._prefetch_factor}, \"\n                f\"pin_memory={cache._pin_memory}, \"\n                f\"persistent_workers={cache._persistent_workers}\"\n            )\n            logger.info(\n                \"Sampler/Cache params: \"\n                f\"rank={cache._sampler_rank}/{cache._sampler_world_size}, \"\n                f\"device={cache._stage_device}, \"\n                f\"swap_interval_steps={cache.swap_interval_steps}\"\n            )\n            self._motion_lib = None\n\n        else:\n            raise ValueError(f\"Unsupported motion backend: {backend}\")\n\n        sampling_strategy_cfg = mcfg.get(\"sampling_strategy\", None)\n        if sampling_strategy_cfg is None:\n            sampling_strategy = \"uniform\"\n        else:\n            sampling_strategy = str(sampling_strategy_cfg).lower()\n        if sampling_strategy == \"weighted_bin\":\n            weighted_bin_cfg = mcfg.get(\"weighted_bin\", {})\n            self._motion_cache.enable_weighted_bin_sampling(\n                cfg=dict(weighted_bin_cfg or {})\n            )\n        elif sampling_strategy == \"curriculum\":\n            curriculum_cfg = dict(mcfg.get(\"curriculum\", {}) or {})\n            self._motion_cache.enable_cache_curriculum_sampling(\n                cfg=curriculum_cfg\n            )\n        elif sampling_strategy not in (\"uniform\", \"curriculum\"):\n            raise ValueError(\n                f\"Invalid sampling_strategy '{sampling_strategy}'. \"\n                \"Expected one of ['curriculum', 'uniform', 'weighted_bin'].\"\n            )\n\n        self._sampling_strategy = sampling_strategy\n\n        self._init_per_env_cache()\n\n    def setup_dumping_dir(self, log_dir: str):\n        mcfg = self.mcfg\n        base_log_dir = str(log_dir)\n\n        if self._sampling_strategy == \"curriculum\":\n            curriculum_dump_dir = os.path.join(\n                base_log_dir, \"cache_curriculum_window_scores\"\n            )\n            self._motion_cache.set_cache_curriculum_dump_dir(\n                curriculum_dump_dir\n            )\n\n        self._dump_sampled_motion_keys_enabled = bool(\n            mcfg.get(\"dump_sampled_motion_keys\", False)\n        )\n        if not self._dump_sampled_motion_keys_enabled:\n            return\n        self._dump_sampled_motion_keys_interval = max(\n            1, int(mcfg.get(\"dump_sampled_motion_keys_interval\", 1))\n        )\n        dump_dir_cfg = \"sampled_motion_cache_keys\"\n        self._dump_sampled_motion_keys_dir = os.path.join(\n            base_log_dir, dump_dir_cfg\n        )\n        if self._dump_sampled_motion_keys_enabled:\n            os.makedirs(self._dump_sampled_motion_keys_dir, exist_ok=True)\n            logger.info(\n                f\"Dumping sampled motion keys to {self._dump_sampled_motion_keys_dir}\"\n            )\n\n    def set_runtime_distributed_context(\n        self, *, process_id: int, num_processes: int\n    ) -> None:\n        self._runtime_process_id = int(process_id)\n        self._runtime_num_processes = max(1, int(num_processes))\n\n    def set_motion_cache_seed(\n        self, seed: int, *, reinitialize: bool = True\n    ) -> None:\n        self._motion_cache.set_seed(int(seed), reinitialize=reinitialize)\n        if reinitialize:\n            self._init_per_env_cache()\n\n    def close(self) -> None:\n        \"\"\"Release motion cache resources for this command term.\"\"\"\n        if self._motion_cache is not None:\n            self._motion_cache.close()\n            self._motion_cache = None\n\n    def _init_per_env_cache(self):\n        \"\"\"Initialize per-env cache for motion tracking.\"\"\"\n        self._clip_indices = torch.zeros(\n            self.num_envs, dtype=torch.long, device=self.device\n        )\n        self._frame_indices = torch.zeros(\n            self.num_envs, dtype=torch.long, device=self.device\n        )\n        self._swap_pending = False\n        self._swap_step_counter = 0\n\n        # Initial assignment\n        clip_idx, frame_idx = self._motion_cache.sample_env_assignments(\n            self.num_envs,\n            self.cfg.n_fut_frames,\n            self.device,\n            deterministic_start=(self._is_evaluating),\n        )\n        self._clip_indices[:] = clip_idx\n        self._frame_indices[:] = frame_idx\n        self._start_frame_indices[:] = frame_idx\n        self._reward_sum_since_assign[:] = 0.0\n        self._step_count_since_assign[:] = 0.0\n        self._update_ref_motion_state_from_cache()\n\n    def _maybe_dump_sampled_motion_keys(self) -> None:\n        if not self._dump_sampled_motion_keys_enabled:\n            return\n\n        swap_index = int(self._motion_cache.swap_index)\n        if swap_index <= 0:\n            return\n        if swap_index % self._dump_sampled_motion_keys_interval != 0:\n            return\n\n        current_batch = self._motion_cache.current_batch\n        window_indices = current_batch.window_indices.detach().cpu().tolist()\n        cache_scores = None\n        cache_selection_counts = None\n        cache_in_prioritized_pool = None\n        curriculum_state_step = None\n        score_bundle = (\n            self._motion_cache.cache_curriculum_scores_for_window_indices(\n                current_batch.window_indices\n            )\n        )\n        if score_bundle is not None:\n            score_tensor, state, version = score_bundle\n            cache_scores = score_tensor.detach().cpu().tolist()\n            cache_selection_counts = (\n                state[\"selection_count\"].detach().cpu().tolist()\n            )\n            cache_in_prioritized_pool = (\n                state[\"in_prioritized_pool\"].detach().cpu().tolist()\n            )\n            curriculum_state_step = int(version)\n        payload = {\n            \"swap_index\": swap_index,\n            \"sampling_strategy\": str(self._sampling_strategy),\n            \"num_keys\": int(len(current_batch.motion_keys)),\n            \"motion_keys\": list(current_batch.motion_keys),\n            \"raw_motion_keys\": list(current_batch.raw_motion_keys),\n            \"window_indices\": window_indices,\n            \"cache_sampling_score\": cache_scores,\n            \"cache_sampling_count\": cache_selection_counts,\n            \"cache_in_prioritized_pool\": cache_in_prioritized_pool,\n            \"curriculum_state_step\": curriculum_state_step,\n        }\n        file_name = (\n            f\"sampled_motion_keys_rank_{self._runtime_process_id:04d}_swap_\"\n            f\"{swap_index:06d}.json\"\n        )\n        output_path = os.path.join(\n            self._dump_sampled_motion_keys_dir, file_name\n        )\n        with open(output_path, \"w\", encoding=\"utf-8\") as handle:\n            json.dump(payload, handle, indent=2)\n            handle.write(\"\\n\")\n\n    def _init_robot_handle(self):\n        self.robot: Articulation = self._env.scene[self.cfg.asset_name]\n        self.anchor_bodylink_name = self.cfg.anchor_bodylink_name\n        self.anchor_bodylink_idx = self.robot.body_names.index(\n            self.anchor_bodylink_name\n        )\n        self.urdf_dof_names = self.cfg.urdf_dof_names\n        self.urdf_body_names = self.cfg.urdf_body_names\n        self.simulator_dof_names = self.robot.joint_names\n        self.simulator_body_names = self.robot.body_names\n        self.urdf2sim_dof_idx = [\n            self.urdf_dof_names.index(dof) for dof in self.simulator_dof_names\n        ]\n        self.urdf2sim_body_idx = [\n            self.urdf_body_names.index(body)\n            for body in self.simulator_body_names\n        ]\n        self.sim2urdf_dof_idx = [\n            self.simulator_dof_names.index(dof) for dof in self.urdf_dof_names\n        ]\n        self.sim2urdf_body_idx = [\n            self.simulator_body_names.index(body)\n            for body in self.urdf_body_names\n        ]\n\n        self.arm_dof_indices = [\n            self.simulator_dof_names.index(dof)\n            for dof in self.cfg.arm_dof_names\n        ]\n        self.torso_dof_indices = [\n            self.simulator_dof_names.index(dof)\n            for dof in self.cfg.waist_dof_names\n        ]\n        self.leg_dof_indices = [\n            self.simulator_dof_names.index(dof)\n            for dof in self.cfg.leg_dof_names\n        ]\n\n        # Body indices for mpkpe metrics using unified naming\n        self.arm_body_indices = [\n            self.simulator_body_names.index(body)\n            for body in self.cfg.arm_body_names\n        ]\n        self.torso_body_indices = [\n            self.simulator_body_names.index(body)\n            for body in self.cfg.torso_body_names\n        ]\n        self.leg_body_indices = [\n            self.simulator_body_names.index(body)\n            for body in self.cfg.leg_body_names\n        ]\n\n        # Per-env world origins (translation only)\n        # Shape: [num_envs, 3] on the same device as the sim\n        self._env_origins = self._env.scene.env_origins.to(self.device)\n\n        # AMP-style observation indices (RSL reference alignment)\n        urdf_dof_name_to_idx = {\n            name: idx for idx, name in enumerate(self.urdf_dof_names)\n        }\n        sim_dof_name_to_idx = {\n            name: idx for idx, name in enumerate(self.simulator_dof_names)\n        }\n        urdf_body_name_to_idx = {\n            name: idx for idx, name in enumerate(self.urdf_body_names)\n        }\n        sim_body_name_to_idx = {\n            name: idx for idx, name in enumerate(self.simulator_body_names)\n        }\n\n        left_arm_dof_names = list(\n            getattr(self.cfg, \"left_arm_dof_names\", []) or []\n        )\n        right_arm_dof_names = list(\n            getattr(self.cfg, \"right_arm_dof_names\", []) or []\n        )\n        left_leg_dof_names = list(\n            getattr(self.cfg, \"left_leg_dof_names\", []) or []\n        )\n        right_leg_dof_names = list(\n            getattr(self.cfg, \"right_leg_dof_names\", []) or []\n        )\n        if not left_arm_dof_names:\n            left_arm_dof_names = self._amp_filter_names_by_prefix(\n                self.urdf_dof_names,\n                \"left_\",\n                (\"shoulder\", \"elbow\", \"wrist\"),\n            )\n        if not right_arm_dof_names:\n            right_arm_dof_names = self._amp_filter_names_by_prefix(\n                self.urdf_dof_names,\n                \"right_\",\n                (\"shoulder\", \"elbow\", \"wrist\"),\n            )\n        if not left_leg_dof_names:\n            left_leg_dof_names = self._amp_filter_names_by_prefix(\n                self.urdf_dof_names, \"left_\", (\"hip\", \"knee\", \"ankle\")\n            )\n        if not right_leg_dof_names:\n            right_leg_dof_names = self._amp_filter_names_by_prefix(\n                self.urdf_dof_names, \"right_\", (\"hip\", \"knee\", \"ankle\")\n            )\n\n        self._amp_left_arm_urdf_dof_idx = [\n            urdf_dof_name_to_idx[name] for name in left_arm_dof_names\n        ]\n        self._amp_right_arm_urdf_dof_idx = [\n            urdf_dof_name_to_idx[name] for name in right_arm_dof_names\n        ]\n        self._amp_left_leg_urdf_dof_idx = [\n            urdf_dof_name_to_idx[name] for name in left_leg_dof_names\n        ]\n        self._amp_right_leg_urdf_dof_idx = [\n            urdf_dof_name_to_idx[name] for name in right_leg_dof_names\n        ]\n        self._amp_left_arm_sim_dof_idx = [\n            sim_dof_name_to_idx[name] for name in left_arm_dof_names\n        ]\n        self._amp_right_arm_sim_dof_idx = [\n            sim_dof_name_to_idx[name] for name in right_arm_dof_names\n        ]\n        self._amp_left_leg_sim_dof_idx = [\n            sim_dof_name_to_idx[name] for name in left_leg_dof_names\n        ]\n        self._amp_right_leg_sim_dof_idx = [\n            sim_dof_name_to_idx[name] for name in right_leg_dof_names\n        ]\n\n        left_arm_body_names = list(\n            getattr(self.cfg, \"left_arm_body_names\", []) or []\n        )\n        right_arm_body_names = list(\n            getattr(self.cfg, \"right_arm_body_names\", []) or []\n        )\n        left_leg_body_names = list(\n            getattr(self.cfg, \"left_leg_body_names\", []) or []\n        )\n        right_leg_body_names = list(\n            getattr(self.cfg, \"right_leg_body_names\", []) or []\n        )\n        if not left_arm_body_names:\n            left_arm_body_names = self._amp_filter_names_by_prefix(\n                self.urdf_body_names, \"left_\", (\"shoulder\", \"elbow\", \"wrist\")\n            )\n        if not right_arm_body_names:\n            right_arm_body_names = self._amp_filter_names_by_prefix(\n                self.urdf_body_names, \"right_\", (\"shoulder\", \"elbow\", \"wrist\")\n            )\n        if not left_leg_body_names:\n            left_leg_body_names = self._amp_filter_names_by_prefix(\n                self.urdf_body_names, \"left_\", (\"hip\", \"knee\", \"ankle\")\n            )\n        if not right_leg_body_names:\n            right_leg_body_names = self._amp_filter_names_by_prefix(\n                self.urdf_body_names, \"right_\", (\"hip\", \"knee\", \"ankle\")\n            )\n\n        left_elbow_name = self._amp_pick_first_name(\n            left_arm_body_names, (\"left_elbow\", \"elbow\")\n        )\n        right_elbow_name = self._amp_pick_first_name(\n            right_arm_body_names, (\"right_elbow\", \"elbow\")\n        )\n        left_foot_name = self._amp_pick_first_name(\n            left_leg_body_names,\n            (\"left_ankle_roll\", \"left_ankle_pitch\", \"left_ankle\"),\n        )\n        right_foot_name = self._amp_pick_first_name(\n            right_leg_body_names,\n            (\"right_ankle_roll\", \"right_ankle_pitch\", \"right_ankle\"),\n        )\n\n        self._amp_left_elbow_urdf_body_idx = (\n            urdf_body_name_to_idx[left_elbow_name]\n            if left_elbow_name is not None\n            else None\n        )\n        self._amp_right_elbow_urdf_body_idx = (\n            urdf_body_name_to_idx[right_elbow_name]\n            if right_elbow_name is not None\n            else None\n        )\n        self._amp_left_foot_urdf_body_idx = (\n            urdf_body_name_to_idx[left_foot_name]\n            if left_foot_name is not None\n            else None\n        )\n        self._amp_right_foot_urdf_body_idx = (\n            urdf_body_name_to_idx[right_foot_name]\n            if right_foot_name is not None\n            else None\n        )\n        self._amp_left_elbow_sim_body_idx = (\n            sim_body_name_to_idx[left_elbow_name]\n            if left_elbow_name is not None\n            else None\n        )\n        self._amp_right_elbow_sim_body_idx = (\n            sim_body_name_to_idx[right_elbow_name]\n            if right_elbow_name is not None\n            else None\n        )\n        self._amp_left_foot_sim_body_idx = (\n            sim_body_name_to_idx[left_foot_name]\n            if left_foot_name is not None\n            else None\n        )\n        self._amp_right_foot_sim_body_idx = (\n            sim_body_name_to_idx[right_foot_name]\n            if right_foot_name is not None\n            else None\n        )\n\n        self._amp_left_hand_local_vec = torch.tensor(\n            [0.0, 0.0, -0.3], device=self.device, dtype=torch.float32\n        )\n        self._amp_right_hand_local_vec = torch.tensor(\n            [0.0, 0.0, -0.3], device=self.device, dtype=torch.float32\n        )\n\n    def _init_buffers(self):\n        self.metrics = {}\n        self.ref_motion_global_frame_ids = torch.zeros(\n            self.num_envs,\n            dtype=torch.long,\n            device=self.device,\n        )\n        # mark envs that timed out (frame id exceeded end frame) in current step\n        self._motion_end_mask = torch.zeros(\n            self.num_envs,\n            dtype=torch.bool,\n            device=self.device,\n        )\n        # counter for number of motion ends per environment\n        self.motion_end_counter = torch.zeros(\n            self.num_envs,\n            dtype=torch.long,\n            device=self.device,\n        )\n        # per-environment cached motion indices\n        self._cached_motion_ids = torch.zeros(\n            self.num_envs,\n            dtype=torch.long,\n            device=self.device,\n        )\n        # env -> cache row indirection (starts as identity mapping)\n        self._env_to_cache_row = torch.arange(\n            self.num_envs, dtype=torch.long, device=self.device\n        )\n        self._start_frame_indices = torch.zeros(\n            self.num_envs,\n            dtype=torch.long,\n            device=self.device,\n        )\n        self._reward_sum_since_assign = torch.zeros(\n            self.num_envs,\n            dtype=torch.float32,\n            device=self.device,\n        )\n        self._mpjpe_sum_since_assign = torch.zeros(\n            self.num_envs,\n            dtype=torch.float32,\n            device=self.device,\n        )\n        self._mpkpe_sum_since_assign = torch.zeros(\n            self.num_envs,\n            dtype=torch.float32,\n            device=self.device,\n        )\n        self._step_count_since_assign = torch.zeros(\n            self.num_envs,\n            dtype=torch.float32,\n            device=self.device,\n        )\n        self._completion_rate_sum_by_window: Dict[int, float] = {}\n        self._completion_rate_count_by_window: Dict[int, int] = {}\n        self._mpkpe_signal_sum_by_window: Dict[int, float] = {}\n        self._mpkpe_signal_count_by_window: Dict[int, int] = {}\n\n        self.pos_history_buffer = None\n        self.rot_history_buffer = None\n        self.ref_pos_history_buffer = None\n        self.current_accel = None\n        self.ref_body_accel = None\n        self.current_ang_accel = None  # Placeholder for angular acceleration\n\n        self.metrics[\"Task/MPJPE_WholeBody\"] = torch.zeros(\n            self.num_envs, device=self.device\n        )\n        self.metrics[\"Task/MPKPE_WholeBody\"] = torch.zeros(\n            self.num_envs, device=self.device\n        )\n\n    def _record_completion_rate_for_envs(self, env_ids: torch.Tensor) -> None:\n        if env_ids.numel() == 0:\n            return\n\n        selected_clip_indices = self._clip_indices[env_ids]\n        lengths = self._motion_cache.lengths_for_indices(selected_clip_indices)\n        window_indices = self._motion_cache.window_indices_for_indices(\n            selected_clip_indices\n        )\n        available_steps = torch.clamp(\n            lengths\n            - int(self.cfg.n_fut_frames)\n            - self._start_frame_indices[env_ids],\n            min=1,\n        )\n        completion_rate = torch.clamp(\n            self._step_count_since_assign[env_ids] / available_steps.float(),\n            min=0.0,\n            max=1.0,\n        )\n        step_den = torch.clamp(self._step_count_since_assign[env_ids], min=1.0)\n        mpkpe_mean = self._mpkpe_sum_since_assign[env_ids] / step_den\n        completion_values = completion_rate.detach().cpu().tolist()\n        mpkpe_values = mpkpe_mean.detach().cpu().tolist()\n        window_values = window_indices.detach().cpu().tolist()\n        for idx, window_index_obj in enumerate(window_values):\n            completion_value = float(completion_values[idx])\n            mpkpe_value = float(mpkpe_values[idx])\n            mpkpe_signal = -mpkpe_value\n            window_index = int(window_index_obj)\n\n            if window_index in self._completion_rate_sum_by_window:\n                self._completion_rate_sum_by_window[window_index] += (\n                    completion_value\n                )\n                self._completion_rate_count_by_window[window_index] += 1\n            else:\n                self._completion_rate_sum_by_window[window_index] = (\n                    completion_value\n                )\n                self._completion_rate_count_by_window[window_index] = 1\n\n            if window_index in self._mpkpe_signal_sum_by_window:\n                self._mpkpe_signal_sum_by_window[window_index] += mpkpe_signal\n                self._mpkpe_signal_count_by_window[window_index] += 1\n            else:\n                self._mpkpe_signal_sum_by_window[window_index] = mpkpe_signal\n                self._mpkpe_signal_count_by_window[window_index] = 1\n\n        self._reward_sum_since_assign[env_ids] = 0.0\n        self._mpjpe_sum_since_assign[env_ids] = 0.0\n        self._mpkpe_sum_since_assign[env_ids] = 0.0\n        self._step_count_since_assign[env_ids] = 0.0\n\n    def _reset_window_curriculum_stats(self) -> None:\n        self._completion_rate_sum_by_window = {}\n        self._completion_rate_count_by_window = {}\n        self._mpkpe_signal_sum_by_window = {}\n        self._mpkpe_signal_count_by_window = {}\n\n    def _build_window_curriculum_stats_from_current_batch(\n        self,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        batch_window_indices = self._motion_cache.current_batch.window_indices\n        row_window_indices = batch_window_indices.detach().to(\n            self.device, dtype=torch.long\n        )\n        count = int(row_window_indices.numel())\n        row_mpkpe_signal = torch.zeros(\n            count, dtype=torch.float32, device=self.device\n        )\n        row_completion_rate = torch.zeros(\n            count, dtype=torch.float32, device=self.device\n        )\n        row_count = torch.zeros(count, dtype=torch.float32, device=self.device)\n\n        window_values = row_window_indices.detach().cpu().tolist()\n        for row_idx, window_index_obj in enumerate(window_values):\n            window_index = int(window_index_obj)\n            completion_count = int(\n                self._completion_rate_count_by_window.get(window_index, 0)\n            )\n            mpkpe_count = int(\n                self._mpkpe_signal_count_by_window.get(window_index, 0)\n            )\n            if completion_count > 0:\n                row_completion_rate[row_idx] = float(\n                    self._completion_rate_sum_by_window[window_index]\n                ) / float(completion_count)\n            if mpkpe_count > 0:\n                row_mpkpe_signal[row_idx] = float(\n                    self._mpkpe_signal_sum_by_window[window_index]\n                ) / float(mpkpe_count)\n            row_count[row_idx] = float(max(completion_count, mpkpe_count))\n\n        return (\n            row_window_indices,\n            row_mpkpe_signal,\n            row_completion_rate,\n            row_count,\n        )\n\n    def _update_cache_curriculum_state(\n        self,\n        *,\n        accelerator,\n        swap_index: int,\n    ) -> None:\n        if self._sampling_strategy != \"curriculum\":\n            self._reset_window_curriculum_stats()\n            return\n\n        (\n            row_window_indices,\n            row_mpkpe_signal,\n            row_completion_rate,\n            row_count,\n        ) = self._build_window_curriculum_stats_from_current_batch()\n\n        if accelerator is not None and int(accelerator.num_processes) > 1:\n            gather_window_indices = accelerator.gather(row_window_indices)\n            gather_mpkpe_signal = accelerator.gather(row_mpkpe_signal)\n            gather_completion_rate = accelerator.gather(row_completion_rate)\n            gather_count = accelerator.gather(row_count)\n        else:\n            gather_window_indices = row_window_indices\n            gather_mpkpe_signal = row_mpkpe_signal\n            gather_completion_rate = row_completion_rate\n            gather_count = row_count\n\n        self._motion_cache.update_cache_curriculum(\n            window_indices=gather_window_indices,\n            mpkpe_signal_means=gather_mpkpe_signal,\n            completion_rate_means=gather_completion_rate,\n            counts=gather_count,\n            swap_index=int(swap_index),\n        )\n        self._reset_window_curriculum_stats()\n\n    def update_curriculum_reward_accumulators(\n        self, rewards: torch.Tensor\n    ) -> None:\n        reward_flat = rewards.view(-1).to(self.device, dtype=torch.float32)\n        all_ids = torch.arange(\n            self.num_envs, dtype=torch.long, device=self.device\n        )\n        motion_ids = self._filter_env_ids_for_motion_task(all_ids)\n        if motion_ids.numel() == 0:\n            return\n        self._reward_sum_since_assign[motion_ids] += reward_flat[motion_ids]\n        mpjpe = self.metrics[\"Task/MPJPE_WholeBody\"]\n        mpkpe = self.metrics[\"Task/MPKPE_WholeBody\"]\n        self._mpjpe_sum_since_assign[motion_ids] += mpjpe[motion_ids].to(\n            dtype=torch.float32\n        )\n        self._mpkpe_sum_since_assign[motion_ids] += mpkpe[motion_ids].to(\n            dtype=torch.float32\n        )\n        self._step_count_since_assign[motion_ids] += 1.0\n\n    @property\n    def command(\n        self,\n    ) -> torch.Tensor:\n        # call the corresponding method based on configured command_obs_name\n        return getattr(self, f\"_get_obs_{self.cfg.command_obs_name}\")()\n\n    @property\n    def command_fut(\n        self,\n    ) -> torch.Tensor:\n        # call the corresponding method based on configured command_obs_name\n        return getattr(self, f\"_get_obs_{self.cfg.command_obs_name}_fut\")()\n\n    def reset(\n        self,\n        env_ids: Sequence[int] | None = None,\n    ) -> dict[str, float]:\n        extras = super().reset(env_ids)\n\n        if env_ids is None:\n            env_ids = slice(None)\n\n        if not isinstance(env_ids, torch.Tensor):\n            env_ids = torch.tensor(\n                env_ids, device=self.device, dtype=torch.long\n            )\n        else:\n            env_ids = env_ids.to(self.device)\n        self._motion_end_mask[env_ids] = False\n        self.motion_end_counter[env_ids] = 0\n\n        # Do not apply cache swap inside per-env reset; defer to PPO barrier.\n        # Always resample only the requested envs here.\n        motion_ids = self._filter_env_ids_for_motion_task(env_ids.view(-1))\n        self._resample_command(motion_ids, eval=self._is_evaluating)\n\n        return extras\n\n    def apply_cache_swap_if_pending_barrier(self, accelerator=None) -> bool:\n        \"\"\"Apply a pending cache swap at a rollout barrier.\n\n        Returns:\n            bool: True if a swap was applied, otherwise False.\n        \"\"\"\n        if not getattr(self, \"_swap_pending\", False):\n            return False\n\n        all_ids = torch.arange(\n            self.num_envs, dtype=torch.long, device=self.device\n        )\n        motion_ids = self._filter_env_ids_for_motion_task(all_ids)\n        if motion_ids.numel() == 0:\n            # No motion envs active under multi-task: keep ref motion inert.\n            self._swap_pending = False\n            self._swap_step_counter = 0\n            return False\n\n        self._record_completion_rate_for_envs(motion_ids)\n        next_swap_index = int(self._motion_cache.swap_index) + 1\n        self._update_cache_curriculum_state(\n            accelerator=accelerator,\n            swap_index=next_swap_index,\n        )\n\n        # Advance cache and reset counters\n        self._motion_cache.advance()\n        self._maybe_dump_sampled_motion_keys()\n        self._swap_pending = False\n        self._swap_step_counter = 0\n\n        # Reassign motion envs to the new cache batch\n        clip_idx, frame_idx = self._motion_cache.sample_env_assignments(\n            int(motion_ids.numel()),\n            self.cfg.n_fut_frames,\n            self.device,\n            deterministic_start=(self._is_evaluating),\n        )\n        self._clip_indices[motion_ids] = clip_idx\n        self._frame_indices[motion_ids] = frame_idx\n        self._start_frame_indices[motion_ids] = frame_idx\n        self._reward_sum_since_assign[motion_ids] = 0.0\n        self._step_count_since_assign[motion_ids] = 0.0\n        self._update_ref_motion_state_from_cache(env_ids=motion_ids)\n\n        # Realign robot states to the new reference\n        self._align_root_to_ref(motion_ids)\n        self._align_dof_to_ref(motion_ids)\n\n        # Reset per-episode timeout bookkeeping for consistency\n        self._motion_end_mask[motion_ids] = False\n        self.motion_end_counter[motion_ids] = 0\n        return True\n\n    def compute(self, dt: float):\n        all_ids = torch.arange(\n            self.num_envs, dtype=torch.long, device=self.device\n        )\n        motion_ids = self._filter_env_ids_for_motion_task(all_ids)\n        if motion_ids.numel() == 0:\n            return\n        self._update_metrics()\n        self._update_command()\n\n    def _update_ref_motion_state(self):\n        \"\"\"Update reference motion state (unified API).\"\"\"\n        return self._update_ref_motion_state_from_cache()\n\n    def _update_ref_motion_state_from_cache(\n        self, env_ids: torch.Tensor | None = None\n    ):\n        \"\"\"Compatibility no-op for cache-backed reference access.\"\"\"\n        del env_ids\n        return None\n\n    def _get_ref_state_array(\n        self,\n        base_key: str,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Gather a reference tensor from the current cache batch.\n\n        Args:\n            base_key: Base key in the motion cache (e.g. \\\"dof_pos\\\", \\\"root_pos\\\").\n            prefix: Optional logical prefix (e.g. \\\"\\\", \\\"ref_\\\", \\\"ft_ref_\\\", \\\"robot_\\\").\n\n        Returns:\n            Tensor of shape ``[num_envs, 1 + n_fut_frames, ...]`` gathered for\n            the envs' current clip/frame assignments.\n        \"\"\"\n        batch_tensors = self._motion_cache.current_batch.tensors\n        tensor_key = resolve_reference_tensor_key(\n            batch_tensors=batch_tensors,\n            base_key=base_key,\n            prefix=prefix,\n        )\n        return self._motion_cache.gather_tensor(\n            tensor_key,\n            clip_indices=self._clip_indices,\n            frame_indices=self._frame_indices,\n            n_future_frames=self.cfg.n_fut_frames,\n        )\n\n    def get_ref_motion_filter_cutoff_hz_cur(self) -> torch.Tensor:\n        try:\n            base = self._get_ref_state_array(\"filter_cutoff_hz\", prefix=\"\")\n        except KeyError:\n            # Older/local datasets may not carry per-clip filter metadata.\n            # Keep the observation available with a neutral default instead of\n            # failing during env construction.\n            return torch.zeros(\n                self.num_envs, 1, device=self.device, dtype=torch.float32\n            )\n        return base[:, 0, ...]\n\n    def _uniform_sample_ref_start_frames(self, env_ids: torch.Tensor):\n        \"\"\"Uniformly sample start frames within cached windows for env_ids.\n\n        Sampling range is [start, end - 1 - n_fut_frames] to ensure required\n        future frames exist. If that upper bound is < start, it falls back to start.\n        \"\"\"\n        if not isinstance(env_ids, torch.Tensor):\n            env_ids = torch.tensor(\n                env_ids, device=self.device, dtype=torch.long\n            )\n        else:\n            env_ids = env_ids.to(self.device).long()\n\n        starts = self.ref_motion_global_start_frame_ids[env_ids]\n        ends = self.ref_motion_global_end_frame_ids[env_ids]\n\n        # Ensure room for future frames if requested\n        n_fut = (\n            int(self.cfg.n_fut_frames)\n            if hasattr(self.cfg, \"n_fut_frames\")\n            else 0\n        )\n        max_start = ends - 1 - n_fut\n        max_start = torch.maximum(max_start, starts)\n\n        num_choices = (max_start - starts + 1).clamp(min=1)\n        # Sample offsets uniformly\n        rand = torch.rand_like(starts, dtype=torch.float32)\n        offsets = torch.floor(rand * num_choices.float()).long()\n        sampled = starts + offsets\n\n        self.ref_motion_global_frame_ids[env_ids] = sampled\n\n    def get_ref_motion_dof_pos_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"dof_pos\", prefix)\n        return base[:, 1:, ...][..., self.urdf2sim_dof_idx]\n\n    def _get_immediate_next_ref_state_array(\n        self,\n        base_key: str,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(base_key, prefix)\n        if base.shape[1] < 2:\n            raise ValueError(\n                f\"Immediate-next reference for '{base_key}' requires at least one future frame.\"\n            )\n        return base[:, 1, ...]\n\n    def get_ref_motion_dof_vel_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"dof_vel\", prefix)\n        return base[:, 1:, ...][..., self.urdf2sim_dof_idx]\n\n    def get_ref_motion_root_global_pos_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"root_pos\", prefix)\n        return base[:, 1:, ...] + self._env_origins[:, None, :]\n\n    def get_ref_motion_root_global_rot_quat_xyzw_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        return self._get_ref_state_array(\"root_rot\", prefix)[:, 1:, ...]\n\n    def get_ref_motion_root_global_rot_quat_wxyz_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        return self.get_ref_motion_root_global_rot_quat_xyzw_fut(\n            prefix=prefix\n        )[..., [3, 0, 1, 2]]\n\n    def get_ref_motion_root_global_lin_vel_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"root_vel\", prefix)\n        return base[:, 1:, ...]\n\n    def get_ref_motion_root_global_ang_vel_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"root_ang_vel\", prefix)\n        return base[:, 1:, ...]\n\n    def get_ref_motion_bodylink_global_pos_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"rg_pos\", prefix)\n        return (\n            base[:, 1:, ...][..., self.urdf2sim_body_idx, :]\n            + self._env_origins[:, None, None, :]\n        )\n\n    def get_ref_motion_bodylink_rel_pos_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        ref_body_global_pos = self.get_ref_motion_bodylink_global_pos_cur(\n            prefix=prefix\n        )  # [B, N, 3]\n        ref_root_global_pos = self.get_ref_motion_root_global_pos_cur(\n            prefix=prefix\n        )  # [B, 3]\n        ref_root_global_rot_wxyz = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_cur(prefix=prefix)\n        )  # [B, 4]\n        rel_pos_w = (\n            ref_body_global_pos - ref_root_global_pos[:, None, :]\n        )  # [B, N, 3]\n        num_bodies = rel_pos_w.shape[1]\n        expanded_ref_root_global_rot_wxyz = ref_root_global_rot_wxyz[\n            :, None, :\n        ].expand(-1, num_bodies, -1)\n        return isaaclab_math.quat_apply_inverse(\n            expanded_ref_root_global_rot_wxyz, rel_pos_w\n        )  # [B, N, 3]\n\n    def get_ref_motion_bodylink_rel_pos_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        ref_body_global_pos_fut = self.get_ref_motion_bodylink_global_pos_fut(\n            prefix=prefix\n        )  # [B, T, N, 3]\n        ref_root_global_pos_fut = self.get_ref_motion_root_global_pos_fut(\n            prefix=prefix\n        )  # [B, T, 3]\n        ref_root_global_rot_wxyz_fut = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)\n        )  # [B, T, 4]\n        rel_pos_w_fut = (\n            ref_body_global_pos_fut - ref_root_global_pos_fut[:, :, None, :]\n        )  # [B, T, N, 3]\n        num_bodies = rel_pos_w_fut.shape[2]\n        expanded_ref_root_global_rot_wxyz_fut = ref_root_global_rot_wxyz_fut[\n            :, :, None, :\n        ].expand(-1, -1, num_bodies, -1)\n        return isaaclab_math.quat_apply_inverse(\n            expanded_ref_root_global_rot_wxyz_fut, rel_pos_w_fut\n        )  # [B, T, N, 3]\n\n    def get_ref_motion_bodylink_global_rot_xyzw_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"rb_rot\", prefix)\n        return base[:, 1:, ...][..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_bodylink_global_lin_vel_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"body_vel\", prefix)\n        return base[:, 1:, ...][..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_bodylink_global_ang_vel_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"body_ang_vel\", prefix)\n        return base[:, 1:, ...][..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_dof_pos_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"dof_pos\", prefix)\n        return base[:, 0, ...][..., self.urdf2sim_dof_idx]\n\n    def get_ref_motion_dof_pos_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_immediate_next_ref_state_array(\"dof_pos\", prefix)\n        return base[..., self.urdf2sim_dof_idx]\n\n    def get_immediate_next_two_dof_pos(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Immediate next two DoF positions in simulator DoF order.\"\"\"\n        n_fut = int(self.cfg.n_fut_frames)\n        if n_fut < 1:\n            raise ValueError(\n                \"n_fut_frames must be at least 1 for immediate next two DoF positions.\"\n            )\n        base = self._get_ref_state_array(\"dof_pos\", prefix)\n        return base[:, :2, ...][..., self.urdf2sim_dof_idx]\n\n    def get_ref_motion_dof_pos_cur_urdf_order(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"dof_pos\", prefix)\n        return base[:, 0, ...]\n\n    def get_ref_motion_cur_heading_aligned_root_pos(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        # prepare current frame robot root global poses\n        robot_cur_global_root_pos = self.robot.data.root_pos_w\n        robot_cur_global_root_rot = self.robot.data.root_quat_w  # wxyz\n        yaw_quat = isaaclab_math.yaw_quat(robot_cur_global_root_rot)\n\n        # transform the current goal frame root poses into the relative heading aligned frame\n        global_pos_diff = (\n            self.get_ref_motion_root_global_pos_cur(prefix=prefix)\n            - robot_cur_global_root_pos\n        )\n        global_pos_diff_heading_aligned = isaaclab_math.quat_apply_inverse(\n            yaw_quat, global_pos_diff\n        )\n        return global_pos_diff_heading_aligned\n\n    def get_ref_motion_fut_heading_aligned_root_pos(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        # prepare current frame robot root global poses\n        robot_cur_global_root_pos = self.robot.data.root_pos_w  # [B, 3]\n        robot_cur_global_root_rot = self.robot.data.root_quat_w  # [B, 4]\n        yaw_quat = isaaclab_math.yaw_quat(robot_cur_global_root_rot)  # [B, 4]\n\n        # transform the current goal frame root poses into the relative heading aligned frame\n        fut_root_global_pos = self.get_ref_motion_root_global_pos_fut(\n            prefix=prefix\n        )  # [B, T, 3]\n        num_fut_frames = fut_root_global_pos.shape[1]\n        global_pos_diff = (\n            fut_root_global_pos - robot_cur_global_root_pos[:, None, :]\n        )  # [B, T, 3]\n        expanded_yaw_quat = yaw_quat[:, None, :].expand(\n            -1, num_fut_frames, -1\n        )  # [B, T, 4]\n        fut_root_global_pos_heading_aligned = isaaclab_math.quat_apply_inverse(\n            expanded_yaw_quat, global_pos_diff\n        )  # [B, T, 3]\n        return fut_root_global_pos_heading_aligned\n\n    def get_ref_motion_cur_heading_aligned_root_rot6d(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Current reference root rotation (rot6d) in heading-aligned frame.\n\n        Returns:\n            torch.Tensor: [B, 6]\n        \"\"\"\n        robot_cur_global_root_rot = self.robot.data.root_quat_w  # [B, 4] wxyz\n        heading_quat_wxyz = isaaclab_math.yaw_quat(\n            robot_cur_global_root_rot\n        )  # [B, 4] wxyz\n        heading_quat_inv_wxyz = isaaclab_math.quat_inv(\n            heading_quat_wxyz\n        )  # [B, 4] wxyz\n\n        ref_root_quat_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur(\n            prefix=prefix\n        )  # [B, 4] wxyz\n        ref_root_quat_in_heading_wxyz = isaaclab_math.quat_mul(\n            heading_quat_inv_wxyz, ref_root_quat_wxyz\n        )  # [B, 4] wxyz\n\n        # rot6d: first two columns of rotation matrix (flattened)\n        ref_root_rot6d = isaaclab_math.matrix_from_quat(\n            ref_root_quat_in_heading_wxyz\n        )[..., :2].reshape(ref_root_quat_wxyz.shape[0], 6)  # [B, 6]\n        return ref_root_rot6d\n\n    def get_ref_motion_fut_heading_aligned_root_rot6d(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Future reference root rotations (rot6d) in heading-aligned frame.\n\n        Returns:\n            torch.Tensor: [B, T, 6]\n        \"\"\"\n        robot_cur_global_root_rot = self.robot.data.root_quat_w  # [B, 4] wxyz\n        heading_quat_wxyz = isaaclab_math.yaw_quat(\n            robot_cur_global_root_rot\n        )  # [B, 4] wxyz\n        heading_quat_inv_wxyz = isaaclab_math.quat_inv(\n            heading_quat_wxyz\n        )  # [B, 4] wxyz\n\n        ref_root_quat_wxyz_fut = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)\n        )  # [B, T, 4] wxyz\n        num_envs, num_fut_frames, _ = ref_root_quat_wxyz_fut.shape\n\n        heading_quat_inv_wxyz_fut = heading_quat_inv_wxyz[:, None, :].expand(\n            -1, num_fut_frames, -1\n        )  # [B, T, 4]\n        ref_root_quat_in_heading_wxyz_fut = isaaclab_math.quat_mul(\n            heading_quat_inv_wxyz_fut, ref_root_quat_wxyz_fut\n        )  # [B, T, 4] wxyz\n\n        ref_root_rot6d_fut = isaaclab_math.matrix_from_quat(\n            ref_root_quat_in_heading_wxyz_fut\n        )[..., :2].reshape(num_envs, num_fut_frames, 6)  # [B, T, 6]\n\n        return ref_root_rot6d_fut\n\n    def get_ref_motion_cur_heading_aligned_root_lin_vel(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Current reference root linear velocity in heading-aligned frame.\n        Returns: [B, 3]\n        \"\"\"\n        robot_cur_global_root_rot = self.robot.data.root_quat_w  # [B, 4] wxyz\n        heading_quat_wxyz = isaaclab_math.yaw_quat(\n            robot_cur_global_root_rot\n        )  # [B, 4] wxyz\n        ref_root_lin_vel_w = self.get_ref_motion_root_global_lin_vel_cur(\n            prefix=prefix\n        )  # [B, 3]\n        ref_root_lin_vel_heading = isaaclab_math.quat_apply_inverse(\n            heading_quat_wxyz, ref_root_lin_vel_w\n        )  # [B, 3]\n        return ref_root_lin_vel_heading\n\n    def get_ref_motion_fut_heading_aligned_root_lin_vel(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Future reference root linear velocity in heading-aligned frame.\n        Returns: [B, T, 3]\n        \"\"\"\n        robot_cur_global_root_rot = self.robot.data.root_quat_w  # [B, 4] wxyz\n        heading_quat_wxyz = isaaclab_math.yaw_quat(\n            robot_cur_global_root_rot\n        )  # [B, 4] wxyz\n        ref_root_lin_vel_w_fut = self.get_ref_motion_root_global_lin_vel_fut(\n            prefix=prefix\n        )  # [B, T, 3]\n        num_envs, num_fut_frames, _ = ref_root_lin_vel_w_fut.shape\n        heading_quat_wxyz_fut = heading_quat_wxyz[:, None, :].expand(\n            -1, num_fut_frames, -1\n        )  # [B, T, 4]\n        ref_root_lin_vel_heading_fut = isaaclab_math.quat_apply_inverse(\n            heading_quat_wxyz_fut, ref_root_lin_vel_w_fut\n        )  # [B, T, 3]\n        return ref_root_lin_vel_heading_fut\n\n    def get_ref_motion_cur_heading_aligned_root_ang_vel(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Current reference root angular velocity in heading-aligned frame.\n        Returns: [B, 3]\n        \"\"\"\n        robot_cur_global_root_rot = self.robot.data.root_quat_w  # [B, 4] wxyz\n        heading_quat_wxyz = isaaclab_math.yaw_quat(\n            robot_cur_global_root_rot\n        )  # [B, 4] wxyz\n        ref_root_ang_vel_w = self.get_ref_motion_root_global_ang_vel_cur(\n            prefix=prefix\n        )  # [B, 3]\n        ref_root_ang_vel_heading = isaaclab_math.quat_apply_inverse(\n            heading_quat_wxyz, ref_root_ang_vel_w\n        )  # [B, 3]\n        return ref_root_ang_vel_heading\n\n    def get_ref_motion_fut_heading_aligned_root_ang_vel(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Future reference root angular velocity in heading-aligned frame.\n        Returns: [B, T, 3]\n        \"\"\"\n        robot_cur_global_root_rot = self.robot.data.root_quat_w  # [B, 4] wxyz\n        heading_quat_wxyz = isaaclab_math.yaw_quat(\n            robot_cur_global_root_rot\n        )  # [B, 4] wxyz\n        ref_root_ang_vel_w_fut = self.get_ref_motion_root_global_ang_vel_fut(\n            prefix=prefix\n        )  # [B, T, 3]\n        num_envs, num_fut_frames, _ = ref_root_ang_vel_w_fut.shape\n        heading_quat_wxyz_fut = heading_quat_wxyz[:, None, :].expand(\n            -1, num_fut_frames, -1\n        )  # [B, T, 4]\n        ref_root_ang_vel_heading_fut = isaaclab_math.quat_apply_inverse(\n            heading_quat_wxyz_fut, ref_root_ang_vel_w_fut\n        )  # [B, T, 3]\n        return ref_root_ang_vel_heading_fut\n\n    @property\n    def robot_dof_pos_cur_urdf_order(self):\n        return self.robot.data.joint_pos[..., self.sim2urdf_dof_idx]\n\n    def get_ref_motion_dof_vel_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"dof_vel\", prefix)\n        return base[:, 0, ...][..., self.urdf2sim_dof_idx]\n\n    def get_ref_motion_dof_vel_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_immediate_next_ref_state_array(\"dof_vel\", prefix)\n        return base[..., self.urdf2sim_dof_idx]\n\n    @property\n    def robot_dof_vel_cur_urdf_order(self):\n        return self.robot.data.joint_vel[..., self.sim2urdf_dof_idx]\n\n    def get_ref_motion_dof_vel_cur_urdf_order(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"dof_vel\", prefix)\n        return base[:, 0, ...]\n\n    def get_ref_motion_root_global_pos_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"root_pos\", prefix)\n        return base[:, 0, ...] + self._env_origins\n\n    def get_ref_motion_root_global_pos_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_immediate_next_ref_state_array(\"root_pos\", prefix)\n        return base + self._env_origins\n\n    def get_ref_motion_root_global_rot_quat_xyzw_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        return self._get_ref_state_array(\"root_rot\", prefix)[:, 0, ...]\n\n    def get_ref_motion_root_global_rot_quat_xyzw_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        return self._get_immediate_next_ref_state_array(\"root_rot\", prefix)\n\n    def get_ref_motion_root_global_rot_quat_wxyz_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        return self.get_ref_motion_root_global_rot_quat_xyzw_cur(\n            prefix=prefix\n        )[..., [3, 0, 1, 2]]\n\n    def get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        return self.get_ref_motion_root_global_rot_quat_xyzw_immediate_next(\n            prefix=prefix\n        )[..., [3, 0, 1, 2]]\n\n    def get_ref_motion_root_global_lin_vel_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"root_vel\", prefix)\n        return base[:, 0, ...]\n\n    def get_ref_motion_root_global_lin_vel_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        return self._get_immediate_next_ref_state_array(\"root_vel\", prefix)\n\n    @property\n    def ref_motion_root_global_lin_vel_cur(self) -> torch.Tensor:\n        return self.get_ref_motion_root_global_lin_vel_cur()\n\n    def get_ref_motion_root_global_ang_vel_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"root_ang_vel\", prefix)\n        return base[:, 0, ...]\n\n    def get_ref_motion_root_global_ang_vel_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        return self._get_immediate_next_ref_state_array(\"root_ang_vel\", prefix)\n\n    def get_ref_motion_gravity_projection_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Current reference gravity projected into reference root frame.\"\"\"\n        g_w = self.robot.data.GRAVITY_VEC_W  # [B, 3]\n        ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur(\n            prefix=prefix\n        )  # [B, 4]\n        return isaaclab_math.quat_apply_inverse(ref_root_rot_wxyz, g_w)\n\n    def get_ref_motion_gravity_projection_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        g_w = self.robot.data.GRAVITY_VEC_W  # [B, 3]\n        ref_root_rot_wxyz = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n                prefix=prefix\n            )\n        )\n        return isaaclab_math.quat_apply_inverse(ref_root_rot_wxyz, g_w)\n\n    def get_ref_motion_gravity_projection_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Future reference gravity projected into reference root frame.\"\"\"\n        g_w = self.robot.data.GRAVITY_VEC_W  # [B, 3]\n        ref_root_rot_wxyz_fut = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)\n        )  # [B, T, 4]\n        gravity_fut = g_w[:, None, :].expand(\n            -1, ref_root_rot_wxyz_fut.shape[1], -1\n        )  # [B, T, 3]\n        return isaaclab_math.quat_apply_inverse(\n            ref_root_rot_wxyz_fut, gravity_fut\n        )  # [B, T, 3]\n\n    def get_ref_motion_base_linvel_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Current reference base linear velocity in reference root frame.\"\"\"\n        ref_root_lin_vel_w = self.get_ref_motion_root_global_lin_vel_cur(\n            prefix=prefix\n        )  # [B, 3]\n        ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur(\n            prefix=prefix\n        )  # [B, 4]\n        return isaaclab_math.quat_apply_inverse(\n            ref_root_rot_wxyz, ref_root_lin_vel_w\n        )  # [B, 3]\n\n    def get_ref_motion_base_linvel_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        ref_root_lin_vel_w = (\n            self.get_ref_motion_root_global_lin_vel_immediate_next(\n                prefix=prefix\n            )\n        )\n        ref_root_rot_wxyz = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n                prefix=prefix\n            )\n        )\n        return isaaclab_math.quat_apply_inverse(\n            ref_root_rot_wxyz, ref_root_lin_vel_w\n        )\n\n    def get_ref_motion_base_linvel_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Future reference base linear velocity in reference root frame.\"\"\"\n        ref_root_lin_vel_w_fut = self.get_ref_motion_root_global_lin_vel_fut(\n            prefix=prefix\n        )  # [B, T, 3]\n        ref_root_rot_wxyz_fut = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)\n        )  # [B, T, 4]\n        return isaaclab_math.quat_apply_inverse(\n            ref_root_rot_wxyz_fut, ref_root_lin_vel_w_fut\n        )  # [B, T, 3]\n\n    def get_ref_motion_base_angvel_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Current reference base angular velocity in reference root frame.\"\"\"\n        ref_root_ang_vel_w = self.get_ref_motion_root_global_ang_vel_cur(\n            prefix=prefix\n        )  # [B, 3]\n        ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur(\n            prefix=prefix\n        )  # [B, 4]\n        return isaaclab_math.quat_apply_inverse(\n            ref_root_rot_wxyz, ref_root_ang_vel_w\n        )  # [B, 3]\n\n    def get_ref_motion_base_angvel_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        ref_root_ang_vel_w = (\n            self.get_ref_motion_root_global_ang_vel_immediate_next(\n                prefix=prefix\n            )\n        )\n        ref_root_rot_wxyz = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n                prefix=prefix\n            )\n        )\n        return isaaclab_math.quat_apply_inverse(\n            ref_root_rot_wxyz, ref_root_ang_vel_w\n        )\n\n    def get_ref_motion_base_angvel_fut(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        \"\"\"Future reference base angular velocity in reference root frame.\"\"\"\n        ref_root_ang_vel_w_fut = self.get_ref_motion_root_global_ang_vel_fut(\n            prefix=prefix\n        )  # [B, T, 3]\n        ref_root_rot_wxyz_fut = (\n            self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)\n        )  # [B, T, 4]\n        return isaaclab_math.quat_apply_inverse(\n            ref_root_rot_wxyz_fut, ref_root_ang_vel_w_fut\n        )  # [B, T, 3]\n\n    def get_ref_motion_bodylink_global_pos_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"rg_pos\", prefix)\n        return (\n            base[:, 0, ...][..., self.urdf2sim_body_idx, :]\n            + self._env_origins[:, None, :]\n        )\n\n    def get_ref_motion_bodylink_global_pos_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_immediate_next_ref_state_array(\"rg_pos\", prefix)\n        return (\n            base[..., self.urdf2sim_body_idx, :]\n            + self._env_origins[:, None, :]\n        )\n\n    def get_ref_motion_bodylink_global_pos_cur_urdf_order(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"rg_pos\", prefix)\n        return base[:, 0, ...] + self._env_origins[:, None, :]\n\n    def get_ref_motion_bodylink_global_rot_wxyz_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        rot_xyzw = self.get_ref_motion_bodylink_global_rot_xyzw_cur(\n            prefix=prefix\n        )\n        return rot_xyzw[..., [3, 0, 1, 2]]\n\n    def get_ref_motion_bodylink_global_rot_xyzw_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"rb_rot\", prefix)\n        return base[:, 0, ...][..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_bodylink_global_rot_xyzw_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_immediate_next_ref_state_array(\"rb_rot\", prefix)\n        return base[..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_bodylink_global_rot_wxyz_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        rot_xyzw = self.get_ref_motion_bodylink_global_rot_xyzw_immediate_next(\n            prefix=prefix\n        )\n        return rot_xyzw[..., [3, 0, 1, 2]]\n\n    def get_ref_motion_bodylink_global_rot_xyzw_cur_urdf_order(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"rb_rot\", prefix)\n        return base[:, 0, ...]\n\n    @property\n    def robot_bodylink_global_pos_cur_urdf_order(self):\n        return self.robot.data.body_pos_w[:, self.sim2urdf_body_idx]\n\n    @property\n    def robot_bodylink_global_rot_wxyz_cur_urdf_order(self):\n        return self.robot.data.body_quat_w[:, self.sim2urdf_body_idx]\n\n    @property\n    def robot_bodylink_global_rot_xyzw_cur_urdf_order(self):\n        return self.robot_bodylink_global_rot_wxyz_cur_urdf_order[\n            ..., [1, 2, 3, 0]\n        ]\n\n    @property\n    def robot_bodylink_global_lin_vel_cur_urdf_order(self):\n        return self.robot.data.body_lin_vel_w[:, self.sim2urdf_body_idx]\n\n    @property\n    def robot_bodylink_global_ang_vel_cur_urdf_order(self):\n        return self.robot.data.body_ang_vel_w[:, self.sim2urdf_body_idx]\n\n    def get_ref_motion_bodylink_global_lin_vel_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"body_vel\", prefix)\n        return base[:, 0, ...][..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_bodylink_global_lin_vel_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_immediate_next_ref_state_array(\"body_vel\", prefix)\n        return base[..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_bodylink_global_lin_vel_cur_urdf_order(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"body_vel\", prefix)\n        return base[:, 0, ...]\n\n    def get_ref_motion_bodylink_global_ang_vel_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"body_ang_vel\", prefix)\n        return base[:, 0, ...][..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_bodylink_global_ang_vel_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_immediate_next_ref_state_array(\"body_ang_vel\", prefix)\n        return base[..., self.urdf2sim_body_idx, :]\n\n    def get_ref_motion_bodylink_global_ang_vel_cur_urdf_order(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base = self._get_ref_state_array(\"body_ang_vel\", prefix)\n        return base[:, 0, ...]\n\n    def _build_amp_obs_from_ref_state(\n        self, frame_idx: int, prefix: str = \"ft_ref_\"\n    ) -> torch.Tensor:\n        if (\n            not self._amp_left_arm_urdf_dof_idx\n            or not self._amp_right_arm_urdf_dof_idx\n            or not self._amp_left_leg_urdf_dof_idx\n            or not self._amp_right_leg_urdf_dof_idx\n            or self._amp_left_elbow_urdf_body_idx is None\n            or self._amp_right_elbow_urdf_body_idx is None\n            or self._amp_left_foot_urdf_body_idx is None\n            or self._amp_right_foot_urdf_body_idx is None\n        ):\n            raise ValueError(\n                \"AMP obs indices are not initialized for ref motion.\"\n            )\n\n        dof_pos = self._get_ref_state_array(\"dof_pos\", prefix)[\n            :, frame_idx, ...\n        ]\n        dof_vel = self._get_ref_state_array(\"dof_vel\", prefix)[\n            :, frame_idx, ...\n        ]\n\n        right_arm_pos = dof_pos[:, self._amp_right_arm_urdf_dof_idx]\n        left_arm_pos = dof_pos[:, self._amp_left_arm_urdf_dof_idx]\n        right_leg_pos = dof_pos[:, self._amp_right_leg_urdf_dof_idx]\n        left_leg_pos = dof_pos[:, self._amp_left_leg_urdf_dof_idx]\n        right_arm_vel = dof_vel[:, self._amp_right_arm_urdf_dof_idx]\n        left_arm_vel = dof_vel[:, self._amp_left_arm_urdf_dof_idx]\n        right_leg_vel = dof_vel[:, self._amp_right_leg_urdf_dof_idx]\n        left_leg_vel = dof_vel[:, self._amp_left_leg_urdf_dof_idx]\n\n        root_pos = self._get_ref_state_array(\"root_pos\", prefix)[\n            :, frame_idx, ...\n        ]\n        root_rot = self._get_ref_state_array(\"root_rot\", prefix)[\n            :, frame_idx, ...\n        ]\n        root_inv = quat_inverse(root_rot, w_last=True)\n\n        rg_pos = self._get_ref_state_array(\"rg_pos\", prefix)[:, frame_idx, ...]\n        rb_rot = self._get_ref_state_array(\"rb_rot\", prefix)[:, frame_idx, ...]\n\n        left_elbow_pos = rg_pos[:, self._amp_left_elbow_urdf_body_idx, :]\n        right_elbow_pos = rg_pos[:, self._amp_right_elbow_urdf_body_idx, :]\n        left_elbow_rot = rb_rot[:, self._amp_left_elbow_urdf_body_idx, :]\n        right_elbow_rot = rb_rot[:, self._amp_right_elbow_urdf_body_idx, :]\n\n        left_hand_offset = self._amp_left_hand_local_vec.expand(\n            left_elbow_pos.shape[0], -1\n        )\n        right_hand_offset = self._amp_right_hand_local_vec.expand(\n            right_elbow_pos.shape[0], -1\n        )\n        left_hand_world = left_elbow_pos + quat_rotate(\n            left_elbow_rot, left_hand_offset, w_last=True\n        )\n        right_hand_world = right_elbow_pos + quat_rotate(\n            right_elbow_rot, right_hand_offset, w_last=True\n        )\n        left_hand_rel = quat_rotate(\n            root_inv, left_hand_world - root_pos, w_last=True\n        )\n        right_hand_rel = quat_rotate(\n            root_inv, right_hand_world - root_pos, w_last=True\n        )\n\n        left_foot_world = rg_pos[:, self._amp_left_foot_urdf_body_idx, :]\n        right_foot_world = rg_pos[:, self._amp_right_foot_urdf_body_idx, :]\n        left_foot_rel = quat_rotate(\n            root_inv, left_foot_world - root_pos, w_last=True\n        )\n        right_foot_rel = quat_rotate(\n            root_inv, right_foot_world - root_pos, w_last=True\n        )\n\n        return torch.cat(\n            [\n                right_arm_pos,\n                left_arm_pos,\n                right_leg_pos,\n                left_leg_pos,\n                right_arm_vel,\n                left_arm_vel,\n                right_leg_vel,\n                left_leg_vel,\n                left_hand_rel,\n                right_hand_rel,\n                left_foot_rel,\n                right_foot_rel,\n            ],\n            dim=-1,\n        )\n\n    def get_ref_motion_amp_obs_cur(\n        self, prefix: str = \"ft_ref_\"\n    ) -> torch.Tensor:\n        \"\"\"AMP observation aligned with RSL reference (current frame).\"\"\"\n        return self._build_amp_obs_from_ref_state(0, prefix=prefix)\n\n    @property\n    def motion_end_mask(self) -> torch.Tensor:\n        \"\"\"[B] bool: per-step timeout mask.\n\n        Uses the per-step `motion_end_mask` set before resampling so the\n        event is observable within the same step, and falls back to a\n        direct comparison if not available.\n        \"\"\"\n        return self._motion_end_mask\n\n    @property\n    def global_robot_anchor_pos_cur(self):\n        return self.robot.data.body_pos_w[:, self.anchor_bodylink_idx]\n\n    def get_ref_motion_anchor_bodylink_global_pos_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        pos = self.get_ref_motion_bodylink_global_pos_cur(prefix=prefix)\n        return pos[:, self.anchor_bodylink_idx]\n\n    def get_ref_motion_anchor_bodylink_global_rot_wxyz_cur(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        rot = self.get_ref_motion_bodylink_global_rot_wxyz_cur(prefix=prefix)\n        return rot[:, self.anchor_bodylink_idx]\n\n    def get_ref_motion_anchor_bodylink_global_pos_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        pos = self.get_ref_motion_bodylink_global_pos_immediate_next(\n            prefix=prefix\n        )\n        return pos[:, self.anchor_bodylink_idx]\n\n    def get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next(\n        self,\n        prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        rot = self.get_ref_motion_bodylink_global_rot_wxyz_immediate_next(\n            prefix=prefix\n        )\n        return rot[:, self.anchor_bodylink_idx]\n\n    def _get_obs_bydmmc_ref_motion(\n        self,\n        obs_prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base_pos = self._get_ref_state_array(\"dof_pos\", obs_prefix)[:, 0, ...][\n            ..., self.urdf2sim_dof_idx\n        ]\n        base_vel = self._get_ref_state_array(\"dof_vel\", obs_prefix)[:, 0, ...][\n            ..., self.urdf2sim_dof_idx\n        ]\n        num_envs = base_pos.shape[0]\n        cur_ref_dof_pos_flat = base_pos.reshape(num_envs, -1)\n        cur_ref_dof_vel_flat = base_vel.reshape(num_envs, -1)\n        return torch.cat([cur_ref_dof_pos_flat, cur_ref_dof_vel_flat], dim=-1)\n\n    def _get_obs_bydmmc_ref_motion_fut(\n        self,\n        obs_prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base_pos = self._get_ref_state_array(\"dof_pos\", obs_prefix)[\n            :, 1:, ...\n        ][..., self.urdf2sim_dof_idx]\n        base_vel = self._get_ref_state_array(\"dof_vel\", obs_prefix)[\n            :, 1:, ...\n        ][..., self.urdf2sim_dof_idx]\n        num_envs = base_pos.shape[0]\n        n_fut_frames = int(self.cfg.n_fut_frames)\n        fut_ref_dof_pos_flat = base_pos.reshape(num_envs, n_fut_frames, -1)\n        fut_ref_dof_vel_flat = base_vel.reshape(num_envs, n_fut_frames, -1)\n        rel_fut_ref_motion_state_seq = torch.cat(\n            [fut_ref_dof_pos_flat, fut_ref_dof_vel_flat], dim=-1\n        )\n        return rel_fut_ref_motion_state_seq.reshape(num_envs, -1)\n\n    def _get_obs_vr_ref_motion_states(\n        self,\n        obs_prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base_pos = self._get_ref_state_array(\"dof_pos\", obs_prefix)[:, 0, ...][\n            ..., self.urdf2sim_dof_idx\n        ]\n        num_envs = base_pos.shape[0]\n        cur_ref_dof_pos_flat = base_pos.reshape(num_envs, -1)\n        return torch.cat(\n            [\n                cur_ref_dof_pos_flat,\n                torch.zeros_like(\n                    cur_ref_dof_pos_flat,\n                    device=cur_ref_dof_pos_flat.device,\n                ),\n            ],\n            dim=-1,\n        )\n\n    def _get_obs_vr_ref_motion_fut(\n        self,\n        obs_prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        base_pos = self._get_ref_state_array(\"dof_pos\", obs_prefix)[\n            :, 1:, ...\n        ][..., self.urdf2sim_dof_idx]\n        num_envs = base_pos.shape[0]\n        n_fut_frames = int(self.cfg.n_fut_frames)\n        fut_ref_dof_pos_flat = base_pos.reshape(num_envs, n_fut_frames, -1)\n        rel_fut_ref_motion_state_seq = torch.cat(\n            [\n                fut_ref_dof_pos_flat,\n                torch.zeros_like(\n                    fut_ref_dof_pos_flat, device=fut_ref_dof_pos_flat.device\n                ),\n            ],\n            dim=-1,\n        )\n        return rel_fut_ref_motion_state_seq.reshape(num_envs, -1)\n\n    def _get_obs_holomotion_rel_ref_motion_flat(\n        self,\n        obs_prefix: str = \"ref_\",\n    ) -> torch.Tensor:\n        # Gather all needed arrays with obs prefix\n        fut_rg_pos = self._get_ref_state_array(\"rg_pos\", obs_prefix)[\n            :, 1:, ...\n        ][..., self.urdf2sim_body_idx, :]\n        fut_rb_rot_xyzw = self._get_ref_state_array(\"rb_rot\", obs_prefix)[\n            :, 1:, ...\n        ][..., self.urdf2sim_body_idx, :]\n        fut_root_rot_xyzw = self._get_ref_state_array(\"root_rot\", obs_prefix)[\n            :, 1:, ...\n        ]\n        fut_root_lin_vel = self._get_ref_state_array(\"root_vel\", obs_prefix)[\n            :, 1:, ...\n        ]\n        fut_root_ang_vel = self._get_ref_state_array(\n            \"root_ang_vel\", obs_prefix\n        )[:, 1:, ...]\n        fut_dof_pos = self._get_ref_state_array(\"dof_pos\", obs_prefix)[\n            :, 1:, ...\n        ][..., self.urdf2sim_dof_idx]\n        fut_dof_vel = self._get_ref_state_array(\"dof_vel\", obs_prefix)[\n            :, 1:, ...\n        ][..., self.urdf2sim_dof_idx]\n\n        num_envs, num_fut_timesteps, num_bodies, _ = fut_rg_pos.shape\n        assert num_envs == self.num_envs\n        assert num_fut_timesteps == self.cfg.n_fut_frames\n\n        fut_ref_root_rot_quat = fut_root_rot_xyzw  # [B, T, 4]\n        fut_ref_root_rot_quat_inv = quat_inverse(\n            fut_ref_root_rot_quat, w_last=True\n        )  # [B, T, 4]\n        fut_ref_root_rot_quat_body_flat = (\n            fut_ref_root_rot_quat[:, :, None, :]\n            .repeat(1, 1, num_bodies, 1)\n            .reshape(-1, 4)\n        )\n        fut_ref_root_rot_quat_body_flat_inv = quat_inverse(\n            fut_ref_root_rot_quat_body_flat, w_last=True\n        )\n\n        ref_fut_heading_quat_inv = calc_heading_quat_inv(\n            fut_root_rot_xyzw.reshape(-1, 4),\n            w_last=True,\n        )  # [B*T, 4]\n        ref_fut_quat_rp = quat_mul(\n            ref_fut_heading_quat_inv,\n            fut_root_rot_xyzw.reshape(-1, 4),\n            w_last=True,\n        )  # [B*T, 4]\n\n        ref_fut_roll, ref_fut_pitch, _ = get_euler_xyz(\n            ref_fut_quat_rp,\n            w_last=True,\n        )\n        ref_fut_roll = wrap_to_pi(ref_fut_roll).reshape(\n            num_envs, num_fut_timesteps, -1\n        )  # [B, T, 1]\n        ref_fut_pitch = wrap_to_pi(ref_fut_pitch).reshape(\n            num_envs, num_fut_timesteps, -1\n        )  # [B, T, 1]\n        ref_fut_rp = torch.cat(\n            [ref_fut_roll, ref_fut_pitch], dim=-1\n        )  # [B, T, 2]\n        ref_fut_rp_flat = ref_fut_rp.reshape(num_envs, -1)  # [B, T * 2]\n        # ---\n\n        fut_ref_root_quat_inv_fut_flat = fut_ref_root_rot_quat_inv.reshape(\n            -1, 4\n        )\n        fut_ref_cur_root_rel_base_lin_vel = quat_rotate(\n            fut_ref_root_quat_inv_fut_flat,  # [B*T, 4]\n            fut_root_lin_vel.reshape(-1, 3),  # [B*T, 3]\n            w_last=True,\n        ).reshape(num_envs, -1)  # [B, num_fut_timesteps * 3]\n        fut_ref_cur_root_rel_base_ang_vel = quat_rotate(\n            fut_ref_root_quat_inv_fut_flat,  # [B*T, 4]\n            fut_root_ang_vel.reshape(-1, 3),  # [B*T, 3]\n            w_last=True,\n        ).reshape(num_envs, -1)  # [B, num_fut_timesteps * 3]\n        # ---\n\n        # --- calculate the absolute DoF position and velocity ---\n        fut_ref_dof_pos_flat = fut_dof_pos.reshape(num_envs, -1)\n        fut_ref_dof_vel_flat = fut_dof_vel.reshape(num_envs, -1)\n        # ---\n\n        # --- calculate the future per frame bodylink position and rotation ---\n        fut_ref_global_bodylink_pos = fut_rg_pos  # [B, T, num_bodies, 3]\n        fut_ref_global_bodylink_rot = fut_rb_rot_xyzw  # [B, T, num_bodies, 4]\n\n        # get root-relative bodylink position\n        fut_ref_root_rel_bodylink_pos = quat_rotate(\n            fut_ref_root_rot_quat_body_flat_inv,\n            (\n                fut_ref_global_bodylink_pos\n                - fut_ref_global_bodylink_pos[:, :, 0:1, :]\n            ).reshape(-1, 3),\n            w_last=True,\n        ).reshape(\n            num_envs, num_fut_timesteps, num_bodies, -1\n        )  # [B, num_fut_timesteps, num_bodies, 3]\n\n        # get root-relative bodylink rotation\n        fut_ref_root_rel_bodylink_rot = quat_mul(\n            fut_ref_root_rot_quat_body_flat_inv,\n            fut_ref_global_bodylink_rot.reshape(-1, 4),\n            w_last=True,\n        )\n        fut_ref_root_rel_bodylink_rot_mat = quaternion_to_matrix(\n            fut_ref_root_rel_bodylink_rot,\n            w_last=True,\n        )[:, :, :2].reshape(\n            num_envs, num_fut_timesteps, num_bodies, -1\n        )  # [B, num_fut_timesteps, num_bodies, 6]\n\n        rel_fut_ref_motion_state_seq = torch.cat(\n            [\n                ref_fut_rp_flat.reshape(\n                    num_envs, num_fut_timesteps, -1\n                ),  # [B, T, 2]\n                fut_ref_cur_root_rel_base_lin_vel.reshape(\n                    num_envs, num_fut_timesteps, -1\n                ),  # [B, T, 3]\n                fut_ref_cur_root_rel_base_ang_vel.reshape(\n                    num_envs, num_fut_timesteps, -1\n                ),  # [B, T, 3]\n                fut_ref_dof_pos_flat.reshape(\n                    num_envs, num_fut_timesteps, -1\n                ),  # [B, T, num_dofs]\n                fut_ref_dof_vel_flat.reshape(\n                    num_envs, num_fut_timesteps, -1\n                ),  # [B, T, num_dofs]\n                fut_ref_root_rel_bodylink_pos.reshape(\n                    num_envs, num_fut_timesteps, -1\n                ),  # [B, T, num_bodies*3]\n                fut_ref_root_rel_bodylink_rot_mat.reshape(\n                    num_envs, num_fut_timesteps, -1\n                ),  # [B, T, num_bodies*6]\n            ],\n            dim=-1,\n        )  # [B, T, 2 + 3 + 3 + num_dofs * 2 + num_bodies * (3 + 6)]\n        return rel_fut_ref_motion_state_seq.reshape(self.num_envs, -1)\n\n    def _resample_command(self, env_ids: Sequence[int], eval=False):\n        \"\"\"Resample command for specified environments.\"\"\"\n        if len(env_ids) == 0:\n            return\n\n        if not isinstance(env_ids, torch.Tensor):\n            env_ids = torch.tensor(env_ids, device=self.device)\n        else:\n            env_ids = env_ids.to(self.device)\n\n        if isinstance(env_ids, torch.Tensor):\n            idxs = env_ids\n        elif isinstance(env_ids, slice):\n            idxs = torch.arange(self.num_envs, device=self.device)\n        else:\n            idxs = torch.tensor(env_ids, device=self.device, dtype=torch.long)\n\n        idxs = self._filter_env_ids_for_motion_task(idxs.view(-1))\n        if idxs.numel() == 0:\n            return\n\n        self._record_completion_rate_for_envs(idxs)\n        clip_idx, frame_idx = self._motion_cache.sample_env_assignments(\n            len(idxs),\n            self.cfg.n_fut_frames,\n            self.device,\n            deterministic_start=(eval or self._is_evaluating),\n        )\n        self._clip_indices[idxs] = clip_idx\n        self._frame_indices[idxs] = frame_idx\n        self._start_frame_indices[idxs] = frame_idx\n        self._reward_sum_since_assign[idxs] = 0.0\n        self._step_count_since_assign[idxs] = 0.0\n        self._update_ref_motion_state_from_cache(env_ids=idxs)\n        self._align_root_to_ref(idxs)\n        self._align_dof_to_ref(idxs)\n\n    def _filter_env_ids_for_motion_task(\n        self, env_ids: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"Filter env_ids to those currently assigned to motion_tracking task.\n\n        In multi-task training, we may keep `ref_motion` registered for observation\n        schemas, but we must avoid applying motion-based state alignment to envs\n        that are not running motion tracking (e.g., velocity tracking only).\n\n        Behavior:\n        - If env does not expose multi-task task buffers, return env_ids (legacy).\n        - If env exposes task buffers but has no \"motion_tracking\" task, return empty.\n        - Otherwise, return env_ids where holo_task_ids == holo_task_name_to_id[\"motion_tracking\"].\n        \"\"\"\n        if env_ids.numel() == 0:\n            return env_ids\n\n        task_ids = getattr(self._env, \"holo_task_ids\", None)\n        task_name_to_id = getattr(self._env, \"holo_task_name_to_id\", None)\n        if task_ids is None or task_name_to_id is None:\n            return env_ids\n\n        motion_tid = task_name_to_id.get(\"motion_tracking\", None)\n        if motion_tid is None:\n            return env_ids[:0]\n\n        task_ids_t = task_ids.to(device=self.device, dtype=torch.long).view(-1)\n        env_ids_t = env_ids.to(device=self.device, dtype=torch.long).view(-1)\n        mask = task_ids_t[env_ids_t] == int(motion_tid)\n        return env_ids_t[mask]\n\n    def _align_root_to_ref(self, env_ids):\n        if not isinstance(env_ids, torch.Tensor):\n            env_ids = torch.tensor(\n                env_ids, device=self.device, dtype=torch.long\n            )\n        else:\n            env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1)\n        env_ids = self._filter_env_ids_for_motion_task(env_ids)\n        if env_ids.numel() == 0:\n            return\n\n        root_pos = self.get_ref_motion_root_global_pos_cur().clone()\n        root_rot_xyzw = self.get_ref_motion_root_global_rot_quat_xyzw_cur()\n        root_rot = root_rot_xyzw[..., [3, 0, 1, 2]].clone()\n        root_lin_vel = self.get_ref_motion_root_global_lin_vel_cur().clone()\n        root_ang_vel = self.get_ref_motion_root_global_ang_vel_cur().clone()\n\n        pos_rot_range_list = [\n            self.cfg.root_pose_perturb_range.get(key, (0.0, 0.0))\n            for key in [\"x\", \"y\", \"z\", \"roll\", \"pitch\", \"yaw\"]\n        ]\n        pos_rot_ranges = torch.tensor(pos_rot_range_list, device=self.device)\n        pos_rot_rand_deltas = isaaclab_math.sample_uniform(\n            pos_rot_ranges[:, 0],\n            pos_rot_ranges[:, 1],\n            (len(env_ids), 6),\n            device=self.device,\n        )\n        translation_delta = pos_rot_rand_deltas[:, 0:3]\n        rotation_delta = isaaclab_math.quat_from_euler_xyz(\n            pos_rot_rand_deltas[:, 3],\n            pos_rot_rand_deltas[:, 4],\n            pos_rot_rand_deltas[:, 5],\n        )\n\n        root_pos[env_ids] += translation_delta\n        root_rot[env_ids] = isaaclab_math.quat_mul(\n            rotation_delta,\n            root_rot[env_ids],\n        )\n\n        lin_ang_vel_range_list = [\n            self.cfg.root_vel_perturb_range.get(key, (0.0, 0.0))\n            for key in [\"x\", \"y\", \"z\", \"roll\", \"pitch\", \"yaw\"]\n        ]\n        lin_ang_vel_ranges = torch.tensor(\n            lin_ang_vel_range_list, device=self.device\n        )\n\n        lin_ang_vel_rand_deltas = isaaclab_math.sample_uniform(\n            lin_ang_vel_ranges[:, 0],\n            lin_ang_vel_ranges[:, 1],\n            (len(env_ids), 6),\n            device=self.device,\n        )\n        root_lin_vel[env_ids] += lin_ang_vel_rand_deltas[:, :3]\n        root_ang_vel[env_ids] += lin_ang_vel_rand_deltas[:, 3:]\n\n        self.robot.write_root_state_to_sim(\n            torch.cat(\n                [\n                    root_pos[env_ids],\n                    root_rot[env_ids],\n                    root_lin_vel[env_ids],\n                    root_ang_vel[env_ids],\n                ],\n                dim=-1,\n            ),\n            env_ids=env_ids,\n        )\n\n    def _align_dof_to_ref(self, env_ids):\n        if not isinstance(env_ids, torch.Tensor):\n            env_ids = torch.tensor(\n                env_ids, device=self.device, dtype=torch.long\n            )\n        else:\n            env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1)\n        env_ids = self._filter_env_ids_for_motion_task(env_ids)\n        if env_ids.numel() == 0:\n            return\n\n        dof_pos = self.get_ref_motion_dof_pos_cur().clone()\n        dof_vel = self.get_ref_motion_dof_vel_cur().clone()\n\n        dof_pos += isaaclab_math.sample_uniform(\n            *self.cfg.dof_pos_perturb_range,\n            dof_pos.shape,\n            dof_pos.device,\n        )\n        soft_dof_pos_limits = self.robot.data.soft_joint_pos_limits[env_ids]\n        dof_pos[env_ids] = torch.clip(\n            dof_pos[env_ids],\n            soft_dof_pos_limits[:, :, 0],\n            soft_dof_pos_limits[:, :, 1],\n        )\n\n        self.robot.write_joint_state_to_sim(\n            dof_pos[env_ids],\n            dof_vel[env_ids],\n            env_ids=env_ids,\n        )\n\n    def force_realign_root_state_to_ref_no_perturb(self, env_ids) -> None:\n        if not isinstance(env_ids, torch.Tensor):\n            env_ids = torch.tensor(\n                env_ids, device=self.device, dtype=torch.long\n            )\n        else:\n            env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1)\n        env_ids = self._filter_env_ids_for_motion_task(env_ids)\n        if env_ids.numel() == 0:\n            return\n\n        root_pos = self.get_ref_motion_root_global_pos_cur().clone()\n        root_rot_xyzw = self.get_ref_motion_root_global_rot_quat_xyzw_cur()\n        root_rot = root_rot_xyzw[..., [3, 0, 1, 2]].clone()\n        root_lin_vel = self.get_ref_motion_root_global_lin_vel_cur().clone()\n        root_ang_vel = self.get_ref_motion_root_global_ang_vel_cur().clone()\n        self.robot.write_root_state_to_sim(\n            torch.cat(\n                [\n                    root_pos[env_ids],\n                    root_rot[env_ids],\n                    root_lin_vel[env_ids],\n                    root_ang_vel[env_ids],\n                ],\n                dim=-1,\n            ),\n            env_ids=env_ids,\n        )\n\n    def force_realign_dof_state_to_ref_no_perturb(self, env_ids) -> None:\n        if not isinstance(env_ids, torch.Tensor):\n            env_ids = torch.tensor(\n                env_ids, device=self.device, dtype=torch.long\n            )\n        else:\n            env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1)\n        env_ids = self._filter_env_ids_for_motion_task(env_ids)\n        if env_ids.numel() == 0:\n            return\n\n        dof_pos = self.get_ref_motion_dof_pos_cur().clone()\n        dof_vel = self.get_ref_motion_dof_vel_cur().clone()\n        soft_dof_pos_limits = self.robot.data.soft_joint_pos_limits[env_ids]\n        dof_pos[env_ids] = torch.clip(\n            dof_pos[env_ids],\n            soft_dof_pos_limits[:, :, 0],\n            soft_dof_pos_limits[:, :, 1],\n        )\n\n        self.robot.write_joint_state_to_sim(\n            dof_pos[env_ids],\n            dof_vel[env_ids],\n            env_ids=env_ids,\n        )\n\n    def force_realign_offline_eval_no_perturb(self, env_ids) -> None:\n        self.force_realign_root_state_to_ref_no_perturb(env_ids)\n        self.force_realign_dof_state_to_ref_no_perturb(env_ids)\n\n    def _update_command(self):\n        all_ids = torch.arange(\n            self.num_envs, dtype=torch.long, device=self.device\n        )\n        motion_ids = self._filter_env_ids_for_motion_task(all_ids)\n        if motion_ids.numel() == 0:\n            return\n\n        continue_ids = motion_ids\n        episode_length_buf = getattr(self._env, \"episode_length_buf\", None)\n        if episode_length_buf is not None:\n            continue_mask = episode_length_buf[motion_ids] != 0\n            continue_ids = motion_ids[continue_mask]\n        if continue_ids.numel() > 0:\n            self._frame_indices[continue_ids] += 1\n        self._swap_step_counter += 1\n\n        if self._swap_step_counter >= self._motion_cache.swap_interval_steps:\n            self._swap_pending = True\n\n        # Resample when motion ends\n        self._resample_when_motion_end_cache()\n        self._update_ref_motion_state_from_cache()\n\n    def _resample_when_motion_end_cache(self):\n        \"\"\"Resample environments when motion ends (simple cache mode).\"\"\"\n        all_ids = torch.arange(\n            self.num_envs, dtype=torch.long, device=self.device\n        )\n        motion_ids = self._filter_env_ids_for_motion_task(all_ids)\n        if motion_ids.numel() == 0:\n            return\n\n        lengths = self._motion_cache.lengths_for_indices(self._clip_indices)\n        max_valid_frame = torch.clamp(\n            lengths - 1 - self.cfg.n_fut_frames, min=0\n        )\n        need_resample = (\n            self._frame_indices[motion_ids] > max_valid_frame[motion_ids]\n        )\n\n        if torch.any(need_resample):\n            resample_ids = motion_ids[torch.nonzero(need_resample).squeeze(-1)]\n            # Resample these envs\n            self._record_completion_rate_for_envs(resample_ids)\n            clip_idx, frame_idx = self._motion_cache.sample_env_assignments(\n                len(resample_ids),\n                self.cfg.n_fut_frames,\n                self.device,\n                deterministic_start=self._is_evaluating,\n            )\n            self._clip_indices[resample_ids] = clip_idx\n            self._frame_indices[resample_ids] = frame_idx\n            self._start_frame_indices[resample_ids] = frame_idx\n            self._reward_sum_since_assign[resample_ids] = 0.0\n            self._step_count_since_assign[resample_ids] = 0.0\n            # Realign robot state\n            self._update_ref_motion_state_from_cache(env_ids=resample_ids)\n            self._align_root_to_ref(resample_ids)\n            self._align_dof_to_ref(resample_ids)\n            # Mark motion end\n            self._motion_end_mask[motion_ids] = False\n            self._motion_end_mask[resample_ids] = True\n            self.motion_end_counter[resample_ids] += 1\n\n    def _update_metrics(self):\n        \"\"\"Update metrics for command progress tracking.\"\"\"\n        if not hasattr(self, \"metrics\"):\n            self.metrics = {}\n\n        self._update_mpjpe_metrics()\n        self._update_mpkpe_metrics()\n\n    def _update_mpjpe_metrics(self):\n        \"\"\"Update MPJPE (Mean Per Joint Position Error) metrics.\"\"\"\n        # Get current and reference joint positions\n        current_dof_pos = self.robot.data.joint_pos  # [B, num_dofs]\n        ref_dof_pos = self.get_ref_motion_dof_pos_immediate_next()\n\n        # Compute joint position errors\n        dof_pos_error = torch.abs(\n            current_dof_pos - ref_dof_pos\n        )  # [B, num_dofs]\n\n        # MPJPE whole body\n        mpjpe_wholebody = torch.mean(dof_pos_error, dim=-1)  # [B]\n\n        # MPJPE arms (using unified naming)\n        mpjpe_arms = torch.mean(\n            dof_pos_error[:, self.arm_dof_indices], dim=-1\n        )  # [B]\n\n        # MPJPE torso (using unified naming)\n        mpjpe_waist = torch.mean(\n            dof_pos_error[:, self.torso_dof_indices], dim=-1\n        )  # [B]\n\n        # MPJPE legs\n        mpjpe_legs = torch.mean(\n            dof_pos_error[:, self.leg_dof_indices], dim=-1\n        )  # [B]\n\n        # Initialize metric tensors if needed\n        for metric_name in [\n            \"Task/MPJPE_WholeBody\",\n            \"Task/MPJPE_Arms\",\n            \"Task/MPJPE_Waist\",\n            \"Task/MPJPE_Legs\",\n        ]:\n            if metric_name not in self.metrics:\n                self.metrics[metric_name] = torch.zeros(\n                    self.num_envs, device=self.device\n                )\n\n        # Update metric values\n        self.metrics[\"Task/MPJPE_WholeBody\"][:] = mpjpe_wholebody\n        self.metrics[\"Task/MPJPE_Arms\"][:] = mpjpe_arms\n        self.metrics[\"Task/MPJPE_Waist\"][:] = mpjpe_waist\n        self.metrics[\"Task/MPJPE_Legs\"][:] = mpjpe_legs\n\n    def _update_mpkpe_metrics(self):\n        \"\"\"Update MPKPE (Mean Per Keybody Position Error) metrics.\"\"\"\n        # Get current and reference body positions\n        current_body_pos = self.robot.data.body_pos_w  # [B, num_bodies, 3]\n        ref_body_pos = self.get_ref_motion_bodylink_global_pos_immediate_next()\n        # [B, num_bodies, 3]\n\n        # Compute body position errors (L2 norm)\n        body_pos_error = torch.norm(\n            current_body_pos - ref_body_pos, dim=-1\n        )  # [B, num_bodies]\n\n        # MPKPE whole body\n        mpkpe_wholebody = torch.mean(body_pos_error, dim=-1)  # [B]\n\n        # MPKPE arms (using unified naming)\n        mpkpe_arms = torch.mean(\n            body_pos_error[:, self.arm_body_indices], dim=-1\n        )  # [B]\n\n        # MPKPE torso (using unified naming)\n        mpkpe_waist = torch.mean(\n            body_pos_error[:, self.torso_body_indices], dim=-1\n        )  # [B]\n\n        # MPKPE legs\n        mpkpe_legs = torch.mean(\n            body_pos_error[:, self.leg_body_indices], dim=-1\n        )  # [B]\n\n        # Initialize metric tensors if needed\n        for metric_name in [\n            \"Task/MPKPE_WholeBody\",\n            \"Task/MPKPE_Arms\",\n            \"Task/MPKPE_Waist\",\n            \"Task/MPKPE_Legs\",\n        ]:\n            if metric_name not in self.metrics:\n                self.metrics[metric_name] = torch.zeros(\n                    self.num_envs, device=self.device\n                )\n\n        # Update metric values\n        self.metrics[\"Task/MPKPE_WholeBody\"][:] = mpkpe_wholebody\n        self.metrics[\"Task/MPKPE_Arms\"][:] = mpkpe_arms\n        self.metrics[\"Task/MPKPE_Waist\"][:] = mpkpe_waist\n        self.metrics[\"Task/MPKPE_Legs\"][:] = mpkpe_legs\n\n    # --- Pose-error getters for curriculum (WholeBody only) ---\n    def get_wholebody_mpjpe(\n        self,\n    ) -> torch.Tensor:\n        \"\"\"[B] current whole-body MPJPE (URDF joint-space abs error).\"\"\"\n        if not hasattr(self, \"metrics\") or (\n            \"Task/MPJPE_WholeBody\" not in self.metrics\n        ):\n            return torch.zeros(self.num_envs, device=self.device)\n        return self.metrics[\"Task/MPJPE_WholeBody\"]\n\n    def get_wholebody_mpkpe(\n        self,\n    ) -> torch.Tensor:\n        \"\"\"[B] current whole-body MPKPE (body position error).\"\"\"\n        if not hasattr(self, \"metrics\") or (\n            \"Task/MPKPE_WholeBody\" not in self.metrics\n        ):\n            return torch.zeros(self.num_envs, device=self.device)\n        return self.metrics[\"Task/MPKPE_WholeBody\"]\n\n    def get_current_motion_keys(\n        self,\n    ) -> list[str]:\n        \"\"\"Return motion window keys for the envs' current cached clips.\"\"\"\n        try:\n            if hasattr(self, \"_motion_cache\") and hasattr(\n                self._motion_cache, \"motion_keys_for_indices\"\n            ):\n                return self._motion_cache.motion_keys_for_indices(\n                    self._clip_indices\n                )\n        except Exception:\n            pass\n        return []\n\n    def _set_debug_vis_impl(self, debug_vis: bool):\n        if debug_vis:\n            # Just enable debug mode - visualizers will be created lazily in callback\n            self._debug_vis_enabled = True\n            # Set visibility if visualizers already exist\n            if hasattr(self, \"ref_body_visualizers\"):\n                for visualizer in self.ref_body_visualizers:\n                    visualizer.set_visibility(True)\n        else:\n            self._debug_vis_enabled = False\n            # Set visibility to false\n            if hasattr(self, \"ref_body_visualizers\"):\n                for visualizer in self.ref_body_visualizers:\n                    visualizer.set_visibility(False)\n\n    def setup_offline_eval_from_frame_zero(self):\n        \"\"\"Setup reference frame indices for offline evaluation from frame 0.\"\"\"\n\n        self._frame_indices[:] = 0\n\n        self._update_ref_motion_state()\n\n        logger.info(\n            f\"Offline evaluation setup complete: all {self.num_envs} \"\n            f\"environments set to frame 0 references\"\n        )\n\n    def setup_offline_eval_deterministic(\n        self, apply_pending_swap: bool = True\n    ) -> None:\n        \"\"\"Deterministic multi-env setup for offline evaluation.\n\n        - Optionally apply a pending cache swap.\n        - Set env i -> cache row i mapping for active clips, frame 0.\n        - Update reference state only. Robot realignment is handled by caller.\n        \"\"\"\n        if apply_pending_swap and getattr(self, \"_swap_pending\", False):\n            self._motion_cache.advance()\n            self._swap_pending = False\n            self._swap_step_counter = 0\n\n        clip_count = int(self._motion_cache.clip_count)\n        active_count = min(int(self.num_envs), clip_count)\n\n        # Reset indices\n        self._clip_indices[:] = 0\n        self._frame_indices[:] = 0\n\n        if active_count > 0:\n            active_ids = torch.arange(\n                active_count, dtype=torch.long, device=self.device\n            )\n            self._clip_indices[active_ids] = torch.arange(\n                active_count, dtype=torch.long, device=self.device\n            )\n\n        self._update_ref_motion_state_from_cache()\n\n    def _debug_vis_callback(self, event):\n        if not self.robot.is_initialized:\n            return\n\n        # Check if debug visualization is enabled\n        if not getattr(self, \"_debug_vis_enabled\", False):\n            return\n\n        # Check if motion cache/assignments are available\n        if (\n            not hasattr(self, \"_motion_cache\")\n            or self._motion_cache is None\n            or not hasattr(self, \"_clip_indices\")\n            or not hasattr(self, \"_frame_indices\")\n        ):\n            return\n\n        # Create visualizers lazily if they don't exist\n        if not hasattr(self, \"ref_body_visualizers\"):\n            self.ref_body_visualizers = []\n            # Get number of bodies from the reference motion data\n            num_bodies = self.get_ref_motion_bodylink_global_pos_cur().shape[\n                -2\n            ]\n            for i in range(num_bodies):\n                # Reference bodylinks as red spheres\n                self.ref_body_visualizers.append(\n                    VisualizationMarkers(\n                        self.cfg.body_keypoint_visualizer_cfg.replace(\n                            prim_path=f\"/Visuals/Command/ref_body_{i}\"\n                        )\n                    )\n                )\n\n        # Visualize reference body keypoints\n        if len(self.ref_body_visualizers) > 0:\n            ref_body_pos = self.get_ref_motion_bodylink_global_pos_cur()\n            # [B, num_bodies, 3]\n\n            num_bodies = min(\n                len(self.ref_body_visualizers), ref_body_pos.shape[1]\n            )\n\n            for i in range(num_bodies):\n                # Visualize reference bodylinks as spheres (position only)\n                self.ref_body_visualizers[i].visualize(\n                    ref_body_pos[:, i],  # [B, 3]\n                )\n\n\n@configclass\nclass MotionCommandCfg(CommandTermCfg):\n    \"\"\"Configuration for the motion command.\"\"\"\n\n    class_type: type = RefMotionCommand\n\n    command_obs_name: str = MISSING\n    urdf_dof_names: list[str] = MISSING\n    urdf_body_names: list[str] = MISSING\n\n    # DOF name groupings for mpjpe metrics (using unified naming)\n    arm_dof_names: list[str] = MISSING\n    waist_dof_names: list[str] = MISSING\n    leg_dof_names: list[str] = MISSING\n\n    # Body name groupings for mpkpe metrics (using unified naming)\n    arm_body_names: list[str] = MISSING\n    torso_body_names: list[str] = MISSING\n    leg_body_names: list[str] = MISSING\n\n    motion_lib_cfg: dict = MISSING\n    seed: int = MISSING\n    process_id: int = MISSING\n    num_processes: int = MISSING\n    is_evaluating: bool = MISSING\n    resample_time_interval_s: float = MISSING\n\n    n_fut_frames: int = MISSING\n    target_fps: int = MISSING\n\n    anchor_bodylink_name: str = \"pelvis\"\n\n    asset_name: str = MISSING\n    debug_vis: bool = False\n\n    root_pose_perturb_range: dict[str, tuple[float, float]] = {}\n    root_vel_perturb_range: dict[str, tuple[float, float]] = {}\n    dof_pos_perturb_range: tuple[float, float] = (-0.1, 0.1)\n    dof_vel_perturb_range: tuple[float, float] = (-1.0, 1.0)\n\n    body_keypoint_visualizer_cfg: VisualizationMarkersCfg = (\n        SPHERE_MARKER_CFG.replace(prim_path=\"/Visuals/Command/ref_keypoint\")\n    )\n    body_keypoint_visualizer_cfg.markers[\"sphere\"].radius = 0.03\n    body_keypoint_visualizer_cfg.markers[\n        \"sphere\"\n    ].visual_material = PreviewSurfaceCfg(\n        diffuse_color=(0.0, 0.0, 1.0)  # blue\n    )\n\n    resampling_time_range: tuple[float, float] = (1.0, 1.0)\n\n\n@configclass\nclass MoTrack_CommandsCfg:\n    pass\n\n\ndef build_motion_tracking_commands_config(command_config_dict: dict):\n    \"\"\"Build isaaclab-compatible CommandsCfg from a config dictionary.\n\n    Args:\n        command_config_dict: Dictionary mapping command names to command configurations.\n                           Each command config should contain the type and parameters.\n\n    Example:\n        command_config_dict = {\n            \"ref_motion\": {\n                \"type\": \"MotionCommandCfg\",\n                \"params\": {\n                    \"command_obs_name\": \"bydmmc_ref_motion\",\n                    \"motion_lib_cfg\": {...},\n                    \"process_id\": 0,\n                    \"num_processes\": 1,\n                    # ... other parameters\n                }\n            }\n        }\n    \"\"\"\n\n    commands_cfg = MoTrack_CommandsCfg()\n\n    # Add command terms dynamically\n    for command_name, command_config in command_config_dict.items():\n        command_type = command_config.get(\"type\", \"MotionCommandCfg\")\n        command_params = command_config.get(\"params\", {})\n\n        # Get the command class type\n        if command_type == \"MotionCommandCfg\":\n            command_cfg = MotionCommandCfg(**command_params)\n        else:\n            raise ValueError(f\"Unknown command type: {command_type}\")\n\n        # Add command to config\n        setattr(commands_cfg, command_name, command_cfg)\n\n    return commands_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_observation.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport isaaclab.envs.mdp as isaaclab_mdp\nimport isaaclab.sim as sim_utils\nfrom dataclasses import fields as dataclass_fields\nfrom isaaclab.actuators import ImplicitActuatorCfg\nfrom isaaclab.assets import Articulation, ArticulationCfg, AssetBaseCfg\nfrom isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg\nfrom isaaclab.managers import (\n    ActionTermCfg,\n    CommandTerm,\n    CommandTermCfg,\n    EventTermCfg as EventTerm,\n    ObservationGroupCfg,\n    ObservationGroupCfg as ObsGroup,\n    ObservationTermCfg,\n    ObservationTermCfg as ObsTerm,\n    RewardTermCfg,\n    SceneEntityCfg,\n    TerminationTermCfg,\n)\nimport torch\nfrom isaaclab.markers import (\n    VisualizationMarkers,\n    VisualizationMarkersCfg,\n)\nfrom isaaclab.markers.config import FRAME_MARKER_CFG\nfrom isaaclab.scene import InteractiveSceneCfg\nfrom isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns\nfrom isaaclab.sim import PhysxCfg, SimulationCfg\nfrom isaaclab.terrains import TerrainImporterCfg\nfrom isaaclab.utils import configclass\n\n\nimport isaaclab.utils.math as isaaclab_math\nimport isaaclab.utils.noise as isaaclab_noise\nfrom omegaconf import DictConfig, ListConfig, OmegaConf\n\nfrom holomotion.src.env.isaaclab_components.isaaclab_utils import (\n    resolve_holo_config,\n)\nfrom holomotion.src.utils.frame_utils import (\n    positions_world_to_env_frame,\n    root_relative_positions_from_env_frame,\n)\n\n\ndef _build_noise_cfg(noise_cfg):\n    noise_cfg = resolve_holo_config(noise_cfg)\n    if not (isinstance(noise_cfg, dict) and \"type\" in noise_cfg):\n        return noise_cfg\n\n    noise_cls = getattr(isaaclab_noise, noise_cfg[\"type\"])\n    noise_params = resolve_holo_config(noise_cfg.get(\"params\", {}))\n    if not isinstance(noise_params, dict):\n        return noise_cls(**noise_params)\n\n    noise_params = dict(noise_params)\n    if \"n_min_z\" in noise_params or \"n_max_z\" in noise_params:\n        base_n_min = noise_params[\"n_min\"]\n        base_n_max = noise_params[\"n_max\"]\n        noise_params[\"n_min\"] = torch.tensor(\n            [base_n_min, base_n_min, noise_params.pop(\"n_min_z\", base_n_min)],\n            dtype=torch.float32,\n        )\n        noise_params[\"n_max\"] = torch.tensor(\n            [base_n_max, base_n_max, noise_params.pop(\"n_max_z\", base_n_max)],\n            dtype=torch.float32,\n        )\n\n    return noise_cls(**noise_params)\n\n\nclass MirrorFunctions:\n    \"\"\"Generic observation mirroring utilities.\"\"\"\n\n    @staticmethod\n    def mirror_dof(\n        x: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"Mirror DOF-aligned tensor [..., A] with permutation and sign.\"\"\"\n        if x.shape[-1] != int(perm.numel()):\n            raise ValueError(\n                f\"mirror_dof expected last dim {perm.numel()}, got {x.shape[-1]}\"\n            )\n        if perm.device != x.device or perm.dtype != torch.long:\n            perm = perm.to(device=x.device, dtype=torch.long)\n        if sign.device != x.device or sign.dtype != x.dtype:\n            sign = sign.to(device=x.device, dtype=x.dtype)\n        mirrored = torch.index_select(x, dim=x.ndim - 1, index=perm)\n        sign_view = sign.view(*([1] * (mirrored.ndim - 1)), sign.numel())\n        return mirrored * sign_view\n\n    @staticmethod\n    def mirror_action(\n        actions: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"Mirror action tensor [..., A] in DOF space with permutation and sign.\"\"\"\n        return MirrorFunctions.mirror_dof(actions, perm=perm, sign=sign)\n\n    @staticmethod\n    def mirror_vec3(x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Mirror a true vector [..., 3] with sign [1, -1, 1].\"\"\"\n        if x.shape[-1] != 3:\n            raise ValueError(\n                f\"mirror_vec3 expected last dim 3, got {x.shape[-1]}\"\n            )\n        sign = torch.tensor(\n            [1.0, -1.0, 1.0], device=x.device, dtype=x.dtype\n        ).view(*([1] * (x.ndim - 1)), 3)\n        return x * sign\n\n    @staticmethod\n    def mirror_axial_vec3(x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Mirror an axial vector [..., 3] with sign [-1, 1, -1].\"\"\"\n        if x.shape[-1] != 3:\n            raise ValueError(\n                f\"mirror_axial_vec3 expected last dim 3, got {x.shape[-1]}\"\n            )\n        sign = torch.tensor(\n            [-1.0, 1.0, -1.0], device=x.device, dtype=x.dtype\n        ).view(*([1] * (x.ndim - 1)), 3)\n        return x * sign\n\n    @staticmethod\n    def mirror_velocity_command(x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Mirror velocity command [..., 3] or [..., 4] preserving move_mask.\"\"\"\n        last_dim = x.shape[-1]\n        if last_dim == 3:\n            sign = torch.tensor(\n                [1.0, -1.0, -1.0], device=x.device, dtype=x.dtype\n            ).view(*([1] * (x.ndim - 1)), 3)\n            return x * sign\n        if last_dim == 4:\n            sign = torch.tensor(\n                [1.0, 1.0, -1.0, -1.0], device=x.device, dtype=x.dtype\n            ).view(*([1] * (x.ndim - 1)), 4)\n            return x * sign\n        raise ValueError(\n            f\"mirror_velocity_command expected last dim 3 or 4, got {last_dim}\"\n        )\n\n\nclass ObservationFunctions:\n    \"\"\"Atomic observation functions.\n\n    The most foundamental observation functions are defined here, aiming to\n    utize the convenient functions from isaaclab apis. For complex observation\n    composition patterns, we'll use the custom observation serizliazer.\n    \"\"\"\n\n    @staticmethod\n    def _get_body_indices(\n        robot: Articulation, keybody_names: list[str] | None\n    ) -> list[int]:\n        \"\"\"Convert body names to indices.\n\n        Args:\n            robot: Robot articulation asset\n            keybody_names: List of body names. If None, returns all body indices.\n\n        Returns:\n            List of body indices corresponding to the given names\n        \"\"\"\n        if keybody_names is None:\n            return list(range(robot.num_bodies))\n\n        body_indices = []\n        for name in keybody_names:\n            if name not in robot.body_names:\n                raise ValueError(\n                    f\"Body '{name}' not found in robot.body_names: {robot.body_names}\"\n                )\n            body_indices.append(robot.body_names.index(name))\n\n        return body_indices\n\n    @staticmethod\n    def _slice_future_frames(\n        tensor: torch.Tensor,\n        *,\n        num_frames: int | None,\n        obs_name: str,\n    ) -> torch.Tensor:\n        if num_frames is None:\n            return tensor\n        num_frames = int(num_frames)\n        if num_frames <= 0:\n            raise ValueError(\n                f\"{obs_name} num_frames must be positive, got {num_frames}.\"\n            )\n        if tensor.ndim < 2:\n            raise ValueError(\n                f\"{obs_name} expected future tensor with ndim >= 2, got {tensor.ndim}.\"\n            )\n        if int(tensor.shape[1]) < num_frames:\n            raise ValueError(\n                f\"{obs_name} requested {num_frames} future frames, but only \"\n                f\"{int(tensor.shape[1])} are available.\"\n            )\n        return tensor[:, :num_frames, ...]\n\n    # ------- Robot Head / mid360 States -------\n    @staticmethod\n    def _get_obs_head_pos_quat_vel(\n        env: ManagerBasedRLEnv, robot_asset_name: str = \"robot\"\n    ):\n        \"\"\"Head (mid360) features in torso frame with first-frame anchor.\n\n        Returns [B,13]: pos(3), quat_wxyz->xyzw(4), lin_vel(3), ang_vel(3), all in torso frame and anchored.\n        \"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        body_names = robot_ptr.body_names\n        if body_names is None:\n            raise RuntimeError(\"robot.body_names is empty\")\n        try:\n            torso_idx = body_names.index(\"torso_link\")\n        except ValueError:\n            raise ValueError(\n                f\"'torso_link' not found in body_names: {body_names}\"\n            )\n\n        B = env.num_envs\n        device = env.device\n        # Mid360 extrinsics relative to torso (rotation about Y by pitch)\n        rel_pos_t = torch.tensor(\n            [0.0002835, 0.00003, 0.41618], dtype=torch.float, device=device\n        )\n        pitch = torch.tensor(\n            0.04014257279586953, dtype=torch.float, device=device\n        )\n        half = pitch * 0.5\n        # WXYZ\n        rel_quat_wxyz = torch.stack(\n            [\n                torch.cos(half),\n                torch.zeros_like(half),\n                torch.sin(half),\n                torch.zeros_like(half),\n            ],\n            dim=-1,\n        )\n        rel_quat_wxyz = rel_quat_wxyz.expand(B, -1)\n\n        # World pose/vel from torso + extrinsics (WXYZ math)\n        torso_pos_w = robot_ptr.data.body_pos_w[:, torso_idx, :]\n        torso_quat_wxyz = robot_ptr.data.body_quat_w[:, torso_idx, :]\n        torso_lin_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :]\n        torso_ang_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :]\n\n        rel_pos = rel_pos_t.expand(B, -1)\n        r_world = isaaclab_math.quat_apply(torso_quat_wxyz, rel_pos)\n        pos_w = torso_pos_w + r_world\n        quat_wxyz = isaaclab_math.quat_mul(torso_quat_wxyz, rel_quat_wxyz)\n        lin_w = torso_lin_w + torch.cross(torso_ang_w, r_world, dim=-1)\n        ang_w = torso_ang_w\n\n        # Convert to torso frame (WXYZ math)\n        rel_p = pos_w - torso_pos_w\n        torso_inv_wxyz = isaaclab_math.quat_inv(torso_quat_wxyz)\n        pos_torso = isaaclab_math.quat_apply(torso_inv_wxyz, rel_p)\n        lin_torso = isaaclab_math.quat_apply(\n            torso_inv_wxyz, lin_w - torch.cross(ang_w, rel_p, dim=-1)\n        )\n        ang_torso = isaaclab_math.quat_apply(torso_inv_wxyz, ang_w)\n        quat_torso_wxyz = isaaclab_math.quat_mul(torso_inv_wxyz, quat_wxyz)\n        # export quaternion as XYZW to match common obs format\n        quat_torso_xyzw = quat_torso_wxyz[..., [1, 2, 3, 0]]\n\n        # First-frame anchor normalization (in torso frame)\n        if not hasattr(env, \"head_anchor_set\"):\n            env.head_anchor_set = torch.zeros(\n                B, dtype=torch.bool, device=device\n            )\n            env.head_anchor_pos = torch.zeros(B, 3, device=device)\n            env.head_anchor_quat_wxyz = torch.zeros(B, 4, device=device)\n            env.head_anchor_quat_wxyz[:, 0] = 1.0  # identity W\n        unset = ~env.head_anchor_set\n        if unset.any():\n            env.head_anchor_pos[unset] = pos_torso[unset]\n            env.head_anchor_quat_wxyz[unset] = quat_torso_wxyz[unset]\n            env.head_anchor_set[unset] = True\n        q0_inv = isaaclab_math.quat_inv(env.head_anchor_quat_wxyz)\n        pos_rel = isaaclab_math.quat_apply(\n            q0_inv, pos_torso - env.head_anchor_pos\n        )\n        lin_rel = isaaclab_math.quat_apply(q0_inv, lin_torso)\n        ang_rel = isaaclab_math.quat_apply(q0_inv, ang_torso)\n        quat_rel_wxyz = isaaclab_math.quat_mul(q0_inv, quat_torso_wxyz)\n        quat_rel_xyzw = quat_rel_wxyz[..., [1, 2, 3, 0]]\n        return torch.cat([pos_rel, quat_rel_xyzw, lin_rel, ang_rel], dim=-1)\n\n    @staticmethod\n    def _get_obs_rel_headlink_lin_vel(\n        env: ManagerBasedRLEnv, robot_asset_name: str = \"robot\"\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Headlink relative linear velocity, expressed in the headlink's frame.\n\n        Definitions:\n        - Headlink: a virtual rigid sensor frame fixed to `torso_link` using the\n          extrinsics defined below (translation `rel_pos_t` and rotation `rel_quat_wxyz`).\n        - Relative linear velocity: v_head - v_torso_origin, both measured in the world\n          frame before re-expression. For a rigid mount, this equals ω_torso × r_world.\n        - Expression frame: the instantaneous headlink frame (i.e., result is in headlink axes).\n\n        Returns:\n            Tensor of shape [num_envs, 3]: headlink relative linear velocity in headlink frame.\n        \"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        body_names = robot_ptr.body_names\n        if body_names is None:\n            raise RuntimeError(\"robot.body_names is empty\")\n        torso_idx = body_names.index(\"torso_link\")\n\n        num_envs = env.num_envs\n        device = env.device\n        # Headlink extrinsics relative to torso: translation + rotation about Y (pitch)\n        rel_pos_t = torch.tensor(\n            [0.0002835, 0.00003, 0.41618], dtype=torch.float, device=device\n        )  # [3]\n        pitch = torch.tensor(\n            0.04014257279586953, dtype=torch.float, device=device\n        )\n        half = pitch * 0.5\n        # Quaternion (WXYZ) for rotation about Y by 'pitch'\n        rel_quat_wxyz = torch.stack(\n            [\n                torch.cos(half),\n                torch.zeros_like(half),\n                torch.sin(half),\n                torch.zeros_like(half),\n            ],\n            dim=-1,\n        ).expand(num_envs, -1)  # [num_envs, 4]\n\n        # Torso world state\n        torso_quat_wxyz = robot_ptr.data.body_quat_w[\n            :, torso_idx, :\n        ]  # [num_envs, 4]\n        torso_lin_w = robot_ptr.data.body_lin_vel_w[\n            :, torso_idx, :\n        ]  # [num_envs, 3]\n        torso_ang_w = robot_ptr.data.body_ang_vel_w[\n            :, torso_idx, :\n        ]  # [num_envs, 3]\n\n        # Headlink world pose from torso + extrinsics\n        rel_pos = rel_pos_t.expand(num_envs, -1)  # [num_envs, 3]\n        r_world = isaaclab_math.quat_apply(\n            torso_quat_wxyz, rel_pos\n        )  # [num_envs, 3]\n        head_quat_wxyz = isaaclab_math.quat_mul(\n            torso_quat_wxyz, rel_quat_wxyz\n        )  # [num_envs, 4]\n\n        # World-frame velocities\n        head_lin_w = torso_lin_w + torch.cross(\n            torso_ang_w, r_world, dim=-1\n        )  # [num_envs, 3]\n        # Relative linear velocity in world frame\n        rel_lin_w = (\n            head_lin_w - torso_lin_w\n        )  # [num_envs, 3] == ω_torso × r_world\n\n        # Re-express in headlink frame\n        head_inv_wxyz = isaaclab_math.quat_inv(head_quat_wxyz)  # [num_envs, 4]\n        rel_lin_head = isaaclab_math.quat_apply(\n            head_inv_wxyz, rel_lin_w\n        )  # [num_envs, 3]\n        return rel_lin_head\n\n    @staticmethod\n    def _get_obs_rel_headlink_ang_vel(\n        env: ManagerBasedRLEnv, robot_asset_name: str = \"robot\"\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Headlink relative angular velocity, expressed in the headlink's frame.\n\n        Definitions:\n        - Headlink: a virtual rigid sensor frame fixed to `torso_link` using the\n          extrinsics defined below.\n        - Relative angular velocity: ω_head - ω_torso, measured in the world frame,\n          then re-expressed in the headlink frame.\n        - For a rigid mount (no neck articulation), ω_head == ω_torso, so the result\n          is identically zero. If an articulated head exists, replace ω_head with the\n          head link's world angular velocity before the subtraction.\n\n        Returns:\n            Tensor of shape [num_envs, 3]: headlink relative angular velocity in headlink frame.\n        \"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        body_names = robot_ptr.body_names\n        if body_names is None:\n            raise RuntimeError(\"robot.body_names is empty\")\n        torso_idx = body_names.index(\"torso_link\")\n\n        num_envs = env.num_envs\n        device = env.device\n        # Headlink extrinsics (rotation about Y by pitch)\n        pitch = torch.tensor(\n            0.04014257279586953, dtype=torch.float, device=device\n        )\n        half = pitch * 0.5\n        rel_quat_wxyz = torch.stack(\n            [\n                torch.cos(half),\n                torch.zeros_like(half),\n                torch.sin(half),\n                torch.zeros_like(half),\n            ],\n            dim=-1,\n        ).expand(num_envs, -1)  # [num_envs, 4]\n\n        torso_quat_wxyz = robot_ptr.data.body_quat_w[\n            :, torso_idx, :\n        ]  # [num_envs, 4]\n        torso_ang_w = robot_ptr.data.body_ang_vel_w[\n            :, torso_idx, :\n        ]  # [num_envs, 3]\n\n        # For the rigid mount, ω_head_w == ω_torso_w\n        head_ang_w = torso_ang_w  # [num_envs, 3]\n        rel_ang_w = (\n            head_ang_w - torso_ang_w\n        )  # [num_envs, 3] -> zeros for rigid mount\n\n        # Re-express in headlink frame\n        head_quat_wxyz = isaaclab_math.quat_mul(\n            torso_quat_wxyz, rel_quat_wxyz\n        )  # [num_envs, 4]\n        head_inv_wxyz = isaaclab_math.quat_inv(head_quat_wxyz)  # [num_envs, 4]\n        rel_ang_head = isaaclab_math.quat_apply(\n            head_inv_wxyz, rel_ang_w\n        )  # [num_envs, 3]\n        return rel_ang_head\n\n    # ------- Robot Root States -------\n    @staticmethod\n    def _get_obs_global_robot_root_pos(env: ManagerBasedRLEnv):\n        \"\"\"Asset root position in the environment frame.\n\n        IsaacLab's root position helpers subtract `env.scene.env_origins`, so\n        this is not the raw simulator-world position.\n        \"\"\"\n        return isaaclab_mdp.root_pos_w(env)\n\n    @staticmethod\n    def _get_obs_global_robot_root_rot_wxyz(env: ManagerBasedRLEnv):\n        \"\"\"Asset root orientation (w, x, y, z) in the environment frame.\"\"\"\n        return isaaclab_mdp.root_quat_w(env)\n\n    @staticmethod\n    def _get_obs_global_robot_root_rot_xyzw(env: ManagerBasedRLEnv):\n        \"\"\"Asset root orientation (x, y, z, w) in the environment frame.\"\"\"\n        return ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)[\n            ..., [1, 2, 3, 0]\n        ]\n\n    @staticmethod\n    def _get_obs_global_robot_root_rot_mat(env: ManagerBasedRLEnv):\n        \"\"\"Asset root orientation as a 3x3 matrix, flattened to the first two rows (6D).\"\"\"\n        return isaaclab_math.matrix_from_quat(\n            ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)\n        )[..., :2]  # [num_envs, 6]\n\n    @staticmethod\n    def _get_obs_global_robot_root_lin_vel(env: ManagerBasedRLEnv):\n        \"\"\"Asset root linear velocity in the environment frame.\"\"\"\n        return isaaclab_mdp.root_lin_vel_w(env)  # [num_envs, 3]\n\n    @staticmethod\n    def _get_obs_global_robot_root_ang_vel(env: ManagerBasedRLEnv):\n        \"\"\"Asset root angular velocity in the environment frame.\"\"\"\n        return isaaclab_mdp.root_ang_vel_w(env)  # [num_envs, 3]\n\n    @staticmethod\n    def _get_obs_rel_robot_root_lin_vel(env: ManagerBasedRLEnv):\n        \"\"\"Relative root linear velocity in the root frame.\"\"\"\n        return isaaclab_mdp.base_lin_vel(env)  # [num_envs, 3]\n\n    @staticmethod\n    def _get_obs_rel_robot_root_ang_vel(env: ManagerBasedRLEnv):\n        \"\"\"Relative root angular velocity in the root frame.\"\"\"\n        return isaaclab_mdp.base_ang_vel(env)  # [num_envs, 3]\n\n    @staticmethod\n    def _get_obs_rel_anchor_lin_vel(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        anchor_bodylink_name: str = \"torso_link\",\n    ):\n        \"\"\"Relative anchor linear velocity in the anchor frame.\"\"\"\n        torso_global_rot_quat_wxyz = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(\n                env, robot_asset_name, [anchor_bodylink_name]\n            )\n        )  # [num_envs, 1, 4]\n        torso_global_lin_vel = (\n            ObservationFunctions._get_obs_global_robot_bodylink_lin_vel(\n                env, robot_asset_name, [anchor_bodylink_name]\n            )\n        )  # [num_envs, 1, 3]\n        torso_rel_lin_vel = isaaclab_math.quat_apply(\n            isaaclab_math.quat_inv(torso_global_rot_quat_wxyz),\n            torso_global_lin_vel,\n        )  # [num_envs, 1, 3]\n        return torso_rel_lin_vel.squeeze(1)  # [num_envs, 3]\n\n    @staticmethod\n    def _get_obs_projected_gravity(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Gravity vector projected into the robot's root frame.\n\n        Projects the world-frame gravity vector into the robot's base frame\n        using the inverse root orientation quaternion.\n        \"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        g_w: torch.Tensor = robot_ptr.data.GRAVITY_VEC_W  # [num_envs, 3]\n        root_quat_wxyz: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)\n        )  # [num_envs, 4]\n\n        # Project gravity into root frame using inverse quaternion\n        projected_gravity: torch.Tensor = isaaclab_math.quat_apply_inverse(\n            root_quat_wxyz, g_w\n        )  # [num_envs, 3]\n\n        return projected_gravity\n\n    @staticmethod\n    def _get_obs_global_robot_root_yaw(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n    ):\n        \"\"\"Robot's yaw heading in the environment frame (in radians).\"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        return robot_ptr.data.heading_w  # [num_envs, ]\n\n    # @torch.compile\n    @staticmethod\n    def _get_obs_robot_root_heading_aligned_quat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n    ):\n        \"\"\"A quaternion representing only the robot's yaw heading.\"\"\"\n        global_yaw = ObservationFunctions._get_obs_global_robot_root_yaw(\n            env,\n            robot_asset_name,\n        )  # [num_envs, ]\n        zero_roll = torch.zeros_like(global_yaw, device=env.device)\n        zero_pitch = torch.zeros_like(global_yaw, device=env.device)\n        heading_aligned_quat = isaaclab_math.quat_from_angle_axis(\n            roll=zero_roll,\n            pitch=zero_pitch,\n            yaw=global_yaw,\n        )  # [num_envs, 4]\n        return heading_aligned_quat  # [num_envs, 4]\n\n    # @torch.compile\n    @staticmethod\n    def _get_obs_rel_robot_root_roll_pitch(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n    ):\n        \"\"\"Robot's roll and pitch relative to its heading-aligned frame.\"\"\"\n        heading_aligned_quat = (\n            ObservationFunctions._get_obs_robot_root_heading_aligned_quat(\n                env,\n                robot_asset_name,\n            )\n        )  # [num_envs, 4]\n        robot_quat_in_heading_aligned_frame = isaaclab_math.quat_mul(\n            isaaclab_math.quat_inv(heading_aligned_quat),\n            ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env),\n        )  # [num_envs, 4]\n        rel_roll, rel_pitch, _ = isaaclab_math.get_euler_xyz(\n            robot_quat_in_heading_aligned_frame\n        )  # [num_envs, 3]\n        return torch.stack([rel_roll, rel_pitch], dim=-1)  # [num_envs, 2]\n\n    # ------- Robot Bodylink States -------\n    @staticmethod\n    def _get_obs_global_robot_bodylink_pos(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ):\n        \"\"\"Positions of specified bodylinks in the environment frame.\n\n        Body link poses are stored in simulator-world coordinates, so this\n        helper subtracts `env.scene.env_origins` to match IsaacLab's\n        environment-frame root helpers.\n        \"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        keybody_idxs = ObservationFunctions._get_body_indices(\n            robot_ptr, keybody_names\n        )\n        keybody_global_pos = positions_world_to_env_frame(\n            robot_ptr.data.body_pos_w[:, keybody_idxs],\n            env.scene.env_origins,\n        )\n        return keybody_global_pos  # [num_envs, num_keybodies, 3]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_rot_wxyz(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ):\n        \"\"\"Orientations (w, x, y, z) of specified bodylinks in the environment frame.\"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        keybody_idxs = ObservationFunctions._get_body_indices(\n            robot_ptr, keybody_names\n        )\n        keybody_global_rot = robot_ptr.data.body_quat_w[:, keybody_idxs]\n        return keybody_global_rot  # [num_envs, num_keybodies, 4]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_rot_xyzw(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ):\n        \"\"\"Orientations (x, y, z, w) of specified bodylinks in the environment frame.\"\"\"\n        return ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(\n            env,\n            robot_asset_name,\n            keybody_names,\n        )[..., [1, 2, 3, 0]]  # [num_envs, num_keybodies, 4]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_rot_mat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ):\n        \"\"\"Orientations of specified bodylinks as a 3x3 matrix, flattened to the first two rows (6D).\"\"\"\n        keybody_global_rot_wxyz = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(\n                env,\n                robot_asset_name,\n                keybody_names,\n            )\n        )\n        return isaaclab_math.matrix_from_quat(keybody_global_rot_wxyz)[\n            ..., :2\n        ]  # [num_envs, num_keybodies, 6]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_lin_vel(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ):\n        \"\"\"Linear velocities of specified bodylinks in the environment frame.\"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        keybody_idxs = ObservationFunctions._get_body_indices(\n            robot_ptr, keybody_names\n        )\n        keybody_global_lin_vel = robot_ptr.data.body_lin_vel_w[:, keybody_idxs]\n        return keybody_global_lin_vel  # [num_envs, num_keybodies, 3]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_ang_vel(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ):\n        \"\"\"Angular velocities of specified bodylinks in the environment frame.\"\"\"\n        robot_ptr = env.scene[robot_asset_name]\n        keybody_idxs = ObservationFunctions._get_body_indices(\n            robot_ptr, keybody_names\n        )\n        keybody_global_ang_vel = robot_ptr.data.body_ang_vel_w[:, keybody_idxs]\n        return keybody_global_ang_vel  # [num_envs, num_keybodies, 3]\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_pos(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies, 3]\n        \"\"\"Root-relative bodylink positions from environment-frame positions.\"\"\"\n        # Get global states\n        keybody_global_pos: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_bodylink_pos(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 3]\n\n        global_root_pos: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_root_pos(env)\n        )  # [num_envs, 3]\n        root_global_rot_wxyz: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)\n        )  # [num_envs, 4]\n\n        return root_relative_positions_from_env_frame(\n            body_pos_env=keybody_global_pos,\n            root_pos_env=global_root_pos,\n            root_quat_w=root_global_rot_wxyz,\n        )\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_rot_wxyz(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies, 4]\n        \"\"\"Orientations (w, x, y, z) of specified bodylinks relative to the robot's root frame.\"\"\"\n        # Get global states\n        keybody_global_rot: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 4]\n\n        root_global_rot_wxyz: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)\n        )  # [num_envs, 4]\n\n        # Transform to root frame by multiplying with inverse root rotation\n        root_inv_rot: torch.Tensor = isaaclab_math.quat_inv(\n            root_global_rot_wxyz\n        )  # [num_envs, 4]\n        num_bodies = keybody_global_rot.shape[1]\n        rel_rot_root: torch.Tensor = isaaclab_math.quat_mul(\n            root_inv_rot[..., None, :].expand(-1, num_bodies, -1),\n            keybody_global_rot,\n        )  # [num_envs, num_keybodies, 4]\n\n        return rel_rot_root\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_rot_xyzw(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies, 4]\n        \"\"\"Orientations (x, y, z, w) of specified bodylinks relative to the robot's root frame.\"\"\"\n        return ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz(\n            env, robot_asset_name, keybody_names\n        )[\n            ..., [1, 2, 3, 0]\n        ]  # [num_envs, num_keybodies, 4] - convert WXYZ to XYZW\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_rot_mat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies, 6]\n        \"\"\"Orientations of specified bodylinks relative to the robot's root frame, as a 3x3 matrix, flattened to the first two rows (6D).\"\"\"\n        keybody_rel_rot_wxyz: torch.Tensor = (\n            ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 4]\n\n        return isaaclab_math.matrix_from_quat(keybody_rel_rot_wxyz)[\n            ..., :2\n        ]  # [num_envs, num_keybodies, 6]\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_lin_vel(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies, 3]\n        \"\"\"Linear velocities of specified bodylinks relative to the robot's root frame.\"\"\"\n        # Get global states\n        keybody_global_lin_vel: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_bodylink_lin_vel(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 3]\n        root_global_lin_vel: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_root_lin_vel(env)\n        )  # [num_envs, 3]\n        root_global_rot_wxyz: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)\n        )  # [num_envs, 4]\n\n        # Compute relative velocity in world frame\n        rel_lin_vel_w = keybody_global_lin_vel - root_global_lin_vel.unsqueeze(\n            1\n        )\n\n        # Transform to root frame by rotating with inverse root rotation\n        root_inv_rot: torch.Tensor = isaaclab_math.quat_inv(\n            root_global_rot_wxyz\n        )  # [num_envs, 4]\n        rel_lin_vel_root: torch.Tensor = isaaclab_math.quat_apply(\n            root_inv_rot.unsqueeze(1), rel_lin_vel_w\n        )  # [num_envs, num_keybodies, 3]\n\n        return rel_lin_vel_root\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_ang_vel(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies, 3]\n        \"\"\"Angular velocities of specified bodylinks relative to the robot's root frame.\"\"\"\n        # Get global states\n        keybody_global_ang_vel: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_bodylink_ang_vel(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 3]\n        root_global_ang_vel: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_root_ang_vel(env)\n        )  # [num_envs, 3]\n        root_global_rot_wxyz: torch.Tensor = (\n            ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)\n        )  # [num_envs, 4]\n\n        # Compute relative angular velocity in world frame\n        rel_ang_vel_w = keybody_global_ang_vel - root_global_ang_vel.unsqueeze(\n            1\n        )\n\n        # Transform to root frame by rotating with inverse root rotation\n        root_inv_rot: torch.Tensor = isaaclab_math.quat_inv(\n            root_global_rot_wxyz\n        )  # [num_envs, 4]\n        rel_ang_vel_root: torch.Tensor = isaaclab_math.quat_apply(\n            root_inv_rot.unsqueeze(1), rel_ang_vel_w\n        )  # [num_envs, num_keybodies, 3]\n\n        return rel_ang_vel_root\n\n    # ------- Flat Bodylink Observations -------\n    @staticmethod\n    def _get_obs_global_robot_bodylink_pos_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 3]\n        \"\"\"Flattened positions of specified bodylinks in the environment frame.\"\"\"\n        bodylink_pos = ObservationFunctions._get_obs_global_robot_bodylink_pos(\n            env, robot_asset_name, keybody_names\n        )  # [num_envs, num_keybodies, 3]\n        return bodylink_pos.reshape(\n            bodylink_pos.shape[0], -1\n        )  # [num_envs, num_keybodies * 3]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_rot_wxyz_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 4]\n        \"\"\"Flattened orientations (w, x, y, z) of specified bodylinks in the environment frame.\"\"\"\n        bodylink_rot = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 4]\n        return bodylink_rot.reshape(\n            bodylink_rot.shape[0], -1\n        )  # [num_envs, num_keybodies * 4]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_rot_xyzw_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 4]\n        \"\"\"Flattened orientations (x, y, z, w) of specified bodylinks in the environment frame.\"\"\"\n        bodylink_rot = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_xyzw(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 4]\n        return bodylink_rot.reshape(\n            bodylink_rot.shape[0], -1\n        )  # [num_envs, num_keybodies * 4]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_rot_mat_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 6]\n        \"\"\"Flattened orientation matrices (6D) of specified bodylinks in the environment frame.\"\"\"\n        bodylink_rot_mat = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_mat(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 6]\n        return bodylink_rot_mat.reshape(\n            bodylink_rot_mat.shape[0], -1\n        )  # [num_envs, num_keybodies * 6]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_lin_vel_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 3]\n        \"\"\"Flattened linear velocities of specified bodylinks in the environment frame.\"\"\"\n        bodylink_lin_vel = (\n            ObservationFunctions._get_obs_global_robot_bodylink_lin_vel(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 3]\n        return bodylink_lin_vel.reshape(\n            bodylink_lin_vel.shape[0], -1\n        )  # [num_envs, num_keybodies * 3]\n\n    @staticmethod\n    def _get_obs_global_robot_bodylink_ang_vel_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 3]\n        \"\"\"Flattened angular velocities of specified bodylinks in the environment frame.\"\"\"\n        bodylink_ang_vel = (\n            ObservationFunctions._get_obs_global_robot_bodylink_ang_vel(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 3]\n        return bodylink_ang_vel.reshape(\n            bodylink_ang_vel.shape[0], -1\n        )  # [num_envs, num_keybodies * 3]\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_pos_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 3]\n        \"\"\"Flattened positions of specified bodylinks relative to the robot's root frame.\"\"\"\n        bodylink_pos = (\n            ObservationFunctions._get_obs_root_rel_robot_bodylink_pos(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 3]\n        return bodylink_pos.reshape(\n            bodylink_pos.shape[0], -1\n        )  # [num_envs, num_keybodies * 3]\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_rot_wxyz_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 4]\n        \"\"\"Flattened orientations (w, x, y, z) of specified bodylinks relative to the robot's root frame.\"\"\"\n        bodylink_rot = (\n            ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 4]\n        return bodylink_rot.reshape(\n            bodylink_rot.shape[0], -1\n        )  # [num_envs, num_keybodies * 4]\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_rot_xyzw_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 4]\n        \"\"\"Flattened orientations (x, y, z, w) of specified bodylinks relative to the robot's root frame.\"\"\"\n        bodylink_rot = (\n            ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_xyzw(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 4]\n        return bodylink_rot.reshape(\n            bodylink_rot.shape[0], -1\n        )  # [num_envs, num_keybodies * 4]\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_rot_mat_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 6]\n        \"\"\"Flattened orientation matrices (6D) of specified bodylinks relative to the robot's root frame.\"\"\"\n        bodylink_rot_mat = (\n            ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_mat(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 6]\n        return bodylink_rot_mat.reshape(\n            bodylink_rot_mat.shape[0], -1\n        )  # [num_envs, num_keybodies * 6]\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_lin_vel_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 3]\n        \"\"\"Flattened linear velocities of specified bodylinks relative to the robot's root frame.\"\"\"\n        bodylink_lin_vel = (\n            ObservationFunctions._get_obs_root_rel_robot_bodylink_lin_vel(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 3]\n        return bodylink_lin_vel.reshape(\n            bodylink_lin_vel.shape[0], -1\n        )  # [num_envs, num_keybodies * 3]\n\n    @staticmethod\n    def _get_obs_root_rel_robot_bodylink_ang_vel_flat(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies * 3]\n        \"\"\"Flattened angular velocities of specified bodylinks relative to the robot's root frame.\"\"\"\n        bodylink_ang_vel = (\n            ObservationFunctions._get_obs_root_rel_robot_bodylink_ang_vel(\n                env, robot_asset_name, keybody_names\n            )\n        )  # [num_envs, num_keybodies, 3]\n        return bodylink_ang_vel.reshape(\n            bodylink_ang_vel.shape[0], -1\n        )  # [num_envs, num_keybodies * 3]\n\n    # ------- Robot DoF States -------\n    @staticmethod\n    def _get_obs_dof_pos(env: ManagerBasedRLEnv):\n        \"\"\"Joint positions relative to the default joint angles.\"\"\"\n        return isaaclab_mdp.joint_pos_rel(env)  # [num_envs, num_dofs]\n\n    @staticmethod\n    def _get_obs_dof_vel(env: ManagerBasedRLEnv):\n        \"\"\"Joint velocities.\"\"\"\n        return isaaclab_mdp.joint_vel_rel(env)  # [num_envs, num_dofs]\n\n    @staticmethod\n    def _get_obs_last_actions(env: ManagerBasedRLEnv):\n        \"\"\"Last action output by the policy.\"\"\"\n        return isaaclab_mdp.last_action(env)  # [num_envs, num_actions]\n\n    # ------- Reference Motion States -------\n    @staticmethod\n    def _get_obs_ref_motion_states(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ):\n        \"\"\"Reference motion states (flattened) via RefMotionCommand schema.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        obs_fn_name = f\"_get_obs_{command.cfg.command_obs_name}\"\n        obs_fn = getattr(command, obs_fn_name)\n        return obs_fn(obs_prefix=ref_prefix)\n\n    @staticmethod\n    def _get_obs_ref_motion_states_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ):\n        \"\"\"Future reference motion states (flattened).\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        obs_fn_name = f\"_get_obs_{command.cfg.command_obs_name}_fut\"\n        obs_fn = getattr(command, obs_fn_name)\n        return obs_fn(obs_prefix=ref_prefix)\n\n    @staticmethod\n    def _get_obs_vr_ref_motion_states(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ):\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command._get_obs_vr_ref_motion_states(obs_prefix=ref_prefix)\n\n    @staticmethod\n    def _get_obs_vr_ref_motion_states_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ):\n        \"\"\"Future reference motion states (flattened).\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command._get_obs_vr_ref_motion_fut(obs_prefix=ref_prefix)\n\n    @staticmethod\n    def _get_obs_ref_dof_pos_cur(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, num_dofs]\n        \"\"\"Reference current DoF positions in simulator DoF order.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_dof_pos_cur(prefix=ref_prefix)\n\n    @staticmethod\n    def _get_obs_immediate_next_two_dof_pos(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 2 * num_dofs]\n        \"\"\"Immediate next two DoF positions in simulator DoF order.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_immediate_next_two_dof_pos(prefix=ref_prefix)\n\n    @staticmethod\n    def _get_obs_ref_motion_cur_heading_aligned_root_pos(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Reference current heading-aligned root position.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_cur_heading_aligned_root_pos(\n            prefix=ref_prefix\n        )\n\n    @staticmethod\n    def _get_obs_ref_motion_fut_heading_aligned_root_pos(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, T, 3]\n        \"\"\"Future reference heading-aligned root position.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_fut_heading_aligned_root_pos(\n            prefix=ref_prefix\n        )\n\n    @staticmethod\n    def _get_obs_ref_motion_cur_heading_aligned_root_rot6d(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 6]\n        \"\"\"Reference current heading-aligned root rotation (rot6d).\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_cur_heading_aligned_root_rot6d(\n            prefix=ref_prefix\n        )\n\n    @staticmethod\n    def _get_obs_ref_motion_fut_heading_aligned_root_rot6d(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, T, 6]\n        \"\"\"Future reference heading-aligned root rotation (rot6d).\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_fut_heading_aligned_root_rot6d(\n            prefix=ref_prefix\n        )\n\n    @staticmethod\n    def _get_obs_ref_motion_cur_heading_aligned_root_lin_vel(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Reference current heading-aligned root linear velocity.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_cur_heading_aligned_root_lin_vel(\n            prefix=ref_prefix\n        )\n\n    @staticmethod\n    def _get_obs_ref_motion_fut_heading_aligned_root_lin_vel(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, T, 3]\n        \"\"\"Future reference heading-aligned root linear velocity.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_fut_heading_aligned_root_lin_vel(\n            prefix=ref_prefix\n        )\n\n    @staticmethod\n    def _get_obs_ref_motion_cur_heading_aligned_root_ang_vel(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Reference current heading-aligned root angular velocity.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_cur_heading_aligned_root_ang_vel(\n            prefix=ref_prefix\n        )\n\n    @staticmethod\n    def _get_obs_ref_motion_fut_heading_aligned_root_ang_vel(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, T, 3]\n        \"\"\"Future reference heading-aligned root angular velocity.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_fut_heading_aligned_root_ang_vel(\n            prefix=ref_prefix\n        )\n\n    @staticmethod\n    def _get_obs_ref_dof_vel_cur(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, num_dofs]\n        \"\"\"Reference current DoF velocities in simulator DoF order.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_dof_vel_cur(prefix=ref_prefix)\n\n    @staticmethod\n    def _get_obs_ref_motion_filter_cutoff_hz(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n    ) -> torch.Tensor:\n        \"\"\"Return clip-level filter metadata; this is prefix-independent.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        return command.get_ref_motion_filter_cutoff_hz_cur()\n\n    @staticmethod\n    def _get_obs_ref_root_height_cur(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 1]\n        \"\"\"Reference current root height: world z minus env-origin z.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        world_pos = command.get_ref_motion_root_global_pos_cur(\n            prefix=ref_prefix\n        )  # [B, 3]\n        height = (world_pos[..., 2] - env.scene.env_origins[..., 2]).unsqueeze(\n            -1\n        )  # [B,1]\n        return height\n\n    @staticmethod\n    def _get_obs_ref_dof_pos_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n        num_frames: int | None = None,\n    ) -> torch.Tensor:  # [num_envs, n_fut_frames * num_dofs]\n        \"\"\"Future reference DoF positions (flattened over time) in simulator DoF order.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        dof_pos_fut = command.get_ref_motion_dof_pos_fut(\n            prefix=ref_prefix\n        )  # [B, T, D(sim)]\n        dof_pos_fut = ObservationFunctions._slice_future_frames(\n            dof_pos_fut,\n            num_frames=num_frames,\n            obs_name=\"ref_dof_pos_fut\",\n        )\n        return dof_pos_fut\n\n    @staticmethod\n    def _get_obs_ref_gravity_projection_cur(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Reference gravity projection.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        gravity_projection = command.get_ref_motion_gravity_projection_cur(\n            prefix=ref_prefix\n        )\n        return gravity_projection\n\n    @staticmethod\n    def _get_obs_ref_gravity_projection_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n        num_frames: int | None = None,\n    ) -> torch.Tensor:  # [num_envs, T, 3]\n        \"\"\"Future reference gravity projection.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        gravity_projection = command.get_ref_motion_gravity_projection_fut(\n            prefix=ref_prefix\n        )\n        gravity_projection = ObservationFunctions._slice_future_frames(\n            gravity_projection,\n            num_frames=num_frames,\n            obs_name=\"ref_gravity_projection_fut\",\n        )\n        return gravity_projection\n\n    @staticmethod\n    def _get_obs_ref_base_linvel_cur(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Reference base linear velocity.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        base_linvel = command.get_ref_motion_base_linvel_cur(prefix=ref_prefix)\n        return base_linvel\n\n    @staticmethod\n    def _get_obs_ref_base_linvel_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n        num_frames: int | None = None,\n    ) -> torch.Tensor:  # [num_envs, T, 3]\n        \"\"\"Future reference base linear velocity.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        base_linvel = command.get_ref_motion_base_linvel_fut(prefix=ref_prefix)\n        base_linvel = ObservationFunctions._slice_future_frames(\n            base_linvel,\n            num_frames=num_frames,\n            obs_name=\"ref_base_linvel_fut\",\n        )\n        return base_linvel\n\n    @staticmethod\n    def _get_obs_ref_base_angvel_cur(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ) -> torch.Tensor:  # [num_envs, 3]\n        \"\"\"Reference base angular velocity.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        base_angvel = command.get_ref_motion_base_angvel_cur(prefix=ref_prefix)\n        return base_angvel\n\n    @staticmethod\n    def _get_obs_ref_keybody_rel_pos_cur(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n        keybody_names: list[str] | None = None,\n    ) -> torch.Tensor:  # [num_envs, num_keybodies, 3]\n        \"\"\"Reference keybody root-relative positions.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        ref_keybody_rel_pos = command.get_ref_motion_bodylink_rel_pos_cur(\n            prefix=ref_prefix\n        )  # [B, N, 3]\n        if keybody_names is None:\n            return ref_keybody_rel_pos\n        robot_ptr = env.scene[\"robot\"]\n        keybody_idxs = ObservationFunctions._get_body_indices(\n            robot_ptr, keybody_names\n        )\n        kb_rel_pos = ref_keybody_rel_pos[:, keybody_idxs, :]\n        bs = kb_rel_pos.shape[0]\n        return kb_rel_pos.reshape(bs, -1)\n\n    @staticmethod\n    def _get_obs_ref_keybody_rel_pos_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n        keybody_names: list[str] | None = None,\n        num_frames: int | None = None,\n    ) -> torch.Tensor:  # [num_envs, T, num_keybodies, 3]\n        \"\"\"Future reference keybody root-relative positions.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        ref_keybody_rel_pos_fut = command.get_ref_motion_bodylink_rel_pos_fut(\n            prefix=ref_prefix\n        )  # [B, T, N, 3]\n        ref_keybody_rel_pos_fut = ObservationFunctions._slice_future_frames(\n            ref_keybody_rel_pos_fut,\n            num_frames=num_frames,\n            obs_name=\"ref_keybody_rel_pos_fut\",\n        )\n        if keybody_names is None:\n            return ref_keybody_rel_pos_fut\n        robot_ptr = env.scene[\"robot\"]\n        keybody_idxs = ObservationFunctions._get_body_indices(\n            robot_ptr, keybody_names\n        )\n        kb_rel_pos_fut = ref_keybody_rel_pos_fut[:, :, keybody_idxs, :]\n        bs, t, _, _ = kb_rel_pos_fut.shape\n        return kb_rel_pos_fut.reshape(bs, t, -1)\n\n    @staticmethod\n    def _get_obs_ref_base_angvel_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n        num_frames: int | None = None,\n    ) -> torch.Tensor:  # [num_envs, T, 3]\n        \"\"\"Future reference base angular velocity.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        base_angvel = command.get_ref_motion_base_angvel_fut(prefix=ref_prefix)\n        base_angvel = ObservationFunctions._slice_future_frames(\n            base_angvel,\n            num_frames=num_frames,\n            obs_name=\"ref_base_angvel_fut\",\n        )\n        return base_angvel\n\n    @staticmethod\n    def _get_obs_ref_dof_vel_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n        num_frames: int | None = None,\n    ) -> torch.Tensor:  # [num_envs, n_fut_frames * num_dofs]\n        \"\"\"Future reference DoF velocities (flattened over time) in simulator DoF order.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        dof_vel_fut = command.get_ref_motion_dof_vel_fut(\n            prefix=ref_prefix\n        )  # [B, T, D(sim)]\n        dof_vel_fut = ObservationFunctions._slice_future_frames(\n            dof_vel_fut,\n            num_frames=num_frames,\n            obs_name=\"ref_dof_vel_fut\",\n        )\n        B, T, D = dof_vel_fut.shape\n        return dof_vel_fut.reshape(B, T * D)\n\n    @staticmethod\n    def _get_obs_ref_root_height_fut(\n        env: ManagerBasedRLEnv,\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n        num_frames: int | None = None,\n    ) -> torch.Tensor:  # [num_envs, n_fut_frames]\n        \"\"\"Future reference root heights per frame: world z minus env-origin z.\"\"\"\n        command = env.command_manager.get_term(ref_motion_command_name)\n        world_pos = command.get_ref_motion_root_global_pos_fut(\n            prefix=ref_prefix\n        )  # [B, T, 3]\n        world_pos = ObservationFunctions._slice_future_frames(\n            world_pos,\n            num_frames=num_frames,\n            obs_name=\"ref_root_height_fut\",\n        )\n        heights = (\n            world_pos[..., 2] - env.scene.env_origins[:, None, 2]\n        )  # [B, T]\n        return heights[..., None]\n\n    # @torch.compile\n    @staticmethod\n    def _get_obs_global_anchor_diff(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ):\n        command = env.command_manager.get_term(ref_motion_command_name)\n        env_ref_motion_anchor_pos = positions_world_to_env_frame(\n            command.get_ref_motion_anchor_bodylink_global_pos_cur(\n                prefix=ref_prefix\n            ),\n            env.scene.env_origins,\n        )\n        global_ref_motino_anchor_rot_wxyz = (\n            command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur(\n                prefix=ref_prefix\n            )\n        )\n        global_robot_anchor_pos = (\n            ObservationFunctions._get_obs_global_robot_bodylink_pos(\n                env, robot_asset_name, [command.anchor_bodylink_name]\n            ).squeeze(1)\n        )\n        global_robot_anchor_rot_wxyz = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(\n                env, robot_asset_name, [command.anchor_bodylink_name]\n            ).squeeze(1)\n        )\n        pos_diff, rot_diff = isaaclab_math.subtract_frame_transforms(\n            t01=global_robot_anchor_pos,\n            q01=global_robot_anchor_rot_wxyz,\n            t02=env_ref_motion_anchor_pos,\n            q02=global_ref_motino_anchor_rot_wxyz,\n        )\n        rot_diff_mat = isaaclab_math.matrix_from_quat(rot_diff)\n        return torch.cat(\n            [\n                pos_diff,\n                rot_diff_mat[..., :2].reshape(env.num_envs, -1),\n            ],\n            dim=-1,\n        )  # [num_envs, 9]\n\n    @staticmethod\n    def _get_obs_global_anchor_pos_diff(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ):\n        command = env.command_manager.get_term(ref_motion_command_name)\n        env_ref_motion_anchor_pos = positions_world_to_env_frame(\n            command.get_ref_motion_anchor_bodylink_global_pos_cur(\n                prefix=ref_prefix\n            ),\n            env.scene.env_origins,\n        )\n        global_ref_motino_anchor_rot_wxyz = (\n            command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur(\n                prefix=ref_prefix\n            )\n        )\n        global_robot_anchor_pos = (\n            ObservationFunctions._get_obs_global_robot_bodylink_pos(\n                env, robot_asset_name, [command.anchor_bodylink_name]\n            ).squeeze(1)\n        )\n        global_robot_anchor_rot_wxyz = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(\n                env, robot_asset_name, [command.anchor_bodylink_name]\n            ).squeeze(1)\n        )\n        pos_diff, _ = isaaclab_math.subtract_frame_transforms(\n            t01=global_robot_anchor_pos,\n            q01=global_robot_anchor_rot_wxyz,\n            t02=env_ref_motion_anchor_pos,\n            q02=global_ref_motino_anchor_rot_wxyz,\n        )\n        return pos_diff\n\n    @staticmethod\n    def _get_obs_global_anchor_rot_diff(\n        env: ManagerBasedRLEnv,\n        robot_asset_name: str = \"robot\",\n        ref_motion_command_name: str = \"ref_motion\",\n        ref_prefix: str = \"ref_\",\n    ):\n        command = env.command_manager.get_term(ref_motion_command_name)\n        env_ref_motion_anchor_pos = positions_world_to_env_frame(\n            command.get_ref_motion_anchor_bodylink_global_pos_cur(\n                prefix=ref_prefix\n            ),\n            env.scene.env_origins,\n        )\n        global_ref_motino_anchor_rot_wxyz = (\n            command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur(\n                prefix=ref_prefix\n            )\n        )\n        global_robot_anchor_pos = (\n            ObservationFunctions._get_obs_global_robot_bodylink_pos(\n                env, robot_asset_name, [command.anchor_bodylink_name]\n            ).squeeze(1)\n        )\n        global_robot_anchor_rot_wxyz = (\n            ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(\n                env, robot_asset_name, [command.anchor_bodylink_name]\n            ).squeeze(1)\n        )\n        _, rot_diff = isaaclab_math.subtract_frame_transforms(\n            t01=global_robot_anchor_pos,\n            q01=global_robot_anchor_rot_wxyz,\n            t02=env_ref_motion_anchor_pos,\n            q02=global_ref_motino_anchor_rot_wxyz,\n        )\n        rot_diff_mat = isaaclab_math.matrix_from_quat(rot_diff)\n        return rot_diff_mat[..., :2].reshape(env.num_envs, -1)\n\n    @staticmethod\n    def _get_obs_velocity_command(\n        env: ManagerBasedRLEnv,\n    ):\n        \"\"\"Velocity command.\n\n        This function should return the velocity command which\n        has already been serialized into flattened vectors. Note that we also\n        add a model switch mask dimension, when commands are small, the mode\n        is set to 0, otherwise it is set to 1.\n        \"\"\"\n        velocity_command = isaaclab_mdp.generated_commands(\n            env,\n            command_name=\"base_velocity\",\n        )\n        # Some IsaacLab velocity commands may append extra channels (e.g., heading).\n        # For velocity-tracking PPO we only use (vx, vy, yaw_rate) to keep the\n        # observation contract stable.\n        if velocity_command.shape[-1] > 3:\n            velocity_command = velocity_command[..., :3]\n        move_mask = (velocity_command.norm(dim=-1) > 0.1).to(\n            dtype=velocity_command.dtype\n        )\n        return torch.cat(\n            [\n                move_mask[..., None],\n                velocity_command,\n            ],\n            dim=-1,\n        )  # [num_envs, 4]\n\n    @staticmethod\n    def _get_obs_place_holder(env: ManagerBasedRLEnv, n_dim: int):\n        return torch.zeros(env.num_envs, n_dim, device=env.device)\n\n    @staticmethod\n    def _get_obs_ref_headling_aligned_vel_cmd(\n        env: ManagerBasedRLEnv, ref_prefix: str = \"ref_\"\n    ):\n        heading_aligned_lin_vel_xyz = ObservationFunctions._get_obs_ref_motion_cur_heading_aligned_root_lin_vel(\n            env, ref_prefix=ref_prefix\n        )\n        heading_aligned_ang_vel_xyz = ObservationFunctions._get_obs_ref_motion_cur_heading_aligned_root_ang_vel(\n            env, ref_prefix=ref_prefix\n        )\n        heading_aligned_vel_cmd = torch.cat(\n            [\n                heading_aligned_lin_vel_xyz[:, :2],\n                heading_aligned_ang_vel_xyz[:, 2:3],\n            ],\n            dim=-1,\n        )\n        move_mask = (heading_aligned_vel_cmd.norm(dim=-1) > 0.1).to(\n            dtype=heading_aligned_vel_cmd.dtype\n        )\n        heading_aligned_vel_cmd = torch.cat(\n            [\n                move_mask[..., None],\n                heading_aligned_vel_cmd,\n            ],\n            dim=-1,\n        )\n        return heading_aligned_vel_cmd\n\n    @staticmethod\n    def _get_obs_heading_aligned_root_ang_vel(env: ManagerBasedRLEnv):\n        root_global_ang_vel = (\n            ObservationFunctions._get_obs_global_robot_root_ang_vel(env)\n        )\n        root_global_rot_wxyz = (\n            ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)\n        )\n        heading_quat_wxyz = isaaclab_math.yaw_quat(root_global_rot_wxyz)\n        heading_aligned_root_ang_vel = isaaclab_math.quat_apply_inverse(\n            heading_quat_wxyz, root_global_ang_vel\n        )\n        return heading_aligned_root_ang_vel\n\n\n@configclass\nclass ObservationsCfg:\n    pass\n\n\ndef build_observations_config(obs_config_dict: dict):\n    \"\"\"Build isaaclab-compatible ObservationsCfg from a config dictionary.\"\"\"\n\n    if isinstance(obs_config_dict, (DictConfig, ListConfig)):\n        obs_config_dict = OmegaConf.to_container(obs_config_dict, resolve=True)\n\n    obs_cfg = ObservationsCfg()\n    obs_term_field_names = {\n        field.name for field in dataclass_fields(ObservationTermCfg)\n    }\n\n    # Create observation groups dynamically\n    for group_name, group_cfg in obs_config_dict.items():\n        group_cfg = resolve_holo_config(group_cfg)\n\n        isaaclab_obs_group_cfg = ObsGroup()\n\n        for key, value in group_cfg.items():\n            if key == \"atomic_obs_list\":\n                continue\n            if hasattr(isaaclab_obs_group_cfg, key):\n                setattr(isaaclab_obs_group_cfg, key, value)\n\n        # Add observation terms to the group\n        for obs_term_dict in group_cfg[\"atomic_obs_list\"]:\n            for obs_name, obs_params in obs_term_dict.items():\n                obs_params = resolve_holo_config(obs_params)\n                func_name = obs_params.get(\"func\", obs_name)\n                method_name = f\"_get_obs_{func_name}\"\n\n                if hasattr(ObservationFunctions, method_name):\n                    func = getattr(ObservationFunctions, method_name)\n                elif hasattr(isaaclab_mdp, func_name):\n                    func = getattr(isaaclab_mdp, func_name)\n                else:\n                    raise ValueError(\n                        f\"Unknown observation function: {func_name}\"\n                    )\n\n                obs_term_kwargs = {\"func\": func}\n                try:\n                    params_cfg = obs_params.get(\"params\", {})\n                except AttributeError:\n                    print(f\"No params found for {obs_name}\")\n\n                obs_term_kwargs[\"params\"] = resolve_holo_config(params_cfg)\n\n                noise_cfg = obs_params.get(\"noise\")\n                if noise_cfg is not None:\n                    obs_term_kwargs[\"noise\"] = _build_noise_cfg(noise_cfg)\n\n                for field_name in obs_term_field_names:\n                    if field_name in {\"func\", \"params\", \"noise\"}:\n                        continue\n                    if field_name in obs_params:\n                        obs_term_kwargs[field_name] = obs_params[field_name]\n\n                obs_term = ObsTerm(**obs_term_kwargs)\n\n                # Add observation term to group\n                setattr(isaaclab_obs_group_cfg, obs_name, obs_term)\n\n        # Add group to main observations config\n        setattr(obs_cfg, group_name, isaaclab_obs_group_cfg)\n\n    return obs_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_rewards.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport torch\nfrom isaaclab.assets import Articulation\nfrom isaaclab.envs import ManagerBasedRLEnv\nfrom isaaclab.managers import ManagerTermBase, RewardTermCfg, SceneEntityCfg\nfrom isaaclab.sensors import ContactSensor\nfrom isaaclab.utils import configclass\nimport isaaclab.utils.math as isaaclab_math\n\nfrom holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import (\n    RefMotionCommand,\n)\nfrom holomotion.src.utils.frame_utils import (\n    positions_world_to_env_frame,\n    root_relative_positions_from_env_frame,\n    root_relative_positions_from_mixed_position_frames,\n)\nimport isaaclab.envs.mdp as isaaclab_mdp\nfrom hydra.utils import instantiate as hydra_instantiate\nfrom omegaconf import DictConfig, ListConfig, OmegaConf\n\nfrom loguru import logger\nfrom holomotion.src.env.isaaclab_components.isaaclab_utils import (\n    _get_body_indices,\n    resolve_holo_config,\n    _get_dof_indices,\n)\n\n\ndef _joint_ids_to_tensor(\n    joint_ids: slice | list[int] | tuple[int, ...] | torch.Tensor | None,\n    num_joints: int,\n    device: torch.device | str,\n) -> torch.Tensor:\n    \"\"\"Convert joint indices to a dense tensor in articulation order.\"\"\"\n    if joint_ids is None:\n        return torch.arange(num_joints, device=device, dtype=torch.long)\n    if isinstance(joint_ids, slice):\n        if joint_ids == slice(None):\n            return torch.arange(num_joints, device=device, dtype=torch.long)\n        return torch.arange(num_joints, device=device, dtype=torch.long)[\n            joint_ids\n        ]\n    if isinstance(joint_ids, torch.Tensor):\n        return joint_ids.to(device=device, dtype=torch.long).flatten()\n    return torch.tensor(joint_ids, device=device, dtype=torch.long)\n\n\ndef _select_effort_limit_vector(\n    asset: Articulation,\n    selected_joint_ids: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Build a per-joint effort-limit vector from instantiated actuators.\"\"\"\n    num_joints = asset.data.applied_torque.shape[1]\n    device = asset.data.applied_torque.device\n    dtype = asset.data.applied_torque.dtype\n\n    effort_limit_vec = torch.zeros(num_joints, device=device, dtype=dtype)\n    is_filled = torch.zeros(num_joints, device=device, dtype=torch.bool)\n\n    for actuator in asset.actuators.values():\n        actuator_joint_ids = _joint_ids_to_tensor(\n            actuator.joint_indices, num_joints=num_joints, device=device\n        )\n        actuator_effort_limit = torch.as_tensor(\n            actuator.effort_limit, device=device, dtype=dtype\n        )\n        if actuator_effort_limit.ndim == 0:\n            actuator_effort_limit = actuator_effort_limit.expand(\n                actuator_joint_ids.numel()\n            )\n        elif actuator_effort_limit.ndim == 2:\n            if actuator_effort_limit.shape[0] > 1:\n                reference = actuator_effort_limit[0].unsqueeze(0)\n                if not torch.allclose(\n                    actuator_effort_limit,\n                    reference.expand_as(actuator_effort_limit),\n                ):\n                    raise ValueError(\n                        \"normed_torque_rate requires actuator effort limits to be static across envs.\"\n                    )\n            actuator_effort_limit = actuator_effort_limit[0]\n        elif actuator_effort_limit.ndim != 1:\n            raise ValueError(\n                \"normed_torque_rate expects actuator effort limits to be scalar, 1-D, or 2-D tensors.\"\n            )\n\n        if actuator_effort_limit.numel() != actuator_joint_ids.numel():\n            raise ValueError(\n                \"normed_torque_rate found mismatched actuator joint indices and effort limits.\"\n            )\n\n        effort_limit_vec[actuator_joint_ids] = actuator_effort_limit\n        is_filled[actuator_joint_ids] = True\n\n    if not torch.all(is_filled[selected_joint_ids]):\n        missing_joint_ids = selected_joint_ids[~is_filled[selected_joint_ids]]\n        raise ValueError(\n            \"normed_torque_rate could not resolve actuator effort limits for \"\n            f\"joint ids {missing_joint_ids.tolist()}.\"\n        )\n\n    selected_effort_limits = effort_limit_vec[selected_joint_ids]\n    if not torch.all(torch.isfinite(selected_effort_limits)):\n        raise ValueError(\n            \"normed_torque_rate requires finite actuator effort limits for all selected joints.\"\n        )\n    if not torch.all(selected_effort_limits > 0.0):\n        raise ValueError(\n            \"normed_torque_rate requires strictly positive actuator effort limits for all selected joints.\"\n        )\n\n    return selected_effort_limits\n\n\ndef key_dof_position_tracking_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    key_dofs: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    keydof_idxs = _get_dof_indices(command.robot, key_dofs)\n    ref_dof_pos = command.get_ref_motion_dof_pos_immediate_next(\n        prefix=ref_prefix\n    )\n    error = torch.sum(\n        torch.square(\n            command.robot.data.joint_pos[:, keydof_idxs]\n            - ref_dof_pos[:, keydof_idxs]\n        ),\n        dim=-1,\n    )\n    return torch.exp(-error / std**2)\n\n\ndef key_dof_velocity_tracking_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    key_dofs: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    keydof_idxs = _get_dof_indices(command.robot, key_dofs)\n    ref_dof_vel = command.get_ref_motion_dof_vel_immediate_next(\n        prefix=ref_prefix\n    )\n    error = torch.sum(\n        torch.square(\n            command.robot.data.joint_vel[:, keydof_idxs]\n            - ref_dof_vel[:, keydof_idxs]\n        ),\n        dim=-1,\n    )\n    return torch.exp(-error / std**2)\n\n\ndef motion_global_anchor_position_error_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    ref_motion_command: RefMotionCommand = env.command_manager.get_term(\n        command_name\n    )\n    ref_anchor_pos = ref_motion_command.get_ref_motion_anchor_bodylink_global_pos_immediate_next(\n        prefix=ref_prefix\n    )\n    robot_anchor_pos = ref_motion_command.global_robot_anchor_pos_cur\n    error = torch.sum(\n        torch.square(ref_anchor_pos - robot_anchor_pos),\n        dim=-1,\n    )\n    return torch.exp(-error / std**2)\n\n\ndef motion_global_anchor_orientation_error_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    ref_anchor_quat = (\n        command.get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )\n    error = (\n        isaaclab_math.quat_error_magnitude(\n            ref_anchor_quat,\n            command.robot.data.body_quat_w[:, command.anchor_bodylink_idx],\n        )\n        ** 2\n    )\n    return torch.exp(-error / std**2)\n\n\ndef motion_relative_body_position_error_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    # Get body indexes based on body names (similar to whole_body_tracking implementation)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Get reference and robot anchor positions/orientations\n    ref_anchor_pos = command.get_ref_motion_root_global_pos_immediate_next(\n        prefix=ref_prefix\n    )  # [B, 3]\n    ref_anchor_quat = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 4] (w,x,y,z)\n    robot_anchor_pos = command.robot.data.body_pos_w[\n        :, command.anchor_bodylink_idx\n    ]  # [B, 3]\n    robot_anchor_quat = command.robot.data.body_quat_w[\n        :, command.anchor_bodylink_idx\n    ]  # [B, 4] (w,x,y,z)\n\n    # Get reference body positions in global frame\n    ref_body_pos_global = (\n        command.get_ref_motion_bodylink_global_pos_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, num_bodies, 3]\n\n    # Transform reference body positions to be relative to robot's current anchor\n    # This follows the same logic as the whole_body_tracking implementation\n\n    # Select relevant body indices first\n    ref_body_pos_selected = ref_body_pos_global[\n        :, keybody_idxs\n    ]  # [B, selected_bodies, 3]\n\n    # Expand anchor positions/orientations to match number of selected bodies\n    num_bodies = len(keybody_idxs)\n    ref_anchor_pos_exp = ref_anchor_pos[:, None, :].expand(\n        -1, num_bodies, -1\n    )  # [B, num_bodies, 3]\n    ref_anchor_quat_exp = ref_anchor_quat[:, None, :].expand(\n        -1, num_bodies, -1\n    )  # [B, num_bodies, 4]\n    robot_anchor_pos_exp = robot_anchor_pos[:, None, :].expand(\n        -1, num_bodies, -1\n    )  # [B, num_bodies, 3]\n    robot_anchor_quat_exp = robot_anchor_quat[:, None, :].expand(\n        -1, num_bodies, -1\n    )  # [B, num_bodies, 4]\n\n    # Create delta transformation (preserving z from reference, aligning xy to robot)\n    delta_pos = robot_anchor_pos_exp.clone()\n    delta_pos[..., 2] = ref_anchor_pos_exp[..., 2]  # Keep reference Z height\n\n    delta_ori = isaaclab_math.yaw_quat(\n        isaaclab_math.quat_mul(\n            robot_anchor_quat_exp,\n            isaaclab_math.quat_inv(ref_anchor_quat_exp),\n        )\n    )\n\n    # Transform reference body positions to relative frame\n    ref_body_pos_relative = delta_pos + isaaclab_math.quat_apply(\n        delta_ori, ref_body_pos_selected - ref_anchor_pos_exp\n    )\n\n    # Get robot body positions\n    robot_body_pos = command.robot.data.body_pos_w[:, keybody_idxs]\n\n    # Compute error\n    error = torch.sum(\n        torch.square(ref_body_pos_relative - robot_body_pos),\n        dim=-1,\n    )\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef root_rel_keybodylink_pos_tracking_l2_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track root-relative keybody positions using environment-frame positions.\n\n    IsaacLab MDP root position helpers are expressed in the environment frame\n    (simulation-world position minus `env.scene.env_origins`). This reward\n    converts body positions into the same environment frame before computing\n    root-relative vectors.\n    \"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    # Get body indexes based on body names (similar to whole_body_tracking implementation)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Get reference and robot root positions/orientations\n    ref_root_pos_env = positions_world_to_env_frame(\n        command.get_ref_motion_root_global_pos_immediate_next(\n            prefix=ref_prefix\n        ),\n        env.scene.env_origins,\n    )  # [B, 3]\n    ref_root_quat_w = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 4] (w,x,y,z)\n    robot_root_pos_env = isaaclab_mdp.root_pos_w(env)  # [B, 3]\n    robot_root_quat_w = isaaclab_mdp.root_quat_w(env)  # [B, 4] (w,x,y,z)\n\n    # Select relevant body indices first\n    ref_body_pos_env = positions_world_to_env_frame(\n        command.get_ref_motion_bodylink_global_pos_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs],\n        env.scene.env_origins,\n    )\n    robot_body_pos_root_rel = (\n        root_relative_positions_from_mixed_position_frames(\n            body_pos_w=command.robot.data.body_pos_w[:, keybody_idxs],\n            root_pos_env=robot_root_pos_env,\n            root_quat_w=robot_root_quat_w,\n            env_origins=env.scene.env_origins,\n        )\n    )\n    ref_body_pos_root_rel = root_relative_positions_from_env_frame(\n        body_pos_env=ref_body_pos_env,\n        root_pos_env=ref_root_pos_env,\n        root_quat_w=ref_root_quat_w,\n    )\n\n    # Compute error\n    error = torch.sum(\n        torch.square(ref_body_pos_root_rel - robot_body_pos_root_rel),\n        dim=-1,\n    )\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef motion_relative_body_orientation_error_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    # Get body indexes based on body names (similar to whole_body_tracking implementation)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Get reference and robot anchor orientations\n    ref_anchor_quat = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 4] (w,x,y,z)\n    robot_anchor_quat = command.robot.data.body_quat_w[\n        :, command.anchor_bodylink_idx\n    ]  # [B, 4] (w,x,y,z)\n\n    # Get reference body orientations in global frame\n    ref_body_quat_global = (\n        command.get_ref_motion_bodylink_global_rot_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, num_bodies, 4]\n\n    # Select relevant body indices\n    ref_body_quat_selected = ref_body_quat_global[\n        :, keybody_idxs\n    ]  # [B, selected_bodies, 4]\n\n    # Expand anchor orientations to match number of selected bodies\n    num_bodies = len(keybody_idxs)\n    ref_anchor_quat_exp = ref_anchor_quat[:, None, :].expand(\n        -1, num_bodies, -1\n    )  # [B, num_bodies, 4]\n    robot_anchor_quat_exp = robot_anchor_quat[:, None, :].expand(\n        -1, num_bodies, -1\n    )  # [B, num_bodies, 4]\n\n    # Compute relative orientation transformation (only yaw component)\n    delta_ori = isaaclab_math.yaw_quat(\n        isaaclab_math.quat_mul(\n            robot_anchor_quat_exp,\n            isaaclab_math.quat_inv(ref_anchor_quat_exp),\n        )\n    )\n\n    # Transform reference body orientations to relative frame\n    ref_body_quat_relative = isaaclab_math.quat_mul(\n        delta_ori, ref_body_quat_selected\n    )\n\n    # Get robot body orientations\n    robot_body_quat = command.robot.data.body_quat_w[:, keybody_idxs]\n\n    # Compute error\n    error = (\n        isaaclab_math.quat_error_magnitude(\n            ref_body_quat_relative, robot_body_quat\n        )\n        ** 2\n    )\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef motion_global_body_linear_velocity_error_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    # Get body indexes based on body names (similar to whole_body_tracking implementation)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Direct comparison of global velocities (no coordinate transformation needed)\n    ref_lin_vel = (\n        command.get_ref_motion_bodylink_global_lin_vel_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs]\n    )\n    robot_lin_vel = command.robot.data.body_lin_vel_w[:, keybody_idxs]\n    error = torch.sum(torch.square(ref_lin_vel - robot_lin_vel), dim=-1)\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef motion_global_body_angular_velocity_error_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    # Get body indexes based on body names (similar to whole_body_tracking implementation)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Direct comparison of global angular velocities (no coordinate transformation needed)\n    ref_ang_vel = (\n        command.get_ref_motion_bodylink_global_ang_vel_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs]\n    )\n    robot_ang_vel = command.robot.data.body_ang_vel_w[:, keybody_idxs]\n    error = torch.sum(torch.square(ref_ang_vel - robot_ang_vel), dim=-1)\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef root_pos_xy_tracking_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    ref_root_pos = command.get_ref_motion_root_global_pos_immediate_next(\n        prefix=ref_prefix\n    )\n    error = torch.sum(\n        torch.square(\n            ref_root_pos[:, :2] - command.robot.data.root_pos_w[:, :2]\n        ),\n        dim=-1,\n    )\n    return torch.exp(-error / std**2)\n\n\ndef root_rot_tracking_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    ref_root_quat = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )\n    error = (\n        isaaclab_math.quat_error_magnitude(\n            ref_root_quat,\n            isaaclab_mdp.root_quat_w(env),\n        )\n        ** 2\n    )\n    return torch.exp(-error / std**2)\n\n\ndef root_pos_rel_z_tracking_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    robot_root_z = command.robot.data.root_pos_w[:, 2]\n    ref_root_z = command.get_ref_motion_root_global_pos_immediate_next(\n        prefix=ref_prefix\n    )[:, 2]\n    dz_rel = robot_root_z - ref_root_z\n    error = torch.square(dz_rel)\n    return torch.exp(-error / std**2)\n\n\ndef root_lin_vel_tracking_l2_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track root linear velocity in each entity's own root frame.\n\n    Returns: [B]\n    \"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n\n    # [B, 3], [B, 4]\n    robot_root_lin_vel_w = isaaclab_mdp.root_lin_vel_w(env)\n    robot_root_quat_w = isaaclab_mdp.root_quat_w(env)\n    ref_root_lin_vel_w = (\n        command.get_ref_motion_root_global_lin_vel_immediate_next(\n            prefix=ref_prefix\n        )\n    )\n    ref_root_quat_w = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )\n\n    # Project to respective root frames\n    robot_root_lin_vel = isaaclab_math.quat_apply(\n        isaaclab_math.quat_inv(robot_root_quat_w),\n        robot_root_lin_vel_w,\n    )  # [B, 3]\n    ref_root_lin_vel = isaaclab_math.quat_apply(\n        isaaclab_math.quat_inv(ref_root_quat_w),\n        ref_root_lin_vel_w,\n    )  # [B, 3]\n\n    error = torch.sum(\n        torch.square(ref_root_lin_vel - robot_root_lin_vel), dim=-1\n    )\n    return torch.exp(-error / std**2)\n\n\ndef root_ang_vel_tracking_l2_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track root angular velocity in each entity's own root frame.\n\n    Returns: [B]\n    \"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n\n    # [B, 3], [B, 4]\n    robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env)\n    robot_root_quat_w = isaaclab_mdp.root_quat_w(env)\n    ref_root_ang_vel_w = (\n        command.get_ref_motion_root_global_ang_vel_immediate_next(\n            prefix=ref_prefix\n        )\n    )\n    ref_root_quat_w = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )\n\n    # Project to respective root frames\n    robot_root_ang_vel = isaaclab_math.quat_apply(\n        isaaclab_math.quat_inv(robot_root_quat_w),\n        robot_root_ang_vel_w,\n    )  # [B, 3]\n    ref_root_ang_vel = isaaclab_math.quat_apply(\n        isaaclab_math.quat_inv(ref_root_quat_w),\n        ref_root_ang_vel_w,\n    )  # [B, 3]\n\n    error = torch.sum(\n        torch.square(ref_root_ang_vel - robot_root_ang_vel), dim=-1\n    )\n    return torch.exp(-error / std**2)\n\n\ndef root_rel_keybodylink_pos_tracking_l2_exp_bydmmc_style(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track keybody positions using per-entity heading-aligned frames.\n\n    For each of robot and reference:\n    - subtract own root position (root-relative in world)\n    - rotate by own yaw-only inverse (heading-aligned frame)\n    Then compare these root-relative, heading-aligned positions.\n\n    All positions are first converted into IsaacLab's environment frame\n    (simulation world minus `env.scene.env_origins`) so robot root and body\n    positions use the same translation convention.\n\n    Returns: [B]\n    \"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Root states in environment frame\n    ref_root_pos = positions_world_to_env_frame(\n        command.get_ref_motion_root_global_pos_immediate_next(\n            prefix=ref_prefix\n        ),\n        env.scene.env_origins,\n    )  # [B, 3]\n    ref_root_quat = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 4]\n    robot_root_pos = isaaclab_mdp.root_pos_w(env)  # [B, 3]\n    robot_root_quat = isaaclab_mdp.root_quat_w(env)  # [B, 4]\n\n    # Body positions in environment frame\n    robot_body_pos = positions_world_to_env_frame(\n        command.robot.data.body_pos_w[:, keybody_idxs],\n        env.scene.env_origins,\n    )  # [B, N, 3]\n    ref_body_pos = positions_world_to_env_frame(\n        command.get_ref_motion_bodylink_global_pos_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs],\n        env.scene.env_origins,\n    )  # [B, N, 3]\n\n    # Expand for broadcasting\n    num_bodies = len(keybody_idxs)\n    ref_root_pos_exp = ref_root_pos[:, None, :].expand(-1, num_bodies, -1)\n    ref_root_quat_exp = ref_root_quat[:, None, :].expand(-1, num_bodies, -1)\n    robot_root_pos_exp = robot_root_pos[:, None, :].expand(-1, num_bodies, -1)\n    robot_root_quat_exp = robot_root_quat[:, None, :].expand(\n        -1, num_bodies, -1\n    )\n\n    # Yaw-only delta orientation (root frames)\n    delta_ori = isaaclab_math.yaw_quat(\n        isaaclab_math.quat_mul(\n            robot_root_quat_exp, isaaclab_math.quat_inv(ref_root_quat_exp)\n        )\n    )  # [B, N, 4]\n\n    # Keep origin at root: compare root-relative vectors after yaw alignment\n    robot_rel = robot_body_pos - robot_root_pos_exp  # [B, N, 3]\n    ref_rel = ref_body_pos - ref_root_pos_exp  # [B, N, 3]\n    ref_rel_aligned = isaaclab_math.quat_apply(delta_ori, ref_rel)  # [B, N, 3]\n\n    # Compare in world (root-relative)\n    error = torch.sum(\n        torch.square(ref_rel_aligned - robot_rel), dim=-1\n    )  # [B, N]\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef root_rel_keybodylink_rot_tracking_l2_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track root-relative keybody rotations in each entity's root frame.\n\n    Returns: [B]\n    \"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Root orientations\n    robot_root_quat_w = isaaclab_mdp.root_quat_w(env)  # [B, 4]\n    ref_root_quat_w = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 4]\n\n    # Body orientations (world)\n    robot_body_quat_w = command.robot.data.body_quat_w[\n        :, keybody_idxs\n    ]  # [B, N, 4]\n    ref_body_quat_w = (\n        command.get_ref_motion_bodylink_global_rot_wxyz_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs]\n    )  # [B, N, 4]\n\n    # Relative (q_rel = q_root^{-1} * q_body)\n    num_bodies = len(keybody_idxs)\n    robot_root_quat_inv_exp = isaaclab_math.quat_inv(robot_root_quat_w)[\n        :, None, :\n    ].expand(-1, num_bodies, -1)\n    ref_root_quat_inv_exp = isaaclab_math.quat_inv(ref_root_quat_w)[\n        :, None, :\n    ].expand(-1, num_bodies, -1)\n\n    robot_rel_quat = isaaclab_math.quat_mul(\n        robot_root_quat_inv_exp,\n        robot_body_quat_w,\n    )  # [B, N, 4]\n    ref_rel_quat = isaaclab_math.quat_mul(\n        ref_root_quat_inv_exp,\n        ref_body_quat_w,\n    )  # [B, N, 4]\n\n    error = (\n        isaaclab_math.quat_error_magnitude(ref_rel_quat, robot_rel_quat) ** 2\n    )  # [B, N]\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef root_rel_keybodylink_lin_vel_tracking_l2_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track keybody linear velocities with motion_relative frame alignment.\n\n    Compute rigid-body-relative velocities for both entities w.r.t. their\n    roots, yaw-align reference to robot using root quats, then compare in\n    world space.\n\n    Root/body positions used for rigid-body radius vectors are first converted\n    into IsaacLab's environment frame (simulation world minus\n    `env.scene.env_origins`) so the translation convention matches\n    `isaaclab_mdp.root_pos_w(env)`.\n\n    Returns: [B]\n    \"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Root states\n    robot_root_pos_w = isaaclab_mdp.root_pos_w(env)  # [B, 3]\n    robot_root_quat_w = isaaclab_mdp.root_quat_w(env)  # [B, 4]\n    robot_root_lin_vel_w = isaaclab_mdp.root_lin_vel_w(env)  # [B, 3]\n    robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env)  # [B, 3]\n\n    ref_root_pos_w = positions_world_to_env_frame(\n        command.get_ref_motion_root_global_pos_immediate_next(\n            prefix=ref_prefix\n        ),\n        env.scene.env_origins,\n    )  # [B, 3]\n    ref_root_quat_w = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 4]\n    ref_root_lin_vel_w = (\n        command.get_ref_motion_root_global_lin_vel_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 3]\n    ref_root_ang_vel_w = (\n        command.get_ref_motion_root_global_ang_vel_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 3]\n\n    # Body states (world)\n    robot_body_pos_w = positions_world_to_env_frame(\n        command.robot.data.body_pos_w[:, keybody_idxs],\n        env.scene.env_origins,\n    )  # [B, N, 3]\n    robot_body_lin_vel_w = command.robot.data.body_lin_vel_w[\n        :, keybody_idxs\n    ]  # [B, N, 3]\n    ref_body_pos_w = positions_world_to_env_frame(\n        command.get_ref_motion_bodylink_global_pos_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs],\n        env.scene.env_origins,\n    )  # [B, N, 3]\n    ref_body_lin_vel_w = (\n        command.get_ref_motion_bodylink_global_lin_vel_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs]\n    )  # [B, N, 3]\n\n    # Rigid-body relative (world)\n    robot_r_w = robot_body_pos_w - robot_root_pos_w[:, None, :]\n    ref_r_w = ref_body_pos_w - ref_root_pos_w[:, None, :]\n\n    robot_cross = torch.cross(\n        robot_root_ang_vel_w[:, None, :], robot_r_w, dim=-1\n    )  # [B, N, 3]\n    ref_cross = torch.cross(\n        ref_root_ang_vel_w[:, None, :], ref_r_w, dim=-1\n    )  # [B, N, 3]\n\n    robot_v_rel_w = (\n        robot_body_lin_vel_w - robot_root_lin_vel_w[:, None, :] - robot_cross\n    )  # [B, N, 3]\n    ref_v_rel_w = (\n        ref_body_lin_vel_w - ref_root_lin_vel_w[:, None, :] - ref_cross\n    )  # [B, N, 3]\n    # Yaw-only delta orientation from root quats; rotate reference velocities\n    num_bodies = len(keybody_idxs)\n    robot_root_quat_exp = robot_root_quat_w[:, None, :].expand(\n        -1, num_bodies, -1\n    )  # [B, N, 4]\n    ref_root_quat_exp = ref_root_quat_w[:, None, :].expand(\n        -1, num_bodies, -1\n    )  # [B, N, 4]\n    delta_ori = isaaclab_math.yaw_quat(\n        isaaclab_math.quat_mul(\n            robot_root_quat_exp, isaaclab_math.quat_inv(ref_root_quat_exp)\n        )\n    )  # [B, N, 4]\n\n    ref_v_rel_aligned_w = isaaclab_math.quat_apply(delta_ori, ref_v_rel_w)\n\n    error = torch.sum(\n        torch.square(ref_v_rel_aligned_w - robot_v_rel_w), dim=-1\n    )  # [B, N]\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef root_rel_keybodylink_ang_vel_tracking_l2_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track root-relative keybody angular velocities in root frames.\n\n    Uses w_rel_w = w_body - w_root, then rotates into each entity's root frame.\n\n    Returns: [B]\n    \"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    # Root orientations and angular velocities\n    robot_root_quat_w = isaaclab_mdp.root_quat_w(env)  # [B, 4]\n    robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env)  # [B, 3]\n    ref_root_quat_w = (\n        command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 4]\n    ref_root_ang_vel_w = (\n        command.get_ref_motion_root_global_ang_vel_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 3]\n\n    # Body angular velocities (world)\n    robot_body_ang_vel_w = command.robot.data.body_ang_vel_w[\n        :, keybody_idxs\n    ]  # [B, N, 3]\n    ref_body_ang_vel_w = (\n        command.get_ref_motion_bodylink_global_ang_vel_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs]\n    )  # [B, N, 3]\n\n    # Relative (world), then rotate\n    robot_w_rel_w = robot_body_ang_vel_w - robot_root_ang_vel_w[:, None, :]\n    ref_w_rel_w = ref_body_ang_vel_w - ref_root_ang_vel_w[:, None, :]\n\n    num_bodies = len(keybody_idxs)\n    robot_root_quat_inv_exp = isaaclab_math.quat_inv(robot_root_quat_w)[\n        :, None, :\n    ].expand(-1, num_bodies, -1)\n    ref_root_quat_inv_exp = isaaclab_math.quat_inv(ref_root_quat_w)[\n        :, None, :\n    ].expand(-1, num_bodies, -1)\n\n    robot_w_rel = isaaclab_math.quat_apply(\n        robot_root_quat_inv_exp,\n        robot_w_rel_w,\n    )  # [B, N, 3]\n    ref_w_rel = isaaclab_math.quat_apply(\n        ref_root_quat_inv_exp,\n        ref_w_rel_w,\n    )  # [B, N, 3]\n\n    error = torch.sum(torch.square(ref_w_rel - robot_w_rel), dim=-1)  # [B, N]\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef global_keybodylink_lin_vel_tracking_l2_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track global keybody linear velocities.\"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    ref_global_keybody_lin_vel = (\n        command.get_ref_motion_bodylink_global_lin_vel_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs]\n    )  # [B, N, 3]\n    robot_keybody_lin_vel = command.robot.data.body_lin_vel_w[\n        :, keybody_idxs\n    ]  # [B, N, 3]\n\n    error = torch.sum(\n        torch.square(ref_global_keybody_lin_vel - robot_keybody_lin_vel),\n        dim=-1,\n    )  # [B, N]\n    return torch.exp(-error.mean(-1) / std**2)\n\n\ndef global_keybodylink_ang_vel_tracking_l2_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Track global keybody angular velocities.\"\"\"\n    command: RefMotionCommand = env.command_manager.get_term(command_name)\n    keybody_idxs = _get_body_indices(command.robot, keybody_names)\n\n    ref_global_keybody_ang_vel = (\n        command.get_ref_motion_bodylink_global_ang_vel_immediate_next(\n            prefix=ref_prefix\n        )[:, keybody_idxs]\n    )  # [B, N, 3]\n    robot_keybody_ang_vel = command.robot.data.body_ang_vel_w[\n        :, keybody_idxs\n    ]  # [B, N, 3]\n\n    error = torch.sum(\n        torch.square(ref_global_keybody_ang_vel - robot_keybody_ang_vel),\n        dim=-1,\n    )  # [B, N]\n    return torch.exp(-error.mean(-1) / std**2)\n\n\n#  @torch.compile\ndef feet_contact_time(\n    env: ManagerBasedRLEnv,\n    sensor_cfg: SceneEntityCfg,\n    threshold: float,\n) -> torch.Tensor:\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    first_air = contact_sensor.compute_first_air(env.step_dt, env.physics_dt)[\n        :, sensor_cfg.body_ids\n    ]\n    last_contact_time = contact_sensor.data.last_contact_time[\n        :, sensor_cfg.body_ids\n    ]\n    reward = torch.sum((last_contact_time < threshold) * first_air, dim=-1)\n    return reward\n\n\ndef track_lin_vel_xy_yaw_frame_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"Track linear velocity (xy) in the gravity-aligned yaw frame using exponential kernel.\n\n    This mirrors the implementation in IsaacLab locomotion velocity MDP.\n    \"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n    vel_yaw = isaaclab_math.quat_apply_inverse(\n        isaaclab_math.yaw_quat(asset.data.root_quat_w),\n        asset.data.root_lin_vel_w[:, :3],\n    )\n    # vel_yaw = isaaclab_math.quat_rotate_inverse(\n    #     isaaclab_math.yaw_quat(asset.data.root_quat_w),\n    #     asset.data.root_lin_vel_w[:, :3],\n    # )\n    lin_vel_error = torch.sum(\n        torch.square(\n            env.command_manager.get_command(command_name)[:, :2]\n            - vel_yaw[:, :2]\n        ),\n        dim=1,\n    )\n    return torch.exp(-lin_vel_error / (std**2))\n\n\ndef feet_slide(\n    env: ManagerBasedRLEnv,\n    sensor_cfg: SceneEntityCfg,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"Penalize feet sliding when in contact using contact forces and foot linear velocity.\"\"\"\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    contacts = (\n        contact_sensor.data.net_forces_w_history[:, :, sensor_cfg.body_ids, :]\n        .norm(dim=-1)\n        .max(dim=1)[0]\n        > 1.0\n    )\n    asset: Articulation = env.scene[asset_cfg.name]\n    body_vel = asset.data.body_lin_vel_w[:, asset_cfg.body_ids, :2]\n    reward = torch.sum(body_vel.norm(dim=-1) * contacts, dim=1)\n    return reward\n\n\ndef feet_slide_ang_vel(\n    env: ManagerBasedRLEnv,\n    sensor_cfg: SceneEntityCfg,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"Penalize feet sliding when in contact using contact forces and foot linear velocity.\"\"\"\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    contacts = (\n        contact_sensor.data.net_forces_w_history[:, :, sensor_cfg.body_ids, :]\n        .norm(dim=-1)\n        .max(dim=1)[0]\n        > 1.0\n    )\n    asset: Articulation = env.scene[asset_cfg.name]\n    body_ang_vel = asset.data.body_ang_vel_w[:, asset_cfg.body_ids, 2:3]\n    reward = torch.sum(body_ang_vel.norm(dim=-1) * contacts, dim=1)\n    return reward\n\n\ndef foot_clearance_reward(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg,\n    target_height: float,\n    std: float,\n    tanh_mult: float,\n    sensor_cfg: SceneEntityCfg,\n) -> torch.Tensor:\n    \"\"\"Reward swinging feet clearing a target height with velocity-shaped kernel.\n\n    Only rewards feet that are swinging (not in contact) and are close to the target height.\n    \"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n    foot_z = asset.data.body_pos_w[:, asset_cfg.body_ids, 2]  # [B, N]\n\n    delta_z = target_height - foot_z\n    delta_z = torch.clamp(delta_z, min=0.0)  # only penalze if below target\n\n    foot_z_error = torch.square(delta_z)  # [B, N]\n\n    # Only reward swinging feet (not in contact)\n    is_swinging = torch.ones_like(foot_z_error, dtype=torch.bool)\n\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    is_contact = (\n        contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0\n    )  # [B, N]\n    is_swinging = ~is_contact\n\n    # Gate reward by horizontal velocity to ensure feet are actually moving\n    foot_horizontal_vel = torch.norm(\n        asset.data.body_lin_vel_w[:, asset_cfg.body_ids, :2], dim=2\n    )  # [B, N]\n    velocity_gate = torch.tanh(tanh_mult * foot_horizontal_vel)  # [B, N]\n\n    # Reward: high when error is low (at target height) and foot is swinging\n    reward_per_foot = (\n        torch.exp(-foot_z_error / std**2) * velocity_gate * is_swinging.float()\n    )\n    return torch.sum(reward_per_foot, dim=1)\n\n\ndef feet_gait(\n    env: ManagerBasedRLEnv,\n    period: float,\n    offset: list[float],\n    sensor_cfg: SceneEntityCfg,\n    threshold: float = 0.5,\n    command_name=None,\n) -> torch.Tensor:\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    is_contact = (\n        contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0\n    )\n\n    global_phase = (\n        (env.episode_length_buf * env.step_dt) % period / period\n    ).unsqueeze(1)\n    phases = []\n    for offset_ in offset:\n        phase = (global_phase + offset_) % 1.0\n        phases.append(phase)\n    leg_phase = torch.cat(phases, dim=-1)\n\n    reward = torch.zeros(env.num_envs, dtype=torch.float, device=env.device)\n    for i in range(len(sensor_cfg.body_ids)):\n        is_stance = leg_phase[:, i] < threshold\n        reward += ~(is_stance ^ is_contact[:, i])\n\n    if command_name is not None:\n        cmd_norm = torch.norm(\n            env.command_manager.get_command(command_name), dim=1\n        )\n        reward *= cmd_norm > 0.1\n    return reward\n\n\njoint_deviation_l1_arms = isaaclab_mdp.joint_deviation_l1\njoint_deviation_l1_arms_roll = isaaclab_mdp.joint_deviation_l1\n\njoint_deviation_l1_waists = isaaclab_mdp.joint_deviation_l1\n\njoint_deviation_l1_legs = isaaclab_mdp.joint_deviation_l1\njoint_deviation_l1_legs_yaw = isaaclab_mdp.joint_deviation_l1\n\njoint_deviation_l1_stand_still = isaaclab_mdp.joint_deviation_l1\n\n\ndef joint_deviation_l2(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"Penalize joint positions that deviate from the default one.\"\"\"\n    # extract the used quantities (to enable type-hinting)\n    asset: Articulation = env.scene[asset_cfg.name]\n    # compute out of limits constraints\n    angle = (\n        asset.data.joint_pos[:, asset_cfg.joint_ids]\n        - asset.data.default_joint_pos[:, asset_cfg.joint_ids]\n    )\n    return torch.sum(torch.square(angle), dim=1)\n\n\njoint_deviation_l2_arms_roll = joint_deviation_l2\njoint_deviation_l2_arms = joint_deviation_l2\njoint_deviation_l2_waists = joint_deviation_l2\njoint_deviation_l2_legs = joint_deviation_l2\njoint_deviation_l2_shoulder_roll = joint_deviation_l2\njoint_deviation_l2_hip_roll = joint_deviation_l2\n\n\ndef energy(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"Penalize the energy used by the robot's joints.\"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n\n    qvel = asset.data.joint_vel[:, asset_cfg.joint_ids]\n    qfrc = asset.data.applied_torque[:, asset_cfg.joint_ids]\n    return torch.sum(torch.abs(qvel) * torch.abs(qfrc), dim=-1)\n\n\ndef positive_work(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"Penalize only the positive mechanical work (energy injected) by the joints.\"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n\n    qvel = asset.data.joint_vel[:, asset_cfg.joint_ids]\n    qfrc = asset.data.applied_torque[:, asset_cfg.joint_ids]\n\n    # Calculate raw mechanical power (positive = motoring, negative = braking)\n    power = qfrc * qvel\n\n    # Only keep positive values, zero out negative (braking) work\n    positive_power = torch.relu(power)\n\n    return torch.sum(positive_power, dim=-1)\n\n\nclass normed_positive_work(ManagerTermBase):\n    \"\"\"Penalize positive joint work normalized by effort and velocity limits.\"\"\"\n\n    def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):\n        super().__init__(cfg, env)\n        self._asset_name: str | None = None\n        self._joint_ids: torch.Tensor | None = None\n        self._inv_effort_limit: torch.Tensor | None = None\n\n    def _maybe_build_cache(\n        self,\n        env: ManagerBasedRLEnv,\n        asset_cfg: SceneEntityCfg,\n    ) -> Articulation:\n        asset: Articulation = env.scene[asset_cfg.name]\n        joint_ids = _joint_ids_to_tensor(\n            getattr(asset_cfg, \"joint_ids\", None),\n            num_joints=asset.data.applied_torque.shape[1],\n            device=asset.data.applied_torque.device,\n        )\n        cache_needs_refresh = (\n            self._asset_name != asset_cfg.name\n            or self._joint_ids is None\n            or not torch.equal(self._joint_ids, joint_ids)\n            or self._inv_effort_limit is None\n            or self._inv_effort_limit.shape != (joint_ids.numel(),)\n            or self._inv_effort_limit.device\n            != asset.data.applied_torque.device\n            or self._inv_effort_limit.dtype != asset.data.applied_torque.dtype\n        )\n        if not cache_needs_refresh:\n            return asset\n\n        effort_limit = _select_effort_limit_vector(asset, joint_ids)\n        self._asset_name = asset_cfg.name\n        self._joint_ids = joint_ids\n        self._inv_effort_limit = effort_limit.reciprocal()\n        return asset\n\n    def __call__(\n        self,\n        env: ManagerBasedRLEnv,\n        asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    ) -> torch.Tensor:\n        asset = self._maybe_build_cache(env, asset_cfg)\n        joint_ids = self._joint_ids\n        inv_effort_limit = self._inv_effort_limit\n        assert joint_ids is not None\n        assert inv_effort_limit is not None\n\n        current_torque = asset.data.applied_torque[:, joint_ids]\n        current_joint_vel = asset.data.joint_vel[:, joint_ids]\n        joint_vel_limits = asset.data.joint_vel_limits[:, joint_ids]\n\n        if not torch.all(torch.isfinite(joint_vel_limits)) or not torch.all(\n            joint_vel_limits > 0.0\n        ):\n            raise ValueError(\n                \"normed_positive_work requires finite, strictly positive \"\n                \"joint velocity limits for all selected joints.\"\n            )\n\n        normalized_power = (current_torque * inv_effort_limit) * (\n            current_joint_vel / joint_vel_limits\n        )\n        return torch.sum(torch.relu(normalized_power), dim=-1)\n\n\nclass normed_torque_rate(ManagerTermBase):\n    \"\"\"Penalize joint torque-rate changes normalized by actuator effort limits.\"\"\"\n\n    def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):\n        super().__init__(cfg, env)\n        self._asset_name: str | None = None\n        self._joint_ids: torch.Tensor | None = None\n        self._inv_effort_limit: torch.Tensor | None = None\n        self._prev_applied_torque: torch.Tensor | None = None\n        self._needs_reseed = torch.ones(\n            self.num_envs, device=self.device, dtype=torch.bool\n        )\n\n    def reset(self, env_ids=None) -> None:\n        if env_ids is None:\n            self._needs_reseed[:] = True\n            return\n        if isinstance(env_ids, slice):\n            self._needs_reseed[env_ids] = True\n            return\n        env_ids_tensor = torch.as_tensor(\n            env_ids, device=self.device, dtype=torch.long\n        )\n        self._needs_reseed[env_ids_tensor] = True\n\n    def _maybe_build_cache(\n        self,\n        env: ManagerBasedRLEnv,\n        asset_cfg: SceneEntityCfg,\n    ) -> Articulation:\n        asset: Articulation = env.scene[asset_cfg.name]\n        joint_ids = _joint_ids_to_tensor(\n            getattr(asset_cfg, \"joint_ids\", None),\n            num_joints=asset.data.applied_torque.shape[1],\n            device=asset.data.applied_torque.device,\n        )\n        cache_needs_refresh = (\n            self._asset_name != asset_cfg.name\n            or self._joint_ids is None\n            or not torch.equal(self._joint_ids, joint_ids)\n            or self._prev_applied_torque is None\n            or self._prev_applied_torque.shape\n            != (env.num_envs, joint_ids.numel())\n            or self._prev_applied_torque.device\n            != asset.data.applied_torque.device\n            or self._prev_applied_torque.dtype\n            != asset.data.applied_torque.dtype\n        )\n        if not cache_needs_refresh:\n            return asset\n\n        effort_limit = _select_effort_limit_vector(asset, joint_ids)\n        self._asset_name = asset_cfg.name\n        self._joint_ids = joint_ids\n        self._inv_effort_limit = effort_limit.reciprocal()\n        self._prev_applied_torque = torch.zeros(\n            env.num_envs,\n            joint_ids.numel(),\n            device=asset.data.applied_torque.device,\n            dtype=asset.data.applied_torque.dtype,\n        )\n        self._needs_reseed = torch.ones(\n            env.num_envs,\n            device=asset.data.applied_torque.device,\n            dtype=torch.bool,\n        )\n        return asset\n\n    def __call__(\n        self,\n        env: ManagerBasedRLEnv,\n        asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    ) -> torch.Tensor:\n        asset = self._maybe_build_cache(env, asset_cfg)\n        joint_ids = self._joint_ids\n        inv_effort_limit = self._inv_effort_limit\n        prev_applied_torque = self._prev_applied_torque\n        assert joint_ids is not None\n        assert inv_effort_limit is not None\n        assert prev_applied_torque is not None\n\n        current_torque = asset.data.applied_torque[:, joint_ids]\n        reward = torch.zeros(\n            env.num_envs,\n            device=current_torque.device,\n            dtype=current_torque.dtype,\n        )\n\n        reseed_mask = self._needs_reseed.clone()\n        if hasattr(env, \"episode_length_buf\"):\n            reseed_mask |= env.episode_length_buf == 0\n\n        active_mask = ~reseed_mask\n        if torch.any(active_mask):\n            delta = (\n                current_torque[active_mask] - prev_applied_torque[active_mask]\n            ) * inv_effort_limit\n            reward[active_mask] = torch.sum(delta.square(), dim=1)\n\n        prev_applied_torque.copy_(current_torque)\n        self._needs_reseed[reseed_mask] = False\n\n        return reward\n\n\ndef track_stand_still_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"base_velocity\",\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"Track stand still joint position using exponential kernel when command velocity is low.\n\n    Returns: [B]\n    \"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n\n    error = torch.sum(\n        torch.square(asset.data.joint_pos - asset.data.default_joint_pos),\n        dim=1,\n    )\n    # Use generated velocity commands (vx, vy, yaw_rate). Some command terms may\n    # expose additional channels (e.g., heading) via get_command().\n    cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)\n    if cmd.shape[-1] > 3:\n        cmd = cmd[..., :3]\n    cmd_norm = torch.norm(cmd, dim=1)\n    return torch.exp(-error / std**2) * (cmd_norm < 0.1)\n\n\ndef stand_still_joint_deviation_l1(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    \"\"\"Penalize L1 joint deviation from default pose when command velocity is low.\n\n    Returns: [B]\n    \"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n\n    # L1 error: sum(|q - q_default|)\n    error = torch.sum(\n        torch.abs(asset.data.joint_pos - asset.data.default_joint_pos),\n        dim=1,\n    )\n\n    cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)\n    if cmd.shape[-1] > 3:\n        cmd = cmd[..., :3]\n    cmd_norm = torch.norm(cmd, dim=1)\n\n    # Return error (to be penalized with negative weight) only when standing still\n    return error * (cmd_norm < 0.1)\n\n\ndef stand_still_action_rate(\n    env: ManagerBasedRLEnv,\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)\n    if cmd.shape[-1] > 3:\n        cmd = cmd[..., :3]\n    stand_still = torch.norm(cmd, dim=1) < 0.1\n    return (\n        torch.sum(\n            torch.square(\n                env.action_manager.action - env.action_manager.prev_action\n            ),\n            dim=1,\n        )\n        * stand_still\n    )\n\n\ndef stand_still_dof_vel_l2(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)\n    if cmd.shape[-1] > 3:\n        cmd = cmd[..., :3]\n    stand_still = torch.norm(cmd, dim=1) < 0.1\n    return (\n        torch.sum(\n            torch.square(env.scene[asset_cfg.name].data.joint_vel),\n            dim=1,\n        )\n        * stand_still\n    )\n\n\nclass action_acc(ManagerTermBase):\n    \"\"\"Penalize the change in action-rate using a stateful second difference.\"\"\"\n\n    def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):\n        super().__init__(cfg, env)\n        self._prev_action: torch.Tensor | None = None\n        self._prev_action_rate: torch.Tensor | None = None\n        self._needs_reseed = torch.ones(\n            self.num_envs, device=self.device, dtype=torch.bool\n        )\n        self._needs_prev_rate = torch.ones(\n            self.num_envs, device=self.device, dtype=torch.bool\n        )\n\n    def reset(self, env_ids=None) -> None:\n        if env_ids is None:\n            self._needs_reseed[:] = True\n            self._needs_prev_rate[:] = True\n            return\n        if isinstance(env_ids, slice):\n            self._needs_reseed[env_ids] = True\n            self._needs_prev_rate[env_ids] = True\n            return\n        env_ids_tensor = torch.as_tensor(\n            env_ids, device=self.device, dtype=torch.long\n        )\n        self._needs_reseed[env_ids_tensor] = True\n        self._needs_prev_rate[env_ids_tensor] = True\n\n    def _maybe_build_cache(\n        self, env: ManagerBasedRLEnv\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        current_action = env.action_manager.action\n        cache_needs_refresh = (\n            self._prev_action is None\n            or self._prev_action_rate is None\n            or self._prev_action.shape != current_action.shape\n            or self._prev_action.device != current_action.device\n            or self._prev_action.dtype != current_action.dtype\n            or self._prev_action_rate.shape != current_action.shape\n            or self._prev_action_rate.device != current_action.device\n            or self._prev_action_rate.dtype != current_action.dtype\n        )\n        if cache_needs_refresh:\n            self._prev_action = torch.zeros_like(current_action)\n            self._prev_action_rate = torch.zeros_like(current_action)\n            self._needs_reseed = torch.ones(\n                env.num_envs,\n                device=current_action.device,\n                dtype=torch.bool,\n            )\n            self._needs_prev_rate = torch.ones(\n                env.num_envs,\n                device=current_action.device,\n                dtype=torch.bool,\n            )\n\n        assert self._prev_action is not None\n        assert self._prev_action_rate is not None\n        return self._prev_action, self._prev_action_rate\n\n    def __call__(self, env: ManagerBasedRLEnv) -> torch.Tensor:\n        current_action = env.action_manager.action\n        prev_action, prev_action_rate = self._maybe_build_cache(env)\n        reward = torch.zeros(\n            env.num_envs,\n            device=current_action.device,\n            dtype=current_action.dtype,\n        )\n\n        reseed_mask = self._needs_reseed.clone()\n        if hasattr(env, \"episode_length_buf\"):\n            reseed_mask |= env.episode_length_buf == 0\n\n        if torch.any(reseed_mask):\n            prev_action[reseed_mask] = current_action[reseed_mask]\n            prev_action_rate[reseed_mask].zero_()\n            self._needs_prev_rate[reseed_mask] = True\n\n        active_mask = ~reseed_mask\n        if torch.any(active_mask):\n            current_action_rate = (\n                current_action[active_mask] - prev_action[active_mask]\n            )\n            ready_mask = ~self._needs_prev_rate[active_mask]\n            if torch.any(ready_mask):\n                action_acc_value = (\n                    current_action_rate[ready_mask]\n                    - prev_action_rate[active_mask][ready_mask]\n                )\n                reward[\n                    active_mask.nonzero(as_tuple=False).flatten()[ready_mask]\n                ] = torch.sum(action_acc_value.square(), dim=1)\n\n            prev_action[active_mask] = current_action[active_mask]\n            prev_action_rate[active_mask] = current_action_rate\n            self._needs_prev_rate[active_mask] = False\n\n        self._needs_reseed[reseed_mask] = False\n        return reward\n\n\naction_acc_l2 = action_acc\n\n\ndef feet_stumble(\n    env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg\n) -> torch.Tensor:\n    # extract the used quantities (to enable type-hinting)\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    forces_z = torch.abs(\n        contact_sensor.data.net_forces_w[:, sensor_cfg.body_ids, 2]\n    )\n    forces_xy = torch.linalg.norm(\n        contact_sensor.data.net_forces_w[:, sensor_cfg.body_ids, :2], dim=2\n    )\n    # Penalize feet hitting vertical surfaces\n    reward = torch.any(forces_xy > 4 * forces_z, dim=1).float()\n    return reward\n\n\ndef feet_too_near(\n    env: ManagerBasedRLEnv,\n    threshold: float = 0.2,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    asset: Articulation = env.scene[asset_cfg.name]\n    feet_pos = asset.data.body_pos_w[:, asset_cfg.body_ids, :]\n    distance = torch.norm(feet_pos[:, 0] - feet_pos[:, 1], dim=-1)\n    return (threshold - distance).clamp(min=0)\n\n\ndef feet_contact_without_cmd(\n    env: ManagerBasedRLEnv,\n    sensor_cfg: SceneEntityCfg,\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    \"\"\"\n    Reward for feet contact when the command is zero.\n    \"\"\"\n    # asset: Articulation = env.scene[asset_cfg.name]\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    is_contact = (\n        contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0\n    )\n\n    cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)\n    if cmd.shape[-1] > 3:\n        cmd = cmd[..., :3]\n    command_norm = torch.norm(cmd, dim=1)\n    reward = torch.sum(is_contact, dim=-1).float()\n    return reward * (command_norm < 0.1)\n\n\ndef torso_xy_ang_vel_l2_penalty(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    robot_ptr = env.scene[asset_cfg.name]\n    torso_idx = robot_ptr.body_names.index(\"torso_link\")\n\n    # World-frame torso angular velocity: [B, 3]\n    torso_ang_vel_w: torch.Tensor = robot_ptr.data.body_ang_vel_w[\n        :, torso_idx, :\n    ]\n\n    # Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading.\n    # Build yaw-only quaternion from stored heading_w (shape [B]).\n    heading_yaw: torch.Tensor = robot_ptr.data.heading_w  # [B]\n    zero = torch.zeros_like(heading_yaw, device=env.device)\n    heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz(\n        roll=zero,\n        pitch=zero,\n        yaw=heading_yaw,\n    )  # [B, 4]\n\n    # Re-express torso angular velocity in heading-aligned frame.\n    heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz)\n    torso_ang_vel_h: torch.Tensor = isaaclab_math.quat_apply(\n        heading_inv_wxyz,\n        torso_ang_vel_w,\n    )  # [B, 3]\n\n    # Penalize lateral components (x, y) with squared magnitude.\n    penalty: torch.Tensor = torch.sum(\n        torch.square(torso_ang_vel_h[:, :2]),\n        dim=-1,\n    )  # [B]\n    return penalty\n\n\ndef torso_upright_l2_penalty(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    robot_ptr = env.scene[asset_cfg.name]\n    torso_idx = robot_ptr.body_names.index(\"torso_link\")\n    torso_rot_quat_w = robot_ptr.data.body_quat_w[:, torso_idx, :]\n\n    # Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading.\n    # Build yaw-only quaternion from stored heading_w (shape [B]).\n    heading_yaw: torch.Tensor = robot_ptr.data.heading_w  # [B]\n    zero = torch.zeros_like(heading_yaw, device=env.device)\n    heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz(\n        roll=zero,\n        pitch=zero,\n        yaw=heading_yaw,\n    )  # [B, 4]\n\n    # Re-express torso angular velocity in heading-aligned frame.\n    heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz)\n    torso_rot_quat_h: torch.Tensor = isaaclab_math.quat_mul(\n        heading_inv_wxyz,\n        torso_rot_quat_w,\n    )  # [B, 3]\n\n    # get the roll and pitch\n    roll, pitch, _ = isaaclab_math.euler_xyz_from_quat(torso_rot_quat_h)\n    pitch *= pitch > 0.0\n    rollpitch = torch.stack([roll * 2.0, pitch], dim=-1)\n\n    # Penalize lateral components (x, y) with squared magnitude.\n    penalty: torch.Tensor = torch.sum(\n        torch.square(rollpitch),\n        dim=-1,\n    )  # [B]\n    return penalty\n\n\ndef torso_upright_l2_penalty_v2(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    target_pitch: float = 0.0,\n    roll_scale: float = 2.0,\n    pitch_scale: float = 1.0,\n) -> torch.Tensor:\n    \"\"\"Penalize torso roll/pitch deviation in a heading-aligned frame (symmetric).\n\n    Compared to `torso_upright_l2_penalty`, this version penalizes *both* forward\n    and backward pitch w.r.t. `target_pitch`.\n\n    Returns: [B]\n    \"\"\"\n    robot_ptr = env.scene[asset_cfg.name]\n    torso_idx = robot_ptr.body_names.index(\"torso_link\")\n    torso_rot_quat_w: torch.Tensor = robot_ptr.data.body_quat_w[\n        :, torso_idx, :\n    ]  # [B, 4]\n\n    heading_yaw: torch.Tensor = robot_ptr.data.heading_w  # [B]\n    zero = torch.zeros_like(heading_yaw, device=env.device)\n    heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz(\n        roll=zero,\n        pitch=zero,\n        yaw=heading_yaw,\n    )  # [B, 4]\n\n    heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz)\n    torso_rot_quat_h: torch.Tensor = isaaclab_math.quat_mul(\n        heading_inv_wxyz,\n        torso_rot_quat_w,\n    )  # [B, 4]\n\n    roll, pitch, _ = isaaclab_math.euler_xyz_from_quat(torso_rot_quat_h)  # [B]\n    roll_err: torch.Tensor = roll_scale * roll\n    pitch_err: torch.Tensor = pitch_scale * (pitch - target_pitch)\n    roll_pitch = torch.stack([roll_err, pitch_err], dim=-1)  # [B, 2]\n\n    penalty: torch.Tensor = torch.sum(torch.square(roll_pitch), dim=-1)  # [B]\n    return penalty\n\n\ndef stand_still_torso_upright_exp_v2(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"base_velocity\",\n    cmd_threshold: float = 0.1,\n    target_pitch: float = 0.0,\n    roll_scale: float = 2.0,\n    pitch_scale: float = 1.0,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"Reward torso uprightness under stand-still commands using an exp kernel.\n\n    Reward:\n        exp(-penalty / std^2)  if ||cmd|| <= cmd_threshold else 0\n    where penalty is computed by `torso_upright_l2_penalty_v2`.\n\n    Returns: [B]\n    \"\"\"\n    command = env.command_manager.get_command(command_name)\n    stand_still_flag: torch.Tensor = (\n        torch.norm(command, dim=1) <= cmd_threshold\n    )\n\n    penalty = torso_upright_l2_penalty_v2(\n        env,\n        asset_cfg=asset_cfg,\n        target_pitch=target_pitch,\n        roll_scale=roll_scale,\n        pitch_scale=pitch_scale,\n    )  # [B]\n    reward = torch.exp(-penalty / std**2)  # [B]\n    return reward * stand_still_flag.to(dtype=reward.dtype)\n\n\ndef torso_linacc_xy_l2_penalty(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    robot_ptr = env.scene[asset_cfg.name]\n    torso_idx = robot_ptr.body_names.index(\"torso_link\")\n\n    # World-frame torso angular velocity: [B, 3]\n\n    torso_linacc_w = robot_ptr.data.body_lin_acc_w[:, torso_idx, :]\n\n    # Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading.\n    # Build yaw-only quaternion from stored heading_w (shape [B]).\n    heading_yaw: torch.Tensor = robot_ptr.data.heading_w  # [B]\n    zero = torch.zeros_like(heading_yaw, device=env.device)\n    heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz(\n        roll=zero,\n        pitch=zero,\n        yaw=heading_yaw,\n    )  # [B, 4]\n\n    # Re-express torso angular velocity in heading-aligned frame.\n    heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz)\n    torso_linacc_h: torch.Tensor = isaaclab_math.quat_apply(\n        heading_inv_wxyz,\n        torso_linacc_w,\n    )  # [B, 3]\n\n    # Penalize lateral components (x, y) with squared magnitude.\n    penalty: torch.Tensor = torch.sum(\n        torch.square(torso_linacc_h),\n        dim=-1,\n    )  # [B]\n    return penalty\n\n\ndef track_lin_vel_xy_heading_aligned_frame_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"base_velocity\",\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"\n    Track linear velocity (xy) in the heading-aligned frame using exponential kernel.\n    Returns: [B]\n    \"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n    vel_yaw = isaaclab_math.quat_apply_inverse(\n        isaaclab_math.yaw_quat(asset.data.root_quat_w),\n        asset.data.root_lin_vel_w[:, :3],\n    )\n    command = env.command_manager.get_command(command_name)\n    stand_still_envs = torch.norm(command, dim=1) <= 0.1\n\n    # treat yaw-only envs as zero-translation targets too\n    # (vx, vy are approx 0 by definition)\n    zero_lin_vel_envs = stand_still_envs\n    tracking_targets = torch.where(\n        zero_lin_vel_envs[:, None], 0.0, command[:, :2]\n    )\n    lin_vel_error = torch.sum(\n        torch.square(tracking_targets - vel_yaw[:, :2]),\n        dim=1,\n    )\n    return torch.exp(-lin_vel_error / std**2)\n\n\ndef track_lin_vel_xy_heading_aligned_frame_exp_v2(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"base_velocity\",\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    asset: Articulation = env.scene[asset_cfg.name]\n    vel_yaw = isaaclab_math.quat_apply_inverse(\n        isaaclab_math.yaw_quat(asset.data.root_quat_w),\n        asset.data.root_lin_vel_w[:, :3],\n    )\n    command = env.command_manager.get_command(command_name)\n\n    yaw_envs = (torch.norm(command[:, :2], dim=1) < 0.1) & (\n        torch.abs(command[:, 2]) > 0.1\n    )\n    stand_still_envs = torch.norm(command, dim=1) <= 0.1\n\n    # treat yaw-only envs as zero-translation targets too\n    # (vx, vy are approx 0 by definition)\n    zero_lin_vel_envs = stand_still_envs | yaw_envs\n    tracking_targets = torch.where(\n        zero_lin_vel_envs[:, None], 0.0, command[:, :2]\n    )\n    lin_vel_error = torch.sum(\n        torch.square(tracking_targets - vel_yaw[:, :2]),\n        dim=1,\n    )\n\n    # encourage zero linear velocity for stand still environments, and encourage yaw-only environments to have more\n    # precise zero linear velocity tracking too\n    reward_weights = torch.where(yaw_envs, 2.0, 1.0) + torch.where(\n        stand_still_envs, 10.0, 0.0\n    )\n\n    return reward_weights * torch.exp(-lin_vel_error / std**2)\n\n\ndef track_ang_vel_z_heading_aligned_frame_exp_v2(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"base_velocity\",\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"\n    Track angular velocity (z) in the heading-aligned frame using exponential kernel.\n    Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame.\n    Returns: [B]\n    \"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n    command = env.command_manager.get_command(command_name)\n\n    yaw_envs = (torch.norm(command[:, :2], dim=1) < 0.1) & (\n        torch.abs(command[:, 2]) > 0.1\n    )\n    stand_still_envs = torch.norm(command, dim=1) <= 0.1\n\n    # set the tracking targets to 0.0 for stand still environments\n    tracking_targets = torch.where(stand_still_envs, 0.0, command[:, 2])\n\n    ang_vel_error = torch.square(\n        tracking_targets - asset.data.root_ang_vel_w[:, 2]\n    )\n\n    # encourage zero angular velocity for stand still environments, and encourage yaw-only environments to have more\n    # precise angular velocity tracking\n    reward_weights = torch.where(yaw_envs, 2.0, 1.0) + torch.where(\n        stand_still_envs, 10.0, 0.0\n    )\n    return reward_weights * torch.exp(-ang_vel_error / std**2)\n\n\ndef track_ang_vel_z_heading_aligned_frame_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"base_velocity\",\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"\n    Track angular velocity (z) in the heading-aligned frame using exponential kernel.\n    Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame.\n    Returns: [B]\n    \"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n    command = env.command_manager.get_command(command_name)\n    stand_still_envs = torch.norm(command, dim=1) <= 0.1\n    tracking_targets = torch.where(stand_still_envs, 0.0, command[:, 2])\n    ang_vel_error = torch.square(\n        tracking_targets - asset.data.root_ang_vel_w[:, 2]\n    )\n    return torch.exp(-ang_vel_error / std**2)\n\n\ndef smoothed_track_ang_vel_z_heading_aligned_frame_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"base_velocity\",\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n) -> torch.Tensor:\n    \"\"\"\n    Track angular velocity (z) in the heading-aligned frame using exponential kernel.\n    Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame.\n    Returns: [B]\n    \"\"\"\n    asset: Articulation = env.scene[asset_cfg.name]\n    command = env.command_manager.get_command(command_name)\n    hist_robot_heading_aligned_ang_vel_z = env.observation_manager.compute()[\n        \"unified\"\n    ][\"rew_heading_aligned_root_ang_vel\"][..., 2]\n    ep_len = env.episode_length_buf\n    obs_window_len = hist_robot_heading_aligned_ang_vel_z.shape[1]\n    smooth_window = torch.minimum(\n        torch.full_like(ep_len, obs_window_len), ep_len\n    )\n    smoothed_robot_heading_aligned_ang_vel_z = (\n        hist_robot_heading_aligned_ang_vel_z.sum(dim=1) / smooth_window\n    )\n    ang_vel_error = torch.square(\n        command[:, 2] - smoothed_robot_heading_aligned_ang_vel_z\n    )\n    return torch.exp(-ang_vel_error / std**2)\n\n\ndef feet_air_time(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    sensor_cfg: SceneEntityCfg,\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids]\n    contact_time = contact_sensor.data.current_contact_time[\n        :, sensor_cfg.body_ids\n    ]\n    in_contact = contact_time > 0.0\n    in_mode_time = torch.where(in_contact, contact_time, air_time)\n    single_stance = torch.sum(in_contact.int(), dim=1) == 1\n    reward = torch.min(\n        torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1\n    )[0]\n    reward = torch.clamp(reward, max=threshold)\n    # no reward for zero command\n    command = env.command_manager.get_command(command_name)\n    reward *= (\n        torch.norm(command[:, :2], dim=1) + torch.abs(command[:, 2])\n    ) > 0.1\n    return reward\n\n\ndef feet_air_time_v2(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    sensor_cfg: SceneEntityCfg,\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids]\n    contact_time = contact_sensor.data.current_contact_time[\n        :, sensor_cfg.body_ids\n    ]\n    in_contact = contact_time > 0.0\n    in_mode_time = torch.where(in_contact, contact_time, air_time)\n    single_stance = torch.sum(in_contact.int(), dim=1) == 1\n    reward = torch.min(\n        torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1\n    )[0]\n    reward = torch.clamp(reward, max=threshold)\n    # no reward for zero command\n    command = env.command_manager.get_command(command_name)\n    stand_still_envs_flag = torch.norm(command, dim=1) <= 0.1\n    ang_z_only_mask = (torch.norm(command[:, :2], dim=1) <= 0.1) & (\n        torch.abs(command[:, 2]) > 0.1\n    )\n    # Stand still: 0.0, yaw-only: 10.0, other: 1.0\n    reward_weights = torch.where(\n        stand_still_envs_flag, 0.0, 1.0\n    ) + torch.where(ang_z_only_mask, 5.0, 0.0)\n    return reward * reward_weights\n\n\ndef feet_air_time_v3(\n    env: ManagerBasedRLEnv,\n    command_name: str,\n    sensor_cfg: SceneEntityCfg,\n    threshold: float,\n) -> torch.Tensor:\n    \"\"\"Reward long steps taken by the feet using L2-kernel.\n\n    This function rewards the agent for taking steps that are longer than a threshold. This helps ensure\n    that the robot lifts its feet off the ground and takes steps. The reward is computed as the sum of\n    the time for which the feet are in the air.\n\n    If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero.\n    \"\"\"\n    # extract the used quantities (to enable type-hinting)\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    # compute the reward\n    first_contact = contact_sensor.compute_first_contact(env.step_dt)[\n        :, sensor_cfg.body_ids\n    ]\n    last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids]\n    reward = torch.sum((last_air_time - threshold) * first_contact, dim=1)\n    # no reward for stand still commands, larger reward for yaw-only commands\n    commands = env.command_manager.get_command(command_name)\n    stand_still_envs = torch.norm(commands, dim=1) <= 0.1\n    yaw_only_envs = (torch.norm(commands[:, :2], dim=1) <= 0.1) & (\n        torch.abs(commands[:, 2]) > 0.1\n    )\n    reward_weights = torch.where(stand_still_envs, 0.0, 1.0) + torch.where(\n        yaw_only_envs, 4.0, 0.0\n    )\n\n    return reward * reward_weights\n\n\ndef feet_air_time_v4(\n    env: ManagerBasedRLEnv,\n    command_name: str,\n    sensor_cfg: SceneEntityCfg,\n    threshold: float,\n) -> torch.Tensor:\n    \"\"\"Reward long steps taken by the feet using L2-kernel.\n\n    This function rewards the agent for taking steps that are longer than a threshold. This helps ensure\n    that the robot lifts its feet off the ground and takes steps. The reward is computed as the sum of\n    the time for which the feet are in the air.\n\n    If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero.\n    \"\"\"\n    # extract the used quantities (to enable type-hinting)\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    # compute the reward\n    first_contact = contact_sensor.compute_first_contact(env.step_dt)[\n        :, sensor_cfg.body_ids\n    ]\n    last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids]\n    reward = torch.sum((last_air_time - threshold) * first_contact, dim=1)\n    # no reward for stand still commands, larger reward for yaw-only commands\n    commands = env.command_manager.get_command(command_name)\n    stand_still_envs = torch.norm(commands, dim=1) <= 0.1\n    reward_weights = torch.where(stand_still_envs, 0.0, 1.0)\n\n    return reward * reward_weights\n\n\ndef yaw_rate_only_movement_l2_penalty(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    \"\"\"Penalize world-frame root XY translation during yaw-only commands.\n\n    When vx, vy are small commands, penalize\n    the squared magnitude of root linear velocity (vx, vy) in world frame.\n\n    Returns: [B]\n    \"\"\"\n    # Velocity command: [B, 3] (vx, vy, yaw_rate). Some command terms may\n    # expose extra channels (e.g., heading) via generated_commands().\n    command = env.command_manager.get_command(command_name)\n\n    # Gate only yaw-rate-only envs: vx=vy=0 and v_yaw > 0.0.\n    yaw_only_mask: torch.Tensor = (\n        torch.norm(command[:, :2], dim=1) <= 0.1\n    )  # [B]\n\n    # Penalize global (world-frame) root linear velocity in x/y.\n    asset: Articulation = env.scene[asset_cfg.name]\n    root_lin_vel_w: torch.Tensor = asset.data.root_lin_vel_w  # [B, 3]\n    penalty: torch.Tensor = torch.sum(\n        torch.square(root_lin_vel_w[:, :2]),\n        dim=1,\n    )  # [B]\n    return penalty * yaw_only_mask.to(dtype=penalty.dtype)\n\n\ndef fly(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    sensor_cfg: SceneEntityCfg,\n) -> torch.Tensor:\n    contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]\n    net_contact_forces = contact_sensor.data.net_forces_w_history\n    is_contact = (\n        torch.max(\n            torch.norm(net_contact_forces[:, :, sensor_cfg.body_ids], dim=-1),\n            dim=1,\n        )[0]\n        > threshold\n    )\n    return torch.sum(is_contact, dim=-1) < 0.5\n\n\ndef stand_still_torso_lin_vel_l2_penalty(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    robot_ptr = env.scene[asset_cfg.name]\n    torso_idx = robot_ptr.body_names.index(\"torso_link\")\n    torso_lin_vel_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :]\n    penalty = torch.sum(torch.square(torso_lin_vel_w), dim=-1)\n    command = env.command_manager.get_command(command_name)\n    stand_still_flag = torch.norm(command, dim=1) <= 0.1\n    return penalty * stand_still_flag.to(dtype=penalty.dtype)\n\n\ndef stand_still_torso_ang_vel_l2_penalty(\n    env: ManagerBasedRLEnv,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    robot_ptr = env.scene[asset_cfg.name]\n    torso_idx = robot_ptr.body_names.index(\"torso_link\")\n    torso_ang_vel_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :]\n    penalty = torch.sum(torch.square(torso_ang_vel_w), dim=-1)\n    command = env.command_manager.get_command(command_name)\n    stand_still_flag = torch.norm(command, dim=1) <= 0.1\n    return penalty * stand_still_flag.to(dtype=penalty.dtype)\n\n\ndef stand_still_torso_lin_vel_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    \"\"\"Reward staying still (zero torso linear velocity) when commanded to stand.\n\n    Uses exponential kernel: exp(-||v||^2 / std^2)\n\n    Args:\n        env: Environment instance\n        std: Standard deviation for exponential kernel\n        asset_cfg: Robot asset configuration\n        command_name: Name of velocity command\n\n    Returns:\n        Reward tensor of shape [B], active only when stand still commanded\n    \"\"\"\n    robot_ptr = env.scene[asset_cfg.name]\n    torso_idx = robot_ptr.body_names.index(\"torso_link\")\n    torso_lin_vel_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :]\n    error = torch.sum(torch.square(torso_lin_vel_w), dim=-1)\n    reward = torch.exp(-error / std**2)\n    command = env.command_manager.get_command(command_name)\n    stand_still_flag = torch.norm(command, dim=1) <= 0.1\n    return reward * stand_still_flag.to(dtype=reward.dtype)\n\n\ndef stand_still_torso_ang_vel_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    \"\"\"Reward staying still (zero torso angular velocity) when commanded to stand.\n\n    Uses exponential kernel: exp(-||omega||^2 / std^2)\n\n    Args:\n        env: Environment instance\n        std: Standard deviation for exponential kernel\n        asset_cfg: Robot asset configuration\n        command_name: Name of velocity command\n\n    Returns:\n        Reward tensor of shape [B], active only when stand still commanded\n    \"\"\"\n    robot_ptr = env.scene[asset_cfg.name]\n    torso_idx = robot_ptr.body_names.index(\"torso_link\")\n    torso_ang_vel_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :]\n    error = torch.sum(torch.square(torso_ang_vel_w), dim=-1)\n    reward = torch.exp(-error / std**2)\n    command = env.command_manager.get_command(command_name)\n    stand_still_flag = torch.norm(command, dim=1) <= 0.1\n    return reward * stand_still_flag.to(dtype=reward.dtype)\n\n\ndef yaw_rate_only_movement_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    command_name: str = \"base_velocity\",\n) -> torch.Tensor:\n    \"\"\"Reward minimal XY translation during yaw-only commands.\n\n    When vx, vy commands are small, reward staying in place using exponential kernel.\n    Uses exponential kernel: exp(-||v_xy||^2 / std^2)\n\n    Args:\n        env: Environment instance\n        std: Standard deviation for exponential kernel\n        asset_cfg: Robot asset configuration\n        command_name: Name of velocity command\n\n    Returns:\n        Reward tensor of shape [B], active only during yaw-only commands\n    \"\"\"\n    command = env.command_manager.get_command(command_name)\n    yaw_only_mask: torch.Tensor = torch.norm(command[:, :2], dim=1) <= 0.1\n\n    asset: Articulation = env.scene[asset_cfg.name]\n    root_lin_vel_w: torch.Tensor = asset.data.root_lin_vel_w\n    error: torch.Tensor = torch.sum(torch.square(root_lin_vel_w[:, :2]), dim=1)\n    reward: torch.Tensor = torch.exp(-error / std**2)\n    return reward * yaw_only_mask.to(dtype=reward.dtype)\n\n\ndef yaw_rate_only_hip_yaw_usage_exp(\n    env: ManagerBasedRLEnv,\n    std: float,\n    command_name: str = \"base_velocity\",\n    hip_yaw_dofs: list[str] | None = None,\n    asset_cfg: SceneEntityCfg = SceneEntityCfg(\"robot\"),\n    lin_threshold: float = 0.1,\n    yaw_threshold: float = 0.1,\n    command_tanh_mult: float = 1.0,\n) -> torch.Tensor:\n    \"\"\"Encourage using hip_yaw joint(s) during yaw-rate-only commands.\n\n    Active only when commanded to rotate in place (vx, vy small and |yaw_rate| large).\n    Rewards hip_yaw joint velocity magnitude using a saturating exponential kernel:\n        r = (1 - exp(-mean(qd_hip_yaw^2) / std^2)) * tanh(command_tanh_mult * |cmd_yaw|)\n\n    Shapes:\n    - command: [B, 3] (vx, vy, yaw_rate)\n    - asset.data.joint_vel: [B, num_dofs]\n    - return: [B]\n    \"\"\"\n    command: torch.Tensor = env.command_manager.get_command(command_name)\n    yaw_only_mask: torch.Tensor = (\n        torch.norm(command[:, :2], dim=1) <= lin_threshold\n    ) & (torch.abs(command[:, 2]) > yaw_threshold)  # [B]\n\n    asset: Articulation = env.scene[asset_cfg.name]\n    if hip_yaw_dofs is None:\n        raise ValueError(\n            \"yaw_rate_only_hip_yaw_usage_exp requires hip_yaw_dofs (joint names in \"\n            f\"robot.joint_names). Got hip_yaw_dofs=None. robot.joint_names={asset.joint_names}\"\n        )\n    hip_yaw_joint_ids: list[int] = _get_dof_indices(asset, hip_yaw_dofs)\n\n    hip_yaw_vel: torch.Tensor = asset.data.joint_vel[\n        :, hip_yaw_joint_ids\n    ]  # [B, N]\n    activity_sq: torch.Tensor = torch.mean(\n        torch.square(hip_yaw_vel), dim=-1\n    )  # [B]\n    usage_reward: torch.Tensor = 1.0 - torch.exp(-activity_sq / std**2)  # [B]\n\n    cmd_yaw_abs: torch.Tensor = torch.abs(command[:, 2])  # [B]\n    cmd_weight: torch.Tensor = torch.tanh(\n        command_tanh_mult * cmd_yaw_abs\n    )  # [B]\n\n    reward: torch.Tensor = usage_reward * cmd_weight\n    return reward * yaw_only_mask.to(dtype=reward.dtype)\n\n\n@configclass\nclass RewardsCfg:\n    pass\n\n\nclass TaskGatedReward:\n    \"\"\"Callable wrapper to gate reward terms by task_id.\"\"\"\n\n    def __init__(self, func, task_name: str):\n        self.func = func\n        self.task_name = task_name\n        self.__name__ = f\"TaskGatedReward[{task_name}]\"\n\n    def __call__(self, env: ManagerBasedRLEnv, *args, **kwargs):\n        task_ids = getattr(env, \"holo_task_ids\", None)\n        mapping = getattr(env, \"holo_task_name_to_id\", None)\n        if task_ids is None or mapping is None:\n            return torch.zeros(env.num_envs, device=env.device)\n        target = mapping.get(self.task_name, None)\n        if target is None:\n            return torch.zeros(env.num_envs, device=env.device)\n        mask = task_ids == target\n        if not torch.any(mask):\n            return torch.zeros(env.num_envs, device=env.device)\n\n        inner_args = kwargs.pop(\"args\", None)\n        inner_kwargs = kwargs.pop(\"kwargs\", None)\n        call_args = args if inner_args is None else (*args, *inner_args)\n        call_kwargs = (\n            kwargs if inner_kwargs is None else {**kwargs, **inner_kwargs}\n        )\n\n        reward = self.func(env, *call_args, **call_kwargs)\n        mask = mask.to(device=reward.device, dtype=reward.dtype)\n        return reward * mask\n\n\ndef build_rewards_config(reward_config_dict: dict):\n    if isinstance(reward_config_dict, (DictConfig, ListConfig)):\n        reward_config_dict = OmegaConf.to_container(\n            reward_config_dict, resolve=True\n        )\n\n    rewards_cfg = RewardsCfg()\n\n    # Detect grouped (multi-task) vs flat (legacy) layout\n    def _is_grouped(cfg: dict) -> bool:\n        for k, v in cfg.items():\n            if k == \"_config\":\n                continue\n            if isinstance(v, dict) and \"weight\" in v:\n                return False\n            return True\n        return False\n\n    is_grouped = _is_grouped(reward_config_dict)\n\n    if not is_grouped:\n        for reward_name, reward_cfg in reward_config_dict.items():\n            if reward_name == \"_config\":\n                continue\n            reward_cfg = resolve_holo_config(reward_cfg)\n            base_params = resolve_holo_config(reward_cfg[\"params\"])\n            method_name = f\"{reward_name}\"\n            func = globals().get(method_name, None)\n            if func is None:\n                func = getattr(isaaclab_mdp, reward_name, None)\n            if func is None:\n                raise ValueError(f\"Unknown reward function: {reward_name}\")\n            params = dict(base_params)\n            setattr(\n                rewards_cfg,\n                reward_name,\n                RewardTermCfg(\n                    func=func,\n                    weight=reward_cfg[\"weight\"],\n                    params=params,\n                ),\n            )\n        return rewards_cfg\n\n    # Grouped: rewards: {task_name: {term: ...}}\n    for task_name, task_group in reward_config_dict.items():\n        if task_name.startswith(\"_\"):\n            continue\n        if not isinstance(task_group, dict):\n            raise ValueError(f\"Expected dict for task group {task_name}\")\n        for reward_name, reward_cfg in task_group.items():\n            reward_cfg = resolve_holo_config(reward_cfg)\n            base_params = resolve_holo_config(reward_cfg[\"params\"])\n            method_name = f\"{reward_name}\"\n            func = globals().get(method_name, None)\n            if func is None:\n                func = getattr(isaaclab_mdp, reward_name, None)\n            if func is None:\n                raise ValueError(f\"Unknown reward function: {reward_name}\")\n            if task_name != \"common\":\n                func = TaskGatedReward(func, task_name)\n                params = {\"args\": [], \"kwargs\": base_params}\n            else:\n                params = base_params\n            flat_name = f\"{task_name}.{reward_name}\"\n            setattr(\n                rewards_cfg,\n                flat_name,\n                RewardTermCfg(\n                    func=func,\n                    weight=reward_cfg[\"weight\"],\n                    params=params,\n                ),\n            )\n\n    return rewards_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_scene.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nimport copy\nimport os\nimport time\nfrom dataclasses import MISSING\n\nimport isaaclab.sim as sim_utils\nfrom isaaclab.actuators import ImplicitActuatorCfg\nfrom isaaclab.assets import ArticulationCfg, AssetBaseCfg\nfrom isaaclab.scene import InteractiveSceneCfg\nfrom isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns\nfrom isaaclab.terrains import TerrainImporterCfg\nfrom isaaclab.utils import configclass\nfrom loguru import logger\nfrom holomotion.src.env.isaaclab_components.isaaclab_terrain import (\n    build_terrain_config,\n)\nfrom holomotion.src.env.isaaclab_components.unitree_actuators import (\n    UnitreeActuator,\n    UnitreeActuatorCfg,\n    UnitreeErfiActuator,\n    UnitreeErfiActuatorCfg,\n)\n\n\nclass SceneFunctions:\n    \"\"\"Collection of scene component builders.\"\"\"\n\n    @staticmethod\n    def build_robot_config(\n        config: dict,\n        domain_rand_config: dict | None = None,\n        main_process: bool = True,\n        process_id: int = 0,\n        num_processes: int = 1,\n    ) -> ArticulationCfg:\n        \"\"\"Build robot articulation configuration.\n\n        Args:\n            config: Robot configuration dictionary\n            main_process: Whether this is the main process (from compiled config)\n            process_id: Process ID/rank (from compiled config)\n            num_processes: Total number of processes (from compiled config)\n        \"\"\"\n        urdf_path = config.asset.urdf_file\n        init_pos = config.init_state.pos\n        default_joint_positions = config.init_state.default_joint_angles\n        root_link_name = config.get(\"root_name\", \"pelvis\")\n        prim_path = \"{ENV_REGEX_NS}/Robot\"\n\n        actuator_type = config.actuators.get(\"actuator_type\", \"implicit\")\n        if actuator_type in {\"unitree\", \"unitree_erfi\"}:\n            actuators = _build_unitree_actuator_cfg(\n                config.actuators, domain_rand_config or {}\n            )\n        else:\n            actuators = {\n                \"all_joints\": ImplicitActuatorCfg(\n                    **config.actuators.all_joints\n                )\n            }\n\n        logger.info(f\"Using {actuator_type} actuators\")\n        logger.info(f\"Actuators: {actuators}\")\n\n        if not os.path.exists(urdf_path):\n            raise FileNotFoundError(f\"URDF file not found: {urdf_path}\")\n\n        # Configure USD output directory. Optionally isolate per rank to avoid races.\n        usd_base_dir = os.path.join(os.path.dirname(urdf_path), \"usd\")\n        unique_usd_per_rank = True\n        if num_processes > 1 and unique_usd_per_rank:\n            usd_dir = os.path.join(usd_base_dir, f\"rank_{process_id}\")\n        else:\n            usd_dir = usd_base_dir\n        os.makedirs(usd_dir, exist_ok=True)\n        logger.info(f\"Using URDF path: {urdf_path}\")\n        logger.info(f\"Using USD directory: {usd_dir}\")\n\n        force_usd_conversion = config.asset.get(\"force_usd_conversion\", True)\n        if num_processes > 1 and unique_usd_per_rank:\n            # Ensure each rank generates its own USD into its isolated directory\n            force_usd_conversion = True\n\n        # Handle DDP\n        if num_processes > 1:\n            logger.info(\n                f\"[Process {process_id}/{num_processes}] Distributed training detected\"\n            )\n\n            if unique_usd_per_rank:\n                logger.info(\n                    f\"[Process {process_id}] Using per-rank USD dir; forcing USD conversion: {force_usd_conversion}\"\n                )\n            else:\n                # Only main process should convert USD to avoid file conflicts\n                if main_process:\n                    logger.info(\n                        f\"[Process {process_id}] Main process - Force USD conversion: {force_usd_conversion}\"\n                    )\n                else:\n                    logger.info(\n                        f\"[Process {process_id}] Non-main process - Skipping USD conversion, waiting for main process\"\n                    )\n                    force_usd_conversion = False\n\n                    # Wait for USD files to be created by main process\n                    urdf_basename = os.path.splitext(\n                        os.path.basename(urdf_path)\n                    )[0]\n                    expected_usd_file = os.path.join(\n                        usd_dir, f\"{urdf_basename}.usd\"\n                    )\n\n                    logger.info(\n                        f\"[Process {process_id}] Waiting for main process to create USD files at {expected_usd_file}...\"\n                    )\n                    max_wait = 60\n                    wait_interval = 1\n                    waited = 0\n\n                    while (\n                        not os.path.exists(expected_usd_file)\n                        and waited < max_wait\n                    ):\n                        time.sleep(wait_interval)\n                        waited += wait_interval\n\n                    if os.path.exists(expected_usd_file):\n                        logger.info(\n                            f\"[Process {process_id}] USD file found, proceeding with loading\"\n                        )\n                    else:\n                        logger.warning(\n                            f\"[Process {process_id}] USD file not found after {max_wait}s, proceeding anyway\"\n                        )\n        else:\n            logger.info(\n                f\"Single process training. Force USD conversion: {force_usd_conversion}\"\n            )\n\n        articulation_cfg = ArticulationCfg(\n            prim_path=prim_path,\n            spawn=sim_utils.UrdfFileCfg(\n                asset_path=os.path.abspath(urdf_path),\n                usd_dir=os.path.abspath(usd_dir),\n                force_usd_conversion=force_usd_conversion,\n                fix_base=False,\n                merge_fixed_joints=True,\n                root_link_name=root_link_name,\n                replace_cylinders_with_capsules=True,\n                activate_contact_sensors=True,\n                rigid_props=sim_utils.RigidBodyPropertiesCfg(\n                    disable_gravity=False,\n                    retain_accelerations=False,\n                    linear_damping=0.0,\n                    angular_damping=0.0,\n                    max_linear_velocity=1000.0,\n                    max_angular_velocity=1000.0,\n                    max_depenetration_velocity=1.0,\n                ),\n                articulation_props=sim_utils.ArticulationRootPropertiesCfg(\n                    enabled_self_collisions=True,\n                    solver_position_iteration_count=8,\n                    solver_velocity_iteration_count=4,\n                ),\n                joint_drive=sim_utils.UrdfConverterCfg.JointDriveCfg(\n                    gains=sim_utils.UrdfConverterCfg.JointDriveCfg.PDGainsCfg(\n                        stiffness=0,\n                        damping=0,\n                    )\n                ),\n            ),\n            init_state=ArticulationCfg.InitialStateCfg(\n                pos=init_pos,\n                joint_pos=default_joint_positions,\n                joint_vel={\".*\": 0.0},\n            ),\n            soft_joint_pos_limit_factor=0.9,\n            actuators=actuators,\n        )\n\n        return articulation_cfg\n\n    @staticmethod\n    def build_lighting_config(\n        config: dict,\n    ) -> tuple[AssetBaseCfg, AssetBaseCfg]:\n        \"\"\"Build lighting configuration.\"\"\"\n        distant_light_intensity = config.get(\"distant_light_intensity\", 3000.0)\n        dome_light_intensity = config.get(\"dome_light_intensity\", 1000.0)\n        distant_light_color = config.get(\n            \"distant_light_color\", (0.75, 0.75, 0.75)\n        )\n        dome_light_color = config.get(\"dome_light_color\", (0.13, 0.13, 0.13))\n\n        light = AssetBaseCfg(\n            prim_path=\"/World/light\",\n            spawn=sim_utils.DistantLightCfg(\n                color=distant_light_color, intensity=distant_light_intensity\n            ),\n        )\n        sky_light = AssetBaseCfg(\n            prim_path=\"/World/skyLight\",\n            spawn=sim_utils.DomeLightCfg(\n                color=dome_light_color, intensity=dome_light_intensity\n            ),\n        )\n        return light, sky_light\n\n    @staticmethod\n    def build_contact_sensor_config(config: dict) -> ContactSensorCfg:\n        \"\"\"Build contact sensor configuration.\"\"\"\n        prim_path = config.get(\"prim_path\", \"{ENV_REGEX_NS}/Robot/.*\")\n        history_length = config.get(\"history_length\", 3)\n        force_threshold = config.get(\"force_threshold\", 10.0)\n        track_air_time = config.get(\"track_air_time\", True)\n        debug_vis = config.get(\"debug_vis\", False)\n\n        return ContactSensorCfg(\n            prim_path=prim_path,\n            history_length=history_length,\n            track_air_time=track_air_time,\n            force_threshold=force_threshold,\n            debug_vis=debug_vis,\n        )\n\n\n@configclass\nclass MotionTrackingSceneCfg(InteractiveSceneCfg):\n    \"\"\"Scene configuration for motion tracking environment.\"\"\"\n\n    pass\n\n\ndef build_scene_config(\n    scene_config_dict: dict,\n    main_process: bool = True,\n    process_id: int = 0,\n    num_processes: int = 1,\n) -> MotionTrackingSceneCfg:\n    \"\"\"Build IsaacLab-compatible scene configuration from config dictionary.\n\n    Args:\n        scene_config_dict: Scene configuration dictionary\n        main_process: Whether this is the main process (from compiled config)\n        process_id: Process ID/rank (from compiled config)\n        num_processes: Total number of processes (from compiled config)\n    \"\"\"\n    scene_cfg = MotionTrackingSceneCfg()\n\n    # Basic scene properties\n    scene_cfg.num_envs = scene_config_dict.get(\"num_envs\", MISSING)\n    scene_cfg.env_spacing = scene_config_dict.get(\"env_spacing\", 2.5)\n    scene_cfg.replicate_physics = scene_config_dict.get(\n        \"replicate_physics\", True\n    )\n\n    # Build robot configuration with process info\n    if \"robot\" in scene_config_dict:\n        robot_config = scene_config_dict[\"robot\"]\n        scene_cfg.robot = SceneFunctions.build_robot_config(\n            robot_config,\n            domain_rand_config=scene_config_dict.get(\"domain_rand\", {}),\n            main_process=main_process,\n            process_id=process_id,\n            num_processes=num_processes,\n        )\n\n    # Build terrain configuration\n    if \"terrain\" in scene_config_dict:\n        terrain_config = scene_config_dict[\"terrain\"]\n        scene_cfg.terrain = build_terrain_config(\n            terrain_config, scene_env_spacing=scene_cfg.env_spacing\n        )\n        if \"robot\" in scene_config_dict:\n            scene_cfg.height_scanner = RayCasterCfg(\n                prim_path=\"{ENV_REGEX_NS}/Robot\",\n                offset=RayCasterCfg.OffsetCfg(pos=(0.0, 0.0, 1.0)),\n                ray_alignment=\"world\",\n                pattern_cfg=patterns.GridPatternCfg(\n                    resolution=1.0, size=(1.0e-3, 1.0e-3)\n                ),\n                debug_vis=False,\n                mesh_prim_paths=[str(scene_cfg.terrain.prim_path)],\n                max_distance=1.0e6,\n            )\n\n    # Build lighting configuration\n    if \"lighting\" in scene_config_dict:\n        lighting_config = scene_config_dict[\"lighting\"]\n        light, sky_light = SceneFunctions.build_lighting_config(\n            lighting_config\n        )\n        scene_cfg.light = light\n        scene_cfg.sky_light = sky_light\n\n    # Build contact sensor configuration\n    if \"contact_sensor\" in scene_config_dict:\n        contact_config = scene_config_dict[\"contact_sensor\"]\n        scene_cfg.contact_forces = SceneFunctions.build_contact_sensor_config(\n            contact_config\n        )\n\n    return scene_cfg\n\n\ndef _cfg_to_kwargs(cfg: object) -> dict:\n    return {\n        key: copy.deepcopy(value)\n        for key, value in vars(cfg).items()\n        if not key.startswith(\"_\")\n    }\n\n\ndef _build_unitree_actuator_cfg(\n    config: dict, domain_rand_config: dict\n) -> dict[str, object]:\n    base_cfg = unitree_actuator_config_hardcoded[\"all_joints\"]\n    base_kwargs = _cfg_to_kwargs(base_cfg)\n    action_delay_cfg = copy.deepcopy(\n        domain_rand_config.get(\"action_delay\", {})\n    )\n    if action_delay_cfg.get(\"enabled\", False):\n        delay_kwargs = {\n            \"min_delay\": int(action_delay_cfg.get(\"min_delay\", 0)),\n            \"max_delay\": int(action_delay_cfg.get(\"max_delay\", 0)),\n        }\n    else:\n        delay_kwargs = {\"min_delay\": 0, \"max_delay\": 0}\n\n    if config.get(\"actuator_type\", \"unitree\") == \"unitree_erfi\":\n        erfi_cfg = copy.deepcopy(domain_rand_config.get(\"erfi\", {}))\n        actuator_filter_kwargs = {\n            \"ema_filter_enabled\": bool(\n                config.get(\"ema_filter_enabled\", False)\n            ),\n            \"ema_filter_alpha\": config.get(\"ema_filter_alpha\", 1.0),\n        }\n        erfi_kwargs = {\n            \"erfi_enabled\": bool(erfi_cfg.get(\"enabled\", False)),\n            \"rfi_probability\": erfi_cfg.get(\"rfi_probability\", 0.5),\n            \"rfi_lim\": erfi_cfg.get(\"rfi_lim\", 0.1),\n            \"randomize_rfi_lim\": erfi_cfg.get(\"randomize_rfi_lim\", True),\n            \"rfi_lim_range\": erfi_cfg.get(\"rfi_lim_range\", (0.5, 1.5)),\n            \"rao_lim\": erfi_cfg.get(\"rao_lim\", 0.1),\n        }\n        actuator_kwargs = {\n            **base_kwargs,\n            **delay_kwargs,\n            **actuator_filter_kwargs,\n            **erfi_kwargs,\n        }\n        actuator_cfg = UnitreeErfiActuatorCfg(**actuator_kwargs)\n        actuator_cfg.class_type = UnitreeErfiActuator\n    else:\n        actuator_kwargs = {**base_kwargs, **delay_kwargs}\n        actuator_cfg = UnitreeActuatorCfg(**actuator_kwargs)\n        actuator_cfg.class_type = UnitreeActuator\n\n    return {\"all_joints\": actuator_cfg}\n\n\nunitree_actuator_config_hardcoded = {\n    \"all_joints\": UnitreeActuatorCfg(\n        joint_names_expr=[\n            \".*_hip_yaw_joint\",\n            \".*_hip_roll_joint\",\n            \".*_hip_pitch_joint\",\n            \".*_knee_joint\",\n            \".*_ankle_pitch_joint\",\n            \".*_ankle_roll_joint\",\n            \"waist_roll_joint\",\n            \"waist_pitch_joint\",\n            \"waist_yaw_joint\",\n            \".*_shoulder_pitch_joint\",\n            \".*_shoulder_roll_joint\",\n            \".*_shoulder_yaw_joint\",\n            \".*_elbow_joint\",\n            \".*_wrist_roll_joint\",\n            \".*_wrist_pitch_joint\",\n            \".*_wrist_yaw_joint\",\n        ],\n        min_delay=0,\n        max_delay=0,\n        effort_limit={\n            \".*_hip_yaw_joint\": 88,\n            \".*_hip_roll_joint\": 139,\n            \".*_hip_pitch_joint\": 88,\n            \".*_knee_joint\": 139,\n            \".*_ankle_pitch_joint\": 50,\n            \".*_ankle_roll_joint\": 50,\n            \"waist_roll_joint\": 50,\n            \"waist_pitch_joint\": 50,\n            \"waist_yaw_joint\": 88,\n            \".*_shoulder_pitch_joint\": 25,\n            \".*_shoulder_roll_joint\": 25,\n            \".*_shoulder_yaw_joint\": 25,\n            \".*_elbow_joint\": 25,\n            \".*_wrist_roll_joint\": 25,\n            \".*_wrist_pitch_joint\": 5,\n            \".*_wrist_yaw_joint\": 5,\n        },\n        velocity_limit={\n            \".*_hip_yaw_joint\": 32,\n            \".*_hip_roll_joint\": 20,\n            \".*_hip_pitch_joint\": 32,\n            \".*_knee_joint\": 20,\n            \".*_ankle_pitch_joint\": 37,\n            \".*_ankle_roll_joint\": 37,\n            \"waist_roll_joint\": 37,\n            \"waist_pitch_joint\": 37,\n            \"waist_yaw_joint\": 32,\n            \".*_shoulder_pitch_joint\": 37,\n            \".*_shoulder_roll_joint\": 37,\n            \".*_shoulder_yaw_joint\": 37,\n            \".*_elbow_joint\": 37,\n            \".*_wrist_roll_joint\": 37,\n            \".*_wrist_pitch_joint\": 22,\n            \".*_wrist_yaw_joint\": 22,\n        },\n        stiffness={\n            \".*_hip_yaw_joint\": 40.1792384737,\n            \".*_hip_roll_joint\": 99.0984277823,\n            \".*_hip_pitch_joint\": 40.1792384737,\n            \".*_knee_joint\": 99.0984277823,\n            \".*_ankle_pitch_joint\": 28.5012461974,\n            \".*_ankle_roll_joint\": 28.5012461974,\n            \"waist_roll_joint\": 28.5012461974,\n            \"waist_pitch_joint\": 28.5012461974,\n            \"waist_yaw_joint\": 40.1792384737,\n            \".*_shoulder_pitch_joint\": 14.2506230987,\n            \".*_shoulder_roll_joint\": 14.2506230987,\n            \".*_shoulder_yaw_joint\": 14.2506230987,\n            \".*_elbow_joint\": 14.2506230987,\n            \".*_wrist_roll_joint\": 14.2506230987,\n            \".*_wrist_pitch_joint\": 16.7783274819,\n            \".*_wrist_yaw_joint\": 16.7783274819,\n        },\n        damping={\n            \".*_hip_yaw_joint\": 2.5578897651,\n            \".*_hip_roll_joint\": 6.30880185368,\n            \".*_hip_pitch_joint\": 2.5578897651,\n            \".*_knee_joint\": 6.30880185368,\n            \".*_ankle_pitch_joint\": 1.81444568664,\n            \".*_ankle_roll_joint\": 1.81444568664,\n            \"waist_roll_joint\": 1.81444568664,\n            \"waist_pitch_joint\": 1.81444568664,\n            \"waist_yaw_joint\": 2.5578897651,\n            \".*_shoulder_pitch_joint\": 0.907222843318,\n            \".*_shoulder_roll_joint\": 0.907222843318,\n            \".*_shoulder_yaw_joint\": 0.907222843318,\n            \".*_elbow_joint\": 0.907222843318,\n            \".*_wrist_roll_joint\": 0.907222843318,\n            \".*_wrist_pitch_joint\": 1.06814150222,\n            \".*_wrist_yaw_joint\": 1.06814150222,\n        },\n        armature={\n            \".*_hip_yaw_joint\": 0.01017752,\n            \".*_hip_roll_joint\": 0.025101925,\n            \".*_hip_pitch_joint\": 0.01017752,\n            \".*_knee_joint\": 0.025101925,\n            \".*_ankle_pitch_joint\": 0.00721945,\n            \".*_ankle_roll_joint\": 0.00721945,\n            \"waist_roll_joint\": 0.00721945,\n            \"waist_pitch_joint\": 0.00721945,\n            \"waist_yaw_joint\": 0.01017752,\n            \".*_shoulder_pitch_joint\": 0.003609725,\n            \".*_shoulder_roll_joint\": 0.003609725,\n            \".*_shoulder_yaw_joint\": 0.003609725,\n            \".*_elbow_joint\": 0.003609725,\n            \".*_wrist_roll_joint\": 0.003609725,\n            \".*_wrist_pitch_joint\": 0.00425,\n            \".*_wrist_yaw_joint\": 0.00425,\n        },\n        friction=0,\n        dynamic_friction=0,\n        viscous_friction=0,\n        X1={\n            \".*_hip_yaw_joint\": 22.63,\n            \".*_hip_roll_joint\": 14.5,\n            \".*_hip_pitch_joint\": 22.63,\n            \".*_knee_joint\": 14.5,\n            \".*_ankle_pitch_joint\": 30.86,\n            \".*_ankle_roll_joint\": 30.86,\n            \"waist_roll_joint\": 30.86,\n            \"waist_pitch_joint\": 30.86,\n            \"waist_yaw_joint\": 22.63,\n            \".*_shoulder_pitch_joint\": 30.86,\n            \".*_shoulder_roll_joint\": 30.86,\n            \".*_shoulder_yaw_joint\": 30.86,\n            \".*_elbow_joint\": 30.86,\n            \".*_wrist_roll_joint\": 30.86,\n            \".*_wrist_pitch_joint\": 15.3,\n            \".*_wrist_yaw_joint\": 15.3,\n        },\n        X2={\n            \".*_hip_yaw_joint\": 35.52,\n            \".*_hip_roll_joint\": 22.7,\n            \".*_hip_pitch_joint\": 35.52,\n            \".*_knee_joint\": 22.7,\n            \".*_ankle_pitch_joint\": 40.13,\n            \".*_ankle_roll_joint\": 40.13,\n            \"waist_roll_joint\": 40.13,\n            \"waist_pitch_joint\": 40.13,\n            \"waist_yaw_joint\": 35.52,\n            \".*_shoulder_pitch_joint\": 40.13,\n            \".*_shoulder_roll_joint\": 40.13,\n            \".*_shoulder_yaw_joint\": 40.13,\n            \".*_elbow_joint\": 40.13,\n            \".*_wrist_roll_joint\": 40.13,\n            \".*_wrist_pitch_joint\": 24.76,\n            \".*_wrist_yaw_joint\": 24.76,\n        },\n        Y1={\n            \".*_hip_yaw_joint\": 71,\n            \".*_hip_roll_joint\": 111,\n            \".*_hip_pitch_joint\": 71,\n            \".*_knee_joint\": 111,\n            \".*_ankle_pitch_joint\": 24.8,\n            \".*_ankle_roll_joint\": 24.8,\n            \"waist_roll_joint\": 24.8,\n            \"waist_pitch_joint\": 24.8,\n            \"waist_yaw_joint\": 71,\n            \".*_shoulder_pitch_joint\": 24.8,\n            \".*_shoulder_roll_joint\": 24.8,\n            \".*_shoulder_yaw_joint\": 24.8,\n            \".*_elbow_joint\": 24.8,\n            \".*_wrist_roll_joint\": 24.8,\n            \".*_wrist_pitch_joint\": 4.8,\n            \".*_wrist_yaw_joint\": 4.8,\n        },\n        Y2={\n            \".*_hip_yaw_joint\": 83.3,\n            \".*_hip_roll_joint\": 131,\n            \".*_hip_pitch_joint\": 83.3,\n            \".*_knee_joint\": 131,\n            \".*_ankle_pitch_joint\": 31.9,\n            \".*_ankle_roll_joint\": 31.9,\n            \"waist_roll_joint\": 31.9,\n            \"waist_pitch_joint\": 31.9,\n            \"waist_yaw_joint\": 83.3,\n            \".*_shoulder_pitch_joint\": 31.9,\n            \".*_shoulder_roll_joint\": 31.9,\n            \".*_shoulder_yaw_joint\": 31.9,\n            \".*_elbow_joint\": 31.9,\n            \".*_wrist_roll_joint\": 31.9,\n            \".*_wrist_pitch_joint\": 8.6,\n            \".*_wrist_yaw_joint\": 8.6,\n        },\n        Fs={\n            \".*_hip_yaw_joint\": 1.6,\n            \".*_hip_roll_joint\": 2.4,\n            \".*_hip_pitch_joint\": 1.6,\n            \".*_knee_joint\": 2.4,\n            \".*_ankle_pitch_joint\": 0.6,\n            \".*_ankle_roll_joint\": 0.6,\n            \"waist_roll_joint\": 0.6,\n            \"waist_pitch_joint\": 0.6,\n            \"waist_yaw_joint\": 1.6,\n            \".*_shoulder_pitch_joint\": 0.6,\n            \".*_shoulder_roll_joint\": 0.6,\n            \".*_shoulder_yaw_joint\": 0.6,\n            \".*_elbow_joint\": 0.6,\n            \".*_wrist_roll_joint\": 0.6,\n            \".*_wrist_pitch_joint\": 0.6,\n            \".*_wrist_yaw_joint\": 0.6,\n        },\n        Fd={\n            \".*_hip_yaw_joint\": 0.16,\n            \".*_hip_roll_joint\": 0.24,\n            \".*_hip_pitch_joint\": 0.16,\n            \".*_knee_joint\": 0.24,\n            \".*_ankle_pitch_joint\": 0.06,\n            \".*_ankle_roll_joint\": 0.06,\n            \"waist_roll_joint\": 0.06,\n            \"waist_pitch_joint\": 0.06,\n            \"waist_yaw_joint\": 0.16,\n            \".*_shoulder_pitch_joint\": 0.06,\n            \".*_shoulder_roll_joint\": 0.06,\n            \".*_shoulder_yaw_joint\": 0.06,\n            \".*_elbow_joint\": 0.06,\n            \".*_wrist_roll_joint\": 0.06,\n            \".*_wrist_pitch_joint\": 0.06,\n            \".*_wrist_yaw_joint\": 0.06,\n        },\n        Va={\n            \".*_hip_yaw_joint\": 0.01,\n            \".*_hip_roll_joint\": 0.01,\n            \".*_hip_pitch_joint\": 0.01,\n            \".*_knee_joint\": 0.01,\n            \".*_ankle_pitch_joint\": 0.01,\n            \".*_ankle_roll_joint\": 0.01,\n            \"waist_roll_joint\": 0.01,\n            \"waist_pitch_joint\": 0.01,\n            \"waist_yaw_joint\": 0.01,\n            \".*_shoulder_pitch_joint\": 0.01,\n            \".*_shoulder_roll_joint\": 0.01,\n            \".*_shoulder_yaw_joint\": 0.01,\n            \".*_elbow_joint\": 0.01,\n            \".*_wrist_roll_joint\": 0.01,\n            \".*_wrist_pitch_joint\": 0.01,\n            \".*_wrist_yaw_joint\": 0.01,\n        },\n    )\n}\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_simulator.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom isaaclab.sim import SimulationCfg, PhysxCfg\n\n\ndef build_simulator_config(sim_config_dict: dict) -> SimulationCfg:\n    \"\"\"Build simulation configuration from config dictionary.\"\"\"\n    policy_freq = sim_config_dict.get(\"policy_freq\", 50)\n    sim_freq = sim_config_dict.get(\"sim_freq\", 200)\n    decimation = int(sim_freq / policy_freq)\n    dt = 1.0 / sim_freq\n    device = sim_config_dict.get(\"device\", \"cuda\")\n\n    # PhysX configuration\n    physx_config = sim_config_dict.get(\"physx\", {})\n    physx = PhysxCfg(\n        bounce_threshold_velocity=physx_config.get(\n            \"bounce_threshold_velocity\", 0.2\n        ),\n        gpu_max_rigid_patch_count=physx_config.get(\n            \"gpu_max_rigid_patch_count\", int(10 * 2**15)\n        ),\n    )\n\n    return SimulationCfg(\n        dt=dt,\n        render_interval=decimation,\n        physx=physx,\n        device=device,\n    )\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_termination.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport inspect\n\nimport isaaclab.envs.mdp as isaaclab_mdp\nimport isaaclab.utils.math as isaaclab_math\nimport torch\nfrom isaaclab.envs import ManagerBasedRLEnv\nfrom isaaclab.managers import TerminationTermCfg\nfrom isaaclab.utils import configclass\n\nfrom holomotion.src.env.isaaclab_components import (\n    isaaclab_motion_tracking_command as motion_tracking_command,\n    isaaclab_utils,\n)\n\n\ndef _list_supported_terminations() -> list[str]:\n    custom_terminations = {\n        name\n        for name, obj in globals().items()\n        if (\n            inspect.isfunction(obj)\n            and obj.__module__ == __name__\n            and not name.startswith(\"_\")\n        )\n    }\n    native_terminations = {\n        name\n        for name in dir(isaaclab_mdp.terminations)\n        if (\n            not name.startswith(\"_\")\n            and callable(getattr(isaaclab_mdp.terminations, name))\n        )\n    }\n    return sorted(custom_terminations | native_terminations)\n\n\ndef _resolve_termination_func(name: str):\n    func = globals().get(name)\n    if inspect.isfunction(func) and func.__module__ == __name__:\n        return func\n\n    func = getattr(isaaclab_mdp.terminations, name, None)\n    if callable(func):\n        return func\n\n    supported = _list_supported_terminations()\n    raise ValueError(\n        f\"Unknown termination function: {name}. Supported: {supported}\"\n    )\n\n\ndef global_bodylink_pos_far(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Any body link position deviates more than threshold (world frame).\"\"\"\n    command: motion_tracking_command.RefMotionCommand = (\n        env.command_manager.get_term(command_name)\n    )\n    ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next(\n        prefix=ref_prefix\n    )  # [B, Nb, 3]\n    robot_pos_w = command.robot.data.body_pos_w  # [B, Nb, 3]\n\n    keybody_idxs = isaaclab_utils._get_body_indices(\n        command.robot, keybody_names\n    )\n\n    if keybody_idxs is not None and len(keybody_idxs) > 0:\n        idxs = torch.as_tensor(\n            keybody_idxs,\n            device=ref_pos_w.device,\n            dtype=torch.long,\n        )\n        ref_pos_w = ref_pos_w[:, idxs]\n        robot_pos_w = robot_pos_w[:, idxs]\n\n    error = torch.norm(ref_pos_w - robot_pos_w, dim=-1)  # [B, Nb]\n    return torch.any(error > threshold, dim=-1)  # [B]\n\n\ndef anchor_ref_z_far(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Anchor link z difference exceeds threshold (world frame).\"\"\"\n    command: motion_tracking_command.RefMotionCommand = (\n        env.command_manager.get_term(command_name)\n    )\n    ref_z = command.get_ref_motion_anchor_bodylink_global_pos_immediate_next(\n        prefix=ref_prefix\n    )[:, -1]\n    robot_z = command.global_robot_anchor_pos_cur[:, -1]\n    return (ref_z - robot_z).abs() > threshold\n\n\ndef ref_gravity_projection_far(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    asset_name: str = \"robot\",\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Difference in projected gravity z-component exceeds threshold.\n\n    Project world gravity into the anchor body frames using inverse\n    quaternion rotation and compare z-components.\n    \"\"\"\n    command: motion_tracking_command.RefMotionCommand = (\n        env.command_manager.get_term(command_name)\n    )\n    g_w = env.scene[asset_name].data.GRAVITY_VEC_W  # [B, 3]\n\n    # Reference anchor orientation (xyzw) from motion cache\n    ref_anchor_quat_xyzw = (\n        command.get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next(\n            prefix=ref_prefix\n        )\n    )  # [B, 4]\n\n    motion_projected_gravity_b = isaaclab_math.quat_apply_inverse(\n        ref_anchor_quat_xyzw, g_w\n    )  # [B, 3]\n\n    # motion_projected_gravity_b = isaaclab_math.quat_rotate_inverse(\n    #     ref_anchor_quat_xyzw, g_w\n    # )  # [B, 3]\n\n    # Robot anchor orientation (xyzw) from sim\n    robot_anchor_quat_wxyz = command.robot.data.body_quat_w[\n        :, command.anchor_bodylink_idx\n    ]  # [B, 4]\n\n    robot_projected_gravity_b = isaaclab_math.quat_apply_inverse(\n        robot_anchor_quat_wxyz, g_w\n    )  # [B, 3]\n\n    # robot_projected_gravity_b = isaaclab_math.quat_rotate_inverse(\n    #     robot_anchor_quat_wxyz, g_w\n    # )  # [B, 3]\n\n    return (\n        motion_projected_gravity_b[:, 2] - robot_projected_gravity_b[:, 2]\n    ).abs() > threshold\n\n\ndef keybody_ref_pos_far(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Any key body link z difference exceeds threshold (world frame).\"\"\"\n    command: motion_tracking_command.RefMotionCommand = (\n        env.command_manager.get_term(command_name)\n    )\n    ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next(\n        prefix=ref_prefix\n    )  # [B, Nb, 3]\n    robot_pos_w = command.robot.data.body_pos_w  # [B, Nb, 3]\n\n    keybody_idxs = isaaclab_utils._get_body_indices(\n        command.robot, keybody_names\n    )\n\n    if keybody_idxs is not None and len(keybody_idxs) > 0:\n        idxs = torch.as_tensor(\n            keybody_idxs,\n            device=ref_pos_w.device,\n            dtype=torch.long,\n        )\n        ref_pos_w = ref_pos_w[:, idxs]\n        robot_pos_w = robot_pos_w[:, idxs]\n\n    error = torch.norm(ref_pos_w - robot_pos_w, dim=-1)  # [B, Nb]\n    return torch.any(error > threshold, dim=-1)  # [B]\n\n\ndef keybody_ref_z_far(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    command_name: str = \"ref_motion\",\n    keybody_names: list[str] | None = None,\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Any key body link z difference exceeds threshold (world frame).\"\"\"\n    command: motion_tracking_command.RefMotionCommand = (\n        env.command_manager.get_term(command_name)\n    )\n    ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next(\n        prefix=ref_prefix\n    )  # [B, Nb, 3]\n    robot_pos_w = command.robot.data.body_pos_w  # [B, Nb, 3]\n\n    keybody_idxs = isaaclab_utils._get_body_indices(\n        command.robot, keybody_names\n    )\n\n    if keybody_idxs is not None and len(keybody_idxs) > 0:\n        idxs = torch.as_tensor(\n            keybody_idxs,\n            device=ref_pos_w.device,\n            dtype=torch.long,\n        )\n        ref_pos_w = ref_pos_w[:, idxs]\n        robot_pos_w = robot_pos_w[:, idxs]\n\n    error_z = (ref_pos_w[..., 2] - robot_pos_w[..., 2]).abs()  # [B, Nb]\n    return torch.any(error_z > threshold, dim=-1)  # [B]\n\n\ndef wholebody_mpjpe_far(\n    env: ManagerBasedRLEnv,\n    threshold: float,\n    command_name: str = \"ref_motion\",\n    ref_prefix: str = \"ref_\",\n) -> torch.Tensor:\n    \"\"\"Mean whole-body DOF position error exceeds threshold.\"\"\"\n    command: motion_tracking_command.RefMotionCommand = (\n        env.command_manager.get_term(command_name)\n    )\n    ref_dof_pos = command.get_ref_motion_dof_pos_immediate_next(\n        prefix=ref_prefix\n    )\n    robot_dof_pos = command.robot.data.joint_pos\n    mean_dof_error = torch.mean(torch.abs(robot_dof_pos - ref_dof_pos), dim=-1)\n    return mean_dof_error > threshold\n\n\ndef motion_end(\n    env: ManagerBasedRLEnv,\n    command_name: str = \"ref_motion\",\n) -> torch.Tensor:\n    \"\"\"Terminate when reference motion frames exceed their end frames.\n\n    Returns a boolean mask of shape [num_envs].\n    \"\"\"\n    command: motion_tracking_command.RefMotionCommand = (\n        env.command_manager.get_term(command_name)\n    )\n    result = command.motion_end_mask.clone().bool()\n    return result\n\n\n@configclass\nclass TerminationsCfg:\n    pass\n\n\ndef build_terminations_config(\n    termination_config_dict: dict,\n) -> TerminationsCfg:\n    terminations_cfg = TerminationsCfg()\n\n    for termination_name, termination_cfg in termination_config_dict.items():\n        termination_cfg = isaaclab_utils.resolve_holo_config(termination_cfg)\n        func = _resolve_termination_func(termination_name)\n        params = isaaclab_utils.resolve_holo_config(\n            termination_cfg.get(\"params\", {})\n        )\n\n        term_cfg = TerminationTermCfg(\n            func=func,\n            params=params,\n            time_out=(termination_name == \"time_out\")\n            or termination_cfg.get(\"time_out\", False),\n        )\n        setattr(terminations_cfg, termination_name, term_cfg)\n\n    return terminations_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_terrain.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport os\n\nimport isaaclab.sim as sim_utils\nimport isaaclab.terrains as terrain_gen\nimport numpy as np\nimport torch\nfrom isaaclab.terrains import TerrainImporter, TerrainImporterCfg\nfrom isaaclab.terrains.height_field import (\n    HfDiscreteObstaclesTerrainCfg,\n    HfPyramidSlopedTerrainCfg,\n    HfRandomUniformTerrainCfg,\n    HfTerrainBaseCfg,\n)\nfrom isaaclab.terrains.height_field.utils import height_field_to_mesh\nfrom isaaclab.utils import configclass\nfrom loguru import logger\n\n\ndef _convert_range_like_params(params: dict) -> dict:\n    \"\"\"Convert list values for common range/size keys to tuples.\n\n    This helps map Hydra YAML list values into IsaacLab config classes that\n    expect tuples (e.g. ``*_range``).\n    \"\"\"\n    converted = {}\n    for key, value in params.items():\n        if isinstance(value, list) and (\n            key.endswith(\"_range\") or key in (\"size\", \"difficulty_range\")\n        ):\n            converted[key] = tuple(value)\n        else:\n            converted[key] = value\n    return converted\n\n\n@height_field_to_mesh\ndef plane_terrain(difficulty: float, cfg: HfTerrainBaseCfg) -> np.ndarray:\n    \"\"\"Generate a truly flat height-field patch.\n\n    This is a lightweight alternative to using ``random_uniform`` with a zero\n    noise range.\n    The ``difficulty`` parameter is ignored.\n    \"\"\"\n    width_pixels = int(cfg.size[0] / cfg.horizontal_scale)\n    length_pixels = int(cfg.size[1] / cfg.horizontal_scale)\n    return np.zeros((width_pixels, length_pixels), dtype=np.int16)\n\n\n@configclass\nclass HfPlaneTerrainCfg(HfTerrainBaseCfg):\n    \"\"\"Configuration for a flat height-field plane terrain.\"\"\"\n\n    function = plane_terrain\n\n\nclass RandomSpawnTerrainImporter(TerrainImporter):\n    \"\"\"Terrain importer that spawns robots randomly within each sub-terrain.\"\"\"\n\n    _terrain_width: float | None = None\n    _terrain_length: float | None = None\n    _spawn_margin: float = 0.0\n\n    def _compute_env_origins_curriculum(\n        self, num_envs: int, origins: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"Compute env origins with random (x, y) positions.\n\n        This overrides the default curriculum-based distribution to add random\n        offsets within each sub-terrain's bounds.\n\n        Args:\n            num_envs: Number of environments.\n            origins: Terrain origins tensor of shape (num_rows, num_cols, 3).\n\n        Returns:\n            Environment origins tensor of shape (num_envs, 3).\n        \"\"\"\n        num_rows, num_cols = origins.shape[:2]\n\n        # Get sub-terrain size from terrain generator config\n        if self.cfg.terrain_generator is None:\n            raise ValueError(\n                \"terrain_generator config is required for random spawning\"\n            )\n        sub_terrain_size = self.cfg.terrain_generator.size\n        terrain_width, terrain_length = (\n            sub_terrain_size[0],\n            sub_terrain_size[1],\n        )\n\n        spawn_margin = float(getattr(self.cfg, \"random_spawn_margin\", 0.0))\n        spawn_margin = max(0.0, spawn_margin)\n        # Clamp margin to avoid invalid sampling ranges.\n        max_margin = 0.5 * min(float(terrain_width), float(terrain_length))\n        if spawn_margin >= max_margin:\n            logger.warning(\n                f\"random_spawn_margin={spawn_margin} is too large \"\n                f\"for sub-terrain size={sub_terrain_size}. \"\n                \"Clamping to 0.0.\"\n            )\n            spawn_margin = 0.0\n\n        # Maximum initial level possible for the terrains\n        if self.cfg.max_init_terrain_level is None:\n            max_init_level = num_rows - 1\n        else:\n            max_init_level = min(self.cfg.max_init_terrain_level, num_rows - 1)\n\n        # Store maximum terrain level possible\n        self.max_terrain_level = num_rows\n\n        # Use default curriculum-based assignment\n        self.terrain_levels = torch.randint(\n            0, max_init_level + 1, (num_envs,), device=self.device\n        )\n        self.terrain_types = torch.div(\n            torch.arange(num_envs, device=self.device),\n            (num_envs / num_cols),\n            rounding_mode=\"floor\",\n        ).to(torch.long)\n\n        # Create environment origins tensor starting from terrain origins\n        env_origins = torch.zeros(num_envs, 3, device=self.device)\n        env_origins[:] = origins[self.terrain_levels, self.terrain_types]\n\n        # Add random (x, y) offsets within each sub-terrain's bounds\n        # Offset range: [-size/2 + margin, size/2 - margin] for both x and y\n        x_min = -terrain_width / 2 + spawn_margin\n        x_max = terrain_width / 2 - spawn_margin\n        y_min = -terrain_length / 2 + spawn_margin\n        y_max = terrain_length / 2 - spawn_margin\n        x_offsets = torch.empty(num_envs, device=self.device).uniform_(\n            x_min, x_max\n        )\n        y_offsets = torch.empty(num_envs, device=self.device).uniform_(\n            y_min, y_max\n        )\n\n        env_origins[:, 0] += x_offsets\n        env_origins[:, 1] += y_offsets\n\n        # Store terrain size for use in update_env_origins\n        self._terrain_width = terrain_width\n        self._terrain_length = terrain_length\n        self._spawn_margin = spawn_margin\n\n        return env_origins\n\n    def update_env_origins(\n        self,\n        env_ids: torch.Tensor,\n        move_up: torch.Tensor,\n        move_down: torch.Tensor,\n    ):\n        \"\"\"Update env origins when terrain levels change.\"\"\"\n        # Check if grid-like spawning\n        if self.terrain_origins is None:\n            return\n\n        # Update terrain level for the envs\n        self.terrain_levels[env_ids] += 1 * move_up - 1 * move_down\n        # Robots that solve the last level are sent to a random one\n        # The minimum level is zero\n        self.terrain_levels[env_ids] = torch.where(\n            self.terrain_levels[env_ids] >= self.max_terrain_level,\n            torch.randint_like(\n                self.terrain_levels[env_ids], self.max_terrain_level\n            ),\n            torch.clip(self.terrain_levels[env_ids], 0),\n        )\n\n        # Update the env origins with terrain origins\n        self.env_origins[env_ids] = self.terrain_origins[\n            self.terrain_levels[env_ids], self.terrain_types[env_ids]\n        ]\n\n        # Add random (x, y) offsets within each sub-terrain's bounds\n        if self._terrain_width is None or self._terrain_length is None:\n            return\n\n        num_updated = len(env_ids)\n        x_min = -self._terrain_width / 2 + self._spawn_margin\n        x_max = self._terrain_width / 2 - self._spawn_margin\n        y_min = -self._terrain_length / 2 + self._spawn_margin\n        y_max = self._terrain_length / 2 - self._spawn_margin\n        x_offsets = torch.empty(num_updated, device=self.device).uniform_(\n            x_min, x_max\n        )\n        y_offsets = torch.empty(num_updated, device=self.device).uniform_(\n            y_min, y_max\n        )\n\n        self.env_origins[env_ids, 0] += x_offsets\n        self.env_origins[env_ids, 1] += y_offsets\n\n\ndef build_terrain_config(\n    config: dict, scene_env_spacing: float = None\n) -> TerrainImporterCfg:\n    \"\"\"Build terrain configuration.\n\n    Preferred usage in Holomotion is via the IsaacLab terrain generator API\n    with height-field sub-terrains fully specified from Hydra configs.\n\n    For backward compatibility only, two legacy modes are still supported:\n\n    * ``terrain_type=\\\"plane\\\"``: simple infinite plane using Isaac Sim's grid.\n    * ``terrain_type=\\\"usd\\\"``: load terrain from a local USD file.\n\n    All paths are offline by construction. Visual materials must use local\n    data:\n\n    * ``visual_material.type=\\\"color\\\"`` maps to :class:`PreviewSurfaceCfg`\n      with ``diffuse_color``, ``metallic`` and ``roughness``.\n    * ``visual_material.type=\\\"mdl\\\"`` is accepted only for local MDL files and\n      never uses NVIDIA Nucleus. When paths are invalid, a neutral color\n      material is used instead.\n\n    Args:\n        config: Terrain configuration dictionary with fields:\n\n            * ``terrain_type``: ``\\\"generator\\\"`` (preferred), ``\\\"plane\\\"`` or\n              ``\\\"usd\\\"`` (legacy).\n            * ``generator`` (required when ``terrain_type=\\\"generator\\\"``):\n              high-level :class:`TerrainGeneratorCfg` parameters such as\n              ``num_rows``, ``num_cols``, ``size``, ``border_width``,\n              ``horizontal_scale``, ``vertical_scale``, ``slope_threshold``,\n              ``difficulty_range``, ``color_scheme``.\n            * ``height_field`` (required when ``terrain_type=\\\"generator\\\"``):\n              height-field sub-terrain configuration with:\n\n              - ``type``: ``\\\"plane\\\"``, ``\\\"random_uniform\\\"``,\n                ``\\\"discrete_obstacles\\\"`` or ``\\\"pyramid_sloped\\\"``.\n              - Remaining keys are forwarded to the corresponding\n                :class:`HfRandomUniformTerrainCfg` or\n                :class:`HfDiscreteObstaclesTerrainCfg`.\n            * ``random_spawn`` (optional): if True, spawns robots at random\n              (x, y) positions within each sub-terrain's bounds.\n            * ``random_spawn_margin`` (optional): if set, keeps random spawn\n              points at least this many meters away from sub-terrain edges\n              (helps avoid spawning near the outer border where robots may fall\n              off).\n            * ``visual_material`` (optional): offline visual material config.\n            * ``static_friction``, ``dynamic_friction``, ``restitution``, etc.\n\n        scene_env_spacing: Environment spacing from scene config (used only\n            when ``terrain_type=\\\"plane\\\"`` is selected).\n\n    Returns:\n        TerrainImporterCfg configured according to the input parameters\n    \"\"\"\n    prim_path = config.get(\"prim_path\", \"/World/ground\")\n    static_friction = config.get(\"static_friction\", 1.0)\n    dynamic_friction = config.get(\"dynamic_friction\", 1.0)\n    restitution = config.get(\"restitution\", 0.0)\n    friction_combine_mode = config.get(\"friction_combine_mode\", \"multiply\")\n    restitution_combine_mode = config.get(\n        \"restitution_combine_mode\", \"multiply\"\n    )\n\n    terrain_type = config.get(\"terrain_type\", \"generator\")\n\n    if terrain_type == \"usd\":\n        usd_path = config.get(\"usd_path\")\n        if usd_path is None:\n            raise ValueError(\n                \"'usd_path' must be specified for terrain_type 'usd'\"\n            )\n        terrain_cfg = TerrainImporterCfg(\n            prim_path=prim_path,\n            terrain_type=\"usd\",\n            usd_path=usd_path,\n            collision_group=-1,\n            physics_material=sim_utils.RigidBodyMaterialCfg(\n                friction_combine_mode=friction_combine_mode,\n                restitution_combine_mode=restitution_combine_mode,\n                static_friction=static_friction,\n                dynamic_friction=dynamic_friction,\n                restitution=restitution,\n            ),\n            debug_vis=config.get(\"debug_vis\", False),\n        )\n        return terrain_cfg\n\n    if terrain_type == \"plane\":\n        env_spacing = (\n            scene_env_spacing if scene_env_spacing is not None else 2.5\n        )\n        terrain_cfg = TerrainImporterCfg(\n            prim_path=prim_path,\n            terrain_type=\"plane\",\n            collision_group=-1,\n            env_spacing=env_spacing,\n            physics_material=sim_utils.RigidBodyMaterialCfg(\n                friction_combine_mode=friction_combine_mode,\n                restitution_combine_mode=restitution_combine_mode,\n                static_friction=static_friction,\n                dynamic_friction=dynamic_friction,\n                restitution=restitution,\n            ),\n            debug_vis=config.get(\"debug_vis\", False),\n        )\n        return terrain_cfg\n\n    if terrain_type != \"generator\":\n        raise ValueError(\n            f\"Unsupported terrain_type '{terrain_type}'. \"\n            \"Expected 'generator', 'plane', or 'usd'.\"\n        )\n\n    generator_cfg_dict = config.get(\"generator\")\n    if generator_cfg_dict is None:\n        raise ValueError(\n            \"When 'terrain_type' is 'generator', a 'generator' dict must be \"\n            \"provided in terrain config.\"\n        )\n\n    # Optional new path: multiple sub-terrains defined under\n    # generator.sub_terrains.\n    sub_terrains_cfg_dict = generator_cfg_dict.get(\"sub_terrains\")\n    sub_terrains_cfg = None\n\n    if sub_terrains_cfg_dict is not None:\n        if not isinstance(sub_terrains_cfg_dict, dict):\n            raise ValueError(\n                \"Expected 'generator.sub_terrains' to be a mapping from names \"\n                \"to sub-terrain configs.\"\n            )\n        sub_terrains_cfg = {}\n        for sub_name, sub_cfg_dict in sub_terrains_cfg_dict.items():\n            if not isinstance(sub_cfg_dict, dict):\n                raise ValueError(\n                    f\"Sub-terrain '{sub_name}' must be a dictionary with at \"\n                    \"least a 'type' field.\"\n                )\n            sub_type = sub_cfg_dict.get(\"type\", \"random_uniform\")\n            sub_proportion = sub_cfg_dict.get(\"proportion\", 1.0)\n            sub_params_raw = {\n                key: value\n                for key, value in sub_cfg_dict.items()\n                if key not in (\"type\", \"proportion\")\n            }\n            sub_params = _convert_range_like_params(sub_params_raw)\n\n            if sub_type == \"random_uniform\":\n                hf_cfg = HfRandomUniformTerrainCfg(\n                    proportion=sub_proportion, **sub_params\n                )\n            elif sub_type == \"plane\":\n                hf_cfg = HfPlaneTerrainCfg(\n                    proportion=sub_proportion, **sub_params\n                )\n            elif sub_type == \"discrete_obstacles\":\n                hf_cfg = HfDiscreteObstaclesTerrainCfg(\n                    proportion=sub_proportion, **sub_params\n                )\n            elif sub_type == \"pyramid_sloped\":\n                hf_cfg = HfPyramidSlopedTerrainCfg(\n                    proportion=sub_proportion, **sub_params\n                )\n            else:\n                raise ValueError(\n                    f\"Unknown sub_terrains['{sub_name}'].type '{sub_type}'. \"\n                    \"Expected 'plane', 'random_uniform', 'discrete_obstacles',\"\n                    \" or 'pyramid_sloped'.\"\n                )\n            sub_terrains_cfg[sub_name] = hf_cfg\n\n    # Deprecated path: single height_field block at top-level.\n    if sub_terrains_cfg is None:\n        height_field_cfg_dict = config.get(\"height_field\")\n        if height_field_cfg_dict is None:\n            raise ValueError(\n                \"When 'terrain_type' is 'generator', either \"\n                \"'generator.sub_terrains' or a 'height_field' dict must be \"\n                \"provided in terrain config.\"\n            )\n\n        logger.warning(\n            \"Terrain config is using deprecated 'height_field' key. \"\n            \"Please migrate to 'generator.sub_terrains' for multi-sub-terrain \"\n            \"support.\"\n        )\n\n        hf_type = height_field_cfg_dict.get(\"type\", \"random_uniform\")\n        hf_params_raw = {\n            key: value\n            for key, value in height_field_cfg_dict.items()\n            if key != \"type\"\n        }\n        hf_params = _convert_range_like_params(hf_params_raw)\n\n        if hf_type == \"random_uniform\":\n            height_field_cfg = HfRandomUniformTerrainCfg(**hf_params)\n        elif hf_type == \"discrete_obstacles\":\n            height_field_cfg = HfDiscreteObstaclesTerrainCfg(**hf_params)\n        else:\n            raise ValueError(\n                f\"Unknown height_field.type '{hf_type}'. \"\n                \"Expected 'random_uniform' or 'discrete_obstacles'.\"\n            )\n        sub_terrains_cfg = {\"main\": height_field_cfg}\n\n    # Build TerrainGeneratorCfg from Hydra config.\n    generator_params = _convert_range_like_params(\n        {\n            key: value\n            for key, value in generator_cfg_dict.items()\n            if key != \"sub_terrains\"\n        }\n    )\n    terrain_generator = terrain_gen.TerrainGeneratorCfg(\n        **{\n            key: value\n            for key, value in generator_params.items()\n            if key\n            in (\n                \"size\",\n                \"border_width\",\n                \"border_height\",\n                \"num_rows\",\n                \"num_cols\",\n                \"horizontal_scale\",\n                \"vertical_scale\",\n                \"slope_threshold\",\n                \"difficulty_range\",\n                \"color_scheme\",\n                \"curriculum\",\n                \"seed\",\n                \"use_cache\",\n                \"cache_dir\",\n            )\n        },\n        sub_terrains=sub_terrains_cfg,\n    )\n\n    # Configure visual material for offline use\n    visual_material = None\n    if \"visual_material\" in config:\n        visual_material_dict = config[\"visual_material\"]\n        material_type = visual_material_dict.get(\"type\", \"color\")\n\n        if material_type == \"color\":\n            # Use PreviewSurfaceCfg with diffuse_color (no internet needed)\n            diffuse_color_raw = visual_material_dict.get(\n                \"diffuse_color\", (0.8, 0.8, 0.8)\n            )\n            # Convert list to tuple if needed (YAML loads lists).\n            # Ensure it's a tuple of floats as required by PreviewSurfaceCfg\n            if isinstance(diffuse_color_raw, list):\n                diffuse_color = tuple(float(x) for x in diffuse_color_raw)\n            elif isinstance(diffuse_color_raw, tuple):\n                diffuse_color = tuple(float(x) for x in diffuse_color_raw)\n            else:\n                diffuse_color = diffuse_color_raw\n            metallic = float(visual_material_dict.get(\"metallic\", 0.0))\n            roughness = float(visual_material_dict.get(\"roughness\", 0.5))\n            visual_material = sim_utils.PreviewSurfaceCfg(\n                diffuse_color=diffuse_color,\n                metallic=metallic,\n                roughness=roughness,\n            )\n        elif material_type == \"none\":\n            # No visual material, rely on vertex colors (e.g. from height map)\n            visual_material = None\n        elif material_type == \"mdl\":\n            # Use MdlFileCfg with local mdl_path\n            mdl_path = visual_material_dict.get(\"mdl_path\")\n            if mdl_path is None:\n                logger.warning(\n                    \"visual_material type is 'mdl' but no mdl_path specified. \"\n                    \"Falling back to color material to avoid internet \"\n                    \"requirements.\"\n                )\n                visual_material = sim_utils.PreviewSurfaceCfg(\n                    diffuse_color=(0.5, 0.5, 0.5)\n                )\n            else:\n                # Resolve relative paths\n                if not os.path.isabs(mdl_path):\n                    if os.path.exists(mdl_path):\n                        resolved_mdl_path = os.path.abspath(mdl_path)\n                    else:\n                        workspace_root = os.path.abspath(\n                            os.path.join(\n                                os.path.dirname(__file__), \"../../../..\"\n                            )\n                        )\n                        resolved_mdl_path = os.path.join(\n                            workspace_root, mdl_path\n                        )\n                else:\n                    resolved_mdl_path = mdl_path\n\n                if os.path.exists(resolved_mdl_path):\n                    visual_material = sim_utils.MdlFileCfg(\n                        mdl_path=resolved_mdl_path\n                    )\n                else:\n                    logger.warning(\n                        f\"MDL file not found at {resolved_mdl_path}. \"\n                        \"Falling back to color material to avoid internet \"\n                        \"requirements.\"\n                    )\n                    visual_material = sim_utils.PreviewSurfaceCfg(\n                        diffuse_color=(0.5, 0.5, 0.5)\n                    )\n        else:\n            logger.warning(\n                f\"Unknown visual_material type: {material_type}. \"\n                \"Using default color material.\"\n            )\n            visual_material = sim_utils.PreviewSurfaceCfg(\n                diffuse_color=(0.5, 0.5, 0.5)\n            )\n\n    # Configure random spawning within sub-terrains if requested\n    random_spawn = config.get(\"random_spawn\", False)\n    terrain_importer_class = (\n        RandomSpawnTerrainImporter if random_spawn else TerrainImporter\n    )\n\n    terrain_cfg = TerrainImporterCfg(\n        prim_path=prim_path,\n        terrain_type=\"generator\",\n        terrain_generator=terrain_generator,\n        max_init_terrain_level=config.get(\n            \"max_init_terrain_level\",\n            terrain_generator.num_rows - 1,\n        ),\n        collision_group=-1,\n        visual_material=visual_material,\n        physics_material=sim_utils.RigidBodyMaterialCfg(\n            friction_combine_mode=friction_combine_mode,\n            restitution_combine_mode=restitution_combine_mode,\n            static_friction=static_friction,\n            dynamic_friction=dynamic_friction,\n            restitution=restitution,\n        ),\n        debug_vis=config.get(\"debug_vis\", False),\n        class_type=terrain_importer_class,\n    )\n\n    if random_spawn:\n        terrain_cfg.random_spawn_margin = float(\n            config.get(\"random_spawn_margin\", 0.0)\n        )\n\n    return terrain_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_utils.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport torch\nfrom isaaclab.assets import Articulation\nfrom isaaclab.envs import ManagerBasedRLEnv\nfrom isaaclab.managers import RewardTermCfg, SceneEntityCfg\nfrom isaaclab.sensors import ContactSensor\nfrom isaaclab.utils import configclass\nimport isaaclab.utils.math as isaaclab_math\n\nfrom holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import (\n    RefMotionCommand,\n)\nimport isaaclab.envs.mdp as isaaclab_mdp\nfrom hydra.utils import instantiate as hydra_instantiate\nfrom omegaconf import DictConfig, ListConfig, OmegaConf\n\nfrom loguru import logger\n\n\ndef _get_dof_indices(\n    robot: Articulation,\n    key_dofs: list[str] | None,\n) -> list[int] | None:\n    if key_dofs is None:\n        return list(range(len(robot.joint_names)))\n    dof_indices = []\n    for name in key_dofs:\n        if name not in robot.joint_names:\n            raise ValueError(\n                f\"DOF '{name}' not found in robot.joint_names: {robot.joint_names}\"\n            )\n        dof_indices.append(robot.joint_names.index(name))\n    return dof_indices\n\n\ndef _get_body_indices(\n    robot: Articulation,\n    keybody_names: list[str] | None,\n) -> list[int] | None:\n    \"\"\"Convert body names to indices.\n\n    Args:\n        robot: Robot articulation asset\n        keybody_names: List of body names. If None, returns None.\n\n    Returns:\n        List of body indices corresponding to the given names, or None if keybody_names is None\n    \"\"\"\n    if keybody_names is None:\n        return list(range(len(robot.body_names)))\n\n    body_indices = []\n    for name in keybody_names:\n        if name not in robot.body_names:\n            raise ValueError(\n                f\"Body '{name}' not found in robot.body_names: {robot.body_names}\"\n            )\n        body_indices.append(robot.body_names.index(name))\n\n    return body_indices\n\n\ndef resolve_holo_config(value):\n    def _sanitize_config_object(obj):\n        for attr, attr_value in vars(obj).items():\n            sanitized_value = resolve_holo_config(attr_value)\n            setattr(obj, attr, sanitized_value)\n        return obj\n\n    if isinstance(value, (DictConfig, ListConfig)):\n        value = OmegaConf.to_container(value, resolve=True)\n\n    if isinstance(value, dict):\n        if \"_target_\" in value:\n            instantiated = hydra_instantiate(value)\n            if hasattr(instantiated, \"__dict__\") and not callable(\n                instantiated\n            ):\n                return _sanitize_config_object(instantiated)\n            return instantiated\n        return {key: resolve_holo_config(item) for key, item in value.items()}\n\n    if isinstance(value, list):\n        return [resolve_holo_config(item) for item in value]\n\n    if hasattr(value, \"__dict__\") and not callable(value):\n        return _sanitize_config_object(value)\n\n    return value\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/isaaclab_velocity_tracking_command.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom __future__ import annotations\n\nfrom dataclasses import MISSING\n\nfrom isaaclab.managers import CommandTermCfg\nfrom isaaclab.utils import configclass\nimport torch\n\nfrom isaaclab.assets import Articulation\nfrom isaaclab.managers import CommandTerm\nfrom isaaclab.markers import VisualizationMarkers\nfrom isaaclab.envs import ManagerBasedEnv\nimport isaaclab.utils.math as math_utils\nfrom isaaclab.markers import VisualizationMarkersCfg\nfrom isaaclab.markers.config import (\n    BLUE_ARROW_X_MARKER_CFG,\n    GREEN_ARROW_X_MARKER_CFG,\n)\n\nfrom typing import Sequence\n\n\nclass HoloMotionUniformVelocityCommand(CommandTerm):\n    r\"\"\"Command generator that generates a velocity command in SE(2) from uniform distribution.\n\n    The command comprises of a linear velocity in x and y direction and an angular velocity around\n    the z-axis. It is given in the robot's base frame.\n\n    If the :attr:`cfg.heading_command` flag is set to True, the angular velocity is computed from the heading\n    error similar to doing a proportional control on the heading error. The target heading is sampled uniformly\n    from the provided range. Otherwise, the angular velocity is sampled uniformly from the provided range.\n\n    Mathematically, the angular velocity is computed as follows from the heading command:\n\n    .. math::\n\n        \\omega_z = \\frac{1}{2} \\text{wrap_to_pi}(\\theta_{\\text{target}} - \\theta_{\\text{current}})\n\n    \"\"\"\n\n    cfg: HoloMotionUniformVelocityCommandCfg\n    \"\"\"The configuration of the command generator.\"\"\"\n\n    def __init__(\n        self, cfg: HoloMotionUniformVelocityCommandCfg, env: ManagerBasedEnv\n    ):\n        \"\"\"Initialize the command generator.\n\n        Args:\n            cfg: The configuration of the command generator.\n            env: The environment.\n\n        Raises:\n            ValueError: If the heading command is active but the heading range is not provided.\n        \"\"\"\n        # initialize the base class\n        super().__init__(cfg, env)\n\n        # check configuration\n        if self.cfg.heading_command and self.cfg.ranges.heading is None:\n            raise ValueError(\n                \"The velocity command has heading commands active (heading_command=True) but the `ranges.heading`\"\n                \" parameter is set to None.\"\n            )\n        if self.cfg.rel_yaw_envs > 0.0:\n            yaw_min, yaw_max = self.cfg.ranges.ang_vel_z\n\n        # obtain the robot asset\n        # -- robot\n        self.robot: Articulation = env.scene[cfg.asset_name]\n\n        # crete buffers to store the command\n        # -- command: x vel, y vel, yaw vel, heading\n        self.vel_command_b = torch.zeros(self.num_envs, 3, device=self.device)\n        self.heading_target = torch.zeros(self.num_envs, device=self.device)\n        self.is_heading_env = torch.zeros(\n            self.num_envs, dtype=torch.bool, device=self.device\n        )\n        self.is_standing_env = torch.zeros_like(self.is_heading_env)\n        self.is_yaw_env = torch.zeros_like(self.is_heading_env)\n        # -- metrics\n        self.metrics[\"error_vel_xy\"] = torch.zeros(\n            self.num_envs, device=self.device\n        )\n        self.metrics[\"error_vel_yaw\"] = torch.zeros(\n            self.num_envs, device=self.device\n        )\n\n    def __str__(self) -> str:\n        \"\"\"Return a string representation of the command generator.\"\"\"\n        msg = \"HoloMotionUniformVelocityCommand:\\n\"\n        msg += f\"\\tCommand dimension: {tuple(self.command.shape[1:])}\\n\"\n        msg += f\"\\tResampling time range: {self.cfg.resampling_time_range}\\n\"\n        msg += f\"\\tHeading command: {self.cfg.heading_command}\\n\"\n        if self.cfg.heading_command:\n            msg += f\"\\tHeading probability: {self.cfg.rel_heading_envs}\\n\"\n        msg += f\"\\tStanding probability: {self.cfg.rel_standing_envs}\\n\"\n        msg += f\"\\tYaw-only probability: {self.cfg.rel_yaw_envs}\"\n        return msg\n\n    \"\"\"\n    Properties\n    \"\"\"\n\n    @property\n    def command(self) -> torch.Tensor:\n        \"\"\"The desired base velocity command in the base frame. Shape is (num_envs, 3).\"\"\"\n        return self.vel_command_b\n\n    \"\"\"\n    Implementation specific functions.\n    \"\"\"\n\n    def _update_metrics(self):\n        # time for which the command was executed\n        max_command_time = self.cfg.resampling_time_range[1]\n        max_command_step = max_command_time / self._env.step_dt\n        # logs data\n        self.metrics[\"error_vel_xy\"] += (\n            torch.norm(\n                self.vel_command_b[:, :2]\n                - self.robot.data.root_lin_vel_b[:, :2],\n                dim=-1,\n            )\n            / max_command_step\n        )\n        self.metrics[\"error_vel_yaw\"] += (\n            torch.abs(\n                self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]\n            )\n            / max_command_step\n        )\n\n    def _resample_command(self, env_ids: Sequence[int]):\n        # sample velocity commands\n        r = torch.empty(len(env_ids), device=self.device)\n        # -- linear velocity - x direction\n        self.vel_command_b[env_ids, 0] = r.uniform_(*self.cfg.ranges.lin_vel_x)\n        # -- linear velocity - y direction\n        self.vel_command_b[env_ids, 1] = r.uniform_(*self.cfg.ranges.lin_vel_y)\n        # -- ang vel yaw - rotation around z\n        self.vel_command_b[env_ids, 2] = r.uniform_(*self.cfg.ranges.ang_vel_z)\n        # heading target\n        if self.cfg.heading_command:\n            self.heading_target[env_ids] = r.uniform_(*self.cfg.ranges.heading)\n            # update heading envs\n            self.is_heading_env[env_ids] = (\n                r.uniform_(0.0, 1.0) <= self.cfg.rel_heading_envs\n            )\n        self.is_yaw_env[env_ids] = (\n            r.uniform_(0.0, 1.0) <= self.cfg.rel_yaw_envs\n        )\n        if self.cfg.heading_command:\n            # yaw-only envs should follow directly sampled yaw commands (not heading control)\n            self.is_heading_env[env_ids] &= ~self.is_yaw_env[env_ids]\n        # update standing envs\n        self.is_standing_env[env_ids] = (\n            r.uniform_(0.0, 1.0) <= self.cfg.rel_standing_envs\n        )\n\n    def _update_command(self):\n        \"\"\"Post-processes the velocity command.\n\n        This function sets velocity command to zero for standing environments and computes angular\n        velocity from heading direction if the heading_command flag is set.\n        \"\"\"\n        # Compute angular velocity from heading direction\n        if self.cfg.heading_command:\n            # resolve indices of heading envs\n            env_ids = self.is_heading_env.nonzero(as_tuple=False).flatten()\n            # compute angular velocity\n            heading_error = math_utils.wrap_to_pi(\n                self.heading_target[env_ids]\n                - self.robot.data.heading_w[env_ids]\n            )\n            self.vel_command_b[env_ids, 2] = torch.clip(\n                self.cfg.heading_control_stiffness * heading_error,\n                min=self.cfg.ranges.ang_vel_z[0],\n                max=self.cfg.ranges.ang_vel_z[1],\n            )\n        yaw_env_ids = self.is_yaw_env.nonzero(as_tuple=False).flatten()\n        self.vel_command_b[yaw_env_ids, :2] = 0.0\n\n        # Enforce standing (i.e., zero velocity command) for standing envs\n        # TODO: check if conversion is needed\n        standing_env_ids = self.is_standing_env.nonzero(\n            as_tuple=False\n        ).flatten()\n        self.vel_command_b[standing_env_ids, :] = 0.0\n\n    def _set_debug_vis_impl(self, debug_vis: bool):\n        # set visibility of markers\n        # note: parent only deals with callbacks. not their visibility\n        if debug_vis:\n            # create markers if necessary for the first time\n            if not hasattr(self, \"goal_vel_visualizer\"):\n                # -- goal\n                self.goal_vel_visualizer = VisualizationMarkers(\n                    self.cfg.goal_vel_visualizer_cfg\n                )\n                # -- current\n                self.current_vel_visualizer = VisualizationMarkers(\n                    self.cfg.current_vel_visualizer_cfg\n                )\n            # set their visibility to true\n            self.goal_vel_visualizer.set_visibility(True)\n            self.current_vel_visualizer.set_visibility(True)\n        else:\n            if hasattr(self, \"goal_vel_visualizer\"):\n                self.goal_vel_visualizer.set_visibility(False)\n                self.current_vel_visualizer.set_visibility(False)\n\n    def _debug_vis_callback(self, event):\n        # check if robot is initialized\n        # note: this is needed in-case the robot is de-initialized. we can't access the data\n        if not self.robot.is_initialized:\n            return\n        # get marker location\n        # -- base state\n        base_pos_w = self.robot.data.root_pos_w.clone()\n        base_pos_w[:, 2] += 0.5\n        # -- resolve the scales and quaternions\n        vel_des_arrow_scale, vel_des_arrow_quat = (\n            self._resolve_xy_velocity_to_arrow(self.command[:, :2])\n        )\n        vel_arrow_scale, vel_arrow_quat = self._resolve_xy_velocity_to_arrow(\n            self.robot.data.root_lin_vel_b[:, :2]\n        )\n        # display markers\n        self.goal_vel_visualizer.visualize(\n            base_pos_w, vel_des_arrow_quat, vel_des_arrow_scale\n        )\n        self.current_vel_visualizer.visualize(\n            base_pos_w, vel_arrow_quat, vel_arrow_scale\n        )\n\n    \"\"\"\n    Internal helpers.\n    \"\"\"\n\n    def _resolve_xy_velocity_to_arrow(\n        self, xy_velocity: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Converts the XY base velocity command to arrow direction rotation.\"\"\"\n        # obtain default scale of the marker\n        default_scale = self.goal_vel_visualizer.cfg.markers[\"arrow\"].scale\n        # arrow-scale\n        arrow_scale = torch.tensor(default_scale, device=self.device).repeat(\n            xy_velocity.shape[0], 1\n        )\n        arrow_scale[:, 0] *= torch.linalg.norm(xy_velocity, dim=1) * 3.0\n        # arrow-direction\n        heading_angle = torch.atan2(xy_velocity[:, 1], xy_velocity[:, 0])\n        zeros = torch.zeros_like(heading_angle)\n        arrow_quat = math_utils.quat_from_euler_xyz(\n            zeros, zeros, heading_angle\n        )\n        # convert everything back from base to world frame\n        base_quat_w = self.robot.data.root_quat_w\n        arrow_quat = math_utils.quat_mul(base_quat_w, arrow_quat)\n\n        return arrow_scale, arrow_quat\n\n\n@configclass\nclass HoloMotionUniformVelocityCommandCfg(CommandTermCfg):\n    \"\"\"Configuration for the uniform velocity command generator.\"\"\"\n\n    class_type: type = HoloMotionUniformVelocityCommand\n\n    asset_name: str = MISSING\n    \"\"\"Name of the asset in the environment for which the commands are generated.\"\"\"\n\n    heading_command: bool = False\n    \"\"\"Whether to use heading command or angular velocity command. Defaults to False.\n\n    If True, the angular velocity command is computed from the heading error, where the\n    target heading is sampled uniformly from provided range. Otherwise, the angular velocity\n    command is sampled uniformly from provided range.\n    \"\"\"\n\n    heading_control_stiffness: float = 1.0\n    \"\"\"Scale factor to convert the heading error to angular velocity command. Defaults to 1.0.\"\"\"\n\n    rel_standing_envs: float = 0.0\n    \"\"\"The sampled probability of environments that should be standing still. Defaults to 0.0.\"\"\"\n\n    rel_yaw_envs: float = 0.0\n    \"\"\"The sampled probability of environments that should receive yaw-only commands. Defaults to 0.0.\n\n    For yaw-only environments, the command is post-processed to:\n    - enforce vx=vy=0\n\n    This is sampled independently from :attr:`rel_standing_envs`. If an environment is both yaw-only\n    and standing, standing still overrides to zero command.\n    \"\"\"\n\n    rel_heading_envs: float = 1.0\n    \"\"\"The sampled probability of environments where the robots follow the heading-based angular velocity command\n    (the others follow the sampled angular velocity command). Defaults to 1.0.\n\n    This parameter is only used if :attr:`heading_command` is True.\n    \"\"\"\n\n    @configclass\n    class Ranges:\n        \"\"\"Uniform distribution ranges for the velocity commands.\"\"\"\n\n        lin_vel_x: tuple[float, float] = MISSING\n        \"\"\"Range for the linear-x velocity command (in m/s).\"\"\"\n\n        lin_vel_y: tuple[float, float] = MISSING\n        \"\"\"Range for the linear-y velocity command (in m/s).\"\"\"\n\n        ang_vel_z: tuple[float, float] = MISSING\n        \"\"\"Range for the angular-z velocity command (in rad/s).\"\"\"\n\n        heading: tuple[float, float] | None = None\n        \"\"\"Range for the heading command (in rad). Defaults to None.\n\n        This parameter is only used if :attr:`~HoloMotionUniformVelocityCommandCfg.heading_command` is True.\n        \"\"\"\n\n    ranges: Ranges = MISSING\n    \"\"\"Distribution ranges for the velocity commands.\"\"\"\n\n    goal_vel_visualizer_cfg: VisualizationMarkersCfg = (\n        GREEN_ARROW_X_MARKER_CFG.replace(\n            prim_path=\"/Visuals/Command/velocity_goal\"\n        )\n    )\n    \"\"\"The configuration for the goal velocity visualization marker. Defaults to GREEN_ARROW_X_MARKER_CFG.\"\"\"\n\n    current_vel_visualizer_cfg: VisualizationMarkersCfg = (\n        BLUE_ARROW_X_MARKER_CFG.replace(\n            prim_path=\"/Visuals/Command/velocity_current\"\n        )\n    )\n    \"\"\"The configuration for the current velocity visualization marker. Defaults to BLUE_ARROW_X_MARKER_CFG.\"\"\"\n\n    # Set the scale of the visualization markers to (0.5, 0.5, 0.5)\n    goal_vel_visualizer_cfg.markers[\"arrow\"].scale = (0.5, 0.5, 0.5)\n    current_vel_visualizer_cfg.markers[\"arrow\"].scale = (0.5, 0.5, 0.5)\n\n\n@configclass\nclass VelTrack_CommandsCfg:\n    pass\n\n\ndef _convert_ranges_dict_to_object(\n    ranges_dict: dict,\n) -> HoloMotionUniformVelocityCommandCfg.Ranges:\n    \"\"\"Convert a dict of ranges to a proper Ranges object with tuples.\"\"\"\n    ranges_kwargs = {}\n    for key, value in ranges_dict.items():\n        if value is None:\n            ranges_kwargs[key] = None\n        elif isinstance(value, (list, tuple)):\n            ranges_kwargs[key] = tuple(value)\n        else:\n            ranges_kwargs[key] = value\n    return HoloMotionUniformVelocityCommandCfg.Ranges(**ranges_kwargs)\n\n\ndef build_velocity_commands_config(\n    command_config_dict: dict,\n) -> VelTrack_CommandsCfg:\n    \"\"\"Build a CommandsCfg that supports velocity commands via IsaacLab isaaclab_mdp.\n\n    Expected format:\n    {\n      \"base_velocity\": {\n        \"type\": \"VelocityCommandCfg\" | \"HoloMotionUniformVelocityCommandCfg\" | \"UniformLevelVelocityCommandCfg\",\n        \"params\": { ... }  # args compatible with mdp command cfgs\n      }\n    }\n\n    For ranges and limit_ranges, pass them as dicts with keys like lin_vel_x, lin_vel_y, ang_vel_z, heading.\n    \"\"\"\n    commands_cfg = VelTrack_CommandsCfg()\n\n    for name, cfg in command_config_dict.items():\n        command_type = cfg.get(\"type\", \"VelocityCommandCfg\")\n        params = cfg.get(\"params\", {}).copy()\n\n        if \"ranges\" in params and isinstance(params[\"ranges\"], dict):\n            params[\"ranges\"] = _convert_ranges_dict_to_object(params[\"ranges\"])\n\n        if \"limit_ranges\" in params and isinstance(\n            params[\"limit_ranges\"], dict\n        ):\n            params[\"limit_ranges\"] = _convert_ranges_dict_to_object(\n                params[\"limit_ranges\"]\n            )\n\n        if command_type == \"HoloMotionUniformVelocityCommandCfg\":\n            term_cfg = HoloMotionUniformVelocityCommandCfg(**params)\n        else:\n            raise ValueError(f\"Unknown velocity command type: {command_type}\")\n\n        setattr(commands_cfg, name, term_cfg)\n\n    return commands_cfg\n"
  },
  {
    "path": "holomotion/src/env/isaaclab_components/unitree_actuators.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\n# This file is modified from the unitree_rl_lab repository:\n# https://github.com/unitreerobotics/unitree_rl_lab\n\nfrom __future__ import annotations\n\nimport json\nimport os\nfrom pathlib import Path\nimport torch\nfrom dataclasses import MISSING\nfrom typing import Sequence\n\nfrom isaaclab.actuators import DelayedPDActuator, DelayedPDActuatorCfg\nfrom isaaclab.utils import configclass\nfrom isaaclab.utils.types import ArticulationActions\n\nfrom loguru import logger\n\n\nclass UnitreeActuator(DelayedPDActuator):\n    \"\"\"Unitree actuator class that implements a torque-speed curve for the actuators.\n\n    The torque-speed curve is defined as follows:\n\n            Torque Limit, N·m\n                ^\n    Y2──────────|\n                |──────────────Y1\n                |              │\\\n                |              │ \\\n                |              │  \\\n                |              |   \\\n    ------------+--------------|------> velocity: rad/s\n                              X1   X2\n\n    - Y1: Peak Torque Test (Torque and Speed in the Same Direction)\n    - Y2: Peak Torque Test (Torque and Speed in the Opposite Direction)\n    - X1: Maximum Speed at Full Torque (T-N Curve Knee Point)\n    - X2: No-Load Speed Test\n\n    - Fs: Static friction coefficient\n    - Fd: Dynamic friction coefficient\n    - Va: Velocity at which the friction is fully activated\n    \"\"\"\n\n    cfg: UnitreeActuatorCfg\n\n    armature: torch.Tensor\n    \"\"\"The armature of the actuator joints. Shape is (num_envs, num_joints).\n        armature = J2 + J1 * i2 ^ 2 + Jr * (i1 * i2) ^ 2\n    \"\"\"\n\n    def __init__(self, cfg: UnitreeActuatorCfg, *args, **kwargs):\n        super().__init__(cfg, *args, **kwargs)\n\n        self._joint_vel = torch.zeros_like(self.computed_effort)\n        self._effort_y1 = self._parse_joint_parameter(cfg.Y1, 1e9)\n        self._effort_y2 = self._parse_joint_parameter(cfg.Y2, cfg.Y1)\n        self._velocity_x1 = self._parse_joint_parameter(cfg.X1, 1e9)\n        self._velocity_x2 = self._parse_joint_parameter(cfg.X2, 1e9)\n        self._friction_static = self._parse_joint_parameter(cfg.Fs, 0.0)\n        self._friction_dynamic = self._parse_joint_parameter(cfg.Fd, 0.0)\n        self._activation_vel = self._parse_joint_parameter(cfg.Va, 0.01)\n\n    def compute(\n        self,\n        control_action: ArticulationActions,\n        joint_pos: torch.Tensor,\n        joint_vel: torch.Tensor,\n    ) -> ArticulationActions:\n        # save current joint vel\n        self._joint_vel[:] = joint_vel\n        # calculate the desired joint torques\n        control_action = super().compute(control_action, joint_pos, joint_vel)\n\n        # apply friction model on the torque\n        self.applied_effort -= (\n            self._friction_static\n            * torch.tanh(joint_vel / self._activation_vel)\n            + self._friction_dynamic * joint_vel\n        )\n\n        control_action.joint_positions = None\n        control_action.joint_velocities = None\n        control_action.joint_efforts = self.applied_effort\n\n        return control_action\n\n    def _clip_effort(self, effort: torch.Tensor) -> torch.Tensor:\n        # check if the effort is the same direction as the joint velocity\n        same_direction = (self._joint_vel * effort) > 0\n        max_effort = torch.where(\n            same_direction, self._effort_y1, self._effort_y2\n        )\n        # check if the joint velocity is less than the max speed at full torque\n        max_effort = torch.where(\n            self._joint_vel.abs() < self._velocity_x1,\n            max_effort,\n            self._compute_effort_limit(max_effort),\n        )\n        return torch.clip(effort, -max_effort, max_effort)\n\n    def _compute_effort_limit(self, max_effort):\n        k = -max_effort / (self._velocity_x2 - self._velocity_x1)\n        limit = k * (self._joint_vel.abs() - self._velocity_x1) + max_effort\n        return limit.clip(min=0.0)\n\n\nclass UnitreeErfiActuator(UnitreeActuator):\n    \"\"\"Unitree actuator with per-env ERFI-50 torque perturbations.\n\n    On environment reset, each env is assigned either step-wise random force\n    injection (RFI) or episode-level random actuation offset (RAO). During\n    rollout, only the selected mode is applied for that env.\n    \"\"\"\n\n    cfg: UnitreeErfiActuatorCfg\n\n    def __init__(self, cfg: UnitreeErfiActuatorCfg, *args, **kwargs):\n        super().__init__(cfg, *args, **kwargs)\n        self._ema_filter_alpha = float(cfg.ema_filter_alpha)\n        if not 0.0 <= self._ema_filter_alpha <= 1.0:\n            raise ValueError(\n                \"ema_filter_alpha must be within [0, 1], \"\n                f\"got {self._ema_filter_alpha}.\"\n            )\n        self._ema_filter_debug_dump_path = (\n            cfg.ema_filter_debug_dump_path\n            or os.environ.get(\"HOLOMOTION_EMA_FILTER_DEBUG_DUMP_PATH\")\n        )\n        self._ema_filter_debug_stop_after_dump = self._parse_bool_env(\n            \"HOLOMOTION_EMA_FILTER_DEBUG_STOP_AFTER_DUMP\",\n            cfg.ema_filter_debug_stop_after_dump,\n        )\n        self._ema_filter_debug_dumped = False\n        self._ema_filter_state = torch.zeros_like(self.computed_effort)\n        self._ema_filter_initialized = torch.zeros(\n            self._num_envs, dtype=torch.bool, device=self._device\n        )\n        self._mode_is_rfi = torch.zeros(\n            self._num_envs, dtype=torch.bool, device=self._device\n        )\n        self._rfi_lim_scale = torch.ones_like(self.computed_effort)\n        self._rao_scale = torch.zeros_like(self.computed_effort)\n\n    def reset(self, env_ids: Sequence[int] | slice | None):\n        super().reset(env_ids)\n        env_ids_tensor = self._env_ids_to_tensor(env_ids)\n        if env_ids_tensor.numel() == 0:\n            return\n        if self.cfg.ema_filter_enabled:\n            self._ema_filter_state[env_ids_tensor] = 0.0\n            self._ema_filter_initialized[env_ids_tensor] = False\n\n        if not self.cfg.erfi_enabled:\n            self._mode_is_rfi[env_ids_tensor] = False\n            self._rfi_lim_scale[env_ids_tensor] = 1.0\n            self._rao_scale[env_ids_tensor] = 0.0\n            return\n\n        sampled_is_rfi = (\n            torch.rand(env_ids_tensor.numel(), device=self._device)\n            < self.cfg.rfi_probability\n        )\n        self._mode_is_rfi[env_ids_tensor] = sampled_is_rfi\n\n        if self.cfg.randomize_rfi_lim:\n            self._rfi_lim_scale[env_ids_tensor] = self._sample_uniform(\n                self.cfg.rfi_lim_range[0],\n                self.cfg.rfi_lim_range[1],\n                (env_ids_tensor.numel(), self.num_joints),\n            )\n        else:\n            self._rfi_lim_scale[env_ids_tensor] = 1.0\n\n        self._rao_scale[env_ids_tensor] = self._sample_uniform(\n            -self.cfg.rao_lim,\n            self.cfg.rao_lim,\n            (env_ids_tensor.numel(), self.num_joints),\n        )\n\n        rfi_env_ids = env_ids_tensor[sampled_is_rfi]\n        if rfi_env_ids.numel() > 0:\n            self._rao_scale[rfi_env_ids] = 0.0\n\n    def compute(\n        self,\n        control_action: ArticulationActions,\n        joint_pos: torch.Tensor,\n        joint_vel: torch.Tensor,\n    ) -> ArticulationActions:\n        control_action = self._filter_joint_position_action(control_action)\n        if not self.cfg.erfi_enabled:\n            return super().compute(control_action, joint_pos, joint_vel)\n\n        if control_action.joint_efforts is None:\n            base_joint_efforts = torch.zeros_like(joint_pos)\n        else:\n            base_joint_efforts = control_action.joint_efforts.clone()\n\n        effort_limit = self.effort_limit.to(base_joint_efforts)\n        rfi_noise = self._sample_uniform(-1.0, 1.0, base_joint_efforts.shape)\n        rfi_term = (\n            rfi_noise * self.cfg.rfi_lim * self._rfi_lim_scale * effort_limit\n        )\n        rao_term = self._rao_scale * effort_limit\n        mode_is_rfi = self._mode_is_rfi.unsqueeze(-1)\n        control_action_with_erfi = ArticulationActions(\n            joint_positions=control_action.joint_positions,\n            joint_velocities=control_action.joint_velocities,\n            joint_efforts=base_joint_efforts\n            + torch.where(mode_is_rfi, rfi_term, rao_term),\n            joint_indices=control_action.joint_indices,\n        )\n\n        return super().compute(control_action_with_erfi, joint_pos, joint_vel)\n\n    def _filter_joint_position_action(\n        self, control_action: ArticulationActions\n    ) -> ArticulationActions:\n        if not self.cfg.ema_filter_enabled:\n            self._maybe_dump_ema_filter_debug_skip(\"ema_filter_disabled\")\n            return control_action\n        if control_action.joint_positions is None:\n            self._maybe_dump_ema_filter_debug_skip(\"joint_positions_none\")\n            return control_action\n\n        raw_joint_positions = control_action.joint_positions\n        previous_filtered_joint_positions = self._ema_filter_state.clone()\n        needs_init = ~self._ema_filter_initialized\n        filtered_joint_positions = raw_joint_positions.clone()\n        if torch.any(~needs_init):\n            filtered_joint_positions = torch.where(\n                needs_init.unsqueeze(-1),\n                raw_joint_positions,\n                self._ema_filter_alpha * raw_joint_positions\n                + (1.0 - self._ema_filter_alpha) * self._ema_filter_state,\n            )\n        self._maybe_dump_ema_filter_debug_verification(\n            raw_joint_positions=raw_joint_positions,\n            filtered_joint_positions=filtered_joint_positions,\n            previous_filtered_joint_positions=previous_filtered_joint_positions,\n            needs_init=needs_init,\n        )\n        self._ema_filter_state[:] = filtered_joint_positions\n        self._ema_filter_initialized[:] = True\n\n        return ArticulationActions(\n            joint_positions=filtered_joint_positions,\n            joint_velocities=control_action.joint_velocities,\n            joint_efforts=control_action.joint_efforts,\n            joint_indices=control_action.joint_indices,\n        )\n\n    def _maybe_dump_ema_filter_debug_verification(\n        self,\n        raw_joint_positions: torch.Tensor,\n        filtered_joint_positions: torch.Tensor,\n        previous_filtered_joint_positions: torch.Tensor,\n        needs_init: torch.Tensor,\n    ) -> None:\n        if (\n            self._ema_filter_debug_dumped\n            or not self._ema_filter_debug_dump_path\n        ):\n            return\n        rank = os.environ.get(\"RANK\")\n        if rank is not None and rank != \"0\":\n            return\n        initialized_env_ids = torch.nonzero(\n            ~needs_init, as_tuple=False\n        ).flatten()\n        if initialized_env_ids.numel() == 0:\n            return\n\n        env_idx = int(initialized_env_ids[0].item())\n        raw = raw_joint_positions[env_idx].detach().cpu()\n        prev = previous_filtered_joint_positions[env_idx].detach().cpu()\n        actual = filtered_joint_positions[env_idx].detach().cpu()\n        expected = (\n            self._ema_filter_alpha * raw\n            + (1.0 - self._ema_filter_alpha) * prev\n        )\n        matched = torch.allclose(actual, expected, atol=1.0e-6, rtol=1.0e-6)\n\n        dump_path = Path(self._ema_filter_debug_dump_path)\n        dump_path.parent.mkdir(parents=True, exist_ok=True)\n        dump_path.write_text(\n            json.dumps(\n                {\n                    \"alpha\": self._ema_filter_alpha,\n                    \"env_index\": env_idx,\n                    \"matched\": bool(matched),\n                    \"raw_joint_positions\": raw.tolist(),\n                    \"previous_filtered_joint_positions\": prev.tolist(),\n                    \"expected_filtered_joint_positions\": expected.tolist(),\n                    \"actual_filtered_joint_positions\": actual.tolist(),\n                    \"pid\": os.getpid(),\n                    \"rank\": rank or \"0\",\n                },\n                indent=2,\n            )\n        )\n        self._ema_filter_debug_dumped = True\n        logger.info(\"Wrote EMA verification dump to {}\", dump_path)\n        if self._ema_filter_debug_stop_after_dump:\n            raise RuntimeError(f\"EMA verification dump written to {dump_path}\")\n\n    def _maybe_dump_ema_filter_debug_skip(self, reason: str) -> None:\n        self._maybe_dump_ema_filter_debug_payload(\n            {\n                \"applied\": False,\n                \"reason\": reason,\n            }\n        )\n\n    def _maybe_dump_ema_filter_debug_payload(self, payload: dict) -> None:\n        if (\n            self._ema_filter_debug_dumped\n            or not self._ema_filter_debug_dump_path\n        ):\n            return\n        rank = os.environ.get(\"RANK\")\n        if rank is not None and rank != \"0\":\n            return\n\n        dump_path = Path(self._ema_filter_debug_dump_path)\n        dump_path.parent.mkdir(parents=True, exist_ok=True)\n        dump_path.write_text(\n            json.dumps(\n                {\n                    **payload,\n                    \"alpha\": self._ema_filter_alpha,\n                    \"pid\": os.getpid(),\n                    \"rank\": rank or \"0\",\n                },\n                indent=2,\n            )\n        )\n        self._ema_filter_debug_dumped = True\n        logger.info(\"Wrote EMA verification dump to {}\", dump_path)\n        if self._ema_filter_debug_stop_after_dump:\n            raise RuntimeError(f\"EMA verification dump written to {dump_path}\")\n\n    @staticmethod\n    def _parse_bool_env(name: str, default: bool) -> bool:\n        raw_value = os.environ.get(name)\n        if raw_value is None:\n            return bool(default)\n        return raw_value.strip().lower() in {\"1\", \"true\", \"yes\", \"on\"}\n\n    def _env_ids_to_tensor(\n        self, env_ids: Sequence[int] | slice | None\n    ) -> torch.Tensor:\n        if env_ids is None or env_ids == slice(None):\n            return torch.arange(self._num_envs, device=self._device)\n        if isinstance(env_ids, torch.Tensor):\n            return env_ids.to(device=self._device, dtype=torch.long).flatten()\n        return torch.tensor(env_ids, device=self._device, dtype=torch.long)\n\n    def _sample_uniform(\n        self, low: float, high: float, shape: tuple[int, ...]\n    ) -> torch.Tensor:\n        return torch.empty(shape, device=self._device).uniform_(low, high)\n\n\n@configclass\nclass UnitreeActuatorCfg(DelayedPDActuatorCfg):\n    \"\"\"\n    Configuration for Unitree actuators.\n    \"\"\"\n\n    class_type: type = UnitreeActuator\n\n    X1: float = 1e9\n    \"\"\"Maximum Speed at Full Torque(T-N Curve Knee Point) Unit: rad/s\"\"\"\n\n    X2: float = 1e9\n    \"\"\"No-Load Speed Test Unit: rad/s\"\"\"\n\n    Y1: float = MISSING\n    \"\"\"Peak Torque Test(Torque and Speed in the Same Direction) Unit: N*m\"\"\"\n\n    Y2: float | None = None\n    \"\"\"Peak Torque Test(Torque and Speed in the Opposite Direction) Unit: N*m\"\"\"\n\n    Fs: float = 0.0\n    \"\"\" Static friction coefficient \"\"\"\n\n    Fd: float = 0.0\n    \"\"\" Dynamic friction coefficient \"\"\"\n\n    Va: float = 0.01\n    \"\"\" Velocity at which the friction is fully activated \"\"\"\n\n\n@configclass\nclass UnitreeErfiActuatorCfg(UnitreeActuatorCfg):\n    \"\"\"Configuration for Unitree actuators with ERFI-50 perturbations.\"\"\"\n\n    class_type: type = UnitreeErfiActuator\n\n    erfi_enabled: bool = False\n    \"\"\"Whether ERFI perturbations are enabled for this actuator.\"\"\"\n\n    ema_filter_enabled: bool = False\n    \"\"\"Whether to apply EMA filtering to incoming joint-position actions.\"\"\"\n\n    ema_filter_alpha: float = 1.0\n    \"\"\"EMA mixing factor using filtered = alpha * raw + (1 - alpha) * prev.\"\"\"\n\n    ema_filter_debug_dump_path: str | None = None\n    \"\"\"Optional JSON path for a one-shot EMA verification dump during runtime.\"\"\"\n\n    ema_filter_debug_stop_after_dump: bool = False\n    \"\"\"Whether to stop execution after writing the EMA verification dump.\"\"\"\n\n    rfi_probability: float = 0.5\n    \"\"\"Probability of assigning RFI to an environment on reset.\"\"\"\n\n    rfi_lim: float = 0.1\n    \"\"\"Base RFI limit, expressed as a ratio of joint effort limits.\"\"\"\n\n    randomize_rfi_lim: bool = True\n    \"\"\"Whether to randomize the per-episode RFI limit scale.\"\"\"\n\n    rfi_lim_range: tuple[float, float] = (0.5, 1.5)\n    \"\"\"Multiplicative range for per-episode RFI scaling.\"\"\"\n\n    rao_lim: float = 0.1\n    \"\"\"RAO limit, expressed as a ratio of joint effort limits.\"\"\"\n\n\n@configclass\nclass UnitreeActuatorCfg_M107_15(UnitreeActuatorCfg):\n    X1 = 14.0\n    X2 = 25.6\n    Y1 = 150.0\n    Y2 = 182.8\n\n    armature = 0.063259741\n\n\n@configclass\nclass UnitreeActuatorCfg_M107_24(UnitreeActuatorCfg):\n    X1 = 8.8\n    X2 = 16\n    Y1 = 240\n    Y2 = 292.5\n\n    armature = 0.160478022\n\n\n@configclass\nclass UnitreeActuatorCfg_Go2HV(UnitreeActuatorCfg):\n    X1 = 13.5\n    X2 = 30\n    Y1 = 20.2\n    Y2 = 23.4\n\n\n@configclass\nclass UnitreeActuatorCfg_N7520_14p3(UnitreeActuatorCfg):\n    # Decimal point cannot be used as variable name, use `p` instead\n    X1 = 22.63\n    X2 = 35.52\n    Y1 = 71\n    Y2 = 83.3\n\n    Fs = 1.6\n    Fd = 0.16\n\n    \"\"\"\n    | rotor  | 0.489e-4 kg·m²\n    | gear_1 | 0.098e-4 kg·m² | ratio | 4.5\n    | gear_2 | 0.533e-4 kg·m² | ratio | 48/22+1\n    \"\"\"\n    armature = 0.01017752\n\n\n@configclass\nclass UnitreeActuatorCfg_N7520_22p5(UnitreeActuatorCfg):\n    # Decimal point cannot be used as variable name, use `p` instead\n    X1 = 14.5\n    X2 = 22.7\n    Y1 = 111.0\n    Y2 = 131.0\n\n    Fs = 2.4\n    Fd = 0.24\n\n    \"\"\"\n    | rotor  | 0.489e-4 kg·m²\n    | gear_1 | 0.109e-4 kg·m² | ratio | 4.5\n    | gear_2 | 0.738e-4 kg·m² | ratio | 5.0\n    \"\"\"\n    armature = 0.025101925\n\n\n@configclass\nclass UnitreeActuatorCfg_N5010_16(UnitreeActuatorCfg):\n    X1 = 27.0\n    X2 = 41.5\n    Y1 = 9.5\n    Y2 = 17.0\n\n    \"\"\"\n    | rotor  | 0.084e-4 kg·m²\n    | gear_1 | 0.015e-4 kg·m² | ratio | 4\n    | gear_2 | 0.068e-4 kg·m² | ratio | 4\n    \"\"\"\n    armature = 0.0021812\n\n\n@configclass\nclass UnitreeActuatorCfg_N5020_16(UnitreeActuatorCfg):\n    X1 = 30.86\n    X2 = 40.13\n    Y1 = 24.8\n    Y2 = 31.9\n\n    Fs = 0.6\n    Fd = 0.06\n\n    \"\"\"\n    | rotor  | 0.139e-4 kg·m²\n    | gear_1 | 0.017e-4 kg·m² | ratio | 46/18+1\n    | gear_2 | 0.169e-4 kg·m² | ratio | 56/16+1\n    \"\"\"\n    armature = 0.003609725\n\n\n@configclass\nclass UnitreeActuatorCfg_W4010_25(UnitreeActuatorCfg):\n    X1 = 15.3\n    X2 = 24.76\n    Y1 = 4.8\n    Y2 = 8.6\n\n    Fs = 0.6\n    Fd = 0.06\n\n    \"\"\"\n    | rotor  | 0.068e-4 kg·m²\n    | gear_1 |                | ratio | 5\n    | gear_2 |                | ratio | 5\n    \"\"\"\n    armature = 0.00425\n"
  },
  {
    "path": "holomotion/src/env/motion_tracking.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport torch\nimport time\nimport os\nimport yaml\nfrom collections import deque\nfrom functools import wraps\nfrom easydict import EasyDict\nimport random\nimport numpy as np\nfrom isaaclab.actuators import ImplicitActuatorCfg\nfrom isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg\nfrom isaaclab.sim import PhysxCfg, SimulationCfg\nfrom isaaclab.utils import configclass\nfrom isaaclab.utils.io import dump_yaml\nfrom loguru import logger\nfrom omegaconf import OmegaConf\n\nfrom holomotion.src.env.isaaclab_components import (\n    ActionsCfg,\n    VelTrack_CommandsCfg,\n    MoTrack_CommandsCfg,\n    EventsCfg,\n    MotionTrackingSceneCfg,\n    ObservationsCfg,\n    RewardsCfg,\n    TerminationsCfg,\n    CurriculumCfg,\n    build_actions_config,\n    build_motion_tracking_commands_config,\n    build_velocity_commands_config,\n    build_domain_rand_config,\n    build_curriculum_config,\n    build_observations_config,\n    build_rewards_config,\n    build_scene_config,\n    build_terminations_config,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_observation import (\n    ObservationFunctions,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_utils import (\n    resolve_holo_config,\n)\n\n# from holomotion.src.modules.agent_modules import ObsSeqSerializer\nimport isaaclab.envs.mdp as isaaclab_mdp\nfrom isaaclab.envs.mdp.events import _randomize_prop_by_op\nfrom isaaclab.managers import SceneEntityCfg, EventTermCfg\nfrom isaaclab.utils import configclass\n\n\nfrom isaaclab.envs import ManagerBasedEnv\nfrom isaaclab.managers import EventTermCfg\nfrom isaaclab.managers import EventTermCfg as EventTerm\n\n\nimport isaaclab.utils.math as math_utils\nfrom isaaclab.assets import Articulation\nfrom isaaclab.envs.mdp.events import _randomize_prop_by_op\nfrom isaaclab.managers import SceneEntityCfg\nfrom typing import TYPE_CHECKING, Literal\n\n\ndef _joint_ids_to_tensor(\n    joint_ids: slice | list[int] | tuple[int, ...] | torch.Tensor | None,\n    num_joints: int,\n    device: torch.device | str,\n) -> torch.Tensor:\n    if joint_ids is None:\n        return torch.arange(num_joints, device=device, dtype=torch.long)\n    if isinstance(joint_ids, slice):\n        if joint_ids == slice(None):\n            return torch.arange(num_joints, device=device, dtype=torch.long)\n        return torch.arange(num_joints, device=device, dtype=torch.long)[\n            joint_ids\n        ]\n    if isinstance(joint_ids, torch.Tensor):\n        return joint_ids.to(device=device, dtype=torch.long).flatten()\n    return torch.tensor(joint_ids, device=device, dtype=torch.long)\n\n\ndef _select_effort_limit_vector(\n    asset: Articulation,\n    selected_joint_ids: torch.Tensor,\n) -> torch.Tensor:\n    num_joints = asset.data.applied_torque.shape[1]\n    device = asset.data.applied_torque.device\n    dtype = asset.data.applied_torque.dtype\n\n    effort_limit_vec = torch.zeros(num_joints, device=device, dtype=dtype)\n    is_filled = torch.zeros(num_joints, device=device, dtype=torch.bool)\n\n    for actuator in asset.actuators.values():\n        actuator_joint_ids = _joint_ids_to_tensor(\n            actuator.joint_indices, num_joints=num_joints, device=device\n        )\n        actuator_effort_limit = torch.as_tensor(\n            actuator.effort_limit, device=device, dtype=dtype\n        )\n        if actuator_effort_limit.ndim == 0:\n            actuator_effort_limit = actuator_effort_limit.expand(\n                actuator_joint_ids.numel()\n            )\n        elif actuator_effort_limit.ndim == 2:\n            if actuator_effort_limit.shape[0] > 1:\n                reference = actuator_effort_limit[0].unsqueeze(0)\n                if not torch.allclose(\n                    actuator_effort_limit,\n                    reference.expand_as(actuator_effort_limit),\n                ):\n                    raise ValueError(\n                        \"normed_torque_rate requires actuator effort limits to be static across envs.\"\n                    )\n            actuator_effort_limit = actuator_effort_limit[0]\n        elif actuator_effort_limit.ndim != 1:\n            raise ValueError(\n                \"normed_torque_rate expects actuator effort limits to be scalar, 1-D, or 2-D tensors.\"\n            )\n\n        if actuator_effort_limit.numel() != actuator_joint_ids.numel():\n            raise ValueError(\n                \"normed_torque_rate found mismatched actuator joint indices and effort limits.\"\n            )\n\n        effort_limit_vec[actuator_joint_ids] = actuator_effort_limit\n        is_filled[actuator_joint_ids] = True\n\n    if not torch.all(is_filled[selected_joint_ids]):\n        missing_joint_ids = selected_joint_ids[~is_filled[selected_joint_ids]]\n        raise ValueError(\n            \"normed_torque_rate could not resolve actuator effort limits for \"\n            f\"joint ids {missing_joint_ids.tolist()}.\"\n        )\n\n    selected_effort_limits = effort_limit_vec[selected_joint_ids]\n    if not torch.all(torch.isfinite(selected_effort_limits)):\n        raise ValueError(\n            \"normed_torque_rate requires finite actuator effort limits for all selected joints.\"\n        )\n    if not torch.all(selected_effort_limits > 0.0):\n        raise ValueError(\n            \"normed_torque_rate requires strictly positive actuator effort limits for all selected joints.\"\n        )\n\n    return selected_effort_limits\n\n\nclass MotionTrackingEnv:\n    \"\"\"IsaacLab-based Motion Tracking Environment.\n\n    This environment integrates motion tracking capabilities with IsaacLab's\n    manager-based architecture, supporting curriculum learning, domain randomization,\n    and various termination conditions.\n\n    This is a wrapper class that handles Isaac Sim initialization and delegates\n    to an internal ManagerBasedRLEnv instance.\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        device: torch.device = None,\n        log_dir: str = None,\n        render_mode: str | None = None,\n        headless: bool = True,\n        accelerator=None,\n    ):\n        \"\"\"Initialize the Motion Tracking Environment.\n\n        Args:\n            config: Configuration for the environment\n            device: Device for tensor operations\n            log_dir: Logging directory\n            render_mode: Render mode for the environment\n            headless: Whether to run in headless mode\n            accelerator: Accelerator instance for distributed training (optional)\n        \"\"\"\n        self.config = config\n        self._device = device\n        self.accelerator = accelerator\n\n        self.log_dir = log_dir\n        self.headless = headless\n        self.init_done = False\n        self.is_evaluating = False\n        self.render_mode = render_mode\n\n        # self._init_motion_tracking_components()\n        self._init_isaaclab_env()\n        # self._init_serializers()\n        self._completion_total_queue = deque(maxlen=1000)\n        self._completion_success_queue = deque(maxlen=1000)\n        self.metrics = {}\n        self._robot_prev_joint_vel = None\n        self._robot_prev_applied_torque = None\n        self._robot_torque_rate_inv_effort_limit = None\n        self._robot_torque_rate_needs_reseed = None\n\n    @property\n    def num_envs(self):\n        return self._env.num_envs\n\n    @property\n    def device(self):\n        return self._env.device\n\n    def _init_isaaclab_env(self):\n        _device = self._device\n\n        curriculum = CurriculumCfg()\n\n        # Determine per-process seed if provided; else create a deterministic per-rank default\n        seed_val = getattr(self.config, \"seed\", None)\n        if seed_val is None:\n            if self.accelerator is not None:\n                pid = self.accelerator.process_index\n            else:\n                pid = int(self.config.get(\"process_id\", 0))\n            seed_val = int(time.time()) + pid\n\n        _robot_config_dict = EasyDict(\n            OmegaConf.to_container(self.config.robot, resolve=True)\n        )\n        _terrain_config_dict = EasyDict(\n            OmegaConf.to_container(self.config.terrain, resolve=True)\n        )\n        _obs_config_dict = EasyDict(\n            OmegaConf.to_container(self.config.obs, resolve=True)\n        )\n        _rewards_config_dict = EasyDict(\n            OmegaConf.to_container(self.config.rewards, resolve=True)\n        )\n        _domain_rand_config_dict = (\n            EasyDict(\n                OmegaConf.to_container(\n                    self.config.domain_rand,\n                    resolve=True,\n                )\n            )\n            if self.config.domain_rand is not None\n            else {}\n        )\n        _terminations_config_dict = (\n            EasyDict(\n                OmegaConf.to_container(\n                    self.config.terminations,\n                    resolve=True,\n                )\n            )\n            if self.config.terminations is not None\n            else {}\n        )\n        _scene_config_dict = EasyDict(\n            OmegaConf.to_container(\n                self.config.scene,\n                resolve=True,\n            )\n        )\n        _commands_config_dict = OmegaConf.to_container(\n            self.config.commands,\n            resolve=True,\n        )\n\n        _simulation_config_dict = EasyDict(\n            OmegaConf.to_container(\n                self.config.simulation,\n                resolve=True,\n            )\n        )\n        _actions_config_dict = EasyDict(\n            OmegaConf.to_container(\n                self.config.actions,\n                resolve=True,\n            )\n        )\n        if getattr(self.config, \"curriculum\", None) is not None:\n            _curriculum_config_dict = EasyDict(\n                OmegaConf.to_container(self.config.curriculum, resolve=True)\n            )\n        else:\n            _curriculum_config_dict = {}\n\n        @configclass\n        class MotionTrackingEnvCfg(ManagerBasedRLEnvCfg):\n            seed: int = seed_val\n            scene_config_dict = {\n                \"num_envs\": self.config.num_envs,\n                \"env_spacing\": self.config.env_spacing,\n                \"replicate_physics\": self.config.replicate_physics,\n                \"robot\": _robot_config_dict,\n                \"terrain\": _terrain_config_dict,\n                \"domain_rand\": _domain_rand_config_dict,\n                \"lighting\": _scene_config_dict.lighting,\n                \"contact_sensor\": _scene_config_dict.contact_sensor,\n            }\n\n            decimation: int = _simulation_config_dict.control_decimation\n            episode_length_s: int = _simulation_config_dict.episode_length_s\n            sim_freq = _simulation_config_dict.sim_freq\n            dt = 1.0 / sim_freq\n            physx = PhysxCfg(\n                bounce_threshold_velocity=_simulation_config_dict.physx.bounce_threshold_velocity,\n                gpu_max_rigid_patch_count=_simulation_config_dict.physx.gpu_max_rigid_patch_count,\n                enable_stabilization=True,\n            )\n\n            if self.accelerator is not None:\n                main_process = self.accelerator.is_main_process\n                process_id = self.accelerator.process_index\n                num_processes = self.accelerator.num_processes\n            else:\n                main_process = self.config.get(\"main_process\", True)\n                process_id = self.config.get(\"process_id\", 0)\n                num_processes = self.config.get(\"num_processes\", 1)\n            scene: MotionTrackingSceneCfg = build_scene_config(\n                scene_config_dict,\n                main_process=main_process,\n                process_id=process_id,\n                num_processes=num_processes,\n            )\n\n            sim: SimulationCfg = SimulationCfg(\n                dt=dt,\n                render_interval=decimation,\n                physx=physx,\n                device=_device,\n                enable_scene_query_support=True,\n            )\n            sim.physics_material = scene.terrain.physics_material\n\n            viewer: ViewerCfg = ViewerCfg(origin_type=\"world\")\n\n            motion_cmds = {}\n            vel_cmds = {}\n            for k, v in _commands_config_dict.items():\n                if (\n                    isinstance(v, dict)\n                    and v.get(\"type\", \"\") == \"MotionCommandCfg\"\n                ):\n                    motion_cmds[k] = v\n                else:\n                    vel_cmds[k] = v\n\n            # Populate RefMotionCommand distributed params when present.\n            if \"ref_motion\" in motion_cmds:\n                if self.accelerator is not None:\n                    cmd_process_id = self.accelerator.process_index\n                    cmd_num_processes = self.accelerator.num_processes\n                else:\n                    cmd_process_id = getattr(self.config, \"process_id\", 0)\n                    cmd_num_processes = getattr(\n                        self.config, \"num_processes\", 1\n                    )\n                motion_cmds[\"ref_motion\"][\"params\"].update(\n                    {\n                        \"seed\": int(seed_val),\n                        \"process_id\": cmd_process_id,\n                        \"num_processes\": cmd_num_processes,\n                        \"is_evaluating\": self.is_evaluating,\n                    }\n                )\n\n            # Build a unified commands cfg that may contain both motion and velocity terms.\n            if motion_cmds:\n                commands: MoTrack_CommandsCfg = (\n                    build_motion_tracking_commands_config(motion_cmds)\n                )\n            else:\n                commands: MoTrack_CommandsCfg = MoTrack_CommandsCfg()\n            if vel_cmds:\n                vel_commands: VelTrack_CommandsCfg = (\n                    build_velocity_commands_config(vel_cmds)\n                )\n                for name in vel_cmds.keys():\n                    setattr(commands, name, getattr(vel_commands, name))\n            observations: ObservationsCfg = build_observations_config(\n                _obs_config_dict.obs_groups\n            )\n            rewards: RewardsCfg = build_rewards_config(_rewards_config_dict)\n\n            if _terminations_config_dict:\n                terminations: TerminationsCfg = build_terminations_config(\n                    _terminations_config_dict\n                )\n            else:\n                terminations: TerminationsCfg = TerminationsCfg()\n\n            if _domain_rand_config_dict:\n                events: EventsCfg = build_domain_rand_config(\n                    _domain_rand_config_dict\n                )\n            else:\n                events: EventsCfg = EventsCfg()\n\n            if \"base_velocity\" in vel_cmds:\n                events.reset_base = EventTerm(\n                    func=isaaclab_mdp.reset_root_state_uniform,\n                    mode=\"reset\",\n                    params={\n                        \"pose_range\": {\n                            \"x\": (-0.5, 0.5),\n                            \"y\": (-0.5, 0.5),\n                            \"yaw\": (-3.14, 3.14),\n                        },\n                        \"velocity_range\": {\n                            \"x\": (0.0, 0.0),\n                            \"y\": (0.0, 0.0),\n                            \"z\": (0.0, 0.0),\n                            \"roll\": (0.0, 0.0),\n                            \"pitch\": (0.0, 0.0),\n                            \"yaw\": (0.0, 0.0),\n                        },\n                    },\n                )\n                events.reset_robot_joints = EventTerm(\n                    func=isaaclab_mdp.reset_joints_by_scale,\n                    mode=\"reset\",\n                    params={\n                        \"position_range\": (1.0, 1.0),\n                        \"velocity_range\": (-1.0, 1.0),\n                    },\n                )\n\n            curriculum: CurriculumCfg = build_curriculum_config(\n                _curriculum_config_dict\n            )\n            actions: ActionsCfg = build_actions_config(_actions_config_dict)\n            sim: SimulationCfg = SimulationCfg(\n                dt=dt,\n                render_interval=decimation,\n                physx=physx,\n                device=_device,\n                enable_scene_query_support=True,\n            )\n            sim.physx.gpu_max_rigid_patch_count = 10 * 2**15\n            sim.physx.enable_stabilization = True\n            sim.physics_material = scene.terrain.physics_material\n\n        isaaclab_env_cfg = MotionTrackingEnvCfg()\n\n        isaaclab_envconfig_dump_path = os.path.join(\n            self.log_dir, \"isaaclab_env_cfg.yaml\"\n        )\n        dump_yaml(isaaclab_envconfig_dump_path, isaaclab_env_cfg)\n\n        self._env = ManagerBasedRLEnv(isaaclab_env_cfg, self.render_mode)\n\n        logger.info(\"IsaacLab environment initialized !\")\n        return self._env\n\n    def _init_motion_tracking_components(self):\n        self.n_fut_frames = self.config.commands.ref_motion.params.n_fut_frames\n        self.target_fps = self.config.commands.ref_motion.params.target_fps\n        self._init_serializers()\n\n    def step(self, actor_state: dict):\n        obs_dict, rewards, terminated, time_outs, infos = self._env.step(\n            actor_state\n        )\n        # IsaacLab separates terminated vs time_outs, combine them for consistency\n        dones = terminated | time_outs\n        self._update_completion_rate_stats(terminated, time_outs, infos)\n        self._update_robot_metrics(infos)\n        return obs_dict, rewards, dones, time_outs, infos\n\n    def _update_robot_metrics(self, infos: dict) -> None:\n        \"\"\"Log robot low-level metrics (scalar means) for TensorBoard/console.\"\"\"\n        if (\"log\" not in infos) or (not isinstance(infos[\"log\"], dict)):\n            infos[\"log\"] = {}\n\n        dt = float(self._env.step_dt)\n        action = self._env.action_manager.action  # [B, A]\n        prev_action = self._env.action_manager.prev_action  # [B, A]\n        action_rate = torch.norm(action - prev_action, dim=-1) / dt  # [B]\n\n        robot = self._env.scene[\"robot\"]\n        dof_vel = robot.data.joint_vel  # [B, Nd]\n        dof_torque = robot.data.applied_torque  # [B, Nd]\n\n        if self._robot_prev_joint_vel is None or (\n            self._robot_prev_joint_vel.shape != dof_vel.shape\n        ):\n            self._robot_prev_joint_vel = dof_vel.clone()\n\n        dof_acc = (dof_vel - self._robot_prev_joint_vel) / dt  # [B, Nd]\n        self._robot_prev_joint_vel = dof_vel.clone()\n\n        if self._robot_prev_applied_torque is None or (\n            self._robot_prev_applied_torque.shape != dof_torque.shape\n        ):\n            joint_ids = torch.arange(\n                dof_torque.shape[1], device=dof_torque.device, dtype=torch.long\n            )\n            effort_limit = _select_effort_limit_vector(robot, joint_ids)\n            self._robot_torque_rate_inv_effort_limit = (\n                effort_limit.reciprocal()\n            )\n            self._robot_prev_applied_torque = torch.zeros_like(dof_torque)\n            self._robot_torque_rate_needs_reseed = torch.ones(\n                dof_torque.shape[0], device=dof_torque.device, dtype=torch.bool\n            )\n\n        normed_torque_rate = torch.zeros(\n            dof_torque.shape[0],\n            device=dof_torque.device,\n            dtype=dof_torque.dtype,\n        )\n        reseed_mask = self._robot_torque_rate_needs_reseed.clone()\n        if hasattr(self._env, \"episode_length_buf\"):\n            reseed_mask |= self._env.episode_length_buf == 0\n\n        active_mask = ~reseed_mask\n        if torch.any(active_mask):\n            delta = (\n                dof_torque[active_mask]\n                - self._robot_prev_applied_torque[active_mask]\n            ) * self._robot_torque_rate_inv_effort_limit\n            normed_torque_rate[active_mask] = torch.sum(delta.square(), dim=1)\n\n        self._robot_prev_applied_torque.copy_(dof_torque)\n        self._robot_torque_rate_needs_reseed[reseed_mask] = False\n\n        dof_acc_norm = torch.norm(dof_acc, dim=-1)  # [B]\n        dof_torque_norm = torch.norm(dof_torque, dim=-1)  # [B]\n        energy = torch.sum(\n            torch.abs(dof_vel) * torch.abs(dof_torque), dim=-1\n        )  # [B]\n\n        self.metrics[\"Robot/Action_Rate\"] = action_rate.mean()\n        self.metrics[\"Robot/DOF_Acc\"] = dof_acc_norm.mean()\n        self.metrics[\"Robot/DOF_Torque\"] = dof_torque_norm.mean()\n        self.metrics[\"Robot/Energy\"] = energy.mean()\n        self.metrics[\"Robot/Normed_Torque_Rate\"] = normed_torque_rate.mean()\n\n        infos[\"log\"][\"Metrics/Robot/Action_Rate\"] = self.metrics[\n            \"Robot/Action_Rate\"\n        ]\n        infos[\"log\"][\"Metrics/Robot/DOF_Acc\"] = self.metrics[\"Robot/DOF_Acc\"]\n        infos[\"log\"][\"Metrics/Robot/DOF_Torque\"] = self.metrics[\n            \"Robot/DOF_Torque\"\n        ]\n        infos[\"log\"][\"Metrics/Robot/Energy\"] = self.metrics[\"Robot/Energy\"]\n        infos[\"log\"][\"Metrics/Robot/Normed_Torque_Rate\"] = self.metrics[\n            \"Robot/Normed_Torque_Rate\"\n        ]\n\n    def _update_completion_rate_stats(\n        self,\n        terminated: torch.Tensor,\n        time_outs: torch.Tensor,\n        infos: dict,\n    ) -> None:\n        \"\"\"Log completion rate over recent done batches.\n\n        Definition:\n        - Completed: time_outs==True and terminated==False.\n        - Failed: terminated==True.\n        The rolling window stores per-step done counts (only when any done occurs).\n        \"\"\"\n        done_mask = (terminated | time_outs).reshape(-1).bool()\n        if torch.any(done_mask):\n            done_count = int(done_mask.sum().item())\n            completed_mask = (\n                time_outs.reshape(-1).bool()\n                & ~terminated.reshape(-1).bool()\n                & done_mask\n            )\n            completed_count = int(completed_mask.sum().item())\n            self._completion_total_queue.append(done_count)\n            self._completion_success_queue.append(completed_count)\n\n        denom = sum(self._completion_total_queue)\n        completion_rate = (\n            float(sum(self._completion_success_queue)) / float(denom)\n            if denom > 0\n            else 0.0\n        )\n        if (\"log\" not in infos) or (not isinstance(infos[\"log\"], dict)):\n            infos[\"log\"] = {}\n        infos[\"log\"][\"Metrics/ref_motion/Task/Completion_Rate\"] = torch.tensor(\n            completion_rate, device=self.device, dtype=torch.float32\n        )\n        self.metrics[\"Metrics/ref_motion/Task/Completion_Rate\"] = (\n            completion_rate\n        )\n\n    def reset_idx(self, env_ids: torch.Tensor):\n        return self._env.reset(env_ids=env_ids)\n\n    def reset_all(self):\n        env_ids = torch.arange(self.num_envs, device=self.device)\n        out = self._env.reset(env_ids=env_ids)\n        return out\n\n    def set_is_evaluating(self):\n        logger.info(\"Setting environment to evaluation mode\")\n        self.is_evaluating = True\n\n    def seed(self, seed: int):\n        self._env.seed(seed)\n"
  },
  {
    "path": "holomotion/src/env/velocity_tracking.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport torch\nimport time\nimport os\nimport yaml\nfrom collections import deque\nfrom functools import wraps\nfrom easydict import EasyDict\nimport random\nimport numpy as np\nfrom isaaclab.actuators import ImplicitActuatorCfg\nfrom isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg\nfrom isaaclab.sim import PhysxCfg, SimulationCfg\nfrom isaaclab.utils import configclass\nfrom isaaclab.utils.io import dump_yaml\nfrom loguru import logger\nfrom omegaconf import OmegaConf\n\nfrom holomotion.src.env.isaaclab_components import (\n    ActionsCfg,\n    VelTrack_CommandsCfg,\n    MoTrack_CommandsCfg,\n    EventsCfg,\n    MotionTrackingSceneCfg,\n    ObservationsCfg,\n    RewardsCfg,\n    TerminationsCfg,\n    CurriculumCfg,\n    build_actions_config,\n    build_motion_tracking_commands_config,\n    build_velocity_commands_config,\n    build_domain_rand_config,\n    build_curriculum_config,\n    build_observations_config,\n    build_rewards_config,\n    build_scene_config,\n    build_terminations_config,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_observation import (\n    ObservationFunctions,\n)\nfrom holomotion.src.env.isaaclab_components.isaaclab_utils import (\n    resolve_holo_config,\n)\nimport isaaclab.envs.mdp as isaaclab_mdp\nfrom isaaclab.envs.mdp.events import _randomize_prop_by_op\nfrom isaaclab.managers import SceneEntityCfg, EventTermCfg\nfrom isaaclab.utils import configclass\n\n\nfrom isaaclab.envs import ManagerBasedEnv\nfrom isaaclab.managers import EventTermCfg\nfrom isaaclab.managers import EventTermCfg as EventTerm\n\n\nimport isaaclab.utils.math as math_utils\nfrom isaaclab.assets import Articulation\nfrom isaaclab.envs.mdp.events import _randomize_prop_by_op\nfrom isaaclab.managers import SceneEntityCfg\nfrom typing import TYPE_CHECKING, Literal\n\n\nclass VelocityTrackingEnv:\n    \"\"\"IsaacLab-based Motion Tracking Environment.\n\n    This environment integrates motion tracking capabilities with IsaacLab's\n    manager-based architecture, supporting curriculum learning, domain randomization,\n    and various termination conditions.\n\n    This is a wrapper class that handles Isaac Sim initialization and delegates\n    to an internal ManagerBasedRLEnv instance.\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        device: torch.device = None,\n        log_dir: str = None,\n        render_mode: str | None = None,\n        headless: bool = True,\n        accelerator=None,\n    ):\n        \"\"\"Initialize the Motion Tracking Environment.\n\n        Args:\n            config: Configuration for the environment\n            device: Device for tensor operations\n            log_dir: Logging directory\n            render_mode: Render mode for the environment\n            headless: Whether to run in headless mode\n            accelerator: Accelerator instance for distributed training (optional)\n        \"\"\"\n        self.config = config\n        self._device = device\n        self.accelerator = accelerator\n\n        self.log_dir = log_dir\n        self.headless = headless\n        self.init_done = False\n        self.is_evaluating = False\n        self.render_mode = render_mode\n\n        # self._init_motion_tracking_components()\n        self._init_isaaclab_env()\n        # self._init_serializers()\n        self._completion_total_queue = deque(maxlen=1000)\n        self._completion_success_queue = deque(maxlen=1000)\n        self.metrics = {}\n        self._robot_prev_joint_vel = None\n\n    @property\n    def num_envs(self):\n        return self._env.num_envs\n\n    @property\n    def device(self):\n        return self._env.device\n\n    def _init_isaaclab_env(self):\n        _device = self._device\n\n        # curriculum = CurriculumCfg()\n\n        # Determine per-process seed if provided; else create a deterministic per-rank default\n        seed_val = getattr(self.config, \"seed\", None)\n        if seed_val is None:\n            if self.accelerator is not None:\n                pid = self.accelerator.process_index\n            else:\n                pid = int(self.config.get(\"process_id\", 0))\n            seed_val = int(time.time()) + pid\n\n        _robot_config_dict = EasyDict(\n            OmegaConf.to_container(self.config.robot, resolve=True)\n        )\n        _terrain_config_dict = EasyDict(\n            OmegaConf.to_container(self.config.terrain, resolve=True)\n        )\n        _obs_config_dict = EasyDict(\n            OmegaConf.to_container(self.config.obs, resolve=True)\n        )\n        _rewards_config_dict = EasyDict(\n            OmegaConf.to_container(self.config.rewards, resolve=True)\n        )\n        _domain_rand_config_dict = (\n            EasyDict(\n                OmegaConf.to_container(\n                    self.config.domain_rand,\n                    resolve=True,\n                )\n            )\n            if self.config.domain_rand is not None\n            else {}\n        )\n        _terminations_config_dict = (\n            EasyDict(\n                OmegaConf.to_container(\n                    self.config.terminations,\n                    resolve=True,\n                )\n            )\n            if self.config.terminations is not None\n            else {}\n        )\n        _scene_config_dict = EasyDict(\n            OmegaConf.to_container(\n                self.config.scene,\n                resolve=True,\n            )\n        )\n        _commands_config_dict = OmegaConf.to_container(\n            self.config.commands,\n            resolve=True,\n        )\n\n        # Headless + no rendering: disable base_velocity debug visualization.\n        # In k8s headless runs, IsaacSim/IsaacLab command debug_vis may wedge\n        # during/after simulation start (seen on velocity-tracking only).\n        # Keep an escape hatch for debugging/video.\n        allow_debug_vis = (not self.headless) or (self.render_mode is not None)\n        force_debug_vis = bool(\n            int(os.environ.get(\"HOLOMOTION_VELCMD_DEBUG_VIS\", \"0\"))\n        )\n        if (\n            (not allow_debug_vis)\n            and (not force_debug_vis)\n            and isinstance(_commands_config_dict, dict)\n            and (\"base_velocity\" in _commands_config_dict)\n        ):\n            bv = _commands_config_dict.get(\"base_velocity\", {})\n            bv_params = bv.get(\"params\", {})\n            if isinstance(bv_params, dict) and bool(\n                bv_params.get(\"debug_vis\", False)\n            ):\n                bv_params[\"debug_vis\"] = False\n                bv[\"params\"] = bv_params\n                _commands_config_dict[\"base_velocity\"] = bv\n                logger.warning(\n                    \"Disabled base_velocity debug_vis for headless non-render runs. \"\n                    \"Set HOLOMOTION_VELCMD_DEBUG_VIS=1 to force-enable.\"\n                )\n\n        _simulation_config_dict = EasyDict(\n            OmegaConf.to_container(\n                self.config.simulation,\n                resolve=True,\n            )\n        )\n        _actions_config_dict = EasyDict(\n            OmegaConf.to_container(\n                self.config.actions,\n                resolve=True,\n            )\n        )\n\n        @configclass\n        class VelocityTrackingEnvCfg(ManagerBasedRLEnvCfg):\n            seed: int = seed_val\n            scene_config_dict = {\n                \"num_envs\": self.config.num_envs,\n                \"env_spacing\": self.config.env_spacing,\n                \"replicate_physics\": self.config.replicate_physics,\n                \"robot\": _robot_config_dict,\n                \"terrain\": _terrain_config_dict,\n                \"domain_rand\": _domain_rand_config_dict,\n                \"lighting\": _scene_config_dict.lighting,\n                \"contact_sensor\": _scene_config_dict.contact_sensor,\n            }\n\n            decimation: int = _simulation_config_dict.control_decimation\n            episode_length_s: int = _simulation_config_dict.episode_length_s\n            sim_freq = _simulation_config_dict.sim_freq\n            dt = 1.0 / sim_freq\n            physx = PhysxCfg(\n                bounce_threshold_velocity=_simulation_config_dict.physx.bounce_threshold_velocity,\n                gpu_max_rigid_patch_count=_simulation_config_dict.physx.gpu_max_rigid_patch_count,\n                enable_stabilization=True,\n            )\n\n            if self.accelerator is not None:\n                main_process = self.accelerator.is_main_process\n                process_id = self.accelerator.process_index\n                num_processes = self.accelerator.num_processes\n            else:\n                main_process = self.config.get(\"main_process\", True)\n                process_id = self.config.get(\"process_id\", 0)\n                num_processes = self.config.get(\"num_processes\", 1)\n            scene: MotionTrackingSceneCfg = build_scene_config(\n                scene_config_dict,\n                main_process=main_process,\n                process_id=process_id,\n                num_processes=num_processes,\n            )\n\n            sim: SimulationCfg = SimulationCfg(\n                dt=dt,\n                render_interval=decimation,\n                physx=physx,\n                device=_device,\n                enable_scene_query_support=True,\n            )\n            sim.physics_material = scene.terrain.physics_material\n\n            viewer: ViewerCfg = ViewerCfg(origin_type=\"world\")\n\n            command_name = list(_commands_config_dict.keys())[0]\n            commands: VelTrack_CommandsCfg = build_velocity_commands_config(\n                _commands_config_dict\n            )\n            observations: ObservationsCfg = build_observations_config(\n                _obs_config_dict.obs_groups\n            )\n            rewards: RewardsCfg = build_rewards_config(_rewards_config_dict)\n\n            if _terminations_config_dict:\n                terminations: TerminationsCfg = build_terminations_config(\n                    _terminations_config_dict\n                )\n            else:\n                terminations: TerminationsCfg = TerminationsCfg()\n\n            if _domain_rand_config_dict:\n                events: EventsCfg = build_domain_rand_config(\n                    _domain_rand_config_dict\n                )\n            else:\n                events: EventsCfg = EventsCfg()\n\n            events.reset_base = EventTerm(\n                func=isaaclab_mdp.reset_root_state_uniform,\n                mode=\"reset\",\n                params={\n                    \"pose_range\": {\n                        \"x\": (-0.5, 0.5),\n                        \"y\": (-0.5, 0.5),\n                        \"yaw\": (-3.14, 3.14),\n                    },\n                    \"velocity_range\": {\n                        \"x\": (0.0, 0.0),\n                        \"y\": (0.0, 0.0),\n                        \"z\": (0.0, 0.0),\n                        \"roll\": (0.0, 0.0),\n                        \"pitch\": (0.0, 0.0),\n                        \"yaw\": (0.0, 0.0),\n                    },\n                },\n            )\n            events.reset_robot_joints = EventTerm(\n                func=isaaclab_mdp.reset_joints_by_scale,\n                mode=\"reset\",\n                params={\n                    \"position_range\": (1.0, 1.0),\n                    \"velocity_range\": (-1.0, 1.0),\n                },\n            )\n\n            # curriculum: CurriculumCfg = build_curriculum_config(\n            #     getattr(self.config, \"curriculum\", {})\n            # )\n\n            actions: ActionsCfg = build_actions_config(_actions_config_dict)\n            sim: SimulationCfg = SimulationCfg(\n                dt=dt,\n                render_interval=decimation,\n                physx=physx,\n                device=_device,\n                enable_scene_query_support=True,\n            )\n            sim.physx.gpu_max_rigid_patch_count = 10 * 2**15\n            sim.physx.enable_stabilization = True\n            sim.physics_material = scene.terrain.physics_material\n\n        isaaclab_env_cfg = VelocityTrackingEnvCfg()\n\n        isaaclab_envconfig_dump_path = os.path.join(\n            self.log_dir, \"isaaclab_env_cfg.yaml\"\n        )\n        dump_yaml(isaaclab_envconfig_dump_path, isaaclab_env_cfg)\n\n        logger.info(\n            \"Constructing IsaacLab ManagerBasedRLEnv (velocity_tracking) ...\"\n        )\n        self._env = ManagerBasedRLEnv(isaaclab_env_cfg, self.render_mode)\n        logger.info(\n            \"IsaacLab ManagerBasedRLEnv constructed (velocity_tracking).\"\n        )\n\n        logger.info(\"IsaacLab environment initialized !\")\n        return self._env\n\n    def _init_motion_tracking_components(self):\n        self._init_serializers()\n\n    def step(self, actor_state: dict):\n        obs_dict, rewards, terminated, time_outs, infos = self._env.step(\n            actor_state\n        )\n        # IsaacLab separates terminated vs time_outs, combine them for consistency\n        dones = terminated | time_outs\n        self._update_completion_rate_stats(terminated, time_outs, infos)\n        return obs_dict, rewards, dones, time_outs, infos\n\n    def _update_completion_rate_stats(\n        self,\n        terminated: torch.Tensor,\n        time_outs: torch.Tensor,\n        infos: dict,\n    ) -> None:\n        \"\"\"Log completion rate over recent done batches.\n\n        Definition:\n        - Completed: time_outs==True and terminated==False.\n        - Failed: terminated==True.\n        The rolling window stores per-step done counts (only when any done occurs).\n        \"\"\"\n        done_mask = (terminated | time_outs).reshape(-1).bool()\n        if torch.any(done_mask):\n            done_count = int(done_mask.sum().item())\n            completed_mask = (\n                time_outs.reshape(-1).bool()\n                & ~terminated.reshape(-1).bool()\n                & done_mask\n            )\n            completed_count = int(completed_mask.sum().item())\n            self._completion_total_queue.append(done_count)\n            self._completion_success_queue.append(completed_count)\n\n        denom = sum(self._completion_total_queue)\n        completion_rate = (\n            float(sum(self._completion_success_queue)) / float(denom)\n            if denom > 0\n            else 0.0\n        )\n        if (\"log\" not in infos) or (not isinstance(infos[\"log\"], dict)):\n            infos[\"log\"] = {}\n        infos[\"log\"][\"Task/Completion_Rate\"] = torch.tensor(\n            completion_rate, device=self.device, dtype=torch.float32\n        )\n\n    def reset_idx(self, env_ids: torch.Tensor):\n        return self._env.reset(env_ids=env_ids)\n\n    def reset_all(self):\n        env_ids = torch.arange(self.num_envs, device=self.device)\n        out = self._env.reset(env_ids=env_ids)\n        return out\n\n    def set_is_evaluating(self):\n        logger.info(\"Setting environment to evaluation mode\")\n        self.is_evaluating = True\n\n    def seed(self, seed: int):\n        self._env.seed(seed)\n"
  },
  {
    "path": "holomotion/src/evaluation/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/evaluation/eval_motion_tracking.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport os\nimport argparse\nimport subprocess\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple\nfrom loguru import logger\n\n\ndef find_checkpoints_to_evaluate(\n    eval_h5_dataset_path: str,\n    root_dir: Path,\n    target_checkpoints: Optional[List[str]],\n    config_name: str,\n) -> List[Tuple[str, str]]:\n    \"\"\"Scan all model subdirectories and collect checkpoints that need evaluation.\n\n    Behavior:\n        - If `target_checkpoints` is provided and non-empty:\n            only these checkpoint stems are considered (e.g. ['model_17500']).\n        - If `target_checkpoints` is None or empty:\n            all checkpoints matching 'model_*.pt' under each model directory will be considered.\n    Returns:\n        A list of (checkpoint_path, config_name) tuples to be evaluated.\n    \"\"\"\n    checkpoints_to_evaluate: List[Tuple[str, str]] = []\n    dataset_path = Path(eval_h5_dataset_path)\n    dataset_suffix = (\n        dataset_path.name if dataset_path.name else \"dataset_unknown\"\n    )\n\n    if root_dir.is_file():\n        checkpoint_file = root_dir\n\n        model_dir_path = checkpoint_file.parent\n        checkpoint_stem = checkpoint_file.stem\n\n        eval_out_dir = (\n            model_dir_path\n            / f\"isaaclab_eval_output_{checkpoint_stem}_{dataset_suffix}\"\n        )\n\n        cfg_name = f\"evaluation/{config_name}\"\n        return [(str(checkpoint_file), cfg_name)]\n\n    if not root_dir.is_dir():\n        logger.error(\n            f\"Checkpoint root directory '{root_dir}' does not exist or is not a directory.\"\n        )\n        return []\n\n    if target_checkpoints:\n        logger.info(\n            f\"Searching for explicit target checkpoints: {target_checkpoints}\"\n        )\n\n    # Iterate over each model directory directly under root_dir\n    for model_dir_path in root_dir.iterdir():\n        if not model_dir_path.is_dir():\n            continue\n\n        if target_checkpoints:\n            # Use only the requested checkpoint stems\n            candidate_files = [\n                model_dir_path / f\"{stem}.pt\" for stem in target_checkpoints\n            ]\n        else:\n            candidate_files = sorted(model_dir_path.glob(\"model_*.pt\"))\n\n        if not candidate_files:\n            continue\n\n        for checkpoint_file in candidate_files:\n            if not checkpoint_file.is_file():\n                logger.debug(f\"Target checkpoint not found: {checkpoint_file}\")\n                continue\n\n            checkpoint_stem = checkpoint_file.stem\n            eval_out_dir = (\n                model_dir_path\n                / f\"isaaclab_eval_output_{checkpoint_stem}_{dataset_suffix}\"\n            )\n            if eval_out_dir.is_dir():\n                logger.debug(\n                    f\"Skipping {checkpoint_file.name}, output exists.\"\n                )\n                continue\n\n            # Construct Hydra config name from the folder name\n            cfg_name = f\"evaluation/{config_name}\"\n            checkpoints_to_evaluate.append((str(checkpoint_file), cfg_name))\n\n    checkpoints_to_evaluate.sort(key=lambda x: x[0])\n    return checkpoints_to_evaluate\n\n\ndef main(\n    checkpoint_dir: str,\n    target_checkpoints: Optional[List[str]],\n    eval_h5_dataset_path: str,\n    config_name: str,\n    num_envs: str,\n) -> None:\n    \"\"\"\n    Entry point for batch evaluation.\n\n    Args:\n        checkpoint_root_dir: Root directory containing subdirectories for models.\n        target_checkpoints: Optional list of checkpoint stems to evaluate\n        single_eval_script: Path to the shell script to run a single evaluation.\n    \"\"\"\n    root_path = Path(checkpoint_dir)\n\n    checkpoints_to_evaluate = find_checkpoints_to_evaluate(\n        eval_h5_dataset_path=eval_h5_dataset_path,\n        root_dir=root_path,\n        target_checkpoints=target_checkpoints,\n        config_name=config_name,\n    )\n\n    if not checkpoints_to_evaluate:\n        logger.warning(\n            f\"No pending evaluations found under '{checkpoint_dir}'.\"\n        )\n        return\n\n    logger.info(\n        f\"Found {len(checkpoints_to_evaluate)} checkpoints to evaluate.\"\n    )\n\n    for i, (ckpt_path, cfg_name) in enumerate(\n        checkpoints_to_evaluate, start=1\n    ):\n        logger.info(\n            f\"[{i}/{len(checkpoints_to_evaluate)}] Evaluating: {cfg_name}/{ckpt_path}\"\n        )\n\n        command = [\n            \"bash\",\n            \"holomotion/scripts/evaluation/eval_motion_tracking_single.sh\",\n            ckpt_path,\n            cfg_name,\n            eval_h5_dataset_path,\n            num_envs,\n        ]\n        subprocess.run(\n            command,\n        )\n\n\ndef parse_args() -> argparse.Namespace:\n    \"\"\"Parse CLI arguments for the batch evaluation script.\"\"\"\n    parser = argparse.ArgumentParser(description=\"motion-tracking evaluation.\")\n    parser.add_argument(\"--checkpoint_dir\", type=str, required=True)\n    parser.add_argument(\n        \"--target_checkpoints\", type=str, nargs=\"*\", default=None\n    )\n    parser.add_argument(\"--config_name\", type=str, required=True)\n    parser.add_argument(\"--eval_h5_dataset_path\", type=str, required=True)\n    parser.add_argument(\"--num_envs\", type=str, required=True)\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(\n        checkpoint_dir=args.checkpoint_dir,\n        target_checkpoints=args.target_checkpoints,\n        eval_h5_dataset_path=args.eval_h5_dataset_path,\n        config_name=args.config_name,\n        num_envs=args.num_envs,\n    )\n"
  },
  {
    "path": "holomotion/src/evaluation/eval_motion_tracking_single.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport os\nimport re\nfrom pathlib import Path\n\nimport hydra\nfrom hydra.utils import get_class\nfrom loguru import logger\nfrom omegaconf import ListConfig, OmegaConf\n\nfrom holomotion.src.evaluation.metrics import run_evaluation\nfrom holomotion.src.utils.config import compile_config\nfrom holomotion.src.utils.onnx_export import export_policy_to_onnx\n\n\ndef load_training_config(\n    checkpoint_path: str, eval_config: OmegaConf\n) -> OmegaConf:\n    \"\"\"Load training config from checkpoint directory.\n\n    Args:\n        checkpoint_path: Path to the checkpoint file.\n        eval_config: Full evaluation config (including command line overrides).\n\n    Returns:\n        Merged config with training config as base.\n    \"\"\"\n    checkpoint = Path(checkpoint_path)\n    config_path = checkpoint.parent / \"config.yaml\"\n\n    if not config_path.exists():\n        config_path = checkpoint.parent.parent / \"config.yaml\"\n        if not config_path.exists():\n            logger.warning(\n                f\"Training config not found at {config_path}, using evaluation config\"\n            )\n            return eval_config\n\n    logger.info(f\"Loading training config from {config_path}\")\n    with open(config_path) as file:\n        train_config = OmegaConf.load(file)\n\n    # Apply eval_overrides from training config if they exist\n    if train_config.get(\"eval_overrides\") is not None:\n        train_config = OmegaConf.merge(\n            train_config, train_config.eval_overrides\n        )\n\n    # Set checkpoint path\n    train_config.checkpoint = checkpoint_path\n    train_config.algo.config.checkpoint = checkpoint_path\n\n    # For evaluation, merge eval_config into train_config\n    config = OmegaConf.merge(train_config, eval_config)\n\n    # force set the terminations and domain rand with eval_config's\n    config.env.config.terminations = eval_config.env.config.terminations\n    config.env.config.domain_rand = eval_config.env.config.domain_rand\n    obs_groups = config.env.config.obs.obs_groups\n    if \"policy\" in obs_groups:\n        obs_groups.policy.enable_corruption = False\n    if \"critic\" in obs_groups:\n        obs_groups.critic.enable_corruption = False\n    if \"unified\" in obs_groups:\n        obs_groups.unified.enable_corruption = False\n\n    return config\n\n\ndef _infer_dataset_suffix(output_dir: str, checkpoint_path: str) -> str:\n    output_name = Path(output_dir).name\n    model_name = Path(checkpoint_path).stem\n    expected_prefix = f\"isaaclab_eval_output_{model_name}_\"\n    if output_name.startswith(expected_prefix):\n        return output_name[len(expected_prefix) :]\n    return output_name\n\n\ndef _checkpoint_sort_key(checkpoint_path: Path):\n    match = re.search(r\"model_(\\d+)\\.pt$\", checkpoint_path.name)\n    if match is not None:\n        return (0, int(match.group(1)), checkpoint_path.name)\n    return (1, checkpoint_path.name)\n\n\ndef _normalize_ckpt_pt_names(ckpt_pt_names) -> list[str]:\n    if ckpt_pt_names is None:\n        return []\n\n    if isinstance(ckpt_pt_names, ListConfig):\n        raw_names = list(ckpt_pt_names)\n    elif isinstance(ckpt_pt_names, (list, tuple)):\n        raw_names = list(ckpt_pt_names)\n    else:\n        raise TypeError(\n            f\"ckpt_pt_names must be a list/tuple, got {type(ckpt_pt_names)}\"\n        )\n\n    normalized_names = []\n    for name in raw_names:\n        name_str = str(name).strip()\n        if name_str == \"\":\n            continue\n        if not name_str.endswith(\".pt\"):\n            name_str = f\"{name_str}.pt\"\n        normalized_names.append(name_str)\n    return normalized_names\n\n\ndef _resolve_export_ckpt_paths(config: OmegaConf) -> list[Path]:\n    log_dir_value = config.get(\"log_dir\", None)\n    checkpoint_value = config.get(\"checkpoint\", None)\n\n    if log_dir_value is None or str(log_dir_value).strip() == \"\":\n        if checkpoint_value is None or str(checkpoint_value).strip() == \"\":\n            raise ValueError(\n                \"When export_only=true, set log_dir or checkpoint.\"\n            )\n        log_dir = Path(str(checkpoint_value)).parent\n    else:\n        log_dir = Path(str(log_dir_value))\n\n    if not log_dir.is_dir():\n        raise NotADirectoryError(\n            f\"log_dir does not exist or is not a directory: {log_dir}\"\n        )\n\n    ckpt_pt_names = _normalize_ckpt_pt_names(config.get(\"ckpt_pt_names\", None))\n    if len(ckpt_pt_names) > 0:\n        selected_paths = []\n        missing_names = []\n        for name in ckpt_pt_names:\n            ckpt_path = log_dir / name\n            if ckpt_path.is_file():\n                selected_paths.append(ckpt_path)\n            else:\n                missing_names.append(name)\n\n        if len(missing_names) > 0:\n            raise FileNotFoundError(\n                f\"Missing checkpoints in log_dir={log_dir}: {missing_names}\"\n            )\n        return selected_paths\n\n    discovered_paths = sorted(log_dir.glob(\"*.pt\"), key=_checkpoint_sort_key)\n    if len(discovered_paths) == 0:\n        raise FileNotFoundError(\n            f\"No .pt checkpoints found in log_dir={log_dir}\"\n        )\n    return discovered_paths\n\n\n@hydra.main(\n    config_path=\"../../config\",\n    config_name=\"evaluation/eval_isaaclab\",\n    version_base=None,\n)\ndef main(config: OmegaConf):\n    \"\"\"Evaluate the motion tracking model.\n\n    Args:\n        config: OmegaConf object containing the evaluation configuration.\n\n    \"\"\"\n    export_only = bool(config.get(\"export_only\", False))\n    if export_only:\n        checkpoint_paths = _resolve_export_ckpt_paths(config)\n        config = load_training_config(str(checkpoint_paths[0]), config)\n    else:\n        if config.checkpoint is None:\n            raise ValueError(\"Checkpoint path must be provided for evaluation\")\n        checkpoint_paths = [Path(str(config.checkpoint))]\n        config = load_training_config(config.checkpoint, config)\n\n    # Compile config without accelerator (PPO will create it)\n    config = compile_config(config, accelerator=None)\n\n    # Use checkpoint directory as log_dir for offline evaluation/export.\n    log_dir = str(checkpoint_paths[0].parent)\n    headless = config.headless\n\n    # PPO creates Accelerator, AppLauncher, and environment internally\n    algo_class = get_class(config.algo._target_)\n    algo = algo_class(\n        env_config=config.env,\n        config=config.algo.config,\n        log_dir=log_dir,\n        headless=headless,\n        is_offline_eval=True,\n    )\n\n    if (\n        algo.accelerator.is_main_process\n        and os.environ.get(\"TORCH_COMPILE_DISABLE\", \"0\") != \"1\"\n    ):\n        logger.info(\n            \"Tip: If you encounter Triton/compilation errors during evaluation,\"\n        )\n        logger.info(\n            \"     set environment variable: export TORCH_COMPILE_DISABLE=1\"\n        )\n\n    if algo.accelerator.is_main_process:\n        with open(os.path.join(log_dir, \"eval_config.yaml\"), \"w\") as f:\n            OmegaConf.save(config, f)\n\n    if export_only:\n        if algo.accelerator.is_main_process:\n            logger.info(\n                \"Running export-only mode for \"\n                f\"{len(checkpoint_paths)} checkpoints in {log_dir}\"\n            )\n        onnx_name_suffix = config.get(\"onnx_name_suffix\", None)\n        use_kv_cache = config.get(\"use_kv_cache\", True)\n        for i, checkpoint_path in enumerate(checkpoint_paths, start=1):\n            ckpt_path = str(checkpoint_path)\n            if algo.accelerator.is_main_process:\n                logger.info(\n                    f\"[{i}/{len(checkpoint_paths)}] Loading checkpoint: \"\n                    f\"{ckpt_path}\"\n                )\n            algo.load(ckpt_path)\n            if algo.accelerator.is_main_process:\n                onnx_path = export_policy_to_onnx(\n                    algo,\n                    ckpt_path,\n                    onnx_name_suffix=onnx_name_suffix,\n                    use_kv_cache=use_kv_cache,\n                )\n                logger.info(f\"Successfully exported policy to: {onnx_path}\")\n            algo.accelerator.wait_for_everyone()\n        if algo.accelerator.is_main_process:\n            logger.info(\"Export-only mode completed successfully!\")\n        return\n\n    if algo.accelerator.is_main_process:\n        logger.info(f\"Loading checkpoint for evaluation: {config.checkpoint}\")\n    algo.load(config.checkpoint)\n\n    command_name = list(config.env.config.commands.keys())[0]\n    if command_name == \"ref_motion\":\n        motion_cmd = algo.env._env.command_manager.get_term(\"ref_motion\")\n        algo.env._env.reset()\n        motion_cmd._update_ref_motion_state()\n\n    # Export ONNX if requested\n    if config.get(\"export_policy\", True):\n        if algo.accelerator.is_main_process:\n            onnx_name_suffix = config.get(\"onnx_name_suffix\", None)\n            onnx_path = export_policy_to_onnx(\n                algo,\n                config.checkpoint,\n                onnx_name_suffix=onnx_name_suffix,\n                use_kv_cache=config.get(\"use_kv_cache\", True),\n            )\n            logger.info(f\"Successfully exported policy to: {onnx_path}\")\n        algo.accelerator.wait_for_everyone()\n\n    calc_per_clip_metrics = bool(config.get(\"calc_per_clip_metrics\", False))\n    generate_report = bool(config.get(\"generate_report\", False))\n    dump_npzs = bool(config.get(\"dump_npzs\", False)) or calc_per_clip_metrics\n    dof_mode = config.get(\"dof_mode\", \"29\")\n    if (\n        calc_per_clip_metrics\n        and not bool(config.get(\"dump_npzs\", False))\n        and algo.accelerator.is_main_process\n    ):\n        logger.info(\n            \"calc_per_clip_metrics=true requires dumped NPZs; \"\n            \"enabling dump_npzs automatically.\"\n        )\n\n    result = algo.offline_evaluate_policy(dump_npzs)\n    algo.accelerator.wait_for_everyone()\n\n    if algo.accelerator.is_main_process:\n        logger.info(\"Evaluation completed successfully!\")\n        output_dir = (\n            result.get(\"output_dir\") if isinstance(result, dict) else None\n        )\n        if output_dir is not None:\n            logger.info(f\"NPZs saved to: {output_dir}\")\n\n        if calc_per_clip_metrics:\n            if output_dir is None:\n                logger.warning(\n                    \"Skipping per-clip metric calculation because \"\n                    \"output_dir is unavailable.\"\n                )\n            else:\n                dataset_suffix = _infer_dataset_suffix(\n                    output_dir, config.checkpoint\n                )\n                run_evaluation(\n                    npz_dir=output_dir,\n                    dataset_suffix=dataset_suffix,\n                    failure_pos_err_thresh_m=0.25,\n                    dof_mode=dof_mode,\n                )\n                logger.info(\n                    f\"Finished per-clip metric calculation for: {output_dir}\"\n                )\n\n        if generate_report:\n            if output_dir is None:\n                logger.warning(\n                    \"Skipping report generation because output_dir is unavailable.\"\n                )\n            else:\n                from holomotion.scripts.evaluation import (\n                    mean_process_5metrics,\n                )\n\n                report_path = mean_process_5metrics.generate_macro_mean_report_from_json_dir(\n                    output_dir\n                )\n                logger.info(f\"Generated metrics report at: {report_path}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/evaluation/eval_mujoco_sim2sim.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport os\nimport csv\nimport shutil\nimport sys\nimport threading\nimport time\nfrom collections import deque\nfrom pathlib import Path\nfrom threading import Thread\n\nimport cv2\nimport hydra\nimport mujoco\nimport mujoco.viewer\nimport numpy as np\nimport onnx\nimport onnxruntime\nimport torch\nfrom loguru import logger\nfrom omegaconf import ListConfig, OmegaConf, open_dict\nfrom tqdm import tqdm\nimport glob\nimport re\n\nimport ray\nfrom holomotion.src.evaluation.metrics import run_evaluation\n\ntry:\n    from horizon_tc_ui.hb_runtime import HBRuntime\nexcept ImportError:\n    HB_ONNXRuntime = None\n    logger.warning(\"HB_ONNXRuntime not available!\")\n\nONNX_IO_DUMP_DIRNAME = \"onnx_io_npy\"\n\ntry:\n    import pynput.keyboard as pynput_kb\n\n    PYNPUT_AVAILABLE = True\nexcept ImportError:\n    PYNPUT_AVAILABLE = False\n    if \"headless\" in sys.argv and \"false\" in sys.argv:\n        logger.warning(\"pynput not available, keyboard control disabled\")\n\nfrom holomotion.src.evaluation.obs import PolicyObsBuilder\nfrom holomotion.src.utils.torch_utils import (\n    quat_apply,\n    quat_inv,\n    subtract_frame_transforms,\n    quat_normalize_wxyz,\n    matrix_from_quat,\n    xyzw_to_wxyz,\n    quat_mul,\n    quat_from_euler_xyz,\n)\nfrom holomotion.src.motion_retargeting.utils.rotation_conversions import (\n    standardize_quaternion,\n)\n\nDEFAULT_FEET_GEOM_NAMES = {\n    \"left\": [\"left_foot\"],\n    \"right\": [\"right_foot\"],\n}\nDEFAULT_FEET_BODY_NAMES = {\n    \"left\": [\"left_ankle_roll_link\"],\n    \"right\": [\"right_ankle_roll_link\"],\n}\n\n\ndef _coerce_config_bool(value, default: bool = False) -> bool:\n    \"\"\"Interpret config booleans without treating non-empty strings as truthy.\"\"\"\n    if value is None:\n        return default\n    if isinstance(value, (bool, np.bool_)):\n        return bool(value)\n    if isinstance(value, str):\n        value = value.strip().lower()\n        if value in {\"1\", \"true\", \"yes\", \"y\", \"on\"}:\n            return True\n        if value in {\"0\", \"false\", \"no\", \"n\", \"off\", \"\"}:\n            return False\n    return bool(value)\n\n\nclass OffscreenRenderer:\n    \"\"\"Minimal offscreen renderer for MuJoCo frames.\"\"\"\n\n    def __init__(\n        self,\n        model,\n        height: int,\n        width: int,\n        distance: float | None = None,\n        azimuth: float | None = None,\n        elevation: float | None = None,\n    ):\n        self.model = model\n        self.height = height\n        self.width = width\n\n        self._overlay_callback = None\n\n        self._gl_ctx = mujoco.GLContext(width, height)\n        self._gl_ctx.make_current()\n\n        self._scene = mujoco.MjvScene(model, maxgeom=1000)\n        self._cam = mujoco.MjvCamera()\n        self._opt = mujoco.MjvOption()\n        mujoco.mjv_defaultFreeCamera(model, self._cam)\n        self.set_align_view(\n            distance=distance,\n            azimuth=azimuth,\n            elevation=elevation,\n        )\n\n        self._con = mujoco.MjrContext(\n            model,\n            mujoco.mjtFontScale.mjFONTSCALE_100,\n        )\n        self._rgb = np.zeros((height, width, 3), dtype=np.uint8)\n        self._viewport = mujoco.MjrRect(0, 0, width, height)\n\n    def set_overlay_callback(self, callback) -> None:\n        \"\"\"Register a callback to draw custom geoms into the scene each frame.\"\"\"\n        self._overlay_callback = callback\n\n    def render(self, data) -> np.ndarray:\n        mujoco.mjv_updateScene(\n            self.model,\n            data,\n            self._opt,\n            None,\n            self._cam,\n            mujoco.mjtCatBit.mjCAT_ALL.value,\n            self._scene,\n        )\n        if self._overlay_callback is not None:\n            self._overlay_callback(self._scene)\n        mujoco.mjr_render(self._viewport, self._scene, self._con)\n        mujoco.mjr_readPixels(self._rgb, None, self._viewport, self._con)\n        return np.flipud(self._rgb)\n\n    def set_align_view(\n        self,\n        lookat: np.ndarray | None = None,\n        distance: float | None = None,\n        azimuth: float | None = None,\n        elevation: float | None = None,\n    ):\n        \"\"\"Set camera to 'align' preset view (default azimuth=60, elevation=-20).\n\n        Args:\n            lookat: Optional lookat point [x, y, z]. If None, uses current lookat.\n            distance: Optional camera distance from lookat point. If None, uses current distance.\n        \"\"\"\n        self._cam.type = mujoco.mjtCamera.mjCAMERA_FREE\n        if azimuth is None:\n            self._cam.azimuth = 60.0  # Side view (looking along Y-axis)\n        else:\n            self._cam.azimuth = float(azimuth)\n        if elevation is None:\n            self._cam.elevation = -20.0  # Slight downward angle\n        else:\n            self._cam.elevation = float(elevation)\n        if lookat is not None:\n            self._cam.lookat = np.asarray(lookat, dtype=np.float32)\n        if distance is not None:\n            self._cam.distance = float(distance)\n\n    def close(self):\n        self._gl_ctx.free()\n\n\nclass VelocityKeyboardHandler:\n    \"\"\"Keyboard handler for interactive velocity commands using WASD and JL keys.\"\"\"\n\n    def __init__(\n        self,\n        vx_increment: float = 0.1,\n        vy_increment: float = 0.1,\n        vyaw_increment: float = 0.05,\n        vx_limits: tuple = (-0.5, 1.0),\n        vy_limits: tuple = (-0.3, 0.3),\n        vyaw_limits: tuple = (-0.5, 0.5),\n    ):\n        self.vx_increment = vx_increment\n        self.vy_increment = vy_increment\n        self.vyaw_increment = vyaw_increment\n\n        # Velocity limits from training config\n        self.vx_min, self.vx_max = vx_limits\n        self.vy_min, self.vy_max = vy_limits\n        self.vyaw_min, self.vyaw_max = vyaw_limits\n\n        self.vx = 0.0\n        self.vy = 0.0\n        self.vyaw = 0.0\n\n        self._listener = None\n        self._lock = threading.Lock()\n\n    def start_listener(self):\n        \"\"\"Start keyboard listener thread (requires pynput).\"\"\"\n        if not PYNPUT_AVAILABLE:\n            logger.warning(\"pynput not available, keyboard control disabled\")\n            return\n\n        def on_press(key):\n            try:\n                if hasattr(key, \"char\") and key.char:\n                    self._handle_key(key.char)\n            except AttributeError:\n                pass\n\n        self._listener = pynput_kb.Listener(on_press=on_press)\n        self._listener.start()\n        logger.info(\n            f\"Keyboard listener started. Velocity limits: \"\n            f\"vx=[{self.vx_min:.1f},{self.vx_max:.1f}], \"\n            f\"vy=[{self.vy_min:.1f},{self.vy_max:.1f}], \"\n            f\"vyaw=[{self.vyaw_min:.1f},{self.vyaw_max:.1f}]\"\n        )\n\n    def stop_listener(self):\n        \"\"\"Stop keyboard listener thread.\"\"\"\n        if self._listener is not None:\n            self._listener.stop()\n            self._listener = None\n\n    def get_velocity_command(self) -> np.ndarray:\n        \"\"\"Get velocity command [vx, vy, vyaw].\n\n        Returns:\n            Velocity command [vx, vy, vyaw]\n        \"\"\"\n        with self._lock:\n            return np.array([self.vx, self.vy, self.vyaw], dtype=np.float32)\n\n    def _handle_key(self, char: str):\n        \"\"\"Handle keyboard press events.\"\"\"\n        with self._lock:\n            # W/S for vx (forward/backward)\n            if char in [\"W\", \"w\"]:\n                self.vx = np.clip(\n                    self.vx + self.vx_increment, self.vx_min, self.vx_max\n                )\n                logger.info(\n                    f\"[W] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}\"\n                )\n            elif char in [\"S\", \"s\"]:\n                self.vx = np.clip(\n                    self.vx - self.vx_increment, self.vx_min, self.vx_max\n                )\n                logger.info(\n                    f\"[S] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}\"\n                )\n            # A/D for vy (left/right)\n            elif char in [\"A\", \"a\"]:\n                self.vy = np.clip(\n                    self.vy + self.vy_increment, self.vy_min, self.vy_max\n                )\n                logger.info(\n                    f\"[A] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}\"\n                )\n            elif char in [\"D\", \"d\"]:\n                self.vy = np.clip(\n                    self.vy - self.vy_increment, self.vy_min, self.vy_max\n                )\n                logger.info(\n                    f\"[D] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}\"\n                )\n            # J/L for vyaw (turn left/right)\n            elif char in [\"J\", \"j\"]:\n                self.vyaw = np.clip(\n                    self.vyaw + self.vyaw_increment,\n                    self.vyaw_min,\n                    self.vyaw_max,\n                )\n                logger.info(\n                    f\"[J] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}\"\n                )\n            elif char in [\"L\", \"l\"]:\n                self.vyaw = np.clip(\n                    self.vyaw - self.vyaw_increment,\n                    self.vyaw_min,\n                    self.vyaw_max,\n                )\n                logger.info(\n                    f\"[L] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}\"\n                )\n            # Space to reset all\n            elif char == \" \":\n                self.vx = 0.0\n                self.vy = 0.0\n                self.vyaw = 0.0\n                logger.info(\"[Space] Command reset to zero\")\n            # X to stop (emergency brake)\n            elif char in [\"X\", \"x\"]:\n                self.vx = 0.0\n                self.vy = 0.0\n                self.vyaw = 0.0\n                logger.info(\"[X] Emergency stop - all velocities set to zero\")\n\n\nclass MujocoEvaluator:\n    \"\"\"Class to handle MuJoCo simulation for policy evaluation.\"\"\"\n\n    def __init__(self, config):\n        \"\"\"Initialize the MuJoCo evaluator.\n\n        Args:\n            config: Configuration object with simulation parameters.\n        \"\"\"\n        self.config = config\n\n        # Initialize variables\n        self.policy_session = None\n        self.motion_encoding = None\n        self.m = None  # MuJoCo model\n        self.d = None  # MuJoCo data\n\n        # Determine command mode from config\n        self.command_mode = self._detect_command_mode()\n        if \"motion_npz_dir\" not in config:\n            logger.info(f\"Command mode: {self.command_mode}\")\n\n        # Motion data\n        self.ref_dof_pos = None\n        self.ref_dof_vel = None\n        self.filter_cutoff_hz = None\n        self.n_motion_frames = 0\n        self.motion_frame_idx = 0\n\n        # Velocity command (for velocity tracking mode)\n        self.velocity_command = np.zeros(3, dtype=np.float32)  # [vx, vy, vyaw]\n        self.target_heading = 0.0  # Target heading for velocity tracking\n        self.keyboard_handler = (\n            None  # Will be initialized if velocity_tracking\n        )\n\n        # Extract configuration parameters\n        self.simulation_dt = 1 / 200\n        self.policy_dt = 1 / 50\n        self.control_decimation = 4\n        self.dof_names_ref_motion = list(config.robot.dof_names)\n        self.num_actions = len(self.dof_names_ref_motion)\n\n        self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32)\n\n        self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.target_dof_pos_onnx = self.default_angles_onnx.copy()\n        self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)\n        self.n_fut_frames = int(config.obs.n_fut_frames)\n        self.actor_place_holder_ndim = self._find_actor_place_holder_ndim()\n\n        self.use_kv_cache = False\n        self.policy_kv_cache = None\n        self.policy_kv_input_name = None\n        self.policy_kv_output_name = None\n        self.policy_kv_shape = None\n        self.policy_model_context_len = 0\n        algo_cfg = self.config.get(\"algo\", None)\n        if algo_cfg is None:\n            raise ValueError(\"Missing config.algo for MuJoCo evaluation.\")\n        algo_config = algo_cfg.get(\"config\", None)\n        if algo_config is None:\n            raise ValueError(\n                \"Missing config.algo.config for MuJoCo evaluation.\"\n            )\n        max_context_len_cfg = algo_config.get(\"num_steps_per_env\", None)\n        if max_context_len_cfg is None:\n            raise ValueError(\n                \"Missing config.algo.config.num_steps_per_env for MuJoCo evaluation.\"\n            )\n        self.max_context_len = int(max_context_len_cfg)\n        if self.max_context_len <= 0:\n            raise ValueError(\n                \"config.algo.config.num_steps_per_env must be > 0, \"\n                f\"got {self.max_context_len}\"\n            )\n        self.policy_effective_context_len = 0\n\n        self.counter = 0\n        self.tau_hist = []\n        # Latest Unitree lowstate message (populated when using Unitree bridge)\n        # self._lowstate_msg = None\n        # Desired target positions keyed by DOF name (updated after each policy step)\n        self.target_dof_pos_by_name = {}\n\n        # Video/recording related\n        self._video_writer = None\n        self._offscreen = None\n        self._frame_interval = None\n        self._last_frame_time = 0.0\n        # Reference(global)->Simulation(global) rigid transform (computed at init)\n        self._ref_to_sim_ready = False\n        self._ref_to_sim_q_wxyz = np.array(\n            [1.0, 0.0, 0.0, 0.0], dtype=np.float32\n        )\n        self._ref_to_sim_t = np.zeros(3, dtype=np.float32)\n        # Optional offset between reference globals and dataset body names (e.g., world body at index 0)\n        # Robot state recording buffers for offline NPZ dumping\n        self._robot_dof_pos_seq: list[np.ndarray] = []\n        self._robot_dof_vel_seq: list[np.ndarray] = []\n        self._robot_dof_acc_seq: list[np.ndarray] = []\n        self._robot_dof_torque_seq: list[np.ndarray] = []\n        self._robot_low_level_dof_torque_seq: list[np.ndarray] = []\n        self._robot_low_level_foot_contact_seq: list[np.ndarray] = []\n        self._robot_low_level_foot_normal_force_seq: list[np.ndarray] = []\n        self._robot_low_level_foot_tangent_speed_seq: list[np.ndarray] = []\n        self._robot_actions_seq: list[np.ndarray] = []\n        self._robot_action_rate_seq: list[np.float32] = []\n        self._robot_global_translation_seq: list[np.ndarray] = []\n        self._robot_global_rotation_quat_seq: list[np.ndarray] = []\n        self._robot_global_velocity_seq: list[np.ndarray] = []\n        self._robot_global_angular_velocity_seq: list[np.ndarray] = []\n        self._robot_moe_expert_indices_seq: list[np.ndarray] = []\n        self._robot_moe_expert_logits_seq: list[np.ndarray] = []\n        self._prev_recorded_dof_vel_ref: np.ndarray | None = None\n        self._prev_actions_onnx: np.ndarray | None = None\n        (\n            self.action_ema_filter_enabled,\n            self.action_ema_filter_alpha,\n        ) = self._get_action_ema_filter_cfg()\n        self._filtered_actions_onnx: np.ndarray | None = None\n        (\n            self.policy_action_delay_step,\n            self.action_delay_type,\n        ) = self._get_action_delay_cfg()\n        self._policy_action_delay_buffer: deque[np.ndarray] = deque(\n            maxlen=max(1, self.policy_action_delay_step + 1)\n        )\n        self._current_policy_action_delay_step = 0\n        self._reset_action_delay_randomization()\n        # Camera config (viewer + offscreen)\n        self._camera_tracking_enabled = bool(\n            self.config.get(\"camera_tracking\", True)\n        )\n        self._camera_height_offset = float(\n            self.config.get(\"camera_height_offset\", 0.3)\n        )\n        self._camera_distance = float(self.config.get(\"camera_distance\", 4.0))\n        self._camera_azimuth = float(self.config.get(\"camera_azimuth\", 60.0))\n        self._camera_elevation = float(\n            self.config.get(\"camera_elevation\", -20.0)\n        )\n        self._root_body_id = -1\n        self._foot_contact_logging_enabled = False\n        self._foot_geom_id_groups: list[list[int]] = [[], []]\n        self._foot_geom_id_to_side: dict[int, int] = {}\n        self._prev_low_level_foot_geom_centers: np.ndarray | None = None\n        self.dump_onnx_io_npy = bool(\n            self.config.get(\"dump_onnx_io_npy\", False)\n        )\n        self.policy_moe_layer_output_names: list[tuple[int, str, str]] = []\n        self._reset_onnx_io_dump_buffers()\n\n    def _reset_onnx_io_dump_buffers(self):\n        self._onnx_io_input_names: list[str] = []\n        self._onnx_io_output_names: list[str] = []\n        self._onnx_io_inputs: dict[str, list[np.ndarray]] = {}\n        self._onnx_io_outputs: dict[str, list[np.ndarray]] = {}\n\n    def _get_action_ema_filter_cfg(self) -> tuple[bool, float]:\n        actuator_cfg = self.config.get(\"robot\", {}).get(\"actuators\", {})\n        actuator_type = actuator_cfg.get(\"actuator_type\", \"unitree\")\n        if actuator_type != \"unitree_erfi\":\n            return False, 1.0\n\n        enabled = _coerce_config_bool(\n            actuator_cfg.get(\"ema_filter_enabled\", False), default=False\n        )\n        alpha = float(actuator_cfg.get(\"ema_filter_alpha\", 1.0))\n        if not 0.0 <= alpha <= 1.0:\n            raise ValueError(\n                \"robot.actuators.ema_filter_alpha must be within [0, 1], \"\n                f\"got {alpha}.\"\n            )\n        return enabled, alpha\n\n    def _reset_action_ema_filter(self) -> None:\n        self._filtered_actions_onnx = None\n\n    def _apply_action_ema_filter(self, raw_actions: np.ndarray) -> np.ndarray:\n        raw_actions = np.asarray(raw_actions, dtype=np.float32)\n        if not self.action_ema_filter_enabled:\n            return raw_actions.copy()\n\n        if self._filtered_actions_onnx is None:\n            self._filtered_actions_onnx = raw_actions.copy()\n            return self._filtered_actions_onnx.copy()\n\n        # self.action_ema_filter_alpha = 0.7\n        filtered_actions = (\n            self.action_ema_filter_alpha * raw_actions\n            + (1.0 - self.action_ema_filter_alpha)\n            * self._filtered_actions_onnx\n        ).astype(np.float32, copy=False)\n        self._filtered_actions_onnx = filtered_actions.copy()\n        return self._filtered_actions_onnx.copy()\n\n    def _get_action_delay_cfg(self) -> tuple[int, str]:\n        max_delay_step = int(self.config.get(\"policy_action_delay_step\", 0))\n        if max_delay_step < 0:\n            raise ValueError(\n                \"policy_action_delay_step must be non-negative, \"\n                f\"got {max_delay_step}.\"\n            )\n\n        delay_type = (\n            str(self.config.get(\"action_delay_type\", \"episode\"))\n            .strip()\n            .lower()\n        )\n        if delay_type not in {\"step\", \"episode\"}:\n            raise ValueError(\n                \"action_delay_type must be one of {'step', 'episode'}, \"\n                f\"got {delay_type!r}.\"\n            )\n        return max_delay_step, delay_type\n\n    def _sample_policy_action_delay_step(self) -> int:\n        if self.policy_action_delay_step <= 0:\n            return 0\n        return int(np.random.randint(0, self.policy_action_delay_step + 1))\n\n    def _reset_action_delay_randomization(self) -> None:\n        self._policy_action_delay_buffer = deque(\n            maxlen=max(1, self.policy_action_delay_step + 1)\n        )\n        if self.policy_action_delay_step <= 0:\n            self._current_policy_action_delay_step = 0\n            return\n        if self.action_delay_type == \"episode\":\n            self._current_policy_action_delay_step = (\n                self._sample_policy_action_delay_step()\n            )\n        else:\n            self._current_policy_action_delay_step = 0\n\n    def _apply_action_delay(self, raw_actions: np.ndarray) -> np.ndarray:\n        raw_actions = np.asarray(raw_actions, dtype=np.float32)\n        if self.policy_action_delay_step <= 0:\n            return raw_actions.copy()\n\n        expected_buffer_len = max(1, self.policy_action_delay_step + 1)\n        if (\n            not hasattr(self, \"_policy_action_delay_buffer\")\n            or self._policy_action_delay_buffer.maxlen != expected_buffer_len\n        ):\n            self._reset_action_delay_randomization()\n\n        if self.action_delay_type == \"step\":\n            self._current_policy_action_delay_step = (\n                self._sample_policy_action_delay_step()\n            )\n\n        self._policy_action_delay_buffer.append(raw_actions.copy())\n        if self._current_policy_action_delay_step >= len(\n            self._policy_action_delay_buffer\n        ):\n            return self._policy_action_delay_buffer[-1].copy()\n\n        return self._policy_action_delay_buffer[\n            -1 - self._current_policy_action_delay_step\n        ].copy()\n\n    @staticmethod\n    def _normalize_foot_geom_name_groups(raw_spec) -> list[list[str]]:\n        if raw_spec is None:\n            return [[], []]\n\n        if OmegaConf.is_config(raw_spec):\n            raw_spec = OmegaConf.to_container(raw_spec, resolve=True)\n\n        def coerce_names(value) -> list[str]:\n            if value is None:\n                return []\n            if isinstance(value, str):\n                return [value]\n            if isinstance(value, (list, tuple)):\n                return [str(name) for name in value if str(name)]\n            return []\n\n        if isinstance(raw_spec, dict):\n            return [\n                coerce_names(raw_spec.get(\"left\", raw_spec.get(\"left_foot\"))),\n                coerce_names(\n                    raw_spec.get(\"right\", raw_spec.get(\"right_foot\"))\n                ),\n            ]\n\n        if isinstance(raw_spec, (list, tuple)) and len(raw_spec) == 2:\n            return [coerce_names(raw_spec[0]), coerce_names(raw_spec[1])]\n\n        logger.warning(\n            \"Unsupported robot.feet_geom_names format. Ignoring configured \"\n            \"foot geom names.\"\n        )\n        return [[], []]\n\n    @staticmethod\n    def _normalize_foot_body_name_groups(raw_spec) -> list[list[str]]:\n        if raw_spec is None:\n            return [\n                list(DEFAULT_FEET_BODY_NAMES[\"left\"]),\n                list(DEFAULT_FEET_BODY_NAMES[\"right\"]),\n            ]\n\n        if OmegaConf.is_config(raw_spec):\n            raw_spec = OmegaConf.to_container(raw_spec, resolve=True)\n\n        def coerce_names(value) -> list[str]:\n            if value is None:\n                return []\n            if isinstance(value, str):\n                return [value]\n            if isinstance(value, (list, tuple)):\n                return [str(name) for name in value if str(name)]\n            return []\n\n        if isinstance(raw_spec, dict):\n            return [\n                coerce_names(raw_spec.get(\"left\", raw_spec.get(\"left_foot\"))),\n                coerce_names(\n                    raw_spec.get(\"right\", raw_spec.get(\"right_foot\"))\n                ),\n            ]\n\n        if isinstance(raw_spec, (list, tuple)) and len(raw_spec) == 2:\n            return [coerce_names(raw_spec[0]), coerce_names(raw_spec[1])]\n\n        logger.warning(\n            \"Unsupported robot.feet_body_names format. Falling back to \"\n            f\"default foot bodies: {DEFAULT_FEET_BODY_NAMES}\"\n        )\n        return [\n            list(DEFAULT_FEET_BODY_NAMES[\"left\"]),\n            list(DEFAULT_FEET_BODY_NAMES[\"right\"]),\n        ]\n\n    def _resolve_foot_geom_ids_from_geom_names(\n        self, foot_geom_name_groups: list[list[str]]\n    ) -> list[list[int]]:\n        foot_geom_id_groups: list[list[int]] = [[], []]\n        for side_idx, geom_names in enumerate(foot_geom_name_groups):\n            for geom_name in geom_names:\n                geom_id = mujoco.mj_name2id(\n                    self.m, mujoco.mjtObj.mjOBJ_GEOM, geom_name\n                )\n                if geom_id == -1:\n                    logger.warning(\n                        f\"Foot geom '{geom_name}' was not found in the MuJoCo model.\"\n                    )\n                    continue\n                foot_geom_id_groups[side_idx].append(int(geom_id))\n        return foot_geom_id_groups\n\n    def _resolve_foot_geom_ids_from_body_names(\n        self, foot_body_name_groups: list[list[str]]\n    ) -> list[list[int]]:\n        foot_geom_id_groups: list[list[int]] = [[], []]\n        geom_bodyid = np.asarray(self.m.geom_bodyid, dtype=np.int32)\n        geom_contype = np.asarray(self.m.geom_contype, dtype=np.int32)\n        geom_conaffinity = np.asarray(self.m.geom_conaffinity, dtype=np.int32)\n        collidable_mask = (geom_contype != 0) | (geom_conaffinity != 0)\n\n        for side_idx, body_names in enumerate(foot_body_name_groups):\n            resolved_geom_ids: list[int] = []\n            for body_name in body_names:\n                body_id = mujoco.mj_name2id(\n                    self.m, mujoco.mjtObj.mjOBJ_BODY, body_name\n                )\n                if body_id == -1:\n                    logger.warning(\n                        f\"Foot body '{body_name}' was not found in the MuJoCo model.\"\n                    )\n                    continue\n                body_geom_ids = np.flatnonzero(geom_bodyid == int(body_id))\n                if body_geom_ids.size == 0:\n                    logger.warning(\n                        f\"Foot body '{body_name}' has no attached geoms.\"\n                    )\n                    continue\n                contact_geom_ids = body_geom_ids[\n                    collidable_mask[body_geom_ids]\n                ]\n                if contact_geom_ids.size == 0:\n                    contact_geom_ids = body_geom_ids\n                resolved_geom_ids.extend(contact_geom_ids.astype(int).tolist())\n\n            # Preserve order while removing duplicates.\n            deduped = list(dict.fromkeys(resolved_geom_ids))\n            foot_geom_id_groups[side_idx] = deduped\n        return foot_geom_id_groups\n\n    def _init_low_level_foot_contact_logging(self) -> None:\n        self._foot_geom_id_groups = [[], []]\n        self._foot_geom_id_to_side = {}\n        self._foot_contact_logging_enabled = False\n        self._prev_low_level_foot_geom_centers = None\n\n        foot_geom_name_groups = self._normalize_foot_geom_name_groups(\n            getattr(self.config.robot, \"feet_geom_names\", None)\n        )\n        foot_body_name_groups = self._normalize_foot_body_name_groups(\n            getattr(self.config.robot, \"feet_body_names\", None)\n        )\n        geom_name_groups = self._resolve_foot_geom_ids_from_geom_names(\n            foot_geom_name_groups\n        )\n        body_name_groups = self._resolve_foot_geom_ids_from_body_names(\n            foot_body_name_groups\n        )\n\n        for side_idx in range(2):\n            resolved_ids = (\n                geom_name_groups[side_idx]\n                if len(geom_name_groups[side_idx]) > 0\n                else body_name_groups[side_idx]\n            )\n            self._foot_geom_id_groups[side_idx] = list(resolved_ids)\n            for geom_id in resolved_ids:\n                self._foot_geom_id_to_side[int(geom_id)] = side_idx\n\n        if any(len(group) == 0 for group in self._foot_geom_id_groups):\n            logger.warning(\n                \"Low-level foot contact logging is unavailable because one or \"\n                \"both foot geom groups could not be resolved. Contact metrics \"\n                \"will be written as NaN.\"\n            )\n            return\n\n        self._foot_contact_logging_enabled = True\n\n    def _record_low_level_foot_contact_sample(self) -> None:\n        foot_contact = np.full((2,), np.nan, dtype=np.float32)\n        foot_normal_force = np.full((2,), np.nan, dtype=np.float32)\n        foot_tangent_speed = np.full((2,), np.nan, dtype=np.float32)\n\n        if not self._foot_contact_logging_enabled:\n            self._robot_low_level_foot_contact_seq.append(foot_contact)\n            self._robot_low_level_foot_normal_force_seq.append(\n                foot_normal_force\n            )\n            self._robot_low_level_foot_tangent_speed_seq.append(\n                foot_tangent_speed\n            )\n            return\n\n        current_centers = np.zeros((2, 3), dtype=np.float32)\n        for side_idx, geom_ids in enumerate(self._foot_geom_id_groups):\n            current_centers[side_idx] = np.mean(\n                self.d.geom_xpos[np.asarray(geom_ids, dtype=np.int32)],\n                axis=0,\n            ).astype(np.float32)\n\n        if self._prev_low_level_foot_geom_centers is None:\n            tangential_speed = np.zeros((2,), dtype=np.float32)\n        else:\n            foot_velocity = (\n                current_centers - self._prev_low_level_foot_geom_centers\n            ) / np.float32(self.simulation_dt)\n            tangential_speed = np.linalg.norm(\n                foot_velocity[:, :2], axis=1\n            ).astype(np.float32)\n        self._prev_low_level_foot_geom_centers = current_centers.copy()\n\n        foot_contact.fill(0.0)\n        foot_normal_force.fill(0.0)\n        foot_tangent_speed = tangential_speed\n\n        contact_force = np.zeros(6, dtype=np.float64)\n        for contact_idx in range(int(self.d.ncon)):\n            contact = self.d.contact[contact_idx]\n            contact_sides = set()\n            geom1 = int(contact.geom1)\n            geom2 = int(contact.geom2)\n            if geom1 in self._foot_geom_id_to_side:\n                contact_sides.add(self._foot_geom_id_to_side[geom1])\n            if geom2 in self._foot_geom_id_to_side:\n                contact_sides.add(self._foot_geom_id_to_side[geom2])\n            if len(contact_sides) != 1:\n                continue\n\n            side_idx = next(iter(contact_sides))\n            foot_contact[side_idx] = 1.0\n            mujoco.mj_contactForce(self.m, self.d, contact_idx, contact_force)\n            foot_normal_force[side_idx] += np.float32(abs(contact_force[0]))\n\n        self._robot_low_level_foot_contact_seq.append(foot_contact)\n        self._robot_low_level_foot_normal_force_seq.append(foot_normal_force)\n        self._robot_low_level_foot_tangent_speed_seq.append(foot_tangent_speed)\n\n    @staticmethod\n    def _flatten_single_step_output(values, *, dtype=None) -> np.ndarray:\n        arr = np.asarray(values, dtype=dtype)\n        if arr.ndim == 0:\n            raise ValueError(\n                \"Expected at least 1D output for single-step ONNX routing dump.\"\n            )\n        return arr.reshape(-1, arr.shape[-1])[0]\n\n    def _discover_policy_moe_outputs(self) -> None:\n        self.policy_moe_layer_output_names: list[tuple[int, str, str]] = []\n        routing_outputs: dict[int, dict[str, str]] = {}\n        pattern = re.compile(r\"^moe_layer_(\\d+)_expert_(indices|logits)$\")\n        for node in self.policy_session.get_outputs():\n            match = pattern.fullmatch(node.name)\n            if match is None:\n                continue\n            layer_idx = int(match.group(1))\n            kind = str(match.group(2))\n            routing_outputs.setdefault(layer_idx, {})[kind] = node.name\n\n        for layer_idx in sorted(routing_outputs):\n            layer_outputs = routing_outputs[layer_idx]\n            if \"indices\" not in layer_outputs or \"logits\" not in layer_outputs:\n                logger.warning(\n                    \"Skipping incomplete MoE routing outputs for layer \"\n                    f\"{layer_idx}: {sorted(layer_outputs)}\"\n                )\n                continue\n            self.policy_moe_layer_output_names.append(\n                (\n                    layer_idx,\n                    layer_outputs[\"indices\"],\n                    layer_outputs[\"logits\"],\n                )\n            )\n        if self.policy_moe_layer_output_names:\n            logger.info(\n                \"Detected MoE routing outputs for layers: \"\n                f\"{[layer_idx for layer_idx, _, _ in self.policy_moe_layer_output_names]}\"\n            )\n\n    def _get_stacked_moe_routing_tensors(\n        self,\n    ) -> tuple[np.ndarray | None, np.ndarray | None]:\n        indices_seq = getattr(self, \"_robot_moe_expert_indices_seq\", [])\n        logits_seq = getattr(self, \"_robot_moe_expert_logits_seq\", [])\n        if len(indices_seq) == 0 or len(logits_seq) == 0:\n            return None, None\n        return (\n            np.stack(indices_seq, axis=0).astype(np.int64),\n            np.stack(logits_seq, axis=0).astype(np.float32),\n        )\n\n    def _get_stacked_low_level_foot_contact_tensors(\n        self,\n    ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:\n        contact_seq = getattr(self, \"_robot_low_level_foot_contact_seq\", [])\n        normal_force_seq = getattr(\n            self, \"_robot_low_level_foot_normal_force_seq\", []\n        )\n        tangent_speed_seq = getattr(\n            self, \"_robot_low_level_foot_tangent_speed_seq\", []\n        )\n        if contact_seq and normal_force_seq and tangent_speed_seq:\n            return (\n                np.stack(contact_seq, axis=0).astype(np.float32),\n                np.stack(normal_force_seq, axis=0).astype(np.float32),\n                np.stack(tangent_speed_seq, axis=0).astype(np.float32),\n            )\n\n        num_low_level_samples = len(\n            getattr(self, \"_robot_low_level_dof_torque_seq\", [])\n        )\n        if num_low_level_samples <= 0:\n            return None, None, None\n\n        nan_array = np.full((num_low_level_samples, 2), np.nan, np.float32)\n        return nan_array.copy(), nan_array.copy(), nan_array.copy()\n\n    def _record_onnx_io_frame(self, input_feed, output_names, onnx_output):\n        if not self._onnx_io_input_names:\n            self._onnx_io_input_names = list(input_feed.keys())\n            self._onnx_io_inputs = {\n                name: [] for name in self._onnx_io_input_names\n            }\n        if not self._onnx_io_output_names:\n            self._onnx_io_output_names = list(output_names)\n            self._onnx_io_outputs = {\n                name: [] for name in self._onnx_io_output_names\n            }\n\n        for name in self._onnx_io_input_names:\n            if name not in input_feed:\n                raise KeyError(f\"Missing ONNX input tensor: {name}\")\n            self._onnx_io_inputs[name].append(\n                np.array(input_feed[name], copy=True)\n            )\n        for name, value in zip(self._onnx_io_output_names, onnx_output):\n            self._onnx_io_outputs[name].append(np.array(value, copy=True))\n\n    @staticmethod\n    def _stack_onnx_io_frames(\n        frame_dict: dict[str, list[np.ndarray]],\n    ) -> dict[str, np.ndarray]:\n        stacked: dict[str, np.ndarray] = {}\n        for name, frames in frame_dict.items():\n            if frames:\n                stacked[name] = np.stack(frames, axis=0)\n            else:\n                stacked[name] = np.empty((0,), dtype=np.float32)\n        return stacked\n\n    def save_onnx_io_dump(self, output_path, meta_info):\n        payload = {\n            \"input_names\": list(self._onnx_io_input_names),\n            \"output_names\": list(self._onnx_io_output_names),\n            \"inputs\": self._stack_onnx_io_frames(self._onnx_io_inputs),\n            \"outputs\": self._stack_onnx_io_frames(self._onnx_io_outputs),\n            \"source_npz\": meta_info.get(\n                \"source_npz\", meta_info.get(\"source_file\", \"\")\n            ),\n            \"onnx_model\": meta_info.get(\n                \"onnx_model\", meta_info.get(\"model\", \"\")\n            ),\n        }\n        np.save(output_path, payload, allow_pickle=True)\n\n    def _find_actor_place_holder_ndim(self):\n        n_dim = 0\n        for obs_dict in self._get_policy_atomic_obs_list():\n            name = str(list(obs_dict.keys())[0])\n            if name == \"place_holder\":\n                params = obs_dict[\"place_holder\"].get(\"params\", {})\n                n_dim = int(params.get(\"n_dim\", 0))\n            if name == \"actor_place_holder\":\n                params = obs_dict[\"actor_place_holder\"].get(\"params\", {})\n                n_dim = int(params.get(\"n_dim\", 0))\n        return n_dim\n\n    def _get_actor_obs_term_params(self, term_name: str) -> dict:\n        for obs_dict in self._get_policy_atomic_obs_list():\n            configured_name = str(list(obs_dict.keys())[0])\n            if configured_name != term_name:\n                continue\n            term_cfg = obs_dict[configured_name]\n            if not isinstance(term_cfg, dict):\n                return {}\n            params = term_cfg.get(\"params\", {})\n            return dict(params) if isinstance(params, dict) else {}\n        return {}\n\n    def _get_ref_keybody_indices(self, term_name: str) -> np.ndarray:\n        params = self._get_actor_obs_term_params(term_name)\n        keybody_names = params.get(\"keybody_names\", None)\n        body_names = [str(name) for name in self.config.robot.body_names]\n        if keybody_names is None:\n            return np.arange(len(body_names), dtype=np.int64)\n\n        keybody_names = [str(name) for name in keybody_names]\n        body_name_to_idx = {\n            body_name: idx for idx, body_name in enumerate(body_names)\n        }\n        missing_names = [\n            name for name in keybody_names if name not in body_name_to_idx\n        ]\n        if len(missing_names) > 0:\n            raise ValueError(\n                f\"Unknown keybody_names in '{term_name}': {missing_names}. \"\n                f\"Available body names: {body_names}\"\n            )\n\n        return np.asarray(\n            [body_name_to_idx[name] for name in keybody_names],\n            dtype=np.int64,\n        )\n\n    @staticmethod\n    def _to_plain_obs_cfg(cfg):\n        if OmegaConf.is_config(cfg):\n            plain_cfg = OmegaConf.to_container(cfg, resolve=True)\n        else:\n            plain_cfg = dict(cfg)\n        if not isinstance(plain_cfg, dict):\n            raise ValueError(\n                f\"Observation term config must be a mapping, got {type(plain_cfg)}\"\n            )\n        return plain_cfg\n\n    def _get_actor_obs_schema_terms(self) -> list[str]:\n        modules_cfg = self.config.get(\"modules\", None)\n        if modules_cfg is None:\n            return []\n        actor_cfg = modules_cfg.get(\"actor\", None)\n        if actor_cfg is None:\n            return []\n        obs_schema = actor_cfg.get(\"obs_schema\", None)\n        if obs_schema is None:\n            return []\n\n        ordered_terms: list[str] = []\n        for _, seq_cfg in obs_schema.items():\n            seq_terms = seq_cfg.get(\"terms\", [])\n            ordered_terms.extend(str(term) for term in seq_terms)\n        return ordered_terms\n\n    def _get_actor_atomic_obs_entries(self) -> list[tuple[str, str, dict]]:\n        obs_cfg = self.config.get(\"obs\", None)\n        if obs_cfg is None:\n            raise ValueError(\"Missing config.obs for MuJoCo sim2sim\")\n        obs_groups = obs_cfg.get(\"obs_groups\", None)\n        if obs_groups is None:\n            raise ValueError(\n                \"Missing config.obs.obs_groups for MuJoCo sim2sim\"\n            )\n\n        if obs_groups.get(\"policy\", None) is not None:\n            entries: list[tuple[str, str, dict]] = []\n            for term_dict in obs_groups.policy.atomic_obs_list:\n                term_name = str(list(term_dict.keys())[0])\n                entries.append(\n                    (\n                        \"policy\",\n                        term_name,\n                        self._to_plain_obs_cfg(term_dict[term_name]),\n                    )\n                )\n            return entries\n\n        if obs_groups.get(\"unified\", None) is not None:\n            entries = []\n            for term_dict in obs_groups.unified.atomic_obs_list:\n                term_name = str(list(term_dict.keys())[0])\n                if term_name.startswith(\"critic_\"):\n                    continue\n                entries.append(\n                    (\n                        \"unified\",\n                        term_name,\n                        self._to_plain_obs_cfg(term_dict[term_name]),\n                    )\n                )\n            if not entries:\n                raise ValueError(\n                    \"obs_groups.unified found but contains no non-critic terms.\"\n                )\n            return entries\n\n        raise ValueError(\n            \"Unsupported obs config for MuJoCo sim2sim: expected obs_groups.policy or obs_groups.unified.\"\n        )\n\n    def _get_policy_atomic_obs_list(self):\n        \"\"\"Resolve the atomic obs list used to build the ONNX policy input.\n\n        Supports both legacy configs (obs_groups.policy) and PULSE-stage2 configs\n        that use a unified group (obs_groups.unified) with actor_/critic_ prefixes.\n        \"\"\"\n        actor_atomic_entries = self._get_actor_atomic_obs_entries()\n        schema_terms = self._get_actor_obs_schema_terms()\n\n        if len(schema_terms) == 0:\n            logger.warning(\n                \"modules.actor.obs_schema is unavailable; using obs_groups actor term order for MuJoCo policy input.\"\n            )\n            return [\n                {term_name: cfg} for _, term_name, cfg in actor_atomic_entries\n            ]\n\n        by_full_key: dict[str, tuple[str, dict]] = {}\n        by_leaf_key: dict[str, tuple[str, dict]] = {}\n        ambiguous_leaf_keys: set[str] = set()\n        for group_name, term_name, term_cfg in actor_atomic_entries:\n            full_key = f\"{group_name}/{term_name}\"\n            by_full_key[full_key] = (term_name, term_cfg)\n\n            if term_name in by_leaf_key:\n                ambiguous_leaf_keys.add(term_name)\n            else:\n                by_leaf_key[term_name] = (term_name, term_cfg)\n\n        ordered_atomic_list = []\n        for schema_term in schema_terms:\n            schema_term_key = str(schema_term)\n            if schema_term_key in by_full_key:\n                term_name, term_cfg = by_full_key[schema_term_key]\n                ordered_atomic_list.append({term_name: term_cfg})\n                continue\n\n            leaf_key = schema_term_key.split(\"/\")[-1]\n            if leaf_key in ambiguous_leaf_keys:\n                raise ValueError(\n                    \"Actor obs_schema term \"\n                    f\"'{schema_term}' is ambiguous by leaf key '{leaf_key}'. \"\n                    \"Use explicit group/term hierarchy in obs_schema terms.\"\n                )\n            if leaf_key not in by_leaf_key:\n                raise ValueError(\n                    \"Actor obs_schema term \"\n                    f\"'{schema_term}' is not present in obs_groups actor atomic obs list.\"\n                )\n            term_name, term_cfg = by_leaf_key[leaf_key]\n            ordered_atomic_list.append({term_name: term_cfg})\n        return ordered_atomic_list\n\n    # ----------------- Kinematics / velocities -----------------\n\n    # ----------------- Kinematics / velocities -----------------\n    def _body_origin_world_velocity(\n        self, body_id: int\n    ) -> tuple[np.ndarray, np.ndarray]:\n        \"\"\"Compute world-frame spatial velocity (v, w) of a body's frame origin.\n\n        Returns:\n            tuple: (lin_vel_w[3], ang_vel_w[3]) in world coordinates.\n        \"\"\"\n        # World-frame Jacobians for body origin\n        jacp = np.zeros((3, self.m.nv), dtype=np.float64)\n        jacr = np.zeros((3, self.m.nv), dtype=np.float64)\n        mujoco.mj_jacBody(self.m, self.d, jacp, jacr, int(body_id))\n        # qvel is float64 in MuJoCo; keep computation in float64 then cast\n        lin_vel_w = jacp @ self.d.qvel\n        ang_vel_w = jacr @ self.d.qvel\n        return lin_vel_w.astype(np.float32), ang_vel_w.astype(np.float32)\n\n    # ----------------- Body name/id resolution -----------------\n    def _get_anchor_body_name(self) -> str:\n        if not hasattr(self, \"anchor_body_name\"):\n            self.anchor_body_name = str(\n                getattr(self.config.robot, \"anchor_body\", \"pelvis\")\n            )\n        logger.info(f\"Anchor body name: {self.anchor_body_name}\")\n        return self.anchor_body_name\n\n    def _get_torso_body_name(self) -> str:\n        if not hasattr(self, \"torso_body_name\"):\n            self.torso_body_name = str(\n                getattr(self.config.robot, \"torso_name\", \"torso_link\")\n            )\n        return self.torso_body_name\n\n    @property\n    def ref_motion_frame_idx(self):\n        return self.motion_frame_idx\n\n    @property\n    def anchor_body_idx(self) -> int:\n        return self.config.robot.body_names.index(\n            self.config.robot.anchor_body\n        )\n\n    @property\n    def root_body_idx(self) -> int:\n        return 0\n\n    @property\n    def torso_body_idx(self) -> int:\n        return self.config.robot.body_names.index(self.config.robot.torso_name)\n\n    @property\n    def robot_global_bodylink_pos(self):\n        \"\"\"World-frame positions of all robot bodies at their MuJoCo body frame origins.\n\n        MuJoCo stores body state for a special world body at index 0, which does not\n        correspond to any physical link and is always static. We slice it out and\n        return `xpos[1:]` so that row 0 corresponds to the root body (e.g. pelvis)\n        and the body dimension matches the HoloMotion NPZ `*_global_translation`\n        arrays.\n\n        Returns:\n            np.ndarray: Array of shape [n_bodies, 3] in MuJoCo body order with the\n            world body excluded.\n        \"\"\"\n        return self.d.xpos[1:]\n\n    @property\n    def robot_global_bodylink_rot(self):\n        \"\"\"World-frame orientations of all robot bodies as WXYZ quaternions.\n\n        As with positions, the MuJoCo world body at index 0 is excluded so that the\n        returned array is aligned with the body dimension used in HoloMotion NPZ\n        `*_global_rotation_quat` arrays (root at index 0, no world entry).\n\n        Returns:\n            np.ndarray: Array of shape [n_bodies, 4] in MuJoCo body order with the\n            world body excluded.\n        \"\"\"\n        xquat = self.d.xquat[1:]\n        xquat_t = torch.as_tensor(xquat, dtype=torch.float32, device=\"cpu\")\n        xquat_t = standardize_quaternion(xquat_t)\n\n        return xquat_t.detach().cpu().numpy()\n\n    @property\n    def robot_global_bodylink_lin_vel(self):\n        \"\"\"World-frame linear velocities of all robot body frame origins.\n\n        Uses `mujoco.mj_objectVelocity` with `mjOBJ_BODY` and `flg_centered=0` to\n        query the 6D spatial velocity at each body's frame origin, then slices the\n        translational component. The world body (ID 0) is excluded so that the body\n        dimension matches the NPZ `*_global_velocity` arrays.\n\n        Returns:\n            np.ndarray: Array of shape [n_bodies, 3] giving linear velocities in the\n            MuJoCo world frame, ordered by body ID starting from the root body.\n        \"\"\"\n        nbody = int(self.m.nbody)\n        vel_6d = np.zeros((nbody, 6), dtype=np.float64)\n        for bid in range(1, nbody):\n            mujoco.mj_objectVelocity(\n                self.m,\n                self.d,\n                mujoco.mjtObj.mjOBJ_BODY,\n                bid,\n                vel_6d[bid],\n                0,\n            )\n        return vel_6d[1:, 3:6]\n\n    @property\n    def robot_global_bodylink_ang_vel(self):\n        \"\"\"World-frame angular velocities of all robot body frame origins.\n\n        Uses the same `mujoco.mj_objectVelocity` call as\n        `robot_global_bodylink_lin_vel` and slices the rotational component. The\n        world body (ID 0) is dropped so that the body dimension is identical to the\n        NPZ `*_global_angular_velocity` arrays and the translation/rotation/velocity\n        tensors all share the same body ordering.\n\n        Returns:\n            np.ndarray: Array of shape [n_bodies, 3] giving angular velocities in\n            the MuJoCo world frame, ordered by body ID starting from the root body.\n        \"\"\"\n        nbody = int(self.m.nbody)\n        vel_6d = np.zeros((nbody, 6), dtype=np.float64)\n        for bid in range(1, nbody):\n            mujoco.mj_objectVelocity(\n                self.m,\n                self.d,\n                mujoco.mjtObj.mjOBJ_BODY,\n                bid,\n                vel_6d[bid],\n                0,\n            )\n        return vel_6d[1:, 0:3]\n\n    @property\n    def robot_dof_pos(self):\n        if hasattr(self, \"actuator_qpos_indices\"):\n            return self.d.qpos[self.actuator_qpos_indices]\n        return self.d.qpos[7:]\n\n    @property\n    def robot_dof_vel(self):\n        if hasattr(self, \"actuator_qvel_indices\"):\n            return self.d.qvel[self.actuator_qvel_indices]\n        return self.d.qvel[6:]\n\n    # ----------------- Reference->Simulation alignment -----------------\n\n    def _ensure_ref_to_sim_transform_rigid(self):\n        \"\"\"Compute rigid transform (yaw + translation) from reference globals to sim globals.\n\n        The transform is defined such that the reference **anchor body** pose at frame 0 is mapped\n        onto the robot's current global anchor pose in XY translation and yaw:\n\n        - `yaw(q_ref_to_sim * q_ref_anchor_0) = yaw(q_robot_anchor_0)`\n        - `t_ref_to_sim + R(q_ref_to_sim) @ t_ref_anchor_0 = t_robot_anchor_0`\n\n        This uses the robot's initial global pose so that arbitrary initialization offsets in\n        XY position and yaw between the robot and the reference motion are absorbed into the\n        reference->simulation mapping, and all subsequent reference globals are expressed in the\n        same world frame as the robot.\n        \"\"\"\n        if self._ref_to_sim_ready:\n            return\n\n        # If we don't have reference globals, fall back to identity transform.\n        if getattr(self, \"ref_global_translation\", None) is None:\n            self._ref_to_sim_q_wxyz = np.array(\n                [1.0, 0.0, 0.0, 0.0], dtype=np.float32\n            )\n            self._ref_to_sim_t = np.zeros(3, dtype=np.float32)\n            self._ref_to_sim_ready = True\n            logger.info(\n                \"No reference global translations available; using identity Ref->Sim transform.\"\n            )\n            return\n\n        # If rotations are missing, keep the previous translation-only semantics.\n        if getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None:\n            t_robot = torch.as_tensor(\n                self.robot_global_bodylink_pos[self.anchor_body_idx],\n                dtype=torch.float32,\n                device=\"cpu\",\n            )\n            t_ref = torch.as_tensor(\n                self.ref_global_translation[0, self.anchor_body_idx].astype(\n                    np.float32\n                ),\n                dtype=torch.float32,\n                device=\"cpu\",\n            )\n            t_ref_to_sim = t_robot - t_ref\n            self._ref_to_sim_q_wxyz = np.array(\n                [1.0, 0.0, 0.0, 0.0], dtype=np.float32\n            )\n            self._ref_to_sim_t = t_ref_to_sim.detach().cpu().numpy()\n            self._ref_to_sim_ready = True\n            logger.info(\n                \"Reference rotations missing; initialized Ref->Sim as translation-only \"\n                f\"transform. t={self._ref_to_sim_t}\"\n            )\n            return\n\n        # Anchor body index shared between robot globals and reference globals\n        anchor_idx = self.anchor_body_idx\n\n        # Robot anchor pose in simulation world frame (after initial state has been set)\n        t_robot = torch.as_tensor(\n            self.robot_global_bodylink_pos[anchor_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_robot_wxyz = torch.as_tensor(\n            self.robot_global_bodylink_rot[anchor_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n\n        # Reference anchor pose at frame 0 in NPZ global frame\n        t_ref0 = torch.as_tensor(\n            self.ref_global_translation[0, anchor_idx].astype(np.float32),\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_ref0_xyzw = torch.as_tensor(\n            self.ref_global_rotation_quat_xyzw[0, anchor_idx].astype(\n                np.float32\n            ),\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_ref0_wxyz = xyzw_to_wxyz(q_ref0_xyzw)\n\n        # Yaw-only rotation mapping: align reference yaw to robot yaw (keep roll/pitch from reference).\n        R_robot = matrix_from_quat(q_robot_wxyz)\n        R_ref0 = matrix_from_quat(q_ref0_wxyz)\n        yaw_robot = torch.atan2(R_robot[1, 0], R_robot[0, 0])\n        yaw_ref0 = torch.atan2(R_ref0[1, 0], R_ref0[0, 0])\n        yaw_delta = yaw_robot - yaw_ref0\n\n        yaw_quat_xyzw = quat_from_euler_xyz(\n            torch.tensor(0.0, dtype=torch.float32, device=\"cpu\"),\n            torch.tensor(0.0, dtype=torch.float32, device=\"cpu\"),\n            yaw_delta,\n        )\n        q_ref_to_sim = xyzw_to_wxyz(yaw_quat_xyzw)\n        q_ref_to_sim = quat_normalize_wxyz(q_ref_to_sim)\n\n        # Translation mapping: t_ref_to_sim + R(q_ref_to_sim) @ t_ref0 = t_robot\n        t_ref0_in_sim = quat_apply(q_ref_to_sim, t_ref0)\n        t_ref_to_sim = t_robot - t_ref0_in_sim\n\n        self._ref_to_sim_q_wxyz = (\n            q_ref_to_sim.detach().cpu().numpy().astype(np.float32)\n        )\n        self._ref_to_sim_t = (\n            t_ref_to_sim.detach().cpu().numpy().astype(np.float32)\n        )\n\n        self._ref_to_sim_ready = True\n        logger.info(\n            \"Initialized Ref->Sim rigid transform. \"\n            f\"q={self._ref_to_sim_q_wxyz}, t={self._ref_to_sim_t}\"\n        )\n\n    def _detect_command_mode(self) -> str:\n        m_dir = self.config.get(\"motion_npz_dir\") or self.config.get(\n            \"eval\", {}\n        ).get(\"motion_npz_dir\")\n        m_path = self.config.get(\"motion_npz_path\") or self.config.get(\n            \"eval\", {}\n        ).get(\"motion_npz_path\")\n        if m_path is not None and not os.path.exists(m_path):\n            raise FileNotFoundError(f\"Motion file not found: {m_path}\")\n\n        if (m_dir and str(m_dir) != \"\") or (m_path and str(m_path) != \"\"):\n            return \"motion_tracking\"\n        return \"velocity_tracking\"\n\n    def _init_obs_buffers(self):\n        atomic_list = self._get_policy_atomic_obs_list()\n        obs_policy_cfg = {\"atomic_obs_list\": atomic_list}\n        self.obs_builder = PolicyObsBuilder(\n            dof_names_onnx=self.dof_names_onnx,\n            default_angles_onnx=self.default_angles_onnx,\n            evaluator=self,\n            obs_policy_cfg=obs_policy_cfg,\n        )\n\n    def load_policy(self):\n        \"\"\"Load the policy model using ONNX Runtime.\"\"\"\n        onnx_model_path = Path(self.config.ckpt_onnx_path)\n\n        logger.info(f\"Loading ONNX policy from {onnx_model_path}\")\n\n        providers = [\"CPUExecutionProvider\"]\n        use_gpu = _coerce_config_bool(\n            self.config.get(\"use_gpu\", False), default=False\n        )\n        gpu_id = int(self.config.get(\"gpu_id\", 0))\n\n        available_providers = onnxruntime.get_available_providers()\n        if use_gpu:\n            if \"CUDAExecutionProvider\" in available_providers:\n                cuda_options = {\"device_id\": gpu_id}\n                if torch.cuda.is_available():\n                    torch.cuda.set_device(gpu_id)\n                    cuda_options[\"user_compute_stream\"] = str(\n                        torch.cuda.current_stream().cuda_stream\n                    )\n                providers = [\n                    (\"CUDAExecutionProvider\", cuda_options),\n                    \"CPUExecutionProvider\",\n                ]\n                logger.info(\n                    f\"Using CUDAExecutionProvider with gpu_id={gpu_id}\"\n                )\n            else:\n                logger.warning(\n                    \"use_gpu=true but CUDAExecutionProvider is unavailable; \"\n                    \"falling back to CPUExecutionProvider.\"\n                )\n\n        sess_options = onnxruntime.SessionOptions()\n        sess_options.intra_op_num_threads = 1\n        sess_options.inter_op_num_threads = 1\n        sess_options.log_severity_level = 3\n\n        self.policy_session = onnxruntime.InferenceSession(\n            str(onnx_model_path),\n            sess_options=sess_options,\n            providers=providers,\n        )\n        logger.info(\n            f\"ONNX Runtime session created successfully using: {self.policy_session.get_providers()}\"\n        )\n        self.policy_input_name = self.policy_session.get_inputs()[0].name\n        self.policy_output_name = self.policy_session.get_outputs()[0].name\n        logger.info(\n            f\"Policy  ONNX Input: {self.policy_input_name}, Output: {self.policy_output_name}\"\n        )\n\n        logger.info(\"Initializing KV-Cache for Policy...\")\n\n        self.policy_input_name = \"obs\"\n        self.policy_kv_input_name = None\n        self.policy_step_input_name = None\n        self.policy_kv_shape = None\n\n        for node in self.policy_session.get_inputs():\n            name = node.name\n            shape = node.shape\n            logger.info(f\"Model Input: Name={name}, Shape={shape}\")\n\n            if \"obs\" in name:\n                self.policy_input_name = name\n            elif \"past_key_values\" in name:\n                self.policy_kv_input_name = name\n                self.policy_kv_shape = shape\n            elif \"step_idx\" in name or \"step\" in name or \"pos\" in name:\n                self.policy_step_input_name = name\n\n        self.policy_output_name = self.policy_session.get_outputs()[0].name\n        self.policy_kv_output_name = None\n        for node in self.policy_session.get_outputs():\n            if \"present_key_values\" in node.name:\n                self.policy_kv_output_name = node.name\n        self._discover_policy_moe_outputs()\n\n        if self.policy_kv_input_name and self.policy_kv_shape:\n            shape = [\n                d if isinstance(d, int) else 1 for d in self.policy_kv_shape\n            ]\n            self.policy_kv_cache = np.zeros(shape, dtype=np.float32)\n            self.policy_model_context_len = (\n                int(shape[3]) if len(shape) > 3 else 0\n            )\n            if self.max_context_len > 0 and self.policy_model_context_len > 0:\n                self.policy_effective_context_len = min(\n                    self.max_context_len, self.policy_model_context_len\n                )\n                logger.info(\n                    \"Using context window from \"\n                    f\"algo.config.num_steps_per_env={self.max_context_len} \"\n                    f\"(model cache len={self.policy_model_context_len}, \"\n                    f\"effective={self.policy_effective_context_len})\"\n                )\n            else:\n                self.policy_effective_context_len = (\n                    self.policy_model_context_len\n                )\n            self.use_kv_cache = True\n            logger.info(f\"KV-Cache ENABLED. Shape: {shape}\")\n        else:\n            self.use_kv_cache = False\n            self.policy_kv_cache = None\n            self.policy_model_context_len = 0\n            self.policy_effective_context_len = 0\n            logger.warning(\"KV-Cache NOT found in model inputs!\")\n            if self.max_context_len > 0:\n                logger.warning(\n                    \"algo.config.num_steps_per_env is set but KV-Cache is unavailable; \"\n                    \"ignoring context window limit.\"\n                )\n\n        logger.info(\"ONNX Policy loaded successfully\")\n\n    def _read_onnx_metadata(self) -> dict:\n        \"\"\"Read model metadata from ONNX file and parse into Python types.\"\"\"\n        onnx_model_path = Path(self.config.ckpt_onnx_path)\n\n        model = onnx.load(str(onnx_model_path))\n        meta = {p.key: p.value for p in model.metadata_props}\n\n        def _parse_floats(csv_str: str):\n            return np.array(\n                [float(x) for x in csv_str.split(\",\") if x != \"\"],\n                dtype=np.float32,\n            )\n\n        result = {}\n        result[\"action_scale\"] = _parse_floats(meta[\"action_scale\"])\n        result[\"kps\"] = _parse_floats(meta[\"joint_stiffness\"])\n        result[\"kds\"] = _parse_floats(meta[\"joint_damping\"])\n        result[\"default_joint_pos\"] = _parse_floats(meta[\"default_joint_pos\"])\n        result[\"joint_names\"] = [\n            x for x in meta[\"joint_names\"].split(\",\") if x != \"\"\n        ]\n\n        # 打印解析后的元数据\n        logger.info(\"=== Loaded ONNX Metadata ===\")\n        for key, value in result.items():\n            # 如果关节名称列表很长，进行格式化处理以保持整洁\n            if key == \"joint_names\":\n                logger.info(f\"{key}: {', '.join(value)}\")\n            else:\n                logger.info(f\"{key}:\\n{value}\")\n        logger.info(\"============================\")\n\n        return result\n\n    def _apply_onnx_metadata(self):\n        \"\"\"Apply PD/scale/defaults from ONNX metadata as authoritative values.\"\"\"\n        meta = self._read_onnx_metadata()\n        self.dof_names_onnx = meta[\"joint_names\"]\n        self.action_scale_onnx = meta[\"action_scale\"].astype(np.float32)\n        self.kps_onnx = meta[\"kps\"].astype(np.float32)\n        self.kds_onnx = meta[\"kds\"].astype(np.float32)\n        self.default_angles_onnx = meta[\"default_joint_pos\"].astype(np.float32)\n\n    def _build_dof_mappings(self):\n        # Map ONNX <-> MJCF for control\n        self.onnx_to_mu = [\n            self.dof_names_onnx.index(name) for name in self.mjcf_dof_names\n        ]\n        self.mu_to_onnx = [\n            self.mjcf_dof_names.index(name) for name in self.dof_names_onnx\n        ]\n        self.ref_to_onnx = [\n            self.dof_names_ref_motion.index(name)\n            for name in self.dof_names_onnx\n        ]\n\n        # Map MuJoCo actuator DOF order -> reference DOF order used in motion npz\n        self.mu_to_ref = []\n        for mu_idx in range(len(self.mjcf_dof_names)):\n            onnx_idx = self.onnx_to_mu[mu_idx]\n            ref_idx = self.ref_to_onnx[onnx_idx]\n            self.mu_to_ref.append(ref_idx)\n\n        self.kps_mu = self.kps_onnx[self.onnx_to_mu].astype(np.float32)\n        self.kds_mu = self.kds_onnx[self.onnx_to_mu].astype(np.float32)\n        self.default_angles_mu = self.default_angles_onnx[\n            self.onnx_to_mu\n        ].astype(np.float32)\n        self.action_scale_mu = self.action_scale_onnx[self.onnx_to_mu].astype(\n            np.float32\n        )\n\n    @staticmethod\n    def _normalize_filter_cutoff_hz(raw_values, num_frames: int) -> np.ndarray:\n        num_frames = max(int(num_frames), 0)\n        if num_frames == 0:\n            return np.zeros((0, 1), dtype=np.float32)\n        if raw_values is None:\n            return np.zeros((num_frames, 1), dtype=np.float32)\n\n        cutoff = np.asarray(raw_values, dtype=np.float32)\n        if cutoff.ndim == 0:\n            cutoff = np.full((num_frames, 1), float(cutoff), dtype=np.float32)\n            return cutoff\n        if cutoff.ndim == 1:\n            cutoff = cutoff[:, None]\n        else:\n            cutoff = cutoff.reshape(cutoff.shape[0], -1)[:, :1]\n\n        if cutoff.shape[0] == 0:\n            return np.zeros((num_frames, 1), dtype=np.float32)\n        if cutoff.shape[0] == 1 and num_frames > 1:\n            cutoff = np.repeat(cutoff, num_frames, axis=0)\n        elif cutoff.shape[0] < num_frames:\n            pad = np.repeat(\n                cutoff[-1:, :], num_frames - cutoff.shape[0], axis=0\n            )\n            cutoff = np.concatenate([cutoff, pad], axis=0)\n        elif cutoff.shape[0] > num_frames:\n            cutoff = cutoff[:num_frames]\n        return cutoff.astype(np.float32, copy=False)\n\n    def load_motion_data(self):\n        \"\"\"Load motion data from npz file.\"\"\"\n        motion_npz_path = self.config.get(\"motion_npz_path\", None)\n        if motion_npz_path is None:\n            logger.warning(\n                \"No motion_npz_path specified in config, using zero reference motion\"\n            )\n            return\n\n        logger.info(f\"Loading motion data from {motion_npz_path}\")\n\n        # Load npz file\n        with np.load(motion_npz_path, allow_pickle=True) as npz:\n            keys = list(npz.keys())\n            raw_filter_cutoff_hz = (\n                np.array(npz[\"filter_cutoff_hz\"]).astype(np.float32)\n                if \"filter_cutoff_hz\" in npz\n                else None\n            )\n\n            # Try direct arrays first (dof_pos, dof_vel or variants)\n            naming_pairs = [\n                (\"ref_dof_pos\", \"ref_dof_vel\"),\n                (\"dof_pos\", \"dof_vels\"),  # backward compat\n                # (\"ft_ref_pos\", \"ft_ref_dof_vel\"),\n            ]\n\n            pos_key = None\n            vel_key = None\n            for pos_k, vel_k in naming_pairs:\n                if pos_k in npz and vel_k in npz:\n                    pos_key = pos_k\n                    vel_key = vel_k\n                    break\n\n            if pos_key is not None and vel_key is not None:\n                # Direct arrays found\n                self.ref_dof_pos = np.array(npz[pos_key]).astype(np.float32)\n                self.ref_dof_vel = np.array(npz[vel_key]).astype(np.float32)\n            elif len(keys) == 1:\n                # Single key - might contain nested dict\n                arr = npz[keys[0]]\n                if getattr(arr, \"dtype\", None) == object:\n                    obj = arr.item() if arr.size == 1 else arr\n                    if isinstance(obj, dict):\n                        if (\n                            raw_filter_cutoff_hz is None\n                            and \"filter_cutoff_hz\" in obj\n                        ):\n                            raw_filter_cutoff_hz = np.array(\n                                obj[\"filter_cutoff_hz\"]\n                            ).astype(np.float32)\n                        # Try to find dof_pos/dof_vel in nested dict\n                        for pos_k, vel_k in naming_pairs:\n                            if pos_k in obj and vel_k in obj:\n                                self.ref_dof_pos = np.array(obj[pos_k]).astype(\n                                    np.float32\n                                )\n                                self.ref_dof_vel = np.array(obj[vel_k]).astype(\n                                    np.float32\n                                )\n                                break\n                        else:\n                            raise ValueError(\n                                f\"Could not find dof_pos/dof_vel in nested dict. \"\n                                f\"Available keys: {list(obj.keys())}\"\n                            )\n                    else:\n                        raise ValueError(\n                            f\"Single key '{keys[0]}' does not contain a dict. \"\n                            f\"Type: {type(obj)}\"\n                        )\n                else:\n                    raise ValueError(\n                        f\"Single key '{keys[0]}' is not an object array. \"\n                        f\"Available keys: {keys}\"\n                    )\n            else:\n                raise ValueError(\n                    f\"Could not find dof_pos/dof_vel arrays. Available keys: {keys}\"\n                )\n\n            # Ensure consistent frame count\n            if self.ref_dof_pos.shape[0] != self.ref_dof_vel.shape[0]:\n                min_frames = min(\n                    self.ref_dof_pos.shape[0], self.ref_dof_vel.shape[0]\n                )\n                self.ref_dof_pos = self.ref_dof_pos[:min_frames]\n                self.ref_dof_vel = self.ref_dof_vel[:min_frames]\n                logger.warning(\n                    f\"Frame count mismatch, truncated to {min_frames} frames\"\n                )\n\n            self.n_motion_frames = self.ref_dof_pos.shape[0]\n\n            # Optional: load reference global body frames as per motion spec\n            ref_pos_keys = [\"ref_global_translation\", \"global_translation\"]\n            ref_rot_keys = [\"ref_global_rotation_quat\", \"global_rotation_quat\"]\n            ref_vel_keys = [\"ref_global_velocity\", \"global_velocity\"]\n            ref_ang_vel_keys = [\n                \"ref_global_angular_velocity\",\n                \"global_angular_velocity\",\n            ]\n            self.ref_global_translation = None\n            self.ref_global_rotation_quat_xyzw = None\n            self.ref_global_velocity = None\n            self.ref_global_angular_velocity = None\n            for k in ref_pos_keys:\n                if k in npz:\n                    self.ref_global_translation = np.array(npz[k]).astype(\n                        np.float32\n                    )\n                    break\n            for k in ref_rot_keys:\n                if k in npz:\n                    self.ref_global_rotation_quat_xyzw = np.array(\n                        npz[k]\n                    ).astype(np.float32)\n                    break\n            for k in ref_vel_keys:\n                if k in npz:\n                    self.ref_global_velocity = np.array(npz[k]).astype(\n                        np.float32\n                    )\n                    break\n            for k in ref_ang_vel_keys:\n                if k in npz:\n                    self.ref_global_angular_velocity = np.array(npz[k]).astype(\n                        np.float32\n                    )\n                    break\n            if self.ref_global_translation is not None:\n                # Truncate to motion frames if needed\n                t_tr = min(\n                    self.n_motion_frames, self.ref_global_translation.shape[0]\n                )\n                if t_tr < self.n_motion_frames:\n                    logger.warning(\n                        f\"Global translation shorter than motion frames ({t_tr} < {self.n_motion_frames}), truncating motion.\"\n                    )\n                    self.n_motion_frames = t_tr\n                    self.ref_dof_pos = self.ref_dof_pos[:t_tr]\n                    self.ref_dof_vel = self.ref_dof_vel[:t_tr]\n\n                self.ref_global_translation = self.ref_global_translation[\n                    :t_tr\n                ]\n            if self.ref_global_rotation_quat_xyzw is not None:\n                t_rr = min(\n                    self.n_motion_frames,\n                    self.ref_global_rotation_quat_xyzw.shape[0],\n                )\n                if t_rr < self.n_motion_frames:\n                    logger.warning(\n                        f\"Global rotation shorter than motion frames ({t_rr} < {self.n_motion_frames}), truncating motion.\"\n                    )\n                    self.n_motion_frames = t_rr\n                    self.ref_dof_pos = self.ref_dof_pos[:t_rr]\n                    self.ref_dof_vel = self.ref_dof_vel[:t_rr]\n                    # Also truncate previously processed globals if necessary\n                    if self.ref_global_translation is not None:\n                        self.ref_global_translation = (\n                            self.ref_global_translation[:t_rr]\n                        )\n\n                self.ref_global_rotation_quat_xyzw = (\n                    self.ref_global_rotation_quat_xyzw[:t_rr]\n                )\n            if self.ref_global_velocity is not None:\n                t_rv = min(\n                    self.n_motion_frames,\n                    self.ref_global_velocity.shape[0],\n                )\n                if t_rv < self.n_motion_frames:\n                    self.n_motion_frames = t_rv\n                    self.ref_dof_pos = self.ref_dof_pos[:t_rv]\n                    self.ref_dof_vel = self.ref_dof_vel[:t_rv]\n                    if self.ref_global_translation is not None:\n                        self.ref_global_translation = (\n                            self.ref_global_translation[:t_rv]\n                        )\n                    if self.ref_global_rotation_quat_xyzw is not None:\n                        self.ref_global_rotation_quat_xyzw = (\n                            self.ref_global_rotation_quat_xyzw[:t_rv]\n                        )\n\n                self.ref_global_velocity = self.ref_global_velocity[:t_rv]\n            if self.ref_global_angular_velocity is not None:\n                t_ra = min(\n                    self.n_motion_frames,\n                    self.ref_global_angular_velocity.shape[0],\n                )\n                if t_ra < self.n_motion_frames:\n                    self.n_motion_frames = t_ra\n                    self.ref_dof_pos = self.ref_dof_pos[:t_ra]\n                    self.ref_dof_vel = self.ref_dof_vel[:t_ra]\n                    if self.ref_global_translation is not None:\n                        self.ref_global_translation = (\n                            self.ref_global_translation[:t_ra]\n                        )\n                    if self.ref_global_rotation_quat_xyzw is not None:\n                        self.ref_global_rotation_quat_xyzw = (\n                            self.ref_global_rotation_quat_xyzw[:t_ra]\n                        )\n                    if self.ref_global_velocity is not None:\n                        self.ref_global_velocity = self.ref_global_velocity[\n                            :t_ra\n                        ]\n\n                self.ref_global_angular_velocity = (\n                    self.ref_global_angular_velocity[:t_ra]\n                )\n\n        self.filter_cutoff_hz = self._normalize_filter_cutoff_hz(\n            raw_filter_cutoff_hz, self.n_motion_frames\n        )\n        logger.info(\n            f\"Loaded motion data with {self.n_motion_frames} frames and {self.ref_dof_pos.shape[1]} DOFs\"\n        )\n\n    def load_mujoco_model(self):\n        \"\"\"Load the MuJoCo model.\"\"\"\n        xml_path = self.config.get(\"robot_xml_path\", None)\n        if xml_path is None:\n            raise ValueError(\n                \"robot_xml_path should be specified in config !!!\"\n            )\n\n        logger.info(f\"Loading MuJoCo model from {xml_path}\")\n        self.m = mujoco.MjModel.from_xml_path(xml_path)\n        self.d = mujoco.MjData(self.m)\n        self.m.opt.timestep = self.simulation_dt\n        logger.info(\n            f\"MuJoCo model loaded with {self.m.nq} position DOFs and {self.m.nu} control DOFs\"\n        )\n\n    def _init_camera_config(self):\n        \"\"\"Initialize shared camera configuration for viewer and offscreen renderers.\"\"\"\n        self._root_body_id = -1\n        if not self._camera_tracking_enabled:\n            logger.info(\"Camera tracking disabled\")\n            return\n\n        # Prefer anchor body from robot config, then fall back to common root names\n        candidates: list[str] = []\n        anchor_name = self._get_anchor_body_name()\n        candidates.append(anchor_name)\n        for name in [\"pelvis\", \"torso\", \"base_link\", \"trunk\", \"root\"]:\n            if name not in candidates:\n                candidates.append(name)\n\n        for body_name in candidates:\n            bid = int(\n                mujoco.mj_name2id(self.m, mujoco.mjtObj.mjOBJ_BODY, body_name)\n            )\n            if bid != -1:\n                self._root_body_id = bid\n                break\n\n        if self._root_body_id != -1:\n            logger.info(\n                f\"Camera tracking enabled for body '{body_name}' (ID={self._root_body_id}), \"\n                f\"lookat height offset: {self._camera_height_offset:.2f}m\"\n            )\n        else:\n            logger.warning(\n                \"Could not find robot root body for camera tracking; \"\n                \"viewer and offscreen cameras will not track the robot.\"\n            )\n\n    def _configure_viewer_camera(self, viewer):\n        \"\"\"Apply shared align-view parameters to the interactive viewer camera.\"\"\"\n        mujoco.mjv_defaultFreeCamera(self.m, viewer.cam)\n        viewer.cam.azimuth = self._camera_azimuth\n        viewer.cam.elevation = self._camera_elevation\n        viewer.cam.distance = self._camera_distance\n\n    def _init_video_tools(self, tag: str):\n        \"\"\"Initialize video writer and offscreen renderer when recording is enabled.\"\"\"\n        if not bool(self.config.get(\"record_video\", False)):\n            return\n        width = int(self.config.get(\"video_width\", 1280))\n        height = int(self.config.get(\"video_height\", 720))\n        fps = float(self.config.get(\"video_fps\", 30.0))\n\n        onnx_stem = os.path.splitext(\n            os.path.basename(self.config.ckpt_onnx_path)\n        )[0]\n        output_dir = os.path.join(\n            os.path.dirname(self.config.ckpt_onnx_path),\n            f\"mujoco_output_{onnx_stem}\",\n        )\n        os.makedirs(output_dir, exist_ok=True)\n        motion_npz_path = self.config.get(\"motion_npz_path\", None)\n        if motion_npz_path is not None:\n            motion_stem = os.path.splitext(os.path.basename(motion_npz_path))[\n                0\n            ]\n            out_path = os.path.join(output_dir, f\"{motion_stem}.mp4\")\n        else:\n            out_path = os.path.join(output_dir, \"velocity_tracking.mp4\")\n\n        fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n        self._video_writer = cv2.VideoWriter(\n            out_path, fourcc, fps, (width, height)\n        )\n        self._offscreen = OffscreenRenderer(\n            self.m,\n            height,\n            width,\n            distance=self._camera_distance,\n            azimuth=self._camera_azimuth,\n            elevation=self._camera_elevation,\n        )\n        self._frame_interval = 1.0 / max(fps, 1.0)\n        self._last_frame_time = 0.0\n        if getattr(self, \"ref_global_translation\", None) is not None:\n            self._offscreen.set_overlay_callback(\n                lambda scene: self._draw_ref_body_spheres_to_scene(\n                    scene, reset_ngeom=False\n                )\n            )\n        logger.info(f\"Recording enabled. Writing to: {out_path}\")\n\n    def _dump_robot_augmented_npz(self) -> None:\n        \"\"\"Copy original motion npz and append robot_* states, saved next to video output.\n\n        The output follows the holomotion offline-eval spec used in PPO:\n        - robot_dof_pos, robot_dof_vel: [T, num_dofs]\n        - robot_global_translation: [T, num_bodies, 3]\n        - robot_global_rotation_quat: [T, num_bodies, 4] (XYZW)\n        - robot_global_velocity: [T, num_bodies, 3]\n        - robot_global_angular_velocity: [T, num_bodies, 3]\n        \"\"\"\n        motion_npz_path = self.config.get(\"motion_npz_path\", None)\n        if motion_npz_path is None:\n            return\n        if len(self._robot_dof_pos_seq) == 0:\n            return\n\n        # Stack recorded sequences\n        robot_dof_pos = np.stack(self._robot_dof_pos_seq, axis=0).astype(\n            np.float32\n        )\n        robot_dof_vel = np.stack(self._robot_dof_vel_seq, axis=0).astype(\n            np.float32\n        )\n        robot_dof_acc = np.stack(self._robot_dof_acc_seq, axis=0).astype(\n            np.float32\n        )\n        robot_dof_torque = np.stack(self._robot_dof_torque_seq, axis=0).astype(\n            np.float32\n        )\n        robot_low_level_dof_torque = None\n        if len(self._robot_low_level_dof_torque_seq) > 0:\n            robot_low_level_dof_torque = np.stack(\n                self._robot_low_level_dof_torque_seq, axis=0\n            ).astype(np.float32)\n        (\n            robot_low_level_foot_contact,\n            robot_low_level_foot_normal_force,\n            robot_low_level_foot_tangent_speed,\n        ) = self._get_stacked_low_level_foot_contact_tensors()\n        robot_actions = None\n        if len(getattr(self, \"_robot_actions_seq\", [])) > 0:\n            robot_actions = np.stack(self._robot_actions_seq, axis=0).astype(\n                np.float32\n            )\n        robot_action_rate = np.asarray(\n            self._robot_action_rate_seq, dtype=np.float32\n        )\n\n        robot_global_translation = np.stack(\n            self._robot_global_translation_seq, axis=0\n        ).astype(np.float32)\n        robot_global_rotation_quat = np.stack(\n            self._robot_global_rotation_quat_seq, axis=0\n        ).astype(np.float32)\n        robot_global_velocity = np.stack(\n            self._robot_global_velocity_seq, axis=0\n        ).astype(np.float32)\n        robot_global_angular_velocity = np.stack(\n            self._robot_global_angular_velocity_seq, axis=0\n        ).astype(np.float32)\n        robot_moe_expert_indices, robot_moe_expert_logits = (\n            self._get_stacked_moe_routing_tensors()\n        )\n\n        # Load original motion npz\n        with np.load(motion_npz_path, allow_pickle=True) as npz:\n            data_dict = {k: npz[k] for k in npz.files}\n\n        # Augment with robot_* arrays (override if already present)\n        data_dict[\"robot_dof_pos\"] = robot_dof_pos\n        data_dict[\"robot_dof_vel\"] = robot_dof_vel\n        data_dict[\"robot_dof_acc\"] = robot_dof_acc\n        data_dict[\"robot_dof_torque\"] = robot_dof_torque\n        if robot_low_level_dof_torque is not None:\n            data_dict[\"robot_low_level_dof_torque\"] = (\n                robot_low_level_dof_torque\n            )\n        if robot_low_level_foot_contact is not None:\n            data_dict[\"robot_low_level_foot_contact\"] = (\n                robot_low_level_foot_contact\n            )\n        if robot_low_level_foot_normal_force is not None:\n            data_dict[\"robot_low_level_foot_normal_force\"] = (\n                robot_low_level_foot_normal_force\n            )\n        if robot_low_level_foot_tangent_speed is not None:\n            data_dict[\"robot_low_level_foot_tangent_speed\"] = (\n                robot_low_level_foot_tangent_speed\n            )\n        if robot_actions is not None:\n            data_dict[\"robot_actions\"] = robot_actions\n        data_dict[\"robot_low_level_torque_dt\"] = np.array(\n            self.simulation_dt, dtype=np.float32\n        )\n        data_dict[\"robot_low_level_contact_dt\"] = np.array(\n            self.simulation_dt, dtype=np.float32\n        )\n        data_dict[\"robot_action_rate\"] = robot_action_rate\n        data_dict[\"robot_global_translation\"] = robot_global_translation\n        data_dict[\"robot_global_rotation_quat\"] = robot_global_rotation_quat\n        data_dict[\"robot_global_velocity\"] = robot_global_velocity\n        data_dict[\"robot_global_angular_velocity\"] = (\n            robot_global_angular_velocity\n        )\n        if robot_moe_expert_indices is not None:\n            data_dict[\"robot_moe_expert_indices\"] = robot_moe_expert_indices\n        if robot_moe_expert_logits is not None:\n            data_dict[\"robot_moe_expert_logits\"] = robot_moe_expert_logits\n\n        # Derive output directory consistent with video writer\n        onnx_stem = os.path.splitext(\n            os.path.basename(self.config.ckpt_onnx_path)\n        )[0]\n        output_dir = os.path.join(\n            os.path.dirname(self.config.ckpt_onnx_path),\n            f\"mujoco_output_{onnx_stem}\",\n        )\n        os.makedirs(output_dir, exist_ok=True)\n        motion_stem = os.path.splitext(os.path.basename(motion_npz_path))[0]\n        out_npz_path = os.path.join(output_dir, f\"{motion_stem}_robot.npz\")\n\n        np.savez_compressed(out_npz_path, **data_dict)\n        logger.info(\n            f\"Robot-augmented motion npz saved to: {out_npz_path} \"\n            f\"(T={robot_dof_pos.shape[0]}, num_dofs={robot_dof_pos.shape[1]}, \"\n            f\"num_bodies={robot_global_translation.shape[1]})\"\n        )\n\n    def _close_video_tools(self):\n        if self._video_writer is not None:\n            self._video_writer.release()\n            self._video_writer = None\n        if self._offscreen is not None:\n            self._offscreen.close()\n            self._offscreen = None\n        self._frame_interval = None\n        self._last_frame_time = 0.0\n\n    def _update_camera_lookat(self, cam):\n        \"\"\"Update camera lookat to track the robot root when tracking is enabled.\"\"\"\n        if not self._camera_tracking_enabled:\n            return\n        if self._root_body_id == -1:\n            return\n        cam.lookat[:2] = self.d.xpos[self._root_body_id][:2]\n        cam.lookat[2] = (\n            self.d.xpos[self._root_body_id][2] + self._camera_height_offset\n        )\n\n    def _maybe_record_frame(self):\n        if self._video_writer is None or self._offscreen is None:\n            return\n        now = time.time()\n        if (\n            self._last_frame_time == 0.0\n            or (now - self._last_frame_time) >= self._frame_interval\n        ):\n            # Update offscreen camera lookat to track robot (if enabled)\n            self._update_camera_lookat(self._offscreen._cam)\n\n            frame_rgb = self._offscreen.render(self.d)\n            # Convert RGB (MuJoCo) -> BGR (OpenCV) before writing\n            frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)\n            self._video_writer.write(frame_bgr)\n            self._last_frame_time = now\n\n    def _apply_control(self, sleep: bool):\n        \"\"\"Apply PD targets via Unitree lowcmd, step MuJoCo, optionally sleep.\"\"\"\n        for _ in range(self.control_decimation):\n            record_low_level_torque = (\n                self.command_mode == \"motion_tracking\"\n                and self.ref_dof_pos is not None\n            )\n            if record_low_level_torque:\n                torque_ref = np.zeros(\n                    len(self.dof_names_ref_motion), dtype=np.float32\n                )\n            current_dof_pos = self.robot_dof_pos\n            current_dof_vel = self.robot_dof_vel\n            for name, act_idx in self.actuator_name_to_index.items():\n                mu_idx = self.actuator_name_to_mu_idx[name]\n                joint_name = self.mjcf_dof_names[mu_idx]\n                target_q = self.target_dof_pos_by_name.get(\n                    joint_name,\n                    float(self.default_angles_mu[mu_idx]),\n                )\n                target_dq = 0.0\n                feedforward_tau = 0.0\n                kp = self.kps_mu[mu_idx]\n                kd = self.kds_mu[mu_idx]\n                current_q = current_dof_pos[mu_idx]\n                current_dq = current_dof_vel[mu_idx]\n                tau = (\n                    feedforward_tau\n                    + kp * (target_q - current_q)\n                    + kd * (target_dq - current_dq)\n                )\n                if (\n                    act_idx in self.actuator_force_range\n                    and self.actuator_force_range[act_idx] is not None\n                ):\n                    min_force, max_force = self.actuator_force_range[act_idx]\n                    tau = np.clip(tau, min_force, max_force)\n                self.d.ctrl[mu_idx] = tau\n                if record_low_level_torque:\n                    torque_ref[self.mu_to_ref[mu_idx]] = np.float32(tau)\n\n            mujoco.mj_step(self.m, self.d)\n            if record_low_level_torque:\n                self._robot_low_level_dof_torque_seq.append(torque_ref)\n                self._record_low_level_foot_contact_sample()\n            if sleep:\n                time.sleep(self.simulation_dt)\n\n    def _compute_pd_torque_command_ref(self) -> np.ndarray:\n        current_dof_pos = self.robot_dof_pos\n        current_dof_vel = self.robot_dof_vel\n\n        num_mu_dofs = len(self.mjcf_dof_names)\n        torque_mu = np.zeros(num_mu_dofs, dtype=np.float32)\n        for name, act_idx in self.actuator_name_to_index.items():\n            mu_idx = self.actuator_name_to_mu_idx[name]\n            joint_name = self.mjcf_dof_names[mu_idx]\n            target_q = self.target_dof_pos_by_name.get(\n                joint_name,\n                float(self.default_angles_mu[mu_idx]),\n            )\n            target_dq = 0.0\n            feedforward_tau = 0.0\n            kp = self.kps_mu[mu_idx]\n            kd = self.kds_mu[mu_idx]\n            current_q = current_dof_pos[mu_idx]\n            current_dq = current_dof_vel[mu_idx]\n            tau = (\n                feedforward_tau\n                + kp * (target_q - current_q)\n                + kd * (target_dq - current_dq)\n            )\n            if (\n                act_idx in self.actuator_force_range\n                and self.actuator_force_range[act_idx] is not None\n            ):\n                min_force, max_force = self.actuator_force_range[act_idx]\n                tau = np.clip(tau, min_force, max_force)\n            torque_mu[mu_idx] = np.float32(tau)\n\n        num_ref_dofs = len(self.dof_names_ref_motion)\n        torque_ref = np.zeros(num_ref_dofs, dtype=np.float32)\n        for mu_idx, ref_idx in enumerate(self.mu_to_ref):\n            torque_ref[ref_idx] = torque_mu[mu_idx]\n        return torque_ref\n\n    def _get_obs_ref_motion_states(self):\n        # [2 * num_actions] in ONNX order: [ref_pos, ref_vel]\n        if self.ref_dof_pos is None or self.ref_dof_vel is None:\n            return np.zeros(2 * self.num_actions, dtype=np.float32)\n        frame_idx = self.motion_frame_idx\n        ref_pos_mu = self.ref_dof_pos[frame_idx]\n        ref_vel_mu = self.ref_dof_vel[frame_idx]\n        # Map URDF/Mu order -> ONNX order using precomputed indices\n        ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32)\n        ref_vel_onnx = ref_vel_mu[self.ref_to_onnx].astype(np.float32)\n        return np.concatenate([ref_pos_onnx, ref_vel_onnx], axis=0).astype(\n            np.float32\n        )\n\n    def _get_obs_ref_motion_states_fut(self):\n        # [T, 2 * num_actions] flattened, ONNX order\n        T = int(self.n_fut_frames)\n        if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None:\n            return np.zeros(0, dtype=np.float32)\n        N = int(self.num_actions)\n        frame_idx = self.motion_frame_idx\n        last_valid_frame_idx = self.n_motion_frames - 1\n        # Build future arrays in Mu order [N, T]\n        pos_fut = np.zeros(\n            (len(self.dof_names_ref_motion), T), dtype=np.float32\n        )\n        vel_fut = np.zeros(\n            (len(self.dof_names_ref_motion), T), dtype=np.float32\n        )\n        for i in range(T):\n            idx = frame_idx + i + 1\n            if idx < self.n_motion_frames:\n                pos_fut[:, i] = self.ref_dof_pos[idx]\n                vel_fut[:, i] = self.ref_dof_vel[idx]\n            else:\n                pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx]\n                vel_fut[:, i] = self.ref_dof_vel[last_valid_frame_idx]\n        # Reorder to ONNX and flatten per training layout\n        pos_fut_onnx = pos_fut[self.ref_to_onnx, :]  # [N, T]\n        vel_fut_onnx = vel_fut[self.ref_to_onnx, :]  # [N, T]\n        fut_concat = np.concatenate(\n            [pos_fut_onnx.T, vel_fut_onnx.T], axis=1\n        )  # [T, 2N]\n        return fut_concat.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_dof_pos_fut(self):\n        # [T, 2 * num_actions] flattened, ONNX order\n        T = int(self.n_fut_frames)\n        if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None:\n            return np.zeros(0, dtype=np.float32)\n        frame_idx = self.motion_frame_idx\n        last_valid_frame_idx = self.n_motion_frames - 1\n        # Build future arrays in Mu order [N, T]\n        pos_fut = np.zeros(\n            (len(self.dof_names_ref_motion), T), dtype=np.float32\n        )\n        for i in range(T):\n            idx = frame_idx + i + 1\n            if idx < self.n_motion_frames:\n                pos_fut[:, i] = self.ref_dof_pos[idx]\n            else:\n                pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx]\n        # Reorder to ONNX and flatten per training layout\n        pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0)  # [N, T]\n        return pos_fut_onnx.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_root_height_fut(self):\n        T = int(self.n_fut_frames)\n        if (\n            T <= 0\n            or self.ref_dof_pos is None\n            or self.ref_dof_vel is None\n            or getattr(self, \"ref_global_translation\", None) is None\n        ):\n            return np.zeros(0, dtype=np.float32)\n        frame_idx = self.motion_frame_idx\n        last_valid_frame_idx = self.n_motion_frames - 1\n        # Build future arrays in Mu order [N, T]\n        h_fut = np.zeros((1, T), dtype=np.float32)\n        for i in range(T):\n            idx = frame_idx + i + 1\n            if idx < self.n_motion_frames:\n                h_fut[:, i] = self.ref_global_translation[\n                    idx, self.root_body_idx, 2\n                ]\n            else:\n                h_fut[:, i] = self.ref_global_translation[\n                    last_valid_frame_idx, self.root_body_idx, 2\n                ]\n        return h_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_dof_pos_cur(self):\n        # [2 * num_actions] in ONNX order: [ref_pos, ref_vel]\n        if self.ref_dof_pos is None or self.ref_dof_vel is None:\n            return np.zeros(2 * self.num_actions, dtype=np.float32)\n        ref_pos_mu = self.ref_dof_pos[self.motion_frame_idx]\n        # Map URDF/Mu order -> ONNX order using precomputed indices\n        ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32)\n        return ref_pos_onnx\n\n    def _get_obs_ref_dof_vel_cur(self):\n        # [2 * num_actions] in ONNX order: [ref_pos, ref_vel]\n        if self.ref_dof_pos is None or self.ref_dof_vel is None:\n            return np.zeros(2 * self.num_actions, dtype=np.float32)\n        ref_vel_mu = self.ref_dof_vel[self.motion_frame_idx]\n        # Map URDF/Mu order -> ONNX order using precomputed indices\n        ref_vel_onnx = ref_vel_mu[self.ref_to_onnx].astype(np.float32)\n        return ref_vel_onnx\n\n    def _get_obs_ref_motion_filter_cutoff_hz(self):\n        # cutoff = getattr(self, \"filter_cutoff_hz\", None)\n        cutoff = 1.0\n        if cutoff is None:\n            return np.float32(0.0)\n        cutoff_flat = np.asarray(cutoff, dtype=np.float32).reshape(-1)\n        if cutoff_flat.size == 0:\n            return np.float32(0.0)\n        frame_idx = min(\n            max(int(getattr(self, \"motion_frame_idx\", 0)), 0),\n            cutoff_flat.size - 1,\n        )\n        return np.float32(cutoff_flat[frame_idx])\n\n    def _get_obs_ref_root_height_cur(self):\n        if getattr(self, \"ref_global_translation\", None) is None:\n            return 0.0\n        return self.ref_global_translation[\n            self.motion_frame_idx, self.root_body_idx, 2\n        ]\n\n    def _get_obs_ref_gravity_projection_cur(self):\n        if (\n            getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None\n            or self.n_motion_frames <= 0\n        ):\n            return np.zeros(3, dtype=np.float32)\n        q_root_xyzw = self.ref_global_rotation_quat_xyzw[\n            self.motion_frame_idx, self.root_body_idx\n        ].astype(np.float32)\n        q_root_wxyz = xyzw_to_wxyz(\n            torch.as_tensor(q_root_xyzw, dtype=torch.float32, device=\"cpu\")\n        )\n        q_root_wxyz = standardize_quaternion(q_root_wxyz)\n        g_w = torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device=\"cpu\")\n        g_root = quat_apply(quat_inv(q_root_wxyz), g_w)\n        return g_root.detach().cpu().numpy().astype(np.float32)\n\n    def _get_obs_ref_gravity_projection_fut(self):\n        T = int(self.n_fut_frames)\n        if (\n            T <= 0\n            or getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None\n            or self.n_motion_frames <= 0\n        ):\n            return np.zeros(0, dtype=np.float32)\n        frame_idx = self.motion_frame_idx\n        last_valid_frame_idx = self.n_motion_frames - 1\n        g_w = torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device=\"cpu\")\n        gravity_fut = np.zeros((T, 3), dtype=np.float32)\n        for i in range(T):\n            idx = frame_idx + i + 1\n            if idx >= self.n_motion_frames:\n                idx = last_valid_frame_idx\n            q_root_xyzw = self.ref_global_rotation_quat_xyzw[\n                idx, self.root_body_idx\n            ].astype(np.float32)\n            q_root_wxyz = xyzw_to_wxyz(\n                torch.as_tensor(q_root_xyzw, dtype=torch.float32, device=\"cpu\")\n            )\n            q_root_wxyz = standardize_quaternion(q_root_wxyz)\n            gravity_fut[i] = (\n                quat_apply(quat_inv(q_root_wxyz), g_w)\n                .detach()\n                .cpu()\n                .numpy()\n                .astype(np.float32)\n            )\n        return gravity_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_base_linvel_cur(self):\n        if (\n            getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None\n            or getattr(self, \"ref_global_velocity\", None) is None\n            or self.n_motion_frames <= 0\n        ):\n            return np.zeros(3, dtype=np.float32)\n        q_root_xyzw = self.ref_global_rotation_quat_xyzw[\n            self.motion_frame_idx, self.root_body_idx\n        ].astype(np.float32)\n        q_root_wxyz = xyzw_to_wxyz(\n            torch.as_tensor(q_root_xyzw, dtype=torch.float32, device=\"cpu\")\n        )\n        q_root_wxyz = standardize_quaternion(q_root_wxyz)\n        v_root_w = torch.as_tensor(\n            self.ref_global_velocity[\n                self.motion_frame_idx, self.root_body_idx\n            ],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        v_root = quat_apply(quat_inv(q_root_wxyz), v_root_w)\n        return v_root.detach().cpu().numpy().astype(np.float32)\n\n    def _get_obs_ref_base_linvel_fut(self):\n        T = int(self.n_fut_frames)\n        if (\n            T <= 0\n            or getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None\n            or getattr(self, \"ref_global_velocity\", None) is None\n            or self.n_motion_frames <= 0\n        ):\n            return np.zeros(0, dtype=np.float32)\n        frame_idx = self.motion_frame_idx\n        last_valid_frame_idx = self.n_motion_frames - 1\n        base_linvel_fut = np.zeros((T, 3), dtype=np.float32)\n        for i in range(T):\n            idx = frame_idx + i + 1\n            if idx >= self.n_motion_frames:\n                idx = last_valid_frame_idx\n            q_root_xyzw = self.ref_global_rotation_quat_xyzw[\n                idx, self.root_body_idx\n            ].astype(np.float32)\n            q_root_wxyz = xyzw_to_wxyz(\n                torch.as_tensor(q_root_xyzw, dtype=torch.float32, device=\"cpu\")\n            )\n            q_root_wxyz = standardize_quaternion(q_root_wxyz)\n            v_root_w = torch.as_tensor(\n                self.ref_global_velocity[idx, self.root_body_idx],\n                dtype=torch.float32,\n                device=\"cpu\",\n            )\n            base_linvel_fut[i] = (\n                quat_apply(quat_inv(q_root_wxyz), v_root_w)\n                .detach()\n                .cpu()\n                .numpy()\n                .astype(np.float32)\n            )\n        return base_linvel_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_base_angvel_cur(self):\n        if (\n            getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None\n            or getattr(self, \"ref_global_angular_velocity\", None) is None\n            or self.n_motion_frames <= 0\n        ):\n            return np.zeros(3, dtype=np.float32)\n        q_root_xyzw = self.ref_global_rotation_quat_xyzw[\n            self.motion_frame_idx, self.root_body_idx\n        ].astype(np.float32)\n        q_root_wxyz = xyzw_to_wxyz(\n            torch.as_tensor(q_root_xyzw, dtype=torch.float32, device=\"cpu\")\n        )\n        q_root_wxyz = standardize_quaternion(q_root_wxyz)\n        w_root_w = torch.as_tensor(\n            self.ref_global_angular_velocity[\n                self.motion_frame_idx, self.root_body_idx\n            ],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        w_root = quat_apply(quat_inv(q_root_wxyz), w_root_w)\n        return w_root.detach().cpu().numpy().astype(np.float32)\n\n    def _get_obs_ref_base_angvel_fut(self):\n        T = int(self.n_fut_frames)\n        if (\n            T <= 0\n            or getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None\n            or getattr(self, \"ref_global_angular_velocity\", None) is None\n            or self.n_motion_frames <= 0\n        ):\n            return np.zeros(0, dtype=np.float32)\n        frame_idx = self.motion_frame_idx\n        last_valid_frame_idx = self.n_motion_frames - 1\n        base_angvel_fut = np.zeros((T, 3), dtype=np.float32)\n        for i in range(T):\n            idx = frame_idx + i + 1\n            if idx >= self.n_motion_frames:\n                idx = last_valid_frame_idx\n            q_root_xyzw = self.ref_global_rotation_quat_xyzw[\n                idx, self.root_body_idx\n            ].astype(np.float32)\n            q_root_wxyz = xyzw_to_wxyz(\n                torch.as_tensor(q_root_xyzw, dtype=torch.float32, device=\"cpu\")\n            )\n            q_root_wxyz = standardize_quaternion(q_root_wxyz)\n            w_root_w = torch.as_tensor(\n                self.ref_global_angular_velocity[idx, self.root_body_idx],\n                dtype=torch.float32,\n                device=\"cpu\",\n            )\n            base_angvel_fut[i] = (\n                quat_apply(quat_inv(q_root_wxyz), w_root_w)\n                .detach()\n                .cpu()\n                .numpy()\n                .astype(np.float32)\n            )\n        return base_angvel_fut.reshape(-1).astype(np.float32)\n\n    def _get_obs_ref_keybody_rel_pos_cur(self):\n        keybody_idxs = self._get_ref_keybody_indices(\n            \"actor_ref_keybody_rel_pos_cur\"\n        )\n        n_keybodies = int(keybody_idxs.shape[0])\n        if n_keybodies == 0:\n            return np.zeros(0, dtype=np.float32)\n        if (\n            getattr(self, \"ref_global_translation\", None) is None\n            or getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None\n            or self.n_motion_frames <= 0\n        ):\n            return np.zeros(n_keybodies * 3, dtype=np.float32)\n\n        frame_idx = self.motion_frame_idx\n        ref_body_global_pos = self.ref_global_translation[frame_idx].astype(\n            np.float32\n        )  # [B, 3]\n        ref_root_global_pos = ref_body_global_pos[\n            self.root_body_idx\n        ]  # [3], world\n        q_root_xyzw = self.ref_global_rotation_quat_xyzw[\n            frame_idx, self.root_body_idx\n        ].astype(np.float32)\n        q_root_wxyz = xyzw_to_wxyz(\n            torch.as_tensor(q_root_xyzw, dtype=torch.float32, device=\"cpu\")\n        )\n        q_root_wxyz = standardize_quaternion(q_root_wxyz)\n\n        rel_pos_w = (\n            ref_body_global_pos[keybody_idxs] - ref_root_global_pos[None, :]\n        )  # [K, 3]\n        rel_pos_w_t = torch.as_tensor(\n            rel_pos_w, dtype=torch.float32, device=\"cpu\"\n        )\n        q_root_expand = q_root_wxyz.unsqueeze(0).expand(n_keybodies, 4)\n        rel_pos_root_t = quat_apply(quat_inv(q_root_expand), rel_pos_w_t)\n        return (\n            rel_pos_root_t.detach()\n            .cpu()\n            .numpy()\n            .reshape(-1)\n            .astype(np.float32)\n        )\n\n    def _get_obs_ref_keybody_rel_pos_fut(self):\n        T = int(self.n_fut_frames)\n        if T <= 0:\n            return np.zeros(0, dtype=np.float32)\n\n        keybody_idxs = self._get_ref_keybody_indices(\n            \"actor_ref_keybody_rel_pos_fut\"\n        )\n        n_keybodies = int(keybody_idxs.shape[0])\n        if n_keybodies == 0:\n            return np.zeros((T, 0), dtype=np.float32)\n        if (\n            getattr(self, \"ref_global_translation\", None) is None\n            or getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None\n            or self.n_motion_frames <= 0\n        ):\n            return np.zeros((T, n_keybodies * 3), dtype=np.float32)\n\n        frame_idx = self.motion_frame_idx\n        last_valid_frame_idx = self.n_motion_frames - 1\n        rel_pos_fut = np.zeros((T, n_keybodies, 3), dtype=np.float32)\n\n        for i in range(T):\n            idx = frame_idx + i + 1\n            if idx >= self.n_motion_frames:\n                idx = last_valid_frame_idx\n\n            ref_body_global_pos = self.ref_global_translation[idx].astype(\n                np.float32\n            )  # [B, 3]\n            ref_root_global_pos = ref_body_global_pos[\n                self.root_body_idx\n            ]  # [3], world\n            q_root_xyzw = self.ref_global_rotation_quat_xyzw[\n                idx, self.root_body_idx\n            ].astype(np.float32)\n            q_root_wxyz = xyzw_to_wxyz(\n                torch.as_tensor(q_root_xyzw, dtype=torch.float32, device=\"cpu\")\n            )\n            q_root_wxyz = standardize_quaternion(q_root_wxyz)\n\n            rel_pos_w = (\n                ref_body_global_pos[keybody_idxs]\n                - ref_root_global_pos[None, :]\n            )  # [K, 3]\n            rel_pos_w_t = torch.as_tensor(\n                rel_pos_w, dtype=torch.float32, device=\"cpu\"\n            )\n            q_root_expand = q_root_wxyz.unsqueeze(0).expand(n_keybodies, 4)\n            rel_pos_fut[i] = (\n                quat_apply(quat_inv(q_root_expand), rel_pos_w_t)\n                .detach()\n                .cpu()\n                .numpy()\n                .astype(np.float32)\n            )\n\n        return rel_pos_fut.reshape(T, -1).astype(np.float32)\n\n    def _get_obs_place_holder(self):\n        return np.zeros(self.actor_place_holder_ndim, dtype=np.float32)\n\n    def _get_obs_vr_ref_motion_states(self):\n        # [2 * num_actions] in ONNX order: [ref_pos, ref_vel]\n        if self.ref_dof_pos is None or self.ref_dof_vel is None:\n            return np.zeros(2 * self.num_actions, dtype=np.float32)\n        frame_idx = self.motion_frame_idx\n        ref_pos_mu = self.ref_dof_pos[frame_idx]\n        # Map URDF/Mu order -> ONNX order using precomputed indices\n        ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32)\n        return np.concatenate(\n            [ref_pos_onnx, np.zeros_like(ref_pos_onnx)],\n            axis=0,\n        ).astype(np.float32)\n\n    def _get_obs_vr_ref_motion_states_fut(self):\n        # [T, 2 * num_actions] flattened, ONNX order\n        T = int(self.n_fut_frames)\n        if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None:\n            return np.zeros(0, dtype=np.float32)\n        N = int(self.num_actions)\n        frame_idx = self.motion_frame_idx\n        last_valid_frame_idx = self.n_motion_frames - 1\n        # Build future arrays in Mu order [N, T]\n        pos_fut = np.zeros(\n            (len(self.dof_names_ref_motion), T), dtype=np.float32\n        )\n        for i in range(T):\n            idx = frame_idx + i + 1\n            if idx < self.n_motion_frames:\n                pos_fut[:, i] = self.ref_dof_pos[idx]\n            else:\n                pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx]\n        # Reorder to ONNX and flatten per training layout\n        pos_fut_onnx = pos_fut[self.ref_to_onnx, :]  # [N, T]\n        fut_concat = np.concatenate(\n            [pos_fut_onnx.T, np.zeros_like(pos_fut_onnx.T)], axis=1\n        )  # [T, 2N]\n        return fut_concat.reshape(-1).astype(np.float32)\n\n    def _get_obs_rel_robot_root_ang_vel(self):\n        q_root_wxyz = torch.as_tensor(\n            self.robot_global_bodylink_rot[self.root_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        w_root_w = torch.as_tensor(\n            self.robot_global_bodylink_ang_vel[self.root_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        w_root_b = quat_apply(quat_inv(q_root_wxyz), w_root_w)\n        return w_root_b.detach().cpu().numpy().astype(np.float32)\n\n    def _get_obs_last_action(self):\n        return np.array(self.actions_onnx, dtype=np.float32).reshape(-1)\n\n    def _get_obs_velocity_command(self):\n        # Extended velocity command: [move_mask, vx, vy, vyaw]\n        if (\n            self.command_mode == \"velocity_tracking\"\n            and getattr(self, \"keyboard_handler\", None) is not None\n        ):\n            cmd = np.asarray(\n                self.keyboard_handler.get_velocity_command(), dtype=np.float32\n            ).reshape(3)\n        else:\n            cmd = np.zeros(3, dtype=np.float32)\n        out = np.zeros(4, dtype=np.float32)\n        out[1:] = cmd\n        out[0] = float(np.linalg.norm(cmd) > 0.1)\n        return out\n\n    def _get_obs_actor_ref_headling_aligned_vel_cmd(self):\n        return self._get_obs_velocity_command()\n\n    # ----------------- Actor term aliases (PULSE stage2 unified obs) -----------------\n    def _get_obs_actor_velocity_command(self):\n        return self._get_obs_velocity_command()\n\n    def _get_obs_actor_projected_gravity(self):\n        return self._get_obs_projected_gravity()\n\n    def _get_obs_actor_rel_robot_root_ang_vel(self):\n        return self._get_obs_rel_robot_root_ang_vel()\n\n    def _get_obs_actor_dof_pos(self):\n        return self._get_obs_dof_pos()\n\n    def _get_obs_actor_dof_vel(self):\n        return self._get_obs_dof_vel()\n\n    def _get_obs_actor_last_action(self):\n        return self._get_obs_last_action()\n\n    def _get_obs_actor_place_holder(self):\n        return self._get_obs_place_holder()\n\n    def _get_obs_actor_ref_dof_pos_fut(self):\n        return self._get_obs_ref_dof_pos_fut()\n\n    def _get_obs_actor_ref_dof_pos_cur(self):\n        return self._get_obs_ref_dof_pos_cur()\n\n    def _get_obs_actor_ref_motion_filter_cutoff_hz(self):\n        return self._get_obs_ref_motion_filter_cutoff_hz()\n\n    def _get_obs_actor_ref_root_height_fut(self):\n        return self._get_obs_ref_root_height_fut()\n\n    def _get_obs_actor_ref_root_height_cur(self):\n        return self._get_obs_ref_root_height_cur()\n\n    def _get_obs_actor_ref_gravity_projection_cur(self):\n        return self._get_obs_ref_gravity_projection_cur()\n\n    def _get_obs_actor_ref_gravity_projection_fut(self):\n        return self._get_obs_ref_gravity_projection_fut()\n\n    def _get_obs_actor_ref_base_linvel_cur(self):\n        return self._get_obs_ref_base_linvel_cur()\n\n    def _get_obs_actor_ref_base_linvel_fut(self):\n        return self._get_obs_ref_base_linvel_fut()\n\n    def _get_obs_actor_ref_base_angvel_cur(self):\n        return self._get_obs_ref_base_angvel_cur()\n\n    def _get_obs_actor_ref_base_angvel_fut(self):\n        return self._get_obs_ref_base_angvel_fut()\n\n    def _get_obs_actor_ref_keybody_rel_pos_cur(self):\n        return self._get_obs_ref_keybody_rel_pos_cur()\n\n    def _get_obs_actor_ref_keybody_rel_pos_fut(self):\n        return self._get_obs_ref_keybody_rel_pos_fut()\n\n    def _get_obs_global_anchor_diff(self):\n        self._ensure_ref_to_sim_transform_rigid()\n        ref_pos_sim = self.ref_global_bodylink_pos\n        ref_rot_sim = self.ref_global_bodylink_rot\n\n        if ref_pos_sim is None or ref_rot_sim is None:\n            return np.zeros(9, dtype=np.float32)\n\n        t_robot = torch.as_tensor(\n            self.robot_global_bodylink_pos[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_robot_wxyz = torch.as_tensor(\n            self.robot_global_bodylink_rot[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        t_ref_sim = torch.as_tensor(\n            ref_pos_sim[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_ref_sim = torch.as_tensor(\n            ref_rot_sim[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        # Use isaaclab semantics: pose of ref (frame 2) w.r.t. robot (frame 1)\n        p_diff_t, q_diff_wxyz_t = subtract_frame_transforms(\n            t01=t_robot,\n            q01=q_robot_wxyz,\n            t02=t_ref_sim,\n            q02=q_ref_sim,\n        )\n        q_diff_wxyz_t = quat_normalize_wxyz(q_diff_wxyz_t)\n        rot_diff_mat = matrix_from_quat(q_diff_wxyz_t)\n        out = torch.cat(\n            [p_diff_t.reshape(-1), rot_diff_mat[..., :2].reshape(-1)], dim=-1\n        )\n        return out.detach().cpu().numpy().astype(np.float32)\n\n    def _get_obs_global_anchor_pos_diff(self):\n        self._ensure_ref_to_sim_transform_rigid()\n        ref_pos_sim = self.ref_global_bodylink_pos\n        ref_rot_sim = self.ref_global_bodylink_rot\n\n        if ref_pos_sim is None or ref_rot_sim is None:\n            return np.zeros(3, dtype=np.float32)\n\n        t_robot = torch.as_tensor(\n            self.robot_global_bodylink_pos[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )  # [3], world\n        q_robot_wxyz = torch.as_tensor(\n            self.robot_global_bodylink_rot[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )  # [4], wxyz\n\n        # Transform reference anchor pose into simulation global frame\n        t_ref_sim = torch.as_tensor(\n            ref_pos_sim[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_ref_sim = torch.as_tensor(\n            ref_rot_sim[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n\n        pos_diff_anchor_t, _ = subtract_frame_transforms(\n            t01=t_robot,\n            q01=q_robot_wxyz,\n            t02=t_ref_sim,\n            q02=q_ref_sim,\n        )\n\n        return pos_diff_anchor_t.detach().cpu().numpy().astype(np.float32)\n\n    def _get_obs_global_anchor_rot_diff(self):\n        self._ensure_ref_to_sim_transform_rigid()\n        ref_pos_sim = self.ref_global_bodylink_pos\n        ref_rot_sim = self.ref_global_bodylink_rot\n\n        if ref_pos_sim is None or ref_rot_sim is None:\n            return np.zeros(6, dtype=np.float32)\n\n        t_robot = torch.as_tensor(\n            self.robot_global_bodylink_pos[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_robot_wxyz = torch.as_tensor(\n            self.robot_global_bodylink_rot[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_robot_wxyz = standardize_quaternion(q_robot_wxyz)\n\n        t_ref_sim = torch.as_tensor(\n            ref_pos_sim[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_ref_sim = torch.as_tensor(\n            ref_rot_sim[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        q_ref_sim = standardize_quaternion(q_ref_sim)\n        _, q_diff_wxyz_t = subtract_frame_transforms(\n            t01=t_robot,\n            q01=q_robot_wxyz,\n            t02=t_ref_sim,\n            q02=q_ref_sim,\n        )\n        q_diff_wxyz_t = standardize_quaternion(q_diff_wxyz_t)\n\n        rot_diff_mat = matrix_from_quat(q_diff_wxyz_t)\n\n        return (\n            rot_diff_mat[..., :2]\n            .reshape(-1)\n            .detach()\n            .cpu()\n            .numpy()\n            .astype(np.float32)\n        )\n\n    def _get_obs_global_bodylink_translation(self) -> np.ndarray:\n        \"\"\"Global body translations in simulator/URDF order, flattened as [num_bodies * 3].\n\n        The body dimension excludes the MuJoCo world body and is assumed to match\n        the NPZ `*_global_translation` arrays (root at index 0).\n        \"\"\"\n        pos = self.robot_global_bodylink_pos.astype(np.float32)  # [B, 3]\n        return pos.reshape(-1)\n\n    def _get_obs_global_bodylink_rotation_quat(self) -> np.ndarray:\n        \"\"\"Global body rotations as XYZW quaternions in simulator/URDF order, flattened [num_bodies * 4].\"\"\"\n        q_wxyz = self.robot_global_bodylink_rot  # [B, 4] in w, x, y, z\n        q_xyzw = np.empty_like(q_wxyz, dtype=np.float32)\n        q_xyzw[..., 0] = q_wxyz[..., 1]\n        q_xyzw[..., 1] = q_wxyz[..., 2]\n        q_xyzw[..., 2] = q_wxyz[..., 3]\n        q_xyzw[..., 3] = q_wxyz[..., 0]\n        return q_xyzw.reshape(-1)\n\n    def _get_obs_global_bodylink_velocity(self) -> np.ndarray:\n        \"\"\"Global body linear velocities in world frame, flattened [num_bodies * 3].\"\"\"\n        lin_vel = self.robot_global_bodylink_lin_vel.astype(\n            np.float32\n        )  # [B, 3]\n        return lin_vel.reshape(-1)\n\n    def _get_obs_global_bodylink_angular_velocity(self) -> np.ndarray:\n        \"\"\"Global body angular velocities in world frame, flattened [num_bodies * 3].\"\"\"\n        ang_vel = self.robot_global_bodylink_ang_vel.astype(\n            np.float32\n        )  # [B, 3]\n        return ang_vel.reshape(-1)\n\n    @property\n    def ref_global_bodylink_pos(self) -> np.ndarray | None:\n        \"\"\"Reference body positions transformed into the simulator global frame.\n\n        Uses the yaw+translation Ref->Sim rigid transform computed from the initial robot\n        global pose so that the reference motion is expressed in the same world frame as\n        the robot (matching XY translation and yaw at frame 0).\n\n        Returns:\n            Array of shape [num_bodies, 3] giving reference positions in simulator world frame,\n            or None if reference globals are not available.\n        \"\"\"\n        if getattr(self, \"ref_global_translation\", None) is None:\n            return None\n        if self.n_motion_frames <= 0:\n            return None\n\n        self._ensure_ref_to_sim_transform_rigid()\n\n        frame_idx = self.ref_motion_frame_idx\n        ref_pos_world = self.ref_global_translation[frame_idx].astype(\n            np.float32\n        )  # [B, 3]\n\n        pos_world_t = torch.as_tensor(\n            ref_pos_world, dtype=torch.float32, device=\"cpu\"\n        )\n\n        q_ref_to_sim = torch.as_tensor(\n            self._ref_to_sim_q_wxyz, dtype=torch.float32, device=\"cpu\"\n        )\n        q_ref_to_sim = q_ref_to_sim.unsqueeze(0).expand(\n            pos_world_t.shape[0], 4\n        )\n\n        t_ref_to_sim = torch.as_tensor(\n            self._ref_to_sim_t, dtype=torch.float32, device=\"cpu\"\n        )\n\n        # Apply yaw rotation + translation based on initial robot state\n        pos_sim_t = (\n            quat_apply(q_ref_to_sim, pos_world_t) + t_ref_to_sim[None, :]\n        )\n\n        return pos_sim_t.detach().cpu().numpy().astype(np.float32)\n\n    @property\n    def ref_global_bodylink_rot(self) -> np.ndarray | None:\n        \"\"\"Reference body rotations transformed into the simulator global frame.\n\n        Uses the yaw component of the Ref->Sim transform so that the reference motion's\n        global yaw is aligned with the robot's initial yaw, while preserving roll/pitch\n        from the motion data.\n\n        Returns:\n            Array of shape [num_bodies, 4] giving reference orientations in WXYZ format,\n            or None if reference globals are not available.\n        \"\"\"\n        if getattr(self, \"ref_global_rotation_quat_xyzw\", None) is None:\n            return None\n        if self.n_motion_frames <= 0:\n            return None\n\n        frame_idx = self.ref_motion_frame_idx\n        ref_rot_xyzw = self.ref_global_rotation_quat_xyzw[frame_idx].astype(\n            np.float32\n        )  # [B, 4] in XYZW\n\n        q_ref_xyzw_t = torch.as_tensor(\n            ref_rot_xyzw, dtype=torch.float32, device=\"cpu\"\n        )\n        q_ref_wxyz_t = xyzw_to_wxyz(q_ref_xyzw_t)\n        q_ref_wxyz_t = standardize_quaternion(q_ref_wxyz_t)\n\n        q_ref_to_sim = torch.as_tensor(\n            self._ref_to_sim_q_wxyz, dtype=torch.float32, device=\"cpu\"\n        )\n        q_ref_to_sim = q_ref_to_sim.unsqueeze(0).expand_as(q_ref_wxyz_t)\n\n        q_ref_sim_wxyz_t = quat_mul(q_ref_to_sim, q_ref_wxyz_t)\n        q_ref_sim_wxyz_t = standardize_quaternion(q_ref_sim_wxyz_t)\n\n        return q_ref_sim_wxyz_t.detach().cpu().numpy().astype(np.float32)\n\n    def _draw_ref_body_spheres_to_scene(\n        self, scene, reset_ngeom: bool\n    ) -> None:\n        \"\"\"Draw blue spheres at reference body positions into a MuJoCo scene.\"\"\"\n        ref_positions_sim = self.ref_global_bodylink_pos\n        if ref_positions_sim is None:\n            if reset_ngeom:\n                scene.ngeom = 0\n            return\n\n        if reset_ngeom:\n            scene.ngeom = 0\n\n        radius = float(self.config.get(\"ref_marker_radius\", 0.03))\n        rgba = np.array([0.8, 0.0, 0.0, 1.0], dtype=np.float32)\n        size = np.array([radius, 0.0, 0.0], dtype=np.float32)\n        mat = np.eye(3, dtype=np.float32).reshape(-1)\n\n        start = int(scene.ngeom)\n        idx = 0\n        for pos in ref_positions_sim:\n            geom_id = start + idx\n            if geom_id >= scene.maxgeom:\n                break\n            mujoco.mjv_initGeom(\n                scene.geoms[geom_id],\n                mujoco.mjtGeom.mjGEOM_SPHERE,\n                size,\n                pos.astype(np.float32),\n                mat,\n                rgba,\n            )\n            idx += 1\n        scene.ngeom = start + idx\n\n    def _get_obs_rel_anchor_lin_vel(self):\n        # Anchor linear velocity expressed in the anchor frame (IsaacLab semantics)\n        q_anchor_wxyz = torch.as_tensor(\n            self.robot_global_bodylink_rot[self.anchor_body_idx],\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n        v_local_t = quat_apply(\n            quat_inv(q_anchor_wxyz),\n            torch.as_tensor(\n                self.robot_global_bodylink_lin_vel[self.anchor_body_idx],\n                dtype=torch.float32,\n                device=\"cpu\",\n            ),\n        )\n        return v_local_t.detach().cpu().numpy().astype(np.float32)\n\n    def _get_obs_projected_gravity(self):\n        q = torch.as_tensor(\n            self.robot_global_bodylink_rot[self.root_body_idx],\n            dtype=torch.float32,\n        )\n        qw, qx, qy, qz = q[0], q[1], q[2], q[3]\n        gravity_orientation = torch.zeros(3, dtype=torch.float32, device=\"cpu\")\n        gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy)\n        gravity_orientation[1] = -2.0 * (qz * qy + qw * qx)\n        gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz)\n        return gravity_orientation.detach().cpu().numpy().astype(np.float32)\n\n    def _get_obs_dof_pos(self):\n        pos_mu = self.robot_dof_pos\n        pos_onnx = pos_mu[self.mu_to_onnx]\n        return (pos_onnx - self.default_angles_onnx.astype(np.float32)).astype(\n            np.float32\n        )\n\n    def _get_obs_dof_vel(self):\n        vel_mu = self.robot_dof_vel\n        vel_onnx = vel_mu[self.mu_to_onnx]\n        return vel_onnx.astype(np.float32)\n\n    def _record_robot_states(self) -> None:\n        \"\"\"Record current robot DOF and global body states for offline NPZ dumping.\n\n        - DOF states are stored in reference DOF order (config.robot.dof_names).\n        - Body states are stored in dataset/URDF order (config.robot.body_names).\n        \"\"\"\n        if self.command_mode != \"motion_tracking\":\n            return\n        if self.ref_dof_pos is None or self.n_motion_frames <= 0:\n            return\n        if len(self._robot_dof_pos_seq) >= self.n_motion_frames:\n            return\n\n        # Joint positions/velocities from Unitree lowstate in actuator (MuJoCo) order\n        pos_mu = self.robot_dof_pos\n        vel_mu = self.robot_dof_vel\n\n        # Map MuJoCo actuator order -> reference DOF order\n        num_dofs = len(self.dof_names_ref_motion)\n        pos_ref = np.zeros(num_dofs, dtype=np.float32)\n        vel_ref = np.zeros(num_dofs, dtype=np.float32)\n        for mu_idx, ref_idx in enumerate(self.mu_to_ref):\n            pos_ref[ref_idx] = pos_mu[mu_idx]\n            vel_ref[ref_idx] = vel_mu[mu_idx]\n\n        self._robot_dof_pos_seq.append(pos_ref)\n        self._robot_dof_vel_seq.append(vel_ref)\n        if self._prev_recorded_dof_vel_ref is None:\n            acc_ref = np.zeros_like(vel_ref, dtype=np.float32)\n        else:\n            acc_ref = (vel_ref - self._prev_recorded_dof_vel_ref) / np.float32(\n                self.policy_dt\n            )\n        self._prev_recorded_dof_vel_ref = vel_ref.copy()\n        self._robot_dof_acc_seq.append(acc_ref.astype(np.float32))\n\n        # Global bodylink states in dataset/URDF order\n        body_count = int(self.robot_global_bodylink_pos.shape[0])\n        trans = self._get_obs_global_bodylink_translation().reshape(\n            body_count, 3\n        )\n        rot = self._get_obs_global_bodylink_rotation_quat().reshape(\n            body_count, 4\n        )\n        vel = self._get_obs_global_bodylink_velocity().reshape(body_count, 3)\n        ang_vel = self._get_obs_global_bodylink_angular_velocity().reshape(\n            body_count, 3\n        )\n\n        self._robot_global_translation_seq.append(trans)\n        self._robot_global_rotation_quat_seq.append(rot)\n        self._robot_global_velocity_seq.append(vel)\n        self._robot_global_angular_velocity_seq.append(ang_vel)\n\n    def load_specific_motion(self, npz_path):\n        with np.load(npz_path, allow_pickle=True) as npz:\n            self.ref_global_translation = npz[\"ref_global_translation\"]\n            self.ref_global_rotation_quat_xyzw = npz[\n                \"ref_global_rotation_quat\"\n            ]\n            self.ref_global_velocity = npz[\"ref_global_velocity\"]\n            self.ref_global_angular_velocity = npz[\n                \"ref_global_angular_velocity\"\n            ]\n            self.ref_dof_pos = npz[\"ref_dof_pos\"]\n            self.ref_dof_vel = npz[\"ref_dof_vel\"]\n            raw_filter_cutoff_hz = (\n                np.array(npz[\"filter_cutoff_hz\"]).astype(np.float32)\n                if \"filter_cutoff_hz\" in npz\n                else None\n            )\n\n        self.n_motion_frames = self.ref_global_translation.shape[0]\n        # self.filter_cutoff_hz = self._normalize_filter_cutoff_hz(\n        #     raw_filter_cutoff_hz, self.n_motion_frames\n        # )\n        self.filter_cutoff_hz = 1.0\n        self._ref_to_sim_q_wxyz = np.array(\n            [1.0, 0.0, 0.0, 0.0], dtype=np.float32\n        )\n        self._ref_to_sim_t = np.zeros(3, dtype=np.float32)\n        self._ref_to_sim_ready = True\n\n    def reset_state_teleport(self):\n        self.counter = 0\n        self.motion_frame_idx = 0\n\n        mujoco.mj_resetDataKeyframe(self.m, self.d, 0)\n\n        has_ref_motion = (\n            self.ref_dof_pos is not None\n            and self.ref_dof_vel is not None\n            and self.ref_global_translation is not None\n            and self.ref_global_rotation_quat_xyzw is not None\n            and self.ref_global_velocity is not None\n            and self.ref_global_angular_velocity is not None\n        )\n\n        if has_ref_motion:\n            root_pos = self.ref_global_translation[0, 0]  # (x, y, z)\n            root_rot = self.ref_global_rotation_quat_xyzw[0, 0]  # XYZW\n            root_vel = self.ref_global_velocity[0, 0]\n            root_ang = self.ref_global_angular_velocity[0, 0]\n            dof_pos = getattr(\n                self, \"stored_full_ref_dof_pos\", self.ref_dof_pos\n            )[0]\n            dof_vel = getattr(\n                self, \"stored_full_ref_dof_vel\", self.ref_dof_vel\n            )[0]\n\n            self.d.qpos[0:3] = root_pos\n            self.d.qpos[3:7] = [\n                root_rot[3],\n                root_rot[0],\n                root_rot[1],\n                root_rot[2],\n            ]  # XYZW -> WXYZ\n            self.d.qpos[self.actuator_qpos_indices] = dof_pos[self.mu_to_ref]\n\n            self.d.qvel[0:3] = root_vel\n            self.d.qvel[3:6] = root_ang\n            self.d.qvel[self.actuator_qvel_indices] = dof_vel[self.mu_to_ref]\n            self.target_dof_pos_mu = dof_pos[self.mu_to_ref].astype(np.float32)\n            logger.info(\n                \"Teleport reset initialized from reference frame 0 \"\n                \"(root + dof pos/vel)\"\n            )\n        else:\n            self.d.qpos[self.actuator_qpos_indices] = self.default_angles_mu\n            self.d.qvel[self.actuator_qvel_indices] = 0.0\n            self.target_dof_pos_mu = self.default_angles_mu.astype(np.float32)\n            logger.info(\n                \"Teleport reset initialized from ONNX default joint positions\"\n            )\n\n        self.target_dof_pos_by_name = {\n            self.mjcf_dof_names[i]: float(self.target_dof_pos_mu[i])\n            for i in range(self.m.nu)\n        }\n        mujoco.mj_forward(self.m, self.d)\n\n        if self.use_kv_cache and self.policy_kv_shape:\n            shape = [\n                d if isinstance(d, int) else 1 for d in self.policy_kv_shape\n            ]\n            self.policy_kv_cache = np.zeros(shape, dtype=np.float32)\n\n        self._robot_dof_pos_seq = []\n        self._robot_dof_vel_seq = []\n        self._robot_dof_acc_seq = []\n        self._robot_dof_torque_seq = []\n        self._robot_low_level_dof_torque_seq = []\n        self._robot_low_level_foot_contact_seq = []\n        self._robot_low_level_foot_normal_force_seq = []\n        self._robot_low_level_foot_tangent_speed_seq = []\n        self._robot_actions_seq = []\n        self._robot_action_rate_seq = []\n        self._prev_recorded_dof_vel_ref = None\n        self._prev_actions_onnx = None\n        self._reset_action_ema_filter()\n        self._reset_action_delay_randomization()\n        self._prev_low_level_foot_geom_centers = None\n        self._robot_global_translation_seq = []\n        self._robot_global_rotation_quat_seq = []\n        self._robot_global_velocity_seq = []\n        self._robot_global_angular_velocity_seq = []\n        self._robot_moe_expert_indices_seq = []\n        self._robot_moe_expert_logits_seq = []\n        self._reset_onnx_io_dump_buffers()\n\n    def save_batch_result(self, output_path, meta_info):\n        import json\n\n        metadata = dict(meta_info)\n        metadata.setdefault(\n            \"robot_low_level_torque_dt\",\n            float(getattr(self, \"simulation_dt\", 1.0 / 200.0)),\n        )\n        metadata.setdefault(\n            \"robot_low_level_contact_dt\",\n            float(getattr(self, \"simulation_dt\", 1.0 / 200.0)),\n        )\n        robot_moe_expert_indices, robot_moe_expert_logits = (\n            self._get_stacked_moe_routing_tensors()\n        )\n        (\n            robot_low_level_foot_contact,\n            robot_low_level_foot_normal_force,\n            robot_low_level_foot_tangent_speed,\n        ) = self._get_stacked_low_level_foot_contact_tensors()\n\n        res = {\n            \"robot_dof_pos\": np.stack(self._robot_dof_pos_seq),\n            \"robot_dof_vel\": np.stack(self._robot_dof_vel_seq),\n            \"robot_dof_acc\": np.stack(self._robot_dof_acc_seq),\n            \"robot_dof_torque\": np.stack(self._robot_dof_torque_seq),\n            \"robot_low_level_dof_torque\": np.stack(\n                self._robot_low_level_dof_torque_seq\n            ),\n            \"robot_low_level_foot_contact\": robot_low_level_foot_contact,\n            \"robot_low_level_foot_normal_force\": (\n                robot_low_level_foot_normal_force\n            ),\n            \"robot_low_level_foot_tangent_speed\": (\n                robot_low_level_foot_tangent_speed\n            ),\n            \"robot_low_level_torque_dt\": np.array(\n                getattr(self, \"simulation_dt\", 1.0 / 200.0), dtype=np.float32\n            ),\n            \"robot_low_level_contact_dt\": np.array(\n                getattr(self, \"simulation_dt\", 1.0 / 200.0), dtype=np.float32\n            ),\n            \"robot_action_rate\": np.asarray(\n                self._robot_action_rate_seq, dtype=np.float32\n            ),\n            \"robot_global_translation\": np.stack(\n                self._robot_global_translation_seq\n            ),\n            \"robot_global_rotation_quat\": np.stack(\n                self._robot_global_rotation_quat_seq\n            ),\n            \"robot_global_velocity\": np.stack(self._robot_global_velocity_seq),\n            \"robot_global_angular_velocity\": np.stack(\n                self._robot_global_angular_velocity_seq\n            ),\n            \"ref_dof_pos\": self.ref_dof_pos,\n            \"ref_dof_vel\": self.ref_dof_vel,\n            \"ref_global_translation\": self.ref_global_translation,\n            \"ref_global_rotation_quat\": self.ref_global_rotation_quat_xyzw,\n            \"ref_global_velocity\": self.ref_global_velocity,\n            \"ref_global_angular_velocity\": self.ref_global_angular_velocity,\n            \"metadata\": json.dumps(metadata),\n        }\n        if len(getattr(self, \"_robot_actions_seq\", [])) > 0:\n            res[\"robot_actions\"] = np.stack(\n                self._robot_actions_seq, axis=0\n            ).astype(np.float32)\n        if robot_moe_expert_indices is not None:\n            res[\"robot_moe_expert_indices\"] = robot_moe_expert_indices\n        if robot_moe_expert_logits is not None:\n            res[\"robot_moe_expert_logits\"] = robot_moe_expert_logits\n        np.savez_compressed(output_path, **res)\n\n    def setup(self):\n        \"\"\"Set up the evaluator by loading all required components.\"\"\"\n        self.load_mujoco_model()\n        self._init_low_level_foot_contact_logging()\n        self._build_mjcf_dof_names()\n        self.load_policy()\n        self._apply_onnx_metadata()\n        self._build_actuator_qpos_indices()\n        self._build_dof_mappings()\n        self._build_actuator_name_map()\n        self._build_actuator_force_range_map()\n        self._init_camera_config()\n        self._init_obs_buffers()\n\n        # Initialize keyboard handler for velocity tracking\n        if self.command_mode == \"velocity_tracking\":\n            self.keyboard_handler = VelocityKeyboardHandler(\n                vx_increment=0.1,\n                vy_increment=0.05,\n                vyaw_increment=0.05,\n                vx_limits=(-0.5, 1.0),\n                vy_limits=(-0.3, 0.3),\n                vyaw_limits=(-0.5, 0.5),\n            )\n            logger.info(\n                \"Velocity tracking mode enabled. Keyboard controls:\\n\"\n                \"  W/S: forward/backward velocity\\n\"\n                \"  A/D: left/right velocity\\n\"\n                \"  J/L: turn left/right\\n\"\n                \"  Space/X: reset all\\n\"\n                \"  Keep terminal window focused for keyboard input\"\n            )\n        elif self.command_mode == \"motion_tracking\":\n            m_path = self.config.get(\"motion_npz_path\", \"\")\n            if m_path and os.path.isfile(m_path):\n                self.load_motion_data()\n\n    def _create_eval_progress_bar(self, desc: str, max_steps: int):\n        if self.ref_dof_pos is not None:\n            return tqdm(total=self.n_motion_frames, desc=desc, unit=\"frame\")\n        if max_steps > 0:\n            return tqdm(total=max_steps, desc=desc, unit=\"step\")\n        return None\n\n    def _advance_eval_frame(self, max_steps: int) -> bool:\n        if self.ref_dof_pos is not None:\n            if self.motion_frame_idx >= (self.n_motion_frames - 1):\n                return False\n            self.motion_frame_idx += 1\n            return True\n        if max_steps > 0 and self.counter >= max_steps:\n            return False\n        return True\n\n    def _run_eval_step(self, max_steps: int) -> bool:\n        self._update_policy()\n        self.counter += 1\n        self._apply_control(sleep=True)\n        if self._video_writer is not None:\n            self._maybe_record_frame()\n        return self._advance_eval_frame(max_steps)\n\n    def _build_mjcf_dof_names(self):\n        \"\"\"Build MJCF joint name lists used for control/state indexing.\n\n        - mjcf_dof_names: joint names corresponding to each actuator (actuator order)\n        \"\"\"\n        names = []\n        for i in range(self.m.nu):\n            j_id = int(self.m.actuator_trnid[i][0])\n            j_name = mujoco.mj_id2name(\n                self.m, mujoco._enums.mjtObj.mjOBJ_JOINT, j_id\n            )\n            names.append(j_name)\n        self.mjcf_dof_names = names\n\n    def _build_actuator_qpos_indices(self):\n        \"\"\"Build mapping from actuator index to qpos/qvel indices.\"\"\"\n        self.actuator_qpos_indices = np.zeros(self.m.nu, dtype=np.int32)\n        self.actuator_qvel_indices = np.zeros(self.m.nu, dtype=np.int32)\n        for i in range(self.m.nu):\n            j_id = int(self.m.actuator_trnid[i, 0])\n            self.actuator_qpos_indices[i] = self.m.jnt_qposadr[j_id]\n            self.actuator_qvel_indices[i] = self.m.jnt_dofadr[j_id]\n\n    def _build_actuator_name_map(self):\n        \"\"\"Build mappings from actuator name to indices and MJCF DOF indices.\"\"\"\n        name_to_index = {}\n        actuator_name_to_mu_idx = {}\n        for i in range(self.m.nu):\n            act_name = mujoco.mj_id2name(\n                self.m, mujoco._enums.mjtObj.mjOBJ_ACTUATOR, i\n            )\n            name_to_index[act_name] = i\n            j_id = int(self.m.actuator_trnid[i][0])\n            j_name = mujoco.mj_id2name(\n                self.m, mujoco._enums.mjtObj.mjOBJ_JOINT, j_id\n            )\n            mu_idx = self.mjcf_dof_names.index(j_name)\n            actuator_name_to_mu_idx[act_name] = mu_idx\n        self.actuator_name_to_index = name_to_index\n        self.actuator_name_to_mu_idx = actuator_name_to_mu_idx\n\n    def _build_actuator_force_range_map(self):\n        \"\"\"Build mapping from actuator index to joint actuator force range from XML.\"\"\"\n        self.actuator_force_range = {}\n        for i in range(self.m.nu):\n            j_id = int(self.m.actuator_trnid[i][0])\n            has_limit = False\n            min_force = 0.0\n            max_force = 0.0\n            if j_id >= 0 and j_id < self.m.njnt:\n                if self.m.jnt_actfrclimited[j_id]:\n                    min_force = float(self.m.jnt_actfrcrange[j_id][0])\n                    max_force = float(self.m.jnt_actfrcrange[j_id][1])\n                    if min_force != 0.0 or max_force != 0.0:\n                        has_limit = True\n            if not has_limit:\n                if self.m.actuator_forcelimited[i]:\n                    min_force = float(self.m.actuator_forcerange[i][0])\n                    max_force = float(self.m.actuator_forcerange[i][1])\n                    if min_force != 0.0 or max_force != 0.0:\n                        has_limit = True\n            if has_limit:\n                self.actuator_force_range[i] = (min_force, max_force)\n            else:\n                self.actuator_force_range[i] = None\n\n    def run_simulation_unitree(self):\n        \"\"\"Run simulation using Unitree's official threading/viewer pattern.\"\"\"\n        # Defer heavy deps to runtime to keep default path light\n\n        # Ensure thirdparty simulate_python is on sys.path for imports\n\n        self.counter = 0\n        self.motion_frame_idx = 0\n        self.reset_state_teleport()\n        max_steps = int(self.config.get(\"max_policy_steps\", 0))\n\n        viewer_dt = float(self.config.get(\"unitree_viewer_dt\", 1.0 / 60.0))\n\n        viewer = mujoco.viewer.launch_passive(self.m, self.d)\n\n        # Configure viewer camera to use shared align / tracking settings\n        self._configure_viewer_camera(viewer)\n\n        # Start keyboard listener for velocity tracking\n        if (\n            self.command_mode == \"velocity_tracking\"\n            and self.keyboard_handler is not None\n        ):\n            self.keyboard_handler.start_listener()\n\n        # Optional recording in viewer mode\n        if bool(self.config.get(\"record_video\", False)):\n            self._init_video_tools(tag=\"viewer\")\n\n        pbar = self._create_eval_progress_bar(\"GUI eval\", max_steps)\n\n        locker = threading.Lock()\n        stop_event = threading.Event()\n\n        def simulation_thread():\n            while viewer.is_running() and not stop_event.is_set():\n                with locker:\n                    keep_running = self._run_eval_step(max_steps)\n                    if pbar is not None:\n                        pbar.update(1)\n                if not keep_running:\n                    stop_event.set()\n                    viewer.close()\n\n        def physics_viewer_thread():\n            while viewer.is_running() and not stop_event.is_set():\n                with locker:\n                    # Update camera lookat to track robot root (with small offset for framing)\n                    self._update_camera_lookat(viewer.cam)\n\n                    # Draw reference global bodylink positions as blue spheres when available\n                    self._draw_ref_body_spheres_to_scene(\n                        viewer.user_scn, reset_ngeom=True\n                    )\n\n                    viewer.sync()\n                time.sleep(viewer_dt)\n\n        viewer_thread = Thread(target=physics_viewer_thread)\n        sim_thread = Thread(target=simulation_thread)\n\n        viewer_thread.start()\n        sim_thread.start()\n\n        # Block until viewer closes\n        viewer_thread.join()\n        sim_thread.join()\n\n        # Close progress bar\n        if pbar is not None:\n            pbar.close()\n\n        # Stop keyboard listener\n        if (\n            self.command_mode == \"velocity_tracking\"\n            and self.keyboard_handler is not None\n        ):\n            self.keyboard_handler.stop_listener()\n\n        # Teardown recording\n        self._close_video_tools()\n\n        # Dump robot-augmented motion npz if motion tracking is enabled\n        self._dump_robot_augmented_npz()\n\n    def run_simulation_unitree_headless(self):\n        \"\"\"Run simulation headless (no GUI) with optional video recording.\"\"\"\n        # Defer heavy deps to runtime to keep default path light\n\n        # Initialize\n        self.counter = 0\n        self.motion_frame_idx = 0\n        self.reset_state_teleport()\n        max_steps = int(self.config.get(\"max_policy_steps\", 0))\n\n        # Start keyboard listener for velocity tracking (even in headless mode)\n        if (\n            self.command_mode == \"velocity_tracking\"\n            and self.keyboard_handler is not None\n        ):\n            self.keyboard_handler.start_listener()\n\n        # Optional recording in headless mode\n        if bool(self.config.get(\"record_video\", False)):\n            self._init_video_tools(tag=\"headless\")\n\n        pbar = self._create_eval_progress_bar(\"Headless eval\", max_steps)\n\n        running = True\n        while running:\n            running = self._run_eval_step(max_steps)\n            if pbar is not None:\n                pbar.update(1)\n\n        if pbar is not None:\n            pbar.close()\n\n        # Stop keyboard listener\n        if (\n            self.command_mode == \"velocity_tracking\"\n            and self.keyboard_handler is not None\n        ):\n            self.keyboard_handler.stop_listener()\n\n        self._close_video_tools()\n\n        # Dump robot-augmented motion npz if motion tracking is enabled\n        self._dump_robot_augmented_npz()\n\n    def run_simulation(self):\n        if bool(self.config.get(\"headless\", False)):\n            logger.info(\"Running MuJoCo sim2sim headless\")\n            self.run_simulation_unitree_headless()\n        else:\n            self.run_simulation_unitree()\n\n    def _update_policy(self):\n        # Record robot states once per policy step for offline NPZ dumping\n        self._record_robot_states()\n\n        latest_obs = self.obs_builder.build_policy_obs()\n        policy_obs_np = latest_obs[None, :]\n        input_feed = {}\n        input_feed[self.policy_input_name] = policy_obs_np\n\n        if self.use_kv_cache:\n            if self.policy_kv_cache is None:\n                shape = [\n                    d if isinstance(d, int) else 1\n                    for d in self.policy_kv_shape\n                ]\n                self.policy_kv_cache = np.zeros(shape, dtype=np.float32)\n            # if (\n            #     self.policy_effective_context_len > 0\n            #     and self.counter > 0\n            #     and self.counter % self.policy_effective_context_len == 0\n            # ):\n            #     self.policy_kv_cache.fill(0.0)\n            input_feed[self.policy_kv_input_name] = self.policy_kv_cache\n\n        if self.policy_step_input_name is not None:\n            step_idx = self.counter\n            if self.use_kv_cache and self.policy_effective_context_len > 0:\n                step_idx = self.counter % self.policy_effective_context_len\n            step_tensor = np.array([step_idx], dtype=np.int64)\n            input_feed[self.policy_step_input_name] = step_tensor\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n\n        output_names = [self.policy_output_name]\n        if self.use_kv_cache and self.policy_kv_output_name:\n            output_names.append(self.policy_kv_output_name)\n        for _, indices_name, logits_name in self.policy_moe_layer_output_names:\n            output_names.extend([indices_name, logits_name])\n\n        onnx_output = self.policy_session.run(output_names, input_feed)\n        if self.dump_onnx_io_npy:\n            self._record_onnx_io_frame(input_feed, output_names, onnx_output)\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n\n        raw_actions_onnx = onnx_output[0].reshape(-1)\n        filtered_actions_onnx = self._apply_action_ema_filter(raw_actions_onnx)\n        self.actions_onnx = self._apply_action_delay(filtered_actions_onnx)\n\n        if self.use_kv_cache and len(onnx_output) > 1:\n            new_cache = onnx_output[1]\n\n            self.policy_kv_cache = new_cache\n        output_offset = 1 + int(\n            bool(self.use_kv_cache and self.policy_kv_output_name)\n        )\n        if self.policy_moe_layer_output_names:\n            step_indices = []\n            step_logits = []\n            for (\n                _layer_idx,\n                _indices_name,\n                _logits_name,\n            ) in self.policy_moe_layer_output_names:\n                step_indices.append(\n                    self._flatten_single_step_output(\n                        onnx_output[output_offset],\n                        dtype=np.int64,\n                    )\n                )\n                output_offset += 1\n                step_logits.append(\n                    self._flatten_single_step_output(\n                        onnx_output[output_offset],\n                        dtype=np.float32,\n                    )\n                )\n                output_offset += 1\n            self._robot_moe_expert_indices_seq.append(\n                np.stack(step_indices, axis=0)\n            )\n            self._robot_moe_expert_logits_seq.append(\n                np.stack(step_logits, axis=0)\n            )\n\n        self.target_dof_pos_onnx = (\n            self.actions_onnx * self.action_scale_onnx\n            + self.default_angles_onnx\n        )\n        self.target_dof_pos_mu = self.target_dof_pos_onnx[self.onnx_to_mu]\n        for i, dof_name in enumerate(self.mjcf_dof_names):\n            self.target_dof_pos_by_name[dof_name] = float(\n                self.target_dof_pos_mu[i]\n            )\n\n        if (\n            self.command_mode == \"motion_tracking\"\n            and self.ref_dof_pos is not None\n            and len(self._robot_action_rate_seq) < len(self._robot_dof_pos_seq)\n        ):\n            self._robot_actions_seq.append(\n                self.actions_onnx.astype(np.float32).copy()\n            )\n            if self._prev_actions_onnx is None:\n                action_rate = np.float32(0.0)\n            else:\n                action_rate = np.float32(\n                    np.linalg.norm(self.actions_onnx - self._prev_actions_onnx)\n                    / self.policy_dt\n                )\n            self._prev_actions_onnx = self.actions_onnx.copy()\n            self._robot_action_rate_seq.append(action_rate)\n            self._robot_dof_torque_seq.append(\n                self._compute_pd_torque_command_ref()\n            )\n\n\ndef _get_config_value(config_obj, key: str):\n    value = config_obj.get(key, None)\n    if value is None and config_obj.get(\"eval\", None) is not None:\n        value = config_obj.eval.get(key, None)\n    return value\n\n\ndef _normalize_ckpt_name_list(ckpt_onnx_names):\n    if ckpt_onnx_names is None:\n        return []\n    if isinstance(ckpt_onnx_names, ListConfig):\n        raw_names = list(ckpt_onnx_names)\n    elif isinstance(ckpt_onnx_names, (list, tuple)):\n        raw_names = list(ckpt_onnx_names)\n    else:\n        raise TypeError(\n            \"ckpt_onnx_names must be a list/tuple, \"\n            f\"got {type(ckpt_onnx_names)}\"\n        )\n    normalized_names = []\n    for name in raw_names:\n        name_str = str(name).strip()\n        if name_str != \"\":\n            normalized_names.append(name_str)\n    return normalized_names\n\n\ndef _resolve_multi_ckpt_paths(ckpt_onnx_root_dir, ckpt_onnx_names):\n    root_dir_str = str(ckpt_onnx_root_dir).strip()\n    if root_dir_str == \"\":\n        raise ValueError(\"ckpt_onnx_root_dir cannot be empty\")\n    root_dir = Path(root_dir_str)\n    if not root_dir.is_dir():\n        raise NotADirectoryError(\n            f\"ckpt_onnx_root_dir does not exist or is not a directory: {root_dir}\"\n        )\n\n    requested_names = _normalize_ckpt_name_list(ckpt_onnx_names)\n    if len(requested_names) == 0:\n        raise ValueError(\n            \"ckpt_onnx_names is empty. Please provide checkpoint names \"\n            'like [\"model_1000.onnx\", \"model_2000.onnx\"].'\n        )\n\n    discovered_paths = sorted(root_dir.rglob(\"*.onnx\"))\n    if len(discovered_paths) == 0:\n        raise FileNotFoundError(\n            f\"No .onnx files found under ckpt_onnx_root_dir={root_dir}\"\n        )\n\n    paths_by_name = {}\n    for path in discovered_paths:\n        if path.name not in paths_by_name:\n            paths_by_name[path.name] = []\n        paths_by_name[path.name].append(path)\n\n    selected_paths = []\n    missing_names = []\n    for name in requested_names:\n        candidates = paths_by_name.get(name, [])\n        if len(candidates) == 0:\n            missing_names.append(name)\n            continue\n        if len(candidates) > 1:\n            logger.warning(\n                f\"Found {len(candidates)} ONNX files named '{name}' under \"\n                f\"{root_dir}; selecting the first one: {candidates[0]}\"\n            )\n        selected_paths.append(candidates[0])\n\n    if len(missing_names) > 0:\n        logger.warning(\n            \"Some requested checkpoints were not found under \"\n            f\"{root_dir}: {missing_names}\"\n        )\n    if len(selected_paths) == 0:\n        raise FileNotFoundError(\n            \"None of the requested checkpoints were found under \"\n            f\"{root_dir}. Requested names: {requested_names}\"\n        )\n\n    return selected_paths\n\n\ndef _resolve_eval_ckpt_paths(config_obj):\n    ckpt_onnx_root_dir = _get_config_value(config_obj, \"ckpt_onnx_root_dir\")\n    if (\n        ckpt_onnx_root_dir is not None\n        and str(ckpt_onnx_root_dir).strip() != \"\"\n    ):\n        ckpt_onnx_names = _get_config_value(config_obj, \"ckpt_onnx_names\")\n        return _resolve_multi_ckpt_paths(ckpt_onnx_root_dir, ckpt_onnx_names)\n\n    ckpt_onnx_path = _get_config_value(config_obj, \"ckpt_onnx_path\")\n    if ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == \"\":\n        raise ValueError(\n            \"No ONNX checkpoint is provided. Set ckpt_onnx_path, or set \"\n            \"ckpt_onnx_root_dir + ckpt_onnx_names.\"\n        )\n    ckpt_path = Path(str(ckpt_onnx_path))\n    if not ckpt_path.is_file():\n        raise FileNotFoundError(f\"ONNX checkpoint not found: {ckpt_path}\")\n    return [ckpt_path]\n\n\ndef _checkpoint_tag_from_path(ckpt_path: Path) -> str:\n    match = re.search(r\"model_(\\d+)\", ckpt_path.name)\n    if match:\n        return f\"model_{match.group(1)}\"\n    return ckpt_path.stem\n\n\ndef _build_eval_output_dir(ckpt_path: Path, dataset_name: str) -> Path:\n    ckpt_tag = _checkpoint_tag_from_path(ckpt_path)\n    dir_name = f\"mujoco_eval_output_{ckpt_tag}_{dataset_name}\"\n    return ckpt_path.parent.parent / dir_name\n\n\ndef _build_onnx_io_dump_dir(output_dir: str | Path) -> Path:\n    return Path(output_dir) / ONNX_IO_DUMP_DIRNAME\n\n\ndef _build_onnx_io_dump_path(\n    output_dir: str | Path, source_file: str | Path\n) -> Path:\n    source_stem = Path(source_file).stem\n    return _build_onnx_io_dump_dir(output_dir) / f\"{source_stem}_onnx_io.npy\"\n\n\ndef _build_onnx_io_dump_readme_text() -> str:\n    return \"\"\"# ONNX I/O 导出说明\n\n本目录用于保存 MuJoCo sim2sim 评测过程中导出的 ONNX 输入输出数据。\n\n## 文件组织\n\n- 每个动作片段会生成一个 `.npy` 文件，文件名形如 `<clip_name>_onnx_io.npy`\n- 每个 `.npy` 文件对应一个原始的动作片段 `.npz`\n- 当前只支持默认的 `holomotion` / `MujocoEvaluator` 批量目录评测模式（`motion_npz_dir`）\n\n## 读取方式\n\n`.npy` 文件内部保存的是一个 Python `dict`，读取时需要开启 `allow_pickle=True`：\n\n```python\nimport numpy as np\n\nnpy_path = \"onnx_io_npy/example_clip_onnx_io.npy\"\npayload = np.load(npy_path, allow_pickle=True).item()\n\nprint(payload.keys())\nprint(payload[\"input_names\"])\nprint(payload[\"output_names\"])\nprint(payload[\"inputs\"][\"obs\"].shape)\nprint(payload[\"outputs\"][\"action\"].shape)\n```\n\n## 数据字段\n\n- `input_names`: ONNX 实际输入张量名称列表\n- `output_names`: ONNX 实际输出张量名称列表\n- `inputs`: 按输入张量名称组织的字典，数组第 0 维是帧索引\n- `outputs`: 按输出张量名称组织的字典，数组第 0 维是帧索引\n- `source_npz`: 原始动作片段文件名\n- `onnx_model`: 导出这些张量时使用的 ONNX 模型路径\n\n## 说明\n\n单个 `.npy` 文件只能保存一个顶层对象，因此这里使用 pickled dict 来同时保存输入名称、输出名称以及逐帧堆叠后的 numpy 数组。\n如果某次导出未产生有效 ONNX I/O 数据，`inputs` 和 `outputs` 可能为空字典，读取时请先检查键是否存在。\n\"\"\"\n\n\ndef write_onnx_io_dump_readme(output_dir: str | Path) -> Path:\n    output_dir_path = Path(output_dir)\n    output_dir_path.mkdir(parents=True, exist_ok=True)\n    readme_path = output_dir_path / \"README.md\"\n    readme_path.write_text(_build_onnx_io_dump_readme_text(), encoding=\"utf-8\")\n    return readme_path\n\n\ndef _allocate_actor_counts(num_checkpoints: int, total_actors: int):\n    if num_checkpoints <= 0:\n        raise ValueError(\"num_checkpoints must be > 0\")\n    if total_actors <= 0:\n        raise ValueError(\"total_actors must be > 0\")\n    base = total_actors // num_checkpoints\n    rem = total_actors % num_checkpoints\n    return [base + (1 if i < rem else 0) for i in range(num_checkpoints)]\n\n\ndef _infer_step_from_ckpt_name(ckpt_name: str):\n    match = re.search(r\"model_(\\d+)\", ckpt_name)\n    if match:\n        return int(match.group(1))\n    fallback = re.search(r\"(\\d+)\", ckpt_name)\n    if fallback:\n        return int(fallback.group(1))\n    return None\n\n\ndef _read_total_macro_row(tsv_path: Path):\n    if not tsv_path.is_file():\n        return None\n    with open(tsv_path, \"r\", encoding=\"utf-8\", newline=\"\") as tsv_file:\n        reader = csv.DictReader(tsv_file, delimiter=\"\\t\")\n        for row in reader:\n            dataset_value = str(row.get(\"Dataset\", \"\")).strip().lower()\n            if \"total\" in dataset_value and \"macro\" in dataset_value:\n                return row\n    return None\n\n\ndef _write_total_macro_summary_table(\n    eval_targets, job_log_dir: Path | None = None\n):\n    rows_by_parent = {}\n    for target in eval_targets:\n        output_dir_path = Path(target[\"output_dir\"])\n        ckpt_path = target[\"ckpt_path\"]\n        tsv_path = output_dir_path / \"sub_dataset_macro_mean_metrics.tsv\"\n        total_row = _read_total_macro_row(tsv_path)\n        if total_row is None:\n            logger.warning(\n                \"Skipping aggregated total metrics entry because \"\n                f\"Total (Macro) row is unavailable: {tsv_path}\"\n            )\n            continue\n        parent_dir = output_dir_path.parent\n        if parent_dir not in rows_by_parent:\n            rows_by_parent[parent_dir] = []\n        rows_by_parent[parent_dir].append(\n            {\n                \"step\": _infer_step_from_ckpt_name(ckpt_path.stem),\n                \"total_row\": total_row,\n                \"ckpt_name\": ckpt_path.stem,\n            }\n        )\n\n    for parent_dir, entries in rows_by_parent.items():\n        if len(entries) == 0:\n            continue\n        entries.sort(\n            key=lambda item: (\n                item[\"step\"] is None,\n                item[\"step\"] if item[\"step\"] is not None else 0,\n                item[\"ckpt_name\"],\n            )\n        )\n        metric_columns = list(entries[0][\"total_row\"].keys())\n        available_steps = [\n            entry[\"step\"] for entry in entries if entry[\"step\"] is not None\n        ]\n        if len(available_steps) > 0:\n            step_range = f\"{min(available_steps)}-{max(available_steps)}\"\n        else:\n            step_range = \"na-na\"\n        output_name = f\"mujoco_model-{step_range}_total_metrics.tsv\"\n        output_path = parent_dir / output_name\n        generated_artifacts = [output_path]\n        with open(output_path, \"w\", encoding=\"utf-8\", newline=\"\") as out_file:\n            writer = csv.writer(out_file, delimiter=\"\\t\", lineterminator=\"\\n\")\n            writer.writerow([\"step\"] + metric_columns)\n            for entry in entries:\n                step_value = (\n                    str(entry[\"step\"]) if entry[\"step\"] is not None else \"\"\n                )\n                writer.writerow(\n                    [step_value]\n                    + [\n                        entry[\"total_row\"].get(col, \"\")\n                        for col in metric_columns\n                    ]\n                )\n        logger.info(f\"Saved aggregated total metrics table at: {output_path}\")\n\n        plot_metric_columns = [\n            col for col in metric_columns if col != \"Dataset\"\n        ]\n        if len(plot_metric_columns) > 0:\n            import matplotlib.pyplot as plt\n\n            ncols = 4\n            nrows = (len(plot_metric_columns) + ncols - 1) // ncols\n            fig, axes = plt.subplots(\n                nrows=nrows,\n                ncols=ncols,\n                figsize=(4.0 * ncols, 2.8 * nrows),\n                squeeze=False,\n            )\n\n            for idx, metric_name in enumerate(plot_metric_columns):\n                ax = axes[idx // ncols][idx % ncols]\n                trend_pairs = []\n                for entry in entries:\n                    step_value = entry[\"step\"]\n                    if step_value is None:\n                        continue\n                    raw_metric = entry[\"total_row\"].get(metric_name, \"\")\n                    if str(raw_metric).strip() == \"\":\n                        continue\n                    trend_pairs.append((step_value, float(raw_metric)))\n\n                if len(trend_pairs) == 0:\n                    ax.text(\n                        0.5,\n                        0.5,\n                        \"No valid data\",\n                        ha=\"center\",\n                        va=\"center\",\n                        transform=ax.transAxes,\n                    )\n                    ax.set_title(metric_name)\n                    ax.set_xticks([])\n                    ax.set_yticks([])\n                    ax.grid(False)\n                    continue\n\n                trend_pairs.sort(key=lambda pair: pair[0])\n                plot_steps = [pair[0] for pair in trend_pairs]\n                plot_values = [pair[1] for pair in trend_pairs]\n                ax.plot(plot_steps, plot_values, marker=\"o\", linewidth=1.2)\n                ax.set_title(metric_name)\n                ax.set_xlabel(\"step\")\n                ax.grid(True, alpha=0.3)\n\n            total_axes = nrows * ncols\n            for idx in range(len(plot_metric_columns), total_axes):\n                axes[idx // ncols][idx % ncols].axis(\"off\")\n\n            fig.tight_layout()\n            plot_path = output_path.with_name(\n                f\"{output_path.stem}_all_metric_trends.pdf\"\n            )\n            fig.savefig(plot_path, format=\"pdf\")\n            plt.close(fig)\n            generated_artifacts.append(plot_path)\n            logger.info(f\"Saved combined metric trend plot at: {plot_path}\")\n\n        if job_log_dir is not None:\n            for artifact_path in generated_artifacts:\n                job_log_path = job_log_dir / artifact_path.name\n                shutil.copy2(artifact_path, job_log_path)\n                logger.info(f\"Exported artifact to /job_log: {job_log_path}\")\n\n\ndef process_config(override_config):\n    \"\"\"Process the configuration, merging with training config if available.\"\"\"\n    ckpt_onnx_path = _get_config_value(override_config, \"ckpt_onnx_path\")\n    ckpt_onnx_root_dir = _get_config_value(\n        override_config, \"ckpt_onnx_root_dir\"\n    )\n    if (\n        (ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == \"\")\n        and ckpt_onnx_root_dir is not None\n        and str(ckpt_onnx_root_dir).strip() != \"\"\n    ):\n        ckpt_onnx_names = _get_config_value(override_config, \"ckpt_onnx_names\")\n        resolved_paths = _resolve_multi_ckpt_paths(\n            ckpt_onnx_root_dir, ckpt_onnx_names\n        )\n        ckpt_onnx_path = str(resolved_paths[0])\n        logger.info(\n            \"Using the first resolved checkpoint as config anchor: \"\n            f\"{ckpt_onnx_path}\"\n        )\n\n    model_type = override_config.get(\"model_type\") or \"holomotion\"\n    if model_type == \"gmt\":\n        config_path = Path(\n            \"holomotion/config/evaluation/gmt_eval_mujoco_sim2sim.yaml\"\n        )\n    elif model_type == \"any2track\":\n        config_path = Path(\n            \"holomotion/config/evaluation/any2track_eval_mujoco_sim2sim.json\"\n        )\n    elif model_type == \"sonic\":\n        config_path = Path(\n            \"holomotion/config/evaluation/sonic_eval_mujoco_sim2sim.yaml\"\n        )\n    else:\n        if ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == \"\":\n            raise ValueError(\n                \"Cannot locate training config.yaml for model_type='holomotion' \"\n                \"without an ONNX checkpoint path. Set ckpt_onnx_path, or set \"\n                \"ckpt_onnx_root_dir + ckpt_onnx_names.\"\n            )\n        onnx_path = Path(str(ckpt_onnx_path))\n        # Load training config.yaml from one level above the ONNX path (../onnx_path)\n        config_path = onnx_path.parent.parent / \"config.yaml\"\n    logger.info(f\"Loading training config file from {config_path}\")\n\n    # Ensure ${eval:'...'} expressions are supported during resolution\n    if not OmegaConf.has_resolver(\"eval\"):\n        OmegaConf.register_new_resolver(\"eval\", lambda expr: eval(expr))\n\n    with open(config_path) as file:\n        train_config = OmegaConf.load(file)\n\n    # Merge training config with any overrides\n    config = OmegaConf.merge(train_config, override_config)\n    with open_dict(config):\n        config.model_type = model_type\n\n    # Resolve config values in-place\n    OmegaConf.resolve(config)\n    if (\n        (\n            config.get(\"ckpt_onnx_path\", None) is None\n            or str(config.get(\"ckpt_onnx_path\")).strip() == \"\"\n        )\n        and ckpt_onnx_path is not None\n        and str(ckpt_onnx_path).strip() != \"\"\n    ):\n        with open_dict(config):\n            config.ckpt_onnx_path = str(ckpt_onnx_path)\n    return config\n\n\ndef _create_ray_evaluator(config_dict, model_type):\n    \"\"\"Create evaluator from serializable config dict (used inside Ray actor).\"\"\"\n    from omegaconf import OmegaConf, open_dict\n\n    config = OmegaConf.create(config_dict)\n    if model_type == \"gmt\":\n        from holomotion.src.evaluation.gmt_sim2sim import GMTEvaluator\n\n        return GMTEvaluator(config)\n    if model_type == \"any2track\":\n        from holomotion.src.evaluation.any2track_sim2sim import (\n            Any2TrackEvaluator,\n        )\n\n        return Any2TrackEvaluator(config)\n    if model_type == \"sonic\":\n        from holomotion.src.evaluation.sonic_mujoco_sim2sim import (\n            SonicEvaluator,\n        )\n\n        return SonicEvaluator(config)\n    return MujocoEvaluator(config)\n\n\ndef run_mujoco_sim2sim_eval(override_config: OmegaConf):\n    os.chdir(hydra.utils.get_original_cwd())\n    config = process_config(override_config)\n    is_eval_mode = False\n    dataset_dir = config.get(\"motion_npz_dir\", None)\n    specific_file = config.get(\"motion_npz_path\", None)\n    calc_per_clip_metrics = bool(config.get(\"calc_per_clip_metrics\", False))\n    generate_report = bool(config.get(\"generate_report\", False))\n    dump_npzs_cfg = bool(config.get(\"dump_npzs\", False))\n    dump_onnx_io_npy = bool(config.get(\"dump_onnx_io_npy\", False))\n    dump_npzs = dump_npzs_cfg or calc_per_clip_metrics\n    if calc_per_clip_metrics and not dump_npzs_cfg:\n        logger.info(\n            \"calc_per_clip_metrics=true requires dumped NPZs; \"\n            \"enabling dump_npzs automatically.\"\n        )\n\n    if (\n        dataset_dir\n        and os.path.isdir(str(dataset_dir))\n        and (not specific_file or str(specific_file) == \"\")\n    ):\n        is_eval_mode = True\n\n    if is_eval_mode:\n        logger.info(f\"Mode: EVALUATION on directory: {dataset_dir}\")\n        logger.remove()\n        logger.add(sys.stderr, level=\"INFO\")\n\n        dataset_name = Path(dataset_dir).name\n        ckpt_paths = _resolve_eval_ckpt_paths(config)\n        logger.info(\n            f\"Resolved {len(ckpt_paths)} checkpoint(s) for evaluation.\"\n        )\n        for idx, ckpt_path in enumerate(ckpt_paths):\n            logger.info(f\"  [{idx}] {ckpt_path}\")\n\n        eval_targets = []\n        for ckpt_path in ckpt_paths:\n            output_dir = _build_eval_output_dir(ckpt_path, dataset_name)\n            eval_targets.append(\n                {\n                    \"ckpt_path\": ckpt_path,\n                    \"output_dir\": str(output_dir),\n                }\n            )\n\n        if dump_npzs:\n            for target in eval_targets:\n                os.makedirs(target[\"output_dir\"], exist_ok=True)\n                if dump_onnx_io_npy:\n                    write_onnx_io_dump_readme(\n                        _build_onnx_io_dump_dir(target[\"output_dir\"])\n                    )\n\n            files = sorted(\n                [\n                    os.path.join(root, name)\n                    for root, _, filenames in os.walk(\n                        dataset_dir, followlinks=True\n                    )\n                    for name in filenames\n                    if name.endswith(\".npz\")\n                ]\n            )\n            logger.info(\n                f\"Found {len(files)} files for dataset_dir={dataset_dir}. \"\n                f\"Will evaluate {len(eval_targets)} checkpoint(s).\"\n            )\n\n            if len(files) == 0:\n                logger.warning(\n                    f\"No NPZ files found under dataset_dir={dataset_dir}\"\n                )\n\n            requested_use_gpu = _coerce_config_bool(\n                config.get(\"use_gpu\", True), default=True\n            )\n            num_available_gpus = 0\n            if requested_use_gpu and torch.cuda.is_available():\n                num_available_gpus = int(torch.cuda.device_count())\n            if requested_use_gpu and num_available_gpus == 0:\n                logger.warning(\n                    \"use_gpu=true but no CUDA device is detected; \"\n                    \"Ray actors will run on CPU.\"\n                )\n            if num_available_gpus > 0:\n                logger.info(\n                    f\"Detected {num_available_gpus} CUDA device(s). \"\n                    \"Using Ray for batch evaluation.\"\n                )\n\n            ray_actors_per_gpu = int(config.get(\"ray_actors_per_gpu\", 4))\n            if ray_actors_per_gpu <= 0:\n                raise ValueError(\"ray_actors_per_gpu must be > 0\")\n            ray_multi_ckpt_mode = str(\n                config.get(\"ray_multi_ckpt_mode\", \"split\")\n            )\n            if ray_multi_ckpt_mode not in (\"split\", \"per_checkpoint\"):\n                raise ValueError(\n                    \"ray_multi_ckpt_mode must be one of: \"\n                    \"'split', 'per_checkpoint'\"\n                )\n\n            success_count = 0\n            total_jobs = len(files) * len(eval_targets)\n            if total_jobs > 0:\n                base_config_dict = OmegaConf.to_container(config, resolve=True)\n                base_config_dict.setdefault(\n                    \"ray_evaluator_module\",\n                    \"holomotion.src.evaluation.eval_mujoco_sim2sim\",\n                )\n                if not ray.is_initialized():\n                    ray.init()\n                from holomotion.src.evaluation.ray_evaluator_actor import (\n                    RayEvaluatorActor,\n                )\n\n                if num_available_gpus > 0:\n                    base_actor_count = num_available_gpus * ray_actors_per_gpu\n                    gpus_per_actor = 1.0 / ray_actors_per_gpu\n                    remote_actor = ray.remote(num_gpus=gpus_per_actor)(\n                        RayEvaluatorActor\n                    )\n                else:\n                    base_actor_count = max(1, ray_actors_per_gpu)\n                    gpus_per_actor = 0.0\n                    remote_actor = ray.remote(num_gpus=0)(RayEvaluatorActor)\n\n                if ray_multi_ckpt_mode == \"per_checkpoint\":\n                    actor_counts = [\n                        base_actor_count for _ in range(len(eval_targets))\n                    ]\n                else:\n                    actor_counts = _allocate_actor_counts(\n                        len(eval_targets), base_actor_count\n                    )\n                    if min(actor_counts) <= 0:\n                        raise ValueError(\n                            \"Not enough actor budget to assign at least one actor \"\n                            \"per checkpoint in split mode. Reduce checkpoint count, \"\n                            \"increase ray_actors_per_gpu, or switch to \"\n                            \"ray_multi_ckpt_mode=per_checkpoint.\"\n                        )\n\n                total_actor_count = sum(actor_counts)\n                logger.info(\n                    f\"Ray: {total_actor_count} persistent actors \"\n                    f\"({ray_actors_per_gpu} per GPU, {gpus_per_actor} GPU each)\"\n                )\n                logger.info(\n                    \"Checkpoint actor allocation: \"\n                    + \", \".join(\n                        [\n                            f\"{target['ckpt_path'].name}={actor_counts[idx]}\"\n                            for idx, target in enumerate(eval_targets)\n                        ]\n                    )\n                )\n\n                refs = []\n                for target_idx, target in enumerate(eval_targets):\n                    target_config_dict = dict(base_config_dict)\n                    target_config_dict[\"ckpt_onnx_path\"] = str(\n                        target[\"ckpt_path\"]\n                    )\n                    num_actors = actor_counts[target_idx]\n                    target_actors = [\n                        remote_actor.remote(\n                            target_config_dict, target[\"output_dir\"]\n                        )\n                        for _ in range(num_actors)\n                    ]\n                    for file_idx, file_path in enumerate(files):\n                        actor = target_actors[file_idx % len(target_actors)]\n                        refs.append(actor.run_clip.remote(file_path))\n                pbar = tqdm(\n                    total=total_jobs,\n                    desc=\"Batch Processing (all checkpoints)\",\n                    unit=\"job\",\n                    dynamic_ncols=True,\n                )\n                while refs:\n                    done, refs = ray.wait(refs, num_returns=1)\n                    for ref in done:\n                        status = ray.get(ref)\n                        if status == \"success\":\n                            success_count += 1\n                        pbar.update(1)\n                pbar.close()\n            logger.info(\n                f\"Batch processing done. Success: {success_count}/{total_jobs}\"\n            )\n        else:\n            logger.info(\"Skipping NPZ dumping because dump_npzs=false.\")\n\n        job_log_dir = Path(\"/job_log\")\n        job_log_enabled = job_log_dir.is_dir() and os.access(\n            str(job_log_dir), os.W_OK\n        )\n        if job_log_enabled:\n            logger.info(\n                f\"Detected writable /job_log. Will copy summary TSVs to {job_log_dir}.\"\n            )\n        else:\n            logger.info(\n                \"/job_log is unavailable or not writable. \"\n                \"Skipping summary TSV export.\"\n            )\n\n        postprocess_targets = []\n        for target in eval_targets:\n            output_dir = target[\"output_dir\"]\n            output_dir_path = Path(output_dir)\n            if not output_dir_path.is_dir():\n                logger.warning(\n                    f\"Output directory does not exist, skipping post-processing: {output_dir}\"\n                )\n                continue\n            postprocess_targets.append(target)\n\n        failure_pos_err_thresh_m = float(\n            config.get(\"failure_pos_err_thresh_m\", 0.25)\n        )\n        metric_calculation = str(config.get(\"metric_calculation\", \"per_clip\"))\n        dof_mode = str(config.get(\"dof_mode\", \"29\"))\n\n        ray_parallel_metrics = bool(\n            config.get(\n                \"ray_parallel_metrics_postprocess\",\n                config.get(\"ray_parallel_metrics\", True),\n            )\n        )\n        metrics_threadpool_max_workers = config.get(\n            \"metrics_threadpool_max_workers\", None\n        )\n        should_parallelize_metrics = (\n            ray_parallel_metrics\n            and len(postprocess_targets) > 1\n            and (calc_per_clip_metrics or generate_report or job_log_enabled)\n        )\n        logger.info(\n            \"Metrics config: \"\n            f\"ray_parallel_metrics_postprocess={ray_parallel_metrics}, \"\n            f\"metrics_threadpool_max_workers={metrics_threadpool_max_workers}\"\n        )\n\n        if should_parallelize_metrics:\n            if not ray.is_initialized():\n                ray.init()\n            from holomotion.src.evaluation.ray_metrics_postprocess import (\n                run_metrics_postprocess_job,\n            )\n\n            ray_metrics_num_cpus_cfg = config.get(\n                \"ray_metrics_postprocess_num_cpus\",\n                config.get(\"ray_metrics_num_cpus\", None),\n            )\n            if ray_metrics_num_cpus_cfg is None:\n                ray_metrics_num_cpus = 0.0\n            else:\n                ray_metrics_num_cpus = float(ray_metrics_num_cpus_cfg)\n            if ray_metrics_num_cpus < 0.0:\n                raise ValueError(\"ray_metrics_num_cpus must be >= 0\")\n\n            metric_refs = []\n            for target in postprocess_targets:\n                ckpt_path = target[\"ckpt_path\"]\n                metric_refs.append(\n                    run_metrics_postprocess_job.options(\n                        num_cpus=ray_metrics_num_cpus\n                    ).remote(\n                        output_dir=target[\"output_dir\"],\n                        dataset_name=dataset_name,\n                        calc_per_clip_metrics=calc_per_clip_metrics,\n                        failure_pos_err_thresh_m=failure_pos_err_thresh_m,\n                        metric_calculation=metric_calculation,\n                        dof_mode=dof_mode,\n                        metrics_threadpool_max_workers=metrics_threadpool_max_workers,\n                        generate_report=generate_report,\n                        job_log_dir=str(job_log_dir)\n                        if job_log_enabled\n                        else None,\n                        ckpt_stem=ckpt_path.stem,\n                    )\n                )\n\n            pbar = tqdm(\n                total=len(metric_refs),\n                desc=\"Metrics post-processing (all checkpoints)\",\n                unit=\"ckpt\",\n                dynamic_ncols=True,\n            )\n            while metric_refs:\n                done, metric_refs = ray.wait(metric_refs, num_returns=1)\n                for ref in done:\n                    result = ray.get(ref)\n                    ckpt_stem = str(result.get(\"ckpt_stem\", \"\")).strip()\n                    if ckpt_stem == \"\":\n                        ckpt_stem = \"unknown\"\n                    if calc_per_clip_metrics:\n                        logger.info(\n                            f\"Metric calculation finished for {ckpt_stem}.\"\n                        )\n                    report_path = str(result.get(\"report_path\", \"\")).strip()\n                    if report_path != \"\":\n                        logger.info(\n                            f\"Generated metrics report for {ckpt_stem} at: {report_path}\"\n                        )\n                    exported_tsv = str(\n                        result.get(\"exported_summary_tsv\", \"\")\n                    ).strip()\n                    if exported_tsv != \"\":\n                        logger.info(f\"Exported summary TSV to: {exported_tsv}\")\n                pbar.update(1)\n            pbar.close()\n        else:\n            mean_process_5metrics = None\n            if generate_report:\n                from holomotion.scripts.evaluation import mean_process_5metrics\n\n            for target in postprocess_targets:\n                output_dir = target[\"output_dir\"]\n                output_dir_path = Path(output_dir)\n                ckpt_path = target[\"ckpt_path\"]\n\n                if calc_per_clip_metrics:\n                    logger.info(\n                        \"Starting metric calculation for \"\n                        f\"{ckpt_path.name}: {output_dir}\"\n                    )\n                    run_evaluation(\n                        npz_dir=output_dir,\n                        dataset_suffix=dataset_name,\n                        failure_pos_err_thresh_m=failure_pos_err_thresh_m,\n                        metric_calculation=metric_calculation,\n                        dof_mode=dof_mode,\n                        threadpool_max_workers=metrics_threadpool_max_workers,\n                    )\n                    logger.info(\n                        f\"Metric calculation finished for {ckpt_path.name}.\"\n                    )\n\n                if generate_report:\n                    report_path = mean_process_5metrics.generate_macro_mean_report_from_json_dir(\n                        output_dir\n                    )\n                    logger.info(\n                        f\"Generated metrics report for {ckpt_path.name} at: {report_path}\"\n                    )\n\n                if job_log_enabled:\n                    sub_dataset_tsv = (\n                        output_dir_path / \"sub_dataset_macro_mean_metrics.tsv\"\n                    )\n                    if sub_dataset_tsv.is_file():\n                        export_name = f\"{ckpt_path.stem}_sub_dataset_macro_mean_metrics.tsv\"\n                        export_path = job_log_dir / export_name\n                        shutil.copy2(sub_dataset_tsv, export_path)\n                        logger.info(f\"Exported summary TSV to: {export_path}\")\n                    else:\n                        logger.warning(\n                            \"Summary TSV not found (skip export): \"\n                            f\"{sub_dataset_tsv}\"\n                        )\n\n        _write_total_macro_summary_table(\n            eval_targets,\n            job_log_dir=job_log_dir if job_log_enabled else None,\n        )\n\n    else:\n        if config.get(\"model_type\", \"holomotion\") == \"sonic\":\n            from holomotion.src.evaluation.sonic_mujoco_sim2sim import (\n                SonicEvaluator,\n            )\n\n            evaluator = SonicEvaluator(config)\n        else:\n            evaluator = MujocoEvaluator(config)\n        evaluator.setup()\n        evaluator.run_simulation()\n\n\n@hydra.main(\n    config_path=\"../../config\",\n    config_name=\"evaluation/eval_mujoco_sim2sim\",\n    version_base=None,\n)\ndef main(override_config: OmegaConf):\n    run_mujoco_sim2sim_eval(override_config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/evaluation/eval_velocity_tracking.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport os\nfrom pathlib import Path\n\nimport hydra\nfrom hydra.utils import get_class\nfrom loguru import logger\nfrom omegaconf import OmegaConf\n\nfrom holomotion.src.utils.config import compile_config\nfrom holomotion.src.utils.onnx_export import export_policy_to_onnx\n\n\ndef load_training_config(\n    checkpoint_path: str, eval_config: OmegaConf\n) -> OmegaConf:\n    \"\"\"Load training config from checkpoint directory.\n\n    Args:\n        checkpoint_path: Path to the checkpoint file.\n        eval_config: Full evaluation config (including command line overrides).\n\n    Returns:\n        Merged config with training config as base.\n    \"\"\"\n    checkpoint = Path(checkpoint_path)\n    config_path = checkpoint.parent / \"config.yaml\"\n\n    if not config_path.exists():\n        config_path = checkpoint.parent.parent / \"config.yaml\"\n        if not config_path.exists():\n            logger.warning(\n                \"Training config not found at \"\n                f\"{config_path}, using evaluation config\"\n            )\n            return eval_config\n\n    logger.info(f\"Loading training config from {config_path}\")\n    with open(config_path) as file:\n        train_config = OmegaConf.load(file)\n\n    # Apply eval_overrides from training config if they exist\n    if train_config.get(\"eval_overrides\") is not None:\n        train_config = OmegaConf.merge(\n            train_config, train_config.eval_overrides\n        )\n\n    # Set checkpoint path\n    train_config.checkpoint = checkpoint_path\n    train_config.algo.config.checkpoint = checkpoint_path\n\n    # For evaluation, merge eval_config into train_config\n    config = OmegaConf.merge(train_config, eval_config)\n\n    # For velocity tracking, always keep the robot configuration from training\n    if hasattr(train_config, \"robot\"):\n        config.robot = train_config.robot\n\n    # foce set the terminations and domain rand with eval_config's\n    config.env.config.terminations = eval_config.env.config.terminations\n    config.env.config.domain_rand = eval_config.env.config.domain_rand\n    config.env.config.domain_rand = eval_config.env.config.domain_rand\n\n    return config\n\n\n@hydra.main(\n    config_path=\"../../config\",\n    config_name=\"evaluation/eval_isaaclab\",\n    version_base=None,\n)\ndef main(config: OmegaConf):\n    \"\"\"Evaluate the velocity tracking model.\n\n    Args:\n        config: OmegaConf object containing the evaluation configuration.\n\n    \"\"\"\n    # Load training config first\n    if config.checkpoint is None:\n        raise ValueError(\"Checkpoint path must be provided for evaluation\")\n\n    config = load_training_config(config.checkpoint, config)\n    # Compile config without accelerator (PPO will create it)\n    config = compile_config(config, accelerator=None)\n\n    # Use checkpoint directory as log_dir for offline evaluation\n    log_dir = os.path.dirname(config.checkpoint)\n    headless = config.headless\n\n    # PPO creates Accelerator, AppLauncher, and environment internally\n    algo_class = get_class(config.algo._target_)\n    algo = algo_class(\n        env_config=config.env,\n        config=config.algo.config,\n        log_dir=log_dir,\n        headless=headless,\n        is_offline_eval=True,\n    )\n\n    if (\n        algo.accelerator.is_main_process\n        and os.environ.get(\"TORCH_COMPILE_DISABLE\", \"0\") != \"1\"\n    ):\n        logger.info(\n            \"Tip: set TORCH_COMPILE_DISABLE=1 if Triton/compile errors occur\"\n        )\n\n    if algo.accelerator.is_main_process:\n        eval_log_dir = os.path.dirname(config.checkpoint)\n        with open(os.path.join(eval_log_dir, \"eval_config.yaml\"), \"w\") as f:\n            OmegaConf.save(config, f)\n\n    if hasattr(config, \"checkpoint\") and config.checkpoint is not None:\n        if algo.accelerator.is_main_process:\n            logger.info(\n                f\"Loading checkpoint for evaluation: {config.checkpoint}\"\n            )\n        algo.load(config.checkpoint)\n    else:\n        if algo.accelerator.is_main_process:\n            logger.warning(\"No checkpoint provided for evaluation!\")\n\n    # Export ONNX if requested\n    if config.get(\"export_policy\", True):\n        if algo.accelerator.is_main_process:\n            onnx_name_suffix = config.get(\"onnx_name_suffix\", None)\n            onnx_path = export_policy_to_onnx(\n                algo,\n                config.checkpoint,\n                onnx_name_suffix=onnx_name_suffix,\n                use_kv_cache=config.get(\"use_kv_cache\", True),\n            )\n            logger.info(f\"Successfully exported policy to: {onnx_path}\")\n        algo.accelerator.wait_for_everyone()\n\n    # Run indefinite velocity tracking rollout for visualization\n    algo.offline_evaluate_velocity_tracking()\n    if algo.accelerator.is_main_process:\n        logger.info(\"Velocity tracking evaluation completed!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/evaluation/find_worst_clips.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport json\nimport math\nfrom pathlib import Path\nfrom typing import Dict, Any, List\n\n\nJSON_INPUT_FILE = \"logs/Holomotion/metrics_output_dataset/model_17500.json\"\ninput_path = Path(JSON_INPUT_FILE).expanduser().resolve()\nOUTPUT_JSON_FILE = str(input_path.parent / \"bad_clips.json\")\n\nWORST_PERCENTAGE = 0.2\n\nMETRICS_INFO: Dict[str, Dict[str, str]] = {\n    \"whole_body_joints_dist\": {\n        \"name\": \"Joint Angle Error (Whole Body Average)\",\n        \"unit\": \"rad\",\n        \"direction\": \"higher_is_worse\",\n    },\n}\n\n\ndef find_and_save_bad_clips(\n    data: Dict[str, Any],\n    metrics_info: Dict[str, Dict[str, str]],\n    percentage: float,\n    output_file: str,\n) -> None:\n    per_clip_data: List[Dict[str, Any]] = data.get(\"per_clip\", [])\n    if not per_clip_data:\n        print(\"Error: 'per_clip' not found in JSON data.\")\n        return\n\n    total_clips = len(per_clip_data)\n    num_to_select = math.ceil(total_clips * percentage)\n    if num_to_select == 0 and total_clips > 0:\n        num_to_select = 1\n\n    bad_clips_report: Dict[str, List[Dict[str, Any]]] = {}\n\n    for key, info in metrics_info.items():\n        direction = info.get(\"direction\")\n        if not direction:\n            continue\n\n        sort_descending = direction == \"higher_is_worse\"\n\n        clips_with_metric_value = [\n            {\"motion_key\": clip[\"motion_key\"], \"value\": clip[key]}\n            for clip in per_clip_data\n            if key in clip and \"motion_key\" in clip\n        ]\n\n        if not clips_with_metric_value:\n            print(f\"Warning: no values found for metric '{key}' in data.\")\n            continue\n\n        sorted_clips = sorted(\n            clips_with_metric_value,\n            key=lambda x: x[\"value\"],\n            reverse=sort_descending,\n        )\n\n        worst_clips = sorted_clips[:num_to_select]\n        bad_clips_report[key] = worst_clips\n\n    with open(output_file, \"w\", encoding=\"utf-8\") as f:\n        json.dump(bad_clips_report, f, indent=4, ensure_ascii=False)\n    print(f\"Saved bad-clips report to: {output_file}\")\n\n\ndef main() -> None:\n    if not Path(JSON_INPUT_FILE).is_file():\n        print(f\"Error: JSON input file '{JSON_INPUT_FILE}' not found.\")\n        return\n\n    with open(JSON_INPUT_FILE, \"r\", encoding=\"utf-8\") as f:\n        data = json.load(f)\n\n    find_and_save_bad_clips(\n        data, METRICS_INFO, WORST_PERCENTAGE, OUTPUT_JSON_FILE\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/evaluation/metrics.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom pathlib import Path\nfrom typing import Dict, List, Optional\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\n\nimport argparse\nimport csv\nimport json\nimport os\nimport re\nfrom glob import glob\nfrom zipfile import BadZipFile\n\nimport numpy as np\nimport pandas as pd\nfrom loguru import logger\nfrom scipy.signal import welch\nfrom scipy.spatial.transform import Rotation as sRot\nfrom tabulate import tabulate\nfrom tqdm import tqdm\n\n\nDEFAULT_ROBOT_CONTROL_DT = 1.0 / 50.0\nTORQUE_JUMP_RATIO_EPS = 1e-6\nMIN_WELCH_SAMPLES = 8\nSTABILITY_BURST_WINDOW_SECONDS = 0.5\nTOUCHDOWN_WINDOW_SECONDS = 0.05\nROOT_BODY_INDEX = 0\nPROBABILITY_EPS = 1e-12\n\n\ndef quat_inv(q):\n    return np.concatenate([-q[..., :3], q[..., 3:4]], axis=-1)\n\n\ndef quat_apply(q, v):\n    q = np.asarray(q, dtype=np.float64)\n    v = np.asarray(v, dtype=np.float64)\n\n    xyz = q[:, None, :3]\n    w = q[:, None, 3:4]\n\n    t = 2.0 * np.cross(xyz, v, axis=-1)\n    return v + w * t + np.cross(xyz, t, axis=-1)\n\n\ndef p_mpjpe(predicted: np.ndarray, target: np.ndarray) -> np.ndarray:\n    \"\"\"Compute Procrustes-aligned MPJPE between predicted and ground truth.\n\n    Reference:\n        This function is inspired by and partially adapted from the SMPLSim:\n        https://github.com/ZhengyiLuo/SMPLSim/blob/0d672790a7672f28361d59dadd98ae2fc1b9685e/smpl_sim/smpllib/smpl_eval.py.\n\n    \"\"\"\n    assert predicted.shape == target.shape\n\n    mu_x = np.mean(target, axis=1, keepdims=True)\n    mu_y = np.mean(predicted, axis=1, keepdims=True)\n\n    x0 = target - mu_x\n    y0 = predicted - mu_y\n\n    norm_x = np.sqrt(np.sum(x0**2, axis=(1, 2), keepdims=True))\n    norm_y = np.sqrt(np.sum(y0**2, axis=(1, 2), keepdims=True))\n\n    x0 /= norm_x\n    y0 /= norm_y\n\n    h = np.matmul(x0.transpose(0, 2, 1), y0)\n    # Per-frame SVD with graceful handling for non-convergence: mark those frames as NaN\n    batch_size = int(h.shape[0])\n    jdim = int(h.shape[1])\n    u = np.empty((batch_size, jdim, jdim), dtype=h.dtype)\n    s = np.empty((batch_size, jdim), dtype=h.dtype)\n    vt = np.empty((batch_size, jdim, jdim), dtype=h.dtype)\n    for i in range(batch_size):\n        try:\n            ui, si, vti = np.linalg.svd(h[i])\n            u[i] = ui\n            s[i] = si\n            vt[i] = vti\n        except np.linalg.LinAlgError:\n            u[i].fill(np.nan)\n            s[i].fill(np.nan)\n            vt[i].fill(np.nan)\n    v = vt.transpose(0, 2, 1)\n    r = np.matmul(v, u.transpose(0, 2, 1))\n\n    # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1\n    sign_det_r = np.sign(np.expand_dims(np.linalg.det(r), axis=1))\n    v[:, :, -1] *= sign_det_r\n    s[:, -1] *= sign_det_r.flatten()\n    r = np.matmul(v, u.transpose(0, 2, 1))  # Corrected rotation\n\n    tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2)\n\n    a = tr * norm_x / norm_y  # Scale\n    t = mu_x - a * np.matmul(mu_y, r)  # Translation\n\n    predicted_aligned = a * np.matmul(predicted, r) + t\n\n    return np.linalg.norm(\n        predicted_aligned - target, axis=len(target.shape) - 1\n    )\n\n\ndef _parse_clip_len_from_name(filename: str) -> Optional[int]:\n    \"\"\"Extract clip length from filename suffix '__start_XXX_len_N'.\"\"\"\n    m = re.search(r\"__start_\\d+_len_(\\d+)\", os.path.basename(filename))\n    return int(m.group(1)) if m else None\n\n\ndef _parse_metadata_entry(raw_metadata) -> Dict[str, object]:\n    if raw_metadata is None:\n        return {}\n\n    parsed = raw_metadata\n    if isinstance(parsed, np.ndarray):\n        if parsed.shape != ():\n            return {}\n        parsed = parsed.item()\n\n    if isinstance(parsed, dict):\n        return parsed\n\n    if isinstance(parsed, bytes):\n        parsed = parsed.decode(\"utf-8\")\n\n    if isinstance(parsed, str):\n        try:\n            obj = json.loads(parsed)\n        except json.JSONDecodeError:\n            return {}\n        return obj if isinstance(obj, dict) else {}\n\n    return {}\n\n\ndef _extract_robot_control_dt(\n    metadata: Dict[str, object], raw_data: Dict[str, np.ndarray]\n) -> float:\n    if \"robot_low_level_torque_dt\" in raw_data:\n        raw_dt = np.asarray(raw_data[\"robot_low_level_torque_dt\"]).item()\n    else:\n        raw_dt = metadata.get(\n            \"robot_low_level_torque_dt\",\n            metadata.get(\"robot_control_dt\", DEFAULT_ROBOT_CONTROL_DT),\n        )\n    try:\n        robot_control_dt = float(raw_dt)\n    except (TypeError, ValueError):\n        return DEFAULT_ROBOT_CONTROL_DT\n\n    if not np.isfinite(robot_control_dt) or robot_control_dt <= 0.0:\n        return DEFAULT_ROBOT_CONTROL_DT\n    return robot_control_dt\n\n\ndef _extract_low_level_contact_dt(\n    metadata: Dict[str, object],\n    raw_data: Dict[str, np.ndarray],\n    robot_control_dt: float,\n) -> float:\n    if \"robot_low_level_contact_dt\" in raw_data:\n        raw_dt = np.asarray(raw_data[\"robot_low_level_contact_dt\"]).item()\n    else:\n        raw_dt = metadata.get(\n            \"robot_low_level_contact_dt\",\n            metadata.get(\n                \"robot_low_level_torque_dt\",\n                metadata.get(\"robot_control_dt\", robot_control_dt),\n            ),\n        )\n    try:\n        contact_dt = float(raw_dt)\n    except (TypeError, ValueError):\n        return robot_control_dt\n\n    if not np.isfinite(contact_dt) or contact_dt <= 0.0:\n        return robot_control_dt\n    return contact_dt\n\n\ndef _aggregate_sample_metric_to_frames(\n    sample_metric: np.ndarray, num_frames: int\n) -> np.ndarray:\n    if int(sample_metric.shape[0]) == num_frames:\n        return sample_metric.astype(float, copy=False)\n    if num_frames <= 0:\n        return np.empty((0,), dtype=float)\n\n    aggregated = np.full((num_frames,), np.nan, dtype=float)\n    for frame_idx, chunk in enumerate(\n        np.array_split(sample_metric, num_frames)\n    ):\n        if chunk.size == 0:\n            continue\n        if np.all(np.isnan(chunk)):\n            continue\n        aggregated[frame_idx] = float(np.nanmean(chunk))\n    return aggregated\n\n\ndef _compute_torque_jump_series(\n    torque_samples: np.ndarray, torque_dt: float\n) -> tuple[np.ndarray, np.ndarray]:\n    num_samples = int(torque_samples.shape[0])\n    torque_jump_norm = np.full((num_samples,), np.nan, dtype=float)\n    torque_jump_ratio = np.full((num_samples,), np.nan, dtype=float)\n    if num_samples <= 1:\n        return torque_jump_norm, torque_jump_ratio\n\n    torque_mag = np.linalg.norm(torque_samples, axis=1)\n    torque_delta_norm = np.linalg.norm(\n        torque_samples[1:] - torque_samples[:-1], axis=1\n    )\n    torque_jump_norm[1:] = torque_delta_norm / torque_dt\n    torque_scale = np.maximum(\n        np.maximum(torque_mag[1:], torque_mag[:-1]), TORQUE_JUMP_RATIO_EPS\n    )\n    torque_jump_ratio[1:] = torque_delta_norm / torque_scale\n    return torque_jump_norm, torque_jump_ratio\n\n\ndef _safe_nanpercentile(values: np.ndarray, q: float) -> float:\n    arr = np.asarray(values, dtype=float).reshape(-1)\n    arr = arr[np.isfinite(arr)]\n    if arr.size == 0:\n        return float(\"nan\")\n    return float(np.nanpercentile(arr, q))\n\n\ndef _safe_nanmean(values: np.ndarray) -> float:\n    arr = np.asarray(values, dtype=float).reshape(-1)\n    arr = arr[np.isfinite(arr)]\n    if arr.size == 0:\n        return float(\"nan\")\n    return float(np.mean(arr))\n\n\ndef _safe_nanmedian(values: np.ndarray) -> float:\n    arr = np.asarray(values, dtype=float).reshape(-1)\n    arr = arr[np.isfinite(arr)]\n    if arr.size == 0:\n        return float(\"nan\")\n    return float(np.median(arr))\n\n\ndef _compute_rolling_nanmean_max(\n    values: np.ndarray, window_size: int\n) -> float:\n    arr = np.asarray(values, dtype=float).reshape(-1)\n    if arr.size == 0:\n        return float(\"nan\")\n    if window_size <= 1:\n        return float(np.nanmax(arr))\n\n    best = float(\"nan\")\n    max_start = int(arr.size) - int(window_size) + 1\n    if max_start <= 0:\n        if np.all(np.isnan(arr)):\n            return float(\"nan\")\n        return float(np.nanmean(arr))\n\n    for start in range(max_start):\n        window = arr[start : start + window_size]\n        if np.all(np.isnan(window)):\n            continue\n        mean_value = float(np.nanmean(window))\n        if np.isnan(best) or mean_value > best:\n            best = mean_value\n    return best\n\n\ndef _integrate_psd_band(\n    frequencies: np.ndarray,\n    power_density: np.ndarray,\n    low_hz: float,\n    high_hz: float,\n) -> float:\n    if (\n        not np.isfinite(low_hz)\n        or not np.isfinite(high_hz)\n        or high_hz <= low_hz\n    ):\n        return float(\"nan\")\n    band_mask = (frequencies >= low_hz) & (frequencies <= high_hz)\n    if not np.any(band_mask):\n        return float(\"nan\")\n    band_freq = frequencies[band_mask]\n    band_power = power_density[band_mask]\n    if band_freq.size == 1:\n        return float(band_power[0])\n    return float(np.trapz(band_power, band_freq))\n\n\ndef _compute_psd_high_frequency_ratio(\n    signal_values: np.ndarray,\n    sample_dt: float,\n    *,\n    high_band_low_hz: float,\n    band_high_hz: float,\n    band_low_hz: float = 0.5,\n) -> float:\n    samples = np.asarray(signal_values, dtype=float).reshape(-1)\n    samples = samples[np.isfinite(samples)]\n    if samples.size < MIN_WELCH_SAMPLES:\n        return float(\"nan\")\n\n    sample_rate_hz = 1.0 / float(sample_dt)\n    max_band_hz = min(float(band_high_hz), 0.45 * sample_rate_hz)\n    if max_band_hz <= max(float(band_low_hz), float(high_band_low_hz)):\n        return float(\"nan\")\n\n    nperseg = min(int(samples.size), 256)\n    frequencies, power_density = welch(\n        samples,\n        fs=sample_rate_hz,\n        nperseg=nperseg,\n        detrend=\"constant\",\n        average=\"mean\",\n    )\n    total_power = _integrate_psd_band(\n        frequencies, power_density, float(band_low_hz), max_band_hz\n    )\n    high_power = _integrate_psd_band(\n        frequencies, power_density, float(high_band_low_hz), max_band_hz\n    )\n    if (\n        not np.isfinite(total_power)\n        or total_power <= 0.0\n        or not np.isfinite(high_power)\n    ):\n        return float(\"nan\")\n    return float(high_power / total_power)\n\n\ndef _compute_torque_chatter_hf_ratio(\n    low_level_torque: np.ndarray, low_level_dt: float\n) -> float:\n    torque_samples = np.asarray(low_level_torque, dtype=float)\n    if torque_samples.ndim != 2 or torque_samples.shape[0] < MIN_WELCH_SAMPLES:\n        return float(\"nan\")\n\n    ratios = []\n    for joint_idx in range(int(torque_samples.shape[1])):\n        ratio = _compute_psd_high_frequency_ratio(\n            torque_samples[:, joint_idx],\n            low_level_dt,\n            high_band_low_hz=10.0,\n            band_high_hz=40.0,\n        )\n        if np.isfinite(ratio):\n            ratios.append(ratio)\n    if len(ratios) == 0:\n        return float(\"nan\")\n    return float(np.mean(ratios))\n\n\ndef _compute_torso_roll_pitch_stability_metrics(\n    robot_global_angular_velocity: np.ndarray,\n    robot_control_dt: float,\n) -> Dict[str, float]:\n    angular_velocity = np.asarray(robot_global_angular_velocity, dtype=float)\n    if angular_velocity.ndim != 3 or angular_velocity.shape[0] == 0:\n        return {\n            \"torso_rp_hf_ratio\": float(\"nan\"),\n            \"torso_rp_angacc_p95\": float(\"nan\"),\n        }\n\n    torso_roll_pitch_vel = angular_velocity[:, ROOT_BODY_INDEX, :2]\n    torso_roll_pitch_speed = np.linalg.norm(torso_roll_pitch_vel, axis=1)\n    hf_ratio = _compute_psd_high_frequency_ratio(\n        torso_roll_pitch_speed,\n        robot_control_dt,\n        high_band_low_hz=5.0,\n        band_high_hz=20.0,\n    )\n\n    if torso_roll_pitch_vel.shape[0] <= 1:\n        angacc_p95 = float(\"nan\")\n    else:\n        roll_pitch_angacc = np.diff(torso_roll_pitch_vel, axis=0) / float(\n            robot_control_dt\n        )\n        roll_pitch_angacc_mag = np.linalg.norm(roll_pitch_angacc, axis=1)\n        angacc_p95 = _safe_nanpercentile(roll_pitch_angacc_mag, 95.0)\n\n    return {\n        \"torso_rp_hf_ratio\": hf_ratio,\n        \"torso_rp_angacc_p95\": angacc_p95,\n    }\n\n\ndef _compute_expert_switching_js_div(\n    robot_moe_expert_logits: np.ndarray | None,\n) -> float:\n    if robot_moe_expert_logits is None:\n        return float(\"nan\")\n\n    logits = np.asarray(robot_moe_expert_logits, dtype=float)\n    if logits.ndim != 3 or logits.shape[0] <= 1 or logits.shape[-1] <= 1:\n        return float(\"nan\")\n\n    if not np.all(np.isfinite(logits)):\n        return float(\"nan\")\n\n    shifted_logits = logits - np.max(logits, axis=-1, keepdims=True)\n    probs = np.exp(shifted_logits)\n    probs /= np.sum(probs, axis=-1, keepdims=True)\n\n    prev_probs = np.clip(probs[:-1], PROBABILITY_EPS, 1.0)\n    next_probs = np.clip(probs[1:], PROBABILITY_EPS, 1.0)\n    mixture = 0.5 * (prev_probs + next_probs)\n\n    kl_prev = np.sum(\n        prev_probs * (np.log(prev_probs) - np.log(mixture)), axis=-1\n    )\n    kl_next = np.sum(\n        next_probs * (np.log(next_probs) - np.log(mixture)), axis=-1\n    )\n    js_divergence = 0.5 * (kl_prev + kl_next) / np.log(2.0)\n    return _safe_nanmean(js_divergence)\n\n\ndef _compute_contact_stability_metrics(\n    foot_contact_samples: np.ndarray | None,\n    foot_normal_force_samples: np.ndarray | None,\n    foot_tangent_speed_samples: np.ndarray | None,\n    contact_dt: float,\n) -> Dict[str, float]:\n    metrics = {\n        \"foot_contact_toggle_rate\": float(\"nan\"),\n        \"foot_impact_force_p95\": float(\"nan\"),\n        \"stance_slip_speed_p95\": float(\"nan\"),\n    }\n    if (\n        foot_contact_samples is None\n        or foot_normal_force_samples is None\n        or foot_tangent_speed_samples is None\n    ):\n        return metrics\n\n    contact = np.asarray(foot_contact_samples, dtype=float)\n    normal_force = np.asarray(foot_normal_force_samples, dtype=float)\n    tangent_speed = np.asarray(foot_tangent_speed_samples, dtype=float)\n    if (\n        contact.shape != normal_force.shape\n        or contact.shape != tangent_speed.shape\n        or contact.ndim != 2\n        or contact.shape[1] != 2\n    ):\n        return metrics\n\n    finite_contact = np.isfinite(contact)\n    if not np.any(finite_contact):\n        return metrics\n\n    contact_binary = np.where(contact >= 0.5, 1.0, 0.0)\n    valid_pair_mask = finite_contact[1:] & finite_contact[:-1]\n    toggle_count = int(\n        np.sum(\n            np.abs(contact_binary[1:] - contact_binary[:-1]) * valid_pair_mask\n        )\n    )\n    clip_duration_seconds = float(contact.shape[0]) * float(contact_dt)\n    if clip_duration_seconds > 0.0:\n        metrics[\"foot_contact_toggle_rate\"] = (\n            float(toggle_count) / clip_duration_seconds\n        )\n\n    touchdown_window = max(\n        1, int(round(TOUCHDOWN_WINDOW_SECONDS / float(contact_dt)))\n    )\n    touchdown_peaks = []\n    for foot_idx in range(2):\n        foot_contact = contact_binary[:, foot_idx]\n        foot_force = normal_force[:, foot_idx]\n        onset_mask = np.zeros_like(foot_contact, dtype=bool)\n        onset_mask[0] = foot_contact[0] >= 0.5\n        onset_mask[1:] = (foot_contact[1:] >= 0.5) & (foot_contact[:-1] < 0.5)\n        for onset_idx in np.flatnonzero(onset_mask):\n            window = foot_force[onset_idx : onset_idx + touchdown_window]\n            if window.size == 0 or np.all(~np.isfinite(window)):\n                continue\n            touchdown_peaks.append(float(np.nanmax(window)))\n    metrics[\"foot_impact_force_p95\"] = _safe_nanpercentile(\n        np.asarray(touchdown_peaks, dtype=float), 95.0\n    )\n\n    stance_slip_mask = (contact_binary >= 0.5) & np.isfinite(tangent_speed)\n    if np.any(stance_slip_mask):\n        metrics[\"stance_slip_speed_p95\"] = _safe_nanpercentile(\n            tangent_speed[stance_slip_mask], 95.0\n        )\n    return metrics\n\n\ndef _compute_clip_stability_summary(\n    data: Dict[str, np.ndarray],\n    robot_control_dt: float,\n    low_level_contact_dt: float,\n) -> Dict[str, float]:\n    robot_low_level_dof_torque = (\n        np.asarray(data[\"robot_low_level_dof_torque\"])\n        if \"robot_low_level_dof_torque\" in data\n        else None\n    )\n    if robot_low_level_dof_torque is None and \"robot_dof_torque\" in data:\n        robot_low_level_dof_torque = np.asarray(data[\"robot_dof_torque\"])\n\n    if robot_low_level_dof_torque is None:\n        torque_chatter_hf_ratio = float(\"nan\")\n        torque_jump_burst_max = float(\"nan\")\n    else:\n        torque_chatter_hf_ratio = _compute_torque_chatter_hf_ratio(\n            robot_low_level_dof_torque, low_level_contact_dt\n        )\n        _, torque_jump_ratio = _compute_torque_jump_series(\n            robot_low_level_dof_torque, low_level_contact_dt\n        )\n        torque_jump_window = max(\n            1,\n            int(\n                round(\n                    STABILITY_BURST_WINDOW_SECONDS\n                    / float(low_level_contact_dt)\n                )\n            ),\n        )\n        torque_jump_burst_max = _compute_rolling_nanmean_max(\n            torque_jump_ratio[1:], torque_jump_window\n        )\n\n    torso_metrics = _compute_torso_roll_pitch_stability_metrics(\n        np.asarray(data[\"robot_global_angular_velocity\"]),\n        robot_control_dt,\n    )\n    contact_metrics = _compute_contact_stability_metrics(\n        np.asarray(data[\"robot_low_level_foot_contact\"])\n        if \"robot_low_level_foot_contact\" in data\n        else None,\n        np.asarray(data[\"robot_low_level_foot_normal_force\"])\n        if \"robot_low_level_foot_normal_force\" in data\n        else None,\n        np.asarray(data[\"robot_low_level_foot_tangent_speed\"])\n        if \"robot_low_level_foot_tangent_speed\" in data\n        else None,\n        low_level_contact_dt,\n    )\n    expert_switching_js_div = _compute_expert_switching_js_div(\n        np.asarray(data[\"robot_moe_expert_logits\"])\n        if \"robot_moe_expert_logits\" in data\n        else None\n    )\n    return {\n        \"torque_chatter_hf_ratio\": torque_chatter_hf_ratio,\n        \"torque_jump_burst_max\": torque_jump_burst_max,\n        \"expert_switching_js_div\": expert_switching_js_div,\n        **torso_metrics,\n        **contact_metrics,\n    }\n\n\ndef _compute_clip_torque_jump_summary(\n    data: Dict[str, np.ndarray],\n    dof_mode: str,\n    torque_dt: float,\n) -> Dict[str, float]:\n    robot_dof_torque = (\n        np.asarray(data[\"robot_dof_torque\"])\n        if \"robot_dof_torque\" in data\n        else None\n    )\n    robot_low_level_dof_torque = (\n        np.asarray(data[\"robot_low_level_dof_torque\"])\n        if \"robot_low_level_dof_torque\" in data\n        else None\n    )\n\n    if dof_mode == \"23\" and robot_dof_torque is not None:\n        total_dofs_in_file = int(robot_dof_torque.shape[1])\n        if total_dofs_in_file == 29:\n            idx_23_in_29_dof = list(range(19)) + list(range(22, 26))\n            robot_dof_torque = robot_dof_torque[:, idx_23_in_29_dof]\n            if (\n                robot_low_level_dof_torque is not None\n                and int(robot_low_level_dof_torque.shape[1])\n                == total_dofs_in_file\n            ):\n                robot_low_level_dof_torque = robot_low_level_dof_torque[\n                    :, idx_23_in_29_dof\n                ]\n\n    chatter_torque = robot_low_level_dof_torque\n    if chatter_torque is None:\n        chatter_torque = robot_dof_torque\n\n    if chatter_torque is None or int(chatter_torque.shape[0]) <= 1:\n        return {\n            \"mean_torque_jump_norm\": float(\"nan\"),\n            \"p95_torque_jump_norm\": float(\"nan\"),\n            \"mean_torque_jump_ratio\": float(\"nan\"),\n            \"p95_torque_jump_ratio\": float(\"nan\"),\n        }\n\n    torque_jump_norm, torque_jump_ratio = _compute_torque_jump_series(\n        chatter_torque, torque_dt\n    )\n    return {\n        \"mean_torque_jump_norm\": float(np.nanmean(torque_jump_norm)),\n        \"p95_torque_jump_norm\": float(\n            np.nanpercentile(torque_jump_norm[1:], 95)\n        ),\n        \"mean_torque_jump_ratio\": float(np.nanmean(torque_jump_ratio)),\n        \"p95_torque_jump_ratio\": float(\n            np.nanpercentile(torque_jump_ratio[1:], 95)\n        ),\n    }\n\n\ndef _per_frame_metrics_from_npz(\n    motion_key: str,\n    data: Dict[str, np.ndarray],\n    dof_mode: str = \"29\",\n    robot_control_dt: float = DEFAULT_ROBOT_CONTROL_DT,\n) -> pd.DataFrame:\n    \"\"\"Compute per-frame metrics for a single motion clip from loaded npz arrays.\n\n    Expects the following keys in `data` (URDF order):\n    - dof_pos, robot_dof_pos\n    - global_translation, robot_global_translation\n    - global_rotation_quat, robot_global_rotation_quat (xyzw)\n    \"\"\"\n    # Required arrays\n    jpos_gt = np.asarray(data[\"ref_global_translation\"])  # (T, J, 3)\n    jpos_pred = np.asarray(data[\"robot_global_translation\"])  # (T, J, 3)\n    rot_gt = np.asarray(data[\"ref_global_rotation_quat\"])  # (T, J, 4) xyzw\n    rot_pred = np.asarray(data[\"robot_global_rotation_quat\"])  # (T, J, 4)\n    dof_gt = np.asarray(data[\"ref_dof_pos\"])  # (T, D)\n    dof_pred = np.asarray(data[\"robot_dof_pos\"])  # (T, D)\n    robot_dof_vel = (\n        np.asarray(data[\"robot_dof_vel\"]) if \"robot_dof_vel\" in data else None\n    )\n    robot_dof_acc = (\n        np.asarray(data[\"robot_dof_acc\"]) if \"robot_dof_acc\" in data else None\n    )\n    robot_dof_torque = (\n        np.asarray(data[\"robot_dof_torque\"])\n        if \"robot_dof_torque\" in data\n        else None\n    )\n    robot_low_level_dof_torque = (\n        np.asarray(data[\"robot_low_level_dof_torque\"])\n        if \"robot_low_level_dof_torque\" in data\n        else None\n    )\n    robot_action_rate = (\n        np.asarray(data[\"robot_action_rate\"])\n        if \"robot_action_rate\" in data\n        else None\n    )\n\n    total_dofs_in_file = int(dof_gt.shape[1])\n    IDX_23_IN_29_DOF = list(range(19)) + list(range(22, 26))\n    IDX_23_IN_29_BODY = [0] + [i + 1 for i in IDX_23_IN_29_DOF]\n\n    if dof_mode == \"23\":\n        if total_dofs_in_file == 29:\n            dof_gt = dof_gt[:, IDX_23_IN_29_DOF]\n            dof_pred = dof_pred[:, IDX_23_IN_29_DOF]\n            if (\n                robot_dof_vel is not None\n                and int(robot_dof_vel.shape[1]) == total_dofs_in_file\n            ):\n                robot_dof_vel = robot_dof_vel[:, IDX_23_IN_29_DOF]\n            if (\n                robot_dof_acc is not None\n                and int(robot_dof_acc.shape[1]) == total_dofs_in_file\n            ):\n                robot_dof_acc = robot_dof_acc[:, IDX_23_IN_29_DOF]\n            if (\n                robot_dof_torque is not None\n                and int(robot_dof_torque.shape[1]) == total_dofs_in_file\n            ):\n                robot_dof_torque = robot_dof_torque[:, IDX_23_IN_29_DOF]\n            if (\n                robot_low_level_dof_torque is not None\n                and int(robot_low_level_dof_torque.shape[1])\n                == total_dofs_in_file\n            ):\n                robot_low_level_dof_torque = robot_low_level_dof_torque[\n                    :, IDX_23_IN_29_DOF\n                ]\n\n            jpos_gt = jpos_gt[:, IDX_23_IN_29_BODY, :]\n            jpos_pred = jpos_pred[:, IDX_23_IN_29_BODY, :]\n\n            rot_gt = rot_gt[:, IDX_23_IN_29_BODY, :]\n            rot_pred = rot_pred[:, IDX_23_IN_29_BODY, :]\n\n    assert jpos_gt.shape == jpos_pred.shape\n    assert rot_gt.shape == rot_pred.shape\n    assert dof_gt.shape == dof_pred.shape\n\n    num_frames = int(jpos_gt.shape[0])\n\n    # Global MPJPE [mm]\n    mpjpe_g = (\n        np.mean(np.linalg.norm(jpos_gt - jpos_pred, axis=2), axis=1) * 1000.0\n    )\n\n    # Per-frame maximum body-link position error [m] (used for failure criterion)\n    # per_joint_err = np.linalg.norm(jpos_pred - jpos_gt, axis=2)\n    # frame_max_body_pos_err = np.max(per_joint_err, axis=1)\n    frame_max_body_pos_err = np.abs(jpos_pred[:, 0, 2] - jpos_gt[:, 0, 2])\n\n    # Localize by root (index 0)\n    jpos_gt_local = jpos_gt - jpos_gt[:, [0]]\n    jpos_pred_local = jpos_pred - jpos_pred[:, [0]]\n    ref_body_pos_root_rel = quat_apply(\n        quat_inv(rot_gt[:, 0, :]),\n        jpos_gt - jpos_gt[:, [0]],\n    )\n    robot_body_pos_root_rel = quat_apply(\n        quat_inv(rot_pred[:, 0, :]),\n        jpos_pred - jpos_pred[:, [0]],\n    )\n\n    mpjpe_l = (\n        np.mean(\n            np.linalg.norm(\n                robot_body_pos_root_rel - ref_body_pos_root_rel, axis=2\n            ),\n            axis=1,\n        )\n        * 1000.0\n    )\n\n    # Procrustes-aligned MPJPE [mm]\n    pa_per_joint = p_mpjpe(jpos_pred_local, jpos_gt_local)\n    mpjpe_pa = np.mean(pa_per_joint, axis=1) * 1000.0\n\n    # Velocity/acceleration errors from positions (discrete frame diffs) [mm/frame],[mm/frame^2]\n    vel_gt = jpos_gt[1:] - jpos_gt[:-1]\n    vel_pred = jpos_pred[1:] - jpos_pred[:-1]\n    vel_dist = (\n        np.mean(np.linalg.norm(vel_pred - vel_gt, axis=2), axis=1) * 1000.0\n    )\n\n    acc_gt = jpos_gt[:-2] - 2 * jpos_gt[1:-1] + jpos_gt[2:]\n    acc_pred = jpos_pred[:-2] - 2 * jpos_pred[1:-1] + jpos_pred[2:]\n    accel_dist = (\n        np.mean(np.linalg.norm(acc_pred - acc_gt, axis=2), axis=1) * 1000.0\n    )\n\n    # DOF angle errors [radians] — whole body average\n    dof_err = np.abs(dof_pred - dof_gt)\n    whole_body_joints_dist = np.mean(dof_err, axis=1)\n\n    # Root orientation errors [radians] — handle zero-norm/invalid quaternions by NaN\n    q_gt_root = rot_gt[:, 0, :]\n    q_pred_root = rot_pred[:, 0, :]\n    norms_gt = np.linalg.norm(q_gt_root, axis=1)\n    norms_pred = np.linalg.norm(q_pred_root, axis=1)\n    valid_mask = (\n        (norms_gt > 0.0)\n        & (norms_pred > 0.0)\n        & np.isfinite(norms_gt)\n        & np.isfinite(norms_pred)\n    )\n\n    root_r_error = np.full((num_frames,), np.nan, dtype=float)\n    root_p_error = np.full((num_frames,), np.nan, dtype=float)\n    root_y_error = np.full((num_frames,), np.nan, dtype=float)\n\n    if np.any(valid_mask):\n        q_gt_valid = q_gt_root[valid_mask] / norms_gt[valid_mask, None]\n        q_pred_valid = q_pred_root[valid_mask] / norms_pred[valid_mask, None]\n        rel_valid = sRot.from_quat(q_gt_valid).inv() * sRot.from_quat(\n            q_pred_valid\n        )\n        euler_xyz = rel_valid.as_euler(\"xyz\", degrees=False)\n        root_r_error[valid_mask] = np.abs(euler_xyz[:, 0])\n        root_p_error[valid_mask] = np.abs(euler_xyz[:, 1])\n        root_y_error[valid_mask] = np.abs(euler_xyz[:, 2])\n\n    # Root velocity error [m/frame]\n    root_pos_gt = jpos_gt[:, 0, :]\n    root_pos_pred = jpos_pred[:, 0, :]\n    root_vel_err = np.linalg.norm(\n        (root_pos_pred[1:] - root_pos_pred[:-1])\n        - (root_pos_gt[1:] - root_pos_gt[:-1]),\n        axis=1,\n    )\n\n    # Root height error [m]\n    root_height_error = np.abs(root_pos_pred[:, 2] - root_pos_gt[:, 2])\n\n    # Robot low-level magnitudes (optional)\n    mean_dof_vel = np.full((num_frames,), np.nan, dtype=float)\n    if robot_dof_vel is not None:\n        if int(robot_dof_vel.shape[0]) != num_frames:\n            raise ValueError(\n                \"robot_dof_vel frame length mismatch: \"\n                f\"{robot_dof_vel.shape[0]} vs {num_frames}\"\n            )\n        mean_dof_vel = np.linalg.norm(robot_dof_vel, axis=1)\n\n    mean_dof_acc = np.full((num_frames,), np.nan, dtype=float)\n    if robot_dof_acc is not None:\n        if int(robot_dof_acc.shape[0]) != num_frames:\n            raise ValueError(\n                \"robot_dof_acc frame length mismatch: \"\n                f\"{robot_dof_acc.shape[0]} vs {num_frames}\"\n            )\n        mean_dof_acc = np.linalg.norm(robot_dof_acc, axis=1)\n\n    mean_dof_torque = np.full((num_frames,), np.nan, dtype=float)\n    mean_torque_jump_norm = np.full((num_frames,), np.nan, dtype=float)\n    mean_torque_jump_ratio = np.full((num_frames,), np.nan, dtype=float)\n    if robot_dof_torque is not None:\n        if int(robot_dof_torque.shape[0]) != num_frames:\n            raise ValueError(\n                \"robot_dof_torque frame length mismatch: \"\n                f\"{robot_dof_torque.shape[0]} vs {num_frames}\"\n            )\n        mean_dof_torque = np.linalg.norm(robot_dof_torque, axis=1)\n    chatter_torque = robot_low_level_dof_torque\n    if chatter_torque is None:\n        chatter_torque = robot_dof_torque\n    if chatter_torque is not None and int(chatter_torque.shape[0]) > 1:\n        torque_jump_norm, torque_jump_ratio = _compute_torque_jump_series(\n            chatter_torque, robot_control_dt\n        )\n        mean_torque_jump_norm = _aggregate_sample_metric_to_frames(\n            torque_jump_norm, num_frames\n        )\n        mean_torque_jump_ratio = _aggregate_sample_metric_to_frames(\n            torque_jump_ratio, num_frames\n        )\n\n    mean_action_rate = np.full((num_frames,), np.nan, dtype=float)\n    if robot_action_rate is not None:\n        flat_action_rate = robot_action_rate.reshape(-1)\n        if int(flat_action_rate.shape[0]) != num_frames:\n            raise ValueError(\n                \"robot_action_rate frame length mismatch: \"\n                f\"{flat_action_rate.shape[0]} vs {num_frames}\"\n            )\n        mean_action_rate = flat_action_rate\n\n    # Frame DataFrame (align lengths by padding NaN at the start where needed)\n    def pad_front(x: np.ndarray, pad: int) -> np.ndarray:\n        if pad <= 0:\n            return x\n        return np.concatenate(\n            [np.full((pad,), np.nan, dtype=float), x], axis=0\n        )\n\n    df = pd.DataFrame(\n        {\n            \"motion_key\": [motion_key] * num_frames,\n            \"frame_idx\": np.arange(num_frames, dtype=int),\n            \"mpjpe_g\": mpjpe_g,\n            \"mpjpe_l\": mpjpe_l,\n            \"mpjpe_pa\": mpjpe_pa,\n            \"vel_dist\": pad_front(vel_dist, 1),\n            \"accel_dist\": pad_front(accel_dist, 2),\n            \"frame_max_body_pos_err\": frame_max_body_pos_err,\n            \"whole_body_joints_dist\": whole_body_joints_dist,\n            \"root_r_error\": root_r_error,\n            \"root_p_error\": root_p_error,\n            \"root_y_error\": root_y_error,\n            \"root_vel_error\": pad_front(root_vel_err, 1),\n            \"root_height_error\": root_height_error,\n            \"mean_dof_vel\": mean_dof_vel,\n            \"mean_dof_acc\": mean_dof_acc,\n            \"mean_dof_torque\": mean_dof_torque,\n            \"mean_torque_jump_norm\": mean_torque_jump_norm,\n            \"mean_torque_jump_ratio\": mean_torque_jump_ratio,\n            \"mean_action_rate\": mean_action_rate,\n        }\n    )\n    return df\n\n\ndef offline_evaluate_dumped_npzs(\n    npz_dir: str,\n    output_json_path: str,\n    failure_pos_err_thresh_m: float = 0.25,\n    metric_calculation: str = \"per_clip\",\n    dof_mode: str = \"29\",\n    threadpool_max_workers: Optional[int] = None,\n) -> Dict[str, dict]:\n    \"\"\"Evaluate dumped NPZs in `npz_dir` and write a JSON summary to `output_dir`.\n\n    The function produces dataset-wide averages and per-clip averages across frames.\n    \"\"\"\n    npz_dir_abs = Path(npz_dir).resolve()\n    os.makedirs(npz_dir_abs, exist_ok=True)\n\n    # Add file handler for logging to metric.log\n    metric_log_path = npz_dir_abs / \"metric.log\"\n    logger.add(\n        str(metric_log_path),\n        format=\"{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}\",\n        level=\"INFO\",\n    )\n\n    logger.info(f\"Input NPZ directory (absolute): {npz_dir_abs}\")\n\n    # Gather NPZ files\n    files = sorted(glob(os.path.join(npz_dir_abs, \"*.npz\")))\n    if len(files) == 0:\n        raise FileNotFoundError(f\"No NPZ files found in: {npz_dir_abs}\")\n\n    # Accumulate per-frame metrics\n    frame_tables: List[pd.DataFrame] = []\n    clip_meta: Dict[str, dict] = {}\n\n    skipped_files_count = 0\n\n    required_keys = [\n        \"ref_dof_pos\",\n        \"ref_dof_vel\",\n        \"ref_global_translation\",\n        \"ref_global_rotation_quat\",\n        \"ref_global_velocity\",\n        \"ref_global_angular_velocity\",\n        \"robot_dof_pos\",\n        \"robot_dof_vel\",\n        \"robot_global_translation\",\n        \"robot_global_rotation_quat\",\n        \"robot_global_velocity\",\n        \"robot_global_angular_velocity\",\n    ]\n    optional_keys = [\n        \"robot_dof_acc\",\n        \"robot_dof_torque\",\n        \"robot_low_level_dof_torque\",\n        \"robot_low_level_torque_dt\",\n        \"robot_low_level_foot_contact\",\n        \"robot_low_level_foot_normal_force\",\n        \"robot_low_level_foot_tangent_speed\",\n        \"robot_low_level_contact_dt\",\n        \"robot_action_rate\",\n        \"robot_moe_expert_indices\",\n        \"robot_moe_expert_logits\",\n    ]\n\n    def _compute_metrics_from_file(fpath: str):\n        try:\n            with np.load(fpath, allow_pickle=True) as npz_data:\n                # Extract arrays and metadata\n                data = {k: npz_data[k] for k in required_keys}\n                for k in optional_keys:\n                    if k in npz_data.files:\n                        data[k] = npz_data[k]\n\n                metadata = _parse_metadata_entry(npz_data.get(\"metadata\"))\n                robot_control_dt = _extract_robot_control_dt(metadata, data)\n                low_level_contact_dt = _extract_low_level_contact_dt(\n                    metadata, data, robot_control_dt\n                )\n\n            motion_key = os.path.splitext(os.path.basename(fpath))[0]\n            clip_len_from_name = _parse_clip_len_from_name(fpath)\n\n            df_frames = _per_frame_metrics_from_npz(\n                motion_key=motion_key,\n                data=data,\n                dof_mode=dof_mode,\n                robot_control_dt=robot_control_dt,\n            )\n            chatter_summary = _compute_clip_torque_jump_summary(\n                data=data, dof_mode=dof_mode, torque_dt=robot_control_dt\n            )\n            stability_summary = _compute_clip_stability_summary(\n                data=data,\n                robot_control_dt=robot_control_dt,\n                low_level_contact_dt=low_level_contact_dt,\n            )\n\n            # Clip-level info and failure criterion (max body-link pos error > threshold)\n            num_frames_clip = int(df_frames.shape[0])\n            clip_length = int(\n                metadata.get(\n                    \"clip_length\", clip_len_from_name or num_frames_clip\n                )\n            )\n            max_body_err = float(\n                np.nanmax(df_frames[\"frame_max_body_pos_err\"].to_numpy())\n            )\n            success = 1.0 if max_body_err <= failure_pos_err_thresh_m else 0.0\n            clip_meta_entry = {\n                \"motion_key\": motion_key,\n                \"num_frames\": num_frames_clip,\n                \"clip_length\": clip_length,\n                \"success\": success,\n                \"max_body_pos_err\": max_body_err,\n                \"failure_threshold_m\": float(failure_pos_err_thresh_m),\n                **chatter_summary,\n                **stability_summary,\n            }\n            return fpath, df_frames, motion_key, clip_meta_entry, None\n        except (ValueError, KeyError, BadZipFile, EOFError, OSError) as e:\n            return fpath, None, None, None, e\n\n    if threadpool_max_workers is None:\n        max_workers = max(1, min(len(files), 24))\n        requested_workers = None\n    else:\n        requested_workers = int(threadpool_max_workers)\n        if requested_workers <= 0:\n            raise ValueError(\"threadpool_max_workers must be > 0\")\n        max_workers = min(requested_workers, len(files))\n        if max_workers <= 0:\n            max_workers = 1\n    logger.info(\n        f\"Metric ThreadPoolExecutor max_workers={max_workers} \"\n        f\"(requested={requested_workers}, num_npz_files={len(files)})\"\n    )\n    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n        futures = {\n            executor.submit(_compute_metrics_from_file, fpath): file_idx\n            for file_idx, fpath in enumerate(files)\n        }\n        processed_results = [None] * len(files)\n        for future in tqdm(\n            as_completed(futures.keys()),\n            total=len(files),\n            desc=\"Compute metrics from NPZs\",\n        ):\n            processed_results[futures[future]] = future.result()\n\n    for result in processed_results:\n        (\n            fpath,\n            df_frames,\n            motion_key,\n            clip_meta_entry,\n            file_error,\n        ) = result\n        if file_error is not None:\n            logger.warning(f\"\\nCaught an error while processing file: {fpath}\")\n            logger.warning(f\"Error type: {type(file_error).__name__}\")\n            logger.warning(f\"Error message: {file_error}\")\n            logger.warning(\"This file will be SKIPPED.\")\n            skipped_files_count += 1\n            continue\n        frame_tables.append(df_frames)\n        clip_meta[motion_key] = clip_meta_entry\n\n    if skipped_files_count > 0:\n        logger.info(\n            f\"\\nFinished processing. Skipped a total of {skipped_files_count} files due to errors.\"\n        )\n\n    # If all files were skipped, there's nothing to process further.\n    if not frame_tables:\n        logger.error(\n            \"No valid NPZ files could be processed. Aborting evaluation.\"\n        )\n        return {}\n\n    # Concatenate per-frame metrics\n    all_frames = pd.concat(frame_tables, ignore_index=True)\n\n    # Per-clip averages\n    frame_metric_cols = [\n        \"mpjpe_g\",\n        \"mpjpe_l\",\n        \"whole_body_joints_dist\",\n        \"root_vel_error\",\n        \"root_r_error\",\n        \"root_p_error\",\n        \"root_y_error\",\n        \"root_height_error\",\n        \"mean_dof_vel\",\n        \"mean_dof_acc\",\n        \"mean_dof_torque\",\n        \"mean_torque_jump_norm\",\n        \"mean_torque_jump_ratio\",\n        \"mean_action_rate\",\n    ]\n    percentile_metric_cols = [\n        \"mean_torque_jump_norm\",\n        \"mean_torque_jump_ratio\",\n    ]\n    percentile_rename_map = {\n        \"mean_torque_jump_norm\": \"p95_torque_jump_norm\",\n        \"mean_torque_jump_ratio\": \"p95_torque_jump_ratio\",\n    }\n    metric_cols = frame_metric_cols + list(percentile_rename_map.values())\n    clip_only_metric_cols = [\n        \"torque_chatter_hf_ratio\",\n        \"torque_jump_burst_max\",\n        \"expert_switching_js_div\",\n        \"torso_rp_hf_ratio\",\n        \"torso_rp_angacc_p95\",\n        \"foot_contact_toggle_rate\",\n        \"foot_impact_force_p95\",\n        \"stance_slip_speed_p95\",\n    ]\n    metric_cols += clip_only_metric_cols\n    # Metric display configuration: metric_key -> (display_name, unit)\n    metric_display_map = {\n        \"mpjpe_g\": (\"Global Bodylink Mean Position Error\", \"mm\"),\n        \"mpjpe_l\": (\"Local Bodylink Mean Position Error\", \"mm\"),\n        \"whole_body_joints_dist\": (\"DOF Position Error\", \"rad\"),\n        \"root_vel_error\": (\"Root Velocity Error\", \"m/s\"),\n        \"root_r_error\": (\"Root Roll Error\", \"rad\"),\n        \"root_p_error\": (\"Root Pitch Error\", \"rad\"),\n        \"root_y_error\": (\"Root Yaw Error\", \"rad\"),\n        \"root_height_error\": (\"Root Height Error\", \"mm\"),\n        \"mean_dof_vel\": (\"Mean DOF Velocity\", \"rad/s\"),\n        \"mean_dof_acc\": (\"Mean DOF Acceleration\", \"rad/s^2\"),\n        \"mean_dof_torque\": (\"Mean DOF Torque\", \"N*m\"),\n        \"mean_torque_jump_norm\": (\"Mean Torque Jump Norm\", \"N*m/s\"),\n        \"p95_torque_jump_norm\": (\"P95 Torque Jump Norm\", \"N*m/s\"),\n        \"mean_torque_jump_ratio\": (\"Mean Torque Jump Ratio\", \"ratio\"),\n        \"p95_torque_jump_ratio\": (\"P95 Torque Jump Ratio\", \"ratio\"),\n        \"mean_action_rate\": (\"Mean Action Rate\", \"1/s\"),\n        \"torque_chatter_hf_ratio\": (\"Torque Chatter HF Ratio\", \"ratio\"),\n        \"torque_jump_burst_max\": (\"Torque Jump Burst Max\", \"ratio\"),\n        \"expert_switching_js_div\": (\"Expert Switching JS Div\", \"bits\"),\n        \"torso_rp_hf_ratio\": (\"Torso RP HF Ratio\", \"ratio\"),\n        \"torso_rp_angacc_p95\": (\"Torso RP Angular Accel P95\", \"rad/s^2\"),\n        \"foot_contact_toggle_rate\": (\"Foot Contact Toggle Rate\", \"1/s\"),\n        \"foot_impact_force_p95\": (\"Foot Impact Force P95\", \"N\"),\n        \"stance_slip_speed_p95\": (\"Stance Slip Speed P95\", \"m/s\"),\n    }\n\n    per_clip_mean = (\n        all_frames.groupby(\"motion_key\")[frame_metric_cols]\n        .mean(numeric_only=True)\n        .reset_index()\n    )\n    per_clip_p95 = (\n        all_frames.groupby(\"motion_key\")[percentile_metric_cols]\n        .quantile(0.95)\n        .reset_index()\n        .rename(columns=percentile_rename_map)\n    )\n    per_clip_summary = per_clip_mean.merge(\n        per_clip_p95, on=\"motion_key\", how=\"left\"\n    )\n    for metric_key in (\n        \"mean_torque_jump_norm\",\n        \"p95_torque_jump_norm\",\n        \"mean_torque_jump_ratio\",\n        \"p95_torque_jump_ratio\",\n    ):\n        per_clip_summary[metric_key] = per_clip_summary[\"motion_key\"].map(\n            {mk: clip_meta[mk].get(metric_key, np.nan) for mk in clip_meta}\n        )\n    for metric_key in clip_only_metric_cols:\n        per_clip_summary[metric_key] = per_clip_summary[\"motion_key\"].map(\n            {mk: clip_meta[mk].get(metric_key, np.nan) for mk in clip_meta}\n        )\n\n    # Merge with success flags\n    per_clip_records = []\n    for _, row in per_clip_summary.iterrows():\n        mk = row[\"motion_key\"]\n        rec = {**row.to_dict(), **clip_meta.get(mk, {})}\n        per_clip_records.append(rec)\n\n    # Persist per-clip metrics as a tabular CSV for easier downstream analysis.\n    per_clip_df = pd.DataFrame(per_clip_records)\n    output_csv_path = str(npz_dir_abs / \"per_clip_metrics.csv\")\n    per_clip_df.to_csv(output_csv_path, index=False)\n    logger.info(f\"Saved per-clip metrics CSV to: {output_csv_path}\")\n\n    dataset_means = {}\n    dataset_medians = {}\n    if metric_calculation == \"per_frame\":\n        agg_source = all_frames\n        agg_desc = \"PER-FRAME\"\n    else:\n        agg_source = per_clip_summary\n        agg_desc = \"PER-CLIP\"\n    for k in metric_cols:\n        if k in agg_source.columns:\n            arr = agg_source[k].to_numpy()\n        else:\n            arr = per_clip_summary[k].to_numpy()\n        dataset_means[k] = _safe_nanmean(arr)\n        dataset_medians[k] = _safe_nanmedian(arr)\n\n    success_rate = float(\n        np.mean([clip_meta[mk][\"success\"] for mk in clip_meta])\n        if len(clip_meta) > 0\n        else 0.0\n    )\n    dataset_means[\"success_rate\"] = success_rate\n\n    # Compose result and write\n    result = {\n        \"dataset\": {\n            \"calculation_mode\": metric_calculation,\n            \"mean\": dataset_means,\n            \"median\": dataset_medians,\n            \"success_rate\": success_rate,\n        },\n        \"num_clips\": int(len(clip_meta)),\n        \"per_clip\": per_clip_records,\n    }\n    with open(output_json_path, \"w\", encoding=\"utf-8\") as f:\n        json.dump(result, f, indent=2)\n\n    # Conversion factors for unit conversion (assuming 50Hz)\n    frame_rate_hz = 50.0\n    unit_conversions = {\n        \"root_height_error\": 1000.0,  # m to mm\n        \"root_vel_error\": frame_rate_hz,  # m/frame to m/s\n    }\n\n    table_data = []\n    # Iterate through metric_display_map to preserve order\n    for key in metric_display_map.keys():\n        if key not in dataset_means:\n            continue\n\n        val_mean = dataset_means[key]\n        val_median = dataset_medians[key]\n        display_name, unit = metric_display_map[key]\n\n        # Apply unit conversion if needed\n        if key in unit_conversions:\n            factor = unit_conversions[key]\n            val_mean = val_mean * factor\n            val_median = val_median * factor\n\n        def fmt(v):\n            return f\"{v:.4f}\" if isinstance(v, float) else str(v)\n\n        table_data.append([display_name, fmt(val_mean), fmt(val_median), unit])\n\n    table_headers = [\"Metric\", \"Mean\", \"Median\", \"Unit\"]\n    output_tsv_path = str(npz_dir_abs / \"whole_dataset_metrics.tsv\")\n    with open(output_tsv_path, \"w\", encoding=\"utf-8\", newline=\"\") as f:\n        writer = csv.writer(f, delimiter=\"\\t\", lineterminator=\"\\n\")\n        writer.writerow(table_headers)\n        writer.writerows(table_data)\n    logger.info(f\"Saved whole-dataset metrics TSV to: {output_tsv_path}\")\n\n    table_str = tabulate(\n        table_data,\n        headers=table_headers,\n        tablefmt=\"simple_outline\",\n        colalign=(\"left\", \"left\", \"left\", \"left\"),\n    )\n    logger.info(\n        \"\\n\"\n        + \"=\" * 80\n        + f\"\\nDATASET-WISE METRICS ({agg_desc})\\n\"\n        + \"=\" * 80\n        + f\"\\n\\n{table_str}\\n\"\n        + \"=\" * 80\n        + \"\\n\"\n    )\n\n    return result\n\n\ndef parse_ckpt_and_dataset_from_eval_dirname(\n    eval_dir_name: str, dataset_suffix: str\n):\n    VALID_PREFIXES = [\"isaaclab_eval_output_\", \"mujoco_eval_output_\"]\n\n    matched_prefix = None\n    for prefix in VALID_PREFIXES:\n        if eval_dir_name.startswith(prefix):\n            matched_prefix = prefix\n            break\n\n    if matched_prefix is None:\n        return None, None\n\n    rest = eval_dir_name[len(matched_prefix) :]\n    if not rest.endswith(dataset_suffix):\n        return None, None\n\n    model_part = rest[: -len(dataset_suffix)]\n    if model_part.endswith(\"_\"):\n        model_part = model_part[:-1]\n\n    m = re.search(r\"model_(\\d+)$\", model_part)\n    if not m:\n        return None, dataset_suffix\n\n    return m.group(1), dataset_suffix\n\n\ndef run_evaluation(\n    npz_dir: str,\n    dataset_suffix: str,\n    failure_pos_err_thresh_m: float = 0.25,\n    metric_calculation: str = \"per_clip\",\n    dof_mode: str = \"29\",\n    threadpool_max_workers: Optional[int] = None,\n):\n    \"\"\"\n    Main function to run evaluation. It scans a root directory, runs evaluation\n    for each found subdirectory, and generates a final summary report.\n\n    Args:\n        npz_dir (str): Top-level directory containing all model evaluation results (e.g., 'logs/test').\n        output_dir (str): Directory to store all generated JSON files and logs.\n        failure_pos_err_thresh_m (float): The position error threshold in meters to determine a failure.\n    \"\"\"\n    root_path = Path(npz_dir)\n\n    logger.info(f\"Starting batch evaluation. Root directory: '{root_path}'\")\n    logger.info(\n        f\"Searching for directories matching pattern: '{dataset_suffix}'\"\n    )\n\n    def has_npz_files(path: Path) -> bool:\n        return path.is_dir() and any(path.glob(\"*.npz\"))\n\n    is_single_eval_dir = (\n        root_path.is_dir()\n        and (\n            root_path.name.startswith(\"isaaclab_eval_output_\")\n            or root_path.name.startswith(\"mujoco_eval_output_\")\n        )\n        and has_npz_files(root_path)\n    )\n\n    if is_single_eval_dir:\n        output_path = root_path\n    else:\n        output_path = root_path / f\"metrics_output_{dataset_suffix}\"\n    output_path.mkdir(parents=True, exist_ok=True)\n\n    if is_single_eval_dir:\n        logger.info(\n            f\"Detected '{root_path}' as a single evaluation directory. \"\n            \"Running offline evaluation only for this directory.\"\n        )\n        model_name = root_path.parent.name\n\n        ckpt_str, ds = parse_ckpt_and_dataset_from_eval_dirname(\n            root_path.name, dataset_suffix\n        )\n        if ckpt_str is None:\n            logger.warning(\n                f\"Could not parse checkpoint/dataset from directory name '{root_path.name}'. \"\n                \"Using 'checkpoint_unknown' in output filename.\"\n            )\n            ckpt_str = \"checkpoint_unknown\"\n            ds = dataset_suffix\n\n        output_json_name = f\"{model_name}_{ckpt_str}_{dof_mode}dof.json\"\n        output_json_path = output_path / output_json_name\n\n        offline_evaluate_dumped_npzs(\n            npz_dir=str(root_path),\n            output_json_path=str(output_json_path),\n            failure_pos_err_thresh_m=failure_pos_err_thresh_m,\n            metric_calculation=metric_calculation,\n            dof_mode=dof_mode,\n            threadpool_max_workers=threadpool_max_workers,\n        )\n        logger.success(\n            f\"Finished single-directory evaluation: model='{model_name}', checkpoint={ckpt_str}\"\n        )\n        return\n    logger.info(\n        f\"Treating '{root_path}' as root directory for batch evaluation.\"\n    )\n    # Find all directories matching the evaluation output pattern.\n    eval_dirs = sorted(\n        p\n        for p in root_path.glob(f\"**/*eval_output_*_{dataset_suffix}\")\n        if p.is_dir()\n    )\n    if not eval_dirs:\n        logger.error(\n            f\"No directories matching the pattern '{dataset_suffix}' found under '{root_path}'. \"\n            \"Please check the path and pattern.\"\n        )\n        return\n\n    all_results = []\n\n    # Process each found evaluation directory.\n    for eval_dir in tqdm(eval_dirs, desc=\"Overall Progress\"):\n        # Extract model name from the parent directory.\n        model_name = eval_dir.parent.name\n        # Parse the checkpoint number from the directory name.\n        ckpt_str, ds = parse_ckpt_and_dataset_from_eval_dirname(\n            eval_dir.name, dataset_suffix\n        )\n        if ckpt_str is None:\n            logger.warning(\n                f\"Could not parse ckpt/dataset from '{eval_dir.name}'. Skipping.\"\n            )\n            continue\n\n        checkpoint = int(ckpt_str)\n\n        logger.info(\n            f\"\\n--- Processing: model='{model_name}', dataset='{ds}', checkpoint={checkpoint} ---\"\n        )\n\n        # Construct a unique output JSON filename.\n        output_json_name = f\"{model_name}_{checkpoint}.json\"\n        output_json_path = output_path / output_json_name\n\n        # Call the evaluation function for the current directory.\n        result = offline_evaluate_dumped_npzs(\n            npz_dir=str(eval_dir),\n            output_json_path=str(output_json_path),\n            failure_pos_err_thresh_m=failure_pos_err_thresh_m,\n            metric_calculation=metric_calculation,\n            dof_mode=dof_mode,\n            threadpool_max_workers=threadpool_max_workers,\n        )\n\n        if result and \"dataset\" in result:\n            # Collect dataset-level average metrics for the final summary.\n            flat_result = {\n                \"model\": model_name,\n                \"checkpoint\": checkpoint,\n                **result[\"dataset\"],\n            }\n            all_results.append(flat_result)\n            logger.success(\n                f\"--- Finished processing: model='{model_name}', checkpoint={checkpoint} ---\"\n            )\n        else:\n            logger.error(\n                f\"--- Failed to process: model='{model_name}', checkpoint={checkpoint} ---\"\n            )\n\n    if not all_results:\n        logger.error(\n            \"No evaluations succeeded. Cannot generate a summary report.\"\n        )\n        return\n\n    logger.info(\"\\n\" + \"=\" * 80)\n    logger.info(\"Batch evaluation finished successfully.\")\n    logger.info(f\"Total successful evaluations: {len(all_results)}\")\n    logger.info(\"=\" * 80)\n\n\nif __name__ == \"__main__\":\n    argument_parser = argparse.ArgumentParser()\n    argument_parser.add_argument(\"--npz_dir\", type=str, required=True)\n    argument_parser.add_argument(\n        \"--dataset_suffix\",\n        type=str,\n        required=True,\n    )\n    argument_parser.add_argument(\n        \"--failure_pos_err_thresh_m\", type=float, default=0.25\n    )\n    argument_parser.add_argument(\n        \"--metric_calculation\",\n        type=str,\n        choices=[\"per_clip\", \"per_frame\"],\n        default=\"per_clip\",\n        help=\"Calculation mode for dataset metrics. 'per_clip' averages clip means (Macro). 'per_frame' averages all frames (Micro).\",\n    )\n    argument_parser.add_argument(\n        \"--dof_mode\",\n        type=str,\n        choices=[\"29\", \"23\"],\n        default=\"29\",\n        help=\"Compute metrics for full 29 DoF or reduced 23 DoF (excluding hands).\",\n    )\n    argument_parser.add_argument(\n        \"--threadpool_max_workers\",\n        type=int,\n        default=None,\n        help=\"Max workers for per-NPZ ThreadPoolExecutor. \"\n        \"Default: None (auto = min(num_files, 24)).\",\n    )\n    args = argument_parser.parse_args()\n\n    run_evaluation(\n        npz_dir=args.npz_dir,\n        dataset_suffix=args.dataset_suffix,\n        failure_pos_err_thresh_m=args.failure_pos_err_thresh_m,\n        metric_calculation=args.metric_calculation,\n        dof_mode=args.dof_mode,\n        threadpool_max_workers=args.threadpool_max_workers,\n    )\n"
  },
  {
    "path": "holomotion/src/evaluation/multi_model_metrics_report.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nimport argparse\nimport itertools\nimport json\nfrom collections import defaultdict\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nimport seaborn as sns\nfrom loguru import logger\nfrom scipy.stats import mannwhitneyu\nimport textwrap\n\nDEFAULT_METRICS_TO_ANALYZE = [\n    \"mpjpe_g\",\n    \"mpjpe_l\",\n    \"whole_body_joints_dist\",\n    \"root_vel_error\",\n    \"root_r_error\",\n    \"root_p_error\",\n    \"root_y_error\",\n    \"root_height_error\",\n]\n\nRADAR_METRICS = [\n    \"mpjpe_g\",\n    \"mpjpe_l\",\n    \"whole_body_joints_dist\",\n    \"root_vel_error\",\n]\nDEFAULT_RADAR_MAPPING = {m: m for m in RADAR_METRICS}\n\nDEFAULT_ALPHA = 0.05\n\n\nclass AnalysisReportGenerator:\n    \"\"\"Load per-clip JSON metrics, run analysis, generate plots + markdown report.\"\"\"\n\n    def __init__(\n        self,\n        json_dir: str,\n        plots_dir: str,\n        dataset_name: str,\n        metrics_to_analyze: List[str],\n        radar_metric_mapping: Dict[str, str],\n        metric_types_for_radar: Dict[str, str],\n        alpha: float = DEFAULT_ALPHA,\n        plot_quantile_cutoff: float = 0.99,\n        kde_linewidth: float = 2.5,\n        min_normalized_value: float = 0.2,\n        radar_chart_filename: str = \"radar_chart_comparison.png\",\n    ) -> None:\n        self.json_dir = Path(json_dir)\n        self.plots_dir = Path(plots_dir)\n        self.dataset_name = dataset_name\n        self.metrics_to_analyze = metrics_to_analyze\n        self.radar_metric_mapping = radar_metric_mapping.copy()\n        self.metric_types_for_radar = metric_types_for_radar.copy()\n        self.alpha = alpha\n        self.plot_quantile_cutoff = plot_quantile_cutoff\n        self.kde_linewidth = kde_linewidth\n        self.min_normalized_value = min_normalized_value\n        self.radar_chart_filename = radar_chart_filename\n\n        self.df: Optional[pd.DataFrame] = None\n        self.models: List[str] = []\n\n    def run(self) -> None:\n        self.plots_dir.mkdir(exist_ok=True, parents=True)\n\n        self.df = self._load_and_prepare_data()\n        if self.df is None or self.df.empty:\n            logger.warning(\"No valid data loaded; aborting analysis.\")\n            return\n\n        self.models = sorted(self.df[\"model\"].unique().tolist())\n        if len(self.models) < 1:\n            logger.warning(\"No models found in data; aborting analysis.\")\n            return\n\n        self._create_matplotlib_radar_chart()\n        markdown_content = self._generate_markdown_report()\n\n        ts = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n        out_md = (\n            self.plots_dir / f\"analysis_report_{self.dataset_name}_{ts}.md\"\n        )\n        out_md.write_text(markdown_content, encoding=\"utf-8\")\n        logger.info(f\"Markdown report written to: {out_md}\")\n\n    def _load_and_prepare_data(self) -> Optional[pd.DataFrame]:\n        if not self.json_dir.is_dir():\n            logger.error(f\"json_dir '{self.json_dir}' is not a directory.\")\n            return None\n\n        json_files = list(self.json_dir.glob(\"*.json\"))\n        if not json_files:\n            logger.error(f\"No .json files found in '{self.json_dir}'.\")\n            return None\n\n        all_clips: List[Dict[str, Any]] = []\n\n        for jf in json_files:\n            model_name = jf.stem\n\n            data = json.loads(jf.read_text(encoding=\"utf-8\"))\n            if not isinstance(data, dict) or \"per_clip\" not in data:\n                logger.warning(\n                    f\"Skipping non-eval JSON file '{jf.name}' \"\n                    f\"(top-level type={type(data)}, has_per_clip={'per_clip' in data if isinstance(data, dict) else False}).\"\n                )\n                continue\n\n            per_clip = data.get(\"per_clip\")\n            if not per_clip:\n                logger.warning(\n                    f\"File '{jf.name}' has empty 'per_clip'; skipping.\"\n                )\n                continue\n\n            for clip in per_clip:\n                clip[\"model\"] = model_name\n                all_clips.append(clip)\n\n        if not all_clips:\n            logger.error(\"No per_clip data found in any JSON files.\")\n            return None\n\n        df = pd.DataFrame(all_clips)\n        logger.info(\n            f\"Loaded {len(df)} clip records from {len(json_files)} JSON files.\"\n        )\n        return df\n\n    def _create_kde_plot(self, metric: str, save_path: Path) -> None:\n        if self.df is None or metric not in self.df.columns:\n            return\n        if self.df[metric].isnull().all():\n            return\n        q_high = self.df[metric].quantile(self.plot_quantile_cutoff)\n        df_filtered = self.df[self.df[metric] <= q_high]\n\n        plt.style.use(\"seaborn-v0_8-whitegrid\")\n        fig, ax = plt.subplots(figsize=(12, 7))\n        sns.kdeplot(\n            data=df_filtered,\n            x=metric,\n            hue=\"model\",\n            hue_order=self.models,\n            ax=ax,\n            fill=False,\n            common_norm=False,\n            palette=\"tab10\",\n            linewidth=self.kde_linewidth,\n        )\n        ax.set_title(\n            f'Error Distribution for \"{metric}\" on {self.dataset_name}',\n            fontsize=16,\n            weight=\"bold\",\n        )\n        ax.set_xlabel(f\"Error Value ({metric})\", fontsize=12)\n        ax.set_ylabel(\"Density\", fontsize=12)\n        ax.set_xlim(left=0)\n\n        legend = ax.get_legend()\n        if legend:\n            legend.set_title(\"Models\", prop={\"size\": 14, \"weight\": \"bold\"})\n            for text in legend.get_texts():\n                text.set_fontsize(14)\n\n        fig.savefig(save_path, dpi=150, bbox_inches=\"tight\")\n        plt.close(fig)\n\n    def _create_matplotlib_radar_chart(self) -> None:\n        if self.df is None:\n            return\n\n        original_metrics = list(self.radar_metric_mapping.keys())\n\n        raw_labels = [self.radar_metric_mapping[m] for m in original_metrics]\n        display_labels = [\n            textwrap.fill(label, width=20, break_long_words=False)\n            for label in raw_labels\n        ]\n        num_metrics = len(original_metrics)\n\n        median_df = self.df.groupby(\"model\")[original_metrics].median()\n        rounded_median_df = median_df.round(2)\n\n        normalized_df = pd.DataFrame(\n            index=self.models, columns=original_metrics, dtype=float\n        )\n        scale = 1.0 - self.min_normalized_value\n\n        for metric in original_metrics:\n            medians = rounded_median_df[metric].dropna()\n            if medians.empty:\n                normalized_df[metric] = self.min_normalized_value\n                continue\n\n            min_val, max_val = medians.min(), medians.max()\n            rng = max_val - min_val if max_val > min_val else 1.0\n\n            for model in self.models:\n                val = rounded_median_df.loc[model, metric]\n                if pd.isna(val):\n                    normalized_df.loc[model, metric] = (\n                        self.min_normalized_value\n                    )\n                    continue\n\n                lower_better = (\n                    self.metric_types_for_radar.get(metric, \"lower\") == \"lower\"\n                )\n                if lower_better:\n                    base = (max_val - val) / rng\n                else:\n                    base = (val - min_val) / rng\n\n                norm = self.min_normalized_value + base * scale\n                normalized_df.loc[model, metric] = norm\n\n        angles = np.linspace(\n            0, 2 * np.pi, num_metrics, endpoint=False\n        ).tolist()\n        angles += angles[:1]\n\n        fig, ax = plt.subplots(\n            figsize=(10, 10), subplot_kw=dict(projection=\"polar\")\n        )\n        cmap = plt.get_cmap(\"tab10\")\n        colors = {m: cmap(i % 10) for i, m in enumerate(self.models)}\n\n        for model in self.models:\n            vals = normalized_df.loc[model].tolist()\n            vals += vals[:1]\n            ax.fill(angles, vals, color=colors[model], alpha=0.25)\n            ax.plot(\n                angles,\n                vals,\n                color=colors[model],\n                linewidth=2.5,\n                label=model,\n                marker=\"o\",\n                markersize=7,\n                markeredgecolor=\"white\",\n                markeredgewidth=1,\n            )\n\n        for j, metric in enumerate(original_metrics):\n            angle = angles[j]\n            groups: Dict[str, List[float]] = defaultdict(list)\n            for model in self.models:\n                orig_val = rounded_median_df.loc[model, metric]\n                norm_val = normalized_df.loc[model, metric]\n                groups[f\"{orig_val:.2f}\"].append(norm_val)\n\n            for label_text, norm_vals in groups.items():\n                avg_norm = float(np.mean(norm_vals))\n                offset = 0.05\n                ax.text(\n                    angle,\n                    avg_norm + offset,\n                    label_text,\n                    ha=\"center\",\n                    va=\"center\",\n                    color=\"black\",\n                    weight=\"bold\",\n                    fontsize=9,\n                    bbox=dict(\n                        boxstyle=\"square,pad=0.3\",\n                        fc=\"white\",\n                        ec=\"none\",\n                        alpha=0.8,\n                    ),\n                )\n\n        ax.set_thetagrids(np.degrees(angles[:-1]), display_labels, fontsize=16)\n        ax.tick_params(axis=\"x\", pad=30)\n        ax.set_rgrids([0.4, 0.6, 0.8, 1.0], labels=[])\n        ax.set_ylim(0, 1.25)\n        ax.spines[\"polar\"].set_visible(False)\n        ax.grid(color=\"grey\", linestyle=\"--\", linewidth=0.5)\n\n        handles, labels = ax.get_legend_handles_labels()\n        if handles:\n            legend_map = dict(zip(labels, handles))\n            ordered_labels = [m for m in self.models if m in legend_map]\n            ordered_handles = [legend_map[m] for m in ordered_labels]\n            ax.legend(\n                handles=ordered_handles,\n                labels=ordered_labels,\n                loc=\"upper center\",\n                bbox_to_anchor=(0.5, 1.15),\n                ncol=len(ordered_handles),\n                fontsize=14,\n                frameon=False,\n            )\n\n        fig.suptitle(\n            f\"Model Comparison on {self.dataset_name} Dataset\",\n            fontsize=20,\n            weight=\"bold\",\n            y=1.05,\n        )\n\n        save_path = self.plots_dir / self.radar_chart_filename\n        fig.savefig(save_path, dpi=300, bbox_inches=\"tight\")\n        plt.close(fig)\n        logger.info(f\"Radar chart saved to: {save_path}\")\n\n    def _generate_markdown_report(self) -> str:\n        if self.df is None or len(self.models) < 2:\n            return \"\"\n\n        parts: List[str] = [\n            f\"**Dataset**: {self.dataset_name}\",\n            f\"**Models**: {', '.join(self.models)}\",\n            f\"**Significance level (alpha)**: {self.alpha}\",\n            \"### Pairwise metric comparisons and distributions\",\n        ]\n\n        two_models = len(self.models) == 2\n        if two_models:\n            model1, model2 = self.models[0], self.models[1]\n\n        for metric in self.metrics_to_analyze:\n            if metric not in self.df.columns:\n                continue\n\n            p_value_str = \"\"\n            if two_models:\n                d1 = self.df.loc[self.df[\"model\"] == model1, metric].dropna()\n                d2 = self.df.loc[self.df[\"model\"] == model2, metric].dropna()\n                if not d1.empty and not d2.empty:\n                    _, p_val = mannwhitneyu(d1, d2, alternative=\"two-sided\")\n                    p_value_str = f\" (p = {p_val:.3g})\"\n\n            parts.append(f\"#### Metric: `{metric}`{p_value_str}\")\n\n            metric_stats: List[Dict[str, Any]] = []\n            for name in self.models:\n                data = self.df.loc[self.df[\"model\"] == name, metric].dropna()\n                if data.empty:\n                    continue\n                metric_stats.append(\n                    {\n                        \"Model\": name,\n                        \"Median\": data.median(),\n                        \"Q1 (25%)\": data.quantile(0.25),\n                        \"Q3 (75%)\": data.quantile(0.75),\n                    }\n                )\n\n            if metric_stats:\n                stats_df = (\n                    pd.DataFrame(metric_stats)\n                    .sort_values(by=\"Median\")\n                    .reset_index(drop=True)\n                )\n                parts.append(stats_df.to_markdown(index=False, floatfmt=\".4f\"))\n\n            findings: List[str] = []\n            lower_better = (\n                self.metric_types_for_radar.get(metric, \"lower\") == \"lower\"\n            )\n\n            for m1, m2 in itertools.combinations(self.models, 2):\n                d1 = self.df.loc[self.df[\"model\"] == m1, metric].dropna()\n                d2 = self.df.loc[self.df[\"model\"] == m2, metric].dropna()\n                if d1.empty or d2.empty:\n                    continue\n                _, p_val = mannwhitneyu(d1, d2, alternative=\"two-sided\")\n                if p_val >= self.alpha:\n                    continue\n\n                m1_med, m2_med = d1.median(), d2.median()\n                better, worse = (m1, m2) if m1_med < m2_med else (m2, m1)\n                if not lower_better:\n                    better, worse = worse, better\n\n                findings.append(\n                    f\"- **{better}** is significantly better than **{worse}** \"\n                    f\"(p < {self.alpha}).\"\n                )\n\n            if findings:\n                parts.append(\"\\n\".join(findings))\n            else:\n                parts.append(\n                    \"No statistically significant differences between models.\"\n                )\n            safe_metric = metric.replace(\" \", \"_\")\n            plot_filename = f\"{safe_metric}.png\"\n            plot_path = self.plots_dir / plot_filename\n            self._create_kde_plot(metric, plot_path)\n            parts.append(\n                f\"##### Distribution plot\\n\"\n                f\"![{metric} distribution on {self.dataset_name}]({plot_filename})\"\n            )\n\n        return \"\\n\\n\".join(parts)\n\n\ndef parse_args() -> argparse.Namespace:\n    parser = argparse.ArgumentParser(\n        description=\"Analyze per-clip JSON metrics, generate plots and markdown report.\"\n    )\n    parser.add_argument(\"--json_dir\", type=str, required=True)\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    json_dir = Path(args.json_dir).resolve()\n    name = json_dir.name\n    for prefix in (\"metrics_output_\",):\n        if name.startswith(prefix) and len(name) > len(prefix):\n            name = name[len(prefix) :]\n            break\n    dataset_name = name  # e.g. \"AMASS\"\n    plots_dir = json_dir / f\"analysis_plots_{dataset_name}\"\n\n    metric_types_for_radar = {m: \"lower\" for m in DEFAULT_METRICS_TO_ANALYZE}\n\n    analyzer = AnalysisReportGenerator(\n        json_dir=args.json_dir,\n        plots_dir=str(plots_dir),\n        dataset_name=dataset_name,\n        metrics_to_analyze=DEFAULT_METRICS_TO_ANALYZE,\n        radar_metric_mapping=DEFAULT_RADAR_MAPPING,\n        metric_types_for_radar=metric_types_for_radar,\n        alpha=DEFAULT_ALPHA,\n    )\n    analyzer.run()\n"
  },
  {
    "path": "holomotion/src/evaluation/obs/__init__.py",
    "content": "from .obs_builder import PolicyObsBuilder, get_gravity_orientation\n\n__all__ = [\n    \"PolicyObsBuilder\",\n    \"get_gravity_orientation\",\n]\n"
  },
  {
    "path": "holomotion/src/evaluation/obs/obs_builder.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport numpy as np\nimport torch\n\nfrom typing import Dict, List, Sequence, Any, Optional\n\n\ndef get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray:\n    \"\"\"Calculate gravity orientation from quaternion.\n\n    Args:\n        quaternion: Array-like [w, x, y, z]\n\n    Returns:\n        np.ndarray of shape (3,) representing gravity projection.\n    \"\"\"\n    qw = float(quaternion[0])\n    qx = float(quaternion[1])\n    qy = float(quaternion[2])\n    qz = float(quaternion[3])\n\n    gravity_orientation = np.zeros(3, dtype=np.float32)\n    gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy)\n    gravity_orientation[1] = -2.0 * (qz * qy + qw * qx)\n    gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz)\n    return gravity_orientation\n\n\nclass _CircularBuffer:\n    \"\"\"History buffer for batched tensor data (batch==1 in our eval/deploy).\n\n    Stores history in oldest->newest order when accessed via .buffer.\n    \"\"\"\n\n    def __init__(self, max_len: int, feat_dim: int, device: str):\n        if max_len < 1:\n            raise ValueError(f\"max_len must be >= 1, got {max_len}\")\n        self._max_len = int(max_len)\n        self._feat_dim = int(feat_dim)\n        self._device = device\n        self._pointer = -1\n        self._num_pushes = 0\n        self._buffer: torch.Tensor = torch.zeros(\n            (self._max_len, 1, self._feat_dim),\n            dtype=torch.float32,\n            device=\"cpu\",\n        )\n\n    @property\n    def buffer(self) -> torch.Tensor:\n        \"\"\"Tensor of shape [1, max_len, feat_dim], oldest->newest along dim=1.\"\"\"\n        if self._num_pushes == 0:\n            raise RuntimeError(\n                \"Attempting to read from an empty history buffer.\"\n            )\n        # roll such that oldest is at index=0 along the history axis\n        rolled = torch.roll(\n            self._buffer, shifts=self._max_len - self._pointer - 1, dims=0\n        )\n        return torch.transpose(rolled, 0, 1)  # [1, max_len, feat]\n\n    def append(self, data: torch.Tensor) -> None:\n        \"\"\"Append one step: data shape [1, feat_dim] on the configured device.\"\"\"\n        if (\n            data.ndim != 2\n            or data.shape[0] != 1\n            or data.shape[1] != self._feat_dim\n        ):\n            raise ValueError(\n                f\"Expected data with shape [1, {self._feat_dim}], got {tuple(data.shape)}\"\n            )\n        self._pointer = (self._pointer + 1) % self._max_len\n        self._buffer[self._pointer] = data\n        if self._num_pushes == 0:\n            # duplicate first push across entire history for warm start\n            self._buffer[:] = data\n        self._num_pushes += 1\n\n\nclass PolicyObsBuilder:\n    \"\"\"Builds policy observations from Unitree lowstate with temporal history.\n\n    Designed to be shared between MuJoCo sim2sim evaluation and ROS2 deployment.\n    History management is internal and produces a flattened vector of size\n    sum_i(context_length * feat_i) across the configured observation items.\n\n    Supports two command modes:\n    - \"motion_tracking\": uses reference motion states\n    - \"velocity_tracking\": uses velocity commands [vx, vy, vyaw]\n    \"\"\"\n\n    def __init__(\n        self,\n        dof_names_onnx: Sequence[str],\n        default_angles_onnx: np.ndarray,\n        evaluator: Optional[Any] = None,\n        obs_policy_cfg: Optional[Dict[str, Any]] = None,\n    ) -> None:\n        self.dof_names_onnx: List[str] = list(dof_names_onnx)\n        self.num_actions: int = len(self.dof_names_onnx)\n        self.evaluator = evaluator\n        self.obs_policy_cfg = obs_policy_cfg\n\n        if default_angles_onnx.shape[0] != self.num_actions:\n            raise ValueError(\n                \"default_angles_onnx length must match num actions\"\n            )\n        self.default_angles_onnx = default_angles_onnx.astype(np.float32)\n        self.default_angles_dict: Dict[str, float] = {\n            name: float(self.default_angles_onnx[idx])\n            for idx, name in enumerate(self.dof_names_onnx)\n        }\n\n        # Build observation schema from config if provided\n        self.term_specs: List[Dict[str, Any]] = []\n\n        for term_dict in self.obs_policy_cfg[\"atomic_obs_list\"]:\n            for name, cfg in term_dict.items():\n                term_dict = {**cfg}\n                term_dict[\"name\"] = name\n                self.term_specs.append(term_dict)\n\n        # Buffers are created lazily after first dimension inference\n        self._buffers: Dict[str, _CircularBuffer] = {}\n\n    def reset(self) -> None:\n        for buf in self._buffers.values():\n            buf._pointer = -1\n            buf._num_pushes = 0\n            buf._buffer.zero_()\n\n    def _compute_term(\n        self,\n        name: str,\n    ) -> np.ndarray:\n        # Prefer evaluator-provided methods; no legacy fallbacks\n        if self.evaluator is not None:\n            meth = getattr(self.evaluator, f\"_get_obs_{name}\", None)\n            if callable(meth):\n                out = meth()\n                return np.asarray(out, dtype=np.float32).reshape(-1)\n        raise ValueError(\n            f\"Unknown observation term '{name}' or evaluator method missing.\"\n        )\n\n    def build_policy_obs(self) -> np.ndarray:\n        \"\"\"Append one step using evaluator-provided observation terms and return flattened obs.\"\"\"\n        # Compute per-term outputs\n        values: Dict[str, np.ndarray] = {}\n\n        for spec in self.term_specs:\n            name = spec[\"name\"]\n            scale = spec.get(\"scale\", 1.0)\n            values[name] = self._compute_term(name) * scale\n\n        # Lazily initialize buffers with inferred feature dims\n        if len(self._buffers) == 0:\n            for spec in self.term_specs:\n                name = spec[\"name\"]\n                hist_len = int(spec.get(\"history_length\", 0))\n                if hist_len <= 0:\n                    continue\n                feat_dim = int(values[name].reshape(-1).shape[0])\n                self._buffers[name] = _CircularBuffer(\n                    hist_len, feat_dim, \"cpu\"\n                )\n        # Append current step to buffers (skip terms without history)\n        for spec in self.term_specs:\n            name = spec[\"name\"]\n            if name in self._buffers:\n                item = torch.as_tensor(\n                    values[name].reshape(1, -1),\n                    dtype=torch.float32,\n                    device=\"cpu\",\n                )\n                self._buffers[name].append(item)\n        # Assemble flat list according to term ordering and history flatten rules\n        flat_list: List[np.ndarray] = []\n        for spec in self.term_specs:\n            name = spec[\"name\"]\n            flatten = bool(spec.get(\"flatten\", True))\n            if name in self._buffers:\n                buf = self._buffers[name].buffer[0]  # [hist, feat]\n                arr = (\n                    buf.reshape(-1).detach().cpu().numpy()\n                    if flatten\n                    else buf[-1].detach().cpu().numpy()\n                )\n                flat_list.append(arr.astype(np.float32))\n            else:\n                # no history -> use computed value directly\n                flat_list.append(values[name].reshape(-1).astype(np.float32))\n\n        if len(flat_list) == 0:\n            return np.zeros(0, dtype=np.float32)\n        return np.concatenate(flat_list, axis=0).astype(np.float32)\n"
  },
  {
    "path": "holomotion/src/evaluation/ray_evaluator_actor.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\"\"\"Minimal Ray actor for batch eval. Lives in its own module so the class\ncan be pickled without pulling in torch/jit from eval_mujoco_sim2sim.\n\"\"\"\n\nimport importlib\nimport os\nimport sys\n\nimport ray\nfrom loguru import logger\n\n\nclass RayEvaluatorActor:\n    \"\"\"Persistent Ray actor: one evaluator (one ONNX session) per actor.\n\n    Schedule with num_gpus=1/actors_per_gpu so that multiple actors share one GPU.\n    Ray sets CUDA_VISIBLE_DEVICES so this actor sees a single GPU as device 0.\n    \"\"\"\n\n    def __init__(self, config_dict, output_dir):\n        logger.remove()\n        logger.add(sys.stderr, level=\"WARNING\")\n        self.output_dir = output_dir\n        self.config_dict = config_dict\n        model_type = config_dict.get(\"model_type\") or \"holomotion\"\n        self.evaluator = _load_ray_evaluator(config_dict, model_type)\n        self.evaluator.setup()\n        if model_type == \"gmt\":\n            self.evaluator.gmt_proprio_buf.clear()\n\n    def run_clip(self, file_path):\n        from holomotion.src.evaluation.eval_mujoco_sim2sim import (\n            _build_onnx_io_dump_dir,\n            _build_onnx_io_dump_path,\n        )\n\n        fname = os.path.basename(file_path)\n        save_name = fname.replace(\".npz\", \"_eval.npz\")\n        save_path = os.path.join(self.output_dir, save_name)\n        self.evaluator.load_specific_motion(file_path)\n        self.evaluator.reset_state_teleport()\n        for i in range(self.evaluator.n_motion_frames):\n            self.evaluator.motion_frame_idx = i\n            self.evaluator._update_policy()\n            self.evaluator._apply_control(sleep=False)\n            self.evaluator.counter += 1\n        meta = {\n            \"source_file\": fname,\n            \"model\": str(self.config_dict.get(\"ckpt_onnx_path\", \"\")),\n            \"source_npz\": fname,\n            \"onnx_model\": str(self.config_dict.get(\"ckpt_onnx_path\", \"\")),\n        }\n        self.evaluator.save_batch_result(save_path, meta)\n        model_type = self.config_dict.get(\"model_type\") or \"holomotion\"\n        if bool(self.config_dict.get(\"dump_onnx_io_npy\", False)) and (\n            model_type == \"holomotion\"\n        ):\n            onnx_io_dir = _build_onnx_io_dump_dir(self.output_dir)\n            os.makedirs(onnx_io_dir, exist_ok=True)\n            self.evaluator.save_onnx_io_dump(\n                _build_onnx_io_dump_path(self.output_dir, fname), meta\n            )\n        return \"success\"\n\n\ndef _load_ray_evaluator(config_dict, model_type):\n    module_name = config_dict.get(\n        \"ray_evaluator_module\",\n        \"holomotion.src.evaluation.eval_mujoco_sim2sim\",\n    )\n    factory_module = importlib.import_module(module_name)\n    return factory_module._create_ray_evaluator(config_dict, model_type)\n"
  },
  {
    "path": "holomotion/src/evaluation/ray_metrics_postprocess.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport shutil\nimport sys\nfrom pathlib import Path\nfrom typing import Any\n\nimport ray\nfrom loguru import logger\n\n\n@ray.remote\ndef run_metrics_postprocess_job(\n    output_dir: str,\n    dataset_name: str,\n    calc_per_clip_metrics: bool,\n    failure_pos_err_thresh_m: float,\n    metric_calculation: str,\n    dof_mode: str,\n    metrics_threadpool_max_workers: int | None,\n    generate_report: bool,\n    job_log_dir: str | None,\n    ckpt_stem: str,\n) -> dict[str, Any]:\n    logger.remove()\n    logger.add(sys.stderr, level=\"WARNING\")\n\n    if calc_per_clip_metrics:\n        from holomotion.src.evaluation.metrics import run_evaluation\n\n        run_evaluation(\n            npz_dir=output_dir,\n            dataset_suffix=dataset_name,\n            failure_pos_err_thresh_m=failure_pos_err_thresh_m,\n            metric_calculation=metric_calculation,\n            dof_mode=dof_mode,\n            threadpool_max_workers=metrics_threadpool_max_workers,\n        )\n\n    report_path = None\n    if generate_report:\n        from holomotion.scripts.evaluation import mean_process_5metrics\n\n        report_path = (\n            mean_process_5metrics.generate_macro_mean_report_from_json_dir(\n                output_dir\n            )\n        )\n\n    exported_summary_tsv = None\n    if job_log_dir is not None:\n        job_log_dir_path = Path(job_log_dir)\n        sub_dataset_tsv = (\n            Path(output_dir) / \"sub_dataset_macro_mean_metrics.tsv\"\n        )\n        if sub_dataset_tsv.is_file():\n            export_name = f\"{ckpt_stem}_sub_dataset_macro_mean_metrics.tsv\"\n            export_path = job_log_dir_path / export_name\n            shutil.copy2(sub_dataset_tsv, export_path)\n            exported_summary_tsv = export_path\n\n    return {\n        \"ckpt_stem\": ckpt_stem,\n        \"output_dir\": output_dir,\n        \"report_path\": str(report_path) if report_path is not None else \"\",\n        \"exported_summary_tsv\": str(exported_summary_tsv)\n        if exported_summary_tsv is not None\n        else \"\",\n    }\n"
  },
  {
    "path": "holomotion/src/modules/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/modules/agent_modules.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom __future__ import annotations\nimport io\nimport copy\nimport math\nfrom pathlib import Path\n\nimport holomotion.src.modules.network_modules as NM\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom holomotion.src.modules.network_modules import EmpiricalNormalization\nfrom loguru import logger\nfrom tensordict import TensorDict\nfrom tensordict.nn import TensorDictModuleBase\nfrom torch.distributions import Normal\n\n\ndef _module_device(module: nn.Module) -> torch.device:\n    for tensor in module.parameters():\n        return tensor.device\n    for tensor in module.buffers():\n        return tensor.device\n    return torch.device(\"cpu\")\n\n\ndef _clone_module_for_cpu_export(module: nn.Module) -> nn.Module:\n    \"\"\"Clone a module for CPU-side export without mutating the live module.\"\"\"\n    buffer = io.BytesIO()\n    # Keep the training module on-device; rank-local device hops during export\n    # can desynchronize DDP state and hang later collectives.\n    torch.save(module, buffer)\n    buffer.seek(0)\n    clone = torch.load(buffer, map_location=\"cpu\", weights_only=False)\n    clone = clone.to(\"cpu\")\n    clone.eval()\n    return clone\n\n\nclass TensorDictAssembler(torch.nn.Module):\n    def __init__(self, schema_config: dict, *, output_mode: str = \"flat\"):\n        super().__init__()\n        self.schema_config = schema_config\n        self.output_mode = str(output_mode).lower()\n        if self.output_mode not in (\"flat\", \"seq\"):\n            raise ValueError(\n                f\"output_mode must be one of {{'flat','seq'}}, got {output_mode}\"\n            )\n\n        self.seq_len_dict: dict[str, int] = {\n            str(k): int(v.get(\"seq_len\", 1)) for k, v in schema_config.items()\n        }\n        _uniq_lens = sorted(set(self.seq_len_dict.values()))\n        self.seq_len: int | None = (\n            int(_uniq_lens[0]) if len(_uniq_lens) == 1 else None\n        )\n        if self.output_mode == \"seq\" and self.seq_len is None:\n            raise ValueError(\n                \"TensorDictAssembler(output_mode='seq') requires a single unique seq_len \"\n                f\"across schema groups, got seq_len_dict={self.seq_len_dict}\"\n            )\n\n        self.output_dim: int | None = None\n\n    @staticmethod\n    def _get_from_data(data: TensorDict, key: str):\n        # Support hierarchical keys like \"latent/z\"\n        if key in data.keys():\n            return data.get(key)\n        if \"/\" in key:\n            current = data\n            for p in key.split(\"/\"):\n                if isinstance(current, TensorDict) and p in current.keys():\n                    current = current.get(p)\n                else:\n                    return None\n            return current\n        return None\n\n    def _validate_to_seq(\n        self,\n        tensor: torch.Tensor,\n        seq_len: int,\n        term: str,\n    ) -> torch.Tensor:\n        \"\"\"Return [B, seq_len, d] tensor.\"\"\"\n        if tensor.ndim == 2:\n            # [B, d] treat as seq_len=1\n            if seq_len != 1:\n                raise ValueError(\n                    f\"Term '{term}' expected seq_len={seq_len} but tensor is 2D {tensor.shape}\"\n                )\n            return tensor[:, None, :]\n        if tensor.ndim == 3:\n            if tensor.shape[1] != seq_len:\n                raise ValueError(\n                    f\"Term '{term}' seq_len mismatch: expected {seq_len}, got {tensor.shape[1]}\"\n                )\n            return tensor\n        raise ValueError(\n            f\"Term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}\"\n        )\n\n    def _validate_and_flatten(\n        self,\n        tensor: torch.Tensor,\n        seq_len: int,\n        term: str,\n    ) -> torch.Tensor:\n        if tensor.ndim == 2:\n            # [B, D] treat as seq_len=1\n            if seq_len != 1:\n                raise ValueError(\n                    f\"Term '{term}' expected seq_len={seq_len} but tensor is 2D {tensor.shape}\"\n                )\n            return tensor\n        if tensor.ndim == 3:\n            if tensor.shape[1] != seq_len:\n                raise ValueError(\n                    f\"Term '{term}' seq_len mismatch: expected {seq_len}, got {tensor.shape[1]}\"\n                )\n            b, t, d = tensor.shape\n            return tensor.reshape(b, t * d)\n        raise ValueError(\n            f\"Term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}\"\n        )\n\n    def forward(self, data: TensorDict) -> torch.Tensor:\n        if not isinstance(data, TensorDict):\n            raise TypeError(\"TensorDictAssembler expects TensorDict input.\")\n\n        if self.output_mode == \"flat\":\n            assembled = []\n            output_dim = 0\n            batch_size = None\n\n            for _, seq_cfg in self.schema_config.items():\n                seq_len = int(seq_cfg.get(\"seq_len\", 1))\n                terms = seq_cfg.get(\"terms\", [])\n                for term in terms:\n                    tensor = self._get_from_data(data, term)\n                    if tensor is None:\n                        raise KeyError(\n                            f\"Missing term '{term}' in TensorDict input for assembler. \"\n                            \"Use explicit hierarchical terms (e.g. 'group/term') \"\n                            \"for nested TensorDict keys.\"\n                        )\n                    flat = self._validate_and_flatten(tensor, seq_len, term)\n                    if batch_size is None:\n                        batch_size = flat.shape[0]\n                    elif flat.shape[0] != batch_size:\n                        raise ValueError(\n                            f\"Batch size mismatch for term '{term}': {flat.shape[0]} vs {batch_size}\"\n                        )\n                    assembled.append(flat)\n                    output_dim += flat.shape[-1]\n\n            if not assembled:\n                raise ValueError(\n                    \"Assembler received an empty schema or no tensors found\"\n                )\n\n            out = torch.cat(assembled, dim=-1)\n\n            # Cache output_dim on first successful forward\n            if self.output_dim is None:\n                self.output_dim = output_dim\n            return out\n\n        # output_mode == \"seq\"\n        assembled_seq = []\n        batch_size = None\n        seq_len_ref = None\n\n        for _, seq_cfg in self.schema_config.items():\n            seq_len = int(seq_cfg.get(\"seq_len\", 1))\n            if seq_len_ref is None:\n                seq_len_ref = seq_len\n            elif seq_len != seq_len_ref:\n                raise ValueError(\n                    \"TensorDictAssembler(output_mode='seq') requires consistent seq_len \"\n                    f\"across schema groups, got {seq_len_ref} vs {seq_len}\"\n                )\n            terms = seq_cfg.get(\"terms\", [])\n            for term in terms:\n                tensor = self._get_from_data(data, term)\n                if tensor is None:\n                    raise KeyError(\n                        f\"Missing term '{term}' in TensorDict input for assembler. \"\n                        \"Use explicit hierarchical terms (e.g. 'group/term') \"\n                        \"for nested TensorDict keys.\"\n                    )\n                seq_tensor = self._validate_to_seq(tensor, seq_len, term)\n                if batch_size is None:\n                    batch_size = seq_tensor.shape[0]\n                elif seq_tensor.shape[0] != batch_size:\n                    raise ValueError(\n                        f\"Batch size mismatch for term '{term}': {seq_tensor.shape[0]} vs {batch_size}\"\n                    )\n                assembled_seq.append(seq_tensor)\n\n        if not assembled_seq:\n            raise ValueError(\n                \"Assembler received an empty schema or no tensors found\"\n            )\n\n        out = torch.cat(assembled_seq, dim=-1)\n        # Expose seq_len and output_dim for sequence assembly\n        if self.seq_len is None:\n            self.seq_len = int(out.shape[1])\n        if self.output_dim is None:\n            self.output_dim = int(out.shape[-1])\n        return out\n\n    @torch.inference_mode()\n    def infer_output_dim(self, sample: TensorDict) -> int:\n        \"\"\"Run a dry forward pass to populate output_dim without grads.\"\"\"\n        if self.output_dim is not None:\n            return int(self.output_dim)\n        _ = self.forward(sample)\n        return self.output_dim\n\n\nclass PPOActorOnnxModule(nn.Module):\n    def __init__(\n        self,\n        actor_module: nn.Module,\n        obs_normalizer: nn.Module,\n        obs_norm_enabled: bool,\n        obs_norm_clip: float,\n    ):\n        super().__init__()\n        self.actor_module = actor_module\n        self.obs_normalizer = obs_normalizer\n        self.obs_norm_enabled = bool(obs_norm_enabled)\n        self.obs_norm_clip = float(obs_norm_clip)\n\n    def forward(self, obs: torch.Tensor) -> torch.Tensor:\n        actor_obs = obs\n        if self.obs_norm_enabled:\n            actor_obs = self.obs_normalizer.normalize_only(actor_obs)\n            if self.obs_norm_clip > 0.0:\n                actor_obs = torch.clamp(\n                    actor_obs, -self.obs_norm_clip, self.obs_norm_clip\n                )\n        return self.actor_module(actor_obs)\n\n\nclass PPOTFActorOnnxModule(nn.Module):\n    def __init__(\n        self,\n        actor_module: nn.Module,\n        obs_normalizer: nn.Module,\n        obs_norm_enabled: bool,\n        obs_norm_clip: float,\n    ):\n        super().__init__()\n        self.actor_module = actor_module\n        self.obs_normalizer = obs_normalizer\n        self.obs_norm_enabled = bool(obs_norm_enabled)\n        self.obs_norm_clip = float(obs_norm_clip)\n\n    def forward(\n        self,\n        obs: torch.Tensor,\n        past_key_values: torch.Tensor,\n        step_idx: torch.Tensor,\n    ) -> tuple[torch.Tensor, ...]:\n        actor_obs = obs\n        if self.obs_norm_enabled:\n            actor_obs = self.obs_normalizer.normalize_only(actor_obs)\n            if self.obs_norm_clip > 0.0:\n                actor_obs = torch.clamp(\n                    actor_obs, -self.obs_norm_clip, self.obs_norm_clip\n                )\n        return self.actor_module(\n            actor_obs,\n            past_key_values=past_key_values,\n            current_pos=step_idx,\n        )\n\n\nclass PPOTFWoKVCacheActorOnnxModule(nn.Module):\n    def __init__(\n        self,\n        actor_module: nn.Module,\n        obs_normalizer: nn.Module,\n        obs_norm_enabled: bool,\n        obs_norm_clip: float,\n    ):\n        super().__init__()\n        self.actor_module = actor_module\n        self.obs_normalizer = obs_normalizer\n        self.obs_norm_enabled = bool(obs_norm_enabled)\n        self.obs_norm_clip = float(obs_norm_clip)\n\n    def forward(self, obs: torch.Tensor) -> torch.Tensor:\n        if obs.ndim != 3:\n            raise ValueError(\n                f\"Expected obs [B, 32, D] for no-kv ONNX path, got {obs.shape}\"\n            )\n        if obs.shape[1] != 32:\n            raise ValueError(\n                f\"Expected fixed token length 32, got {int(obs.shape[1])}\"\n            )\n        actor_obs = obs\n        if self.obs_norm_enabled:\n            actor_obs = self.obs_normalizer.normalize_only(actor_obs)\n            if self.obs_norm_clip > 0.0:\n                actor_obs = torch.clamp(\n                    actor_obs, -self.obs_norm_clip, self.obs_norm_clip\n                )\n        action_seq = self.actor_module.sequence_mu(actor_obs, attn_mask=None)\n        return action_seq[:, -1, :]\n\n\nclass PPOCondTFActorOnnxModule(nn.Module):\n    def __init__(\n        self,\n        actor_module: nn.Module,\n        state_obs_normalizer: nn.Module,\n        obs_norm_enabled: bool,\n        obs_norm_clip: float,\n        state_dim: int,\n        future_seq_len: int,\n        future_token_dim: int,\n        future_term_dims: list[int],\n    ):\n        super().__init__()\n        self.actor_module = actor_module\n        self.state_obs_normalizer = state_obs_normalizer\n        self.obs_norm_enabled = bool(obs_norm_enabled)\n        self.obs_norm_clip = float(obs_norm_clip)\n        self.state_dim = int(state_dim)\n        self.future_seq_len = int(future_seq_len)\n        self.future_token_dim = int(future_token_dim)\n        self.future_term_dims = [int(x) for x in future_term_dims]\n        if any(d <= 0 for d in self.future_term_dims):\n            raise ValueError(\n                f\"future_term_dims must be all positive, got {self.future_term_dims}\"\n            )\n        if sum(self.future_term_dims) != self.future_token_dim:\n            raise ValueError(\n                \"future_term_dims sum mismatch: expected \"\n                f\"{self.future_token_dim}, got {sum(self.future_term_dims)}\"\n            )\n\n    def forward(\n        self,\n        obs: torch.Tensor,\n        past_key_values: torch.Tensor,\n        step_idx: torch.Tensor,\n    ) -> tuple[torch.Tensor, ...]:\n        if obs.ndim != 2:\n            raise ValueError(f\"Expected obs [B, D], got {obs.shape}\")\n        state_obs = obs[:, : self.state_dim]\n        future_flat = obs[:, self.state_dim :]\n        expected_future_dim = self.future_seq_len * self.future_token_dim\n        if future_flat.shape[-1] != expected_future_dim:\n            raise ValueError(\n                \"Future obs dim mismatch for ONNX path: expected \"\n                f\"{expected_future_dim}, got {future_flat.shape[-1]}\"\n            )\n        if self.obs_norm_enabled:\n            state_obs = self.state_obs_normalizer.normalize_only(state_obs)\n            if self.obs_norm_clip > 0.0:\n                state_obs = torch.clamp(\n                    state_obs, -self.obs_norm_clip, self.obs_norm_clip\n                )\n        # Reconstruct [B, N_fut, D_fut] from term-major flattened layout:\n        # [term1 (N_fut*d1), term2 (N_fut*d2), ...] -> per-step concat along last dim.\n        b = int(obs.shape[0])\n        offset = 0\n        future_parts = []\n        for d_term in self.future_term_dims:\n            span = int(self.future_seq_len * d_term)\n            chunk = future_flat[:, offset : offset + span]\n            future_parts.append(chunk.reshape(b, self.future_seq_len, d_term))\n            offset += span\n        if offset != int(future_flat.shape[-1]):\n            raise ValueError(\n                \"Future flat slicing mismatch in ONNX path: \"\n                f\"consumed={offset}, total={int(future_flat.shape[-1])}\"\n            )\n        future_obs = torch.cat(future_parts, dim=-1)\n        return self.actor_module._forward_inference_onnx_cond(\n            state_obs,\n            future_obs,\n            past_key_values,\n            step_idx,\n        )\n\n\nclass PPOActor(TensorDictModuleBase):\n    def __init__(\n        self,\n        obs_schema: dict | None,\n        module_config_dict: dict,\n        num_actions: int,\n        init_noise_std: float,\n        *,\n        obs_example: dict | None = None,\n    ):\n        super(PPOActor, self).__init__()\n\n        self.use_logvar = module_config_dict.get(\"use_logvar\", False)\n        obs_norm_cfg = module_config_dict.get(\"obs_norm\", {})\n        self.obs_norm_enabled = bool(obs_norm_cfg.get(\"enabled\", False))\n        if self.obs_norm_enabled:\n            self.obs_norm_clip = float(obs_norm_cfg.get(\"clip_range\", 0.0))\n            self.obs_norm_eps = float(obs_norm_cfg.get(\"epsilon\", 1.0e-8))\n            self.obs_norm_update_method = str(\n                obs_norm_cfg.get(\n                    \"update_method\", obs_norm_cfg.get(\"method\", \"cumulative\")\n                )\n            ).lower()\n            self.obs_norm_ema_momentum = float(\n                obs_norm_cfg.get(\"ema_momentum\")\n            )\n\n        module_config_dict = self._process_module_config(\n            module_config_dict,\n            num_actions,\n        )\n\n        self.actor_net_type = module_config_dict.get(\"type\", \"MLP\")\n\n        logger.info(f\"actor_net_type: {self.actor_net_type}\")\n\n        actor_net_class = getattr(NM, self.actor_net_type, None)\n\n        if actor_net_class is NM.MLP and obs_schema is None:\n            raise ValueError(\n                \"PPOActor(Mlp) requires obs_schema so the agent module can assemble\"\n                \"TensorDict observations into a flat tensor.\"\n            )\n\n        if obs_schema is not None:\n            output_mode = \"seq\" if actor_net_class is NM.ConvMLP else \"flat\"\n            self.assembler = TensorDictAssembler(\n                obs_schema, output_mode=output_mode\n            )\n            if obs_example is not None:\n                self.assembler.infer_output_dim(obs_example)\n            if self.assembler.output_dim is None:\n                raise ValueError(\n                    \"TensorDictAssembler could not infer output_dim\"\n                )\n            input_dim_for_net = int(self.assembler.output_dim)\n        else:\n            raise ValueError(\"obs_schema can't be None!\")\n\n        actor_in_keys: list[str] = []\n        for _, seq_cfg in obs_schema.items():\n            if not isinstance(seq_cfg, dict):\n                continue\n            for term in seq_cfg.get(\"terms\", []):\n                actor_in_keys.append(str(term))\n        self.in_keys = actor_in_keys\n        self.out_keys = [\n            \"actions\",\n            \"actions_log_prob\",\n            \"mu\",\n            \"sigma\",\n            \"entropy\",\n        ]\n        if self.obs_norm_enabled and self.assembler is not None:\n            self.obs_normalizer = EmpiricalNormalization(\n                shape=self.assembler.output_dim,\n                eps=self.obs_norm_eps,\n                update_method=self.obs_norm_update_method,\n                ema_momentum=self.obs_norm_ema_momentum,\n            )\n        else:\n            self.obs_normalizer = nn.Identity()\n\n        # Always pass obs_example if available\n        if obs_example is not None:\n            self.actor_module = actor_net_class(\n                input_dim=input_dim_for_net,\n                output_dim=int(module_config_dict[\"output_dim\"]),\n                module_config_dict=module_config_dict,\n            )\n        else:\n            raise ValueError(\"Obs example can't be None!\")\n\n        if \"output_head_init_scale\" in module_config_dict:\n            output_head_init_scale = float(\n                module_config_dict[\"output_head_init_scale\"]\n            )\n            if output_head_init_scale <= 0.0:\n                raise ValueError(\"output_head_init_scale must be > 0.\")\n            output_head = self.actor_module.output_head\n            if not isinstance(output_head, nn.Linear):\n                raise ValueError(\n                    \"output_head_init_scale requires actor_module.output_head to be nn.Linear.\"\n                )\n            with torch.no_grad():\n                output_head.weight.mul_(output_head_init_scale)\n                if output_head.bias is not None:\n                    output_head.bias.mul_(output_head_init_scale)\n\n        self._actor_schema_module = bool(\n            getattr(self.actor_module, \"proprio_assembler\", None)\n        )\n        self.fix_sigma = module_config_dict.get(\"fix_sigma\", False)\n        self.max_sigma = module_config_dict.get(\"max_sigma\", 1.0)\n        self.min_sigma = module_config_dict.get(\"min_sigma\", 0.1)\n\n        if \"noise_std_type\" in module_config_dict:\n            self.noise_std_type = str(\n                module_config_dict[\"noise_std_type\"]\n            ).lower()\n        elif self.use_logvar:\n            self.noise_std_type = \"log\"\n        else:\n            self.noise_std_type = \"scalar\"\n\n        # Action noise parameters (kept outside nets so optimizer updates them)\n        if self.noise_std_type == \"log\":\n            logger.info(\"Using log-std parameterization for action noise\")\n            self.log_std = nn.Parameter(\n                torch.log(torch.ones(num_actions) * init_noise_std)\n            )\n            if self.fix_sigma:\n                self.log_std.requires_grad = False\n        else:  # scalar (default)\n            self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))\n            if self.fix_sigma:\n                self.std.requires_grad = False\n        self.distribution = None\n        # disable args validation for speedup\n        Normal.set_default_validate_args = False\n        self.actor_obs_transforms: list[callable] = []\n        if self.obs_norm_enabled:\n            self.actor_obs_transforms.append(self._normalize_actor_obs)\n\n    def _process_module_config(self, module_config_dict, num_actions):\n        if module_config_dict.get(\"output_schema\", None) is not None:\n            raise ValueError(\n                \"PPOActor no longer supports module_config_dict.output_schema. \"\n                \"Use scalar module_config_dict.output_dim instead.\"\n            )\n\n        # Resolve output_dim placeholders when present.\n        if \"output_dim\" in module_config_dict:\n            output_dim = module_config_dict[\"output_dim\"]\n            if isinstance(output_dim, list):\n                raise ValueError(\n                    \"PPOActor expects module_config_dict.output_dim to be a scalar. \"\n                    \"List-valued output_dim is not supported.\"\n                )\n            if output_dim == \"robot_action_dim\":\n                module_config_dict[\"output_dim\"] = num_actions\n\n        return module_config_dict\n\n    def _sigma_from_params(self) -> torch.Tensor:\n        if self.noise_std_type == \"log\":\n            return torch.exp(self.log_std)\n        return self.std\n\n    def _normalize_actor_obs(\n        self, obs: torch.Tensor, update: bool\n    ) -> torch.Tensor:\n        if not self.obs_norm_enabled:\n            return obs\n        clip = float(self.obs_norm_clip)\n        if obs.ndim == 3:\n            b, seq_len, d = obs.shape\n            flat_obs = obs.reshape(b * seq_len, d)\n            if update:\n                self.obs_normalizer.update(flat_obs)\n            flat_obs = self.obs_normalizer.normalize_only(flat_obs)\n            obs = flat_obs.reshape(b, seq_len, d)\n        else:\n            if update:\n                self.obs_normalizer.update(obs)\n            obs = self.obs_normalizer.normalize_only(obs)\n        if clip > 0.0:\n            obs = torch.clamp(obs, -clip, clip)\n        return obs\n\n    def _sigma_like(self, like: torch.Tensor) -> torch.Tensor:\n        sigma_vec = self._sigma_from_params()\n        sigma_vec = torch.clamp(\n            sigma_vec,\n            min=float(self.min_sigma),\n            max=float(self.max_sigma),\n        )\n        if sigma_vec.ndim == 1 and like.ndim >= 2:\n            view_shape = [1 for _ in range(like.ndim - 1)] + [\n                sigma_vec.shape[0]\n            ]\n            return sigma_vec.view(*view_shape).expand_as(like)\n        if sigma_vec.shape != like.shape:\n            return sigma_vec.expand_as(like)\n        return sigma_vec\n\n    @property\n    def actor(self):\n        return self.actor_module\n\n    @property\n    def flat_obs_dim(self) -> int:\n        if self.assembler is None:\n            raise ValueError(\n                \"PPOActor has no assembler; flat obs dim unavailable.\"\n            )\n        if self.assembler.output_dim is None:\n            raise ValueError(\n                \"PPOActor assembler output_dim is not initialized.\"\n            )\n        return int(self.assembler.output_dim)\n\n    def export_onnx(\n        self,\n        onnx_path: str | Path,\n        *,\n        opset_version: int = 17,\n    ) -> str:\n        if self._actor_schema_module:\n            raise ValueError(\n                \"PPOActor export expects flat-obs actor modules, not schema-native modules.\"\n            )\n        export_path = Path(onnx_path)\n        export_path.parent.mkdir(parents=True, exist_ok=True)\n\n        if hasattr(self.actor_module, \"clear_router_distribution_cache\"):\n            self.actor_module.clear_router_distribution_cache()\n        actor_module = _clone_module_for_cpu_export(self.actor_module)\n        if self.obs_norm_enabled:\n            obs_normalizer = _clone_module_for_cpu_export(self.obs_normalizer)\n        else:\n            obs_normalizer = nn.Identity()\n\n        exporter = PPOActorOnnxModule(\n            actor_module=actor_module,\n            obs_normalizer=obs_normalizer,\n            obs_norm_enabled=self.obs_norm_enabled,\n            obs_norm_clip=self.obs_norm_clip if self.obs_norm_enabled else 0.0,\n        ).to(\"cpu\")\n        exporter.eval()\n\n        obs = torch.zeros(\n            1, self.flat_obs_dim, device=\"cpu\", dtype=torch.float32\n        )\n        torch.onnx.export(\n            exporter,\n            (obs,),\n            str(export_path),\n            export_params=True,\n            opset_version=opset_version,\n            verbose=False,\n            dynamo=False,\n            input_names=[\"obs\"],\n            output_names=[\"actions\"],\n        )\n        return str(export_path)\n\n    def forward(\n        self,\n        obs_td: TensorDict,\n        actions: torch.Tensor | None = None,\n        mode: str = \"sampling\",\n        *,\n        update_obs_norm: bool = True,\n    ) -> TensorDict:\n        \"\"\"TensorDict-first forward for PPOActor.\n\n        Returns a TensorDict with keys:\n        - actions: [B, A]\n        - actions_log_prob: [B] (sampling/logp only)\n        - mu: [B, A]\n        - sigma: [B, A]\n        - entropy: [B] (sampling/logp only)\n        \"\"\"\n        if mode not in (\"sampling\", \"logp\", \"inference\"):\n            raise ValueError(f\"Unsupported mode: {mode}\")\n        if not isinstance(obs_td, TensorDict):\n            raise ValueError(\"PPOActor.forward expects TensorDict input.\")\n\n        td = obs_td.clone(\n            recurse=False\n        )  # this only clones the tree sturcture, not the data\n\n        if self._actor_schema_module:\n            mu = self.actor_module(obs_td)\n        else:\n            if self.assembler is None:\n                raise ValueError(\n                    \"Flat-tensor actor module requires obs_schema in PPOActor init.\"\n                )\n            actor_obs = self.assembler(obs_td)\n            update = bool(update_obs_norm)\n            for fn in self.actor_obs_transforms:\n                actor_obs = fn(actor_obs, update)\n            mu = self.actor_module(actor_obs)\n\n        sigma = self._sigma_like(mu)\n        td.set(\"mu\", mu)\n        td.set(\"sigma\", sigma)\n\n        if mode == \"inference\":\n            actions_out = mu\n            td.set(\"actions\", actions_out)\n            return td\n\n        self.distribution = Normal(mu, sigma)\n        if mode == \"sampling\":\n            actions_out = self.distribution.sample()\n        else:\n            if actions is None:\n                raise ValueError(\"actions must be provided when mode='logp'\")\n            actions_out = actions\n\n        td.set(\"actions\", actions_out)\n        td.set(\n            \"actions_log_prob\",\n            self.distribution.log_prob(actions_out).sum(dim=-1),\n        )\n        td.set(\"entropy\", self.distribution.entropy().sum(dim=-1))\n        return td\n\n    def update_distribution(self, actor_obs):\n        mean = self.actor(actor_obs)\n        # Resolve std according to parameterization\n        std_val = self._sigma_from_params()\n\n        std_val = torch.clamp(std_val, min=self.min_sigma, max=self.max_sigma)\n        self.distribution = Normal(mean, std_val)\n\n    def override_sigma(self, sigma_override: float | torch.Tensor) -> None:\n        \"\"\"Override actor sigma parameters (std) explicitly.\n\n        Args:\n            sigma_override: scalar or [A] tensor for sigma_theta (std).\n        \"\"\"\n        if self.noise_std_type not in (\"scalar\", \"log\"):\n            raise ValueError(\n                f\"Unsupported noise_std_type for override: {self.noise_std_type}\"\n            )\n        param = self.log_std if self.noise_std_type == \"log\" else self.std\n        sigma_tensor = torch.as_tensor(\n            sigma_override, device=param.device, dtype=param.dtype\n        )\n        if sigma_tensor.numel() == 1:\n            sigma_tensor = sigma_tensor.expand_as(param)\n        elif sigma_tensor.shape != param.shape:\n            raise ValueError(\n                f\"sigma_override shape {tuple(sigma_tensor.shape)} does not match \"\n                f\"actor sigma shape {tuple(param.shape)}.\"\n            )\n        if torch.any(sigma_tensor <= 0):\n            raise ValueError(\"sigma_override must be > 0 for all dims.\")\n        if self.noise_std_type == \"log\":\n            sigma_tensor = torch.log(sigma_tensor)\n        with torch.no_grad():\n            param.copy_(sigma_tensor)\n\n\nclass PPOCritic(TensorDictModuleBase):\n    def __init__(\n        self,\n        obs_schema: dict | None,\n        module_config_dict,\n        *,\n        obs_example: dict | None = None,\n    ):\n        super(PPOCritic, self).__init__()\n        self.critic_net_type = module_config_dict.get(\"type\", \"MLP\")\n        obs_norm_cfg = module_config_dict.get(\"obs_norm\", {})\n        self.obs_norm_enabled = bool(obs_norm_cfg.get(\"enabled\", False))\n\n        if self.obs_norm_enabled:\n            self.obs_norm_clip = float(obs_norm_cfg.get(\"clip_range\", 0.0))\n            self.obs_norm_eps = float(obs_norm_cfg.get(\"epsilon\", 1.0e-8))\n\n            self.obs_norm_update_method = str(\n                obs_norm_cfg.get(\n                    \"update_method\", obs_norm_cfg.get(\"method\", \"cumulative\")\n                )\n            ).lower()\n            self.obs_norm_ema_momentum = float(\n                obs_norm_cfg.get(\"ema_momentum\")\n            )\n\n        critic_net_class = getattr(NM, self.critic_net_type, None)\n        if critic_net_class is None:\n            critic_net_class = globals().get(self.critic_net_type, None)\n        if critic_net_class is None or not isinstance(critic_net_class, type):\n            available_classes = [\n                name\n                for name in dir(NM)\n                if isinstance(getattr(NM, name, None), type)\n            ] + [\n                name\n                for name, obj in globals().items()\n                if isinstance(obj, type)\n            ]\n            raise NotImplementedError(\n                f\"Unknown critic_net_type: {self.critic_net_type}. \"\n                f\"Available classes: {available_classes}\"\n            )\n\n        if critic_net_class is NM.MLP and obs_schema is None:\n            raise ValueError(\n                \"PPOCritic(MLP) requires obs_schema so the agent module can assemble \"\n                \"TensorDict observations into a flat tensor.\"\n            )\n\n        # Build assembler for flat-tensor networks only\n        # Schema-based networks (e.g., MultiTaskCritic) don't need it\n        if obs_schema is not None:\n            output_mode = \"seq\" if critic_net_class is NM.ConvMLP else \"flat\"\n            self.assembler = TensorDictAssembler(\n                obs_schema, output_mode=output_mode\n            )\n            if obs_example is not None:\n                self.assembler.infer_output_dim(obs_example)\n            if self.assembler.output_dim is None:\n                raise ValueError(\n                    \"TensorDictAssembler could not infer output_dim; provide obs_example.\"\n                )\n            input_dim_for_net = int(self.assembler.output_dim)\n        else:\n            # Schema-based modules don't use wrapper's assembler\n            self.assembler = None\n            input_dim_for_net = 0\n\n        critic_in_keys: list[str] = []\n        if obs_schema is not None:\n            for _, seq_cfg in obs_schema.items():\n                if not isinstance(seq_cfg, dict):\n                    continue\n                for term in seq_cfg.get(\"terms\", []):\n                    critic_in_keys.append(str(term))\n        self.in_keys = critic_in_keys\n        self.out_keys = [\"values\"]\n\n        if self.obs_norm_enabled and self.assembler is not None:\n            self.obs_normalizer = EmpiricalNormalization(\n                shape=self.assembler.output_dim,\n                eps=self.obs_norm_eps,\n                update_method=self.obs_norm_update_method,\n                ema_momentum=self.obs_norm_ema_momentum,\n            )\n        else:\n            self.obs_normalizer = nn.Identity()\n\n        # Always pass obs_example if available\n        if obs_example is not None:\n            self.critic_module = critic_net_class(\n                input_dim=input_dim_for_net,\n                output_dim=int(module_config_dict[\"output_dim\"]),\n                module_config_dict=module_config_dict,\n            )\n\n        else:\n            raise ValueError(\"obs_schema can't be None!\")\n        self._critic_schema_module = bool(\n            getattr(self.critic_module, \"proprio_assembler\", None)\n        )\n        self.critic_obs_transforms: list[callable] = []\n        if self.obs_norm_enabled:\n            self.critic_obs_transforms.append(self._normalize_critic_obs)\n\n    def _normalize_critic_obs(\n        self, obs: torch.Tensor, update: bool\n    ) -> torch.Tensor:\n        if not self.obs_norm_enabled:\n            return obs\n        clip = float(self.obs_norm_clip)\n        if obs.ndim == 3:\n            b, seq_len, d = obs.shape\n            flat_obs = obs.reshape(b * seq_len, d)\n            if update:\n                self.obs_normalizer.update(flat_obs)\n            flat_obs = self.obs_normalizer.normalize_only(flat_obs)\n            obs = flat_obs.reshape(b, seq_len, d)\n        else:\n            if update:\n                self.obs_normalizer.update(obs)\n            obs = self.obs_normalizer.normalize_only(obs)\n        if clip > 0.0:\n            obs = torch.clamp(obs, -clip, clip)\n        return obs\n\n    def forward(\n        self,\n        obs_td: TensorDict,\n        update_obs_norm: bool = True,\n        **kwargs,\n    ) -> TensorDict:\n        \"\"\"TensorDict-first forward for PPOCritic.\n\n        Args:\n            obs_td: TensorDict observations keyed by obs terms.\n            update_obs_norm: If False, skip updating running stats.\n\n        Returns:\n            TensorDict with key:\n                - \"values\": [B, 1]\n        \"\"\"\n        if not isinstance(obs_td, TensorDict):\n            raise ValueError(\"PPOCritic.forward expects TensorDict input.\")\n\n        td = obs_td.clone(recurse=False)\n        if self._critic_schema_module:\n            values = self.critic_module(obs_td)\n            if values.ndim == 1:\n                values = values[..., None]\n            td.set(\"values\", values)\n            return td\n\n        if self.assembler is None:\n            raise ValueError(\n                \"Flat-tensor critic module requires obs_schema in PPOCritic init.\"\n            )\n        critic_obs = self.assembler(obs_td)\n        update = bool(update_obs_norm)\n        for fn in self.critic_obs_transforms:\n            critic_obs = fn(critic_obs, update)\n        values = self.critic_module(critic_obs)\n        if values.ndim == 1:\n            values = values[..., None]\n        td.set(\"values\", values)\n        return td\n\n\nclass PPOTFActor(PPOActor):\n    \"\"\"Transformer-based PPO actor wrapper compatible with PPOActor interface.\n\n    - Uses NM.TransformerDecoderPolicy as actor_module\n    - Provides KV-cache controls\n    - Uses model-predicted diagonal std for distribution\n    \"\"\"\n\n    def __init__(\n        self,\n        obs_schema: dict | None,\n        module_config_dict: dict,\n        num_actions: int,\n        init_noise_std: float,\n        *,\n        obs_example: dict | None = None,\n    ):\n        super().__init__(\n            obs_schema=obs_schema,\n            module_config_dict=module_config_dict,\n            num_actions=num_actions,\n            init_noise_std=init_noise_std,\n            obs_example=obs_example,\n        )\n        # Ensure initial std is strictly inside [min_sigma, max_sigma] to avoid boundary saturation\n        init_std_val = float(init_noise_std)\n        if not (self.min_sigma < init_std_val < self.max_sigma):\n            # Expand bounds conservatively if needed\n            if init_std_val >= self.max_sigma:\n                self.max_sigma = max(self.max_sigma, init_std_val * 2.0)\n            if init_std_val <= self.min_sigma:\n                self.min_sigma = min(self.min_sigma, init_std_val * 0.1)\n        aux_cfg = module_config_dict.get(\"aux_state_pred\", {})\n        self.aux_state_pred_enabled = bool(aux_cfg.get(\"enabled\", False))\n        aux_cmd_cfg = module_config_dict.get(\"aux_router_command_recon\", {})\n        self.aux_router_command_recon_enabled = bool(\n            aux_cmd_cfg.get(\"enabled\", False)\n        )\n        aux_switch_cfg = module_config_dict.get(\n            \"aux_router_switch_penalty\", {}\n        )\n        self.aux_router_switch_penalty_enabled = bool(\n            aux_switch_cfg.get(\"enabled\", False)\n        )\n        aux_router_future_cfg = module_config_dict.get(\n            \"aux_router_future_recon\", {}\n        )\n        self.aux_router_future_recon_enabled = bool(\n            aux_router_future_cfg.get(\"enabled\", False)\n        )\n        self.aux_router_future_recon_assembler: TensorDictAssembler | None = (\n            None\n        )\n\n    def _sigma_from_params(self) -> torch.Tensor:\n        # Prefer log-std if present; otherwise use softplus(linear) for positivity\n        if hasattr(self, \"log_std\"):\n            return torch.exp(self.log_std)\n        return F.softplus(self.std)\n\n    def reset_kv_cache(self, num_envs: int, device):\n        if hasattr(self.actor_module, \"reset_kv_cache\"):\n            self.actor_module.reset_kv_cache(num_envs, device)\n\n    def clear_env_cache(self, env_ids: torch.Tensor):\n        if hasattr(self.actor_module, \"clear_env_cache\"):\n            self.actor_module.clear_env_cache(env_ids)\n\n    def onnx_past_key_values_shape(\n        self, *, batch_size: int = 1\n    ) -> tuple[int, int, int, int, int, int]:\n        num_kv_layers = int(\n            getattr(\n                self.actor_module, \"onnx_kv_layers\", self.actor_module.n_layers\n            )\n        )\n        return (\n            num_kv_layers,\n            2,\n            int(batch_size),\n            int(self.actor_module.max_ctx_len),\n            int(self.actor_module.n_kv_heads),\n            int(self.actor_module.head_dim),\n        )\n\n    def onnx_moe_layer_indices(self) -> list[int]:\n        layers = getattr(self.actor_module, \"layers\", None)\n        if layers is None:\n            return []\n        return [\n            layer_idx\n            for layer_idx, layer in enumerate(layers)\n            if isinstance(layer, NM.GroupedMoEBlock)\n        ]\n\n    def onnx_routing_output_names(self) -> list[str]:\n        output_names: list[str] = []\n        for layer_idx in self.onnx_moe_layer_indices():\n            output_names.extend(\n                [\n                    f\"moe_layer_{layer_idx}_expert_indices\",\n                    f\"moe_layer_{layer_idx}_expert_logits\",\n                ]\n            )\n        return output_names\n\n    def _maybe_update_aux_router_future_recon_norm(\n        self,\n        obs_td: TensorDict,\n        *,\n        update: bool,\n    ) -> None:\n        if (\n            not update\n            or not self.aux_router_future_recon_enabled\n            or self.aux_router_future_recon_assembler is None\n        ):\n            return\n        future_target = self.aux_router_future_recon_assembler(obs_td)\n        self.actor_module.update_aux_router_future_recon_normalizer(\n            future_target\n        )\n\n    def export_onnx(\n        self,\n        onnx_path: str | Path,\n        *,\n        opset_version: int = 17,\n        use_kv_cache: bool = True,\n    ) -> str:\n        export_path = Path(onnx_path)\n        export_path.parent.mkdir(parents=True, exist_ok=True)\n\n        if hasattr(self.actor_module, \"clear_router_distribution_cache\"):\n            self.actor_module.clear_router_distribution_cache()\n        actor_module = _clone_module_for_cpu_export(self.actor_module)\n        if self.obs_norm_enabled:\n            obs_normalizer = _clone_module_for_cpu_export(self.obs_normalizer)\n        else:\n            obs_normalizer = nn.Identity()\n\n        obs = torch.zeros(\n            1, self.flat_obs_dim, device=\"cpu\", dtype=torch.float32\n        )\n        if use_kv_cache:\n            exporter = PPOTFActorOnnxModule(\n                actor_module=actor_module,\n                obs_normalizer=obs_normalizer,\n                obs_norm_enabled=self.obs_norm_enabled,\n                obs_norm_clip=self.obs_norm_clip\n                if self.obs_norm_enabled\n                else 0.0,\n            ).to(\"cpu\")\n            exporter.eval()\n\n            cache_shape = self.onnx_past_key_values_shape(batch_size=1)\n            past_key_values = torch.zeros(\n                *cache_shape, device=\"cpu\", dtype=torch.float32\n            )\n            step_idx = torch.tensor([0], dtype=torch.long, device=\"cpu\")\n            output_names = [\n                \"actions\",\n                \"present_key_values\",\n                *self.onnx_routing_output_names(),\n            ]\n\n            torch.onnx.export(\n                exporter,\n                (obs, past_key_values, step_idx),\n                str(export_path),\n                export_params=True,\n                opset_version=opset_version,\n                verbose=False,\n                dynamo=False,\n                input_names=[\"obs\", \"past_key_values\", \"step_idx\"],\n                output_names=output_names,\n            )\n        else:\n            exporter = PPOTFWoKVCacheActorOnnxModule(\n                actor_module=actor_module,\n                obs_normalizer=obs_normalizer,\n                obs_norm_enabled=self.obs_norm_enabled,\n                obs_norm_clip=self.obs_norm_clip\n                if self.obs_norm_enabled\n                else 0.0,\n            ).to(\"cpu\")\n            exporter.eval()\n            obs = torch.zeros(\n                1, 32, self.flat_obs_dim, device=\"cpu\", dtype=torch.float32\n            )\n\n            torch.onnx.export(\n                exporter,\n                (obs,),\n                str(export_path),\n                export_params=True,\n                opset_version=opset_version,\n                verbose=False,\n                dynamo=False,\n                input_names=[\"obs\"],\n                output_names=[\"actions\"],\n            )\n        return str(export_path)\n\n    def update_distribution(self, actor_obs):\n        \"\"\"Distribution using TransformerDecoderPolicy single-step mu + learnable log-std.\n\n        Args:\n            actor_obs: [B, D] normalized obs\n        \"\"\"\n        mu = self.actor_module.single_step_mu(actor_obs)\n        std = self._sigma_from_params()\n        std = torch.clamp(std, min=self.min_sigma, max=self.max_sigma)\n        self.distribution = Normal(mu, std)\n\n    def forward(\n        self,\n        obs_td: TensorDict | torch.Tensor,\n        actions: torch.Tensor | None = None,\n        mode: str = \"sampling\",\n        attn_mask: torch.Tensor | None = None,\n        *,\n        update_obs_norm: bool = True,\n        past_key_values: torch.Tensor | None = None,\n        current_pos: torch.Tensor | None = None,\n    ) -> TensorDict | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"TensorDict-first forward for PPOTFActor.\n\n        Modes:\n        - \"sampling\" / \"logp\" / \"inference\": single-step policy with KV-cache-aware\n          mean prediction via `actor_module.single_step_mu`.\n        - \"sequence_logp\": sequence log-prob evaluation with attention mask support.\n        \"\"\"\n        if past_key_values is not None:\n            if isinstance(obs_td, TensorDict):\n                if self.assembler is None:\n                    raise ValueError(\n                        \"PPOTFActor requires obs_schema/assembler for ONNX cache path.\"\n                    )\n                actor_obs = self.assembler(obs_td)\n            else:\n                actor_obs = obs_td\n            return self.actor_module(\n                actor_obs,\n                past_key_values=past_key_values,\n                current_pos=current_pos,\n            )\n        if mode == \"sequence_logp\":\n            if not isinstance(obs_td, TensorDict):\n                raise ValueError(\n                    \"PPOTFActor.forward(mode='sequence_logp') expects TensorDict input.\"\n                )\n            if obs_td.batch_dims != 2:\n                raise ValueError(\n                    \"PPOTFActor.forward(mode='sequence_logp') expects TensorDict with \"\n                    f\"batch_dims=2 [B, T], got batch_size={tuple(obs_td.batch_size)}\"\n                )\n            if self.assembler is None:\n                raise ValueError(\n                    \"PPOTFActor requires obs_schema to assemble sequence observations.\"\n                )\n            if actions is None:\n                raise ValueError(\n                    \"actions must be provided when mode='sequence_logp'\"\n                )\n\n            b, t = int(obs_td.batch_size[0]), int(obs_td.batch_size[1])\n            flat_td = obs_td.flatten(0, 1)\n            actor_obs_flat = self.assembler(flat_td)\n            update = bool(update_obs_norm)\n            for fn in self.actor_obs_transforms:\n                actor_obs_flat = fn(actor_obs_flat, update)\n            self._maybe_update_aux_router_future_recon_norm(\n                flat_td, update=update\n            )\n            actor_obs_seq = actor_obs_flat.reshape(b, t, -1)\n\n            if actor_obs_seq.ndim != 3:\n                raise ValueError(\n                    \"PPOTFActor forward(mode='sequence_logp') expects actor_obs \"\n                    f\"with shape [B, T, D], got {actor_obs_seq.shape}\"\n                )\n            mu, sigma, logp, entropy, aux_preds = self.sequence_forward_logp(\n                actor_obs_seq, actions, attn_mask\n            )\n            td = obs_td.clone(recurse=False)\n            td.set(\"mu\", mu)\n            td.set(\"sigma\", sigma)\n            td.set(\"actions\", actions)\n            td.set(\"actions_log_prob\", logp)\n            td.set(\"entropy\", entropy)\n            if aux_preds is not None:\n                if \"base_lin_vel_loc\" in aux_preds:\n                    td.set(\n                        \"aux_base_lin_vel_loc\", aux_preds[\"base_lin_vel_loc\"]\n                    )\n                    td.set(\n                        \"aux_base_lin_vel_log_std\",\n                        aux_preds[\"base_lin_vel_log_std\"],\n                    )\n                    td.set(\"aux_root_height_loc\", aux_preds[\"root_height_loc\"])\n                    td.set(\n                        \"aux_root_height_log_std\",\n                        aux_preds[\"root_height_log_std\"],\n                    )\n                    td.set(\n                        \"aux_keybody_contact_logits\",\n                        aux_preds[\"keybody_contact_logits\"],\n                    )\n                    td.set(\n                        \"aux_ref_keybody_rel_pos\",\n                        aux_preds[\"ref_keybody_rel_pos\"],\n                    )\n                    td.set(\n                        \"aux_robot_keybody_rel_pos\",\n                        aux_preds[\"robot_keybody_rel_pos\"],\n                    )\n                    if \"denoise_ref_root_lin_vel_residual\" in aux_preds:\n                        td.set(\n                            \"aux_denoise_ref_root_lin_vel_residual\",\n                            aux_preds[\"denoise_ref_root_lin_vel_residual\"],\n                        )\n                    if \"denoise_ref_root_ang_vel_residual\" in aux_preds:\n                        td.set(\n                            \"aux_denoise_ref_root_ang_vel_residual\",\n                            aux_preds[\"denoise_ref_root_ang_vel_residual\"],\n                        )\n                    if \"denoise_ref_dof_pos_residual\" in aux_preds:\n                        td.set(\n                            \"aux_denoise_ref_dof_pos_residual\",\n                            aux_preds[\"denoise_ref_dof_pos_residual\"],\n                        )\n                if \"router_command_recon\" in aux_preds:\n                    td.set(\n                        \"aux_router_command_recon\",\n                        aux_preds[\"router_command_recon\"],\n                    )\n                if \"router_future_recon\" in aux_preds:\n                    td.set(\n                        \"aux_router_future_recon\",\n                        aux_preds[\"router_future_recon\"],\n                    )\n                if \"router_features\" in aux_preds:\n                    td.set(\"router_features\", aux_preds[\"router_features\"])\n                if \"router_temporal_features\" in aux_preds:\n                    td.set(\n                        \"router_temporal_features\",\n                        aux_preds[\"router_temporal_features\"],\n                    )\n            return td\n\n        if mode not in (\"sampling\", \"logp\", \"inference\"):\n            raise ValueError(f\"Unsupported mode: {mode}\")\n        if not isinstance(obs_td, TensorDict):\n            raise ValueError(\"PPOTFActor.forward expects TensorDict input.\")\n        if self.assembler is None:\n            raise ValueError(\n                \"Flat-tensor actor module requires obs_schema in PPOTFActor init.\"\n            )\n\n        td = obs_td.clone(recurse=False)\n        actor_obs = self.assembler(obs_td)\n        update = bool(update_obs_norm)\n        for fn in self.actor_obs_transforms:\n            actor_obs = fn(actor_obs, update)\n        self._maybe_update_aux_router_future_recon_norm(obs_td, update=update)\n\n        if hasattr(self.actor_module, \"single_step_mu\"):\n            mu = self.actor_module.single_step_mu(actor_obs)\n        else:\n            mu = self.actor_module(actor_obs)\n        sigma = self._sigma_like(mu)\n        td.set(\"mu\", mu)\n        td.set(\"sigma\", sigma)\n\n        if mode == \"inference\":\n            td.set(\"actions\", mu)\n            return td\n\n        self.distribution = Normal(mu, sigma)\n        if mode == \"sampling\":\n            actions_out = self.distribution.sample()\n        else:\n            if actions is None:\n                raise ValueError(\"actions must be provided when mode='logp'\")\n            actions_out = actions\n        td.set(\"actions\", actions_out)\n        td.set(\n            \"actions_log_prob\",\n            self.distribution.log_prob(actions_out).sum(dim=-1),\n        )\n        td.set(\"entropy\", self.distribution.entropy().sum(dim=-1))\n        return td\n\n    def sequence_forward_logp(\n        self,\n        obs_seq: torch.Tensor,\n        actions: torch.Tensor,\n        attn_mask: torch.Tensor | None,\n    ) -> tuple[\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        dict[str, torch.Tensor] | None,\n    ]:\n        \"\"\"Sequence log-prob path with learnable per-action log-std.\n\n        Args:\n            obs_seq: [B, T, D]\n            actions: [B, T, A]\n            attn_mask: [B, T, T] boolean (True if attend allowed)\n\n        Returns:\n            mu: [B, T, A], sigma: [B, T, A], logp: [B, T, 1], entropy: [B, T, 1]\n        \"\"\"\n        aux_preds = None\n        aux_router_future_recon_enabled = bool(\n            getattr(self, \"aux_router_future_recon_enabled\", False)\n        )\n        need_pre_moe_aux = self.aux_state_pred_enabled\n        need_router_features = (\n            self.aux_router_command_recon_enabled\n            or self.aux_router_switch_penalty_enabled\n        )\n        need_router_aux = (\n            need_router_features or aux_router_future_recon_enabled\n        )\n        need_ref_aux_hidden = bool(\n            (need_pre_moe_aux or aux_router_future_recon_enabled)\n            and getattr(\n                self.actor_module, \"supports_explicit_ref_aux_hidden\", False\n            )\n        )\n        if need_pre_moe_aux and need_router_aux:\n            sequence_mu_kwargs = {\n                \"attn_mask\": attn_mask,\n                \"return_pre_moe_hidden\": True,\n                \"return_router_features\": need_router_features,\n                \"return_router_temporal_features\": self.aux_router_switch_penalty_enabled,\n            }\n            if need_ref_aux_hidden:\n                sequence_mu_kwargs[\"return_ref_aux_hidden\"] = True\n            actor_outputs = self.actor_module.sequence_mu(\n                obs_seq,\n                **sequence_mu_kwargs,\n            )\n            output_parts = list(actor_outputs)\n            mu = output_parts.pop(0)\n            pre_moe_hidden = output_parts.pop(0)\n            ref_aux_hidden = (\n                output_parts.pop(0) if need_ref_aux_hidden else None\n            )\n            router_features = (\n                output_parts.pop(0) if need_router_features else None\n            )\n            router_temporal_features = (\n                output_parts.pop(0)\n                if self.aux_router_switch_penalty_enabled\n                else None\n            )\n            aux_preds = self.actor_module.predict_aux_from_pre_moe(\n                pre_moe_hidden,\n                ref_aux_hidden=ref_aux_hidden if need_ref_aux_hidden else None,\n            )\n            if router_features is not None:\n                aux_preds[\"router_features\"] = router_features\n            if router_temporal_features is not None:\n                aux_preds[\"router_temporal_features\"] = (\n                    router_temporal_features\n                )\n            if self.aux_router_command_recon_enabled:\n                aux_preds[\"router_command_recon\"] = (\n                    self.actor_module.predict_aux_router_command_from_router_features(\n                        router_features\n                    )\n                )\n            if aux_router_future_recon_enabled:\n                aux_preds[\"router_future_recon\"] = (\n                    self.actor_module.predict_aux_router_future_recon_from_router_hidden(\n                        ref_aux_hidden\n                    )\n                )\n        elif need_pre_moe_aux:\n            sequence_mu_kwargs = {\n                \"attn_mask\": attn_mask,\n                \"return_pre_moe_hidden\": True,\n            }\n            if need_ref_aux_hidden:\n                sequence_mu_kwargs[\"return_ref_aux_hidden\"] = True\n            actor_outputs = self.actor_module.sequence_mu(\n                obs_seq,\n                **sequence_mu_kwargs,\n            )\n            if need_ref_aux_hidden:\n                mu, pre_moe_hidden, ref_aux_hidden = actor_outputs\n            else:\n                mu, pre_moe_hidden = actor_outputs\n            aux_preds = self.actor_module.predict_aux_from_pre_moe(\n                pre_moe_hidden,\n                ref_aux_hidden=ref_aux_hidden if need_ref_aux_hidden else None,\n            )\n        elif need_router_aux:\n            sequence_mu_kwargs = {\n                \"attn_mask\": attn_mask,\n                \"return_router_features\": need_router_features,\n                \"return_router_temporal_features\": self.aux_router_switch_penalty_enabled,\n            }\n            if need_ref_aux_hidden:\n                sequence_mu_kwargs[\"return_ref_aux_hidden\"] = True\n            actor_outputs = self.actor_module.sequence_mu(\n                obs_seq,\n                **sequence_mu_kwargs,\n            )\n            output_parts = list(actor_outputs)\n            mu = output_parts.pop(0)\n            ref_aux_hidden = (\n                output_parts.pop(0) if need_ref_aux_hidden else None\n            )\n            router_features = (\n                output_parts.pop(0) if need_router_features else None\n            )\n            router_temporal_features = (\n                output_parts.pop(0)\n                if self.aux_router_switch_penalty_enabled\n                else None\n            )\n            aux_preds = {}\n            if router_features is not None:\n                aux_preds[\"router_features\"] = router_features\n            if router_temporal_features is not None:\n                aux_preds[\"router_temporal_features\"] = (\n                    router_temporal_features\n                )\n            if self.aux_router_command_recon_enabled:\n                aux_preds[\"router_command_recon\"] = (\n                    self.actor_module.predict_aux_router_command_from_router_features(\n                        router_features\n                    )\n                )\n            if aux_router_future_recon_enabled:\n                aux_preds[\"router_future_recon\"] = (\n                    self.actor_module.predict_aux_router_future_recon_from_router_hidden(\n                        ref_aux_hidden\n                    )\n                )\n        else:\n            mu = self.actor_module.sequence_mu(obs_seq, attn_mask=attn_mask)\n        # Match sampling-time clamping for stability and consistent KL/log-prob\n        sigma_vec = self._sigma_from_params().clamp(\n            self.min_sigma, self.max_sigma\n        )\n        sigma = sigma_vec[None, None, :].expand_as(mu)\n        var = sigma * sigma\n        logp = -0.5 * (\n            ((actions - mu) ** 2) / (var + 1.0e-8)\n            + 2.0 * torch.log(sigma + 1.0e-8)\n            + math.log(2.0 * math.pi)\n        ).sum(dim=-1, keepdim=True)\n        entropy = (\n            0.5 + 0.5 * math.log(2.0 * math.pi) + torch.log(sigma + 1.0e-8)\n        ).sum(dim=-1, keepdim=True)\n        return mu, sigma, logp, entropy, aux_preds\n\n\nclass PPOTFRefRouterActor(PPOTFActor):\n    @staticmethod\n    def _leaf_obs_name(term: str) -> str:\n        return str(term).rsplit(\"/\", maxsplit=1)[-1]\n\n    @classmethod\n    def _infer_flat_term_dim(\n        cls,\n        *,\n        obs_example: TensorDict,\n        term: str,\n        seq_len: int,\n    ) -> int:\n        tensor = TensorDictAssembler._get_from_data(obs_example, str(term))\n        if tensor is None:\n            raise KeyError(\n                f\"Missing obs term '{term}' in obs_example while inferring \"\n                \"reference-router feature indices.\"\n            )\n        if not isinstance(tensor, torch.Tensor):\n            raise TypeError(\n                f\"Obs term '{term}' must be a torch.Tensor, got {type(tensor)}.\"\n            )\n        if tensor.ndim == 2:\n            if seq_len != 1:\n                raise ValueError(\n                    f\"Obs term '{term}' expected seq_len={seq_len} but tensor \"\n                    f\"is 2D with shape {tuple(tensor.shape)}.\"\n                )\n            return int(tensor.shape[-1])\n        if tensor.ndim == 3:\n            if int(tensor.shape[1]) != seq_len:\n                raise ValueError(\n                    f\"Obs term '{term}' seq_len mismatch: expected {seq_len}, \"\n                    f\"got {int(tensor.shape[1])}.\"\n                )\n            return int(tensor.shape[1] * tensor.shape[-1])\n        raise ValueError(\n            f\"Obs term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}.\"\n        )\n\n    @classmethod\n    def infer_router_feature_indices(\n        cls,\n        obs_schema: dict,\n        obs_example: TensorDict,\n    ) -> list[int]:\n        if not isinstance(obs_example, TensorDict):\n            raise ValueError(\n                \"PPOTFRefRouterActor requires TensorDict obs_example.\"\n            )\n\n        router_feature_indices: list[int] = []\n        offset = 0\n        for _, seq_cfg in obs_schema.items():\n            if not isinstance(seq_cfg, dict):\n                continue\n            seq_len = int(seq_cfg.get(\"seq_len\", 1))\n            for term in seq_cfg.get(\"terms\", []):\n                term_str = str(term)\n                flat_dim = cls._infer_flat_term_dim(\n                    obs_example=obs_example,\n                    term=term_str,\n                    seq_len=seq_len,\n                )\n                leaf_name = cls._leaf_obs_name(term_str)\n                if leaf_name.startswith(\"actor_ref_\"):\n                    router_feature_indices.extend(\n                        range(offset, offset + flat_dim)\n                    )\n                offset += flat_dim\n\n        if len(router_feature_indices) == 0:\n            raise ValueError(\n                \"PPOTFRefRouterActor could not infer any actor_ref_* features \"\n                \"from obs_schema.\"\n            )\n        return router_feature_indices\n\n    def __init__(\n        self,\n        obs_schema: dict | None,\n        module_config_dict: dict,\n        num_actions: int,\n        init_noise_std: float,\n        *,\n        obs_example: dict | None = None,\n    ):\n        if obs_schema is None:\n            raise ValueError(\n                \"PPOTFRefRouterActor requires non-empty obs_schema.\"\n            )\n        if obs_example is None:\n            raise ValueError(\"PPOTFRefRouterActor requires obs_example.\")\n        if bool(module_config_dict.get(\"use_future_cross_attn\", False)):\n            raise ValueError(\n                \"PPOTFRefRouterActor does not support use_future_cross_attn=True.\"\n            )\n\n        actor_module_cfg = copy.deepcopy(module_config_dict)\n        aux_future_cfg = actor_module_cfg.get(\"aux_router_future_recon\", {})\n        if bool(aux_future_cfg.get(\"enabled\", False)):\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicy does not support \"\n                \"aux_router_future_recon.\"\n            )\n        router_feature_indices = self.infer_router_feature_indices(\n            obs_schema, obs_example\n        )\n        actor_module_cfg[\"router_input_dim\"] = int(len(router_feature_indices))\n        actor_module_cfg[\"router_feature_indices\"] = list(\n            router_feature_indices\n        )\n        if \"router_embed_mlp_hidden\" not in actor_module_cfg:\n            actor_module_cfg[\"router_embed_mlp_hidden\"] = int(\n                actor_module_cfg.get(\"obs_embed_mlp_hidden\", 1024)\n            )\n\n        super().__init__(\n            obs_schema=obs_schema,\n            module_config_dict=actor_module_cfg,\n            num_actions=num_actions,\n            init_noise_std=init_noise_std,\n            obs_example=obs_example,\n        )\n        self.router_feature_indices = list(router_feature_indices)\n\n\nclass PPOTFRefRouterSeqActor(PPOTFActor):\n    REQUIRED_CURRENT_REF_TERMS = (\n        \"actor_ref_gravity_projection_cur\",\n        \"actor_ref_base_linvel_cur\",\n        \"actor_ref_base_angvel_cur\",\n        \"actor_ref_dof_pos_cur\",\n        \"actor_ref_root_height_cur\",\n    )\n    REQUIRED_FUTURE_REF_TERMS = (\n        \"actor_ref_gravity_projection_fut\",\n        \"actor_ref_base_linvel_fut\",\n        \"actor_ref_base_angvel_fut\",\n        \"actor_ref_dof_pos_fut\",\n        \"actor_ref_root_height_fut\",\n    )\n    SUPPORTED_AUX_WEIGHT_NAMES = {\n        \"w_base_lin_vel\",\n        \"w_keybody_contact\",\n        \"w_ref_keybody_rel_pos\",\n        \"w_robot_keybody_rel_pos\",\n    }\n\n    @staticmethod\n    def _leaf_obs_name(term: str) -> str:\n        return str(term).rsplit(\"/\", maxsplit=1)[-1]\n\n    @classmethod\n    def _infer_flat_term_dim(\n        cls,\n        *,\n        obs_example: TensorDict,\n        term: str,\n        seq_len: int,\n    ) -> int:\n        tensor = TensorDictAssembler._get_from_data(obs_example, str(term))\n        if tensor is None:\n            raise KeyError(\n                f\"Missing obs term '{term}' in obs_example while inferring shared ref partitions.\"\n            )\n        if not isinstance(tensor, torch.Tensor):\n            raise TypeError(\n                f\"Obs term '{term}' must be a torch.Tensor, got {type(tensor)}.\"\n            )\n        if tensor.ndim == 2:\n            if seq_len != 1:\n                raise ValueError(\n                    f\"Obs term '{term}' expected seq_len={seq_len} but tensor \"\n                    f\"is 2D with shape {tuple(tensor.shape)}.\"\n                )\n            return int(tensor.shape[-1])\n        if tensor.ndim == 3:\n            if int(tensor.shape[1]) != seq_len:\n                raise ValueError(\n                    f\"Obs term '{term}' seq_len mismatch: expected {seq_len}, \"\n                    f\"got {int(tensor.shape[1])}.\"\n                )\n            return int(tensor.shape[-1])\n        raise ValueError(\n            f\"Obs term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}.\"\n        )\n\n    @classmethod\n    def _validate_v2_aux_config(cls, module_config_dict: dict) -> None:\n        aux_cmd_cfg = module_config_dict.get(\"aux_router_command_recon\", {})\n        if bool(aux_cmd_cfg.get(\"enabled\", False)):\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 does not support \"\n                \"aux_router_command_recon.\"\n            )\n        aux_cfg = module_config_dict.get(\"aux_state_pred\", {})\n        if not bool(aux_cfg.get(\"enabled\", False)):\n            return\n        for key, value in aux_cfg.items():\n            if not str(key).startswith(\"w_\"):\n                continue\n            if float(value) <= 0.0:\n                continue\n            if str(key) not in cls.SUPPORTED_AUX_WEIGHT_NAMES:\n                raise ValueError(\n                    \"ReferenceRoutedGroupedMoETransformerPolicyV2 only supports \"\n                    \"aux_state_pred weights for \"\n                    \"base_lin_vel, keybody_contact, ref_keybody_rel_pos, and \"\n                    \"robot_keybody_rel_pos. Unsupported weight: \"\n                    f\"{key}.\"\n                )\n\n    @classmethod\n    def _build_aux_router_future_recon_schema(\n        cls, obs_schema: dict\n    ) -> dict[str, dict]:\n        required_terms = set(cls.REQUIRED_FUTURE_REF_TERMS)\n        matched_terms: set[str] = set()\n        future_schema: dict[str, dict] = {}\n\n        for group_name, seq_cfg in obs_schema.items():\n            if not isinstance(seq_cfg, dict):\n                continue\n            terms = [\n                str(term)\n                for term in seq_cfg.get(\"terms\", [])\n                if cls._leaf_obs_name(str(term)) in required_terms\n            ]\n            if len(terms) == 0:\n                continue\n            next_seq_cfg = dict(seq_cfg)\n            next_seq_cfg[\"terms\"] = terms\n            future_schema[str(group_name)] = next_seq_cfg\n            matched_terms.update(cls._leaf_obs_name(term) for term in terms)\n\n        missing_terms = sorted(required_terms.difference(matched_terms))\n        if missing_terms:\n            raise ValueError(\n                \"PPOTFRefRouterSeqActor could not infer all future ref terms \"\n                \"for aux_router_future_recon. Missing: \"\n                + \", \".join(missing_terms)\n            )\n        return future_schema\n\n    @classmethod\n    def _prepare_aux_router_future_recon(\n        cls,\n        *,\n        actor_module_cfg: dict,\n        obs_schema: dict,\n        obs_example: TensorDict,\n    ) -> TensorDictAssembler | None:\n        aux_future_cfg = copy.deepcopy(\n            actor_module_cfg.get(\"aux_router_future_recon\", {})\n        )\n        if not bool(aux_future_cfg.get(\"enabled\", False)):\n            actor_module_cfg[\"aux_router_future_recon\"] = aux_future_cfg\n            return None\n\n        future_schema = cls._build_aux_router_future_recon_schema(obs_schema)\n        future_assembler = TensorDictAssembler(\n            future_schema, output_mode=\"flat\"\n        )\n        aux_future_cfg[\"output_dim\"] = int(\n            future_assembler.infer_output_dim(obs_example)\n        )\n        actor_module_cfg[\"aux_router_future_recon\"] = aux_future_cfg\n        return future_assembler\n\n    @classmethod\n    def _infer_shared_ref_layout(\n        cls,\n        obs_schema: dict,\n        obs_example: TensorDict,\n    ) -> dict[str, int | list[int] | list[tuple[int, int, int]]]:\n        if not isinstance(obs_example, TensorDict):\n            raise ValueError(\n                \"PPOTFRefRouterSeqActor requires TensorDict obs_example.\"\n            )\n\n        required_cur = set(cls.REQUIRED_CURRENT_REF_TERMS)\n        required_fut = set(cls.REQUIRED_FUTURE_REF_TERMS)\n        found_cur: dict[str, tuple[int, int]] = {}\n        found_fut: dict[str, tuple[int, int, int]] = {}\n        state_indices: list[int] = []\n        ref_cur_indices: list[int] = []\n        offset = 0\n        ref_fut_seq_len: int | None = None\n\n        for _, seq_cfg in obs_schema.items():\n            if not isinstance(seq_cfg, dict):\n                continue\n            seq_len = int(seq_cfg.get(\"seq_len\", 1))\n            for term in seq_cfg.get(\"terms\", []):\n                term_str = str(term)\n                leaf_name = cls._leaf_obs_name(term_str)\n                flat_term_dim = cls._infer_flat_term_dim(\n                    obs_example=obs_example,\n                    term=term_str,\n                    seq_len=seq_len,\n                )\n                flat_span = int(seq_len * flat_term_dim)\n                term_range = list(range(offset, offset + flat_span))\n\n                if leaf_name in required_cur:\n                    if seq_len != 1:\n                        raise ValueError(\n                            \"current ref term \"\n                            f\"'{leaf_name}' must have seq_len=1, got {seq_len}.\"\n                        )\n                    if leaf_name in found_cur:\n                        raise ValueError(\n                            f\"duplicate current ref term '{leaf_name}' in obs_schema.\"\n                        )\n                    found_cur[leaf_name] = (offset, flat_term_dim)\n                    ref_cur_indices.extend(term_range)\n                elif leaf_name in required_fut:\n                    if leaf_name in found_fut:\n                        raise ValueError(\n                            f\"duplicate future ref term '{leaf_name}' in obs_schema.\"\n                        )\n                    if ref_fut_seq_len is None:\n                        ref_fut_seq_len = seq_len\n                    elif ref_fut_seq_len != seq_len:\n                        raise ValueError(\n                            \"future ref terms must share one seq_len, got \"\n                            f\"{ref_fut_seq_len} and {seq_len}.\"\n                        )\n                    found_fut[leaf_name] = (\n                        offset,\n                        offset + flat_span,\n                        flat_term_dim,\n                    )\n                else:\n                    state_indices.extend(term_range)\n\n                offset += flat_span\n\n        missing_cur = sorted(required_cur.difference(found_cur.keys()))\n        if missing_cur:\n            raise ValueError(\n                \"missing required current ref term(s): \"\n                + \", \".join(missing_cur)\n            )\n        missing_fut = sorted(required_fut.difference(found_fut.keys()))\n        if missing_fut:\n            raise ValueError(\n                \"missing required future ref term(s): \"\n                + \", \".join(missing_fut)\n            )\n        if ref_fut_seq_len is None or ref_fut_seq_len <= 0:\n            raise ValueError(\n                \"missing required future ref terms in obs_schema.\"\n            )\n        if len(state_indices) == 0:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 requires at least \"\n                \"one non-reference actor state feature.\"\n            )\n\n        ordered_fut_slices = [\n            found_fut[leaf_name] for leaf_name in cls.REQUIRED_FUTURE_REF_TERMS\n        ]\n        return {\n            \"full_obs_input_dim\": int(offset),\n            \"state_obs_input_dim\": int(len(state_indices)),\n            \"ref_cur_token_dim\": int(len(ref_cur_indices)),\n            \"ref_fut_token_dim\": int(\n                sum(end - start for start, end, _ in ordered_fut_slices)\n                // ref_fut_seq_len\n            ),\n            \"ref_fut_seq_len\": int(ref_fut_seq_len),\n            \"state_feature_indices\": state_indices,\n            \"ref_cur_feature_indices\": ref_cur_indices,\n            \"ref_fut_slices\": ordered_fut_slices,\n        }\n\n    def __init__(\n        self,\n        obs_schema: dict | None,\n        module_config_dict: dict,\n        num_actions: int,\n        init_noise_std: float,\n        *,\n        obs_example: dict | None = None,\n    ):\n        if obs_schema is None:\n            raise ValueError(\n                \"PPOTFRefRouterSeqActor requires non-empty obs_schema.\"\n            )\n        if obs_example is None:\n            raise ValueError(\"PPOTFRefRouterSeqActor requires obs_example.\")\n        if bool(module_config_dict.get(\"use_future_cross_attn\", False)):\n            raise ValueError(\n                \"PPOTFRefRouterSeqActor does not support use_future_cross_attn=True.\"\n            )\n        self._validate_v2_aux_config(module_config_dict)\n        inferred_layout = self._infer_shared_ref_layout(\n            obs_schema, obs_example\n        )\n\n        actor_module_cfg = copy.deepcopy(module_config_dict)\n        actor_module_cfg[\"input_dim_override\"] = int(\n            inferred_layout[\"state_obs_input_dim\"]\n        )\n        actor_module_cfg[\"state_obs_input_dim\"] = int(\n            inferred_layout[\"state_obs_input_dim\"]\n        )\n        actor_module_cfg[\"ref_cur_token_dim\"] = int(\n            inferred_layout[\"ref_cur_token_dim\"]\n        )\n        actor_module_cfg[\"ref_fut_token_dim\"] = int(\n            inferred_layout[\"ref_fut_token_dim\"]\n        )\n        actor_module_cfg[\"ref_fut_seq_len\"] = int(\n            inferred_layout[\"ref_fut_seq_len\"]\n        )\n        actor_module_cfg[\"state_feature_indices\"] = list(\n            inferred_layout[\"state_feature_indices\"]\n        )\n        actor_module_cfg[\"ref_cur_feature_indices\"] = list(\n            inferred_layout[\"ref_cur_feature_indices\"]\n        )\n        actor_module_cfg[\"ref_fut_slices\"] = [\n            list(item) for item in inferred_layout[\"ref_fut_slices\"]\n        ]\n        actor_module_cfg.pop(\"router_hist_obs_schema\", None)\n        actor_module_cfg.pop(\"router_fut_obs_schema\", None)\n\n        super().__init__(\n            obs_schema=obs_schema,\n            module_config_dict=actor_module_cfg,\n            num_actions=num_actions,\n            init_noise_std=init_noise_std,\n            obs_example=obs_example,\n        )\n        self.full_obs_input_dim = int(inferred_layout[\"full_obs_input_dim\"])\n        self.state_obs_input_dim = int(inferred_layout[\"state_obs_input_dim\"])\n        self.ref_cur_token_dim = int(inferred_layout[\"ref_cur_token_dim\"])\n        self.ref_fut_token_dim = int(inferred_layout[\"ref_fut_token_dim\"])\n        self.ref_fut_seq_len = int(inferred_layout[\"ref_fut_seq_len\"])\n        self.state_feature_indices = list(\n            inferred_layout[\"state_feature_indices\"]\n        )\n        self.ref_cur_feature_indices = list(\n            inferred_layout[\"ref_cur_feature_indices\"]\n        )\n        self.ref_fut_slices = [\n            tuple(int(v) for v in item)\n            for item in inferred_layout[\"ref_fut_slices\"]\n        ]\n\n\nclass PPOTFRefRouterV3Actor(PPOTFRefRouterSeqActor):\n    def __init__(\n        self,\n        obs_schema: dict | None,\n        module_config_dict: dict,\n        num_actions: int,\n        init_noise_std: float,\n        *,\n        obs_example: dict | None = None,\n    ):\n        if obs_schema is None:\n            raise ValueError(\n                \"PPOTFRefRouterV3Actor requires non-empty obs_schema.\"\n            )\n        if obs_example is None:\n            raise ValueError(\"PPOTFRefRouterV3Actor requires obs_example.\")\n        if bool(module_config_dict.get(\"use_future_cross_attn\", False)):\n            raise ValueError(\n                \"PPOTFRefRouterV3Actor does not support use_future_cross_attn=True.\"\n            )\n        self._validate_v2_aux_config(module_config_dict)\n        inferred_layout = self._infer_shared_ref_layout(\n            obs_schema, obs_example\n        )\n\n        actor_module_cfg = copy.deepcopy(module_config_dict)\n        actor_module_cfg[\"state_obs_input_dim\"] = int(\n            inferred_layout[\"state_obs_input_dim\"]\n        )\n        actor_module_cfg[\"ref_cur_token_dim\"] = int(\n            inferred_layout[\"ref_cur_token_dim\"]\n        )\n        actor_module_cfg[\"ref_fut_token_dim\"] = int(\n            inferred_layout[\"ref_fut_token_dim\"]\n        )\n        actor_module_cfg[\"ref_fut_seq_len\"] = int(\n            inferred_layout[\"ref_fut_seq_len\"]\n        )\n        actor_module_cfg[\"state_feature_indices\"] = list(\n            inferred_layout[\"state_feature_indices\"]\n        )\n        actor_module_cfg[\"ref_cur_feature_indices\"] = list(\n            inferred_layout[\"ref_cur_feature_indices\"]\n        )\n        actor_module_cfg[\"ref_fut_slices\"] = [\n            list(item) for item in inferred_layout[\"ref_fut_slices\"]\n        ]\n        actor_module_cfg.pop(\"router_hist_obs_schema\", None)\n        actor_module_cfg.pop(\"router_fut_obs_schema\", None)\n        future_recon_assembler = self._prepare_aux_router_future_recon(\n            actor_module_cfg=actor_module_cfg,\n            obs_schema=obs_schema,\n            obs_example=obs_example,\n        )\n\n        PPOTFActor.__init__(\n            self,\n            obs_schema=obs_schema,\n            module_config_dict=actor_module_cfg,\n            num_actions=num_actions,\n            init_noise_std=init_noise_std,\n            obs_example=obs_example,\n        )\n        self.full_obs_input_dim = int(inferred_layout[\"full_obs_input_dim\"])\n        self.state_obs_input_dim = int(inferred_layout[\"state_obs_input_dim\"])\n        self.ref_cur_token_dim = int(inferred_layout[\"ref_cur_token_dim\"])\n        self.ref_fut_token_dim = int(inferred_layout[\"ref_fut_token_dim\"])\n        self.ref_fut_seq_len = int(inferred_layout[\"ref_fut_seq_len\"])\n        self.state_feature_indices = list(\n            inferred_layout[\"state_feature_indices\"]\n        )\n        self.ref_cur_feature_indices = list(\n            inferred_layout[\"ref_cur_feature_indices\"]\n        )\n        self.ref_fut_slices = [\n            tuple(int(v) for v in item)\n            for item in inferred_layout[\"ref_fut_slices\"]\n        ]\n        self.aux_router_future_recon_assembler = future_recon_assembler\n\n\nclass PPOCondTFActor(PPOTFActor):\n    \"\"\"Transformer actor with flat state obs and seq future-token conditioning.\"\"\"\n\n    def __init__(\n        self,\n        obs_schema: dict | None,\n        module_config_dict: dict,\n        num_actions: int,\n        init_noise_std: float,\n        *,\n        obs_example: dict | None = None,\n    ):\n        super().__init__(\n            obs_schema=obs_schema,\n            module_config_dict=module_config_dict,\n            num_actions=num_actions,\n            init_noise_std=init_noise_std,\n            obs_example=obs_example,\n        )\n        if obs_schema is None:\n            raise ValueError(\"PPOCondTFActor requires non-empty obs_schema.\")\n        if \"flattened_obs\" not in obs_schema:\n            raise ValueError(\"obs_schema must contain 'flattened_obs'.\")\n        if \"flattened_obs_fut\" not in obs_schema:\n            raise ValueError(\"obs_schema must contain 'flattened_obs_fut'.\")\n        if obs_example is None:\n            raise ValueError(\"PPOCondTFActor requires obs_example.\")\n\n        self.state_schema = {\"flattened_obs\": obs_schema[\"flattened_obs\"]}\n        self.future_schema = {\n            \"flattened_obs_fut\": obs_schema[\"flattened_obs_fut\"]\n        }\n        self.state_assembler = TensorDictAssembler(\n            self.state_schema, output_mode=\"flat\"\n        )\n        self.future_assembler = TensorDictAssembler(\n            self.future_schema, output_mode=\"seq\"\n        )\n        self.state_dim = int(\n            self.state_assembler.infer_output_dim(obs_example)\n        )\n        self.future_token_dim = int(\n            self.future_assembler.infer_output_dim(obs_example)\n        )\n        self.future_seq_len = int(self.future_assembler.seq_len)\n        self.future_term_dims = self._infer_future_term_dims(obs_example)\n        self.full_obs_dim = int(self.flat_obs_dim)\n        expected_full = self.state_dim + (\n            self.future_seq_len * self.future_token_dim\n        )\n        if self.full_obs_dim != expected_full:\n            raise ValueError(\n                \"Assembled obs dim mismatch in PPOCondTFActor: \"\n                f\"full={self.full_obs_dim}, expected={expected_full}\"\n            )\n        if self.obs_norm_enabled:\n            self.state_obs_normalizer = EmpiricalNormalization(\n                shape=self.state_dim,\n                eps=self.obs_norm_eps,\n                update_method=self.obs_norm_update_method,\n                ema_momentum=self.obs_norm_ema_momentum,\n            )\n        else:\n            self.state_obs_normalizer = nn.Identity()\n\n    def _infer_future_term_dims(self, obs_example: TensorDict) -> list[int]:\n        if not isinstance(obs_example, TensorDict):\n            raise ValueError(\"PPOCondTFActor requires TensorDict obs_example.\")\n        fut_cfg = self.future_schema.get(\"flattened_obs_fut\", None)\n        if fut_cfg is None:\n            raise ValueError(\n                \"Missing future schema group 'flattened_obs_fut'.\"\n            )\n        terms = fut_cfg.get(\"terms\", [])\n        if not isinstance(terms, list) or len(terms) == 0:\n            raise ValueError(\"Future schema terms must be a non-empty list.\")\n        dims: list[int] = []\n        for term in terms:\n            tensor = TensorDictAssembler._get_from_data(obs_example, str(term))\n            if tensor is None:\n                raise KeyError(\n                    f\"Missing future term '{term}' in obs_example TensorDict.\"\n                )\n            if not isinstance(tensor, torch.Tensor):\n                raise TypeError(\n                    f\"Future term '{term}' must be a torch.Tensor, got {type(tensor)}\"\n                )\n            if tensor.ndim == 2:\n                dims.append(int(tensor.shape[-1]))\n            elif tensor.ndim == 3:\n                dims.append(int(tensor.shape[-1]))\n            else:\n                raise ValueError(\n                    f\"Future term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}\"\n                )\n        if sum(dims) != int(self.future_token_dim):\n            raise ValueError(\n                \"Inferred future_term_dims sum mismatch: expected \"\n                f\"{int(self.future_token_dim)}, got {sum(dims)} (dims={dims})\"\n            )\n        return dims\n\n    @property\n    def flat_obs_dim(self) -> int:\n        if self.assembler is None:\n            raise ValueError(\n                \"PPOCondTFActor requires the base flat assembler for ONNX.\"\n            )\n        if self.assembler.output_dim is None:\n            raise ValueError(\"Base assembler output_dim is not initialized.\")\n        return int(self.assembler.output_dim)\n\n    def _normalize_state_obs(\n        self, state_obs: torch.Tensor, update: bool\n    ) -> torch.Tensor:\n        if not self.obs_norm_enabled:\n            return state_obs\n        if state_obs.ndim != 2:\n            raise ValueError(\n                f\"state_obs must be [B, D_state], got {tuple(state_obs.shape)}\"\n            )\n        if update:\n            self.state_obs_normalizer.update(state_obs)\n        state_obs = self.state_obs_normalizer.normalize_only(state_obs)\n        if self.obs_norm_clip > 0.0:\n            state_obs = torch.clamp(\n                state_obs, -self.obs_norm_clip, self.obs_norm_clip\n            )\n        return state_obs\n\n    def _assemble_state_future(\n        self, obs_td: TensorDict\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        if not isinstance(obs_td, TensorDict):\n            raise ValueError(\n                \"PPOCondTFActor._assemble_state_future expects TensorDict input.\"\n            )\n        state_obs = self.state_assembler(obs_td)\n        future_obs = self.future_assembler(obs_td)\n        return state_obs, future_obs\n\n    def _split_flat_obs(\n        self, obs: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        if obs.ndim != 2:\n            raise ValueError(f\"Expected [B, D], got {obs.shape}\")\n        state_obs = obs[:, : self.state_dim]\n        future_flat = obs[:, self.state_dim :]\n        expected_dim = self.future_seq_len * self.future_token_dim\n        if future_flat.shape[-1] != expected_dim:\n            raise ValueError(\n                \"Future flat obs dim mismatch: expected \"\n                f\"{expected_dim}, got {future_flat.shape[-1]}\"\n            )\n        b = int(obs.shape[0])\n        offset = 0\n        future_parts = []\n        for d_term in self.future_term_dims:\n            span = int(self.future_seq_len * d_term)\n            chunk = future_flat[:, offset : offset + span]\n            future_parts.append(chunk.reshape(b, self.future_seq_len, d_term))\n            offset += span\n        if offset != int(future_flat.shape[-1]):\n            raise ValueError(\n                \"Future flat slicing mismatch: \"\n                f\"consumed={offset}, total={int(future_flat.shape[-1])}\"\n            )\n        future_obs = torch.cat(future_parts, dim=-1)\n        return state_obs, future_obs\n\n    def export_onnx(\n        self,\n        onnx_path: str | Path,\n        *,\n        opset_version: int = 17,\n    ) -> str:\n        export_path = Path(onnx_path)\n        export_path.parent.mkdir(parents=True, exist_ok=True)\n\n        if hasattr(self.actor_module, \"clear_router_distribution_cache\"):\n            self.actor_module.clear_router_distribution_cache()\n        actor_module = _clone_module_for_cpu_export(self.actor_module)\n        if self.obs_norm_enabled:\n            state_obs_normalizer = _clone_module_for_cpu_export(\n                self.state_obs_normalizer\n            )\n        else:\n            state_obs_normalizer = nn.Identity()\n\n        exporter = PPOCondTFActorOnnxModule(\n            actor_module=actor_module,\n            state_obs_normalizer=state_obs_normalizer,\n            obs_norm_enabled=self.obs_norm_enabled,\n            obs_norm_clip=self.obs_norm_clip if self.obs_norm_enabled else 0.0,\n            state_dim=self.state_dim,\n            future_seq_len=self.future_seq_len,\n            future_token_dim=self.future_token_dim,\n            future_term_dims=self.future_term_dims,\n        ).to(\"cpu\")\n        exporter.eval()\n\n        cache_shape = self.onnx_past_key_values_shape(batch_size=1)\n        obs = torch.zeros(\n            1, self.flat_obs_dim, device=\"cpu\", dtype=torch.float32\n        )\n        past_key_values = torch.zeros(\n            *cache_shape, device=\"cpu\", dtype=torch.float32\n        )\n        step_idx = torch.tensor([0], dtype=torch.long, device=\"cpu\")\n        output_names = [\n            \"actions\",\n            \"present_key_values\",\n            *self.onnx_routing_output_names(),\n        ]\n\n        torch.onnx.export(\n            exporter,\n            (obs, past_key_values, step_idx),\n            str(export_path),\n            export_params=True,\n            opset_version=opset_version,\n            verbose=False,\n            dynamo=False,\n            input_names=[\"obs\", \"past_key_values\", \"step_idx\"],\n            output_names=output_names,\n        )\n        return str(export_path)\n\n    def update_distribution(self, actor_obs):\n        if not isinstance(actor_obs, tuple) or len(actor_obs) != 2:\n            raise ValueError(\n                \"PPOCondTFActor.update_distribution expects tuple(state_obs, future_obs).\"\n            )\n        state_obs, future_obs = actor_obs\n        mu = self.actor_module.single_step_mu_cond(\n            state_obs,\n            future_obs,\n            future_mask=None,\n        )\n        std = self._sigma_from_params()\n        std = torch.clamp(std, min=self.min_sigma, max=self.max_sigma)\n        self.distribution = Normal(mu, std)\n\n    def forward(\n        self,\n        obs_td: TensorDict | torch.Tensor,\n        actions: torch.Tensor | None = None,\n        mode: str = \"sampling\",\n        attn_mask: torch.Tensor | None = None,\n        *,\n        update_obs_norm: bool = True,\n        past_key_values: torch.Tensor | None = None,\n        current_pos: torch.Tensor | None = None,\n    ) -> TensorDict | tuple[torch.Tensor, torch.Tensor]:\n        if past_key_values is not None:\n            if isinstance(obs_td, TensorDict):\n                state_obs, future_obs = self._assemble_state_future(obs_td)\n            else:\n                state_obs, future_obs = self._split_flat_obs(obs_td)\n            state_obs = self._normalize_state_obs(state_obs, update=False)\n            return self.actor_module._forward_inference_onnx_cond(\n                state_obs,\n                future_obs,\n                past_key_values,\n                current_pos,\n            )\n\n        if mode == \"sequence_logp\":\n            if not isinstance(obs_td, TensorDict):\n                raise ValueError(\n                    \"PPOCondTFActor.forward(mode='sequence_logp') expects TensorDict input.\"\n                )\n            if obs_td.batch_dims != 2:\n                raise ValueError(\n                    \"PPOCondTFActor.forward(mode='sequence_logp') expects batch_dims=2 [B, T], \"\n                    f\"got batch_size={tuple(obs_td.batch_size)}\"\n                )\n            if actions is None:\n                raise ValueError(\n                    \"actions must be provided when mode='sequence_logp'\"\n                )\n\n            b, t = int(obs_td.batch_size[0]), int(obs_td.batch_size[1])\n            future_mask = None\n            if \"future_mask\" in obs_td.keys():\n                future_mask = obs_td.get(\"future_mask\")\n                if future_mask.shape != (b, t, self.future_seq_len):\n                    raise ValueError(\n                        \"future_mask shape mismatch in sequence_logp: expected \"\n                        f\"{(b, t, self.future_seq_len)}, got {tuple(future_mask.shape)}\"\n                    )\n                future_mask = future_mask.to(torch.bool)\n            flat_td = obs_td.flatten(0, 1)\n            state_flat, future_flat = self._assemble_state_future(flat_td)\n            update = bool(update_obs_norm)\n            state_flat = self._normalize_state_obs(state_flat, update=update)\n            state_seq = state_flat.reshape(b, t, -1)\n            future_seq = future_flat.reshape(\n                b, t, self.future_seq_len, self.future_token_dim\n            )\n\n            (\n                mu,\n                sigma,\n                logp,\n                entropy,\n                aux_preds,\n            ) = self.sequence_forward_logp_cond(\n                state_seq,\n                future_seq,\n                actions,\n                attn_mask,\n                future_mask,\n            )\n            td = obs_td.clone(recurse=False)\n            td.set(\"mu\", mu)\n            td.set(\"sigma\", sigma)\n            td.set(\"actions\", actions)\n            td.set(\"actions_log_prob\", logp)\n            td.set(\"entropy\", entropy)\n            if aux_preds is not None:\n                if \"base_lin_vel_loc\" in aux_preds:\n                    td.set(\n                        \"aux_base_lin_vel_loc\", aux_preds[\"base_lin_vel_loc\"]\n                    )\n                    td.set(\n                        \"aux_base_lin_vel_log_std\",\n                        aux_preds[\"base_lin_vel_log_std\"],\n                    )\n                    td.set(\"aux_root_height_loc\", aux_preds[\"root_height_loc\"])\n                    td.set(\n                        \"aux_root_height_log_std\",\n                        aux_preds[\"root_height_log_std\"],\n                    )\n                    td.set(\n                        \"aux_keybody_contact_logits\",\n                        aux_preds[\"keybody_contact_logits\"],\n                    )\n                    td.set(\n                        \"aux_ref_keybody_rel_pos\",\n                        aux_preds[\"ref_keybody_rel_pos\"],\n                    )\n                    td.set(\n                        \"aux_robot_keybody_rel_pos\",\n                        aux_preds[\"robot_keybody_rel_pos\"],\n                    )\n                    if \"denoise_ref_root_lin_vel_residual\" in aux_preds:\n                        td.set(\n                            \"aux_denoise_ref_root_lin_vel_residual\",\n                            aux_preds[\"denoise_ref_root_lin_vel_residual\"],\n                        )\n                    if \"denoise_ref_root_ang_vel_residual\" in aux_preds:\n                        td.set(\n                            \"aux_denoise_ref_root_ang_vel_residual\",\n                            aux_preds[\"denoise_ref_root_ang_vel_residual\"],\n                        )\n                    if \"denoise_ref_dof_pos_residual\" in aux_preds:\n                        td.set(\n                            \"aux_denoise_ref_dof_pos_residual\",\n                            aux_preds[\"denoise_ref_dof_pos_residual\"],\n                        )\n                if \"router_command_recon\" in aux_preds:\n                    td.set(\n                        \"aux_router_command_recon\",\n                        aux_preds[\"router_command_recon\"],\n                    )\n                if \"router_features\" in aux_preds:\n                    td.set(\"router_features\", aux_preds[\"router_features\"])\n                if \"router_temporal_features\" in aux_preds:\n                    td.set(\n                        \"router_temporal_features\",\n                        aux_preds[\"router_temporal_features\"],\n                    )\n            return td\n\n        if mode not in (\"sampling\", \"logp\", \"inference\"):\n            raise ValueError(f\"Unsupported mode: {mode}\")\n        if not isinstance(obs_td, TensorDict):\n            raise ValueError(\n                \"PPOCondTFActor.forward expects TensorDict input.\"\n            )\n\n        td = obs_td.clone(recurse=False)\n        state_obs, future_obs = self._assemble_state_future(obs_td)\n        update = bool(update_obs_norm)\n        state_obs = self._normalize_state_obs(state_obs, update=update)\n        future_mask = None\n        if \"future_mask\" in td.keys():\n            future_mask = td.get(\"future_mask\")\n            if future_mask.shape != (state_obs.shape[0], self.future_seq_len):\n                raise ValueError(\n                    \"future_mask shape mismatch in single-step forward: expected \"\n                    f\"{(state_obs.shape[0], self.future_seq_len)}, got {tuple(future_mask.shape)}\"\n                )\n            future_mask = future_mask.to(torch.bool)\n        mu = self.actor_module.single_step_mu_cond(\n            state_obs, future_obs, future_mask=future_mask\n        )\n        sigma = self._sigma_like(mu)\n        td.set(\"mu\", mu)\n        td.set(\"sigma\", sigma)\n\n        if mode == \"inference\":\n            td.set(\"actions\", mu)\n            return td\n\n        self.distribution = Normal(mu, sigma)\n        if mode == \"sampling\":\n            actions_out = self.distribution.sample()\n        else:\n            if actions is None:\n                raise ValueError(\"actions must be provided when mode='logp'\")\n            actions_out = actions\n        td.set(\"actions\", actions_out)\n        td.set(\n            \"actions_log_prob\",\n            self.distribution.log_prob(actions_out).sum(dim=-1),\n        )\n        td.set(\"entropy\", self.distribution.entropy().sum(dim=-1))\n        return td\n\n    def sequence_forward_logp_cond(\n        self,\n        state_seq: torch.Tensor,\n        future_seq: torch.Tensor,\n        actions: torch.Tensor,\n        attn_mask: torch.Tensor | None,\n        future_mask: torch.Tensor | None,\n    ) -> tuple[\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        dict[str, torch.Tensor] | None,\n    ]:\n        aux_preds = None\n        need_pre_moe_aux = self.aux_state_pred_enabled\n        need_router_aux = (\n            self.aux_router_command_recon_enabled\n            or self.aux_router_switch_penalty_enabled\n        )\n        if need_pre_moe_aux and need_router_aux:\n            actor_outputs = self.actor_module.sequence_mu_cond(\n                state_seq,\n                future_seq,\n                attn_mask=attn_mask,\n                future_mask=future_mask,\n                return_pre_moe_hidden=True,\n                return_router_features=True,\n                return_router_temporal_features=self.aux_router_switch_penalty_enabled,\n            )\n            if self.aux_router_switch_penalty_enabled:\n                (\n                    mu,\n                    pre_moe_hidden,\n                    router_features,\n                    router_temporal_features,\n                ) = actor_outputs\n            else:\n                mu, pre_moe_hidden, router_features = actor_outputs\n            aux_preds = self.actor_module.predict_aux_from_pre_moe(\n                pre_moe_hidden\n            )\n            aux_preds[\"router_features\"] = router_features\n            if self.aux_router_switch_penalty_enabled:\n                aux_preds[\"router_temporal_features\"] = (\n                    router_temporal_features\n                )\n            if self.aux_router_command_recon_enabled:\n                aux_preds[\"router_command_recon\"] = (\n                    self.actor_module.predict_aux_router_command_from_router_features(\n                        router_features\n                    )\n                )\n        elif need_pre_moe_aux:\n            mu, pre_moe_hidden = self.actor_module.sequence_mu_cond(\n                state_seq,\n                future_seq,\n                attn_mask=attn_mask,\n                future_mask=future_mask,\n                return_pre_moe_hidden=True,\n            )\n            aux_preds = self.actor_module.predict_aux_from_pre_moe(\n                pre_moe_hidden\n            )\n        elif need_router_aux:\n            actor_outputs = self.actor_module.sequence_mu_cond(\n                state_seq,\n                future_seq,\n                attn_mask=attn_mask,\n                future_mask=future_mask,\n                return_router_features=True,\n                return_router_temporal_features=self.aux_router_switch_penalty_enabled,\n            )\n            if self.aux_router_switch_penalty_enabled:\n                (\n                    mu,\n                    router_features,\n                    router_temporal_features,\n                ) = actor_outputs\n            else:\n                mu, router_features = actor_outputs\n            aux_preds = {\"router_features\": router_features}\n            if self.aux_router_switch_penalty_enabled:\n                aux_preds[\"router_temporal_features\"] = (\n                    router_temporal_features\n                )\n            if self.aux_router_command_recon_enabled:\n                aux_preds[\"router_command_recon\"] = (\n                    self.actor_module.predict_aux_router_command_from_router_features(\n                        router_features\n                    )\n                )\n        else:\n            mu = self.actor_module.sequence_mu_cond(\n                state_seq,\n                future_seq,\n                attn_mask=attn_mask,\n                future_mask=future_mask,\n            )\n        sigma_vec = self._sigma_from_params().clamp(\n            self.min_sigma, self.max_sigma\n        )\n        sigma = sigma_vec[None, None, :].expand_as(mu)\n        var = sigma * sigma\n        logp = -0.5 * (\n            ((actions - mu) ** 2) / (var + 1.0e-8)\n            + 2.0 * torch.log(sigma + 1.0e-8)\n            + math.log(2.0 * math.pi)\n        ).sum(dim=-1, keepdim=True)\n        entropy = (\n            0.5 + 0.5 * math.log(2.0 * math.pi) + torch.log(sigma + 1.0e-8)\n        ).sum(dim=-1, keepdim=True)\n        return mu, sigma, logp, entropy, aux_preds\n"
  },
  {
    "path": "holomotion/src/modules/network_modules.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport math\nfrom contextlib import nullcontext\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\n\n\nclass EmpiricalNormalization(nn.Module):\n    \"\"\"Normalize mean and variance of values based on empirical values.\"\"\"\n\n    def __init__(\n        self,\n        shape,\n        eps: float = 1e-2,\n        until: int | None = None,\n        *,\n        update_method: str = \"cumulative\",\n        ema_momentum: float | None = None,\n    ):\n        \"\"\"Initialize EmpiricalNormalization module.\n\n        Args:\n            shape (int or tuple of int): Shape of input values except\n                batch axis.\n            eps (float): Small value for stability.\n            until (int or None): If this arg is specified, the link learns\n                input values until the sum of batch sizes\n            exceeds it.\n            update_method:\n                One of {\"cumulative\", \"ema\"}.\n                - \"cumulative\": count-based updates (legacy behavior).\n                - \"ema\": EMA updates of mean and second moment.\n            ema_momentum:\n                EMA momentum in (0, 1]. Required when update_method == \"ema\".\n        \"\"\"\n        super().__init__()\n        self.eps = eps\n        self.until = until\n        self.update_method = str(update_method).lower()\n        self.ema_momentum = (\n            float(ema_momentum) if ema_momentum is not None else None\n        )\n        if self.update_method in (\"count\", \"cumulative\"):\n            self.update_method = \"cumulative\"\n        elif self.update_method in (\"ema\", \"exp\", \"exponential\"):\n            self.update_method = \"ema\"\n        else:\n            raise ValueError(\n                f\"update_method must be one of {{'cumulative','ema'}}, got {update_method}\"\n            )\n        if self.update_method == \"ema\":\n            if self.ema_momentum is None:\n                raise ValueError(\n                    \"ema_momentum must be provided when update_method == 'ema'\"\n                )\n            if not (0.0 < self.ema_momentum <= 1.0):\n                raise ValueError(\n                    f\"ema_momentum must be in (0, 1], got {self.ema_momentum}\"\n                )\n        self.register_buffer(\"_mean\", torch.zeros(shape)[None, ...])\n        self.register_buffer(\"_var\", torch.ones(shape)[None, ...])\n        self.register_buffer(\"_std\", torch.ones(shape)[None, ...])\n        self.register_buffer(\"_ex2\", torch.ones(shape)[None, ...])\n        self.register_buffer(\"count\", torch.tensor(0, dtype=torch.long))\n        self.register_buffer(\"_last_sync_mean\", torch.zeros(shape)[None, ...])\n        self.register_buffer(\"_last_sync_var\", torch.ones(shape)[None, ...])\n        self.register_buffer(\n            \"_last_sync_count\", torch.tensor(0, dtype=torch.long)\n        )\n\n    @property\n    def mean(self):\n        return self._mean.squeeze(0).clone()\n\n    @property\n    def std(self):\n        return self._std.squeeze(0).clone()\n\n    def forward(self, x):\n        \"\"\"Normalize mean and variance of values based on empirical values.\n\n        Args:\n            x (ndarray or Variable): Input values\n\n        Returns:\n            ndarray or Variable: Normalized output values\n        \"\"\"\n\n        if self.training:\n            self.update(x)\n        return (x - self._mean) / (self._std + self.eps)\n\n    def normalize_only(self, x):\n        return (x - self._mean) / (self._std + self.eps)\n\n    @torch.compiler.disable\n    @torch.jit.unused\n    def update(self, x):\n        \"\"\"Learn input values without computing the output values of them.\"\"\"\n\n        if self.until is not None and self.count >= self.until:\n            return\n\n        count_x = x.shape[0]\n        self.count += count_x\n        if self.update_method == \"ema\":\n            m = float(self.ema_momentum)\n            mean_x = torch.mean(x, dim=0, keepdim=True)\n            ex2_x = torch.mean(x * x, dim=0, keepdim=True)\n            self._mean.mul_(1.0 - m).add_(mean_x, alpha=m)\n            self._ex2.mul_(1.0 - m).add_(ex2_x, alpha=m)\n            var = torch.clamp(self._ex2 - self._mean * self._mean, min=0.0)\n            self._var.copy_(var)\n            self._std.copy_(torch.sqrt(self._var))\n            return\n\n        rate = count_x / self.count\n\n        var_x = torch.var(x, dim=0, unbiased=False, keepdim=True)\n        mean_x = torch.mean(x, dim=0, keepdim=True)\n        delta_mean = mean_x - self._mean\n        self._mean += rate * delta_mean\n        self._var += rate * (\n            var_x - self._var + delta_mean * (mean_x - self._mean)\n        )\n        self._std = torch.sqrt(self._var)\n\n    @torch.jit.unused\n    def inverse(self, y):\n        return y * (self._std + self.eps) + self._mean\n\n    def sync_stats_across_processes(self, accelerator):\n        \"\"\"Synchronize normalization statistics across distributed processes.\"\"\"\n        if accelerator.num_processes <= 1:\n            return\n\n        if self.update_method == \"ema\":\n            # EMA stats are already running estimates.\n            # Sync by averaging across ranks.\n            mean_g = accelerator.reduce(\n                self._mean.to(dtype=torch.float32), reduction=\"mean\"\n            )\n            ex2_g = accelerator.reduce(\n                self._ex2.to(dtype=torch.float32), reduction=\"mean\"\n            )\n            var_g = torch.clamp(ex2_g - mean_g * mean_g, min=0.0)\n            self._mean.copy_(mean_g.to(self._mean.dtype))\n            self._ex2.copy_(ex2_g.to(self._ex2.dtype))\n            self._var.copy_(var_g.to(self._var.dtype))\n            self._std.copy_(torch.sqrt(self._var))\n            return\n\n        # Weighted synchronization with correction to avoid double counting\n        device = self._mean.device\n        count_local = self.count.to(device=device, dtype=torch.float32)\n        mean_local = self._mean.to(device=device, dtype=torch.float32)\n        var_local = self._var.to(device=device, dtype=torch.float32)\n\n        # Local weighted sums\n        sum_count = accelerator.reduce(count_local, reduction=\"sum\")\n        sum_mean_count = accelerator.reduce(\n            mean_local * count_local, reduction=\"sum\"\n        )\n        sum_ex2_count = accelerator.reduce(\n            (var_local + mean_local * mean_local) * count_local,\n            reduction=\"sum\",\n        )\n\n        # Correct for replication of previously-synced global stats\n        # across ranks.\n        last_c = self._last_sync_count.to(device=device, dtype=torch.float32)\n        if last_c.item() > 0:\n            w_minus_1 = float(accelerator.num_processes - 1)\n            last_mean = self._last_sync_mean.to(\n                device=device, dtype=torch.float32\n            )\n            last_var = self._last_sync_var.to(\n                device=device, dtype=torch.float32\n            )\n            sum_count = sum_count - w_minus_1 * last_c\n            sum_mean_count = sum_mean_count - w_minus_1 * (last_mean * last_c)\n            sum_ex2_count = sum_ex2_count - w_minus_1 * (\n                (last_var + last_mean * last_mean) * last_c\n            )\n\n        if sum_count.item() <= 0:\n            return\n\n        global_mean = sum_mean_count / sum_count\n        global_ex2 = sum_ex2_count / sum_count\n        global_var = torch.clamp(\n            global_ex2 - global_mean * global_mean, min=0.0\n        )\n        global_std = torch.sqrt(global_var)\n\n        # Copy back (keep original buffer shapes)\n        self._mean.copy_(global_mean.to(self._mean.dtype))\n        self._var.copy_(global_var.to(self._var.dtype))\n        self._std.copy_(global_std.to(self._std.dtype))\n        # Set global sample count and remember snapshot for next correction\n        self.count.copy_(sum_count.to(self.count.dtype))\n        self._last_sync_mean.copy_(global_mean.to(self._last_sync_mean.dtype))\n        self._last_sync_var.copy_(global_var.to(self._last_sync_var.dtype))\n        self._last_sync_count.copy_(self.count)\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        input_dim: int,\n        output_dim: int,\n        module_config_dict: dict,\n    ):\n        super().__init__()\n        self.module_config_dict = module_config_dict\n        self.input_dim = int(input_dim)\n        self.output_dim = int(output_dim)\n        if self.input_dim <= 0:\n            raise ValueError(\n                f\"MLP input_dim must be positive, got {self.input_dim}\"\n            )\n        if self.output_dim <= 0:\n            raise ValueError(\n                f\"MLP output_dim must be positive, got {self.output_dim}\"\n            )\n\n        def _make_norm(\n            norm_type: str,\n            dim: int,\n            *,\n            eps: float,\n        ) -> nn.Module:\n            t = str(norm_type).lower()\n            if t in (\"none\", \"identity\", \"null\"):\n                return nn.Identity()\n            if t in (\"layernorm\", \"ln\"):\n                return nn.LayerNorm(dim, eps=eps)\n            if t in (\"rmsnorm\", \"rms\"):\n                return RMSNorm(dim, eps=eps)\n            raise ValueError(\n                f\"Unknown norm '{t}'. Expected one of {'none', 'layernorm', 'rmsnorm'}.\"\n            )\n\n        self.hidden_norm_type = module_config_dict.get(\"hidden_norm\", \"none\")\n        self.hidden_norm_eps = float(\n            module_config_dict.get(\"hidden_norm_eps\", 1.0e-6)\n        )\n\n        layer_config = self.module_config_dict[\"layer_config\"]\n        hidden_dims: list[int] = list(layer_config.get(\"hidden_dims\", []))\n        activation = getattr(nn, str(layer_config[\"activation\"]))()\n\n        layers: list[nn.Module] = []\n        prev = self.input_dim\n        for h in hidden_dims:\n            h_i = int(h)\n            layers.append(nn.Linear(prev, h_i))\n            layers.append(\n                _make_norm(\n                    self.hidden_norm_type,\n                    h_i,\n                    eps=self.hidden_norm_eps,\n                )\n            )\n            layers.append(activation)\n            prev = h_i\n        self.trunk = nn.Sequential(*layers) if layers else nn.Identity()\n        self.output_head = nn.Linear(prev, self.output_dim)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward.\n\n        Args:\n            x: [..., input_dim] assembled tensor observations.\n\n        Returns:\n            y: [..., output_dim]\n        \"\"\"\n        if not isinstance(x, torch.Tensor):\n            raise TypeError(f\"MLP expects torch.Tensor input, got {type(x)}\")\n        h = self.trunk(x)\n        return self.output_head(h)\n\n\nclass ConvMLP(nn.Module):\n    \"\"\"Conv1d + pooling history encoder with an MLP head.\"\"\"\n\n    def __init__(\n        self,\n        input_dim: int,\n        output_dim: int,\n        module_config_dict: dict,\n    ):\n        super().__init__()\n        self.module_config_dict = module_config_dict\n        self.input_dim = int(input_dim)\n        self.output_dim = int(output_dim)\n\n        layer_cfg = dict(module_config_dict.get(\"layer_config\", {}))\n        activation = str(layer_cfg.get(\"activation\", \"SiLU\"))\n\n        self.conv_channels = int(module_config_dict.get(\"conv_channels\", 128))\n        self.conv_layers = int(module_config_dict.get(\"conv_layers\", 2))\n        self.conv_kernel_size = int(\n            module_config_dict.get(\"conv_kernel_size\", 3)\n        )\n        self.pool_type = str(\n            module_config_dict.get(\"pool_type\", \"avg\")\n        ).lower()\n\n        conv_modules: list[nn.Module] = []\n        padding = self.conv_kernel_size // 2\n        in_ch = int(self.input_dim)\n        for _ in range(self.conv_layers):\n            conv_modules.append(\n                nn.Conv1d(\n                    in_channels=in_ch,\n                    out_channels=self.conv_channels,\n                    kernel_size=self.conv_kernel_size,\n                    padding=padding,\n                    bias=True,\n                )\n            )\n            conv_modules.append(getattr(nn, activation)())\n            in_ch = self.conv_channels\n\n        conv_modules.append(nn.AdaptiveAvgPool1d(1))\n\n        self.hist_encoder = nn.Sequential(*conv_modules)\n\n        fused_dim = int(self.conv_channels + self.input_dim)\n        self.mlp_head = MLP(\n            input_dim=fused_dim,\n            output_dim=int(self.output_dim),\n            module_config_dict=module_config_dict,\n        )\n\n    def forward(self, hist_seq: torch.Tensor) -> torch.Tensor:\n        ctx = self.hist_encoder(hist_seq.transpose(1, 2)).squeeze(-1)\n        latest = hist_seq[:, -1, :]\n        fused = torch.cat([ctx, latest], dim=-1)\n        return self.mlp_head(fused)\n\n\nclass ReferenceMotionConvRouterEncoder(nn.Module):\n    \"\"\"Conv1d encoder for reference-motion router sequences.\"\"\"\n\n    def __init__(\n        self,\n        input_dim: int,\n        output_dim: int,\n        *,\n        conv_channels: int = 128,\n        conv_layers: int = 2,\n        conv_kernel_size: int = 3,\n        pool_type: str = \"avg\",\n    ):\n        super().__init__()\n        self.input_dim = int(input_dim)\n        self.output_dim = int(output_dim)\n        self.conv_channels = int(conv_channels)\n        self.conv_layers = int(conv_layers)\n        self.conv_kernel_size = int(conv_kernel_size)\n        self.pool_type = str(pool_type).lower()\n        if self.input_dim <= 0:\n            raise ValueError(\n                f\"input_dim must be positive, got {self.input_dim}\"\n            )\n        if self.output_dim <= 0:\n            raise ValueError(\n                f\"output_dim must be positive, got {self.output_dim}\"\n            )\n        if self.conv_channels <= 0:\n            raise ValueError(\n                f\"conv_channels must be positive, got {self.conv_channels}\"\n            )\n        if self.conv_layers <= 0:\n            raise ValueError(\n                f\"conv_layers must be positive, got {self.conv_layers}\"\n            )\n        if self.conv_kernel_size <= 0:\n            raise ValueError(\n                f\"conv_kernel_size must be positive, got {self.conv_kernel_size}\"\n            )\n        if self.pool_type not in {\"avg\", \"max\"}:\n            raise ValueError(\n                f\"pool_type must be one of {{'avg','max'}}, got {self.pool_type}\"\n            )\n\n        padding = self.conv_kernel_size // 2\n        conv_modules: list[nn.Module] = []\n        in_ch = self.input_dim\n        for _ in range(self.conv_layers):\n            conv_modules.append(\n                nn.Conv1d(\n                    in_channels=in_ch,\n                    out_channels=self.conv_channels,\n                    kernel_size=self.conv_kernel_size,\n                    padding=padding,\n                    bias=True,\n                )\n            )\n            conv_modules.append(nn.SiLU())\n            in_ch = self.conv_channels\n        self.temporal_trunk = nn.Sequential(*conv_modules)\n        if self.pool_type == \"avg\":\n            self.pool = nn.AdaptiveAvgPool1d(1)\n        else:\n            self.pool = nn.AdaptiveMaxPool1d(1)\n        self.out_proj = nn.Sequential(\n            nn.Linear(self.conv_channels, self.output_dim),\n            nn.SiLU(),\n            nn.Linear(self.output_dim, self.output_dim),\n        )\n\n    def forward(self, seq: torch.Tensor) -> torch.Tensor:\n        if seq.ndim != 3:\n            raise ValueError(\n                f\"Expected router seq with shape [B, T, D], got {tuple(seq.shape)}.\"\n            )\n        if int(seq.shape[-1]) != self.input_dim:\n            raise ValueError(\n                \"Router seq dim mismatch: expected \"\n                f\"{self.input_dim}, got {int(seq.shape[-1])}.\"\n            )\n        x = seq.transpose(1, 2)\n        x = self.temporal_trunk(x)\n        x = self.pool(x).squeeze(-1)\n        return self.out_proj(x)\n\n\nclass SingleQueryAttentionPool(nn.Module):\n    def __init__(self, d_model: int):\n        super().__init__()\n        self.d_model = int(d_model)\n        self.scale = float(self.d_model) ** -0.5\n        self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)\n        self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)\n        self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)\n        self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        tokens: torch.Tensor,\n    ) -> torch.Tensor:\n        if query.ndim == 2:\n            if tokens.ndim != 3:\n                raise ValueError(\n                    \"SingleQueryAttentionPool expected [B, N, D] tokens for \"\n                    f\"2D query, got {tuple(tokens.shape)}.\"\n                )\n            q = self.q_proj(query).unsqueeze(-2)\n            k = self.k_proj(tokens)\n            v = self.v_proj(tokens)\n            attn = torch.softmax(\n                (q * k).sum(dim=-1, keepdim=True) * self.scale,\n                dim=-2,\n            )\n            return self.out_proj((attn * v).sum(dim=-2))\n        if query.ndim == 3:\n            if tokens.ndim != 4:\n                raise ValueError(\n                    \"SingleQueryAttentionPool expected [B, T, N, D] tokens for \"\n                    f\"3D query, got {tuple(tokens.shape)}.\"\n                )\n            q = self.q_proj(query).unsqueeze(-2)\n            k = self.k_proj(tokens)\n            v = self.v_proj(tokens)\n            attn = torch.softmax(\n                (q * k).sum(dim=-1, keepdim=True) * self.scale,\n                dim=-2,\n            )\n            return self.out_proj((attn * v).sum(dim=-2))\n        raise ValueError(\n            f\"SingleQueryAttentionPool query must be 2D or 3D, got {query.ndim}.\"\n        )\n\n\nclass GroupedMoETransformerPolicy(nn.Module):\n    \"\"\"Hybrid Modern Transformer decoder policy with SOTA improvements.\n    Structure:\n        - Layer 0: Dense MLP (ModernTransformerBlock)\n        - Optional final layer: Dense MLP when dense_layer_at_last=True\n        - Intermediate layers: MoE MLP (GroupedMoEBlock)\n    Features:\n        - RealRoPE.\n        - RMSNorm: Root Mean Square Normalization.\n        - GQA: Grouped Query Attention (configurable n_kv_heads).\n        - QK-Norm: RMSNorm on Queries and Keys.\n        - Gated Attention: Qwen-style element-wise sigmoid gating.\n        - SwiGLU MLP: DeepseekV3MLP for feed-forward.\n        - Flash Attention: via F.scaled_dot_product_attention.\n        - Gradient Checkpointing: optional for memory efficiency.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim: int,\n        output_dim: int,\n        module_config_dict: dict,\n    ):\n        super().__init__()\n        self.input_dim = int(input_dim)\n        self.output_dim = int(output_dim)\n        self.module_config_dict = module_config_dict\n\n        self.num_fine_experts = module_config_dict[\"num_fine_experts\"]\n        self.num_shared_experts = module_config_dict[\"num_shared_experts\"]\n        self.top_k = module_config_dict[\"top_k\"]\n        self.use_dynamic_bias = module_config_dict.get(\n            \"use_dynamic_bias\", False\n        )\n        self.bias_update_rate = module_config_dict.get(\n            \"bias_update_rate\", 0.001\n        )\n        self.routing_score_fn = str(\n            module_config_dict.get(\"routing_score_fn\", \"softmax\")\n        ).lower()\n        self.freeze_router = bool(\n            module_config_dict.get(\"freeze_router\", False)\n        )\n        self.routing_scale = float(\n            module_config_dict.get(\"routing_scale\", 1.0)\n        )\n        self.expert_bias_clip = float(\n            module_config_dict.get(\"expert_bias_clip\", 0.0)\n        )\n        dead_margin_cfg = module_config_dict.get(\n            \"dead_expert_margin_to_topk\", {}\n        )\n        selected_margin_cfg = module_config_dict.get(\n            \"selected_expert_margin_to_unselected\", {}\n        )\n        self.dead_expert_margin_to_topk_enabled = bool(\n            dead_margin_cfg.get(\"enabled\", False)\n        )\n        self.selected_expert_margin_to_unselected_enabled = bool(\n            selected_margin_cfg.get(\"enabled\", False)\n        )\n        self.selected_expert_margin_to_unselected_target = float(\n            selected_margin_cfg.get(\"target\", 0.0)\n        )\n        if self.routing_score_fn not in (\"softmax\", \"sigmoid\"):\n            raise ValueError(\n                f\"routing_score_fn must be one of {{'softmax','sigmoid'}}, got {self.routing_score_fn}\"\n            )\n        if self.routing_scale <= 0.0:\n            raise ValueError(\n                f\"routing_scale must be > 0, got {self.routing_scale}\"\n            )\n        if self.expert_bias_clip < 0.0:\n            raise ValueError(\n                f\"expert_bias_clip must be >= 0, got {self.expert_bias_clip}\"\n            )\n        if self.selected_expert_margin_to_unselected_target < 0.0:\n            raise ValueError(\n                \"selected_expert_margin_to_unselected.target must be >= 0, \"\n                f\"got {self.selected_expert_margin_to_unselected_target}\"\n            )\n\n        _ov = module_config_dict.get(\"input_dim_override\", None)\n        self.obs_input_dim = (\n            int(_ov) if isinstance(_ov, (int, float)) else None\n        )\n\n        self.obs_embed_mlp_hidden = int(\n            module_config_dict.get(\"obs_embed_mlp_hidden\", 1024)\n        )\n\n        self.d_model = int(module_config_dict.get(\"d_model\", 256))\n        self.n_layers = int(module_config_dict.get(\"n_layers\", 4))\n        self.dense_layer_at_last = bool(\n            module_config_dict.get(\"dense_layer_at_last\", False)\n        )\n        self.n_heads = int(module_config_dict.get(\"n_heads\", 4))\n        self.n_kv_heads = int(\n            module_config_dict.get(\"n_kv_heads\", self.n_heads // 2)\n        )\n        self.ff_mult = float(module_config_dict.get(\"ff_mult\", 4))\n        self.ff_mult_dense = int(\n            module_config_dict.get(\"ff_mult_dense\", self.ff_mult * 3)\n        )\n        self.attn_dropout = float(module_config_dict.get(\"attn_dropout\", 0.0))\n        self.mlp_dropout = float(module_config_dict.get(\"mlp_dropout\", 0.0))\n        self.max_ctx_len = int(module_config_dict.get(\"max_ctx_len\", 64))\n        self.use_qk_norm = module_config_dict.get(\"use_qk_norm\", True)\n        self.use_gated_attn = module_config_dict.get(\"use_gated_attn\", True)\n        self.gated_attn_type = module_config_dict.get(\n            \"gated_attn_type\", \"headwise\"\n        )\n        self.use_checkpointing = module_config_dict.get(\n            \"use_checkpointing\", False\n        )\n        self.use_future_cross_attn = bool(\n            module_config_dict.get(\"use_future_cross_attn\", False)\n        )\n        self.state_obs_dim = int(\n            module_config_dict.get(\n                \"state_obs_dim\", self.obs_input_dim or self.input_dim\n            )\n        )\n        self.future_seq_len = int(module_config_dict.get(\"future_seq_len\", 0))\n        self.future_token_dim = int(\n            module_config_dict.get(\"future_token_dim\", 0)\n        )\n\n        self.head_dim = self.d_model // self.n_heads\n        if self.d_model % self.n_heads != 0:\n            raise ValueError(\n                f\"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})\"\n            )\n        if self.head_dim % 2 != 0:\n            raise ValueError(\n                f\"RoPE requires even head_dim, got head_dim={self.head_dim}\"\n            )\n\n        # RoPE configuration (used in both sequence and KV-cached single-step inference)\n        self.rope_theta = float(module_config_dict.get(\"rope_theta\", 10000.0))\n        self.inv_freq = 1.0 / (\n            self.rope_theta\n            ** (\n                torch.arange(0, self.head_dim, 2, dtype=torch.float32)\n                / self.head_dim\n            )\n        )  # [head_dim//2]\n        self.register_buffer(\"_rope_inv_freq\", self.inv_freq, persistent=False)\n        self._set_cos_sin_cache(seq_len=8192)\n\n        obs_in = self.obs_input_dim or self.input_dim\n        if self.use_future_cross_attn:\n            if self.future_seq_len <= 0:\n                raise ValueError(\n                    \"future_seq_len must be positive when use_future_cross_attn=True\"\n                )\n            if self.future_token_dim <= 0:\n                raise ValueError(\n                    \"future_token_dim must be positive when use_future_cross_attn=True\"\n                )\n            self.state_obs_embed = nn.Sequential(\n                nn.Linear(self.state_obs_dim, self.obs_embed_mlp_hidden),\n                nn.SiLU(),\n                nn.Linear(self.obs_embed_mlp_hidden, self.d_model),\n            )\n            # Keep a single state embedding module so DDP doesn't see unused\n            # parameters from an extra unused `obs_embed` in conditional mode.\n            self.obs_embed = self.state_obs_embed\n            self.future_obs_embed = nn.Sequential(\n                nn.Linear(self.future_token_dim, self.obs_embed_mlp_hidden),\n                nn.SiLU(),\n                nn.Linear(self.obs_embed_mlp_hidden, self.d_model),\n            )\n            self.future_pos_embed = nn.Embedding(\n                self.future_seq_len, self.d_model\n            )\n        else:\n            self.obs_embed = nn.Sequential(\n                nn.Linear(obs_in, self.obs_embed_mlp_hidden),\n                nn.SiLU(),\n                nn.Linear(self.obs_embed_mlp_hidden, self.d_model),\n            )\n            self.state_obs_embed = None\n            self.future_obs_embed = None\n            self.future_pos_embed = None\n        # Stack of TransformerBlocks: first layer is always dense; the last\n        # layer is also dense when dense_layer_at_last=True.\n        self.layers = nn.ModuleList()\n        for i in range(self.n_layers):\n            use_dense_layer = i == 0 or (\n                self.dense_layer_at_last and i == self.n_layers - 1\n            )\n            if use_dense_layer:\n                layer = ModernTransformerBlock(\n                    d_model=self.d_model,\n                    n_heads=self.n_heads,\n                    n_kv_heads=self.n_kv_heads,\n                    ff_mult=self.ff_mult_dense,\n                    use_qk_norm=self.use_qk_norm,\n                    use_gated_attn=self.use_gated_attn,\n                    gated_attn_type=self.gated_attn_type,\n                    attn_dropout=self.attn_dropout,\n                    mlp_dropout=self.mlp_dropout,\n                    use_cross_attn=self.use_future_cross_attn,\n                )\n            else:\n                layer = GroupedMoEBlock(\n                    d_model=self.d_model,\n                    n_heads=self.n_heads,\n                    n_kv_heads=self.n_kv_heads,\n                    ff_mult=self.ff_mult,\n                    use_qk_norm=self.use_qk_norm,\n                    use_gated_attn=self.use_gated_attn,\n                    gated_attn_type=self.gated_attn_type,\n                    attn_dropout=self.attn_dropout,\n                    mlp_dropout=self.mlp_dropout,\n                    num_fine_experts=self.num_fine_experts,\n                    num_shared_experts=self.num_shared_experts,\n                    top_k=self.top_k,\n                    use_dynamic_bias=self.use_dynamic_bias,\n                    bias_update_rate=self.bias_update_rate,\n                    routing_score_fn=self.routing_score_fn,\n                    freeze_router=self.freeze_router,\n                    routing_scale=self.routing_scale,\n                    expert_bias_clip=self.expert_bias_clip,\n                    dead_expert_margin_to_topk_enabled=(\n                        self.dead_expert_margin_to_topk_enabled\n                    ),\n                    selected_expert_margin_to_unselected_enabled=(\n                        self.selected_expert_margin_to_unselected_enabled\n                    ),\n                    selected_expert_margin_to_unselected_target=(\n                        self.selected_expert_margin_to_unselected_target\n                    ),\n                    use_cross_attn=self.use_future_cross_attn,\n                )\n            self.layers.append(layer)\n        self._last_moe_layer_idx = None\n        for layer_idx, layer in enumerate(self.layers):\n            if isinstance(layer, GroupedMoEBlock):\n                self._last_moe_layer_idx = layer_idx\n\n        self.norm_f = RMSNorm(self.d_model)\n        self.action_mu_head = nn.Sequential(\n            nn.Linear(self.d_model, self.d_model),\n            nn.SiLU(),\n            nn.Linear(self.d_model, self.output_dim),\n        )\n        aux_cfg = module_config_dict.get(\"aux_state_pred\", {})\n        self.aux_state_pred_enabled = bool(aux_cfg.get(\"enabled\", False))\n        self.aux_contact_dim = int(\n            len(aux_cfg.get(\"keybody_contact_names\", []))\n        )\n        self.aux_keybody_pos_dim = int(\n            len(aux_cfg.get(\"keybody_rel_pos_names\", []))\n        )\n        self.use_aux_denoise_ref_root_lin_vel = bool(\n            float(aux_cfg.get(\"w_denoise_ref_root_lin_vel\", 0.0)) > 0.0\n        )\n        self.use_aux_denoise_ref_root_ang_vel = bool(\n            float(aux_cfg.get(\"w_denoise_ref_root_ang_vel\", 0.0)) > 0.0\n        )\n        self.use_aux_denoise_ref_dof_pos = bool(\n            float(aux_cfg.get(\"w_denoise_ref_dof_pos\", 0.0)) > 0.0\n        )\n        if self.aux_state_pred_enabled:\n            self.aux_vel_head = nn.Linear(self.d_model, 6)\n            self.aux_height_head = nn.Linear(self.d_model, 2)\n            self.aux_denoise_ref_root_lin_vel_head = (\n                nn.Linear(self.d_model, 3)\n                if self.use_aux_denoise_ref_root_lin_vel\n                else None\n            )\n            self.aux_denoise_ref_root_ang_vel_head = (\n                nn.Linear(self.d_model, 3)\n                if self.use_aux_denoise_ref_root_ang_vel\n                else None\n            )\n            self.aux_contact_head = (\n                nn.Linear(self.d_model, self.aux_contact_dim)\n                if self.aux_contact_dim > 0\n                else None\n            )\n            self.aux_ref_keybody_pos_head = (\n                nn.Linear(self.d_model, self.aux_keybody_pos_dim * 3)\n                if self.aux_keybody_pos_dim > 0\n                else None\n            )\n            self.aux_robot_keybody_pos_head = (\n                nn.Linear(self.d_model, self.aux_keybody_pos_dim * 3)\n                if self.aux_keybody_pos_dim > 0\n                else None\n            )\n            self.aux_denoise_ref_dof_pos_head = (\n                nn.Linear(self.d_model, self.output_dim)\n                if self.use_aux_denoise_ref_dof_pos\n                else None\n            )\n        else:\n            self.aux_vel_head = None\n            self.aux_height_head = None\n            self.aux_denoise_ref_root_lin_vel_head = None\n            self.aux_denoise_ref_root_ang_vel_head = None\n            self.aux_contact_head = None\n            self.aux_ref_keybody_pos_head = None\n            self.aux_robot_keybody_pos_head = None\n            self.aux_denoise_ref_dof_pos_head = None\n\n        # True per-layer KV cache for single-step inference.\n        # K/V shapes: [B, n_layers, max_ctx_len, n_kv_heads, head_dim]\n        self._k_cache: torch.Tensor | None = None\n        self._v_cache: torch.Tensor | None = None\n        # Cache state per environment\n        self._kv_cache_len: torch.Tensor | None = None  # [B]\n        self._kv_cache_write_idx: torch.Tensor | None = None  # [B]\n        self._kv_cache_abs_pos: torch.Tensor | None = None  # [B]\n        self._prev_last_moe_router_p: torch.Tensor | None = None\n        self._prev_last_moe_router_valid: torch.Tensor | None = None\n        self._last_moe_router_js_sum: torch.Tensor | None = None\n        self._last_moe_router_js_count: torch.Tensor | None = None\n        self._last_moe_router_top1_switch_sum: torch.Tensor | None = None\n        self._last_moe_router_top1_switch_count: torch.Tensor | None = None\n        aux_cmd_cfg = module_config_dict.get(\"aux_router_command_recon\", {})\n        self.aux_router_command_recon_enabled = bool(\n            aux_cmd_cfg.get(\"enabled\", False)\n        )\n        self.aux_router_command_recon_output_dim = int(\n            aux_cmd_cfg.get(\"output_dim\", 0)\n        )\n        self.aux_router_command_recon_hidden_dim = int(\n            aux_cmd_cfg.get(\"hidden_dim\", self.d_model)\n        )\n        self._num_moe_layers = sum(\n            1 for layer in self.layers if isinstance(layer, GroupedMoEBlock)\n        )\n        if self.aux_router_command_recon_enabled:\n            if self._num_moe_layers <= 0:\n                raise ValueError(\n                    \"aux_router_command_recon requires at least one GroupedMoEBlock.\"\n                )\n            if self.aux_router_command_recon_output_dim <= 0:\n                raise ValueError(\n                    \"aux_router_command_recon.output_dim must be positive when enabled.\"\n                )\n            router_feature_dim = self._num_moe_layers * self.num_fine_experts\n            self.aux_router_command_recon_head = nn.Sequential(\n                nn.Linear(\n                    router_feature_dim,\n                    self.aux_router_command_recon_hidden_dim,\n                ),\n                nn.SiLU(),\n                nn.Linear(\n                    self.aux_router_command_recon_hidden_dim,\n                    self.aux_router_command_recon_output_dim,\n                ),\n            )\n        else:\n            self.aux_router_command_recon_head = None\n        aux_router_future_cfg = module_config_dict.get(\n            \"aux_router_future_recon\", {}\n        )\n        self.aux_router_future_recon_enabled = bool(\n            aux_router_future_cfg.get(\"enabled\", False)\n        )\n        self.aux_router_future_recon_output_dim = int(\n            aux_router_future_cfg.get(\"output_dim\", 0)\n        )\n        self.aux_router_future_recon_hidden_dim = int(\n            aux_router_future_cfg.get(\"hidden_dim\", self.d_model)\n        )\n        aux_router_future_norm_cfg = aux_router_future_cfg.get(\n            \"target_norm\", {}\n        )\n        self.aux_router_future_recon_norm_eps = float(\n            aux_router_future_norm_cfg.get(\"epsilon\", 1.0e-2)\n        )\n        self.aux_router_future_recon_norm_update_method = str(\n            aux_router_future_norm_cfg.get(\"update_method\", \"cumulative\")\n        ).lower()\n        aux_router_future_norm_ema = aux_router_future_norm_cfg.get(\n            \"ema_momentum\", None\n        )\n        self.aux_router_future_recon_norm_ema_momentum = (\n            float(aux_router_future_norm_ema)\n            if aux_router_future_norm_ema is not None\n            else None\n        )\n        if self.aux_router_future_recon_enabled:\n            if self.aux_router_future_recon_output_dim <= 0:\n                raise ValueError(\n                    \"aux_router_future_recon.output_dim must be positive when enabled.\"\n                )\n            self.aux_router_future_recon_head = nn.Sequential(\n                nn.Linear(\n                    self.d_model,\n                    self.aux_router_future_recon_hidden_dim,\n                ),\n                nn.SiLU(),\n                nn.Linear(\n                    self.aux_router_future_recon_hidden_dim,\n                    self.aux_router_future_recon_output_dim,\n                ),\n            )\n            self.aux_router_future_recon_normalizer = EmpiricalNormalization(\n                shape=self.aux_router_future_recon_output_dim,\n                eps=self.aux_router_future_recon_norm_eps,\n                update_method=self.aux_router_future_recon_norm_update_method,\n                ema_momentum=self.aux_router_future_recon_norm_ema_momentum,\n            )\n        else:\n            self.aux_router_future_recon_head = None\n            self.aux_router_future_recon_normalizer = None\n        self._apply_base_freeze_router_state()\n\n    def _load_from_state_dict(\n        self,\n        state_dict,\n        prefix,\n        local_metadata,\n        strict,\n        missing_keys,\n        unexpected_keys,\n        error_msgs,\n    ):\n        if self.use_future_cross_attn:\n            # In conditional mode, `obs_embed` is tied to `state_obs_embed`.\n            # Older checkpoints may contain separate weights for both; ensure we\n            # always load the trained state embedding weights.\n            obs_prefix = prefix + \"obs_embed.\"\n            state_prefix = prefix + \"state_obs_embed.\"\n            for suffix in (\"0.weight\", \"0.bias\", \"2.weight\", \"2.bias\"):\n                s_key = state_prefix + suffix\n                o_key = obs_prefix + suffix\n                if s_key in state_dict:\n                    state_dict[o_key] = state_dict[s_key]\n\n        legacy_aux_prefix = prefix + \"aux_command_recon_head.\"\n        current_aux_prefix = prefix + \"aux_router_command_recon_head.\"\n        legacy_aux_keys = [\n            key\n            for key in list(state_dict.keys())\n            if key.startswith(legacy_aux_prefix)\n        ]\n        if legacy_aux_keys:\n            if self.aux_router_command_recon_head is not None:\n                for legacy_key in legacy_aux_keys:\n                    suffix = legacy_key.removeprefix(legacy_aux_prefix)\n                    current_key = current_aux_prefix + suffix\n                    state_dict.setdefault(current_key, state_dict[legacy_key])\n            for legacy_key in legacy_aux_keys:\n                state_dict.pop(legacy_key, None)\n\n        super()._load_from_state_dict(\n            state_dict,\n            prefix,\n            local_metadata,\n            strict,\n            missing_keys,\n            unexpected_keys,\n            error_msgs,\n        )\n        self._apply_freeze_router_state()\n\n    def _router_no_grad_context(self):\n        if self.freeze_router:\n            return torch.no_grad()\n        return nullcontext()\n\n    def _apply_base_freeze_router_state(self) -> None:\n        for layer in self.layers:\n            if isinstance(layer, GroupedMoEBlock):\n                layer._apply_freeze_router_state()\n\n    def _apply_freeze_router_state(self) -> None:\n        self._apply_base_freeze_router_state()\n        if self.aux_router_future_recon_head is not None:\n            self.aux_router_future_recon_head.requires_grad_(\n                not self.freeze_router\n            )\n\n    def _set_cos_sin_cache(self, seq_len):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype\n        )\n\n        # outer product: [seq_len, head_dim/2]\n        freqs = torch.outer(t, self.inv_freq)\n\n        # Concatenate to match rotate_half: [seq_len, head_dim]\n        # Different from complex, here we just concat freqs to match the real-valued rotation logic\n        emb = torch.cat((freqs, freqs), dim=-1)\n\n        # [seq_len, head_dim]\n        self.register_buffer(\"cos_cached\", emb.cos(), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin(), persistent=False)\n\n    def get_cos_sin(self, x, position_ids):\n        \"\"\"根据 position_ids 获取 cos/sin\n        x: [B, T, D]\n        position_ids: [B, T]\n        Returns: cos, sin -> [B, T, D] (broadcastable)\n        \"\"\"\n        # cos_cached: [MaxLen, D]\n        # F.embedding(pos, cache) -> [B, T, D]\n        cos = F.embedding(position_ids, self.cos_cached)\n        sin = F.embedding(position_ids, self.sin_cached)\n        return cos.to(x.dtype), sin.to(x.dtype)\n\n    def _init_last_moe_router_shift_state(self, num_envs: int, device) -> None:\n        if self._last_moe_layer_idx is None:\n            self._prev_last_moe_router_p = None\n            self._prev_last_moe_router_valid = None\n            self._last_moe_router_js_sum = None\n            self._last_moe_router_js_count = None\n            self._last_moe_router_top1_switch_sum = None\n            self._last_moe_router_top1_switch_count = None\n            return\n        self._prev_last_moe_router_p = torch.zeros(\n            num_envs,\n            self.num_fine_experts,\n            device=device,\n            dtype=torch.float32,\n        )\n        self._prev_last_moe_router_valid = torch.zeros(\n            num_envs, device=device, dtype=torch.bool\n        )\n        self._last_moe_router_js_sum = torch.zeros(\n            (), device=device, dtype=torch.float32\n        )\n        self._last_moe_router_js_count = torch.zeros(\n            (), device=device, dtype=torch.float32\n        )\n        self._last_moe_router_top1_switch_sum = torch.zeros(\n            (), device=device, dtype=torch.float32\n        )\n        self._last_moe_router_top1_switch_count = torch.zeros(\n            (), device=device, dtype=torch.float32\n        )\n\n    def _accumulate_last_moe_router_shift(\n        self, router_distribution: torch.Tensor\n    ) -> None:\n        if (\n            self._prev_last_moe_router_p is None\n            or self._prev_last_moe_router_valid is None\n            or self._last_moe_router_js_sum is None\n            or self._last_moe_router_js_count is None\n            or self._last_moe_router_top1_switch_sum is None\n            or self._last_moe_router_top1_switch_count is None\n        ):\n            return\n        if (\n            router_distribution.ndim != 3\n            or int(router_distribution.shape[1]) != 1\n        ):\n            return\n        curr_p = router_distribution[:, 0, :].to(torch.float32)\n        if int(curr_p.shape[0]) != int(self._prev_last_moe_router_p.shape[0]):\n            return\n        prev_valid = self._prev_last_moe_router_valid\n        if torch.any(prev_valid):\n            prev_p = self._prev_last_moe_router_p[prev_valid]\n            curr_p_valid = curr_p[prev_valid]\n            mix_p = 0.5 * (curr_p_valid + prev_p)\n            eps = 1.0e-20\n            curr_safe = curr_p_valid.clamp_min(eps)\n            prev_safe = prev_p.clamp_min(eps)\n            mix_safe = mix_p.clamp_min(eps)\n            kl_curr = (\n                curr_p_valid * (torch.log(curr_safe) - torch.log(mix_safe))\n            ).sum(dim=-1)\n            kl_prev = (\n                prev_p * (torch.log(prev_safe) - torch.log(mix_safe))\n            ).sum(dim=-1)\n            js = 0.5 * (kl_curr + kl_prev)\n            self._last_moe_router_js_sum.add_(js.sum())\n            self._last_moe_router_js_count.add_(float(js.numel()))\n            curr_top1 = curr_p_valid.argmax(dim=-1)\n            prev_top1 = prev_p.argmax(dim=-1)\n            switch = (curr_top1 != prev_top1).to(torch.float32)\n            self._last_moe_router_top1_switch_sum.add_(switch.sum())\n            self._last_moe_router_top1_switch_count.add_(float(switch.numel()))\n        self._prev_last_moe_router_p.copy_(curr_p)\n        self._prev_last_moe_router_valid.fill_(True)\n\n    def get_last_moe_router_shift_stats(\n        self,\n    ) -> dict[str, torch.Tensor | None]:\n        return {\n            \"js_sum\": self._last_moe_router_js_sum,\n            \"js_count\": self._last_moe_router_js_count,\n            \"top1_switch_sum\": self._last_moe_router_top1_switch_sum,\n            \"top1_switch_count\": self._last_moe_router_top1_switch_count,\n        }\n\n    def reset_kv_cache(self, num_envs: int, device):\n        \"\"\"Initialize per-environment KV cache for single-step inference.\"\"\"\n        cache_dtype = (\n            torch.float16\n            if torch.device(device).type == \"cuda\"\n            else torch.float32\n        )\n        self._k_cache = torch.zeros(\n            num_envs,\n            self.n_layers,\n            self.max_ctx_len,\n            self.n_kv_heads,\n            self.head_dim,\n            device=device,\n            dtype=cache_dtype,\n        )\n        self._v_cache = torch.zeros_like(self._k_cache)\n        self._kv_cache_len = torch.zeros(\n            num_envs, dtype=torch.long, device=device\n        )\n        self._kv_cache_write_idx = torch.zeros(\n            num_envs, dtype=torch.long, device=device\n        )\n        self._kv_cache_abs_pos = torch.zeros(\n            num_envs, dtype=torch.long, device=device\n        )\n        self._init_last_moe_router_shift_state(num_envs, device)\n\n    def clear_env_cache(self, env_ids: torch.Tensor | None):\n        \"\"\"Reset KV cache state for specific environments.\"\"\"\n        if self._k_cache is None:\n            return\n        if env_ids is None:\n            self._k_cache.zero_()\n            self._v_cache.zero_()\n            self._kv_cache_len.zero_()\n            self._kv_cache_write_idx.zero_()\n            self._kv_cache_abs_pos.zero_()\n            if self._prev_last_moe_router_p is not None:\n                self._prev_last_moe_router_p.zero_()\n            if self._prev_last_moe_router_valid is not None:\n                self._prev_last_moe_router_valid.zero_()\n            if self._last_moe_router_js_sum is not None:\n                self._last_moe_router_js_sum.zero_()\n            if self._last_moe_router_js_count is not None:\n                self._last_moe_router_js_count.zero_()\n            if self._last_moe_router_top1_switch_sum is not None:\n                self._last_moe_router_top1_switch_sum.zero_()\n            if self._last_moe_router_top1_switch_count is not None:\n                self._last_moe_router_top1_switch_count.zero_()\n        else:\n            self._k_cache[env_ids] = 0.0\n            self._v_cache[env_ids] = 0.0\n            self._kv_cache_len[env_ids] = 0\n            self._kv_cache_write_idx[env_ids] = 0\n            self._kv_cache_abs_pos[env_ids] = 0\n            if self._prev_last_moe_router_valid is not None:\n                self._prev_last_moe_router_valid[env_ids] = False\n            if self._prev_last_moe_router_p is not None:\n                self._prev_last_moe_router_p[env_ids] = 0.0\n\n    def set_collect_routing_stats(self, collect: bool) -> None:\n        collect_flag = bool(collect)\n        for layer_idx, layer in enumerate(self.layers):\n            if isinstance(layer, GroupedMoEBlock):\n                layer.collect_routing_stats = collect_flag\n                layer.collect_router_distribution = (\n                    collect_flag and layer_idx == self._last_moe_layer_idx\n                )\n\n    def reset_routing_stats(self) -> None:\n        for layer in self.layers:\n            if isinstance(layer, GroupedMoEBlock):\n                layer.reset_routing_stats()\n\n    def clear_router_distribution_cache(self) -> None:\n        for layer in self.layers:\n            if isinstance(layer, GroupedMoEBlock):\n                layer.last_router_distribution = None\n                layer.last_router_logits = None\n                layer.capture_router_distribution = False\n                layer.capture_router_logits = False\n\n    def _set_capture_router_distributions(self, capture: bool) -> None:\n        self._set_capture_router_features(\n            capture_distributions=capture,\n            capture_logits=False,\n        )\n\n    def _set_capture_router_features(\n        self,\n        *,\n        capture_distributions: bool,\n        capture_logits: bool,\n    ) -> None:\n        capture_distribution_flag = bool(capture_distributions)\n        capture_logits_flag = bool(capture_logits)\n        for layer in self.layers:\n            if isinstance(layer, GroupedMoEBlock):\n                layer.capture_router_distribution = capture_distribution_flag\n                layer.capture_router_logits = capture_logits_flag\n\n    def apply_dynamic_bias_update_from_stats(self) -> None:\n        for layer in self.layers:\n            if isinstance(layer, GroupedMoEBlock):\n                layer.apply_bias_update_from_counts()\n\n    def _make_causal_mask(self, T: int, device) -> torch.Tensor:\n        \"\"\"Generate causal attention mask: shape [T, T], True where attend allowed.\"\"\"\n        return torch.tril(torch.ones(T, T, device=device, dtype=torch.bool))\n\n    def _forward_layers_range(\n        self,\n        h: torch.Tensor,\n        cos: torch.Tensor | None,\n        sin: torch.Tensor | None,\n        mask: torch.Tensor | None,\n        memory: torch.Tensor | None = None,\n        memory_mask: torch.Tensor | None = None,\n        router_h: torch.Tensor | None = None,\n        router_h_per_layer: list[torch.Tensor | None] | None = None,\n        *,\n        start_layer: int,\n        end_layer: int,\n        return_pre_moe_hidden: bool = False,\n        return_router_features: bool = False,\n        return_router_temporal_features: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...]:\n        \"\"\"Forward through a contiguous layer range with optional checkpointing.\"\"\"\n        if (\n            start_layer < 0\n            or end_layer < start_layer\n            or end_layer > len(self.layers)\n        ):\n            raise ValueError(\n                \"Invalid layer range for _forward_layers_range: \"\n                f\"start_layer={start_layer}, end_layer={end_layer}, \"\n                f\"num_layers={len(self.layers)}.\"\n            )\n        pre_moe_hidden = None\n        router_features = []\n        router_temporal_features = []\n        self._set_capture_router_features(\n            capture_distributions=return_router_features,\n            capture_logits=return_router_temporal_features,\n        )\n        try:\n            for layer_idx in range(start_layer, end_layer):\n                layer = self.layers[layer_idx]\n                layer_router_h = router_h\n                if router_h_per_layer is not None:\n                    layer_router_h = router_h_per_layer[layer_idx]\n                if self.use_checkpointing and self.training:\n                    if isinstance(layer, GroupedMoEBlock):\n                        h = checkpoint.checkpoint(\n                            layer,\n                            h,\n                            cos,\n                            sin,\n                            mask,\n                            memory,\n                            memory_mask,\n                            layer_router_h,\n                            use_reentrant=False,\n                        )\n                    else:\n                        h = checkpoint.checkpoint(\n                            layer,\n                            h,\n                            cos,\n                            sin,\n                            mask,\n                            memory,\n                            memory_mask,\n                            use_reentrant=False,\n                        )\n                else:\n                    if isinstance(layer, GroupedMoEBlock):\n                        h = layer(\n                            h,\n                            cos,\n                            sin,\n                            mask,\n                            memory,\n                            memory_mask,\n                            router_x=layer_router_h,\n                        )\n                    else:\n                        h = layer(h, cos, sin, mask, memory, memory_mask)\n                if return_pre_moe_hidden and layer_idx == 0:\n                    pre_moe_hidden = h\n                if return_router_features and isinstance(\n                    layer, GroupedMoEBlock\n                ):\n                    if layer.last_router_distribution is None:\n                        raise ValueError(\n                            f\"Missing router distribution for MoE layer {layer_idx}.\"\n                        )\n                    router_features.append(layer.last_router_distribution)\n                if return_router_temporal_features and isinstance(\n                    layer, GroupedMoEBlock\n                ):\n                    if layer.last_router_logits is None:\n                        raise ValueError(\n                            f\"Missing router logits for MoE layer {layer_idx}.\"\n                        )\n                    router_temporal_features.append(layer.last_router_logits)\n        finally:\n            self._set_capture_router_features(\n                capture_distributions=False,\n                capture_logits=False,\n            )\n\n        outputs: list[torch.Tensor] = [h]\n        if return_pre_moe_hidden:\n            if pre_moe_hidden is None:\n                raise ValueError(\n                    \"Missing pre-MoE hidden state from the leading dense layer.\"\n                )\n            outputs.append(pre_moe_hidden)\n        if return_router_features:\n            if len(router_features) == 0:\n                raise ValueError(\n                    \"Missing router features while return_router_features=True.\"\n                )\n            outputs.append(torch.cat(router_features, dim=-1))\n        if return_router_temporal_features:\n            if len(router_temporal_features) == 0:\n                raise ValueError(\n                    \"Missing router temporal features while \"\n                    \"return_router_temporal_features=True.\"\n                )\n            outputs.append(torch.cat(router_temporal_features, dim=-1))\n        if len(outputs) == 1:\n            return outputs[0]\n        return tuple(outputs)\n\n    def _forward_layers(\n        self,\n        h: torch.Tensor,\n        cos: torch.Tensor | None,\n        sin: torch.Tensor | None,\n        mask: torch.Tensor | None,\n        memory: torch.Tensor | None = None,\n        memory_mask: torch.Tensor | None = None,\n        router_h: torch.Tensor | None = None,\n        router_h_per_layer: list[torch.Tensor | None] | None = None,\n        return_pre_moe_hidden: bool = False,\n        return_router_features: bool = False,\n        return_router_temporal_features: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...]:\n        return self._forward_layers_range(\n            h,\n            cos,\n            sin,\n            mask,\n            memory,\n            memory_mask,\n            router_h,\n            router_h_per_layer,\n            start_layer=0,\n            end_layer=len(self.layers),\n            return_pre_moe_hidden=return_pre_moe_hidden,\n            return_router_features=return_router_features,\n            return_router_temporal_features=return_router_temporal_features,\n        )\n\n    def _compute_router_hidden(self, x: torch.Tensor) -> torch.Tensor | None:\n        return None\n\n    def sequence_mu(\n        self,\n        x: torch.Tensor,\n        *,\n        attn_mask: torch.Tensor | None = None,\n        return_hidden: bool = False,\n        return_pre_moe_hidden: bool = False,\n        return_router_features: bool = False,\n        return_router_temporal_features: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...]:\n        \"\"\"Compute per-token action mean for sequences.\n\n        Args:\n            x: [B, T, D] flat obs per token.\n            attn_mask: [B, T, T] boolean mask (True if attend allowed), or None for causal.\n            return_hidden: If True, also return the hidden states.\n\n        Returns:\n            mu: [B, T, A]\n            h: [B, T, d_model] (only if return_hidden=True)\n        \"\"\"\n        B, T, _ = x.shape\n        h = self.obs_embed(x)  # [B, T, d_model]\n        router_h = self._compute_router_hidden(x)\n\n        # SDPA bool attention mask uses True = allowed (can attend).\n        if attn_mask is not None:\n            tgt_mask = attn_mask.unsqueeze(1)  # [B, 1, T, T]\n            # Episode-aware positions: first attendable token is episode start.\n            start_idx = attn_mask.to(torch.int64).argmax(dim=-1)  # [B, T]\n            t_idx = torch.arange(T, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(B, T)\n            pos = t_idx - start_idx  # [B, T]\n        else:\n            tgt_mask = None\n            pos = torch.arange(T, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(B, T)\n\n        cos, sin = self.get_cos_sin(h, pos)  # [B, T, head_dim//2]\n        if return_hidden and return_pre_moe_hidden:\n            raise ValueError(\n                \"return_hidden and return_pre_moe_hidden cannot both be True.\"\n            )\n        forward_out = self._forward_layers(\n            h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            router_h=router_h,\n            return_pre_moe_hidden=return_pre_moe_hidden,\n            return_router_features=return_router_features,\n            return_router_temporal_features=return_router_temporal_features,\n        )\n        extras: list[torch.Tensor] = []\n        if isinstance(forward_out, tuple):\n            h = forward_out[0]\n            extras = list(forward_out[1:])\n        else:\n            h = forward_out\n        h = self.norm_f(h)\n        mu = self.action_mu_head(h)\n        outputs: list[torch.Tensor] = [mu]\n        if return_pre_moe_hidden:\n            outputs.append(extras.pop(0))\n        if return_router_features:\n            outputs.append(extras.pop(0))\n        if return_router_temporal_features:\n            outputs.append(extras.pop(0))\n        if len(outputs) > 1:\n            return tuple(outputs)\n        if return_hidden:\n            return mu, h\n        return mu\n\n    def sequence_hidden(\n        self,\n        x: torch.Tensor,\n        *,\n        attn_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"Compute per-token latent features for sequences.\n\n        Args:\n            x: [B, T, D] flat obs per token.\n            attn_mask: [B, T, T] boolean mask (True if attend allowed).\n\n        Returns:\n            h_f: [B, T, d_model]\n        \"\"\"\n        B, T, _ = x.shape\n        h = self.obs_embed(x)  # [B, T, d_model]\n        router_h = self._compute_router_hidden(x)\n\n        if attn_mask is not None:\n            tgt_mask = attn_mask.unsqueeze(1)  # [B, 1, T, T]\n            start_idx = attn_mask.to(torch.int64).argmax(dim=-1)  # [B, T]\n            t_idx = torch.arange(T, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(B, T)\n            pos = t_idx - start_idx  # [B, T]\n        else:\n            tgt_mask = None\n            pos = torch.arange(T, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(B, T)\n\n        cos, sin = self.get_cos_sin(h, pos)\n        h = self._forward_layers(\n            h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            router_h=router_h,\n        )\n        h = self.norm_f(h)\n        return h\n\n    def _embed_future_tokens(\n        self, future_tokens: torch.Tensor\n    ) -> torch.Tensor:\n        if not self.use_future_cross_attn:\n            raise ValueError(\n                \"_embed_future_tokens requires use_future_cross_attn=True\"\n            )\n        if future_tokens.ndim == 3:\n            b, n, d = future_tokens.shape\n            if n != self.future_seq_len:\n                raise ValueError(\n                    f\"future token length mismatch: expected {self.future_seq_len}, got {n}\"\n                )\n            if d != self.future_token_dim:\n                raise ValueError(\n                    f\"future token dim mismatch: expected {self.future_token_dim}, got {d}\"\n                )\n            pos = torch.arange(\n                n, device=future_tokens.device, dtype=torch.long\n            )\n            pos_emb = self.future_pos_embed(pos)[None, :, :]\n            return self.future_obs_embed(future_tokens) + pos_emb\n        if future_tokens.ndim == 4:\n            b, t, n, d = future_tokens.shape\n            if n != self.future_seq_len:\n                raise ValueError(\n                    f\"future token length mismatch: expected {self.future_seq_len}, got {n}\"\n                )\n            if d != self.future_token_dim:\n                raise ValueError(\n                    f\"future token dim mismatch: expected {self.future_token_dim}, got {d}\"\n                )\n            pos = torch.arange(\n                n, device=future_tokens.device, dtype=torch.long\n            )\n            pos_emb = self.future_pos_embed(pos)[None, None, :, :]\n            return self.future_obs_embed(future_tokens) + pos_emb\n        raise ValueError(\n            f\"future_tokens must be 3D or 4D, got shape {tuple(future_tokens.shape)}\"\n        )\n\n    def sequence_mu_cond(\n        self,\n        state_seq: torch.Tensor,\n        future_seq: torch.Tensor,\n        *,\n        attn_mask: torch.Tensor | None = None,\n        future_mask: torch.Tensor | None = None,\n        return_pre_moe_hidden: bool = False,\n        return_router_features: bool = False,\n        return_router_temporal_features: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...]:\n        if not self.use_future_cross_attn:\n            raise ValueError(\n                \"sequence_mu_cond requires use_future_cross_attn=True\"\n            )\n        if state_seq.ndim != 3:\n            raise ValueError(\n                f\"state_seq must have shape [B, T, D], got {tuple(state_seq.shape)}\"\n            )\n        if future_seq.ndim != 4:\n            raise ValueError(\n                \"future_seq must have shape [B, T, N_fut, D_fut], \"\n                f\"got {tuple(future_seq.shape)}\"\n            )\n        b, t, d_state = state_seq.shape\n        bf, tf, n_fut, d_fut = future_seq.shape\n        if bf != b or tf != t:\n            raise ValueError(\n                \"state_seq and future_seq batch/time mismatch: \"\n                f\"state={tuple(state_seq.shape)}, future={tuple(future_seq.shape)}\"\n            )\n        if d_state != self.state_obs_dim:\n            raise ValueError(\n                f\"state_seq dim mismatch: expected {self.state_obs_dim}, got {d_state}\"\n            )\n        if n_fut != self.future_seq_len:\n            raise ValueError(\n                f\"future_seq len mismatch: expected {self.future_seq_len}, got {n_fut}\"\n            )\n        if d_fut != self.future_token_dim:\n            raise ValueError(\n                f\"future_seq dim mismatch: expected {self.future_token_dim}, got {d_fut}\"\n            )\n\n        h = self.state_obs_embed(state_seq)\n        memory = self._embed_future_tokens(future_seq)\n        if future_mask is None:\n            future_mask = torch.ones(\n                b,\n                t,\n                n_fut,\n                dtype=torch.bool,\n                device=state_seq.device,\n            )\n        if future_mask.shape != (b, t, n_fut):\n            raise ValueError(\n                \"future_mask shape mismatch: expected \"\n                f\"{(b, t, n_fut)}, got {tuple(future_mask.shape)}\"\n            )\n\n        if attn_mask is not None:\n            tgt_mask = attn_mask.unsqueeze(1)\n            start_idx = attn_mask.to(torch.int64).argmax(dim=-1)\n            t_idx = torch.arange(t, device=state_seq.device, dtype=torch.long)[\n                None, :\n            ].expand(b, t)\n            pos = t_idx - start_idx\n        else:\n            tgt_mask = None\n            pos = torch.arange(t, device=state_seq.device, dtype=torch.long)[\n                None, :\n            ].expand(b, t)\n\n        cos, sin = self.get_cos_sin(h, pos)\n        forward_out = self._forward_layers(\n            h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            memory=memory,\n            memory_mask=future_mask,\n            return_pre_moe_hidden=return_pre_moe_hidden,\n            return_router_features=return_router_features,\n            return_router_temporal_features=return_router_temporal_features,\n        )\n        extras: list[torch.Tensor] = []\n        if isinstance(forward_out, tuple):\n            h = forward_out[0]\n            extras = list(forward_out[1:])\n        else:\n            h = forward_out\n        h = self.norm_f(h)\n        mu = self.action_mu_head(h)\n        outputs: list[torch.Tensor] = [mu]\n        if return_pre_moe_hidden:\n            outputs.append(extras.pop(0))\n        if return_router_features:\n            outputs.append(extras.pop(0))\n        if return_router_temporal_features:\n            outputs.append(extras.pop(0))\n        if len(outputs) > 1:\n            return tuple(outputs)\n        return mu\n\n    def predict_aux_from_pre_moe(\n        self,\n        pre_moe_hidden: torch.Tensor,\n        *,\n        ref_aux_hidden: torch.Tensor | None = None,\n    ) -> dict[str, torch.Tensor]:\n        if not self.aux_state_pred_enabled:\n            raise ValueError(\n                \"predict_aux_from_pre_moe requires aux_state_pred.enabled=True.\"\n            )\n        if pre_moe_hidden.ndim != 3:\n            raise ValueError(\n                f\"Expected pre_moe_hidden with shape [B, T, D], got {tuple(pre_moe_hidden.shape)}\"\n            )\n        vel_params = self.aux_vel_head(pre_moe_hidden)\n        height_params = self.aux_height_head(pre_moe_hidden)\n        vel_loc, vel_log_std = vel_params.chunk(2, dim=-1)\n        height_loc, height_log_std = height_params.chunk(2, dim=-1)\n        aux_outputs = {\n            \"base_lin_vel_loc\": vel_loc,\n            \"base_lin_vel_log_std\": vel_log_std,\n            \"root_height_loc\": height_loc,\n            \"root_height_log_std\": height_log_std,\n        }\n        if self.aux_contact_head is not None:\n            aux_outputs[\"keybody_contact_logits\"] = self.aux_contact_head(\n                pre_moe_hidden\n            )\n        else:\n            aux_outputs[\"keybody_contact_logits\"] = pre_moe_hidden.new_zeros(\n                pre_moe_hidden.shape[0],\n                pre_moe_hidden.shape[1],\n                0,\n            )\n        if self.aux_denoise_ref_root_lin_vel_head is not None:\n            aux_outputs[\"denoise_ref_root_lin_vel_residual\"] = (\n                self.aux_denoise_ref_root_lin_vel_head(pre_moe_hidden)\n            )\n        if self.aux_denoise_ref_root_ang_vel_head is not None:\n            aux_outputs[\"denoise_ref_root_ang_vel_residual\"] = (\n                self.aux_denoise_ref_root_ang_vel_head(pre_moe_hidden)\n            )\n        if self.aux_ref_keybody_pos_head is not None:\n            aux_outputs[\"ref_keybody_rel_pos\"] = self.aux_ref_keybody_pos_head(\n                pre_moe_hidden\n            ).reshape(\n                pre_moe_hidden.shape[0],\n                pre_moe_hidden.shape[1],\n                self.aux_keybody_pos_dim,\n                3,\n            )\n            aux_outputs[\"robot_keybody_rel_pos\"] = (\n                self.aux_robot_keybody_pos_head(pre_moe_hidden).reshape(\n                    pre_moe_hidden.shape[0],\n                    pre_moe_hidden.shape[1],\n                    self.aux_keybody_pos_dim,\n                    3,\n                )\n            )\n        else:\n            aux_outputs[\"ref_keybody_rel_pos\"] = pre_moe_hidden.new_zeros(\n                pre_moe_hidden.shape[0],\n                pre_moe_hidden.shape[1],\n                0,\n                3,\n            )\n            aux_outputs[\"robot_keybody_rel_pos\"] = pre_moe_hidden.new_zeros(\n                pre_moe_hidden.shape[0],\n                pre_moe_hidden.shape[1],\n                0,\n                3,\n            )\n        if self.aux_denoise_ref_dof_pos_head is not None:\n            aux_outputs[\"denoise_ref_dof_pos_residual\"] = (\n                self.aux_denoise_ref_dof_pos_head(pre_moe_hidden)\n            )\n        return aux_outputs\n\n    def predict_aux_router_command_from_router_features(\n        self, router_features: torch.Tensor\n    ) -> torch.Tensor:\n        if not self.aux_router_command_recon_enabled:\n            raise ValueError(\n                \"predict_aux_router_command_from_router_features requires \"\n                \"aux_router_command_recon.enabled=True.\"\n            )\n        if router_features.ndim != 3:\n            raise ValueError(\n                \"Expected router_features with shape [B, T, D], got \"\n                f\"{tuple(router_features.shape)}.\"\n            )\n        if self.aux_router_command_recon_head is None:\n            raise ValueError(\n                \"aux_router_command_recon_head is not initialized.\"\n            )\n        return self.aux_router_command_recon_head(router_features)\n\n    def update_aux_router_future_recon_normalizer(\n        self, future_target: torch.Tensor\n    ) -> None:\n        if not self.aux_router_future_recon_enabled:\n            raise ValueError(\n                \"update_aux_router_future_recon_normalizer requires \"\n                \"aux_router_future_recon.enabled=True.\"\n            )\n        if self.aux_router_future_recon_normalizer is None:\n            raise ValueError(\n                \"aux_router_future_recon_normalizer is not initialized.\"\n            )\n        if future_target.ndim < 2:\n            raise ValueError(\n                \"Expected future_target with shape [B, D] or [B, T, D], got \"\n                f\"{tuple(future_target.shape)}.\"\n            )\n        flat_target = future_target.reshape(\n            -1, future_target.shape[-1]\n        ).detach()\n        self.aux_router_future_recon_normalizer.update(flat_target)\n\n    def normalize_aux_router_future_recon_target(\n        self, future_target: torch.Tensor\n    ) -> torch.Tensor:\n        if not self.aux_router_future_recon_enabled:\n            raise ValueError(\n                \"normalize_aux_router_future_recon_target requires \"\n                \"aux_router_future_recon.enabled=True.\"\n            )\n        if self.aux_router_future_recon_normalizer is None:\n            raise ValueError(\n                \"aux_router_future_recon_normalizer is not initialized.\"\n            )\n        if future_target.ndim < 2:\n            raise ValueError(\n                \"Expected future_target with shape [B, D] or [B, T, D], got \"\n                f\"{tuple(future_target.shape)}.\"\n            )\n        flat_target = future_target.reshape(-1, future_target.shape[-1])\n        norm_target = self.aux_router_future_recon_normalizer.normalize_only(\n            flat_target\n        )\n        return norm_target.reshape_as(future_target)\n\n    def predict_aux_router_future_recon_from_router_hidden(\n        self, router_hidden: torch.Tensor\n    ) -> torch.Tensor:\n        if not self.aux_router_future_recon_enabled:\n            raise ValueError(\n                \"predict_aux_router_future_recon_from_router_hidden requires \"\n                \"aux_router_future_recon.enabled=True.\"\n            )\n        if router_hidden.ndim != 3:\n            raise ValueError(\n                \"Expected router_hidden with shape [B, T, D], got \"\n                f\"{tuple(router_hidden.shape)}.\"\n            )\n        if self.aux_router_future_recon_head is None:\n            raise ValueError(\n                \"aux_router_future_recon_head is not initialized.\"\n            )\n        return self.aux_router_future_recon_head(router_hidden)\n\n    def single_step_mu_cond(\n        self,\n        state_x: torch.Tensor,\n        future_tokens: torch.Tensor,\n        *,\n        future_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        if not self.use_future_cross_attn:\n            raise ValueError(\n                \"single_step_mu_cond requires use_future_cross_attn=True\"\n            )\n        if state_x.ndim != 2:\n            raise ValueError(f\"Expected state_x [B, D], got {state_x.shape}\")\n        if future_tokens.ndim != 3:\n            raise ValueError(\n                \"Expected future_tokens [B, N_fut, D_fut], \"\n                f\"got {future_tokens.shape}\"\n            )\n        b, d_state = state_x.shape\n        bf, n_fut, d_fut = future_tokens.shape\n        if bf != b:\n            raise ValueError(\n                f\"Batch mismatch between state_x and future_tokens: {b} vs {bf}\"\n            )\n        if d_state != self.state_obs_dim:\n            raise ValueError(\n                f\"state_x dim mismatch: expected {self.state_obs_dim}, got {d_state}\"\n            )\n        if n_fut != self.future_seq_len:\n            raise ValueError(\n                f\"future len mismatch: expected {self.future_seq_len}, got {n_fut}\"\n            )\n        if d_fut != self.future_token_dim:\n            raise ValueError(\n                f\"future dim mismatch: expected {self.future_token_dim}, got {d_fut}\"\n            )\n\n        if self._k_cache is None:\n            state_seq = state_x[:, None, :]\n            future_seq = future_tokens[:, None, :, :]\n            if future_mask is not None:\n                future_mask = future_mask[:, None, :]\n            mu_seq = self.sequence_mu_cond(\n                state_seq,\n                future_seq,\n                attn_mask=None,\n                future_mask=future_mask,\n            )\n            return mu_seq[:, 0, :]\n\n        if self._k_cache.device != state_x.device:\n            self._k_cache = self._k_cache.to(state_x.device)\n            self._v_cache = self._v_cache.to(state_x.device)\n            self._kv_cache_len = self._kv_cache_len.to(state_x.device)\n            self._kv_cache_write_idx = self._kv_cache_write_idx.to(\n                state_x.device\n            )\n            self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(state_x.device)\n\n        h = self.state_obs_embed(state_x)[:, None, :]\n        memory = self._embed_future_tokens(future_tokens)\n\n        if self._k_cache.dtype != h.dtype:\n            self._k_cache = self._k_cache.to(h.dtype)\n            self._v_cache = self._v_cache.to(h.dtype)\n\n        cache_len = self._kv_cache_len\n        insert_pos = self._kv_cache_write_idx\n        max_len = int(self.max_ctx_len)\n        new_len = torch.clamp(cache_len + 1, max=max_len)\n\n        self._kv_cache_len = new_len\n        self._kv_cache_write_idx = (insert_pos + 1) % max_len\n\n        pos = self._kv_cache_abs_pos\n        self._kv_cache_abs_pos = pos + 1\n        pos_ids = pos.unsqueeze(1)\n        cos, sin = self.get_cos_sin(h, pos_ids)\n\n        memory_mask = None\n        if future_mask is not None:\n            if future_mask.shape != (b, n_fut):\n                raise ValueError(\n                    \"future_mask shape mismatch for single-step path: expected \"\n                    f\"{(b, n_fut)}, got {tuple(future_mask.shape)}\"\n                )\n            memory_mask = future_mask[:, None, None, :]\n\n        for layer_idx, layer in enumerate(self.layers):\n            x_norm = layer.norm1(h)\n            k_cache_l = self._k_cache[:, layer_idx]\n            v_cache_l = self._v_cache[:, layer_idx]\n            attn_out, _, _ = layer.attn.forward_single_token(\n                x_norm,\n                cos,\n                sin,\n                k_cache_l,\n                v_cache_l,\n                new_len,\n                insert_pos,\n            )\n            h = h + attn_out\n            if layer.use_cross_attn:\n                h = h + layer.cross_attn(\n                    layer.norm_cross(h), memory, memory_mask\n                )\n            h2 = layer.norm2(h)\n            if isinstance(layer, GroupedMoEBlock):\n                ffn = layer.compute_moe_ffn(h2)\n                if (\n                    layer_idx == self._last_moe_layer_idx\n                    and layer.collect_routing_stats\n                    and layer.last_router_distribution is not None\n                ):\n                    self._accumulate_last_moe_router_shift(\n                        layer.last_router_distribution\n                    )\n            else:\n                ffn = layer.mlp_dropout(layer.mlp(h2))\n            h = h + ffn\n\n        h = self.norm_f(h)\n        return self.action_mu_head(h[:, 0, :])\n\n    def forward(\n        self,\n        input: torch.Tensor,\n        past_key_values: torch.Tensor | None = None,\n        current_pos: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Forward pass for single-step inference (no history).\"\"\"\n        if past_key_values is not None:\n            return self._forward_inference_onnx(\n                input, past_key_values, current_pos\n            )\n        if input.ndim != 2:\n            raise ValueError(f\"Expected [B, D], got {input.shape}\")\n        mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None)\n        return mu_seq[:, 0, :]\n\n    def single_step_mu(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Compute action mean for a single step using per-layer KV cache.\n\n        Uses a ring-buffer KV cache with per-env absolute positions for RoPE.\n        \"\"\"\n        if x.ndim != 2:\n            raise ValueError(f\"Expected [B, D], got {x.shape}\")\n        B, _ = x.shape\n\n        if self._k_cache is None:\n            mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None)\n            return mu_seq[:, 0, :]\n\n        # Ensure cache device matches\n        if self._k_cache.device != x.device:\n            self._k_cache = self._k_cache.to(x.device)\n            self._v_cache = self._v_cache.to(x.device)\n            self._kv_cache_len = self._kv_cache_len.to(x.device)\n            self._kv_cache_write_idx = self._kv_cache_write_idx.to(x.device)\n            self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(x.device)\n\n        h = self.obs_embed(x)[:, None, :]  # [B, 1, d_model]\n        router_h = self._compute_router_hidden(x)\n        if router_h is not None:\n            router_h = router_h[:, None, :]\n\n        # Ensure cache dtype matches compute dtype (convert once if needed)\n        if self._k_cache.dtype != h.dtype:\n            self._k_cache = self._k_cache.to(h.dtype)\n            self._v_cache = self._v_cache.to(h.dtype)\n\n        cache_len = self._kv_cache_len  # [B]\n        insert_pos = self._kv_cache_write_idx  # [B]\n        max_len = int(self.max_ctx_len)\n        new_len = torch.clamp(cache_len + 1, max=max_len)  # [B]\n\n        self._kv_cache_len = new_len\n        self._kv_cache_write_idx = (insert_pos + 1) % max_len\n\n        # RoPE frequencies for current absolute position\n        pos = self._kv_cache_abs_pos  # [B]\n        self._kv_cache_abs_pos = pos + 1\n        pos_ids = pos.unsqueeze(1)  # [B, 1]\n        cos, sin = self.get_cos_sin(h, pos_ids)\n\n        for layer_idx, layer in enumerate(self.layers):\n            x_norm = layer.norm1(h)\n            k_cache_l = self._k_cache[\n                :, layer_idx\n            ]  # [B, L, n_kv_heads, head_dim]\n            v_cache_l = self._v_cache[:, layer_idx]\n            attn_out, _, _ = layer.attn.forward_single_token(\n                x_norm,\n                cos,\n                sin,\n                k_cache_l,\n                v_cache_l,\n                new_len,\n                insert_pos,\n            )\n            h = h + attn_out\n            # FFN/MoE path for single token (h: [B,1,D])\n            h2 = layer.norm2(h)  # [B,1,D]\n            if isinstance(layer, GroupedMoEBlock):\n                ffn = layer.compute_moe_ffn(h2, router_x=router_h)\n                if (\n                    layer_idx == self._last_moe_layer_idx\n                    and layer.collect_routing_stats\n                    and layer.last_router_distribution is not None\n                ):\n                    self._accumulate_last_moe_router_shift(\n                        layer.last_router_distribution\n                    )\n\n            else:\n                ffn = layer.mlp_dropout(layer.mlp(h2))  # 保持 [B, 1, D]\n\n            h = h + ffn\n\n        h = self.norm_f(h)\n        return self.action_mu_head(h[:, 0, :])\n\n    def _forward_inference_onnx(\n        self,\n        x: torch.Tensor,\n        past_key_values: torch.Tensor,\n        current_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, ...]:\n        \"\"\"Single-step inference compatible with ONNX export.\n        Aligns strictly with `single_step_mu` logic using Real-valued RoPE.\n\n        Args:\n            x: [B, D] (Batch=1 for ONNX usually)\n            past_key_values: [n_layers, 2, B, max_len, n_kv_heads, head_dim]\n            current_pos: [B] or scalar, the absolute step index (0, 1, 2...)\n\n        Returns:\n            action: [B, A]\n            present_key_values: Updated KV cache tensor\n        \"\"\"\n        # Embedding [B, D] -> [B, 1, D]\n        h = self.obs_embed(x)[:, None, :]  # [1, 1, 512]\n        router_h = self._compute_router_hidden(x)\n        if router_h is not None:\n            router_h = router_h[:, None, :]\n        B = h.shape[0]  # 1\n\n        # Calculate Cache Indices (Ring Buffer Logic)\n        # past_key_values shape: [L, 2, B, T, H, D] -> T is index 3\n        max_len = past_key_values.shape[3]  # 32\n\n        if current_pos.ndim == 0:\n            current_pos = current_pos.view(1).expand(B)\n\n        # insert_pos: [B]\n        insert_pos = current_pos % max_len\n\n        # new_len: [B]\n        new_len = torch.clamp(current_pos + 1, max=max_len)\n\n        # position_ids: [B, 1]\n        position_ids = current_pos.unsqueeze(1)\n\n        # cos, sin shape: [B, 1, head_dim]\n        cos, sin = self.get_cos_sin(h, position_ids)\n\n        present_key_values_list = []\n        routing_debug_outputs: list[torch.Tensor] = []\n        export_routing_debug = torch.onnx.is_in_onnx_export()\n\n        for i, layer in enumerate(self.layers):\n            # Unpack Cache: [2, B, T, H, D]\n            layer_past = past_key_values[i]\n            k_cache = layer_past[0]\n            v_cache = layer_past[1]\n\n            h_norm = layer.norm1(h)\n\n            # Attention\n            attn_out, new_k_cache, new_v_cache = (\n                layer.attn.forward_single_token(\n                    x=h_norm,\n                    cos=cos,\n                    sin=sin,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    new_len=new_len,\n                    insert_pos=insert_pos,\n                )\n            )\n\n            h = h + attn_out\n\n            # FFN / MoE\n            h_norm2 = layer.norm2(h)\n\n            if isinstance(layer, GroupedMoEBlock):\n                if export_routing_debug:\n                    ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(\n                        h_norm2,\n                        router_x=router_h,\n                        return_routing_debug=True,\n                    )\n                    routing_debug_outputs.extend([topk_idx, router_logits])\n                else:\n                    ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h)\n            else:\n                # Dense MLP\n                ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))\n            h = h + ffn_out\n            current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)\n            present_key_values_list.append(current_layer_kv)\n\n        h = self.norm_f(h)\n        action = self.action_mu_head(h[:, 0, :])\n        present_key_values = torch.stack(present_key_values_list, dim=0)\n\n        if export_routing_debug and routing_debug_outputs:\n            return (action, present_key_values, *routing_debug_outputs)\n        return action, present_key_values\n\n\nclass ReferenceRoutedGroupedMoETransformerPolicy(GroupedMoETransformerPolicy):\n    def __init__(\n        self,\n        input_dim: int,\n        output_dim: int,\n        module_config_dict: dict,\n    ):\n        module_config = dict(module_config_dict)\n        if bool(module_config.get(\"use_future_cross_attn\", False)):\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicy does not support \"\n                \"use_future_cross_attn=True.\"\n            )\n\n        router_input_dim = module_config.get(\"router_input_dim\", None)\n        router_feature_indices = module_config.get(\n            \"router_feature_indices\", None\n        )\n        if router_input_dim is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicy requires router_input_dim.\"\n            )\n        if router_feature_indices is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicy requires \"\n                \"router_feature_indices.\"\n            )\n\n        self.router_input_dim = int(router_input_dim)\n        self.router_feature_indices = tuple(\n            int(idx) for idx in router_feature_indices\n        )\n        if self.router_input_dim <= 0:\n            raise ValueError(\n                f\"router_input_dim must be positive, got {self.router_input_dim}.\"\n            )\n        if len(self.router_feature_indices) != self.router_input_dim:\n            raise ValueError(\n                \"router_input_dim must match len(router_feature_indices): \"\n                f\"{self.router_input_dim} vs {len(self.router_feature_indices)}.\"\n            )\n        if any(idx < 0 for idx in self.router_feature_indices):\n            raise ValueError(\n                f\"router_feature_indices must be non-negative, got {self.router_feature_indices}.\"\n            )\n        super().__init__(\n            input_dim=input_dim,\n            output_dim=output_dim,\n            module_config_dict=module_config,\n        )\n        obs_in = int(self.obs_input_dim or self.input_dim)\n        if any(idx >= obs_in for idx in self.router_feature_indices):\n            raise ValueError(\n                \"router_feature_indices exceed the flat actor obs dim \"\n                f\"{obs_in}: {self.router_feature_indices}\"\n            )\n        self.router_embed_mlp_hidden = int(\n            module_config.get(\n                \"router_embed_mlp_hidden\", self.obs_embed_mlp_hidden\n            )\n        )\n        self.register_buffer(\n            \"_router_feature_indices\",\n            torch.tensor(self.router_feature_indices, dtype=torch.long),\n            persistent=False,\n        )\n        self.router_obs_embed = nn.Sequential(\n            nn.Linear(self.router_input_dim, self.router_embed_mlp_hidden),\n            nn.SiLU(),\n            nn.Linear(self.router_embed_mlp_hidden, self.d_model),\n        )\n        self._apply_freeze_router_state()\n\n    def _apply_freeze_router_state(self) -> None:\n        super()._apply_freeze_router_state()\n        self.router_obs_embed.requires_grad_(not self.freeze_router)\n\n    def _compute_router_hidden(self, x: torch.Tensor) -> torch.Tensor | None:\n        if x.shape[-1] != int(self.obs_input_dim or self.input_dim):\n            raise ValueError(\n                \"Reference-routed policy expected flat obs dim \"\n                f\"{int(self.obs_input_dim or self.input_dim)}, got {x.shape[-1]}.\"\n            )\n        router_idx = self._router_feature_indices\n        if router_idx.device != x.device:\n            router_idx = router_idx.to(x.device)\n        router_obs = torch.index_select(x, dim=x.ndim - 1, index=router_idx)\n        return self.router_obs_embed(router_obs)\n\n    def _forward_inference_onnx_cond(\n        self,\n        state_x: torch.Tensor,\n        future_tokens: torch.Tensor,\n        past_key_values: torch.Tensor,\n        current_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, ...]:\n        if not self.use_future_cross_attn:\n            raise ValueError(\n                \"_forward_inference_onnx_cond requires use_future_cross_attn=True\"\n            )\n        if state_x.ndim != 2:\n            raise ValueError(\n                f\"state_x must have shape [B, D_state], got {tuple(state_x.shape)}\"\n            )\n        if future_tokens.ndim != 3:\n            raise ValueError(\n                \"future_tokens must have shape [B, N_fut, D_fut], \"\n                f\"got {tuple(future_tokens.shape)}\"\n            )\n        h = self.state_obs_embed(state_x)[:, None, :]\n        memory = self._embed_future_tokens(future_tokens)\n        b = h.shape[0]\n        max_len = past_key_values.shape[3]\n        if current_pos.ndim == 0:\n            current_pos = current_pos.view(1).expand(b)\n        insert_pos = current_pos % max_len\n        new_len = torch.clamp(current_pos + 1, max=max_len)\n        position_ids = current_pos.unsqueeze(1)\n        cos, sin = self.get_cos_sin(h, position_ids)\n\n        present_key_values_list = []\n        routing_debug_outputs: list[torch.Tensor] = []\n        export_routing_debug = torch.onnx.is_in_onnx_export()\n        for i, layer in enumerate(self.layers):\n            layer_past = past_key_values[i]\n            k_cache = layer_past[0]\n            v_cache = layer_past[1]\n\n            h_norm = layer.norm1(h)\n            attn_out, new_k_cache, new_v_cache = (\n                layer.attn.forward_single_token(\n                    x=h_norm,\n                    cos=cos,\n                    sin=sin,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    new_len=new_len,\n                    insert_pos=insert_pos,\n                )\n            )\n            h = h + attn_out\n\n            if layer.use_cross_attn:\n                h = h + layer.cross_attn(layer.norm_cross(h), memory, None)\n\n            h_norm2 = layer.norm2(h)\n            if isinstance(layer, GroupedMoEBlock):\n                if export_routing_debug:\n                    ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(\n                        h_norm2,\n                        return_routing_debug=True,\n                    )\n                    routing_debug_outputs.extend([topk_idx, router_logits])\n                else:\n                    ffn_out = layer.compute_moe_ffn(h_norm2)\n            else:\n                ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))\n            h = h + ffn_out\n            current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)\n            present_key_values_list.append(current_layer_kv)\n\n        h = self.norm_f(h)\n        action = self.action_mu_head(h[:, 0, :])\n        present_key_values = torch.stack(present_key_values_list, dim=0)\n        if export_routing_debug and routing_debug_outputs:\n            return (action, present_key_values, *routing_debug_outputs)\n        return action, present_key_values\n\n\nclass ReferenceRoutedGroupedMoETransformerPolicyV2(\n    GroupedMoETransformerPolicy\n):\n    supports_explicit_ref_aux_hidden = True\n\n    def __init__(\n        self,\n        input_dim: int,\n        output_dim: int,\n        module_config_dict: dict,\n    ):\n        module_config = dict(module_config_dict)\n        if bool(module_config.get(\"use_future_cross_attn\", False)):\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 does not \"\n                \"support use_future_cross_attn=True.\"\n            )\n        state_obs_input_dim = module_config.get(\"state_obs_input_dim\", None)\n        ref_cur_token_dim = module_config.get(\"ref_cur_token_dim\", None)\n        ref_fut_token_dim = module_config.get(\"ref_fut_token_dim\", None)\n        ref_fut_seq_len = module_config.get(\"ref_fut_seq_len\", None)\n        state_feature_indices = module_config.get(\n            \"state_feature_indices\", None\n        )\n        ref_cur_feature_indices = module_config.get(\n            \"ref_cur_feature_indices\", None\n        )\n        ref_fut_slices = module_config.get(\"ref_fut_slices\", None)\n        if state_obs_input_dim is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 requires \"\n                \"state_obs_input_dim.\"\n            )\n        if ref_cur_token_dim is None or ref_fut_token_dim is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 requires \"\n                \"ref_cur_token_dim and ref_fut_token_dim.\"\n            )\n        if ref_fut_seq_len is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 requires \"\n                \"ref_fut_seq_len.\"\n            )\n        if state_feature_indices is None or ref_cur_feature_indices is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 requires \"\n                \"state_feature_indices and ref_cur_feature_indices.\"\n            )\n        if ref_fut_slices is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 requires \"\n                \"ref_fut_slices.\"\n            )\n\n        self.full_obs_input_dim = int(input_dim)\n        self.state_obs_input_dim = int(state_obs_input_dim)\n        self.ref_cur_token_dim = int(ref_cur_token_dim)\n        self.ref_fut_token_dim = int(ref_fut_token_dim)\n        self.ref_fut_seq_len = int(ref_fut_seq_len)\n        self.state_feature_indices = tuple(\n            int(idx) for idx in state_feature_indices\n        )\n        self.ref_cur_feature_indices = tuple(\n            int(idx) for idx in ref_cur_feature_indices\n        )\n        self.ref_fut_slices = tuple(\n            (int(start), int(end), int(dim))\n            for start, end, dim in ref_fut_slices\n        )\n        if self.state_obs_input_dim <= 0:\n            raise ValueError(\n                \"state_obs_input_dim must be positive, got \"\n                f\"{self.state_obs_input_dim}.\"\n            )\n        if self.ref_cur_token_dim <= 0 or self.ref_fut_token_dim <= 0:\n            raise ValueError(\n                \"ref token dims must be positive, got \"\n                f\"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}.\"\n            )\n        if self.ref_cur_token_dim != self.ref_fut_token_dim:\n            raise ValueError(\n                \"current/future ref token dims must match, got \"\n                f\"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}.\"\n            )\n        if self.ref_fut_seq_len <= 0:\n            raise ValueError(\n                f\"ref_fut_seq_len must be positive, got {self.ref_fut_seq_len}.\"\n            )\n        if len(self.state_feature_indices) != self.state_obs_input_dim:\n            raise ValueError(\n                \"state_obs_input_dim must match len(state_feature_indices): \"\n                f\"{self.state_obs_input_dim} vs {len(self.state_feature_indices)}.\"\n            )\n        if len(self.ref_cur_feature_indices) != self.ref_cur_token_dim:\n            raise ValueError(\n                \"ref_cur_token_dim must match len(ref_cur_feature_indices): \"\n                f\"{self.ref_cur_token_dim} vs {len(self.ref_cur_feature_indices)}.\"\n            )\n        fut_flat_dim = 0\n        for start, end, dim in self.ref_fut_slices:\n            if end <= start or dim <= 0:\n                raise ValueError(\n                    f\"Invalid ref_fut_slices entry {(start, end, dim)}.\"\n                )\n            if (end - start) != self.ref_fut_seq_len * dim:\n                raise ValueError(\n                    \"Future ref slice span must equal ref_fut_seq_len * dim, got \"\n                    f\"{(start, end, dim)} with ref_fut_seq_len={self.ref_fut_seq_len}.\"\n                )\n            fut_flat_dim += end - start\n        expected_full_input_dim = (\n            self.state_obs_input_dim + self.ref_cur_token_dim + fut_flat_dim\n        )\n        if self.full_obs_input_dim != expected_full_input_dim:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 expected full \"\n                f\"input dim {expected_full_input_dim}, got {self.full_obs_input_dim}.\"\n            )\n\n        self.ref_hist_n_layers = int(module_config.get(\"ref_hist_n_layers\", 1))\n        if self.ref_hist_n_layers != 1:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV2 currently supports \"\n                \"exactly one ref history attention layer.\"\n            )\n        self.ref_future_conv_channels = int(\n            module_config.get(\n                \"ref_future_conv_channels\", self.ref_cur_token_dim\n            )\n        )\n        self.ref_future_conv_layers = int(\n            module_config.get(\"ref_future_conv_layers\", 2)\n        )\n        self.ref_future_conv_kernel_size = int(\n            module_config.get(\"ref_future_conv_kernel_size\", 3)\n        )\n        self.ref_future_conv_stride = int(\n            module_config.get(\"ref_future_conv_stride\", 2)\n        )\n        if self.ref_future_conv_layers <= 0:\n            raise ValueError(\n                \"ref_future_conv_layers must be positive, got \"\n                f\"{self.ref_future_conv_layers}.\"\n            )\n        if self.ref_future_conv_kernel_size <= 0:\n            raise ValueError(\n                \"ref_future_conv_kernel_size must be positive, got \"\n                f\"{self.ref_future_conv_kernel_size}.\"\n            )\n        if self.ref_future_conv_stride <= 0:\n            raise ValueError(\n                \"ref_future_conv_stride must be positive, got \"\n                f\"{self.ref_future_conv_stride}.\"\n            )\n\n        module_config[\"input_dim_override\"] = self.state_obs_input_dim\n        super().__init__(\n            input_dim=input_dim,\n            output_dim=output_dim,\n            module_config_dict=module_config,\n        )\n        self.onnx_kv_layers = int(self.ref_hist_n_layers + self.n_layers)\n        self.register_buffer(\n            \"_state_feature_indices\",\n            torch.tensor(self.state_feature_indices, dtype=torch.long),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"_ref_cur_feature_indices\",\n            torch.tensor(self.ref_cur_feature_indices, dtype=torch.long),\n            persistent=False,\n        )\n\n        self.ref_frame_embed = nn.Sequential(\n            nn.Linear(self.ref_cur_token_dim, self.obs_embed_mlp_hidden),\n            nn.SiLU(),\n            nn.Linear(self.obs_embed_mlp_hidden, self.d_model),\n        )\n        self.ref_hist_norm = RMSNorm(self.d_model)\n        self.ref_hist_attn = ModernAttention(\n            d_model=self.d_model,\n            n_heads=self.n_heads,\n            n_kv_heads=self.n_kv_heads,\n            use_qk_norm=self.use_qk_norm,\n            use_gated_attn=self.use_gated_attn,\n            gated_attn_type=self.gated_attn_type,\n            attn_dropout=self.attn_dropout,\n        )\n        self.ref_hist_out_norm = RMSNorm(self.d_model)\n\n        padding = self.ref_future_conv_kernel_size // 2\n        conv_modules: list[nn.Module] = []\n        in_ch = self.d_model\n        for layer_idx in range(self.ref_future_conv_layers):\n            out_ch = (\n                self.d_model\n                if layer_idx == self.ref_future_conv_layers - 1\n                else self.ref_future_conv_channels\n            )\n            conv_modules.append(\n                nn.Conv1d(\n                    in_channels=in_ch,\n                    out_channels=out_ch,\n                    kernel_size=self.ref_future_conv_kernel_size,\n                    stride=self.ref_future_conv_stride,\n                    padding=padding,\n                    bias=True,\n                )\n            )\n            conv_modules.append(nn.SiLU())\n            in_ch = out_ch\n        self.ref_future_conv = nn.Sequential(*conv_modules)\n\n        self.actor_ref_pool = SingleQueryAttentionPool(self.d_model)\n        self.router_ref_pool = SingleQueryAttentionPool(self.d_model)\n        self.router_query = nn.Parameter(torch.zeros(self.d_model))\n        self.actor_ref_ctx_norm = RMSNorm(self.d_model)\n        self.actor_film_hidden_norm = RMSNorm(self.d_model)\n        self.actor_ref_film = nn.Sequential(\n            nn.Linear(self.d_model, self.d_model),\n            nn.SiLU(),\n            nn.Linear(self.d_model, 2 * self.d_model),\n        )\n        nn.init.zeros_(self.actor_ref_film[-1].weight)\n        nn.init.zeros_(self.actor_ref_film[-1].bias)\n        self.actor_film_gain_max = float(\n            module_config.get(\"actor_film_gain_max\", 1.0)\n        )\n        self.actor_film_gain_init = float(\n            module_config.get(\"actor_film_gain_init\", 0.05)\n        )\n        if self.actor_film_gain_max <= 0.0:\n            raise ValueError(\n                \"actor_film_gain_max must be positive, got \"\n                f\"{self.actor_film_gain_max}.\"\n            )\n        if not (0.0 < self.actor_film_gain_init < self.actor_film_gain_max):\n            raise ValueError(\n                \"actor_film_gain_init must be in (0, actor_film_gain_max), \"\n                f\"got {self.actor_film_gain_init} with max \"\n                f\"{self.actor_film_gain_max}.\"\n            )\n        gain_init_ratio = self.actor_film_gain_init / self.actor_film_gain_max\n        self.actor_film_gain_raw = nn.Parameter(\n            torch.full(\n                (self.d_model,),\n                math.log(gain_init_ratio / (1.0 - gain_init_ratio)),\n            )\n        )\n        self.actor_film_scale_max = 0.5\n        self.actor_film_shift_max = 0.5\n        self.actor_film_delta_rms_eps = float(\n            module_config.get(\"actor_film_delta_rms_eps\", 1.0e-6)\n        )\n        if self.actor_film_delta_rms_eps <= 0.0:\n            raise ValueError(\n                \"actor_film_delta_rms_eps must be positive, got \"\n                f\"{self.actor_film_delta_rms_eps}.\"\n            )\n\n        self._ref_hist_k_cache: torch.Tensor | None = None\n        self._ref_hist_v_cache: torch.Tensor | None = None\n        self._apply_freeze_router_state()\n\n    def _apply_freeze_router_state(self) -> None:\n        super()._apply_freeze_router_state()\n        requires_grad = not self.freeze_router\n        self.ref_frame_embed.requires_grad_(requires_grad)\n        self.ref_hist_norm.requires_grad_(requires_grad)\n        self.ref_hist_attn.requires_grad_(requires_grad)\n        self.ref_hist_out_norm.requires_grad_(requires_grad)\n        self.ref_future_conv.requires_grad_(requires_grad)\n        self.router_ref_pool.requires_grad_(requires_grad)\n        self.router_query.requires_grad_(requires_grad)\n\n    def _build_shared_ref_tokens(\n        self,\n        ref_cur_x: torch.Tensor,\n        ref_fut_x: torch.Tensor,\n        pos: torch.Tensor,\n        tgt_mask: torch.Tensor | None,\n    ) -> torch.Tensor:\n        with self._router_no_grad_context():\n            ref_cur_h = self.ref_frame_embed(ref_cur_x)\n            ref_hist_attn = self.ref_hist_attn(\n                self.ref_hist_norm(ref_cur_h),\n                *self.get_cos_sin(ref_cur_h, pos),\n                mask=tgt_mask,\n            )\n            ref_hist_h = self.ref_hist_out_norm(ref_cur_h + ref_hist_attn)\n            ref_fut_tokens = self._encode_future_tokens(ref_fut_x)\n            return torch.cat([ref_hist_h.unsqueeze(2), ref_fut_tokens], dim=2)\n\n    def _build_shared_ref_tokens_single_step(\n        self,\n        ref_cur_x: torch.Tensor,\n        ref_fut_x: torch.Tensor,\n        pos_ids: torch.Tensor,\n        *,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        new_len: torch.Tensor,\n        insert_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        with self._router_no_grad_context():\n            ref_cur_h = self.ref_frame_embed(ref_cur_x)[:, None, :]\n            ref_cos, ref_sin = self.get_cos_sin(ref_cur_h, pos_ids)\n            ref_hist_attn, ref_k_cache, ref_v_cache = (\n                self.ref_hist_attn.forward_single_token(\n                    self.ref_hist_norm(ref_cur_h),\n                    ref_cos,\n                    ref_sin,\n                    k_cache,\n                    v_cache,\n                    new_len,\n                    insert_pos,\n                )\n            )\n            ref_hist_h = self.ref_hist_out_norm(ref_cur_h + ref_hist_attn)\n            ref_fut_tokens = self._encode_future_tokens(ref_fut_x)\n            shared_ref_tokens = torch.cat([ref_hist_h, ref_fut_tokens], dim=1)\n        return shared_ref_tokens, ref_k_cache, ref_v_cache\n\n    def _split_actor_ref_inputs(\n        self, x: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        if x.ndim not in (2, 3):\n            raise ValueError(\n                f\"Expected full obs tensor with ndim 2 or 3, got {x.ndim}.\"\n            )\n        if int(x.shape[-1]) != self.full_obs_input_dim:\n            raise ValueError(\n                \"Full obs dim mismatch for reference router V2: expected \"\n                f\"{self.full_obs_input_dim}, got {int(x.shape[-1])}.\"\n            )\n        state_idx = self._state_feature_indices.to(x.device)\n        ref_cur_idx = self._ref_cur_feature_indices.to(x.device)\n        state_x = torch.index_select(x, dim=x.ndim - 1, index=state_idx)\n        ref_cur_x = torch.index_select(x, dim=x.ndim - 1, index=ref_cur_idx)\n        fut_parts: list[torch.Tensor] = []\n        for start, end, dim in self.ref_fut_slices:\n            chunk = x[..., start:end]\n            if x.ndim == 2:\n                fut_parts.append(\n                    chunk.reshape(int(x.shape[0]), self.ref_fut_seq_len, dim)\n                )\n            else:\n                fut_parts.append(\n                    chunk.reshape(\n                        int(x.shape[0]),\n                        int(x.shape[1]),\n                        self.ref_fut_seq_len,\n                        dim,\n                    )\n                )\n        ref_fut_x = torch.cat(fut_parts, dim=-1)\n        return state_x, ref_cur_x, ref_fut_x\n\n    def _encode_future_tokens(self, ref_fut_x: torch.Tensor) -> torch.Tensor:\n        if ref_fut_x.ndim == 3:\n            fut = self.ref_frame_embed(ref_fut_x)\n            return self.ref_future_conv(fut.transpose(1, 2)).transpose(1, 2)\n        if ref_fut_x.ndim == 4:\n            batch, time, seq_len, dim = ref_fut_x.shape\n            fut = self.ref_frame_embed(\n                ref_fut_x.reshape(batch * time, seq_len, dim)\n            )\n            fut = self.ref_future_conv(fut.transpose(1, 2)).transpose(1, 2)\n            return fut.reshape(batch, time, fut.shape[1], self.d_model)\n        raise ValueError(\n            f\"Expected ref_fut_x with ndim 3 or 4, got {ref_fut_x.ndim}.\"\n        )\n\n    def _pool_router_context(\n        self, shared_ref_tokens: torch.Tensor\n    ) -> torch.Tensor:\n        with self._router_no_grad_context():\n            if shared_ref_tokens.ndim == 3:\n                query = self.router_query.to(\n                    device=shared_ref_tokens.device,\n                    dtype=shared_ref_tokens.dtype,\n                )[None, :].expand(int(shared_ref_tokens.shape[0]), -1)\n            elif shared_ref_tokens.ndim == 4:\n                query = self.router_query.to(\n                    device=shared_ref_tokens.device,\n                    dtype=shared_ref_tokens.dtype,\n                )[None, None, :].expand(\n                    int(shared_ref_tokens.shape[0]),\n                    int(shared_ref_tokens.shape[1]),\n                    -1,\n                )\n            else:\n                raise ValueError(\n                    \"shared_ref_tokens must have ndim 3 or 4, got \"\n                    f\"{shared_ref_tokens.ndim}.\"\n                )\n            return self.router_ref_pool(query, shared_ref_tokens)\n\n    def _apply_actor_ref_film(\n        self, state_hidden: torch.Tensor, actor_ref_ctx: torch.Tensor\n    ) -> torch.Tensor:\n        ctx = self.actor_ref_ctx_norm(actor_ref_ctx)\n        scale_raw, shift_raw = self.actor_ref_film(ctx).chunk(2, dim=-1)\n        scale = self.actor_film_scale_max * torch.tanh(scale_raw)\n        shift = self.actor_film_shift_max * torch.tanh(shift_raw)\n        hidden_norm = self.actor_film_hidden_norm(state_hidden)\n        delta = scale * hidden_norm + shift\n        delta = self._normalize_actor_film_delta(delta)\n        gain = self._actor_film_gain().to(\n            device=state_hidden.device, dtype=state_hidden.dtype\n        )\n        expand_shape = [1] * (delta.ndim - 1) + [self.d_model]\n        return state_hidden + delta * gain.view(*expand_shape)\n\n    def _actor_film_gain(self) -> torch.Tensor:\n        return self.actor_film_gain_max * torch.sigmoid(\n            self.actor_film_gain_raw\n        )\n\n    def _normalize_actor_film_delta(self, delta: torch.Tensor) -> torch.Tensor:\n        rms = delta.pow(2).mean(dim=-1, keepdim=True)\n        return delta * torch.rsqrt(rms + self.actor_film_delta_rms_eps)\n\n    def _ensure_internal_cache_device(\n        self,\n        device,\n        *,\n        dtype: torch.dtype | None = None,\n    ) -> None:\n        if self._k_cache is not None and self._k_cache.device != device:\n            self._k_cache = self._k_cache.to(device)\n            self._v_cache = self._v_cache.to(device)\n            self._ref_hist_k_cache = self._ref_hist_k_cache.to(device)\n            self._ref_hist_v_cache = self._ref_hist_v_cache.to(device)\n            self._kv_cache_len = self._kv_cache_len.to(device)\n            self._kv_cache_write_idx = self._kv_cache_write_idx.to(device)\n            self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(device)\n        if (\n            dtype is not None\n            and self._k_cache is not None\n            and self._k_cache.dtype != dtype\n        ):\n            self._k_cache = self._k_cache.to(dtype)\n            self._v_cache = self._v_cache.to(dtype)\n            self._ref_hist_k_cache = self._ref_hist_k_cache.to(dtype)\n            self._ref_hist_v_cache = self._ref_hist_v_cache.to(dtype)\n\n    def reset_kv_cache(self, num_envs: int, device):\n        cache_dtype = (\n            torch.float16\n            if torch.device(device).type == \"cuda\"\n            else torch.float32\n        )\n        self._k_cache = torch.zeros(\n            num_envs,\n            self.n_layers,\n            self.max_ctx_len,\n            self.n_kv_heads,\n            self.head_dim,\n            device=device,\n            dtype=cache_dtype,\n        )\n        self._v_cache = torch.zeros_like(self._k_cache)\n        self._ref_hist_k_cache = torch.zeros(\n            num_envs,\n            self.ref_hist_n_layers,\n            self.max_ctx_len,\n            self.n_kv_heads,\n            self.head_dim,\n            device=device,\n            dtype=cache_dtype,\n        )\n        self._ref_hist_v_cache = torch.zeros_like(self._ref_hist_k_cache)\n        self._kv_cache_len = torch.zeros(\n            num_envs, dtype=torch.long, device=device\n        )\n        self._kv_cache_write_idx = torch.zeros(\n            num_envs, dtype=torch.long, device=device\n        )\n        self._kv_cache_abs_pos = torch.zeros(\n            num_envs, dtype=torch.long, device=device\n        )\n        self._init_last_moe_router_shift_state(num_envs, device)\n\n    def clear_env_cache(self, env_ids: torch.Tensor | None):\n        if self._k_cache is None:\n            return\n        if env_ids is None:\n            self._k_cache.zero_()\n            self._v_cache.zero_()\n            self._ref_hist_k_cache.zero_()\n            self._ref_hist_v_cache.zero_()\n            self._kv_cache_len.zero_()\n            self._kv_cache_write_idx.zero_()\n            self._kv_cache_abs_pos.zero_()\n            if self._prev_last_moe_router_p is not None:\n                self._prev_last_moe_router_p.zero_()\n            if self._prev_last_moe_router_valid is not None:\n                self._prev_last_moe_router_valid.zero_()\n            if self._last_moe_router_js_sum is not None:\n                self._last_moe_router_js_sum.zero_()\n            if self._last_moe_router_js_count is not None:\n                self._last_moe_router_js_count.zero_()\n            if self._last_moe_router_top1_switch_sum is not None:\n                self._last_moe_router_top1_switch_sum.zero_()\n            if self._last_moe_router_top1_switch_count is not None:\n                self._last_moe_router_top1_switch_count.zero_()\n            return\n        self._k_cache[env_ids] = 0.0\n        self._v_cache[env_ids] = 0.0\n        self._ref_hist_k_cache[env_ids] = 0.0\n        self._ref_hist_v_cache[env_ids] = 0.0\n        self._kv_cache_len[env_ids] = 0\n        self._kv_cache_write_idx[env_ids] = 0\n        self._kv_cache_abs_pos[env_ids] = 0\n        if self._prev_last_moe_router_valid is not None:\n            self._prev_last_moe_router_valid[env_ids] = False\n        if self._prev_last_moe_router_p is not None:\n            self._prev_last_moe_router_p[env_ids] = 0.0\n\n    def predict_aux_from_pre_moe(\n        self,\n        pre_moe_hidden: torch.Tensor,\n        *,\n        ref_aux_hidden: torch.Tensor | None = None,\n    ) -> dict[str, torch.Tensor]:\n        aux_outputs = super().predict_aux_from_pre_moe(\n            pre_moe_hidden, ref_aux_hidden=ref_aux_hidden\n        )\n        if self.aux_ref_keybody_pos_head is not None:\n            if ref_aux_hidden is None:\n                raise ValueError(\n                    \"Missing shared-ref auxiliary hidden state for \"\n                    \"ref_keybody_rel_pos prediction.\"\n                )\n            aux_outputs[\"ref_keybody_rel_pos\"] = self.aux_ref_keybody_pos_head(\n                ref_aux_hidden\n            ).reshape(\n                ref_aux_hidden.shape[0],\n                ref_aux_hidden.shape[1],\n                self.aux_keybody_pos_dim,\n                3,\n            )\n        return aux_outputs\n\n    def sequence_mu(\n        self,\n        x: torch.Tensor,\n        *,\n        attn_mask: torch.Tensor | None = None,\n        return_hidden: bool = False,\n        return_pre_moe_hidden: bool = False,\n        return_ref_aux_hidden: bool = False,\n        return_router_features: bool = False,\n        return_router_temporal_features: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...]:\n        _, time, _ = x.shape\n        state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)\n        state_h = self.obs_embed(state_x)\n\n        if attn_mask is not None:\n            tgt_mask = attn_mask.unsqueeze(1)\n            start_idx = attn_mask.to(torch.int64).argmax(dim=-1)\n            t_idx = torch.arange(time, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(int(x.shape[0]), time)\n            pos = t_idx - start_idx\n        else:\n            tgt_mask = None\n            pos = torch.arange(time, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(int(x.shape[0]), time)\n\n        shared_ref_tokens = self._build_shared_ref_tokens(\n            ref_cur_x=ref_cur_x,\n            ref_fut_x=ref_fut_x,\n            pos=pos,\n            tgt_mask=tgt_mask,\n        )\n        actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)\n        router_h = self._pool_router_context(shared_ref_tokens)\n        cos, sin = self.get_cos_sin(state_h, pos)\n        if return_hidden and return_pre_moe_hidden:\n            raise ValueError(\n                \"return_hidden and return_pre_moe_hidden cannot both be True.\"\n            )\n        block0_h = self._forward_layers_range(\n            state_h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            router_h=router_h,\n            start_layer=0,\n            end_layer=1,\n        )\n        h = self._apply_actor_ref_film(block0_h, actor_ref_ctx)\n        forward_out = self._forward_layers_range(\n            h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            router_h=router_h,\n            start_layer=1,\n            end_layer=len(self.layers),\n            return_router_features=return_router_features,\n            return_router_temporal_features=return_router_temporal_features,\n        )\n        extras: list[torch.Tensor] = []\n        if isinstance(forward_out, tuple):\n            h = forward_out[0]\n            extras = list(forward_out[1:])\n        else:\n            h = forward_out\n        h = self.norm_f(h)\n        mu = self.action_mu_head(h)\n        outputs: list[torch.Tensor] = [mu]\n        if return_pre_moe_hidden:\n            outputs.append(block0_h)\n        if return_ref_aux_hidden:\n            outputs.append(router_h)\n        if return_router_features:\n            outputs.append(extras.pop(0))\n        if return_router_temporal_features:\n            outputs.append(extras.pop(0))\n        if len(outputs) > 1:\n            return tuple(outputs)\n        if return_hidden:\n            return mu, h\n        return mu\n\n    def sequence_hidden(\n        self,\n        x: torch.Tensor,\n        *,\n        attn_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        _, time, _ = x.shape\n        state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)\n        state_h = self.obs_embed(state_x)\n\n        if attn_mask is not None:\n            tgt_mask = attn_mask.unsqueeze(1)\n            start_idx = attn_mask.to(torch.int64).argmax(dim=-1)\n            t_idx = torch.arange(time, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(int(x.shape[0]), time)\n            pos = t_idx - start_idx\n        else:\n            tgt_mask = None\n            pos = torch.arange(time, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(int(x.shape[0]), time)\n\n        shared_ref_tokens = self._build_shared_ref_tokens(\n            ref_cur_x=ref_cur_x,\n            ref_fut_x=ref_fut_x,\n            pos=pos,\n            tgt_mask=tgt_mask,\n        )\n        actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)\n        router_h = self._pool_router_context(shared_ref_tokens)\n        cos, sin = self.get_cos_sin(state_h, pos)\n        h = self._forward_layers_range(\n            state_h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            router_h=router_h,\n            start_layer=0,\n            end_layer=1,\n        )\n        h = self._apply_actor_ref_film(h, actor_ref_ctx)\n        h = self._forward_layers_range(\n            h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            router_h=router_h,\n            start_layer=1,\n            end_layer=len(self.layers),\n        )\n        h = self.norm_f(h)\n        return h\n\n    def forward(\n        self,\n        input: torch.Tensor,\n        past_key_values: torch.Tensor | None = None,\n        current_pos: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        if past_key_values is not None:\n            return self._forward_inference_onnx(\n                input, past_key_values, current_pos\n            )\n        if input.ndim != 2:\n            raise ValueError(f\"Expected [B, D], got {input.shape}\")\n        mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None)\n        return mu_seq[:, 0, :]\n\n    def single_step_mu(self, x: torch.Tensor) -> torch.Tensor:\n        if x.ndim != 2:\n            raise ValueError(f\"Expected [B, D], got {x.shape}\")\n        state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)\n        batch = int(state_x.shape[0])\n        if self._k_cache is None:\n            mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None)\n            return mu_seq[:, 0, :]\n\n        state_h = self.obs_embed(state_x)\n        self._ensure_internal_cache_device(x.device, dtype=state_h.dtype)\n\n        cache_len = self._kv_cache_len\n        insert_pos = self._kv_cache_write_idx\n        max_len = int(self.max_ctx_len)\n        new_len = torch.clamp(cache_len + 1, max=max_len)\n\n        self._kv_cache_len = new_len\n        self._kv_cache_write_idx = (insert_pos + 1) % max_len\n\n        pos = self._kv_cache_abs_pos\n        self._kv_cache_abs_pos = pos + 1\n        pos_ids = pos.unsqueeze(1)\n        shared_ref_tokens, _, _ = self._build_shared_ref_tokens_single_step(\n            ref_cur_x=ref_cur_x,\n            ref_fut_x=ref_fut_x,\n            pos_ids=pos_ids,\n            k_cache=self._ref_hist_k_cache[:, 0],\n            v_cache=self._ref_hist_v_cache[:, 0],\n            new_len=new_len,\n            insert_pos=insert_pos,\n        )\n        actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)[\n            :, None, :\n        ]\n        router_h = self._pool_router_context(shared_ref_tokens)[:, None, :]\n        cos, sin = self.get_cos_sin(state_h[:, None, :], pos_ids)\n\n        h = state_h[:, None, :]\n        for layer_idx, layer in enumerate(self.layers[:1]):\n            x_norm = layer.norm1(h)\n            k_cache_l = self._k_cache[:, layer_idx]\n            v_cache_l = self._v_cache[:, layer_idx]\n            attn_out, _, _ = layer.attn.forward_single_token(\n                x_norm,\n                cos,\n                sin,\n                k_cache_l,\n                v_cache_l,\n                new_len,\n                insert_pos,\n            )\n            h = h + attn_out\n            h2 = layer.norm2(h)\n            if isinstance(layer, GroupedMoEBlock):\n                ffn = layer.compute_moe_ffn(h2, router_x=router_h)\n                if (\n                    layer_idx == self._last_moe_layer_idx\n                    and layer.collect_routing_stats\n                    and layer.last_router_distribution is not None\n                ):\n                    self._accumulate_last_moe_router_shift(\n                        layer.last_router_distribution\n                    )\n            else:\n                ffn = layer.mlp_dropout(layer.mlp(h2))\n            h = h + ffn\n\n        h = self._apply_actor_ref_film(h, actor_ref_ctx)\n\n        for layer_idx, layer in enumerate(self.layers[1:], start=1):\n            x_norm = layer.norm1(h)\n            k_cache_l = self._k_cache[:, layer_idx]\n            v_cache_l = self._v_cache[:, layer_idx]\n            attn_out, _, _ = layer.attn.forward_single_token(\n                x_norm,\n                cos,\n                sin,\n                k_cache_l,\n                v_cache_l,\n                new_len,\n                insert_pos,\n            )\n            h = h + attn_out\n            h2 = layer.norm2(h)\n            if isinstance(layer, GroupedMoEBlock):\n                ffn = layer.compute_moe_ffn(h2, router_x=router_h)\n                if (\n                    layer_idx == self._last_moe_layer_idx\n                    and layer.collect_routing_stats\n                    and layer.last_router_distribution is not None\n                ):\n                    self._accumulate_last_moe_router_shift(\n                        layer.last_router_distribution\n                    )\n            else:\n                ffn = layer.mlp_dropout(layer.mlp(h2))\n            h = h + ffn\n\n        h = self.norm_f(h)\n        return self.action_mu_head(h[:, 0, :]).reshape(batch, -1)\n\n    def _forward_inference_onnx(\n        self,\n        x: torch.Tensor,\n        past_key_values: torch.Tensor,\n        current_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, ...]:\n        state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)\n        state_h = self.obs_embed(state_x)\n        batch = state_h.shape[0]\n        max_len = past_key_values.shape[3]\n\n        if current_pos.ndim == 0:\n            current_pos = current_pos.view(1).expand(batch)\n\n        insert_pos = current_pos % max_len\n        new_len = torch.clamp(current_pos + 1, max=max_len)\n        position_ids = current_pos.unsqueeze(1)\n        ref_layer_past = past_key_values[0]\n        shared_ref_tokens, ref_k_cache, ref_v_cache = (\n            self._build_shared_ref_tokens_single_step(\n                ref_cur_x=ref_cur_x,\n                ref_fut_x=ref_fut_x,\n                pos_ids=position_ids,\n                k_cache=ref_layer_past[0],\n                v_cache=ref_layer_past[1],\n                new_len=new_len,\n                insert_pos=insert_pos,\n            )\n        )\n        actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)[\n            :, None, :\n        ]\n        router_h = self._pool_router_context(shared_ref_tokens)[:, None, :]\n        cos, sin = self.get_cos_sin(state_h[:, None, :], position_ids)\n\n        present_key_values_list = [\n            torch.stack([ref_k_cache, ref_v_cache], dim=0)\n        ]\n        routing_debug_outputs: list[torch.Tensor] = []\n        export_routing_debug = torch.onnx.is_in_onnx_export()\n\n        h = state_h[:, None, :]\n        for i, layer in enumerate(self.layers[:1]):\n            layer_past = past_key_values[self.ref_hist_n_layers + i]\n            k_cache = layer_past[0]\n            v_cache = layer_past[1]\n\n            h_norm = layer.norm1(h)\n            attn_out, new_k_cache, new_v_cache = (\n                layer.attn.forward_single_token(\n                    x=h_norm,\n                    cos=cos,\n                    sin=sin,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    new_len=new_len,\n                    insert_pos=insert_pos,\n                )\n            )\n            h = h + attn_out\n\n            h_norm2 = layer.norm2(h)\n            if isinstance(layer, GroupedMoEBlock):\n                if export_routing_debug:\n                    ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(\n                        h_norm2,\n                        router_x=router_h,\n                        return_routing_debug=True,\n                    )\n                    routing_debug_outputs.extend([topk_idx, router_logits])\n                else:\n                    ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h)\n            else:\n                ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))\n            h = h + ffn_out\n            current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)\n            present_key_values_list.append(current_layer_kv)\n\n        h = self._apply_actor_ref_film(h, actor_ref_ctx)\n\n        for i, layer in enumerate(self.layers[1:], start=1):\n            layer_past = past_key_values[self.ref_hist_n_layers + i]\n            k_cache = layer_past[0]\n            v_cache = layer_past[1]\n\n            h_norm = layer.norm1(h)\n            attn_out, new_k_cache, new_v_cache = (\n                layer.attn.forward_single_token(\n                    x=h_norm,\n                    cos=cos,\n                    sin=sin,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    new_len=new_len,\n                    insert_pos=insert_pos,\n                )\n            )\n            h = h + attn_out\n\n            h_norm2 = layer.norm2(h)\n            if isinstance(layer, GroupedMoEBlock):\n                if export_routing_debug:\n                    ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(\n                        h_norm2,\n                        router_x=router_h,\n                        return_routing_debug=True,\n                    )\n                    routing_debug_outputs.extend([topk_idx, router_logits])\n                else:\n                    ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h)\n            else:\n                ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))\n            h = h + ffn_out\n            current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)\n            present_key_values_list.append(current_layer_kv)\n\n        h = self.norm_f(h)\n        action = self.action_mu_head(h[:, 0, :])\n        present_key_values = torch.stack(present_key_values_list, dim=0)\n\n        if export_routing_debug and routing_debug_outputs:\n            return (action, present_key_values, *routing_debug_outputs)\n        return action, present_key_values\n\n\nclass ReferenceRoutedGroupedMoETransformerPolicyV3(\n    ReferenceRoutedGroupedMoETransformerPolicyV2\n):\n    supports_explicit_ref_aux_hidden = True\n\n    def __init__(\n        self,\n        input_dim: int,\n        output_dim: int,\n        module_config_dict: dict,\n    ):\n        module_config = dict(module_config_dict)\n        if bool(module_config.get(\"use_future_cross_attn\", False)):\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV3 does not \"\n                \"support use_future_cross_attn=True.\"\n            )\n        state_obs_input_dim = module_config.get(\"state_obs_input_dim\", None)\n        ref_cur_token_dim = module_config.get(\"ref_cur_token_dim\", None)\n        ref_fut_token_dim = module_config.get(\"ref_fut_token_dim\", None)\n        ref_fut_seq_len = module_config.get(\"ref_fut_seq_len\", None)\n        state_feature_indices = module_config.get(\n            \"state_feature_indices\", None\n        )\n        ref_cur_feature_indices = module_config.get(\n            \"ref_cur_feature_indices\", None\n        )\n        ref_fut_slices = module_config.get(\"ref_fut_slices\", None)\n        if state_obs_input_dim is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV3 requires \"\n                \"state_obs_input_dim.\"\n            )\n        if ref_cur_token_dim is None or ref_fut_token_dim is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV3 requires \"\n                \"ref_cur_token_dim and ref_fut_token_dim.\"\n            )\n        if ref_fut_seq_len is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV3 requires \"\n                \"ref_fut_seq_len.\"\n            )\n        if state_feature_indices is None or ref_cur_feature_indices is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV3 requires \"\n                \"state_feature_indices and ref_cur_feature_indices.\"\n            )\n        if ref_fut_slices is None:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV3 requires \"\n                \"ref_fut_slices.\"\n            )\n\n        self.full_obs_input_dim = int(input_dim)\n        self.state_obs_input_dim = int(state_obs_input_dim)\n        self.ref_cur_token_dim = int(ref_cur_token_dim)\n        self.ref_fut_token_dim = int(ref_fut_token_dim)\n        self.ref_fut_seq_len = int(ref_fut_seq_len)\n        self.state_feature_indices = tuple(\n            int(idx) for idx in state_feature_indices\n        )\n        self.ref_cur_feature_indices = tuple(\n            int(idx) for idx in ref_cur_feature_indices\n        )\n        self.ref_fut_slices = tuple(\n            (int(start), int(end), int(dim))\n            for start, end, dim in ref_fut_slices\n        )\n        if self.state_obs_input_dim <= 0:\n            raise ValueError(\n                \"state_obs_input_dim must be positive, got \"\n                f\"{self.state_obs_input_dim}.\"\n            )\n        if self.ref_cur_token_dim <= 0 or self.ref_fut_token_dim <= 0:\n            raise ValueError(\n                \"ref token dims must be positive, got \"\n                f\"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}.\"\n            )\n        if self.ref_cur_token_dim != self.ref_fut_token_dim:\n            raise ValueError(\n                \"current/future ref token dims must match, got \"\n                f\"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}.\"\n            )\n        if self.ref_fut_seq_len <= 0:\n            raise ValueError(\n                f\"ref_fut_seq_len must be positive, got {self.ref_fut_seq_len}.\"\n            )\n        if len(self.state_feature_indices) != self.state_obs_input_dim:\n            raise ValueError(\n                \"state_obs_input_dim must match len(state_feature_indices): \"\n                f\"{self.state_obs_input_dim} vs {len(self.state_feature_indices)}.\"\n            )\n        if len(self.ref_cur_feature_indices) != self.ref_cur_token_dim:\n            raise ValueError(\n                \"ref_cur_token_dim must match len(ref_cur_feature_indices): \"\n                f\"{self.ref_cur_token_dim} vs {len(self.ref_cur_feature_indices)}.\"\n            )\n        fut_flat_dim = 0\n        for start, end, dim in self.ref_fut_slices:\n            if end <= start or dim <= 0:\n                raise ValueError(\n                    f\"Invalid ref_fut_slices entry {(start, end, dim)}.\"\n                )\n            if (end - start) != self.ref_fut_seq_len * dim:\n                raise ValueError(\n                    \"Future ref slice span must equal ref_fut_seq_len * dim, got \"\n                    f\"{(start, end, dim)} with ref_fut_seq_len={self.ref_fut_seq_len}.\"\n                )\n            fut_flat_dim += end - start\n        expected_full_input_dim = (\n            self.state_obs_input_dim + self.ref_cur_token_dim + fut_flat_dim\n        )\n        if self.full_obs_input_dim != expected_full_input_dim:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV3 expected full \"\n                f\"input dim {expected_full_input_dim}, got {self.full_obs_input_dim}.\"\n            )\n\n        self.ref_hist_n_layers = int(module_config.get(\"ref_hist_n_layers\", 1))\n        if self.ref_hist_n_layers != 1:\n            raise ValueError(\n                \"ReferenceRoutedGroupedMoETransformerPolicyV3 currently supports \"\n                \"exactly one ref history attention layer.\"\n            )\n\n        layer_proj_hidden_default = int(\n            module_config.get(\n                \"router_layer_proj_hidden_dim\",\n                module_config.get(\"d_model\", 256),\n            )\n        )\n        self.ref_motion_input_dim = int(\n            self.ref_cur_token_dim\n            + self.ref_fut_seq_len * self.ref_fut_token_dim\n        )\n        self.router_layer_proj_hidden_dim = int(layer_proj_hidden_default)\n        if self.router_layer_proj_hidden_dim <= 0:\n            raise ValueError(\n                \"router_layer_proj_hidden_dim must be positive, got \"\n                f\"{self.router_layer_proj_hidden_dim}.\"\n            )\n\n        GroupedMoETransformerPolicy.__init__(\n            self,\n            input_dim=input_dim,\n            output_dim=output_dim,\n            module_config_dict=module_config,\n        )\n        self.onnx_kv_layers = int(self.ref_hist_n_layers + self.n_layers)\n        self.register_buffer(\n            \"_state_feature_indices\",\n            torch.tensor(self.state_feature_indices, dtype=torch.long),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"_ref_cur_feature_indices\",\n            torch.tensor(self.ref_cur_feature_indices, dtype=torch.long),\n            persistent=False,\n        )\n\n        self.ref_frame_embed = nn.Sequential(\n            nn.Linear(self.ref_motion_input_dim, self.obs_embed_mlp_hidden),\n            nn.SiLU(),\n            nn.Linear(self.obs_embed_mlp_hidden, self.d_model),\n        )\n        self.ref_hist_norm = RMSNorm(self.d_model)\n        self.ref_hist_attn = ModernAttention(\n            d_model=self.d_model,\n            n_heads=self.n_heads,\n            n_kv_heads=self.n_kv_heads,\n            use_qk_norm=self.use_qk_norm,\n            use_gated_attn=self.use_gated_attn,\n            gated_attn_type=self.gated_attn_type,\n            attn_dropout=self.attn_dropout,\n        )\n        self.ref_hist_out_norm = RMSNorm(self.d_model)\n\n        self._moe_layer_indices = tuple(\n            i\n            for i, layer in enumerate(self.layers)\n            if isinstance(layer, GroupedMoEBlock)\n        )\n        self.router_layer_projections = nn.ModuleList(\n            [\n                nn.Sequential(\n                    RMSNorm(self.d_model),\n                    nn.Linear(self.d_model, self.router_layer_proj_hidden_dim),\n                    nn.SiLU(),\n                    nn.Linear(self.router_layer_proj_hidden_dim, self.d_model),\n                )\n                for _ in self._moe_layer_indices\n            ]\n        )\n\n        self._ref_hist_k_cache: torch.Tensor | None = None\n        self._ref_hist_v_cache: torch.Tensor | None = None\n        self._apply_freeze_router_state()\n\n    def _apply_freeze_router_state(self) -> None:\n        GroupedMoETransformerPolicy._apply_freeze_router_state(self)\n        requires_grad = not self.freeze_router\n        self.ref_frame_embed.requires_grad_(requires_grad)\n        self.ref_hist_norm.requires_grad_(requires_grad)\n        self.ref_hist_attn.requires_grad_(requires_grad)\n        self.ref_hist_out_norm.requires_grad_(requires_grad)\n        self.router_layer_projections.requires_grad_(requires_grad)\n\n    def _build_router_ref_motion(\n        self,\n        ref_cur_x: torch.Tensor,\n        ref_fut_x: torch.Tensor,\n    ) -> torch.Tensor:\n        if ref_cur_x.ndim not in (2, 3):\n            raise ValueError(\n                f\"Expected ref_cur_x with ndim 2 or 3, got {ref_cur_x.ndim}.\"\n            )\n        if ref_fut_x.ndim != ref_cur_x.ndim + 1:\n            raise ValueError(\n                \"Expected ref_fut_x to add one future-seq axis relative to \"\n                f\"ref_cur_x, got cur={tuple(ref_cur_x.shape)}, \"\n                f\"fut={tuple(ref_fut_x.shape)}.\"\n            )\n        ref_fut_flat = torch.flatten(ref_fut_x, start_dim=-2)\n        return torch.cat([ref_cur_x, ref_fut_flat], dim=-1)\n\n    def _build_shared_router_summary(\n        self,\n        ref_hist_h: torch.Tensor,\n    ) -> torch.Tensor:\n        with self._router_no_grad_context():\n            return ref_hist_h\n\n    def _build_router_h_per_layer(\n        self,\n        shared_router_summary: torch.Tensor,\n    ) -> list[torch.Tensor | None]:\n        with self._router_no_grad_context():\n            router_h_per_layer: list[torch.Tensor | None] = [\n                None for _ in self.layers\n            ]\n            for proj, layer_idx in zip(\n                self.router_layer_projections, self._moe_layer_indices\n            ):\n                router_h_per_layer[layer_idx] = proj(shared_router_summary)\n            return router_h_per_layer\n\n    def _build_ref_hist_hidden(\n        self,\n        ref_motion_x: torch.Tensor,\n        pos: torch.Tensor,\n        tgt_mask: torch.Tensor | None,\n    ) -> torch.Tensor:\n        with self._router_no_grad_context():\n            ref_motion_h = self.ref_frame_embed(ref_motion_x)\n            ref_hist_attn = self.ref_hist_attn(\n                self.ref_hist_norm(ref_motion_h),\n                *self.get_cos_sin(ref_motion_h, pos),\n                mask=tgt_mask,\n            )\n            return self.ref_hist_out_norm(ref_motion_h + ref_hist_attn)\n\n    def _build_ref_hist_hidden_single_step(\n        self,\n        ref_motion_x: torch.Tensor,\n        pos_ids: torch.Tensor,\n        *,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        new_len: torch.Tensor,\n        insert_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        with self._router_no_grad_context():\n            ref_motion_h = self.ref_frame_embed(ref_motion_x)[:, None, :]\n            ref_cos, ref_sin = self.get_cos_sin(ref_motion_h, pos_ids)\n            ref_hist_attn, ref_k_cache, ref_v_cache = (\n                self.ref_hist_attn.forward_single_token(\n                    self.ref_hist_norm(ref_motion_h),\n                    ref_cos,\n                    ref_sin,\n                    k_cache,\n                    v_cache,\n                    new_len,\n                    insert_pos,\n                )\n            )\n            ref_hist_h = self.ref_hist_out_norm(ref_motion_h + ref_hist_attn)\n        return ref_hist_h, ref_k_cache, ref_v_cache\n\n    def predict_aux_from_pre_moe(\n        self,\n        pre_moe_hidden: torch.Tensor,\n        *,\n        ref_aux_hidden: torch.Tensor | None = None,\n    ) -> dict[str, torch.Tensor]:\n        return GroupedMoETransformerPolicy.predict_aux_from_pre_moe(\n            self,\n            pre_moe_hidden,\n            ref_aux_hidden=ref_aux_hidden,\n        )\n\n    def sequence_mu(\n        self,\n        x: torch.Tensor,\n        *,\n        attn_mask: torch.Tensor | None = None,\n        return_hidden: bool = False,\n        return_pre_moe_hidden: bool = False,\n        return_ref_aux_hidden: bool = False,\n        return_router_features: bool = False,\n        return_router_temporal_features: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...]:\n        batch, time, _ = x.shape\n        h = self.obs_embed(x)\n        _, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)\n        ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x)\n\n        if attn_mask is not None:\n            tgt_mask = attn_mask.unsqueeze(1)\n            start_idx = attn_mask.to(torch.int64).argmax(dim=-1)\n            t_idx = torch.arange(time, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(batch, time)\n            pos = t_idx - start_idx\n        else:\n            tgt_mask = None\n            pos = torch.arange(time, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(batch, time)\n\n        ref_hist_h = self._build_ref_hist_hidden(\n            ref_motion_x=ref_motion_x,\n            pos=pos,\n            tgt_mask=tgt_mask,\n        )\n        shared_router_summary = self._build_shared_router_summary(ref_hist_h)\n        router_h_per_layer = self._build_router_h_per_layer(\n            shared_router_summary\n        )\n        cos, sin = self.get_cos_sin(h, pos)\n        if return_hidden and return_pre_moe_hidden:\n            raise ValueError(\n                \"return_hidden and return_pre_moe_hidden cannot both be True.\"\n            )\n        forward_out = self._forward_layers(\n            h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            router_h_per_layer=router_h_per_layer,\n            return_pre_moe_hidden=return_pre_moe_hidden,\n            return_router_features=return_router_features,\n            return_router_temporal_features=return_router_temporal_features,\n        )\n        extras: list[torch.Tensor] = []\n        if isinstance(forward_out, tuple):\n            h = forward_out[0]\n            extras = list(forward_out[1:])\n        else:\n            h = forward_out\n        h = self.norm_f(h)\n        mu = self.action_mu_head(h)\n        outputs: list[torch.Tensor] = [mu]\n        if return_pre_moe_hidden:\n            outputs.append(extras.pop(0))\n        if return_ref_aux_hidden:\n            outputs.append(shared_router_summary)\n        if return_router_features:\n            outputs.append(extras.pop(0))\n        if return_router_temporal_features:\n            outputs.append(extras.pop(0))\n        if len(outputs) > 1:\n            return tuple(outputs)\n        if return_hidden:\n            return mu, h\n        return mu\n\n    def sequence_hidden(\n        self,\n        x: torch.Tensor,\n        *,\n        attn_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        batch, time, _ = x.shape\n        h = self.obs_embed(x)\n        _, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)\n        ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x)\n\n        if attn_mask is not None:\n            tgt_mask = attn_mask.unsqueeze(1)\n            start_idx = attn_mask.to(torch.int64).argmax(dim=-1)\n            t_idx = torch.arange(time, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(batch, time)\n            pos = t_idx - start_idx\n        else:\n            tgt_mask = None\n            pos = torch.arange(time, device=x.device, dtype=torch.long)[\n                None, :\n            ].expand(batch, time)\n\n        ref_hist_h = self._build_ref_hist_hidden(\n            ref_motion_x=ref_motion_x,\n            pos=pos,\n            tgt_mask=tgt_mask,\n        )\n        shared_router_summary = self._build_shared_router_summary(ref_hist_h)\n        router_h_per_layer = self._build_router_h_per_layer(\n            shared_router_summary\n        )\n        cos, sin = self.get_cos_sin(h, pos)\n        h = self._forward_layers(\n            h,\n            cos=cos,\n            sin=sin,\n            mask=tgt_mask,\n            router_h_per_layer=router_h_per_layer,\n        )\n        h = self.norm_f(h)\n        return h\n\n    def forward(\n        self,\n        input: torch.Tensor,\n        past_key_values: torch.Tensor | None = None,\n        current_pos: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        if past_key_values is not None:\n            return self._forward_inference_onnx(\n                input, past_key_values, current_pos\n            )\n        if input.ndim != 2:\n            raise ValueError(f\"Expected [B, D], got {input.shape}\")\n        mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None)\n        return mu_seq[:, 0, :]\n\n    def single_step_mu(self, x: torch.Tensor) -> torch.Tensor:\n        if x.ndim != 2:\n            raise ValueError(f\"Expected [B, D], got {x.shape}\")\n        _, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)\n        ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x)\n        batch = int(x.shape[0])\n        if self._k_cache is None:\n            mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None)\n            return mu_seq[:, 0, :]\n\n        h = self.obs_embed(x)[:, None, :]\n        self._ensure_internal_cache_device(x.device, dtype=h.dtype)\n\n        cache_len = self._kv_cache_len\n        insert_pos = self._kv_cache_write_idx\n        max_len = int(self.max_ctx_len)\n        new_len = torch.clamp(cache_len + 1, max=max_len)\n\n        self._kv_cache_len = new_len\n        self._kv_cache_write_idx = (insert_pos + 1) % max_len\n\n        pos = self._kv_cache_abs_pos\n        self._kv_cache_abs_pos = pos + 1\n        pos_ids = pos.unsqueeze(1)\n\n        ref_hist_h, _, _ = self._build_ref_hist_hidden_single_step(\n            ref_motion_x=ref_motion_x,\n            pos_ids=pos_ids,\n            k_cache=self._ref_hist_k_cache[:, 0],\n            v_cache=self._ref_hist_v_cache[:, 0],\n            new_len=new_len,\n            insert_pos=insert_pos,\n        )\n        shared_router_summary = self._build_shared_router_summary(ref_hist_h)\n        router_h_per_layer = self._build_router_h_per_layer(\n            shared_router_summary\n        )\n        cos, sin = self.get_cos_sin(h, pos_ids)\n\n        for layer_idx, layer in enumerate(self.layers):\n            x_norm = layer.norm1(h)\n            k_cache_l = self._k_cache[:, layer_idx]\n            v_cache_l = self._v_cache[:, layer_idx]\n            attn_out, _, _ = layer.attn.forward_single_token(\n                x_norm,\n                cos,\n                sin,\n                k_cache_l,\n                v_cache_l,\n                new_len,\n                insert_pos,\n            )\n            h = h + attn_out\n            h2 = layer.norm2(h)\n            if isinstance(layer, GroupedMoEBlock):\n                ffn = layer.compute_moe_ffn(\n                    h2, router_x=router_h_per_layer[layer_idx]\n                )\n                if (\n                    layer_idx == self._last_moe_layer_idx\n                    and layer.collect_routing_stats\n                    and layer.last_router_distribution is not None\n                ):\n                    self._accumulate_last_moe_router_shift(\n                        layer.last_router_distribution\n                    )\n            else:\n                ffn = layer.mlp_dropout(layer.mlp(h2))\n            h = h + ffn\n\n        h = self.norm_f(h)\n        return self.action_mu_head(h[:, 0, :]).reshape(batch, -1)\n\n    def _forward_inference_onnx(\n        self,\n        x: torch.Tensor,\n        past_key_values: torch.Tensor,\n        current_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, ...]:\n        _, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)\n        ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x)\n        h = self.obs_embed(x)[:, None, :]\n        batch = h.shape[0]\n        max_len = past_key_values.shape[3]\n\n        if current_pos.ndim == 0:\n            current_pos = current_pos.view(1).expand(batch)\n\n        insert_pos = current_pos % max_len\n        new_len = torch.clamp(current_pos + 1, max=max_len)\n        position_ids = current_pos.unsqueeze(1)\n\n        ref_layer_past = past_key_values[0]\n        ref_hist_h, ref_k_cache, ref_v_cache = (\n            self._build_ref_hist_hidden_single_step(\n                ref_motion_x=ref_motion_x,\n                pos_ids=position_ids,\n                k_cache=ref_layer_past[0],\n                v_cache=ref_layer_past[1],\n                new_len=new_len,\n                insert_pos=insert_pos,\n            )\n        )\n        shared_router_summary = self._build_shared_router_summary(ref_hist_h)\n        router_h_per_layer = self._build_router_h_per_layer(\n            shared_router_summary\n        )\n        cos, sin = self.get_cos_sin(h, position_ids)\n\n        present_key_values_list = [\n            torch.stack([ref_k_cache, ref_v_cache], dim=0)\n        ]\n        routing_debug_outputs: list[torch.Tensor] = []\n        export_routing_debug = torch.onnx.is_in_onnx_export()\n\n        for i, layer in enumerate(self.layers):\n            layer_past = past_key_values[self.ref_hist_n_layers + i]\n            k_cache = layer_past[0]\n            v_cache = layer_past[1]\n\n            h_norm = layer.norm1(h)\n            attn_out, new_k_cache, new_v_cache = (\n                layer.attn.forward_single_token(\n                    x=h_norm,\n                    cos=cos,\n                    sin=sin,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    new_len=new_len,\n                    insert_pos=insert_pos,\n                )\n            )\n            h = h + attn_out\n\n            h_norm2 = layer.norm2(h)\n            if isinstance(layer, GroupedMoEBlock):\n                if export_routing_debug:\n                    ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(\n                        h_norm2,\n                        router_x=router_h_per_layer[i],\n                        return_routing_debug=True,\n                    )\n                    routing_debug_outputs.extend([topk_idx, router_logits])\n                else:\n                    ffn_out = layer.compute_moe_ffn(\n                        h_norm2, router_x=router_h_per_layer[i]\n                    )\n            else:\n                ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))\n            h = h + ffn_out\n            current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)\n            present_key_values_list.append(current_layer_kv)\n\n        h = self.norm_f(h)\n        action = self.action_mu_head(h[:, 0, :])\n        present_key_values = torch.stack(present_key_values_list, dim=0)\n\n        if export_routing_debug and routing_debug_outputs:\n            return (action, present_key_values, *routing_debug_outputs)\n        return action, present_key_values\n\n\nclass GroupedMoEBlock(nn.Module):\n    def __init__(\n        self,\n        d_model: int,\n        n_heads: int,\n        num_fine_experts: int,\n        num_shared_experts: int,\n        top_k: int,\n        n_kv_heads: int | None = None,\n        ff_mult: float = 2,\n        use_qk_norm: bool = True,\n        use_gated_attn: bool = True,\n        gated_attn_type: str = \"headwise\",\n        attn_dropout: float = 0.0,\n        mlp_dropout: float = 0.0,\n        use_dynamic_bias: bool = False,\n        bias_update_rate: float = 0.001,\n        routing_score_fn: str = \"softmax\",\n        freeze_router: bool = False,\n        routing_scale: float = 1.0,\n        expert_bias_clip: float = 0.0,\n        dead_expert_margin_to_topk_enabled: bool = False,\n        selected_expert_margin_to_unselected_enabled: bool = False,\n        selected_expert_margin_to_unselected_target: float = 0.0,\n        use_cross_attn: bool = False,\n    ):\n        super().__init__()\n        self.d_model = d_model\n        self.n_heads = n_heads\n        self.num_fine_experts = num_fine_experts\n        self.num_shared_experts = num_shared_experts\n        self.top_k = top_k\n        self.use_dynamic_bias = use_dynamic_bias\n        self.bias_update_rate = bias_update_rate\n        self.routing_score_fn = str(routing_score_fn).lower()\n        self.freeze_router = bool(freeze_router)\n        self.routing_scale = float(routing_scale)\n        self.expert_bias_clip = float(expert_bias_clip)\n        self.dead_expert_margin_to_topk_enabled = bool(\n            dead_expert_margin_to_topk_enabled\n        )\n        self.selected_expert_margin_to_unselected_enabled = bool(\n            selected_expert_margin_to_unselected_enabled\n        )\n        self.selected_expert_margin_to_unselected_target = float(\n            selected_expert_margin_to_unselected_target\n        )\n        if self.routing_score_fn not in (\"softmax\", \"sigmoid\"):\n            raise ValueError(\n                f\"routing_score_fn must be one of {{'softmax','sigmoid'}}, got {self.routing_score_fn}\"\n            )\n        if self.routing_scale <= 0.0:\n            raise ValueError(\n                f\"routing_scale must be > 0, got {self.routing_scale}\"\n            )\n        if self.expert_bias_clip < 0.0:\n            raise ValueError(\n                f\"expert_bias_clip must be >= 0, got {self.expert_bias_clip}\"\n            )\n        if self.selected_expert_margin_to_unselected_target < 0.0:\n            raise ValueError(\n                \"selected_expert_margin_to_unselected_target must be >= 0, \"\n                f\"got {self.selected_expert_margin_to_unselected_target}\"\n            )\n        self.register_buffer(\"expert_bias\", torch.zeros(num_fine_experts))\n        self.register_buffer(\n            \"routing_counts_accum\",\n            torch.zeros(num_fine_experts, dtype=torch.long),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"last_routed_expert_usage\",\n            torch.zeros(num_fine_experts, dtype=torch.float32),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"last_routed_active_expert_count\",\n            torch.tensor(0.0),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"last_routed_max_expert_frac\",\n            torch.tensor(0.0),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"last_active_expert_ratio\", torch.tensor(0.0), persistent=False\n        )\n        self.register_buffer(\n            \"last_max_expert_frac\", torch.tensor(0.0), persistent=False\n        )\n        self.register_buffer(\n            \"last_expert_count_cv\", torch.tensor(0.0), persistent=False\n        )\n        self.register_buffer(\n            \"last_min_expert_frac\", torch.tensor(0.0), persistent=False\n        )\n        self.register_buffer(\n            \"last_dead_expert_ratio\", torch.tensor(0.0), persistent=False\n        )\n        self.register_buffer(\n            \"last_dense_expert_usage\",\n            torch.zeros(num_fine_experts, dtype=torch.float32),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"last_dead_expert_margin_to_topk_loss_value\",\n            torch.tensor(0.0),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"last_dead_expert_margin_to_topk_target\",\n            torch.tensor(0.0),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"last_selected_expert_margin_to_unselected\",\n            torch.tensor(0.0),\n            persistent=False,\n        )\n        self.register_buffer(\n            \"last_selected_expert_margin_to_unselected_loss_value\",\n            torch.tensor(0.0),\n            persistent=False,\n        )\n        self.collect_routing_stats = False\n        self.collect_router_distribution = False\n        self.capture_router_distribution = False\n        self.capture_router_logits = False\n        self.last_router_distribution: torch.Tensor | None = None\n        self.last_router_logits: torch.Tensor | None = None\n        self.last_dead_expert_margin_to_topk_loss: torch.Tensor | None = None\n        self.last_selected_expert_margin_to_unselected_loss: (\n            torch.Tensor | None\n        ) = None\n        self.use_cross_attn = bool(use_cross_attn)\n\n        self.norm1 = RMSNorm(d_model)\n        self.attn = ModernAttention(\n            d_model=d_model,\n            n_heads=n_heads,\n            n_kv_heads=n_kv_heads,\n            use_qk_norm=use_qk_norm,\n            use_gated_attn=use_gated_attn,\n            gated_attn_type=gated_attn_type,\n            attn_dropout=attn_dropout,\n        )\n        if self.use_cross_attn:\n            self.norm_cross = RMSNorm(d_model)\n            self.cross_attn = ModernCrossAttention(\n                d_model=d_model,\n                n_heads=n_heads,\n                n_kv_heads=n_kv_heads,\n                use_qk_norm=use_qk_norm,\n                use_gated_attn=use_gated_attn,\n                gated_attn_type=gated_attn_type,\n                attn_dropout=attn_dropout,\n            )\n        else:\n            self.norm_cross = None\n            self.cross_attn = None\n\n        self.norm2 = RMSNorm(d_model)\n        self.intermediate_dim = int(d_model * ff_mult)\n\n        self.router = nn.Linear(d_model, num_fine_experts, bias=False)\n        self._apply_freeze_router_state()\n\n        # Gate + Up (Combined)\n        self.gate_up_proj = nn.Parameter(\n            torch.empty(\n                num_fine_experts, self.d_model, 2 * self.intermediate_dim\n            )\n        )\n        # Down\n        self.down_proj = nn.Parameter(\n            torch.empty(num_fine_experts, self.intermediate_dim, self.d_model)\n        )\n\n        self.shared_experts = DeepseekV3MLP(\n            hidden_size=d_model,\n            intermediate_size=int(d_model * ff_mult * num_shared_experts),\n        )\n        self.mlp_dropout = (\n            nn.Dropout(mlp_dropout) if mlp_dropout > 0.0 else nn.Identity()\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        nn.init.xavier_uniform_(self.gate_up_proj)\n        nn.init.xavier_uniform_(self.down_proj)\n\n    def _load_from_state_dict(\n        self,\n        state_dict,\n        prefix,\n        local_metadata,\n        strict,\n        missing_keys,\n        unexpected_keys,\n        error_msgs,\n    ):\n        gate_up_key = prefix + \"gate_up_proj\"\n        current_gate_up_shape = tuple(self.gate_up_proj.shape)\n        legacy_gate_up_shape = (\n            self.num_fine_experts,\n            2 * self.intermediate_dim,\n            self.d_model,\n        )\n\n        down_key = prefix + \"down_proj\"\n        current_down_shape = tuple(self.down_proj.shape)\n        legacy_down_shape = (\n            self.num_fine_experts,\n            self.d_model,\n            self.intermediate_dim,\n        )\n\n        is_legacy_layout = None\n        if gate_up_key in state_dict:\n            gate_up_shape = tuple(state_dict[gate_up_key].shape)\n            gate_up_is_current = gate_up_shape == current_gate_up_shape\n            gate_up_is_legacy = gate_up_shape == legacy_gate_up_shape\n            if gate_up_is_current and not gate_up_is_legacy:\n                is_legacy_layout = False\n            elif gate_up_is_legacy and not gate_up_is_current:\n                is_legacy_layout = True\n        if is_legacy_layout is None and down_key in state_dict:\n            down_shape = tuple(state_dict[down_key].shape)\n            down_is_current = down_shape == current_down_shape\n            down_is_legacy = down_shape == legacy_down_shape\n            if down_is_current and not down_is_legacy:\n                is_legacy_layout = False\n            elif down_is_legacy and not down_is_current:\n                is_legacy_layout = True\n\n        if gate_up_key in state_dict:\n            gate_up_w = state_dict[gate_up_key]\n            gate_up_shape = tuple(gate_up_w.shape)\n            gate_up_is_legacy_only = (\n                gate_up_shape == legacy_gate_up_shape\n                and gate_up_shape != current_gate_up_shape\n            )\n            gate_up_is_ambiguous = (\n                gate_up_shape == legacy_gate_up_shape\n                and gate_up_shape == current_gate_up_shape\n            )\n            if gate_up_is_legacy_only or (\n                gate_up_is_ambiguous and is_legacy_layout\n            ):\n                state_dict[gate_up_key] = gate_up_w.transpose(\n                    -2, -1\n                ).contiguous()\n\n        if down_key in state_dict:\n            down_w = state_dict[down_key]\n            down_shape = tuple(down_w.shape)\n            down_is_legacy_only = (\n                down_shape == legacy_down_shape\n                and down_shape != current_down_shape\n            )\n            down_is_ambiguous = (\n                down_shape == legacy_down_shape\n                and down_shape == current_down_shape\n            )\n            if down_is_legacy_only or (down_is_ambiguous and is_legacy_layout):\n                state_dict[down_key] = down_w.transpose(-2, -1).contiguous()\n\n        super()._load_from_state_dict(\n            state_dict,\n            prefix,\n            local_metadata,\n            strict,\n            missing_keys,\n            unexpected_keys,\n            error_msgs,\n        )\n        self._apply_freeze_router_state()\n\n    def _apply_freeze_router_state(self) -> None:\n        self.router.requires_grad_(not self.freeze_router)\n\n    def reset_routing_stats(self) -> None:\n        self.routing_counts_accum.zero_()\n        self.last_router_distribution = None\n        self.last_router_logits = None\n        self.last_routed_expert_usage.zero_()\n        self.last_routed_active_expert_count.zero_()\n        self.last_routed_max_expert_frac.zero_()\n        self.last_dense_expert_usage.zero_()\n        self.last_dead_expert_margin_to_topk_loss_value.zero_()\n        self.last_dead_expert_margin_to_topk_target.zero_()\n        self.last_dead_expert_margin_to_topk_loss = None\n        self.last_selected_expert_margin_to_unselected.zero_()\n        self.last_selected_expert_margin_to_unselected_loss_value.zero_()\n        self.last_selected_expert_margin_to_unselected_loss = None\n\n    def accumulate_routing_stats(self, topk_idx: torch.Tensor) -> None:\n        with torch.no_grad():\n            counts = torch.bincount(\n                topk_idx.reshape(-1), minlength=self.num_fine_experts\n            )\n            self.routing_counts_accum.add_(counts)\n\n    def _apply_bias_update_from_counts(self, counts: torch.Tensor) -> None:\n        with torch.no_grad():\n            if dist.is_available() and dist.is_initialized():\n                dist.all_reduce(counts, op=dist.ReduceOp.SUM)\n            total = counts.sum()\n            if int(total.item()) == 0:\n                self.last_active_expert_ratio.zero_()\n                self.last_max_expert_frac.zero_()\n                self.last_expert_count_cv.zero_()\n                self.last_min_expert_frac.zero_()\n                self.last_dead_expert_ratio.zero_()\n                return\n            if self.use_dynamic_bias:\n                avg = counts.float().mean()\n                error = avg - counts.float()\n                self.expert_bias.add_(\n                    self.bias_update_rate * torch.sign(error)\n                )\n            total = total.clamp_min(1)\n            active_ratio = (counts > 0).to(torch.float32).mean()\n            max_expert_frac = counts.max().to(torch.float32) / total.to(\n                torch.float32\n            )\n            min_expert_frac = counts.min().to(torch.float32) / total.to(\n                torch.float32\n            )\n            dead_expert_ratio = (counts == 0).to(torch.float32).mean()\n            counts_f = counts.to(torch.float32)\n            counts_mean = counts_f.mean().clamp_min(1.0e-6)\n            counts_std = counts_f.std(unbiased=False)\n            expert_count_cv = counts_std / counts_mean\n            self.last_active_expert_ratio.copy_(active_ratio)\n            self.last_max_expert_frac.copy_(max_expert_frac)\n            self.last_expert_count_cv.copy_(expert_count_cv)\n            self.last_min_expert_frac.copy_(min_expert_frac)\n            self.last_dead_expert_ratio.copy_(dead_expert_ratio)\n            if self.use_dynamic_bias and self.expert_bias_clip > 0.0:\n                self.expert_bias.clamp_(\n                    min=-self.expert_bias_clip, max=self.expert_bias_clip\n                )\n\n    def apply_bias_update_from_counts(self) -> None:\n        with torch.no_grad():\n            counts = self.routing_counts_accum.clone()\n            self.routing_counts_accum.zero_()\n        self._apply_bias_update_from_counts(counts)\n\n    def _update_routed_expert_stats_and_floor_loss(\n        self,\n        topk_idx: torch.Tensor,\n        dense_distribution: torch.Tensor,\n        choice_scores: torch.Tensor,\n    ) -> torch.Tensor:\n        counts = torch.bincount(\n            topk_idx.reshape(-1), minlength=self.num_fine_experts\n        ).to(torch.float32)\n        total_assignments = max(int(topk_idx.numel()), 1)\n        hard_usage = counts / float(total_assignments)\n        active_count = (counts > 0).to(torch.float32).sum()\n        max_frac = hard_usage.max() if hard_usage.numel() > 0 else counts.sum()\n        with torch.no_grad():\n            self.last_routed_expert_usage.copy_(\n                hard_usage.to(self.last_routed_expert_usage.dtype)\n            )\n            self.last_routed_active_expert_count.copy_(\n                active_count.to(self.last_routed_active_expert_count.dtype)\n            )\n            self.last_routed_max_expert_frac.copy_(\n                max_frac.to(self.last_routed_max_expert_frac.dtype)\n            )\n        dense_usage = dense_distribution.to(torch.float32).mean(dim=(0, 1))\n        with torch.no_grad():\n            self.last_dense_expert_usage.copy_(\n                dense_usage.detach().to(self.last_dense_expert_usage.dtype)\n            )\n        kth_choice_score = choice_scores.gather(-1, topk_idx)[..., -1:]\n        if self.top_k < self.num_fine_experts:\n            selected_mask = F.one_hot(\n                topk_idx, num_classes=self.num_fine_experts\n            ).any(dim=-2)\n            best_unselected_score = (\n                choice_scores.masked_fill(\n                    selected_mask, torch.finfo(choice_scores.dtype).min\n                )\n                .max(dim=-1, keepdim=True)\n                .values\n            )\n            selected_margin_gap = kth_choice_score - best_unselected_score\n            selected_margin = selected_margin_gap.mean()\n        else:\n            selected_margin_gap = choice_scores.new_zeros(\n                choice_scores.shape[:2] + (1,)\n            )\n            selected_margin = choice_scores.new_zeros(())\n        if self.selected_expert_margin_to_unselected_enabled:\n            selected_margin_loss = torch.relu(\n                self.selected_expert_margin_to_unselected_target\n                - selected_margin_gap\n            ).mean()\n        else:\n            selected_margin_loss = choice_scores.new_zeros(())\n        with torch.no_grad():\n            self.last_selected_expert_margin_to_unselected.copy_(\n                selected_margin.detach().to(\n                    self.last_selected_expert_margin_to_unselected.dtype\n                )\n            )\n            self.last_selected_expert_margin_to_unselected_loss_value.copy_(\n                selected_margin_loss.detach().to(\n                    self.last_selected_expert_margin_to_unselected_loss_value.dtype\n                )\n            )\n        self.last_selected_expert_margin_to_unselected_loss = (\n            selected_margin_loss\n        )\n\n        if not self.dead_expert_margin_to_topk_enabled:\n            margin_loss = dense_distribution.new_zeros(())\n            with torch.no_grad():\n                self.last_dead_expert_margin_to_topk_loss_value.zero_()\n                self.last_dead_expert_margin_to_topk_target.zero_()\n            self.last_dead_expert_margin_to_topk_loss = margin_loss\n            return margin_loss\n\n        dead_mask = (counts == 0).to(choice_scores.dtype)\n        margin_gap = torch.relu(kth_choice_score - choice_scores)\n        dead_margin_sum = (\n            margin_gap * dead_mask.view(1, 1, self.num_fine_experts)\n        ).sum()\n        dead_count = dead_mask.sum()\n        num_tokens = choice_scores.new_ones(\n            choice_scores.shape[:2], dtype=choice_scores.dtype\n        ).sum()\n        normalizer = dead_count.clamp_min(1.0) * num_tokens\n        margin_loss = dead_margin_sum / normalizer\n        with torch.no_grad():\n            self.last_dead_expert_margin_to_topk_loss_value.copy_(\n                margin_loss.detach().to(\n                    self.last_dead_expert_margin_to_topk_loss_value.dtype\n                )\n            )\n            self.last_dead_expert_margin_to_topk_target.copy_(\n                kth_choice_score.mean()\n                .detach()\n                .to(self.last_dead_expert_margin_to_topk_target.dtype)\n            )\n        self.last_dead_expert_margin_to_topk_loss = margin_loss\n        return margin_loss\n\n    @torch.compiler.disable\n    def _compute_sparse_experts(\n        self,\n        x: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_scores: torch.Tensor,\n    ) -> torch.Tensor:\n        B, T, D = x.size()\n        num_top_k = self.top_k\n\n        is_exporting = torch.onnx.is_in_onnx_export()\n        if is_exporting:\n            # ONNX/runtime path: compute only selected experts (top_k),\n            # avoiding O(num_experts) per-step overhead at bs=1.\n            return self._compute_with_topk_selection(x, topk_idx, topk_scores)\n\n        x_flat = x.view(-1, D)\n        expert_ids = topk_idx.view(-1)\n        scores = topk_scores.view(-1)\n\n        raw_token_indices = (\n            torch.arange(B * T, device=x.device)\n            .unsqueeze(1)\n            .expand(-1, num_top_k)\n            .reshape(-1)\n        )\n        sorted_expert_ids, perm = torch.sort(expert_ids)\n        sorted_token_indices = raw_token_indices[perm]\n        x_sorted = x_flat[sorted_token_indices]\n        scores_sorted = scores[perm]\n\n        # Path B: High-Performance Grouped GEMM\n        output_sorted = self._compute_with_grouped_mm(\n            x_sorted, sorted_expert_ids\n        )\n\n        output_sorted = output_sorted * scores_sorted.unsqueeze(-1)\n        inv_perm = torch.argsort(perm)\n        output_flat = output_sorted[inv_perm]\n        output_final = output_flat.view(B * T, num_top_k, D).sum(dim=1)\n\n        return output_final.view(B, T, D)\n\n    def _compute_with_grouped_mm(\n        self, x_sorted: torch.Tensor, sorted_expert_ids: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"Based on official implementation logic:\n        - offsets must be Cumsum (End-Indices).\n        - offsets length must be exactly Num_Experts (NOT N+1).\n        - dtype must be int32.\n        \"\"\"\n        tokens_per_expert = torch.bincount(\n            sorted_expert_ids.long(), minlength=self.num_fine_experts\n        )\n        counts = tokens_per_expert[: self.num_fine_experts]\n\n        offsets = torch.cumsum(counts, dim=0, dtype=torch.int32)\n\n        gate_up_out = _grouped_linear(\n            x_sorted, self.gate_up_proj, offs=offsets\n        )\n\n        x1, x2 = gate_up_out.chunk(2, dim=-1)\n        hidden = F.silu(x1) * x2\n\n        out = _grouped_linear(hidden, self.down_proj, offs=offsets)\n\n        return out\n\n    def _compute_with_topk_selection(\n        self,\n        x: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_scores: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"ONNX-friendly sparse expert compute that scales with top_k, not\n        num_fine_experts.\n        \"\"\"\n        B, T, D = x.shape\n        N = B * T\n        K = self.top_k\n        orig_dtype = x.dtype\n\n        x_tokens = x.reshape(N, D)\n        idx = topk_idx.reshape(N, K)\n        scores = topk_scores.reshape(N, K)\n\n        x_rep = x_tokens[:, None, :].expand(N, K, D).reshape(N * K, D)\n        idx_flat = idx.reshape(N * K)\n        compute_dtype = self.gate_up_proj.dtype\n        if x_rep.dtype != compute_dtype:\n            x_rep = x_rep.to(compute_dtype)\n\n        gate_up_w = self.gate_up_proj.index_select(0, idx_flat)\n        gate_up_out = torch.bmm(x_rep.unsqueeze(1), gate_up_w).squeeze(1)\n\n        x1, x2 = gate_up_out.chunk(2, dim=-1)\n        hidden = F.silu(x1) * x2\n\n        down_w = self.down_proj.index_select(0, idx_flat)\n        sparse_flat = torch.bmm(hidden.unsqueeze(1), down_w).squeeze(1)\n        if sparse_flat.dtype != orig_dtype:\n            sparse_flat = sparse_flat.to(orig_dtype)\n\n        sparse = sparse_flat.view(N, K, D)\n        weighted = sparse * scores.to(sparse.dtype).unsqueeze(-1)\n        out = weighted.sum(dim=1)\n        return out.view(B, T, D)\n\n    def _compute_with_loop_fallback(\n        self, x_sorted: torch.Tensor, sorted_expert_ids: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"Path A: Loop Fallback (Compatible with F.linear and 3D Weights)\"\"\"\n        results = []\n        for i in range(self.num_fine_experts):\n            mask = sorted_expert_ids == i\n            inp_i = x_sorted[mask]\n\n            # Gate + Up\n            w_gate_up = self.gate_up_proj[i].transpose(0, 1)\n            gate_up_out = F.linear(inp_i, w_gate_up)\n\n            x1, x2 = gate_up_out.chunk(2, dim=-1)\n            hidden = F.silu(x1) * x2\n\n            # Down\n            w_down = self.down_proj[i].transpose(0, 1)\n            out_i = F.linear(hidden, w_down)\n            results.append(out_i)\n\n        return torch.cat(results, dim=0)\n\n    def compute_moe_ffn(\n        self,\n        x: torch.Tensor,\n        router_x: torch.Tensor | None = None,\n        *,\n        return_routing_debug: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        B, T, D = x.shape\n        should_cache_router_distribution = (\n            self.collect_routing_stats and self.collect_router_distribution\n        ) or self.capture_router_distribution\n        should_cache_router_logits = self.capture_router_logits\n        # 1. Shared Experts (Dense Path)\n        shared_out = self.shared_experts(x)\n\n        # 2. Router (Gating)\n        router_input = x if router_x is None else router_x\n        if router_input.shape != x.shape:\n            raise ValueError(\n                \"router_x shape must match x shape in compute_moe_ffn: \"\n                f\"x={tuple(x.shape)}, router_x={tuple(router_input.shape)}\"\n            )\n        if self.freeze_router:\n            with torch.no_grad():\n                logits = self.router(router_input)\n        else:\n            logits = self.router(router_input)\n        logits_fp32 = logits.to(torch.float32)\n        bias_fp32 = None\n        if self.use_dynamic_bias:\n            bias_fp32 = self.expert_bias.to(\n                device=logits.device, dtype=torch.float32\n            )\n\n        if self.routing_score_fn == \"softmax\":\n            choice_logits = logits_fp32\n            if bias_fp32 is not None:\n                # Keep dynamic bias as a selection correction, not a mixture-weight shaper.\n                choice_logits = choice_logits + bias_fp32\n            choice_scores = choice_logits\n            _, topk_idx = torch.topk(choice_scores, self.top_k, dim=-1)\n            dense_distribution = torch.softmax(logits_fp32, dim=-1)\n            if torch.onnx.is_in_onnx_export():\n                selected_probs = dense_distribution.gather(-1, topk_idx)\n            else:\n                selected_logits = logits_fp32.gather(-1, topk_idx)\n                log_z = torch.logsumexp(logits_fp32, dim=-1, keepdim=True)\n                selected_probs = torch.exp(selected_logits - log_z)\n            topk_scores = selected_probs / selected_probs.sum(\n                dim=-1, keepdim=True\n            ).clamp_min(1.0e-20)\n            router_distribution = None\n            if should_cache_router_distribution:\n                router_distribution = dense_distribution\n        else:  # sigmoid\n            scores = torch.sigmoid(logits_fp32)\n            dense_distribution = scores / scores.sum(\n                dim=-1, keepdim=True\n            ).clamp_min(1.0e-20)\n            scores_for_choice = scores\n            if bias_fp32 is not None:\n                # DeepSeek-style correction bias for expert choice.\n                scores_for_choice = scores_for_choice + bias_fp32\n            choice_scores = scores_for_choice\n            _, topk_idx = torch.topk(choice_scores, self.top_k, dim=-1)\n            selected_scores = scores.gather(-1, topk_idx)\n            # Match DeepSeek-style routing: bias affects only expert choice,\n            # while the expert mixing weights come from the original sigmoid\n            # affinities normalized over the selected experts.\n            topk_scores = selected_scores / selected_scores.sum(\n                dim=-1, keepdim=True\n            ).clamp_min(1.0e-20)\n            router_distribution = None\n            if should_cache_router_distribution:\n                router_distribution = dense_distribution\n\n        if self.collect_routing_stats:\n            self.accumulate_routing_stats(topk_idx)\n        if (\n            should_cache_router_distribution\n            and router_distribution is not None\n        ):\n            self.last_router_distribution = router_distribution\n        else:\n            self.last_router_distribution = None\n        if should_cache_router_logits:\n            self.last_router_logits = logits_fp32\n        else:\n            self.last_router_logits = None\n        self._update_routed_expert_stats_and_floor_loss(\n            topk_idx=topk_idx,\n            dense_distribution=dense_distribution,\n            choice_scores=choice_scores,\n        )\n        if self.routing_scale != 1.0:\n            topk_scores = topk_scores * self.routing_scale\n        topk_scores = topk_scores.to(logits.dtype)\n\n        # 3. Sparse Experts Computation (Grouped MM / ONNX Loop)\n        sparse_out = self._compute_sparse_experts(x, topk_idx, topk_scores)\n\n        # 4. Combine\n        output = shared_out + sparse_out\n        output = self.mlp_dropout(output)\n\n        if return_routing_debug:\n            return output, topk_idx, logits_fp32\n        return output\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        cos: torch.Tensor = None,\n        sin: torch.Tensor = None,\n        mask: torch.Tensor | None = None,\n        memory: torch.Tensor | None = None,\n        memory_mask: torch.Tensor | None = None,\n        router_x: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass compatible with ONNX and Attention/Norm.\"\"\"\n        norm_x = self.norm1(x)\n        attn_out = self.attn(norm_x, cos, sin, mask)\n        x = x + attn_out\n        if self.use_cross_attn and memory is not None:\n            x_cross = self.norm_cross(x)\n            if memory.ndim == 4:\n                b, t, d_model = x_cross.shape\n                _, _, n_fut, _ = memory.shape\n                q = x_cross.reshape(b * t, 1, d_model)\n                mem = memory.reshape(b * t, n_fut, d_model)\n                mem_mask = None\n                if memory_mask is not None:\n                    if memory_mask.ndim != 3:\n                        raise ValueError(\n                            \"memory_mask for 4D memory must have shape [B, T, N_fut]\"\n                        )\n                    mem_mask = memory_mask.reshape(b * t, 1, 1, n_fut)\n                cross = self.cross_attn(q, mem, mem_mask).reshape(\n                    b, t, d_model\n                )\n            else:\n                cross = self.cross_attn(x_cross, memory, memory_mask)\n            x = x + cross\n\n        h = self.norm2(x)\n        ffn_out = self.compute_moe_ffn(h, router_x=router_x)\n        x = x + ffn_out\n\n        return x\n\n\nclass ModernTransformerBlock(nn.Module):\n    \"\"\"Modern Transformer block with pre-norm, SwiGLU MLP, and modern attention.\n\n    Features:\n        - Pre-normalization with RMSNorm.\n        - ModernAttention (GQA, QK-Norm, RealRoPE, Gated Attention).\n        - DeepseekV3MLP (SwiGLU) for feed-forward.\n    \"\"\"\n\n    def __init__(\n        self,\n        d_model: int,\n        n_heads: int,\n        n_kv_heads: int | None = None,\n        ff_mult: int = 4,\n        use_qk_norm: bool = True,\n        use_gated_attn: bool = True,\n        gated_attn_type: str = \"headwise\",\n        attn_dropout: float = 0.0,\n        mlp_dropout: float = 0.0,\n        use_cross_attn: bool = False,\n    ):\n        super().__init__()\n        self.use_cross_attn = bool(use_cross_attn)\n        self.norm1 = RMSNorm(d_model)\n        self.attn = ModernAttention(\n            d_model=d_model,\n            n_heads=n_heads,\n            n_kv_heads=n_kv_heads,\n            use_qk_norm=use_qk_norm,\n            use_gated_attn=use_gated_attn,\n            gated_attn_type=gated_attn_type,\n            attn_dropout=attn_dropout,\n        )\n        if self.use_cross_attn:\n            self.norm_cross = RMSNorm(d_model)\n            self.cross_attn = ModernCrossAttention(\n                d_model=d_model,\n                n_heads=n_heads,\n                n_kv_heads=n_kv_heads,\n                use_qk_norm=use_qk_norm,\n                use_gated_attn=use_gated_attn,\n                gated_attn_type=gated_attn_type,\n                attn_dropout=attn_dropout,\n            )\n        else:\n            self.norm_cross = None\n            self.cross_attn = None\n        self.norm2 = RMSNorm(d_model)\n        self.mlp = DeepseekV3MLP(\n            hidden_size=d_model, intermediate_size=d_model * ff_mult\n        )\n        self.mlp_dropout = (\n            nn.Dropout(mlp_dropout) if mlp_dropout > 0.0 else nn.Identity()\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        mask: torch.Tensor | None = None,\n        memory: torch.Tensor | None = None,\n        memory_mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass with pre-norm residual connections.\n\n        Args:\n            x: Input tensor [B, T, d_model].\n            freqs_cis: RoPE frequencies [T, head_dim // 2].\n            mask: Attention mask [T, T] or [B, T, T], True = allowed (can attend).\n\n        Returns:\n            out: Output tensor [B, T, d_model].\n        \"\"\"\n        x = x + self.attn(self.norm1(x), cos, sin, mask)\n        if self.use_cross_attn and memory is not None:\n            x_cross = self.norm_cross(x)\n            if memory.ndim == 4:\n                b, t, d_model = x_cross.shape\n                _, _, n_fut, _ = memory.shape\n                q = x_cross.reshape(b * t, 1, d_model)\n                mem = memory.reshape(b * t, n_fut, d_model)\n                mem_mask = None\n                if memory_mask is not None:\n                    if memory_mask.ndim != 3:\n                        raise ValueError(\n                            \"memory_mask for 4D memory must have shape [B, T, N_fut]\"\n                        )\n                    mem_mask = memory_mask.reshape(b * t, 1, 1, n_fut)\n                cross = self.cross_attn(q, mem, mem_mask).reshape(\n                    b, t, d_model\n                )\n            else:\n                cross = self.cross_attn(x_cross, memory, memory_mask)\n            x = x + cross\n        x = x + self.mlp_dropout(self.mlp(self.norm2(x)))\n        return x\n\n\nclass DeepseekV3MLP(nn.Module):\n    \"\"\"SwiGLU MLP with fused gate+up projection for efficiency.\"\"\"\n\n    def __init__(self, hidden_size=None, intermediate_size=None):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        # Fused gate and up projection: outputs [gate, up] concatenated\n        self.gate_up_proj = nn.Linear(\n            self.hidden_size,\n            2 * self.intermediate_size,\n            bias=True,\n        )\n        self.down_proj = nn.Linear(\n            self.intermediate_size,\n            self.hidden_size,\n            bias=True,\n        )\n        self.act_fn = nn.SiLU()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # x: [..., hidden_size]\n        gate_up = self.gate_up_proj(x)  # [..., 2 * intermediate_size]\n        gate, up = gate_up.chunk(2, dim=-1)  # each [..., intermediate_size]\n        return self.down_proj(self.act_fn(gate) * up)\n\n\nclass RMSNorm(nn.Module):\n    \"\"\"Root Mean Square Layer Normalization (used in Llama, DeepSeek, Qwen).\"\"\"\n\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        var = x.pow(2).mean(-1, keepdim=True)\n        x_normed = x * torch.rsqrt(var + self.eps)\n        return x_normed * self.weight\n\n\nclass ModernAttention(nn.Module):\n    \"\"\"Modern attention with GQA, QK-Norm, RealRoPE, Flash Attention, and Gated Attention.\n\n    Features:\n        - GQA: Grouped Query Attention (n_kv_heads < n_heads).\n        - QK-Norm: RMSNorm on queries and keys for stability.\n        - RealRoPE: Real-valued Rotary Positional Embeddings.\n        - Flash Attention: via F.scaled_dot_product_attention.\n        - Gated Attention: Headwise or element-wise sigmoid gating (Qwen3-style).\n        - Fused Projections: Q separate, KV fused for efficiency.\n\n    Reference: https://github.com/qiuzh20/gated_attention\n    \"\"\"\n\n    def __init__(\n        self,\n        d_model: int,\n        n_heads: int,\n        n_kv_heads: int | None = None,\n        use_qk_norm: bool = True,\n        use_gated_attn: bool = True,\n        gated_attn_type: str = \"headwise\",\n        attn_dropout: float = 0.0,\n    ):\n        \"\"\"Initialize ModernAttention.\n\n        Args:\n            d_model: Model dimension.\n            n_heads: Number of query heads.\n            n_kv_heads: Number of key/value heads (for GQA). Defaults to n_heads.\n            use_qk_norm: Apply RMSNorm to Q and K.\n            use_gated_attn: Enable gated attention.\n            gated_attn_type: \"headwise\" (Qwen3-style, one gate per head) or\n                             \"elementwise\" (one gate per element).\n            attn_dropout: Dropout probability for attention.\n        \"\"\"\n        super().__init__()\n        self.d_model = d_model\n        self.n_heads = n_heads\n        self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads\n        self.head_dim = d_model // n_heads\n        self.n_rep = self.n_heads // self.n_kv_heads\n        self.use_qk_norm = use_qk_norm\n        self.use_gated_attn = use_gated_attn\n        self.gated_attn_type = gated_attn_type\n        self.attn_dropout = attn_dropout\n\n        # Fused projections: Q separate, KV fused (for GQA efficiency)\n        self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)\n        self.kv_proj = nn.Linear(\n            d_model, 2 * self.n_kv_heads * self.head_dim, bias=False\n        )\n        self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)\n\n        if self.use_qk_norm:\n            self.q_norm = RMSNorm(self.head_dim)\n            self.k_norm = RMSNorm(self.head_dim)\n\n        if self.use_gated_attn:\n            if self.gated_attn_type == \"headwise\":\n                # Qwen3-style: one gate scalar per head [B, T, n_heads]\n                self.gate_proj = nn.Linear(d_model, n_heads, bias=False)\n            else:\n                # Element-wise: gate each element [B, T, d_model]\n                self.gate_proj = nn.Linear(d_model, d_model, bias=False)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        cos: torch.Tensor,  # [B, T, D]\n        sin: torch.Tensor,  # [B, T, D]\n        mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"Forward pass.\n\n        Args:\n            x: Input tensor [B, T, d_model].\n            cos: RoPE cosine frequencies [B, T, head_dim].\n            sin: RoPE sine frequencies [B, T, head_dim].\n            mask: Attention mask [T, T] or [B, T, T], True = allowed (can attend).\n\n        Returns:\n            out: Output tensor [B, T, d_model].\n        \"\"\"\n        B, T, _ = x.shape\n\n        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)\n        kv = self.kv_proj(x).view(B, T, 2, self.n_kv_heads, self.head_dim)\n        k, v = kv[:, :, 0], kv[:, :, 1]  # each [B, T, n_kv_heads, head_dim]\n\n        if self.use_qk_norm:\n            q = self.q_norm(q)\n            k = self.k_norm(k)\n\n        # Transpose for SDPA: [B, n_heads, T, head_dim]\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        q, k = apply_rotary_pos_emb(q, k, cos, sin)\n\n        # Flash Attention via SDPA (handles GQA internally)\n        dropout_p = self.attn_dropout if self.training else 0.0\n        is_exporting = torch.onnx.is_in_onnx_export()\n\n        if is_exporting:\n            k = repeat_kv(k, self.n_rep)\n            v = repeat_kv(v, self.n_rep)\n            enable_gqa = False\n        else:\n            enable_gqa = True\n        attn_out = export_safe_scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attn_mask=mask,\n            dropout_p=dropout_p,\n            is_causal=(mask is None),\n            enable_gqa=enable_gqa,\n        )\n\n        # attn_out: [B, n_heads, T, head_dim]\n\n        # Gated Attention (Qwen3-style)\n        if self.use_gated_attn:\n            if self.gated_attn_type == \"headwise\":\n                # Headwise gating: [B, T, n_heads] -> [B, n_heads, T, 1]\n                g = torch.sigmoid(self.gate_proj(x))  # [B, T, n_heads]\n                g = g.transpose(1, 2)[..., None]  # [B, n_heads, T, 1]\n                attn_out = attn_out * g\n            else:\n                # Element-wise gating: apply after reshaping\n                attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1)\n                g = torch.sigmoid(self.gate_proj(x))  # [B, T, d_model]\n                attn_out = attn_out * g\n                return self.o_proj(attn_out)\n\n        attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1)\n        return self.o_proj(attn_out)\n\n    def forward_single_token(\n        self,\n        x: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n        new_len: torch.Tensor,\n        insert_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"Forward for single token with per-environment KV cache update.\n\n        Args:\n            x: Input tensor [B, 1, d_model].\n            cos: RoPE cosine frequencies [B, 1, head_dim].\n            sin: RoPE sine frequencies [B, 1, head_dim].\n            k_cache: K cache for this layer [B, max_ctx_len, n_kv_heads, head_dim].\n            v_cache: V cache for this layer [B, max_ctx_len, n_kv_heads, head_dim].\n            new_len: New valid cache length per env AFTER inserting this token [B].\n            insert_pos: Insert position per env [B].\n\n        Returns:\n            attn_out: [B, 1, d_model]\n            k_cache: Updated K cache\n            v_cache: Updated V cache\n        \"\"\"\n        B = x.shape[0]\n        q = self.q_proj(x).view(B, 1, self.n_heads, self.head_dim)\n        kv = self.kv_proj(x).view(B, 1, 2, self.n_kv_heads, self.head_dim)\n        k_new, v_new = kv[:, :, 0], kv[:, :, 1]  # [B, 1, n_kv_heads, head_dim]\n\n        if self.use_qk_norm:\n            q = self.q_norm(q)\n            k_new = self.k_norm(k_new)\n\n        # Apply RoPE with per-environment position\n        q = q.transpose(1, 2)\n        k_new = k_new.transpose(1, 2)\n        q, k_new = apply_rotary_pos_emb(q, k_new, cos, sin)\n\n        q = q.transpose(1, 2)\n        k_new = k_new.transpose(1, 2)\n\n        # Scatter K, V into cache at per-env insert positions\n        # insert_pos: [B] -> [B, 1, 1, 1] for scatter\n        idx = (\n            insert_pos.view(B, 1, 1, 1)\n            .expand(B, 1, self.n_kv_heads, self.head_dim)\n            .to(torch.int64)\n        )\n        if torch.onnx.is_in_onnx_export():\n            # === ONNX 模式: Out-of-place (生成新 Tensor) ===\n            k_cache = k_cache.scatter(1, idx, k_new.to(k_cache.dtype))\n            v_cache = v_cache.scatter(1, idx, v_new.to(v_cache.dtype))\n        else:\n            # === Rollout 模式: In-place (原地修改) ===\n            k_cache.scatter_(1, idx, k_new.to(k_cache.dtype))\n            v_cache.scatter_(1, idx, v_new.to(v_cache.dtype))\n\n        # Compute attention over cached keys/values\n        # Mask out positions >= new_len (after insert)\n        max_len = k_cache.shape[1]\n        new_len = new_len.clamp(max=max_len)  # [B]\n        # Build per-env mask: [B, max_len] where True = valid (can attend)\n        pos_idx = torch.arange(max_len, device=x.device, dtype=torch.int64)\n        valid_mask = pos_idx[None, :] < new_len[:, None]  # [B, max_len]\n        # For SDPA bool mask: True = allowed (can attend)\n        attn_mask = valid_mask[:, None, None, :]  # [B, 1, 1, max_len]\n\n        # GQA: Use native SDPA broadcasting (no repeat_interleave)\n        k_attn = k_cache.to(q.dtype)\n        v_attn = v_cache.to(q.dtype)\n\n        # Transpose for SDPA: [B, n_heads, T, head_dim]\n        q_t = q.transpose(1, 2)  # [B, n_heads, 1, head_dim]\n        k_t = k_attn.transpose(1, 2)  # [B, n_kv_heads, max_len, head_dim]\n        v_t = v_attn.transpose(1, 2)\n\n        dropout_p = self.attn_dropout if self.training else 0.0\n        is_exporting = torch.onnx.is_in_onnx_export()\n\n        if is_exporting:\n            k_t = repeat_kv(k_t, self.n_rep)\n            v_t = repeat_kv(v_t, self.n_rep)\n            enable_gqa = False\n        else:\n            enable_gqa = True\n        attn_out = export_safe_scaled_dot_product_attention(\n            q_t,\n            k_t,\n            v_t,\n            attn_mask=attn_mask,\n            dropout_p=dropout_p,\n            is_causal=False,\n            enable_gqa=enable_gqa,\n        )\n        # attn_out: [B, n_heads, 1, head_dim]\n\n        # Gated Attention\n        if self.use_gated_attn:\n            if self.gated_attn_type == \"headwise\":\n                g = torch.sigmoid(self.gate_proj(x))  # [B, 1, n_heads]\n                g = g.transpose(1, 2)[..., None]  # [B, n_heads, 1, 1]\n                attn_out = attn_out * g\n            else:\n                attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, -1)\n                g = torch.sigmoid(self.gate_proj(x))\n                attn_out = attn_out * g\n                return self.o_proj(attn_out), k_cache, v_cache\n\n        attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, -1)\n        return self.o_proj(attn_out), k_cache, v_cache\n\n\nclass ModernCrossAttention(nn.Module):\n    \"\"\"Cross-attention with GQA/QK-norm and optional gated attention.\"\"\"\n\n    def __init__(\n        self,\n        d_model: int,\n        n_heads: int,\n        n_kv_heads: int | None = None,\n        use_qk_norm: bool = True,\n        use_gated_attn: bool = True,\n        gated_attn_type: str = \"headwise\",\n        attn_dropout: float = 0.0,\n    ):\n        super().__init__()\n        self.d_model = int(d_model)\n        self.n_heads = int(n_heads)\n        self.n_kv_heads = (\n            int(n_kv_heads) if n_kv_heads is not None else int(n_heads)\n        )\n        self.head_dim = self.d_model // self.n_heads\n        self.n_rep = self.n_heads // self.n_kv_heads\n        self.use_qk_norm = bool(use_qk_norm)\n        self.use_gated_attn = bool(use_gated_attn)\n        self.gated_attn_type = str(gated_attn_type)\n        self.attn_dropout = float(attn_dropout)\n\n        self.q_proj = nn.Linear(\n            self.d_model, self.n_heads * self.head_dim, bias=False\n        )\n        self.kv_proj = nn.Linear(\n            self.d_model, 2 * self.n_kv_heads * self.head_dim, bias=False\n        )\n        self.o_proj = nn.Linear(\n            self.n_heads * self.head_dim, self.d_model, bias=False\n        )\n\n        if self.use_qk_norm:\n            self.q_norm = RMSNorm(self.head_dim)\n            self.k_norm = RMSNorm(self.head_dim)\n\n        if self.use_gated_attn:\n            if self.gated_attn_type == \"headwise\":\n                self.gate_proj = nn.Linear(\n                    self.d_model, self.n_heads, bias=False\n                )\n            else:\n                self.gate_proj = nn.Linear(\n                    self.d_model, self.d_model, bias=False\n                )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        memory: torch.Tensor,\n        mask: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        if x.ndim != 3:\n            raise ValueError(f\"x must be [B, T, D], got {tuple(x.shape)}\")\n        if memory.ndim != 3:\n            raise ValueError(\n                f\"memory must be [B, N, D], got {tuple(memory.shape)}\"\n            )\n        b, t, _ = x.shape\n        bm, n, _ = memory.shape\n        if bm != b:\n            raise ValueError(\n                f\"batch mismatch between x and memory: {b} vs {bm}\"\n            )\n\n        q = self.q_proj(x).view(b, t, self.n_heads, self.head_dim)\n        kv = self.kv_proj(memory).view(b, n, 2, self.n_kv_heads, self.head_dim)\n        k, v = kv[:, :, 0], kv[:, :, 1]\n\n        if self.use_qk_norm:\n            q = self.q_norm(q)\n            k = self.k_norm(k)\n\n        q = q.transpose(1, 2)\n        k = k.transpose(1, 2)\n        v = v.transpose(1, 2)\n\n        dropout_p = self.attn_dropout if self.training else 0.0\n        is_exporting = torch.onnx.is_in_onnx_export()\n        if is_exporting:\n            k = repeat_kv(k, self.n_rep)\n            v = repeat_kv(v, self.n_rep)\n            enable_gqa = False\n        else:\n            enable_gqa = True\n        attn_out = export_safe_scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attn_mask=mask,\n            dropout_p=dropout_p,\n            is_causal=False,\n            enable_gqa=enable_gqa,\n        )\n\n        if self.use_gated_attn:\n            if self.gated_attn_type == \"headwise\":\n                g = torch.sigmoid(self.gate_proj(x))\n                g = g.transpose(1, 2)[..., None]\n                attn_out = attn_out * g\n            else:\n                attn_out = attn_out.transpose(1, 2).contiguous().view(b, t, -1)\n                g = torch.sigmoid(self.gate_proj(x))\n                attn_out = attn_out * g\n                return self.o_proj(attn_out)\n\n        attn_out = attn_out.transpose(1, 2).contiguous().view(b, t, -1)\n        return self.o_proj(attn_out)\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"Standard LLaMA GQA replication logic, optimized for ONNX export.\n    Equivalent to torch.repeat_interleave(x, dim=1, repeats=n_rep).\n\n    Input shape:  (batch, num_key_value_heads, seqlen, head_dim)\n    Output shape: (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n\n    # 1. Unsqueeze: [batch, n_kv, 1, seq, dim]\n    # 2. Expand:    [batch, n_kv, n_rep, seq, dim]\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n\n    # 3. Reshape:   [batch, n_kv * n_rep, seq, dim] -> [batch, n_head, seq, dim]\n    return hidden_states.reshape(\n        batch, num_key_value_heads * n_rep, slen, head_dim\n    )\n\n\ndef export_safe_scaled_dot_product_attention(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    *,\n    attn_mask: torch.Tensor | None,\n    dropout_p: float,\n    is_causal: bool,\n    enable_gqa: bool = False,\n) -> torch.Tensor:\n    if (\n        not torch.onnx.is_in_onnx_export()\n        or attn_mask is None\n        or attn_mask.dtype != torch.bool\n    ):\n        return F.scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attn_mask=attn_mask,\n            dropout_p=dropout_p,\n            is_causal=is_causal,\n            enable_gqa=enable_gqa,\n        )\n\n    # Use additive float bias during ONNX export so the legacy exporter\n    # does not emit the bool-mask SDPA cleanup path with IsNaN.\n    mask_bias = torch.zeros_like(attn_mask, dtype=q.dtype)\n    mask_bias = mask_bias.masked_fill(~attn_mask, torch.finfo(q.dtype).min)\n    return F.scaled_dot_product_attention(\n        q,\n        k,\n        v,\n        attn_mask=mask_bias,\n        dropout_p=dropout_p,\n        is_causal=is_causal,\n        enable_gqa=enable_gqa,\n    )\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    orig_dtype = q.dtype\n\n    # 强制转为 fp32 进行计算\n    q_fp32 = q.to(torch.float32)\n    k_fp32 = k.to(torch.float32)\n    cos_fp32 = cos.to(torch.float32)\n    sin_fp32 = sin.to(torch.float32)\n\n    q_embed = (q_fp32 * cos_fp32) + (rotate_half(q_fp32) * sin_fp32)\n    k_embed = (k_fp32 * cos_fp32) + (rotate_half(k_fp32) * sin_fp32)\n    return q_embed.to(orig_dtype), k_embed.to(orig_dtype)\n\n\ndef _grouped_linear(\n    input: torch.Tensor,\n    weight: torch.Tensor,\n    bias: torch.Tensor | None = None,\n    offs: torch.Tensor | None = None,\n) -> torch.Tensor:\n    \"\"\"input: [Total_Tokens, In_Dim]\n    weight: [Num_Experts, In_Dim, Out_Dim]\n    \"\"\"\n    orig_dtype = input.dtype\n    if input.dtype != weight.dtype:\n        input = input.to(weight.dtype)\n    out = torch._grouped_mm(input, weight, offs=offs)\n    if out.dtype != orig_dtype:\n        out = out.to(orig_dtype)\n    if bias is not None:\n        out = out + bias\n    return out\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/gmr_to_holomotion.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport json, os, sys\nfrom pathlib import Path\nfrom typing import Dict, Tuple, Optional, List\n\nimport joblib\nimport numpy as np\nimport torch\nimport hydra\nfrom omegaconf import OmegaConf, DictConfig, ListConfig\nfrom tqdm import tqdm\n\nfrom loguru import logger\nfrom holomotion.src.utils import torch_utils\nfrom holomotion.src.motion_retargeting.utils.torch_humanoid_batch import (\n    HumanoidBatch,\n)\nfrom holomotion.src.motion_retargeting.utils import (\n    rotation_conversions as rot_conv,\n)\nfrom holomotion.src.motion_retargeting.holomotion_preprocess import (\n    HoloMotionPreprocessor,\n    ProcessedClip,\n)\nimport ray\nimport logging\n\n\ndef quaternion_to_axis_angle(q: torch.Tensor) -> torch.Tensor:\n    q = q / torch.norm(q, dim=-1, keepdim=True)\n    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]\n    angle = 2 * torch.acos(torch.clamp(w, -1.0, 1.0))\n    s = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))\n    s = torch.clamp(s, min=1e-6)\n    ax = x / s\n    ay = y / s\n    az = z / s\n    axis_angles = torch.stack([ax * angle, ay * angle, az * angle], dim=-1)\n    return axis_angles\n\n\ndef dof_to_pose_aa(\n    dof_pos: np.ndarray,\n    robot_config_path: Optional[str],\n    root_rot: Optional[np.ndarray],\n) -> np.ndarray:\n    \"\"\"Compute pose_aa via FK; if no config is provided, return zero placeholders.\"\"\"\n    if not robot_config_path:\n        T = dof_pos.shape[0]\n        return np.zeros((T, 27, 3), dtype=np.float32)\n\n    robot_cfg = OmegaConf.load(robot_config_path)\n    logger.info(f\"Loaded robot config for FK from: {robot_config_path}\")\n    fk = HumanoidBatch(robot_cfg.robot)\n    num_aug = len(robot_cfg.robot.extend_config)\n\n    dof_t = torch.as_tensor(dof_pos, dtype=torch.float32)\n    if dof_t.dim() == 3 and dof_t.shape[-1] == 1:\n        dof_t = dof_t.squeeze(-1)\n    T = dof_t.shape[0]\n\n    if root_rot is None:\n        root_aa = torch.zeros((T, 3), dtype=torch.float32)\n    else:\n        rr = torch.as_tensor(root_rot, dtype=torch.float32)\n        root_aa = quaternion_to_axis_angle(rr) if rr.shape[-1] == 4 else rr\n\n    joint_aa = fk.dof_axis * dof_t.unsqueeze(-1)\n    pose_aa = torch.cat(\n        [root_aa.unsqueeze(1), joint_aa, torch.zeros((T, num_aug, 3))], dim=1\n    )\n    return pose_aa.numpy().astype(np.float32, copy=False)\n\n\ndef load_any_pkl(p: Path):\n    with open(p, \"rb\") as f:\n        return joblib.load(f)\n\n\ndef unwrap_source(obj) -> Dict[str, np.ndarray]:\n    \"\"\"Accept {top_key: inner} or flat dict (early GMR).\"\"\"\n    if isinstance(obj, dict) and len(obj) == 1:\n        inner = next(iter(obj.values()))\n        if isinstance(inner, dict):\n            return inner\n    if isinstance(obj, dict):\n        return obj\n    raise ValueError(\"Unsupported PKL structure\")\n\n\ndef make_motion_key(p: Path, src_dir: Path) -> str:\n    rel = p.relative_to(src_dir).with_suffix(\"\")\n    return \"/\".join(rel.parts)\n\n\ndef key_to_filename(key: str) -> str:\n    return key.replace(\"/\", \"_\") + \".npz\"\n\n\ndef get_ref_schema(\n    ref_dir: Path,\n) -> Tuple[Dict[str, Tuple[Tuple[int, ...], np.dtype]], str]:\n    \"\"\"\n    Read schema only from ref_dir/_schema.json.\n    Expected JSON structure:\n    {\n      \"schema\": {\n        \"root_trans_offset\": {\"shape\": [T, 3], \"dtype\": \"float64\"},\n        \"pose_aa\": {\"shape\": [T, 27, 3], \"dtype\": \"float32\"},\n        ...\n      },\n      \"sample_top_key\": \"xxx\"\n    }\n    \"\"\"\n    ref_dir = Path(ref_dir)\n    cache_path = ref_dir / \"_schema.json\"\n    if not cache_path.exists():\n        raise FileNotFoundError(f\"Schema JSON not found: {cache_path}\")\n    try:\n        with open(cache_path, \"r\", encoding=\"utf-8\") as f:\n            obj = json.load(f)\n    except Exception as e:\n        raise ValueError(f\"Failed to parse _schema.json: {e}\")\n\n    schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]] = {}\n    raw = obj.get(\"schema\", {})\n    if not isinstance(raw, dict) or not raw:\n        raise ValueError(\"Schema JSON missing 'schema' object or it's empty.\")\n\n    for k, v in raw.items():\n        if not isinstance(v, dict) or \"shape\" not in v or \"dtype\" not in v:\n            raise ValueError(f\"Bad schema entry for key '{k}': {v}\")\n        shape = tuple(int(x) for x in v[\"shape\"])\n        dtype = np.dtype(v[\"dtype\"])\n        schema[k] = (shape, dtype)\n\n    sample_top_key = obj.get(\"sample_top_key\", \"\")\n    return schema, sample_top_key\n\n\ndef infer_T(src_inner: Dict[str, np.ndarray]) -> Optional[int]:\n    for key in [\n        \"root_trans_offset\",\n        \"root_pos\",\n        \"pose_aa\",\n        \"dof\",\n        \"dof_pos\",\n        \"root_rot\",\n        \"smpl_joints\",\n    ]:\n        v = src_inner.get(key)\n        if isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] > 0:\n            return int(v.shape[0])\n    T = 0\n    for v in src_inner.values():\n        if isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] > T:\n            T = int(v.shape[0])\n    return T or None\n\n\ndef build_inner_from_source(\n    src_inner: Dict[str, np.ndarray],\n    schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]],\n    T_default: int,\n) -> Dict[str, object]:\n    alt_map = {\n        \"root_trans_offset\": [\n            \"root_trans_offset\",\n            \"root_pos\",\n            \"trans\",\n            \"root_trans\",\n        ],\n        \"pose_aa\": [\"pose_aa\"],\n        \"dof\": [\"dof\", \"dof_pos\"],\n        \"root_rot\": [\"root_rot\", \"root_orient\", \"root_quat\"],\n        \"smpl_joints\": [\"smpl_joints\", \"joints\", \"smpljoints\"],\n        \"fps\": [\"fps\", \"mocap_framerate\", \"mocap_frame_rate\"],\n    }\n    out: Dict[str, object] = {}\n    T = infer_T(src_inner) or T_default\n\n    for key, (shape, dtype) in schema.items():\n        if key == \"fps\":\n            fps = None\n            for cand in alt_map[\"fps\"]:\n                v = src_inner.get(cand)\n                if isinstance(v, (int, np.integer)):\n                    fps = int(v)\n                    break\n            out[\"fps\"] = int(fps) if fps is not None else 30\n            continue\n\n        src_arr = None\n        for cand in alt_map.get(key, []):\n            v = src_inner.get(cand)\n            if isinstance(v, np.ndarray) and v.ndim >= 1:\n                src_arr = v\n                break\n\n        # Target shape: override leading T; keep source column count for DOF\n        if key == \"dof\" and isinstance(src_arr, np.ndarray):\n            target_shape = (T, src_arr.shape[1] if src_arr.ndim >= 2 else 1)\n        else:\n            ts = list(shape)\n            if ts:\n                ts[0] = T\n            target_shape = tuple(ts)\n\n        if src_arr is None:\n            out[key] = np.zeros(target_shape, dtype=dtype)\n            continue\n\n        arr = src_arr.astype(dtype, copy=False)\n\n        if key == \"dof\" and arr.ndim == 1:\n            arr = arr.reshape(-1, 1)\n        if arr.shape[0] > T:\n            arr = arr[:T]\n        elif arr.shape[0] < T:\n            pad = np.repeat(arr[-1:], T - arr.shape[0], axis=0)\n            arr = np.concatenate([arr, pad], axis=0)\n\n        if (\n            key != \"dof\"\n            and len(target_shape) == 2\n            and arr.shape[1] != target_shape[1]\n        ):\n            if arr.shape[1] > target_shape[1]:\n                arr = arr[:, : target_shape[1]]\n            else:\n                pad = np.zeros(\n                    (T, target_shape[1] - arr.shape[1]), dtype=arr.dtype\n                )\n                arr = np.concatenate([arr, pad], axis=1)\n\n        if len(target_shape) == 3:\n            d1 = min(arr.shape[1], target_shape[1])\n            d2 = min(arr.shape[2], target_shape[2])\n            arr = arr[:, :d1, :d2]\n            if (arr.shape[1], arr.shape[2]) != (\n                target_shape[1],\n                target_shape[2],\n            ):\n                pad = np.zeros(\n                    (\n                        T,\n                        target_shape[1] - arr.shape[1],\n                        target_shape[2] - arr.shape[2],\n                    ),\n                    dtype=arr.dtype,\n                )\n                arr = np.concatenate([arr, pad], axis=1)\n                if arr.shape[2] != target_shape[2]:\n                    pad2 = np.zeros(\n                        (T, target_shape[1], target_shape[2] - arr.shape[2]),\n                        dtype=arr.dtype,\n                    )\n                    arr = np.concatenate([arr, pad2], axis=2)\n\n        out[key] = arr.astype(dtype, copy=False)\n    return out\n\n\ndef to_torch(tensor):\n    if torch.is_tensor(tensor):\n        return tensor\n    else:\n        return torch.from_numpy(tensor.copy())\n\n\ndef batch_interpolate_tensor(\n    tensor, orig_times, target_times, use_slerp=False\n):\n    \"\"\"Optimized tensor interpolation with batch processing\"\"\"\n    target_num_frames = len(target_times)\n    shape = list(tensor.shape)\n    shape[0] = target_num_frames\n\n    # Create empty output tensor\n    result = torch.zeros(shape, device=tensor.device, dtype=tensor.dtype)\n\n    if len(tensor.shape) == 2:\n        # For 2D tensors - process all frames at once\n        # Create masks for the three cases\n        before_mask = target_times <= orig_times[0]\n        after_mask = target_times >= orig_times[-1]\n        valid_mask = ~(before_mask | after_mask)\n\n        # Handle edge cases\n        if before_mask.any():\n            result[before_mask] = tensor[0]\n        if after_mask.any():\n            result[after_mask] = tensor[-1]\n\n        # Process interpolation for valid times\n        if valid_mask.any():\n            valid_times = target_times[valid_mask]\n            # Get indices for lower frames\n            indices = torch.searchsorted(orig_times, valid_times) - 1\n            # Ensure indices are valid\n            indices = torch.clamp(indices, 0, len(orig_times) - 2)\n            next_indices = indices + 1\n\n            # Calculate weights\n            alphas = (valid_times - orig_times[indices]) / (\n                orig_times[next_indices] - orig_times[indices]\n            )\n            alphas = alphas.unsqueeze(-1)  # Add dimension for broadcasting\n\n            if use_slerp and tensor.shape[1] == 4:  # Quaternion data\n                # Process in smaller batches to avoid memory issues\n                batch_size = 1000  # Adjust based on available memory\n                num_valid = valid_mask.sum()\n\n                for i in range(0, num_valid, batch_size):\n                    end_idx = min(i + batch_size, num_valid)\n                    batch_indices = torch.where(valid_mask)[0][i:end_idx]\n                    batch_alphas = alphas[i:end_idx]\n                    batch_lower_indices = indices[i:end_idx]\n                    batch_upper_indices = next_indices[i:end_idx]\n\n                    # Get frame data for this batch\n                    frames_low = tensor[batch_lower_indices]\n                    frames_high = tensor[batch_upper_indices]\n\n                    # Apply SLERP to this batch\n                    result[batch_indices] = torch_utils.slerp(\n                        frames_low, frames_high, batch_alphas\n                    )\n            else:\n                # Standard linear interpolation - can be done in one batch\n                frames_low = tensor[indices]\n                frames_high = tensor[next_indices]\n                result[valid_mask] = (\n                    frames_low * (1 - alphas) + frames_high * alphas\n                )\n\n    elif len(tensor.shape) == 3:\n        # For 3D tensors - process each joint sequence\n        for j in range(tensor.shape[1]):\n            result[:, j] = batch_interpolate_tensor(\n                tensor[:, j], orig_times, target_times, use_slerp\n            )\n\n    return result\n\n\ndef fast_interpolate_motion(motion_dict, source_fps, target_fps):\n    \"\"\"Optimized motion interpolation that preserves correctness\"\"\"\n    # Early return if no interpolation needed\n    if source_fps == target_fps:\n        return motion_dict\n\n    # Calculate timestamps\n    orig_dt = 1.0 / source_fps\n    target_dt = 1.0 / target_fps\n\n    # Find the first tensor to determine number of frames\n    for v in motion_dict.values():\n        if torch.is_tensor(v):\n            num_frames = v.shape[0]\n            device = v.device\n            break\n    else:\n        return motion_dict  # No tensor data to interpolate\n\n    orig_times = torch.arange(0, num_frames, device=device) * orig_dt\n    wallclock_len = orig_dt * (num_frames - 1)\n    target_num_frames = int(wallclock_len * target_fps) + 1\n    target_times = (\n        torch.arange(0, target_num_frames, device=device) * target_dt\n    )\n\n    # Create interpolated motion dictionary\n    interp_motion = {}\n\n    for k, v in motion_dict.items():\n        if not torch.is_tensor(v):\n            interp_motion[k] = v\n            continue\n\n        is_quat = \"quat\" in k\n        interp_motion[k] = batch_interpolate_tensor(\n            v, orig_times, target_times, is_quat\n        )\n\n    return interp_motion\n\n\ndef process_single_motion(\n    robot_cfg: dict,\n    all_samples,  # Can be dict or LazyMotionLoader\n    curr_key: str,\n    target_fps: int = 50,\n    fast_interpolate: bool = True,\n    debug_mode: bool = False,\n):\n    logger.debug(f\"Starting process_single_motion for key: {curr_key}\")\n\n    humanoid_fk = HumanoidBatch(robot_cfg)\n\n    motion_sample_dict = all_samples[curr_key]\n\n    if len(motion_sample_dict) == 1:\n        motion_sample_dict = motion_sample_dict[\n            list(motion_sample_dict.keys())[0]\n        ]\n\n    logger.debug(\"Step 3: Extracting sequence length\")\n    if debug_mode:\n        # In debug mode, let exceptions bubble up naturally\n        if \"root_trans_offset\" not in motion_sample_dict:\n            available_keys = list(motion_sample_dict.keys())\n            raise KeyError(\n                f\"'root_trans_offset' not found in motion data. Available keys: {available_keys}\"\n            )\n        seq_len = motion_sample_dict[\"root_trans_offset\"].shape[0]\n        start, end = 0, seq_len\n        logger.debug(f\"Step 3 completed - seq_len: {seq_len}\")\n    else:\n        try:\n            if \"root_trans_offset\" not in motion_sample_dict:\n                available_keys = list(motion_sample_dict.keys())\n                raise KeyError(\n                    f\"'root_trans_offset' not found in motion data. Available keys: {available_keys}\"\n                )\n            seq_len = motion_sample_dict[\"root_trans_offset\"].shape[0]\n            start, end = 0, seq_len\n            logger.debug(f\"Step 3 completed - seq_len: {seq_len}\")\n        except Exception as e:\n            logger.error(\n                f\"Step 3 failed - Extracting sequence length: {e}\",\n                exc_info=True,\n            )\n            raise RuntimeError(\n                f\"Failed to extract sequence length: {e}\"\n            ) from e\n\n    logger.debug(\"Step 4: Processing root translation\")\n    if debug_mode:\n        # In debug mode, let exceptions bubble up naturally\n        trans = to_torch(motion_sample_dict[\"root_trans_offset\"]).clone()[\n            start:end\n        ]\n        logger.debug(f\"Step 4 completed - trans shape: {trans.shape}\")\n    else:\n        try:\n            trans = to_torch(motion_sample_dict[\"root_trans_offset\"]).clone()[\n                start:end\n            ]\n            logger.debug(f\"Step 4 completed - trans shape: {trans.shape}\")\n        except Exception as e:\n            logger.error(\n                f\"Step 4 failed - Processing root translation: {e}\",\n                exc_info=True,\n            )\n            raise RuntimeError(\n                f\"Failed to process root translation: {e}\"\n            ) from e\n\n    logger.debug(\"Step 5: Processing pose_aa\")\n    if debug_mode:\n        # In debug mode, let exceptions bubble up naturally\n        if \"pose_aa\" not in motion_sample_dict:\n            available_keys = list(motion_sample_dict.keys())\n            raise KeyError(\n                f\"'pose_aa' not found in motion data. Available keys: {available_keys}\"\n            )\n        pose_aa = to_torch(motion_sample_dict[\"pose_aa\"][start:end]).clone()\n        # If available, enforce root rotation from input quaternions (XYZW)\n        if \"root_rot\" in motion_sample_dict:\n            root_quat_xyzw = to_torch(\n                motion_sample_dict[\"root_rot\"][start:end]\n            ).clone()\n            root_quat_wxyz = rot_conv.xyzw_to_wxyz(root_quat_xyzw)\n            root_axis_angle = rot_conv.quaternion_to_axis_angle(root_quat_wxyz)\n            pose_aa[:, 0, :] = root_axis_angle\n        logger.debug(f\"Step 5 completed - pose_aa shape: {pose_aa.shape}\")\n    else:\n        try:\n            if \"pose_aa\" not in motion_sample_dict:\n                available_keys = list(motion_sample_dict.keys())\n                raise KeyError(\n                    f\"'pose_aa' not found in motion data. Available keys: {available_keys}\"\n                )\n            pose_aa = to_torch(\n                motion_sample_dict[\"pose_aa\"][start:end]\n            ).clone()\n            # If available, enforce root rotation from input quaternions (XYZW)\n            if \"root_rot\" in motion_sample_dict:\n                root_quat_xyzw = to_torch(\n                    motion_sample_dict[\"root_rot\"][start:end]\n                ).clone()\n                root_quat_wxyz = rot_conv.xyzw_to_wxyz(root_quat_xyzw)\n                root_axis_angle = rot_conv.quaternion_to_axis_angle(\n                    root_quat_wxyz\n                )\n                pose_aa[:, 0, :] = root_axis_angle\n            logger.debug(f\"Step 5 completed - pose_aa shape: {pose_aa.shape}\")\n        except Exception as e:\n            logger.error(\n                f\"Step 5 failed - Processing pose_aa: {e}\", exc_info=True\n            )\n            raise RuntimeError(f\"Failed to process pose_aa: {e}\") from e\n\n    logger.debug(\"Step 6: Calculating dt\")\n    if debug_mode:\n        # In debug mode, let exceptions bubble up naturally\n        if \"fps\" not in motion_sample_dict:\n            available_keys = list(motion_sample_dict.keys())\n            raise KeyError(\n                f\"'fps' not found in motion data. Available keys: {available_keys}\"\n            )\n        fps = motion_sample_dict[\"fps\"]\n        if fps <= 0:\n            raise ValueError(f\"Invalid fps value: {fps}\")\n        dt = 1 / fps\n        logger.debug(f\"Step 6 completed - fps: {fps}, dt: {dt}\")\n    else:\n        try:\n            if \"fps\" not in motion_sample_dict:\n                available_keys = list(motion_sample_dict.keys())\n                raise KeyError(\n                    f\"'fps' not found in motion data. Available keys: {available_keys}\"\n                )\n            fps = motion_sample_dict[\"fps\"]\n            if fps <= 0:\n                raise ValueError(f\"Invalid fps value: {fps}\")\n            dt = 1 / fps\n            logger.debug(f\"Step 6 completed - fps: {fps}, dt: {dt}\")\n        except Exception as e:\n            logger.error(f\"Step 6 failed - Calculating dt: {e}\", exc_info=True)\n            raise RuntimeError(f\"Failed to calculate dt: {e}\") from e\n\n    logger.debug(\"Step 8: Running forward kinematics\")\n    if debug_mode:\n        # In debug mode, let exceptions bubble up naturally\n        curr_motion = humanoid_fk.fk_batch(\n            pose_aa[None,],\n            trans[None,],\n            return_full=True,\n            dt=dt,\n        )\n        logger.debug(\"Step 8 completed\")\n    else:\n        try:\n            curr_motion = humanoid_fk.fk_batch(\n                pose_aa[None,],\n                trans[None,],\n                return_full=True,\n                dt=dt,\n            )\n            logger.debug(\"Step 8 completed\")\n        except Exception as e:\n            logger.error(\n                f\"Step 8 failed - Forward kinematics: {e}\", exc_info=True\n            )\n            raise RuntimeError(f\"Failed to run forward kinematics: {e}\") from e\n    curr_motion = dict(\n        {\n            k: v.squeeze() if torch.is_tensor(v) else v\n            for k, v in curr_motion.items()\n        }\n    )\n    motion_fps = curr_motion[\"fps\"]\n    motion_dt = 1.0 / motion_fps\n    num_frames = curr_motion[\"global_rotation\"].shape[0]\n    wallclock_len = motion_dt * (num_frames - 1)\n    num_dofs = len(robot_cfg.motion.dof_names)\n    num_bodies = len(robot_cfg.motion.body_names)\n    num_extended_bodies = num_bodies + len(\n        robot_cfg.motion.get(\"extend_config\", [])\n    )\n\n    # build a frame_flag array to indicate three status:\n    # start_of_motion: 0, middle_of_motion: 1, end_of_motion: 2\n    frame_flag = torch.ones(num_frames).int()\n    frame_flag[0] = 0\n    frame_flag[-1] = 2\n    curr_motion[\"frame_flag\"] = frame_flag\n\n    # rename and pop some keys\n    curr_motion[\"global_rotation_quat\"] = curr_motion.pop(\"global_rotation\")\n    curr_motion[\"local_rotation_quat\"] = curr_motion.pop(\"local_rotation\")\n    if \"global_translation_extend\" in curr_motion:\n        curr_motion[\"global_rotation_quat_extend\"] = curr_motion.pop(\n            \"global_rotation_extend\"\n        )\n    curr_motion.pop(\"fps\")\n    curr_motion.pop(\"global_rotation_mat\")\n    if \"global_rotation_mat_extend\" in curr_motion:\n        curr_motion.pop(\"global_rotation_mat_extend\")\n\n    # add some keys\n    curr_motion[\"global_root_translation\"] = curr_motion[\"global_translation\"][\n        :, 0\n    ]\n    curr_motion[\"global_root_rotation_quat\"] = curr_motion[\n        \"global_rotation_quat\"\n    ][:, 0]\n\n    # Interpolate to target_fps if different from original fps\n    if target_fps != motion_fps:\n        curr_motion = fast_interpolate_motion(\n            curr_motion, motion_fps, target_fps\n        )\n        motion_fps = target_fps\n        motion_dt = 1.0 / target_fps\n        num_frames = (\n            next(iter(curr_motion.values())).shape[0]\n            if curr_motion\n            else num_frames\n        )\n        wallclock_len = motion_dt * (num_frames - 1)\n\n    sample_dict = {\n        \"motion_name\": curr_key,\n        \"motion_fps\": motion_fps,\n        \"num_frames\": num_frames,\n        \"wallclock_len\": wallclock_len,\n        \"num_dofs\": num_dofs,\n        \"num_bodies\": num_bodies,\n        \"num_extended_bodies\": num_extended_bodies,\n    }\n    sample_dict.update(\n        {\n            k: curr_motion[k].float().cpu().numpy()\n            for k in sorted(curr_motion.keys())\n        }\n    )\n\n    if debug_mode:\n        for k, v in sample_dict.items():\n            if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray):\n                logger.debug(f\"{k}: {v.shape}\")\n            else:\n                logger.debug(f\"{k}: {v}\")\n\n    return sample_dict\n\n\nclass InMemoryAlignedLoader:\n    \"\"\"Minimal Loader interface: compatible with process_single_motion sample access.\"\"\"\n\n    def __init__(self, mapping: Dict[str, Dict[str, object]]):\n        self._map = mapping\n\n    def keys(self) -> List[str]:\n        return list(self._map.keys())\n\n    def __len__(self):\n        return len(self._map)\n\n    def __getitem__(self, k: str):\n        return self._map[k]\n\n    def load(self, k: str):\n        return self._map[k]\n\n    def get(self, k: str, default=None):\n        return self._map.get(k, default)\n\n\ndef arrays_for_npz(\n    sample: Dict, emit_prefixed: bool = True, emit_legacy: bool = False\n) -> Dict[str, np.ndarray]:\n    \"\"\"\n    Build NPZ arrays:\n    - Always include frame_flag if present\n    - If emit_prefixed: write ref_* arrays mapped from base keys\n    - If emit_legacy: also include legacy, unprefixed keys for compatibility\n    \"\"\"\n    base_to_ref = {\n        \"dof_pos\": \"ref_dof_pos\",\n        \"dof_vel\": \"ref_dof_vel\",\n        \"dof_vels\": \"ref_dof_vel\",\n        \"global_translation\": \"ref_global_translation\",\n        \"global_rotation_quat\": \"ref_global_rotation_quat\",\n        \"global_velocity\": \"ref_global_velocity\",\n        \"global_angular_velocity\": \"ref_global_angular_velocity\",\n    }\n    out: Dict[str, np.ndarray] = {}\n    if isinstance(sample.get(\"frame_flag\"), np.ndarray):\n        out[\"frame_flag\"] = sample[\"frame_flag\"]\n    for base, ref_name in base_to_ref.items():\n        v = sample.get(base, None)\n        if isinstance(v, np.ndarray):\n            if emit_prefixed:\n                out[ref_name] = v\n            if emit_legacy:\n                out[base] = v\n    return out\n\n\n@ray.remote\nclass MotionProcessorActor:\n    \"\"\"\n    Persistent Ray actor that loads robot config once and processes PKLs asynchronously.\n    \"\"\"\n\n    def __init__(\n        self,\n        robot_cfg_path: str,\n        schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]],\n    ):\n        cfg = OmegaConf.load(robot_cfg_path)\n        self.robot_cfg = cfg.robot\n        self.schema = schema\n        # Cached FK holder for DOF → axis-angle conversion (uses dof_axis)\n        self._fk_for_dof = HumanoidBatch(self.robot_cfg)\n\n    def _dof_to_pose_aa_cached(\n        self, dof_pos: np.ndarray, root_rot: Optional[np.ndarray]\n    ) -> np.ndarray:\n        dof_t = torch.as_tensor(dof_pos, dtype=torch.float32)\n        if dof_t.dim() == 3 and dof_t.shape[-1] == 1:\n            dof_t = dof_t.squeeze(-1)\n        T = int(dof_t.shape[0])\n\n        if root_rot is None:\n            root_aa = torch.zeros((T, 3), dtype=torch.float32)\n        else:\n            rr = torch.as_tensor(root_rot, dtype=torch.float32)\n            root_aa = quaternion_to_axis_angle(rr) if rr.shape[-1] == 4 else rr\n\n        num_aug = len(self.robot_cfg.extend_config)\n        joint_aa = self._fk_for_dof.dof_axis * dof_t[:, :, None]\n        pose_aa = torch.cat(\n            [root_aa[:, None, :], joint_aa, torch.zeros((T, num_aug, 3))],\n            dim=1,\n        )\n        return pose_aa.numpy().astype(np.float32, copy=False)\n\n    def process_pkl(\n        self,\n        p_str: str,\n        src_dir_str: str,\n        target_fps: int,\n        fast_interpolate: bool,\n        debug_mode: bool,\n    ) -> Tuple[bool, Dict[str, object]]:\n        \"\"\"\n        Returns (success, payload). On success, payload contains:\n          { \"flat_key\": str, \"sample\": Dict[str, np.ndarray|scalar] }\n        \"\"\"\n        p = Path(p_str)\n        src_dir = Path(src_dir_str)\n        motion_key_rel = make_motion_key(p, src_dir)\n        flat_key = motion_key_rel.replace(\"/\", \"_\")\n\n        obj = load_any_pkl(p)\n        inner = unwrap_source(obj)\n        T_default = infer_T(inner) or 1\n        aligned = build_inner_from_source(inner, self.schema, T_default)\n\n        dof = aligned.get(\"dof\")\n        if isinstance(dof, np.ndarray) and dof.size > 0:\n            root_rot = aligned.get(\"root_rot\", None)\n            aligned[\"pose_aa\"] = self._dof_to_pose_aa_cached(dof, root_rot)\n\n        loader = InMemoryAlignedLoader({flat_key: aligned})\n        sample = process_single_motion(\n            self.robot_cfg,\n            loader,\n            flat_key,\n            int(target_fps),\n            bool(fast_interpolate),\n            bool(debug_mode),\n        )\n        payload: Dict[str, object] = {\"flat_key\": flat_key, \"sample\": sample}\n        return True, payload\n\n\n@hydra.main(\n    config_path=\"../../config\",\n    config_name=\"motion_retargeting/gmr_to_holomotion\",\n    version_base=None,\n)\ndef main(cfg: DictConfig) -> None:\n    # Setup logging\n    logger.remove()\n    log_level = \"DEBUG\" if bool(cfg.processing.debug_mode) else \"INFO\"\n    logger.add(sys.stderr, level=log_level, colorize=True)\n\n    src_path = Path(str(cfg.io.src_dir)).expanduser().resolve()\n    ref_dir = Path(str(cfg.io.ref_dir)).expanduser().resolve()\n    out_root = Path(str(cfg.io.out_root)).expanduser().resolve()\n    clips_dir = out_root / \"clips\"\n    clips_dir.mkdir(parents=True, exist_ok=True)\n\n    # dump resolved config used\n    (out_root).mkdir(parents=True, exist_ok=True)\n    with open(out_root / \"config_used.yaml\", \"w\") as f:\n        f.write(OmegaConf.to_yaml(cfg))\n\n    # 1) schema from _schema.json\n    schema, _ = get_ref_schema(ref_dir)\n\n    # 2) gather PKLs\n    if src_path.is_file() and src_path.suffix == \".pkl\":\n        src_pkls = [src_path]\n        root_for_keys = src_path.parent\n    else:\n        src_pkls = []\n        for dirpath, _, filenames in os.walk(src_path, followlinks=True):\n            for filename in filenames:\n                if filename.endswith(\".pkl\"):\n                    p = Path(dirpath) / filename\n                    if p.is_file():\n                        src_pkls.append(p)\n        src_pkls = sorted(src_pkls)\n        root_for_keys = src_path\n\n    # 3) quiet third-party DEBUG logs (e.g., filelock/Ray)\n    logging.getLogger(\"filelock\").setLevel(logging.WARNING)\n    logging.getLogger(\"ray\").setLevel(logging.ERROR)\n    os.environ.setdefault(\"RAY_BACKEND_LOG_LEVEL\", \"error\")\n\n    # 4) initialize Ray\n    if str(cfg.ray.ray_address):\n        ray.init(\n            address=str(cfg.ray.ray_address),\n            ignore_reinit_error=True,\n            log_to_driver=False,\n            include_dashboard=False,\n            logging_level=logging.ERROR,\n        )\n    else:\n        num_cpus = (\n            None if int(cfg.ray.num_workers) <= 0 else int(cfg.ray.num_workers)\n        )\n        ray.init(\n            num_cpus=num_cpus,\n            ignore_reinit_error=True,\n            log_to_driver=False,\n            include_dashboard=False,\n            logging_level=logging.ERROR,\n        )\n\n    # 5) build work list (skip existing if requested)\n    skip_existing = bool(cfg.processing.skip_existing)\n    work_list: List[Path] = []\n    for p in src_pkls:\n        motion_key = make_motion_key(p, root_for_keys)\n        out_name = key_to_filename(motion_key)\n        if skip_existing and (clips_dir / out_name).exists():\n            continue\n        work_list.append(p)\n\n    if not work_list:\n        logger.info(\"No tasks to run (all outputs exist or no PKLs found).\")\n        ray.shutdown()\n        return\n\n    # 6) create persistent actors (each loads robot config once)\n    if int(cfg.ray.num_workers) > 0:\n        num_actors = min(len(work_list), int(cfg.ray.num_workers))\n    else:\n        available_cpus = int(ray.available_resources().get(\"CPU\", 1))\n        num_actors = min(len(work_list), max(1, available_cpus))\n    actors = [\n        MotionProcessorActor.remote(str(cfg.io.robot_config), schema)\n        for _ in range(num_actors)\n    ]\n\n    # Parse pipeline config\n    pipeline_cfg = cfg.get(\"preprocess\", None)\n    pipeline = None\n    if pipeline_cfg is not None:\n        pipeline_val = pipeline_cfg.get(\"pipeline\", None)\n        if pipeline_val is not None:\n            if isinstance(pipeline_val, (list, tuple, ListConfig)):\n                pipeline = [str(s) for s in pipeline_val]\n            elif isinstance(pipeline_val, str):\n                import ast\n\n                pipeline = ast.literal_eval(pipeline_val)\n            else:\n                logger.warning(\n                    f\"Unexpected pipeline type: {type(pipeline_val)}, value: {pipeline_val}\"\n                )\n                pipeline = []\n        else:\n            pipeline = []\n    else:\n        pipeline = []\n\n    # Separate per-clip stages from dataset-level stages\n    per_clip_pipeline = (\n        [s for s in pipeline if s != \"tagging\"] if pipeline else []\n    )\n    tagging_enabled = pipeline and \"tagging\" in pipeline\n\n    logger.info(\"=\" * 80)\n    logger.info(\"Preprocessing Configuration:\")\n    if pipeline:\n        logger.info(f\"  Pipeline stages: {pipeline}\")\n        logger.info(f\"  Number of stages: {len(pipeline)}\")\n        for i, stage in enumerate(pipeline, 1):\n            logger.info(f\"    {i}. {stage}\")\n        if tagging_enabled:\n            logger.info(\n                \"  Note: 'tagging' is a dataset-level operation and will run after all clips are processed\"\n            )\n    else:\n        logger.info(\n            \"  No preprocessing pipeline specified - no processors will be applied\"\n        )\n    logger.info(\"=\" * 80)\n\n    preprocessor = HoloMotionPreprocessor(\n        slicing_cfg=cfg.slicing,\n        filtering_cfg=cfg.filtering,\n        tagging_cfg=cfg.tagging,\n        padding_cfg=cfg.get(\"padding\", None),\n        pipeline=per_clip_pipeline if per_clip_pipeline else None,\n    )\n\n    # 7) asynchronously schedule PKLs to actors (round-robin)\n    pending = {}\n    next_idx = 0\n    # prime the queue\n    for i in range(min(num_actors, len(work_list))):\n        p = work_list[next_idx]\n        next_idx += 1\n        ref = actors[i].process_pkl.remote(\n            str(p),\n            str(root_for_keys),\n            int(cfg.processing.target_fps),\n            bool(cfg.processing.fast_interpolate),\n            bool(cfg.processing.debug_mode),\n        )\n        pending[ref] = i\n\n    # 8) collect results and keep feeding new tasks (post-process in-memory, then write)\n    total_outputs = 0\n    with tqdm(total=len(work_list), desc=\"Ray: PKL→NPZ (Hydra)\") as pbar:\n        while pending:\n            done, _ = ray.wait(list(pending.keys()), num_returns=1)\n            ref = done[0]\n            actor_idx = pending.pop(ref)\n            ok, payload = ray.get(ref)\n            if ok:\n                flat_key: str = payload[\"flat_key\"]  # type: ignore[assignment]\n                sample: Dict = payload[\"sample\"]  # type: ignore[assignment]\n                arrays_ref = arrays_for_npz(\n                    sample,\n                    emit_prefixed=bool(cfg.naming.emit_prefixed),\n                    emit_legacy=bool(cfg.naming.emit_legacy),\n                )\n                base_meta = {\n                    \"motion_key\": flat_key,\n                    \"raw_motion_key\": flat_key,\n                    \"motion_fps\": float(sample[\"motion_fps\"]),\n                    \"num_frames\": int(sample[\"num_frames\"]),\n                    \"wallclock_len\": float(sample[\"wallclock_len\"]),\n                    \"num_dofs\": int(sample[\"num_dofs\"]),\n                    \"num_bodies\": int(sample[\"num_bodies\"]),\n                    \"num_extended_bodies\": int(sample[\"num_extended_bodies\"]),\n                    \"slice_start\": 0,\n                    \"slice_end\": int(sample[\"num_frames\"]),\n                }\n                base_clip = ProcessedClip(\n                    motion_key=flat_key,\n                    metadata=base_meta,\n                    arrays=arrays_ref,\n                )\n                clips = preprocessor.process_clip(base_clip)\n                for clip in clips:\n                    out_name = f\"{clip.motion_key}.npz\"\n                    out_path = clips_dir / out_name\n                    np.savez_compressed(\n                        out_path,\n                        metadata=json.dumps(clip.metadata),\n                        **clip.arrays,\n                    )\n                    total_outputs += 1\n            else:\n                logger.warning(f\"Processing failed: {payload}\")\n            pbar.update(1)\n            if next_idx < len(work_list):\n                p = work_list[next_idx]\n                next_idx += 1\n                new_ref = actors[actor_idx].process_pkl.remote(\n                    str(p),\n                    str(root_for_keys),\n                    int(cfg.processing.target_fps),\n                    bool(cfg.processing.fast_interpolate),\n                    bool(cfg.processing.debug_mode),\n                )\n                pending[new_ref] = actor_idx\n\n    # 9) Optional kinematic tagging (write to out_root level)\n    if tagging_enabled:\n        tags_path = (\n            Path(str(cfg.tagging.output_json_path)).expanduser().resolve()\n            if str(cfg.tagging.output_json_path)\n            else (out_root / \"kinematic_tags.json\")\n        )\n        preprocessor.tag_directory(clips_dir, tags_path)\n\n    logger.info(\n        f\"Done. NPZ written to: {clips_dir} (total clips: {total_outputs})\"\n    )\n    ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/holomotion_fk.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom __future__ import annotations\n\nimport os\nimport xml.etree.ElementTree as ETree\nfrom typing import Dict, List, Tuple\n\nimport torch\nimport pytorch_kinematics as pk\n\nfrom loguru import logger\nfrom holomotion.src.utils import torch_utils\n\n\nclass MJCFParser:\n    def __init__(self, robot_file_path: str) -> None:\n        self._robot_file_path = robot_file_path\n\n    @staticmethod\n    def parse_vec(\n        text: str | None, size: int, default: List[float]\n    ) -> List[float]:\n        if text is None:\n            return list(default)\n        values = [float(v) for v in text.strip().split()]\n        if len(values) != size:\n            raise ValueError(\n                f\"Expected {size} values, got {len(values)} in '{text}'\"\n            )\n        return values\n\n    @staticmethod\n    def _find_parent(\n        root: ETree.Element, child: ETree.Element\n    ) -> ETree.Element | None:\n        for parent in root.iter():\n            for node in list(parent):\n                if node is child:\n                    return parent\n        return None\n\n    @staticmethod\n    def _select_include_children(\n        parent: ETree.Element, inc_root: ETree.Element\n    ) -> List[ETree.Element]:\n        if inc_root.tag == \"mujoco\":\n            if parent.tag != \"mujoco\":\n                sub = inc_root.find(parent.tag)\n                if sub is not None:\n                    return list(sub)\n            return list(inc_root)\n        if inc_root.tag == parent.tag:\n            return list(inc_root)\n        return list(inc_root)\n\n    @classmethod\n    def _resolve_includes(cls, root: ETree.Element, base_dir: str) -> None:\n        includes = root.findall(\".//include\")\n        while includes:\n            for inc in includes:\n                inc_file = inc.attrib.get(\"file\")\n                if inc_file is None:\n                    raise ValueError(\"Include tag missing 'file' attribute\")\n                inc_path = os.path.join(base_dir, inc_file)\n                inc_root = ETree.parse(inc_path).getroot()\n                cls._resolve_includes(inc_root, os.path.dirname(inc_path))\n                parent = cls._find_parent(root, inc)\n                if parent is None:\n                    raise ValueError(\"Failed to resolve include parent\")\n                insert_children = cls._select_include_children(\n                    parent, inc_root\n                )\n                insert_index = list(parent).index(inc)\n                for child in list(insert_children):\n                    parent.insert(insert_index, child)\n                    insert_index += 1\n                parent.remove(inc)\n            includes = root.findall(\".//include\")\n\n    def load_root(self) -> ETree.Element:\n        root = ETree.parse(self._robot_file_path).getroot()\n        self._resolve_includes(root, os.path.dirname(self._robot_file_path))\n        return root\n\n    def parse(\n        self,\n    ) -> Tuple[\n        List[str],\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        List[List[str]],\n        Dict[str, int],\n        Dict[str, List[float]],\n        List[str],\n        torch.Tensor,\n        List[List[int]],\n    ]:\n        root = self.load_root()\n        xml_world = root.find(\"worldbody\")\n        if xml_world is None:\n            raise ValueError(\"MJCF missing worldbody\")\n        xml_body_root = xml_world.find(\"body\")\n        if xml_body_root is None:\n            raise ValueError(\"MJCF missing root body\")\n\n        body_names: List[str] = []\n        parents: List[int] = []\n        local_translation: List[List[float]] = []\n        local_rotation: List[List[float]] = []\n        body_joint_order: List[List[str]] = []\n        joint_body_index: Dict[str, int] = {}\n        joint_axis: Dict[str, List[float]] = {}\n\n        def _add_body(xml_body: ETree.Element, parent_index: int) -> None:\n            body_idx = len(body_names)\n            body_names.append(xml_body.attrib.get(\"name\", \"\"))\n            parents.append(parent_index)\n            local_translation.append(\n                self.parse_vec(xml_body.attrib.get(\"pos\"), 3, [0.0, 0.0, 0.0])\n            )\n            local_rotation.append(\n                self.parse_vec(\n                    xml_body.attrib.get(\"quat\"), 4, [1.0, 0.0, 0.0, 0.0]\n                )\n            )\n            joints_in_body: List[str] = []\n            for joint in xml_body.findall(\"joint\"):\n                joint_name = joint.attrib.get(\"name\")\n                if joint_name is None:\n                    raise ValueError(\"Joint missing name\")\n                joint_type = joint.attrib.get(\"type\", \"hinge\")\n                if joint_type == \"free\":\n                    continue\n                if joint_type != \"hinge\":\n                    raise ValueError(f\"Unsupported joint type: {joint_type}\")\n                axis = self.parse_vec(\n                    joint.attrib.get(\"axis\"), 3, [0.0, 0.0, 1.0]\n                )\n                joint_body_index[joint_name] = body_idx\n                joint_axis[joint_name] = axis\n                joints_in_body.append(joint_name)\n            body_joint_order.append(joints_in_body)\n            for child in xml_body.findall(\"body\"):\n                _add_body(child, body_idx)\n\n        _add_body(xml_body_root, -1)\n        if local_translation:\n            local_translation[0] = [0.0, 0.0, 0.0]\n            local_rotation[0] = [1.0, 0.0, 0.0, 0.0]\n\n        dof_names: List[str] = []\n        for elem in root.iter():\n            if elem.tag == \"actuator\":\n                for child in list(elem):\n                    joint_name = child.attrib.get(\"joint\")\n                    if joint_name is not None:\n                        dof_names.append(joint_name)\n        if len(dof_names) == 0:\n            raise ValueError(\"No actuated joints found in MJCF\")\n        dof_axis: List[List[float]] = []\n        for joint_name in dof_names:\n            if joint_name not in joint_body_index:\n                raise ValueError(f\"Actuator joint not found: {joint_name}\")\n            dof_axis.append(joint_axis[joint_name])\n\n        dof_name_to_index = {name: idx for idx, name in enumerate(dof_names)}\n        body_dof_indices: List[List[int]] = []\n        for joints in body_joint_order:\n            indices: List[int] = []\n            for name in joints:\n                if name in dof_name_to_index:\n                    indices.append(dof_name_to_index[name])\n            body_dof_indices.append(indices)\n\n        return (\n            body_names,\n            torch.tensor(parents, dtype=torch.long),\n            torch.tensor(local_translation, dtype=torch.float32),\n            torch.tensor(local_rotation, dtype=torch.float32),\n            body_joint_order,\n            joint_body_index,\n            joint_axis,\n            dof_names,\n            torch.tensor(dof_axis, dtype=torch.float32),\n            body_dof_indices,\n        )\n\n\nclass URDFParser:\n    def __init__(self, urdf_path: str) -> None:\n        self._urdf_path = urdf_path\n\n    @staticmethod\n    def _as_tf(\n        tf: torch.Tensor | None, identity: torch.Tensor\n    ) -> torch.Tensor:\n        if tf is None:\n            return identity\n        if tf.ndim == 3:\n            return tf[0]\n        return tf\n\n    def _load_chain(self) -> pk.Chain:\n        with open(self._urdf_path, mode=\"r\", encoding=\"utf-8\") as f:\n            urdf_text = f.read()\n        return pk.build_chain_from_urdf(urdf_text)\n\n    def parse(\n        self,\n    ) -> Tuple[\n        List[str],\n        torch.Tensor,\n        torch.Tensor,\n        torch.Tensor,\n        List[List[str]],\n        Dict[str, int],\n        Dict[str, List[float]],\n        List[str],\n        torch.Tensor,\n        List[List[int]],\n    ]:\n        pk_chain = self._load_chain()\n        dof_names = pk_chain.get_joint_parameter_names()\n        if len(dof_names) == 0:\n            raise ValueError(\"No actuated joints found in URDF\")\n        dof_axis = pk_chain.axes.to(dtype=torch.float32)\n\n        root_name = pk_chain._root.name\n        moving_frames = pk_chain.get_frame_names(exclude_fixed=True)\n        body_names = [root_name] + [\n            name for name in moving_frames if name != root_name\n        ]\n        body_name_to_index = {name: idx for idx, name in enumerate(body_names)}\n\n        num_frames = len(pk_chain.idx_to_frame)\n        frame_name_to_index = {\n            name: idx for idx, name in pk_chain.idx_to_frame.items()\n        }\n\n        full_parent_indices: List[int] = []\n        for i in range(num_frames):\n            chain_indices = pk_chain.parents_indices[i]\n            if chain_indices.numel() <= 1:\n                full_parent_indices.append(-1)\n            else:\n                full_parent_indices.append(int(chain_indices[-2].item()))\n\n        identity = torch.eye(4, dtype=torch.float32)\n        frame_transforms: List[torch.Tensor] = [identity] * num_frames\n        for i in range(num_frames):\n            link_offset = self._as_tf(pk_chain.link_offsets[i], identity)\n            joint_offset = self._as_tf(pk_chain.joint_offsets[i], identity)\n            if i == 0:\n                link_offset = identity\n                joint_offset = identity\n            parent = full_parent_indices[i]\n            if parent < 0:\n                frame_tf = identity\n            else:\n                frame_tf = (\n                    frame_transforms[parent] @ link_offset @ joint_offset\n                )\n            frame_transforms[i] = frame_tf\n\n        parents: List[int] = []\n        local_translation: List[List[float]] = []\n        local_rotation_mat: List[torch.Tensor] = []\n        body_joint_order: List[List[str]] = []\n        joint_body_index: Dict[str, int] = {}\n        joint_axis: Dict[str, List[float]] = {}\n\n        for body_name in body_names:\n            frame_idx = frame_name_to_index[body_name]\n            parent_frame_idx = full_parent_indices[frame_idx]\n            parent_body_idx = -1\n            while parent_frame_idx >= 0:\n                parent_name = pk_chain.idx_to_frame[parent_frame_idx]\n                if parent_name in body_name_to_index:\n                    parent_body_idx = body_name_to_index[parent_name]\n                    break\n                parent_frame_idx = full_parent_indices[parent_frame_idx]\n            parents.append(parent_body_idx)\n\n            if parent_body_idx < 0:\n                local_tf = identity\n            else:\n                local_tf = (\n                    torch.linalg.inv(frame_transforms[parent_frame_idx])\n                    @ frame_transforms[frame_idx]\n                )\n            local_translation.append(local_tf[:3, 3].tolist())\n            local_rotation_mat.append(local_tf[:3, :3])\n\n            joints_in_body: List[str] = []\n            joint_index = int(pk_chain.joint_indices[frame_idx].item())\n            if joint_index >= 0:\n                joint_type = int(pk_chain.joint_type_indices[frame_idx].item())\n                if joint_type != 1:\n                    raise ValueError(\n                        f\"Unsupported joint type index: {joint_type}\"\n                    )\n                joint_name = dof_names[joint_index]\n                joints_in_body.append(joint_name)\n                joint_body_index[joint_name] = body_name_to_index[body_name]\n                joint_axis[joint_name] = dof_axis[joint_index].tolist()\n            body_joint_order.append(joints_in_body)\n\n        local_rotation = torch_utils.quat_from_matrix(\n            torch.stack(local_rotation_mat, dim=0)\n        )\n\n        dof_name_to_index = {name: idx for idx, name in enumerate(dof_names)}\n        body_dof_indices: List[List[int]] = []\n        for joints in body_joint_order:\n            indices: List[int] = []\n            for name in joints:\n                if name in dof_name_to_index:\n                    indices.append(dof_name_to_index[name])\n            body_dof_indices.append(indices)\n\n        return (\n            body_names,\n            torch.tensor(parents, dtype=torch.long),\n            torch.tensor(local_translation, dtype=torch.float32),\n            local_rotation.to(dtype=torch.float32),\n            body_joint_order,\n            joint_body_index,\n            joint_axis,\n            dof_names,\n            dof_axis,\n            body_dof_indices,\n        )\n\n\n# @torch.compile(dynamic=True)\nclass HoloMotionFK(torch.nn.Module):\n    def __init__(\n        self,\n        robot_file_path: str,\n        device: torch.device | str = \"cpu\",\n        dtype: torch.dtype = torch.float32,\n    ) -> None:\n        super().__init__()\n        self.robot_file_path = robot_file_path\n        _, ext = os.path.splitext(robot_file_path)\n        ext = ext.lower()\n        if ext == \".urdf\":\n            parser = URDFParser(robot_file_path)\n        elif ext in [\".xml\", \".mjcf\"]:\n            parser = MJCFParser(robot_file_path)\n        else:\n            raise ValueError(f\"Unsupported file extension: {ext}\")\n\n        logger.info(\n            f\"Parsing robot file for online forward kinematics: {robot_file_path}...\"\n        )\n\n        (\n            body_names,\n            parents,\n            local_translation,\n            local_rotation,\n            body_joint_order,\n            joint_body_index,\n            joint_axis,\n            dof_names,\n            dof_axis,\n            body_dof_indices,\n        ) = parser.parse()\n        self.body_names = body_names\n        self.dof_names = dof_names\n        self.num_bodies = len(body_names)\n        self.num_dof = len(dof_names)\n        parents = parents.to(device=device)\n        local_translation = local_translation.to(device=device, dtype=dtype)\n        local_rotation = local_rotation.to(device=device, dtype=dtype)\n        local_rotation_mat = torch_utils.matrix_from_quat(local_rotation)\n        dof_axis = dof_axis.to(device=device, dtype=dtype)\n        max_body_dofs = max(\n            (len(indices) for indices in body_dof_indices), default=0\n        )\n        body_dof_index_tensor = torch.full(\n            (self.num_bodies, max_body_dofs),\n            -1,\n            dtype=torch.long,\n        )\n        body_dof_mask = torch.zeros(\n            (self.num_bodies, max_body_dofs), dtype=torch.bool\n        )\n        for body_idx, indices in enumerate(body_dof_indices):\n            if not indices:\n                continue\n            body_dof_index_tensor[body_idx, : len(indices)] = torch.tensor(\n                indices, dtype=torch.long\n            )\n            body_dof_mask[body_idx, : len(indices)] = True\n        self.register_buffer(\"_parents\", parents)\n        self.register_buffer(\"_local_translation\", local_translation)\n        self.register_buffer(\"_local_rotation_mat\", local_rotation_mat)\n        self.register_buffer(\"_dof_axis\", dof_axis)\n        self.register_buffer(\"_body_dof_index_tensor\", body_dof_index_tensor)\n        self.register_buffer(\"_body_dof_mask\", body_dof_mask)\n        self._body_joint_order = body_joint_order\n        self._joint_body_index = joint_body_index\n        self._joint_axis = joint_axis\n        self._body_dof_indices = body_dof_indices\n\n    @torch.no_grad()\n    def forward(\n        self,\n        root_pos: torch.Tensor,\n        root_quat: torch.Tensor,\n        dof_pos: torch.Tensor,\n        fps: float,\n        quat_format: str = \"xyzw\",\n        sub_batch_size: int = 64,\n        vel_smoothing_sigma: float = 2.0,\n    ) -> Dict[str, torch.Tensor]:\n        \"\"\"Forward kinematics and smoothed velocities.\n\n        Args:\n            root_pos: (B, T, 3)\n            root_quat: (B, T, 4), XYZW by default\n            dof_pos: (B, T, ndof)\n            fps: frames per second\n            sub_batch_size: split batch into chunks to reduce peak memory\n            vel_smoothing_sigma: Gaussian sigma for smoothing velocity signals\n                along the time axis (set <= 0 to disable).\n\n        Returns:\n            Dict with global_translation/global_rotation_quat/global_velocity/\n            global_angular_velocity/dof_pos/dof_vel.\n        \"\"\"\n        if fps <= 0.0:\n            raise ValueError(f\"Invalid fps: {fps}\")\n        if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3:\n            raise ValueError(\"Inputs must be (B, T, ...)\")\n        if (\n            root_pos.shape[:2] != root_quat.shape[:2]\n            or root_pos.shape[:2] != dof_pos.shape[:2]\n        ):\n            raise ValueError(\"Mismatched batch/time shapes among inputs\")\n        if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4:\n            raise ValueError(\n                \"root_pos must be (B,T,3) and root_quat must be (B,T,4)\"\n            )\n        if dof_pos.shape[-1] != self.num_dof:\n            raise ValueError(\n                f\"dof_pos last dim {dof_pos.shape[-1]} does not match \"\n                f\"{self.num_dof}\"\n            )\n\n        device = self._local_translation.device\n        dtype = self._local_translation.dtype\n        root_pos = root_pos.to(device=device, dtype=dtype)\n        root_quat = root_quat.to(device=device, dtype=dtype)\n        dof_pos = dof_pos.to(device=device, dtype=dtype)\n\n        batch_size, seq_len = root_pos.shape[:2]\n        if (\n            sub_batch_size is None\n            or sub_batch_size <= 0\n            or sub_batch_size >= batch_size\n        ):\n            return self._forward_impl(\n                root_pos=root_pos,\n                root_quat=root_quat,\n                dof_pos=dof_pos,\n                fps=fps,\n                quat_format=quat_format,\n                vel_smoothing_sigma=float(vel_smoothing_sigma),\n            )\n\n        global_translation = torch.empty(\n            (batch_size, seq_len, self.num_bodies, 3),\n            device=device,\n            dtype=dtype,\n        )\n        global_rotation_quat = torch.empty(\n            (batch_size, seq_len, self.num_bodies, 4),\n            device=device,\n            dtype=dtype,\n        )\n        global_velocity = torch.empty_like(global_translation)\n        global_angular_velocity = torch.empty_like(global_translation)\n        dof_pos_out = torch.empty_like(dof_pos)\n        dof_vel = torch.empty_like(dof_pos)\n\n        for start in range(0, batch_size, sub_batch_size):\n            end = min(start + sub_batch_size, batch_size)\n            out = self._forward_impl(\n                root_pos=root_pos[start:end],\n                root_quat=root_quat[start:end],\n                dof_pos=dof_pos[start:end],\n                fps=fps,\n                quat_format=quat_format,\n                vel_smoothing_sigma=float(vel_smoothing_sigma),\n            )\n            global_translation[start:end] = out[\"global_translation\"]\n            global_rotation_quat[start:end] = out[\"global_rotation_quat\"]\n            global_velocity[start:end] = out[\"global_velocity\"]\n            global_angular_velocity[start:end] = out[\"global_angular_velocity\"]\n            dof_pos_out[start:end] = out[\"dof_pos\"]\n            dof_vel[start:end] = out[\"dof_vel\"]\n\n        return {\n            \"global_translation\": global_translation,\n            \"global_rotation_quat\": global_rotation_quat,\n            \"global_velocity\": global_velocity,\n            \"global_angular_velocity\": global_angular_velocity,\n            \"dof_pos\": dof_pos_out,\n            \"dof_vel\": dof_vel,\n        }\n\n    def _forward_impl(\n        self,\n        root_pos: torch.Tensor,\n        root_quat: torch.Tensor,\n        dof_pos: torch.Tensor,\n        fps: float,\n        quat_format: str,\n        vel_smoothing_sigma: float,\n    ) -> Dict[str, torch.Tensor]:\n        device = self._local_translation.device\n        dtype = self._local_translation.dtype\n        if quat_format == \"xyzw\":\n            root_quat_wxyz = torch_utils.xyzw_to_wxyz(root_quat)\n        elif quat_format == \"wxyz\":\n            root_quat_wxyz = root_quat\n        else:\n            raise ValueError(f\"Unsupported quat_format: {quat_format}\")\n\n        root_rotmat = torch_utils.matrix_from_quat(root_quat_wxyz)\n        dof_rotmats = torch_utils.axis_angle_to_matrix(dof_pos, self._dof_axis)\n\n        positions_world = torch.empty(\n            (dof_pos.shape[0], dof_pos.shape[1], self.num_bodies, 3),\n            device=device,\n            dtype=dtype,\n        )\n        rotations_world = torch.empty(\n            (dof_pos.shape[0], dof_pos.shape[1], self.num_bodies, 3, 3),\n            device=device,\n            dtype=dtype,\n        )\n\n        for i in range(self.num_bodies):\n            parent = int(self._parents[i].item())\n            if parent < 0:\n                positions_world[:, :, i] = root_pos\n                rotations_world[:, :, i] = root_rotmat\n                continue\n            parent_pos = positions_world[:, :, parent]\n            parent_rot = rotations_world[:, :, parent]\n            offset = self._local_translation[i]\n            pos = parent_pos + torch.einsum(\"btij,j->bti\", parent_rot, offset)\n            rot = torch.matmul(parent_rot, self._local_rotation_mat[i])\n            body_dof_indices = self._body_dof_indices[i]\n            for dof_idx in body_dof_indices:\n                rot = torch.matmul(rot, dof_rotmats[:, :, dof_idx])\n            positions_world[:, :, i] = pos\n            rotations_world[:, :, i] = rot\n\n        global_translation = positions_world\n        global_rotation_mat = rotations_world\n        global_quat_wxyz = torch_utils.quat_from_matrix(global_rotation_mat)\n        global_quat_xyzw = torch_utils.wxyz_to_xyzw(global_quat_wxyz)\n\n        dt = 1.0 / fps\n        if dof_pos.shape[1] < 2:\n            dof_vel = torch.zeros_like(dof_pos)\n        else:\n            diff = (dof_pos[:, 1:] - dof_pos[:, :-1]) / dt\n            pad = diff[:, -2:-1] if diff.shape[1] >= 2 else diff[:, -1:]\n            dof_vel = torch.cat([diff, pad], dim=1)\n        dof_vel = torch_utils.smooth_time_series(\n            dof_vel, sigma=float(vel_smoothing_sigma), dim=1\n        )\n\n        global_velocity = torch_utils.grad_t(global_translation, dt)\n        global_velocity = torch_utils.smooth_time_series(\n            global_velocity, sigma=float(vel_smoothing_sigma), dim=1\n        )\n\n        if global_quat_xyzw.shape[1] < 2:\n            global_angular_velocity = torch.zeros_like(global_translation)\n        else:\n            q1 = torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, 1:])\n            q0_inv = torch_utils.quat_conjugate(\n                torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, :-1])\n            )\n            q_rel = torch_utils.quat_mul(q1, q0_inv)\n            q_rel = q_rel / torch.linalg.norm(q_rel, dim=-1, keepdim=True)\n            q_rel = torch_utils.standardize_quaternion(q_rel)\n\n            identity = torch.tensor(\n                [1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype\n            )[None, None, None]\n            q_rel_full = identity.expand(\n                global_quat_xyzw.shape[0],\n                global_quat_xyzw.shape[1],\n                global_quat_xyzw.shape[2],\n                4,\n            ).clone()\n            q_rel_full[:, :-1] = q_rel\n            global_angular_velocity = (\n                torch_utils.axis_angle_from_quat(q_rel_full, w_last=False) / dt\n            )\n            global_angular_velocity = torch_utils.smooth_time_series(\n                global_angular_velocity,\n                sigma=float(vel_smoothing_sigma),\n                dim=1,\n            )\n\n        return {\n            \"global_translation\": global_translation,\n            \"global_rotation_quat\": global_quat_xyzw,\n            \"global_velocity\": global_velocity,\n            \"global_angular_velocity\": global_angular_velocity,\n            \"dof_pos\": dof_pos,\n            \"dof_vel\": dof_vel,\n        }\n\n\n# class HoloMotionFK_V2(torch.nn.Module):\n#     \"\"\"\n#     Use pytorch_kinematics to compute FK.\n#     \"\"\"\n\n#     def __init__(\n#         self,\n#         robot_file_path: str,\n#         device: torch.device | str = \"cpu\",\n#         dtype: torch.dtype = torch.float32,\n#     ) -> None:\n#         super().__init__()\n#         self.robot_file_path = robot_file_path\n#         urdf_path = os.path.splitext(robot_file_path)[0] + \".urdf\"\n#         if not os.path.isfile(urdf_path):\n#             raise FileNotFoundError(f\"URDF not found: {urdf_path}\")\n#         with open(urdf_path, mode=\"r\", encoding=\"utf-8\") as f:\n#             urdf_text = f.read()\n\n#         pk_chain = pk.build_chain_from_urdf(urdf_text)\n#         pk_chain = pk_chain.to(dtype=dtype, device=device)\n\n#         self.dof_names = pk_chain.get_joint_parameter_names()\n#         self.num_dof = len(self.dof_names)\n#         root_name = pk_chain._root.name\n#         moving_frames = pk_chain.get_frame_names(exclude_fixed=True)\n#         self.body_names = [root_name] + [\n#             name for name in moving_frames if name != root_name\n#         ]\n#         self.num_bodies = len(self.body_names)\n\n#         body_frame_indices = pk_chain.get_frame_indices(*self.body_names)\n#         self.register_buffer(\"_body_frame_indices\", body_frame_indices)\n\n#         num_frames = len(pk_chain.idx_to_frame)\n#         identity = torch.eye(4, device=device, dtype=dtype)\n#         link_offsets = []\n#         joint_offsets = []\n#         for i in range(num_frames):\n#             link_offset = pk_chain.link_offsets[i]\n#             joint_offset = pk_chain.joint_offsets[i]\n#             if link_offset is None:\n#                 link_offset = identity\n#             if joint_offset is None:\n#                 joint_offset = identity\n#             if link_offset.ndim == 3:\n#                 link_offset = link_offset[0]\n#             if joint_offset.ndim == 3:\n#                 joint_offset = joint_offset[0]\n#             link_offsets.append(link_offset)\n#             joint_offsets.append(joint_offset)\n#         if num_frames > 0:\n#             link_offsets[0] = identity\n#             joint_offsets[0] = identity\n\n#         parent_indices: List[int] = []\n#         for i in range(num_frames):\n#             chain_indices = pk_chain.parents_indices[i]\n#             if chain_indices.numel() <= 1:\n#                 parent_indices.append(-1)\n#             else:\n#                 parent_indices.append(int(chain_indices[-2].item()))\n\n#         self.register_buffer(\"_pk_axes\", pk_chain.axes)\n#         self.register_buffer(\n#             \"_pk_joint_type_indices\", pk_chain.joint_type_indices\n#         )\n#         self.register_buffer(\"_pk_joint_indices\", pk_chain.joint_indices)\n#         self.register_buffer(\n#             \"_pk_link_offsets\", torch.stack(link_offsets, dim=0)\n#         )\n#         self.register_buffer(\n#             \"_pk_joint_offsets\", torch.stack(joint_offsets, dim=0)\n#         )\n#         self.register_buffer(\n#             \"_pk_parent_indices\",\n#             torch.tensor(parent_indices, dtype=torch.long, device=device),\n#         )\n#         self._num_frames = num_frames\n\n#     def forward(\n#         self,\n#         root_pos: torch.Tensor,\n#         root_quat: torch.Tensor,\n#         dof_pos: torch.Tensor,\n#         fps: float,\n#         quat_format: str = \"xyzw\",\n#     ) -> Dict[str, torch.Tensor]:\n#         \"\"\"\n#         Args:\n#             root_pos: (B, T, 3)\n#             root_quat: (B, T, 4), XYZW by default\n#             dof_pos: (B, T, ndof)\n#             fps: frames per second\n#         Returns:\n#             Dict with global_translation/global_rotation_quat/global_velocity/\n#             global_angular_velocity/dof_pos/dof_vel.\n#         \"\"\"\n#         if fps <= 0.0:\n#             raise ValueError(f\"Invalid fps: {fps}\")\n#         if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3:\n#             raise ValueError(\"Inputs must be (B, T, ...)\")\n#         if (\n#             root_pos.shape[:2] != root_quat.shape[:2]\n#             or root_pos.shape[:2] != dof_pos.shape[:2]\n#         ):\n#             raise ValueError(\"Mismatched batch/time shapes among inputs\")\n#         if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4:\n#             raise ValueError(\n#                 \"root_pos must be (B,T,3) and root_quat must be (B,T,4)\"\n#             )\n#         if dof_pos.shape[-1] != self.num_dof:\n#             raise ValueError(\n#                 f\"dof_pos last dim {dof_pos.shape[-1]} does not match {self.num_dof}\"\n#             )\n\n#         device = self._pk_axes.device\n#         dtype = self._pk_axes.dtype\n#         root_pos = root_pos.to(device=device, dtype=dtype)\n#         root_quat = root_quat.to(device=device, dtype=dtype)\n#         dof_pos = dof_pos.to(device=device, dtype=dtype)\n\n#         if quat_format == \"xyzw\":\n#             root_quat_wxyz = torch_utils.xyzw_to_wxyz(root_quat)\n#         elif quat_format == \"wxyz\":\n#             root_quat_wxyz = root_quat\n#         else:\n#             raise ValueError(f\"Unsupported quat_format: {quat_format}\")\n\n#         batch_size, seq_len = root_pos.shape[:2]\n#         flat_size = batch_size * seq_len\n#         root_pos_flat = root_pos.reshape(flat_size, 3)\n#         root_quat_flat = root_quat_wxyz.reshape(flat_size, 4)\n#         dof_pos_flat = dof_pos.reshape(flat_size, self.num_dof)\n\n#         axes_expanded = self._pk_axes[None].expand(flat_size, -1, -1)\n#         revolute_tf = axis_and_angle_to_matrix_44(axes_expanded, dof_pos_flat)\n#         prismatic_tf = axis_and_d_to_pris_matrix(axes_expanded, dof_pos_flat)\n\n#         frame_transforms = torch.empty(\n#             (flat_size, self._num_frames, 4, 4), device=device, dtype=dtype\n#         )\n#         identity = torch.eye(4, device=device, dtype=dtype).repeat(\n#             flat_size, 1, 1\n#         )\n\n#         for i in range(self._num_frames):\n#             parent = int(self._pk_parent_indices[i].item())\n#             if parent < 0:\n#                 frame_tf = identity\n#             else:\n#                 frame_tf = frame_transforms[:, parent]\n#             frame_tf = frame_tf @ self._pk_link_offsets[i]\n#             frame_tf = frame_tf @ self._pk_joint_offsets[i]\n#             joint_type = int(self._pk_joint_type_indices[i].item())\n#             if joint_type == 1:\n#                 joint_index = int(self._pk_joint_indices[i].item())\n#                 frame_tf = frame_tf @ revolute_tf[:, joint_index]\n#             elif joint_type == 2:\n#                 joint_index = int(self._pk_joint_indices[i].item())\n#                 frame_tf = frame_tf @ prismatic_tf[:, joint_index]\n#             frame_transforms[:, i] = frame_tf\n\n#         chain_tf = torch.index_select(\n#             frame_transforms, 1, self._body_frame_indices\n#         )\n\n#         root_rotmat = torch_utils.matrix_from_quat(root_quat_flat)\n#         root_tf = torch.eye(4, device=device, dtype=dtype).repeat(\n#             flat_size, 1, 1\n#         )\n#         root_tf[:, :3, :3] = root_rotmat\n#         root_tf[:, :3, 3] = root_pos_flat\n\n#         world_tf = root_tf[:, None] @ chain_tf\n#         world_tf = world_tf.reshape(batch_size, seq_len, self.num_bodies, 4, 4)\n#         global_translation = world_tf[:, :, :, :3, 3]\n#         global_rotation_mat = world_tf[:, :, :, :3, :3]\n#         global_quat_wxyz = torch_utils.quat_from_matrix(global_rotation_mat)\n#         global_quat_xyzw = torch_utils.wxyz_to_xyzw(global_quat_wxyz)\n\n#         dt = 1.0 / fps\n#         if dof_pos.shape[1] < 2:\n#             dof_vel = torch.zeros_like(dof_pos)\n#         else:\n#             diff = (dof_pos[:, 1:] - dof_pos[:, :-1]) / dt\n#             pad = diff[:, -2:-1] if diff.shape[1] >= 2 else diff[:, -1:]\n#             dof_vel = torch.cat([diff, pad], dim=1)\n\n#         global_velocity = torch_utils.grad_t(global_translation, dt)\n#         global_velocity = torch_utils.gaussian_filter1d(\n#             global_velocity, sigma=2.0, dim=1\n#         )\n\n#         if global_quat_xyzw.shape[1] < 2:\n#             global_angular_velocity = torch.zeros_like(global_translation)\n#         else:\n#             q1 = torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, 1:])\n#             q0_inv = torch_utils.quat_conjugate(\n#                 torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, :-1])\n#             )\n#             q_rel = torch_utils.quat_mul(q1, q0_inv)\n#             q_rel = q_rel / torch.linalg.norm(q_rel, dim=-1, keepdim=True)\n#             q_rel = torch_utils.standardize_quaternion(q_rel)\n\n#             identity = torch.tensor(\n#                 [1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype\n#             )[None, None, None]\n#             q_rel_full = identity.expand(\n#                 global_quat_xyzw.shape[0],\n#                 global_quat_xyzw.shape[1],\n#                 global_quat_xyzw.shape[2],\n#                 4,\n#             ).clone()\n#             q_rel_full[:, :-1] = q_rel\n#             global_angular_velocity = (\n#                 torch_utils.axis_angle_from_quat(q_rel_full, w_last=False) / dt\n#             )\n#             global_angular_velocity = torch_utils.gaussian_filter1d(\n#                 global_angular_velocity,\n#                 sigma=2.0,\n#                 dim=1,\n#             )\n\n#         return {\n#             \"global_translation\": global_translation,\n#             \"global_rotation_quat\": global_quat_xyzw,\n#             \"global_velocity\": global_velocity,\n#             \"global_angular_velocity\": global_angular_velocity,\n#             \"dof_pos\": dof_pos,\n#             \"dof_vel\": dof_vel,\n#         }\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/holomotion_preprocess.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nimport json\nimport logging\nimport os\nimport sys\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport hydra\nimport numpy as np\nimport ray\nimport torch\nfrom loguru import logger\nfrom omegaconf import DictConfig, ListConfig, OmegaConf\nfrom scipy.spatial.transform import Rotation as sRot\nfrom scipy.spatial.transform import Slerp\nfrom tqdm import tqdm\n\nfrom holomotion.src.motion_retargeting.utils.torch_humanoid_batch import (\n    HumanoidBatch,\n)\nfrom holomotion.src.motion_retargeting.utils import (\n    rotation_conversions as rot_conv,\n)\nfrom holomotion.src.motion_retargeting.reference_filtering import (\n    butterworth_filter_ref_arrays as shared_butterworth_filter_ref_arrays,\n)\n\n\ndef compute_slices(\n    sequence_len: int, window_size: int, overlap: int\n) -> List[Tuple[int, int]]:\n    step = window_size - overlap\n    if step <= 0:\n        raise ValueError(\"window_size must be > overlap\")\n    slices: List[Tuple[int, int]] = []\n    start = 0\n    length = int(sequence_len)\n    while start < length:\n        end = min(start + window_size, length)\n        slices.append((start, end))\n        if end == length:\n            break\n        start += step\n    return slices\n\n\ndef _reshape_time_flat(a: np.ndarray) -> Tuple[np.ndarray, Tuple[int, ...]]:\n    shape = a.shape\n    t = shape[0]\n    return a.reshape(t, -1), shape\n\n\ndef _butterworth_lowpass_smooth_time(\n    a: np.ndarray, fps: float, cutoff_hz: float, order: int\n) -> np.ndarray:\n    from scipy.signal import butter, filtfilt\n\n    t = a.shape[0]\n    if t < 3:\n        return a.astype(np.float32, copy=True)\n    if fps <= 0.0 or cutoff_hz <= 0.0:\n        return a.astype(np.float32, copy=True)\n    nyquist = 0.5 * float(fps)\n    wn = float(cutoff_hz) / nyquist\n    if wn >= 1.0:\n        wn = 0.999\n    if wn <= 0.0:\n        return a.astype(np.float32, copy=True)\n    flat, shape = _reshape_time_flat(a.astype(np.float64, copy=False))\n    b, a_coefs = butter(int(order), wn, btype=\"low\", analog=False)\n    maxlen = max(len(b), len(a_coefs))\n    padlen_required = max(3 * (maxlen - 1), 3 * maxlen)\n    if t <= padlen_required:\n        return a.astype(np.float32, copy=True)\n    filtered = filtfilt(b, a_coefs, flat, axis=0, method=\"pad\")\n    return filtered.reshape(shape).astype(np.float32, copy=False)\n\n\ndef _quat_normalize(q: np.ndarray) -> np.ndarray:\n    norm = np.linalg.norm(q, axis=-1, keepdims=True)\n    norm = np.where(norm == 0.0, 1.0, norm)\n    return (q / norm).astype(np.float32, copy=False)\n\n\ndef _quat_hemisphere_align(q: np.ndarray) -> np.ndarray:\n    if q.shape[0] == 0:\n        return q\n    aligned = q.copy()\n    prev = aligned[0]\n    for t in range(1, aligned.shape[0]):\n        dots = np.sum(prev * aligned[t], axis=-1)\n        mask = dots < 0.0\n        if np.any(mask):\n            aligned[t, mask] = -aligned[t, mask]\n        prev = aligned[t]\n    return aligned\n\n\ndef _quat_conjugate(q: np.ndarray) -> np.ndarray:\n    conj = q.copy()\n    conj[..., :3] = -conj[..., :3]\n    return conj\n\n\ndef _quat_multiply(a: np.ndarray, b: np.ndarray) -> np.ndarray:\n    av = a[..., :3]\n    aw = a[..., 3:4]\n    bv = b[..., :3]\n    bw = b[..., 3:4]\n    cross = np.cross(av, bv)\n    vec = aw * bv + bw * av + cross\n    scalar = aw * bw - np.sum(av * bv, axis=-1, keepdims=True)\n    return np.concatenate([vec, scalar], axis=-1)\n\n\ndef _finite_difference_time(a: np.ndarray, dt: float) -> np.ndarray:\n    t = a.shape[0]\n    if t < 2 or dt <= 0.0:\n        return np.zeros_like(a, dtype=np.float32)\n    deriv = np.gradient(\n        a.astype(np.float64, copy=False),\n        dt,\n        axis=0,\n        edge_order=2 if t >= 3 else 1,\n    )\n    return deriv.astype(np.float32, copy=False)\n\n\ndef _angular_velocity_from_quat(\n    q: np.ndarray, q_dot: np.ndarray\n) -> np.ndarray:\n    q_conj = _quat_conjugate(q)\n    prod = _quat_multiply(q_conj, q_dot)\n    omega = 2.0 * prod[..., :3]\n    return omega.astype(np.float32, copy=False)\n\n\ndef butterworth_filter_ref_arrays(\n    arrays: Dict[str, np.ndarray], fps: float, cutoff_hz: float, order: int\n) -> Dict[str, np.ndarray]:\n    return shared_butterworth_filter_ref_arrays(\n        arrays=arrays,\n        fps=fps,\n        cutoff_hz=cutoff_hz,\n        order=order,\n    )\n\n\ndef _summary(arr: np.ndarray) -> Dict[str, float]:\n    if arr.size == 0:\n        return {\n            \"mean\": 0.0,\n            \"std\": 0.0,\n            \"median\": 0.0,\n            \"min\": 0.0,\n            \"max\": 0.0,\n            \"q25\": 0.0,\n            \"q75\": 0.0,\n        }\n    return {\n        \"mean\": float(arr.mean()),\n        \"std\": float(arr.std()),\n        \"median\": float(np.median(arr)),\n        \"min\": float(arr.min()),\n        \"max\": float(arr.max()),\n        \"q25\": float(np.quantile(arr, 0.25)),\n        \"q75\": float(np.quantile(arr, 0.75)),\n    }\n\n\ndef _ds_summary(arr: np.ndarray) -> Dict[str, float]:\n    if arr.size == 0:\n        return {\n            \"DS_mean\": 0.0,\n            \"DS_std\": 0.0,\n            \"DS_median\": 0.0,\n            \"DS_min\": 0.0,\n            \"DS_max\": 0.0,\n            \"DS_q25\": 0.0,\n            \"DS_q75\": 0.0,\n        }\n    return {\n        \"DS_mean\": float(arr.mean()),\n        \"DS_std\": float(arr.std()),\n        \"DS_median\": float(np.median(arr)),\n        \"DS_min\": float(arr.min()),\n        \"DS_max\": float(arr.max()),\n        \"DS_q25\": float(np.quantile(arr, 0.25)),\n        \"DS_q75\": float(np.quantile(arr, 0.75)),\n    }\n\n\ndef _interpolate_linear(\n    start: np.ndarray, end: np.ndarray, num_frames: int\n) -> np.ndarray:\n    \"\"\"Linear interpolation between start and end over num_frames.\n\n    Returns array where result[0] == start and result[-1] == end.\n    \"\"\"\n    start = np.asarray(start, dtype=np.float32)\n    end = np.asarray(end, dtype=np.float32)\n    if num_frames <= 1:\n        return start[None, ...]\n    t = np.linspace(0.0, 1.0, num_frames, dtype=np.float32)\n    for _ in range(start.ndim):\n        t = t[..., None]\n    result = ((1.0 - t) * start + t * end).astype(np.float32)\n    result[0] = start\n    result[-1] = end\n    return result\n\n\ndef _interpolate_quaternions_slerp(\n    start_quat: np.ndarray, end_quat: np.ndarray, num_frames: int\n) -> np.ndarray:\n    \"\"\"SLERP interpolation between two quaternions (XYZW format) over num_frames.\n\n    Args:\n        start_quat: shape [4] in XYZW format\n        end_quat: shape [4] in XYZW format\n        num_frames: number of interpolation frames\n\n    Returns:\n        shape [num_frames, 4] in XYZW format, with result[0] == start_quat\n        and result[-1] == end_quat.\n    \"\"\"\n    start_quat = np.asarray(start_quat, dtype=np.float32)\n    end_quat = np.asarray(end_quat, dtype=np.float32)\n    if num_frames <= 1:\n        return start_quat[None, ...]\n    rotations = sRot.from_quat([start_quat, end_quat])\n    slerp = Slerp([0.0, 1.0], rotations)\n    t = np.linspace(0.0, 1.0, num_frames)\n    result = slerp(t).as_quat().astype(np.float32)\n    result[0] = start_quat\n    result[-1] = end_quat\n    return result\n\n\ndef _extract_yaw_only_quat(quat: np.ndarray) -> np.ndarray:\n    \"\"\"Extract yaw-only quaternion (XYZW format) from a full quaternion.\n\n    Args:\n        quat: shape [4] in XYZW format\n\n    Returns:\n        shape [4] in XYZW format with only yaw rotation\n    \"\"\"\n    rot = sRot.from_quat(quat)\n    euler = rot.as_euler(\"xyz\", degrees=False)\n    yaw_only_euler = np.array([0.0, 0.0, euler[2]])\n    yaw_only_rot = sRot.from_euler(\"xyz\", yaw_only_euler, degrees=False)\n    return yaw_only_rot.as_quat().astype(np.float32)\n\n\ndef _dof_to_pose_aa(\n    dof_pos: np.ndarray,\n    root_rot_xyzw: np.ndarray,\n    humanoid_fk: \"HumanoidBatch\",\n    num_augment_joints: int,\n) -> np.ndarray:\n    \"\"\"Convert DOF positions and root rotation to pose axis-angle.\n\n    Args:\n        dof_pos: shape [T, num_dofs]\n        root_rot_xyzw: shape [T, 4] in XYZW format\n        humanoid_fk: HumanoidBatch instance\n        num_augment_joints: number of augmented joints\n\n    Returns:\n        pose_aa: shape [T, num_bodies + num_augment_joints, 3]\n    \"\"\"\n    dof_t = torch.as_tensor(dof_pos, dtype=torch.float32)\n    T = dof_t.shape[0]\n    root_quat_wxyz = rot_conv.xyzw_to_wxyz(\n        torch.as_tensor(root_rot_xyzw, dtype=torch.float32)\n    )\n    root_aa = rot_conv.quaternion_to_axis_angle(root_quat_wxyz)\n    joint_aa = humanoid_fk.dof_axis * dof_t[:, :, None]\n    pose_aa = torch.cat(\n        [\n            root_aa[:, None, :],\n            joint_aa,\n            torch.zeros((T, num_augment_joints, 3), dtype=torch.float32),\n        ],\n        dim=1,\n    )\n    return pose_aa.numpy().astype(np.float32)\n\n\ndef _compute_fk_motion(\n    dof_pos: np.ndarray,\n    root_pos: np.ndarray,\n    root_rot_xyzw: np.ndarray,\n    humanoid_fk: \"HumanoidBatch\",\n    num_augment_joints: int,\n    fps: float,\n) -> Dict[str, np.ndarray]:\n    \"\"\"Compute all motion arrays from dof_pos, root_pos, and root_rot via FK.\n\n    Args:\n        dof_pos: shape [T, num_dofs]\n        root_pos: shape [T, 3]\n        root_rot_xyzw: shape [T, 4] in XYZW format\n        humanoid_fk: HumanoidBatch instance\n        num_augment_joints: number of augmented joints\n        fps: frames per second\n\n    Returns:\n        Dict with ref_dof_pos, ref_dof_vel, ref_global_translation,\n        ref_global_rotation_quat, ref_global_velocity, ref_global_angular_velocity,\n        frame_flag\n    \"\"\"\n    T = dof_pos.shape[0]\n    dt = 1.0 / fps\n    pose_aa = _dof_to_pose_aa(\n        dof_pos, root_rot_xyzw, humanoid_fk, num_augment_joints\n    )\n    pose_aa_t = torch.as_tensor(pose_aa, dtype=torch.float32)\n    root_pos_t = torch.as_tensor(root_pos, dtype=torch.float32)\n    fk_result = humanoid_fk.fk_batch(\n        pose_aa_t[None, ...],\n        root_pos_t[None, ...],\n        return_full=True,\n        dt=dt,\n    )\n    frame_flag = np.ones(T, dtype=np.int32)\n    frame_flag[0] = 0\n    frame_flag[-1] = 2\n    return {\n        \"ref_dof_pos\": fk_result.dof_pos.squeeze(0).numpy().astype(np.float32),\n        \"ref_dof_vel\": fk_result.dof_vels.squeeze(0)\n        .numpy()\n        .astype(np.float32),\n        \"ref_global_translation\": fk_result.global_translation.squeeze(0)\n        .numpy()\n        .astype(np.float32),\n        \"ref_global_rotation_quat\": fk_result.global_rotation.squeeze(0)\n        .numpy()\n        .astype(np.float32),\n        \"ref_global_velocity\": fk_result.global_velocity.squeeze(0)\n        .numpy()\n        .astype(np.float32),\n        \"ref_global_angular_velocity\": fk_result.global_angular_velocity.squeeze(\n            0\n        )\n        .numpy()\n        .astype(np.float32),\n        \"frame_flag\": frame_flag,\n    }\n\n\n@dataclass\nclass ProcessedClip:\n    motion_key: str\n    metadata: Dict[str, Any]\n    arrays: Dict[str, np.ndarray]\n\n\nclass HoloMotionPreprocessor:\n    \"\"\"\n    Composable preprocessing pipeline operating on standardized HoloMotion NPZ clips.\n\n    Supports per-clip stages like slicing and Butterworth filtering,\n    plus dataset-level kinematic tagging.\n    \"\"\"\n\n    def __init__(\n        self,\n        slicing_cfg: Optional[DictConfig] = None,\n        filtering_cfg: Optional[DictConfig] = None,\n        tagging_cfg: Optional[DictConfig] = None,\n        padding_cfg: Optional[DictConfig] = None,\n        pipeline: Optional[List[str]] = None,\n    ) -> None:\n        self.slicing_cfg = slicing_cfg\n        self.filtering_cfg = filtering_cfg\n        self.tagging_cfg = tagging_cfg\n        self.padding_cfg = padding_cfg\n        self.pipeline = self._resolve_pipeline(pipeline)\n        self._humanoid_fk: Optional[HumanoidBatch] = None\n        self._robot_cfg: Optional[DictConfig] = None\n\n    def _resolve_pipeline(self, pipeline: Optional[List[str]]) -> List[str]:\n        if pipeline is not None:\n            return list(pipeline)\n        return []\n\n    def process_clip(self, clip: ProcessedClip) -> List[ProcessedClip]:\n        clips = [clip]\n        logger.debug(\n            f\"Processing clip '{clip.motion_key}' with pipeline: {self.pipeline}\"\n        )\n        for stage in self.pipeline:\n            logger.debug(f\"Applying stage: {stage}\")\n            if stage in (\"slicing\", \"slice\"):\n                next_clips: List[ProcessedClip] = []\n                for c in clips:\n                    next_clips.extend(self._apply_slicing(c))\n                clips = next_clips\n                logger.debug(f\"After slicing: {len(clips)} clips\")\n            elif stage in (\n                \"apply_butterworth_filter\",\n                \"filtering\",\n                \"butterworth_filter\",\n            ):\n                clips = [self._apply_filtering(c) for c in clips]\n                logger.debug(\n                    f\"After apply_butterworth_filter: {len(clips)} clips\"\n                )\n            elif stage == \"filename_as_motionkey\":\n                clips = [self._apply_filename_as_motionkey(c) for c in clips]\n                logger.debug(\n                    f\"After filename_as_motionkey: {len(clips)} clips\"\n                )\n            elif stage == \"legacy_to_ref_keys\":\n                clips = [self._apply_legacy_to_ref_keys(c) for c in clips]\n                logger.debug(f\"After legacy_to_ref_keys: {len(clips)} clips\")\n            elif stage == \"add_legacy_keys\":\n                clips = [self._apply_add_legacy_keys(c) for c in clips]\n                logger.debug(f\"After add_legacy_keys: {len(clips)} clips\")\n            elif stage == \"add_padding\":\n                clips = [self._apply_add_padding(c) for c in clips]\n                logger.debug(f\"After add_padding: {len(clips)} clips\")\n            else:\n                logger.warning(\n                    f\"Unknown preprocessing stage '{stage}' ignored.\"\n                )\n        return clips\n\n    def _apply_slicing(self, clip: ProcessedClip) -> List[ProcessedClip]:\n        cfg = self.slicing_cfg\n        if cfg is None:\n            logger.warning(\n                \"Slicing requested but slicing_cfg is None - skipping slicing\"\n            )\n            return [clip]\n\n        window_size = int(getattr(cfg, \"window_size\", 0))\n        overlap = int(getattr(cfg, \"overlap\", 0))\n        seq_len = int(clip.metadata.get(\"num_frames\", 0))\n        if seq_len <= 0:\n            return [clip]\n\n        slice_specs = compute_slices(seq_len, window_size, overlap)\n        if not slice_specs:\n            return [clip]\n\n        fps = float(clip.metadata.get(\"motion_fps\", 0.0))\n        raw_motion_key = str(\n            clip.metadata.get(\n                \"raw_motion_key\", clip.metadata.get(\"motion_key\", \"\")\n            )\n        )\n        base_motion_key = str(clip.metadata.get(\"motion_key\", raw_motion_key))\n        arrays = clip.arrays\n\n        out_clips: List[ProcessedClip] = []\n        for s, e in slice_specs:\n            arrays_window: Dict[str, np.ndarray] = {}\n            for k, v in arrays.items():\n                if (\n                    isinstance(v, np.ndarray)\n                    and v.ndim >= 1\n                    and v.shape[0] == seq_len\n                ):\n                    arrays_window[k] = v[s:e]\n                else:\n                    arrays_window[k] = v\n\n            num_frames = int(e - s)\n            if num_frames <= 0:\n                continue\n\n            wallclock_len = float(num_frames - 1) / fps if fps > 0.0 else 0.0\n            if s == 0 and e == seq_len:\n                motion_key = base_motion_key\n            else:\n                motion_key = f\"{base_motion_key}_s{s}_e{e}\"\n\n            meta = dict(clip.metadata)\n            meta[\"motion_key\"] = motion_key\n            meta[\"raw_motion_key\"] = raw_motion_key\n            meta[\"num_frames\"] = num_frames\n            meta[\"wallclock_len\"] = wallclock_len\n            meta[\"slice_start\"] = int(s)\n            meta[\"slice_end\"] = int(e)\n            out_clips.append(\n                ProcessedClip(\n                    motion_key=motion_key,\n                    metadata=meta,\n                    arrays=arrays_window,\n                )\n            )\n        return out_clips\n\n    def _apply_filtering(self, clip: ProcessedClip) -> ProcessedClip:\n        cfg = self.filtering_cfg\n        if cfg is None:\n            logger.warning(\n                \"Filtering requested but filtering_cfg is None - skipping filtering\"\n            )\n            return clip\n\n        fps = float(clip.metadata.get(\"motion_fps\", 0.0))\n        cutoff = float(getattr(cfg, \"butter_cutoff_hz\", 0.0))\n        order = int(getattr(cfg, \"butter_order\", 4))\n        ft = butterworth_filter_ref_arrays(\n            clip.arrays, fps=fps, cutoff_hz=cutoff, order=order\n        )\n        arrays = dict(clip.arrays)\n        arrays.update(ft)\n        return ProcessedClip(\n            motion_key=clip.motion_key,\n            metadata=clip.metadata,\n            arrays=arrays,\n        )\n\n    def _apply_filename_as_motionkey(\n        self, clip: ProcessedClip\n    ) -> ProcessedClip:\n        filename = clip.metadata.get(\"source_filename\", None)\n        if filename is None:\n            logger.warning(\n                \"filename_as_motionkey requested but source_filename not found in metadata - skipping\"\n            )\n            return clip\n\n        new_motion_key = str(filename)\n        meta = dict(clip.metadata)\n        meta[\"motion_key\"] = new_motion_key\n        if \"raw_motion_key\" not in meta:\n            meta[\"raw_motion_key\"] = clip.motion_key\n\n        return ProcessedClip(\n            motion_key=new_motion_key,\n            metadata=meta,\n            arrays=clip.arrays,\n        )\n\n    def _apply_add_legacy_keys(self, clip: ProcessedClip) -> ProcessedClip:\n        \"\"\"Add deprecated legacy keys for backward compatibility.\n\n        Maps ref_* keys to legacy unprefixed keys according to spec:\n        - ref_dof_pos -> dof_pos\n        - ref_dof_vel -> dof_vels\n        - ref_global_translation -> global_translation\n        - ref_global_rotation_quat -> global_rotation_quat\n        - ref_global_velocity -> global_velocity\n        - ref_global_angular_velocity -> global_angular_velocity\n        \"\"\"\n        ref_to_legacy = {\n            \"ref_dof_pos\": \"dof_pos\",\n            \"ref_dof_vel\": \"dof_vels\",\n            \"ref_global_translation\": \"global_translation\",\n            \"ref_global_rotation_quat\": \"global_rotation_quat\",\n            \"ref_global_velocity\": \"global_velocity\",\n            \"ref_global_angular_velocity\": \"global_angular_velocity\",\n        }\n\n        arrays = dict(clip.arrays)\n        for ref_key, legacy_key in ref_to_legacy.items():\n            if ref_key in arrays:\n                if legacy_key not in arrays:\n                    arrays[legacy_key] = arrays[ref_key].copy()\n                    logger.debug(\n                        f\"Added legacy key '{legacy_key}' from '{ref_key}'\"\n                    )\n                else:\n                    logger.debug(\n                        f\"Legacy key '{legacy_key}' already exists, skipping\"\n                    )\n\n        return ProcessedClip(\n            motion_key=clip.motion_key,\n            metadata=clip.metadata,\n            arrays=arrays,\n        )\n\n    def _apply_legacy_to_ref_keys(self, clip: ProcessedClip) -> ProcessedClip:\n        \"\"\"Add new ref_* keys from legacy unprefixed keys.\n\n        Maps legacy keys to ref_* keys according to spec while keeping the\n        original legacy arrays:\n        - dof_pos -> ref_dof_pos\n        - dof_vels -> ref_dof_vel\n        - global_translation -> ref_global_translation\n        - global_rotation_quat -> ref_global_rotation_quat\n        - global_velocity -> ref_global_velocity\n        - global_angular_velocity -> ref_global_angular_velocity\n        \"\"\"\n        legacy_to_ref = {\n            \"dof_pos\": \"ref_dof_pos\",\n            \"dof_vels\": \"ref_dof_vel\",\n            \"global_translation\": \"ref_global_translation\",\n            \"global_rotation_quat\": \"ref_global_rotation_quat\",\n            \"global_velocity\": \"ref_global_velocity\",\n            \"global_angular_velocity\": \"ref_global_angular_velocity\",\n        }\n\n        arrays = dict(clip.arrays)\n        for legacy_key, ref_key in legacy_to_ref.items():\n            if legacy_key in arrays:\n                if ref_key not in arrays:\n                    arrays[ref_key] = arrays[legacy_key].copy()\n                    logger.debug(\n                        f\"Added ref key '{ref_key}' from legacy key '{legacy_key}'\"\n                    )\n                else:\n                    logger.debug(\n                        f\"Ref key '{ref_key}' already exists, skipping\"\n                    )\n\n        return ProcessedClip(\n            motion_key=clip.motion_key,\n            metadata=clip.metadata,\n            arrays=arrays,\n        )\n\n    def _get_humanoid_fk(self) -> HumanoidBatch:\n        \"\"\"Lazy-load and cache HumanoidBatch for FK computation.\"\"\"\n        if self._humanoid_fk is not None:\n            return self._humanoid_fk\n        cfg = self.padding_cfg\n        robot_config_path = str(getattr(cfg, \"robot_config_path\", \"\"))\n        self._robot_cfg = OmegaConf.load(robot_config_path)\n        self._humanoid_fk = HumanoidBatch(self._robot_cfg.robot)\n        return self._humanoid_fk\n\n    def _get_default_dof_pos(self) -> np.ndarray:\n        \"\"\"Get default DOF positions from robot config.\"\"\"\n        robot_cfg = self._robot_cfg.robot\n        dof_names = list(robot_cfg.dof_names)\n        init_state = robot_cfg.get(\"init_state\", {})\n        default_angles = init_state.get(\"default_joint_angles\", {})\n        default_dof = np.zeros(len(dof_names), dtype=np.float32)\n        for i, name in enumerate(dof_names):\n            default_dof[i] = float(default_angles.get(name, 0.0))\n        return default_dof\n\n    def _apply_add_padding(self, clip: ProcessedClip) -> ProcessedClip:\n        \"\"\"Add transition and static padding to the motion clip.\n\n        Adds stand-still padding at default pose before and after the motion,\n        with smooth transitions between default pose and the motion's first/last\n        frames. Recalculates all states from root pos, rot and dof pos via FK.\n        \"\"\"\n        cfg = self.padding_cfg\n        if cfg is None:\n            logger.warning(\n                \"Padding requested but padding_cfg is None - skipping padding\"\n            )\n            return clip\n\n        fps = float(clip.metadata.get(\"motion_fps\", 50.0))\n        stand_still_time = float(getattr(cfg, \"stand_still_time\", 1.0))\n        transition_time = float(getattr(cfg, \"transition_time\", 1.5))\n        robot_config_path = str(getattr(cfg, \"robot_config_path\", \"\"))\n        if not robot_config_path:\n            raise ValueError(\n                \"robot_config_path must be specified in padding_cfg\"\n            )\n\n        stand_still_frames = max(1, int(stand_still_time * fps))\n        transition_frames = max(1, int(transition_time * fps))\n\n        humanoid_fk = self._get_humanoid_fk()\n        default_dof = self._get_default_dof_pos()\n        extend_config = self._robot_cfg.robot.get(\"extend_config\", [])\n        num_augment = len(extend_config) if extend_config else 0\n\n        # Get root offset from HumanoidBatch (usually from MJCF root body pos)\n        # self._offsets is [1, num_bodies, 3]\n        # root_offset = humanoid_fk._offsets[0, 0].cpu().numpy()\n\n        arrays = clip.arrays\n        dof_pos = arrays.get(\"ref_dof_pos\", arrays.get(\"dof_pos\"))\n        global_trans = arrays.get(\n            \"ref_global_translation\", arrays.get(\"global_translation\")\n        )\n        global_rot = arrays.get(\n            \"ref_global_rotation_quat\", arrays.get(\"global_rotation_quat\")\n        )\n        if dof_pos is None or global_trans is None or global_rot is None:\n            raise ValueError(\n                \"Missing required arrays for padding: ref_dof_pos, \"\n                \"ref_global_translation, or ref_global_rotation_quat\"\n            )\n\n        T_orig = dof_pos.shape[0]\n        dof_pos = dof_pos.astype(np.float32, copy=True)\n        root_pos = global_trans[:, 0, :].astype(np.float32, copy=True)\n        root_rot = global_rot[:, 0, :].astype(np.float32, copy=True)\n\n        first_dof = dof_pos[0].copy()\n        last_dof = dof_pos[-1].copy()\n        first_root_pos = root_pos[0].copy()\n        last_root_pos = root_pos[-1].copy()\n        first_root_rot = root_rot[0].copy()\n        last_root_rot = root_rot[-1].copy()\n\n        logger.debug(\n            f\"Padding: T_orig={T_orig}, first_root_pos={first_root_pos}, \"\n            f\"last_root_pos={last_root_pos}\"\n        )\n        logger.debug(\n            f\"Padding: first_dof[:3]={first_dof[:3]}, last_dof[:3]={last_dof[:3]}\"\n        )\n        logger.debug(\n            f\"Padding: original dof_pos[-1][:3]={dof_pos[-1][:3]}, \"\n            f\"original root_pos[-1]={root_pos[-1]}, original root_rot[-1]={root_rot[-1]}\"\n        )\n\n        first_yaw_quat = _extract_yaw_only_quat(first_root_rot)\n        last_yaw_quat = _extract_yaw_only_quat(last_root_rot)\n\n        start_stand_dof = np.tile(default_dof, (stand_still_frames, 1))\n        start_trans_dof = _interpolate_linear(\n            default_dof, first_dof, transition_frames\n        )\n        end_trans_dof = _interpolate_linear(\n            last_dof, default_dof, transition_frames\n        )\n        end_stand_dof = np.tile(default_dof, (stand_still_frames, 1))\n\n        start_stand_root_pos = np.tile(first_root_pos, (stand_still_frames, 1))\n        start_trans_root_pos = _interpolate_linear(\n            first_root_pos, first_root_pos, transition_frames\n        )\n        end_trans_root_pos = _interpolate_linear(\n            last_root_pos, last_root_pos, transition_frames\n        )\n        end_stand_root_pos = np.tile(last_root_pos, (stand_still_frames, 1))\n\n        start_stand_root_rot = np.tile(first_yaw_quat, (stand_still_frames, 1))\n        start_trans_root_rot = _interpolate_quaternions_slerp(\n            first_yaw_quat, first_root_rot, transition_frames\n        )\n        end_trans_root_rot = _interpolate_quaternions_slerp(\n            last_root_rot, last_yaw_quat, transition_frames\n        )\n        end_stand_root_rot = np.tile(last_yaw_quat, (stand_still_frames, 1))\n\n        # Construct full sequence of inputs\n        full_dof = np.concatenate(\n            [\n                start_stand_dof,\n                start_trans_dof,\n                dof_pos,\n                end_trans_dof,\n                end_stand_dof,\n            ],\n            axis=0,\n        )\n        full_root_pos = np.concatenate(\n            [\n                start_stand_root_pos,\n                start_trans_root_pos,\n                root_pos,\n                end_trans_root_pos,\n                end_stand_root_pos,\n            ],\n            axis=0,\n        )\n        full_root_rot = np.concatenate(\n            [\n                start_stand_root_rot,\n                start_trans_root_rot,\n                root_rot,\n                end_trans_root_rot,\n                end_stand_root_rot,\n            ],\n            axis=0,\n        )\n\n        # Compute FK for the entire sequence to ensure continuity\n        new_arrays = _compute_fk_motion(\n            full_dof,\n            full_root_pos,\n            full_root_rot,\n            humanoid_fk,\n            num_augment,\n            fps,\n        )\n\n        T_new = full_dof.shape[0]\n        wallclock_len = float(T_new - 1) / fps if fps > 0.0 else 0.0\n        meta = dict(clip.metadata)\n        meta[\"num_frames\"] = T_new\n        meta[\"wallclock_len\"] = wallclock_len\n        meta[\"padding_stand_still_frames\"] = stand_still_frames\n        meta[\"padding_transition_frames\"] = transition_frames\n        meta[\"original_num_frames\"] = T_orig\n\n        return ProcessedClip(\n            motion_key=clip.motion_key,\n            metadata=meta,\n            arrays=new_arrays,\n        )\n\n    def process_npz_file(self, npz_path: Path) -> List[ProcessedClip]:\n        with np.load(npz_path, allow_pickle=False) as data:\n            if \"metadata\" not in data:\n                raise KeyError(f\"'metadata' missing in NPZ: {npz_path}\")\n            meta_text = str(data[\"metadata\"])\n            metadata = json.loads(meta_text)\n            motion_key = str(metadata[\"motion_key\"])\n            arrays: Dict[str, np.ndarray] = {}\n            for k in data.files:\n                if k == \"metadata\":\n                    continue\n                arrays[k] = np.array(data[k], copy=False)\n\n        filename_without_ext = npz_path.stem\n        metadata[\"source_filename\"] = filename_without_ext\n\n        base_clip = ProcessedClip(\n            motion_key=motion_key,\n            metadata=metadata,\n            arrays=arrays,\n        )\n        return self.process_clip(base_clip)\n\n    def run_on_directory(\n        self,\n        src_root: Path,\n        out_root: Path,\n        use_ray: bool = False,\n        num_workers: int = 0,\n    ) -> None:\n        if src_root.is_dir():\n            if (src_root / \"clips\").is_dir():\n                clips_src = src_root / \"clips\"\n            else:\n                clips_src = src_root\n        else:\n            raise ValueError(f\"Source root is not a directory: {src_root}\")\n\n        clips_dst = out_root / \"clips\"\n        clips_dst.mkdir(parents=True, exist_ok=True)\n\n        files = sorted([p for p in clips_src.rglob(\"*.npz\") if p.is_file()])\n        if not files:\n            logger.info(\"No NPZ files found to process.\")\n            return\n\n        if use_ray:\n            if num_workers <= 0:\n                available_cpus = int(ray.available_resources().get(\"CPU\", 1))\n                effective_workers = max(1, available_cpus)\n            else:\n                effective_workers = num_workers\n            self._run_on_directory_ray(files, clips_dst, effective_workers)\n        else:\n            self._run_on_directory_sequential(files, clips_dst)\n\n    def _run_on_directory_sequential(\n        self, files: List[Path], clips_dst: Path\n    ) -> None:\n        logger.info(f\"Processing {len(files)} NPZ files sequentially\")\n        logger.info(f\"Pipeline stages to apply: {self.pipeline}\")\n        total_input_clips = 0\n        total_output_clips = 0\n        for p in tqdm(files, desc=\"HoloMotion preprocess NPZ\", unit=\"file\"):\n            clips = self.process_npz_file(p)\n            total_input_clips += 1\n            for clip in clips:\n                total_output_clips += 1\n                out_name = f\"{clip.motion_key}.npz\"\n                out_path = clips_dst / out_name\n                metadata_json = json.dumps(clip.metadata)\n                np.savez_compressed(\n                    out_path, metadata=metadata_json, **clip.arrays\n                )\n        logger.info(\n            f\"Processed {total_input_clips} input files into {total_output_clips} output clips\"\n        )\n\n    def _run_on_directory_ray(\n        self, files: List[Path], clips_dst: Path, num_workers: int\n    ) -> None:\n        if num_workers <= 0:\n            available_cpus = int(ray.available_resources().get(\"CPU\", 1))\n            num_actors = min(len(files), max(1, available_cpus))\n        else:\n            num_actors = min(len(files), num_workers)\n        actors = [\n            PreprocessorActor.remote(\n                slicing_cfg=self.slicing_cfg,\n                filtering_cfg=self.filtering_cfg,\n                tagging_cfg=self.tagging_cfg,\n                padding_cfg=self.padding_cfg,\n                pipeline=self.pipeline,\n            )\n            for _ in range(num_actors)\n        ]\n\n        pending = {}\n        next_idx = 0\n        for i in range(min(num_actors, len(files))):\n            p = files[next_idx]\n            next_idx += 1\n            ref = actors[i].process_npz_file.remote(str(p))\n            pending[ref] = i\n\n        total_outputs = 0\n        with tqdm(\n            total=len(files), desc=\"Ray: HoloMotion preprocess NPZ\"\n        ) as pbar:\n            while pending:\n                done, _ = ray.wait(list(pending.keys()), num_returns=1)\n                ref = done[0]\n                actor_idx = pending.pop(ref)\n                clips = ray.get(ref)\n                for clip in clips:\n                    out_name = f\"{clip.motion_key}.npz\"\n                    out_path = clips_dst / out_name\n                    metadata_json = json.dumps(clip.metadata)\n                    np.savez_compressed(\n                        out_path, metadata=metadata_json, **clip.arrays\n                    )\n                    total_outputs += 1\n                pbar.update(1)\n                if next_idx < len(files):\n                    p = files[next_idx]\n                    next_idx += 1\n                    new_ref = actors[actor_idx].process_npz_file.remote(str(p))\n                    pending[new_ref] = actor_idx\n\n        logger.info(f\"Processed {total_outputs} clips total.\")\n\n    def tag_directory(self, clips_dir: Path, tags_path: Path) -> None:\n        files = sorted([p for p in clips_dir.rglob(\"*.npz\") if p.is_file()])\n\n        clip_info: Dict[str, Dict[str, Dict[str, float]]] = {}\n        all_speed: List[np.ndarray] = []\n        all_wnorm: List[np.ndarray] = []\n        all_zrel: List[np.ndarray] = []\n        all_jerk: List[np.ndarray] = []\n\n        for f in tqdm(files, desc=\"Tagging kinematics\", unit=\"file\"):\n            with np.load(f, allow_pickle=True) as data:\n                meta_text = str(data[\"metadata\"])\n                meta = json.loads(meta_text)\n                key = str(meta[\"motion_key\"])\n                fps = float(meta[\"motion_fps\"])\n\n                def pick(name: str) -> np.ndarray:\n                    if f\"ft_ref_{name}\" in data:\n                        return np.array(data[f\"ft_ref_{name}\"], copy=False)\n                    if f\"ref_{name}\" in data:\n                        return np.array(data[f\"ref_{name}\"], copy=False)\n                    return np.array([], dtype=np.float32)\n\n                gv = pick(\"global_velocity\")\n                ga = pick(\"global_angular_velocity\")\n                gt = pick(\"global_translation\")\n\n                if gv.size > 0:\n                    root_vel = gv[:, 0, :]\n                    speed = np.linalg.norm(root_vel, axis=1)\n                else:\n                    speed = np.array([], dtype=float)\n\n                if ga.size > 0:\n                    root_w = ga[:, 0, :]\n                    wnorm = np.linalg.norm(root_w, axis=1)\n                else:\n                    wnorm = np.array([], dtype=float)\n\n                if gt.size > 0:\n                    root_pos_z = gt[:, 0, 2]\n                    z_rel = np.abs(root_pos_z - float(root_pos_z[0]))\n                else:\n                    z_rel = np.array([], dtype=float)\n\n                if gv.shape[0] >= 3:\n                    dt = 1.0 / fps if fps > 0.0 else 0.0\n                    a = (\n                        np.diff(gv, axis=0) / dt\n                        if dt > 0.0\n                        else np.zeros_like(gv)\n                    )\n                    j = (\n                        np.diff(a, axis=0) / dt\n                        if dt > 0.0\n                        else np.zeros_like(a)\n                    )\n                    jn = np.linalg.norm(j, axis=2)\n                else:\n                    jn = np.array([], dtype=float)\n\n                clip_info[key] = {\n                    \"root_linear_speed\": _summary(speed),\n                    \"root_angular_speed\": _summary(wnorm),\n                    \"root_delta_z\": _summary(z_rel),\n                    \"jerk\": _summary(jn),\n                }\n\n                if speed.size > 0:\n                    all_speed.append(speed.astype(float))\n                if wnorm.size > 0:\n                    all_wnorm.append(wnorm.astype(float))\n                if z_rel.size > 0:\n                    all_zrel.append(z_rel.astype(float))\n                if jn.size > 0:\n                    all_jerk.append(jn.astype(float))\n\n        speed_cat = (\n            np.concatenate([a for a in all_speed if a.size > 0], axis=0)\n            if len(all_speed) > 0\n            else np.array([], dtype=float)\n        )\n        wnorm_cat = (\n            np.concatenate([a for a in all_wnorm if a.size > 0], axis=0)\n            if len(all_wnorm) > 0\n            else np.array([], dtype=float)\n        )\n        zrel_cat = (\n            np.concatenate([a for a in all_zrel if a.size > 0], axis=0)\n            if len(all_zrel) > 0\n            else np.array([], dtype=float)\n        )\n        jerk_cat = (\n            np.concatenate([a for a in all_jerk if a.size > 0], axis=0)\n            if len(all_jerk) > 0\n            else np.array([], dtype=float)\n        )\n\n        result = {\n            \"dataset_stats\": {\n                \"root_linear_speed\": _ds_summary(speed_cat),\n                \"root_angular_speed\": _ds_summary(wnorm_cat),\n                \"root_delta_z\": _ds_summary(zrel_cat),\n                \"jerk\": _ds_summary(jerk_cat),\n            },\n            \"clip_info\": clip_info,\n        }\n        with open(tags_path, \"w\") as f:\n            json.dump(result, f, indent=2, sort_keys=True)\n        logger.info(f\"Wrote kinematic tags JSON to: {tags_path}\")\n\n\n@ray.remote\nclass PreprocessorActor:\n    \"\"\"Ray actor that holds a HoloMotionPreprocessor instance for parallel processing.\"\"\"\n\n    def __init__(\n        self,\n        slicing_cfg: Optional[DictConfig] = None,\n        filtering_cfg: Optional[DictConfig] = None,\n        tagging_cfg: Optional[DictConfig] = None,\n        padding_cfg: Optional[DictConfig] = None,\n        pipeline: Optional[List[str]] = None,\n    ) -> None:\n        self.preprocessor = HoloMotionPreprocessor(\n            slicing_cfg=slicing_cfg,\n            filtering_cfg=filtering_cfg,\n            tagging_cfg=tagging_cfg,\n            padding_cfg=padding_cfg,\n            pipeline=pipeline,\n        )\n        logger.debug(\n            f\"PreprocessorActor initialized with pipeline: {self.preprocessor.pipeline}\"\n        )\n\n    def process_npz_file(self, npz_path_str: str) -> List[ProcessedClip]:\n        npz_path = Path(npz_path_str)\n        return self.preprocessor.process_npz_file(npz_path)\n\n\n@hydra.main(\n    config_path=\"../../config\",\n    config_name=\"motion_retargeting/holomotion_preprocess\",\n    version_base=None,\n)\ndef main(cfg: DictConfig) -> None:\n    logger.remove()\n    logger.add(sys.stderr, level=\"INFO\", colorize=True)\n\n    src_root = Path(str(cfg.io.src_root)).expanduser().resolve()\n    out_root = Path(str(cfg.io.out_root)).expanduser().resolve()\n    out_root.mkdir(parents=True, exist_ok=True)\n\n    # Dump resolved config used\n    with open(out_root / \"config_used.yaml\", \"w\") as f:\n        f.write(OmegaConf.to_yaml(cfg))\n\n    # Parse pipeline\n    pipeline_cfg = cfg.get(\"preprocess\", None)\n    logger.debug(f\"Raw preprocess config: {pipeline_cfg}\")\n    pipeline = None\n    if pipeline_cfg is not None:\n        pipeline_val = pipeline_cfg.get(\"pipeline\", None)\n        logger.debug(\n            f\"Raw pipeline value: {pipeline_val} (type: {type(pipeline_val)})\"\n        )\n        if pipeline_val is not None:\n            if isinstance(pipeline_val, (list, tuple, ListConfig)):\n                pipeline = [str(s) for s in pipeline_val]\n            elif isinstance(pipeline_val, str):\n                import ast\n\n                pipeline = ast.literal_eval(pipeline_val)\n            else:\n                logger.warning(\n                    f\"Unexpected pipeline type: {type(pipeline_val)}, value: {pipeline_val}\"\n                )\n                pipeline = []\n        else:\n            logger.debug(\"pipeline_val is None\")\n    else:\n        logger.debug(\"preprocess config is None\")\n\n    # Separate per-clip stages from dataset-level stages\n    per_clip_pipeline = (\n        [s for s in pipeline if s != \"tagging\"] if pipeline else []\n    )\n    tagging_enabled = pipeline and \"tagging\" in pipeline\n\n    logger.info(\"=\" * 80)\n    logger.info(\"Preprocessing Configuration:\")\n    logger.info(f\"  Source directory: {src_root}\")\n    logger.info(f\"  Output directory: {out_root}\")\n    if pipeline:\n        logger.info(f\"  Pipeline stages: {pipeline}\")\n        logger.info(f\"  Number of stages: {len(pipeline)}\")\n        for i, stage in enumerate(pipeline, 1):\n            logger.info(f\"    {i}. {stage}\")\n        if tagging_enabled:\n            logger.info(\n                \"  Note: 'tagging' is a dataset-level operation and will run after all clips are processed\"\n            )\n    else:\n        logger.warning(\n            \"  No preprocessing pipeline specified - no processors will be applied!\"\n        )\n    logger.info(\"=\" * 80)\n\n    use_ray = bool(cfg.get(\"ray\", {}).get(\"enabled\", False))\n    num_workers = int(cfg.get(\"ray\", {}).get(\"num_workers\", 0))\n    ray_address = str(cfg.get(\"ray\", {}).get(\"ray_address\", \"\"))\n\n    if use_ray:\n        logging.getLogger(\"filelock\").setLevel(logging.WARNING)\n        logging.getLogger(\"ray\").setLevel(logging.ERROR)\n        os.environ.setdefault(\"RAY_BACKEND_LOG_LEVEL\", \"error\")\n\n        if ray_address:\n            ray.init(\n                address=ray_address,\n                ignore_reinit_error=True,\n                log_to_driver=False,\n                include_dashboard=False,\n                logging_level=logging.ERROR,\n            )\n            if num_workers <= 0:\n                num_workers = int(ray.available_resources().get(\"CPU\", 1))\n        else:\n            num_cpus = None if num_workers <= 0 else num_workers\n            ray.init(\n                num_cpus=num_cpus,\n                ignore_reinit_error=True,\n                log_to_driver=False,\n                include_dashboard=False,\n                logging_level=logging.ERROR,\n            )\n            if num_workers <= 0:\n                num_workers = int(ray.available_resources().get(\"CPU\", 1))\n\n    preprocessor = HoloMotionPreprocessor(\n        slicing_cfg=cfg.slicing,\n        filtering_cfg=cfg.filtering,\n        tagging_cfg=cfg.tagging,\n        padding_cfg=cfg.get(\"padding\", None),\n        pipeline=per_clip_pipeline if per_clip_pipeline else None,\n    )\n\n    logger.info(\n        f\"Preprocessor initialized with pipeline: {preprocessor.pipeline}\"\n    )\n    logger.info(\n        f\"  Slicing config present: {preprocessor.slicing_cfg is not None}\"\n    )\n    logger.info(\n        f\"  Filtering config present: {preprocessor.filtering_cfg is not None}\"\n    )\n    logger.info(\n        f\"  Tagging config present: {preprocessor.tagging_cfg is not None}\"\n    )\n\n    preprocessor.run_on_directory(\n        src_root, out_root, use_ray=use_ray, num_workers=num_workers\n    )\n\n    if use_ray:\n        ray.shutdown()\n\n    if tagging_enabled:\n        if str(cfg.tagging.output_json_path):\n            tags_path = (\n                Path(str(cfg.tagging.output_json_path)).expanduser().resolve()\n            )\n        else:\n            tags_path = out_root / \"kinematic_tags.json\"\n        clips_dir = out_root / \"clips\"\n        preprocessor.tag_directory(clips_dir, tags_path)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/kinematic_filter.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport json\nimport sys\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Set, Tuple\n\nimport hydra\nimport yaml\nfrom loguru import logger\nfrom omegaconf import DictConfig, OmegaConf\nfrom tqdm import tqdm\n\n\ndef _eval_rule(val: float, op: str, thr: float) -> bool:\n    if op == \">\":\n        return val > thr\n    if op == \">=\":\n        return val >= thr\n    if op == \"<\":\n        return val < thr\n    if op == \"<=\":\n        return val <= thr\n    if op == \"==\":\n        return val == thr\n    if op == \"!=\":\n        return val != thr\n    raise ValueError(f\"Unsupported op: {op}\")\n\n\ndef _deep_get(container: Dict[str, Any], parts: List[str]) -> Optional[float]:\n    cur: Any = container\n    for p in parts:\n        if not isinstance(cur, dict) or p not in cur:\n            return None\n        cur = cur[p]\n    if isinstance(cur, (int, float)):\n        return float(cur)\n    return None\n\n\ndef _resolve_value(\n    tags_root: Dict[str, Any],\n    clip_group: Dict[str, Any],\n    path: str,\n) -> Optional[float]:\n    \"\"\"Resolve a threshold path to a numeric value.\n\n    - dataset_stats.<feature>.<DS_stat> reads from tags_root\n    - kinematic_features.<feature>.<stat> reads from clip_group\n    - <feature>.<stat> (no prefix) also reads from clip_group for convenience.\n    \"\"\"\n    parts = str(path).split(\".\")\n    if len(parts) == 0:\n        return None\n    if parts[0] == \"dataset_stats\":\n        return _deep_get(tags_root, parts)\n    if parts[0] == \"kinematic_features\":\n        return _deep_get(clip_group, parts[1:])\n    return _deep_get(clip_group, parts)\n\n\ndef filter_with_schema(\n    tags: Dict[str, Any],\n    schema: Dict[str, Any],\n) -> Tuple[Set[str], Dict[str, int], Dict[str, int]]:\n    thresholds: Dict[str, Dict[str, Any]] = schema.get(\"thresholds\", {}) or {}\n    across_mode = str(schema.get(\"across\", \"union\"))\n    out: Set[str] = set()\n    path_counts: Dict[str, int] = {}\n    group_counts: Dict[str, int] = {}\n\n    clips: Dict[str, Dict[str, Any]] = tags.get(\"clip_info\", {}) or {}\n    for motion_key, groups in tqdm(\n        clips.items(), desc=\"Evaluating schema\", unit=\"clip\"\n    ):\n        hits: List[bool] = []\n        hits_by_path: Dict[str, bool] = {}\n        group_hit_any: Dict[str, bool] = {}\n        for path, spec in thresholds.items():\n            parts = str(path).split(\".\")\n            if len(parts) == 0:\n                continue\n            val = _resolve_value(tags, groups, path)\n            if val is None:\n                continue\n            op = str(spec.get(\"op\", \">\"))\n            thr = float(spec[\"value\"])\n            hit = _eval_rule(val, op, thr)\n            hits.append(hit)\n            hits_by_path[path] = hit\n            grp = parts[0]\n            if hit:\n                group_hit_any[grp] = True\n        if len(hits) == 0:\n            continue\n        if across_mode == \"union\":\n            excluded = any(hits)\n        elif across_mode == \"intersection\":\n            excluded = all(hits)\n        else:\n            raise ValueError(f\"Invalid across mode: {across_mode}\")\n        if not excluded:\n            continue\n        out.add(motion_key)\n        # accumulate counts for excluded clips\n        for pth, hit in hits_by_path.items():\n            if hit:\n                path_counts[pth] = path_counts.get(pth, 0) + 1\n        for grp, any_hit in group_hit_any.items():\n            if any_hit:\n                group_counts[grp] = group_counts.get(grp, 0) + 1\n    return out, path_counts, group_counts\n\n\ndef _default_schema_path() -> Path:\n    # holomotion/src/motion_retargeting/kinematic_filter.py\n    # -> holomotion/config/motion_retargeting/kinematic_filtering_schema.yaml\n    this_file = Path(__file__).resolve()\n    holomotion_dir = this_file.parents[2]\n    return (\n        holomotion_dir\n        / \"config\"\n        / \"motion_retargeting\"\n        / \"kinematic_filtering_schema.yaml\"\n    )\n\n\ndef run(\n    dataset_root: str,\n    schema_yaml_path: Optional[str] = None,\n    output_yaml_path: Optional[str] = None,\n    schema_obj: Optional[Dict[str, Any]] = None,\n) -> Set[str]:\n    \"\"\"Execute kinematic filtering using tags and a schema.\n\n    - dataset_root: directory containing 'kinematic_tags.json'\n    - schema_yaml_path: external YAML with 'across' and 'thresholds' (optional)\n    - schema_obj: inline dict with 'across' and 'thresholds' (optional)\n    - output_yaml_path: where to write the excluded list YAML (optional)\n    \"\"\"\n    root = Path(dataset_root).expanduser().resolve()\n    tags_path = root / \"kinematic_tags.json\"\n    if not tags_path.is_file():\n        raise FileNotFoundError(f\"Missing kinematic tags JSON: {tags_path}\")\n\n    schema: Dict[str, Any]\n    if schema_obj is not None:\n        schema = dict(schema_obj)\n    else:\n        schema_path = (\n            Path(schema_yaml_path).expanduser().resolve()\n            if schema_yaml_path\n            else _default_schema_path()\n        )\n        if not schema_path.is_file():\n            raise FileNotFoundError(f\"Missing schema YAML: {schema_path}\")\n        schema = yaml.safe_load(open(schema_path, \"r\", encoding=\"utf-8\"))\n\n    out_yaml = (\n        Path(output_yaml_path).expanduser().resolve()\n        if output_yaml_path\n        else (root / \"excluded_kinematic_motion_names.yaml\")\n    )\n\n    logger.info(f\"Dataset root: {root}\")\n    logger.info(f\"Reading tags from: {tags_path}\")\n    logger.info(\n        \"Using schema from: inline config\"\n        if schema_obj is not None\n        else \"Using schema from YAML file\"\n    )\n    # Pretty-print resolved schema to console\n    try:\n        logger.info(\n            \"Resolved schema:\\n\"\n            + yaml.safe_dump(schema, sort_keys=True, default_flow_style=False)\n        )\n    except Exception:\n        pass\n\n    tags = json.load(open(tags_path, \"r\", encoding=\"utf-8\"))\n\n    # Dump the used filter config into dataset root\n    try:\n        used_cfg = {\n            \"dataset_root\": str(root),\n            \"output_yaml\": str(out_yaml),\n            \"schema\": schema,\n        }\n        with open(\n            root / \"kinematic_filter_config_used.yaml\", \"w\", encoding=\"utf-8\"\n        ) as f:\n            yaml.safe_dump(\n                used_cfg, f, sort_keys=True, default_flow_style=False\n            )\n    except Exception:\n        pass\n\n    excluded_keys, path_counts, group_counts = filter_with_schema(tags, schema)\n    with open(out_yaml, \"w\", encoding=\"utf-8\") as f:\n        f.write(\"# @package _global_\\n\\n\")\n        f.write(\"excluded_motion_names:\\n\")\n        for k in sorted(excluded_keys):\n            f.write(f\"- {k}\\n\")\n\n    logger.info(f\"Excluded by config: {len(excluded_keys)}\")\n    if len(group_counts) > 0:\n        logger.info(\"Excluded counts by category:\")\n        for grp, cnt in sorted(\n            group_counts.items(), key=lambda kv: kv[1], reverse=True\n        ):\n            logger.info(f\"- {grp}: {cnt}\")\n    if len(path_counts) > 0:\n        logger.info(\"Excluded counts by threshold path:\")\n        for pth, cnt in sorted(\n            path_counts.items(), key=lambda kv: kv[1], reverse=True\n        ):\n            logger.info(f\"- {pth}: {cnt}\")\n    logger.info(f\"Wrote excluded list to: {out_yaml}\")\n    return excluded_keys\n\n\n@hydra.main(\n    config_path=\"../../config\",\n    config_name=\"motion_retargeting/kinematic_filter\",\n    version_base=None,\n)\ndef main(cfg: DictConfig) -> None:\n    logger.remove()\n    logger.add(sys.stderr, level=\"INFO\", colorize=True)\n\n    dataset_root = str(cfg.io.dataset_root)\n    # Optional fields (external schema YAML override and output path)\n    schema_val = \"\"\n    out_val = \"\"\n    schema_obj = None\n    if \"schema\" in cfg:\n        # Inline schema object\n        schema_obj = OmegaConf.to_object(cfg.schema)\n    if \"filtering\" in cfg and hasattr(cfg.filtering, \"schema_yaml\"):\n        schema_val = str(cfg.filtering.get(\"schema_yaml\", \"\") or \"\")\n        out_val = str(cfg.filtering.get(\"output_yaml\", \"\") or \"\")\n    elif \"filtering\" in cfg:\n        out_val = str(cfg.filtering.get(\"output_yaml\", \"\") or \"\")\n\n    schema_yaml_path = schema_val if len(schema_val) > 0 else None\n    output_yaml_path = out_val if len(out_val) > 0 else None\n\n    run(\n        dataset_root=dataset_root,\n        schema_yaml_path=schema_yaml_path,\n        schema_obj=schema_obj,\n        output_yaml_path=output_yaml_path,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/pack_hdf5_v2.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport json\nimport os\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport h5py\nimport hydra\nimport numpy as np\nfrom loguru import logger\nfrom omegaconf import ListConfig, OmegaConf\nfrom tqdm import tqdm\n\n\ndef _ensure_dir(path: str) -> None:\n    os.makedirs(path, exist_ok=True)\n\n\n@dataclass\nclass ArraySpec:\n    name: str\n    shape_tail: Tuple[int, ...]  # shape excluding time dim\n    dtype: np.dtype\n\n\n@dataclass\nclass ClipEntry:\n    clip_id: int\n    name: str\n    path: str\n\n\nclass Hdf5ShardWriter:\n    def __init__(\n        self,\n        h5_path: str,\n        array_specs: List[ArraySpec],\n        chunks_t: int,\n        compression: str,\n    ) -> None:\n        self.h5_path = h5_path\n        self.array_specs = array_specs\n        self.chunks_t = int(chunks_t)\n        self.compression = compression\n\n        _ensure_dir(os.path.dirname(self.h5_path))\n        self.h5 = h5py.File(self.h5_path, \"w\")\n\n        self.datasets: Dict[str, h5py.Dataset] = {}\n        for spec in self.array_specs:\n            chunks = (self.chunks_t, *spec.shape_tail)\n            maxshape = (None, *spec.shape_tail)\n            ds = self.h5.create_dataset(\n                spec.name,\n                shape=(0, *spec.shape_tail),\n                maxshape=maxshape,\n                chunks=chunks,\n                compression=(\n                    self.compression if self.compression != \"none\" else None\n                ),\n                dtype=spec.dtype,\n                shuffle=True if self.compression != \"none\" else False,\n            )\n            self.datasets[spec.name] = ds\n\n        self._clip_starts: List[int] = []\n        self._clip_lengths: List[int] = []\n        self._clip_motion_ids: List[int] = []\n        self._clip_metadata: List[str] = []\n\n        self.t_cursor = 0\n\n    def append_motion(\n        self,\n        motion_id: int,\n        np_arrays: Dict[str, np.ndarray],\n        metadata_json: str,\n    ) -> Tuple[int, int]:\n        if \"ref_dof_pos\" not in np_arrays:\n            raise KeyError(\"ref_dof_pos missing for HDF5 v2 packing\")\n        t_len = int(np_arrays[\"ref_dof_pos\"].shape[0])\n\n        start = self.t_cursor\n        end = start + t_len\n        for spec in self.array_specs:\n            if spec.name not in np_arrays:\n                raise KeyError(\n                    f\"Missing array '{spec.name}' for HDF5 v2 packing\"\n                )\n            ds = self.datasets[spec.name]\n            ds.resize((end, *spec.shape_tail))\n            ds[start:end, ...] = np_arrays[spec.name]\n\n        self._clip_starts.append(start)\n        self._clip_lengths.append(t_len)\n        self._clip_motion_ids.append(motion_id)\n        self._clip_metadata.append(metadata_json)\n        self.t_cursor = end\n        return start, t_len\n\n    def finalize(self) -> Dict[str, Any]:\n        g = self.h5.create_group(\"clips\")\n        g.create_dataset(\n            \"start\", data=np.asarray(self._clip_starts, dtype=np.int64)\n        )\n        g.create_dataset(\n            \"length\", data=np.asarray(self._clip_lengths, dtype=np.int64)\n        )\n        g.create_dataset(\n            \"motion_key_id\",\n            data=np.asarray(self._clip_motion_ids, dtype=np.int64),\n        )\n        vlen_str = h5py.string_dtype(encoding=\"utf-8\")\n        g.create_dataset(\n            \"metadata_json\",\n            data=np.asarray(self._clip_metadata, dtype=vlen_str),\n        )\n\n        summary = {\n            \"file\": self.h5_path,\n            \"num_clips\": len(self._clip_starts),\n            \"num_frames\": int(self.t_cursor),\n        }\n        self.h5.flush()\n        self.h5.close()\n        return summary\n\n\ndef _normalize_root_list(value: Any) -> List[str]:\n    if value is None:\n        return []\n    if isinstance(value, (str, os.PathLike)):\n        return [str(value)]\n    if isinstance(value, (list, tuple, ListConfig)):\n        return [str(v) for v in list(value)]\n    return [str(value)]\n\n\ndef _discover_motion_entries(roots: List[str]) -> List[ClipEntry]:\n    motion_key_to_path: Dict[str, str] = {}\n    for root in roots:\n        root_path = Path(root).expanduser().resolve()\n        parent_dir_name = root_path.name\n        clips_dir = root_path / \"clips\"\n        base_dir = clips_dir if clips_dir.is_dir() else root_path\n        if not base_dir.is_dir():\n            raise FileNotFoundError(f\"NPZ directory not found: {base_dir}\")\n        for dirpath, _, filenames in os.walk(str(base_dir)):\n            for fname in filenames:\n                if not fname.endswith(\".npz\"):\n                    continue\n                base_key = os.path.splitext(fname)[0]\n                motion_key = f\"{parent_dir_name}_{base_key}\"\n                npz_path = os.path.join(dirpath, fname)\n                if motion_key in motion_key_to_path:\n                    raise ValueError(f\"Duplicate motion key: {motion_key}\")\n                motion_key_to_path[motion_key] = npz_path\n\n    entries = [\n        ClipEntry(clip_id=i, name=key, path=motion_key_to_path[key])\n        for i, key in enumerate(sorted(motion_key_to_path.keys()))\n    ]\n    if len(entries) == 0:\n        raise ValueError(\"No NPZ files found in input directories.\")\n    return entries\n\n\ndef _load_metadata_json(npz_path: Path) -> Tuple[str, Dict[str, Any]]:\n    with np.load(npz_path, allow_pickle=False) as data:\n        if \"metadata\" not in data:\n            raise KeyError(f\"'metadata' missing in NPZ: {npz_path}\")\n        metadata_json = str(data[\"metadata\"])\n        num_frames_from_dof = data[\"ref_dof_pos\"].shape[0]\n    metadata = json.loads(metadata_json)\n    num_frames_from_metadata = metadata[\"num_frames\"]\n    assert num_frames_from_dof == num_frames_from_metadata, (\n        f\"num_frames_from_dof {num_frames_from_dof} != num_frames_from_metadata {num_frames_from_metadata} in {npz_path}\"\n    )\n    if not isinstance(metadata, dict):\n        raise ValueError(f\"metadata must be a JSON object in {npz_path}\")\n    return metadata_json, metadata\n\n\ndef _cast_array(array: np.ndarray, name: str, npz_path: Path) -> np.ndarray:\n    if array.dtype == np.float32:\n        return array\n    if array.dtype.kind == \"O\":\n        raise ValueError(f\"Array '{name}' in {npz_path} has object dtype.\")\n    if np.issubdtype(array.dtype, np.integer):\n        logger.warning(\n            \"Casting array '{}' in {} from {} to float32.\",\n            name,\n            npz_path,\n            array.dtype,\n        )\n        return array.astype(np.float32, copy=False)\n    raise ValueError(\n        f\"Array '{name}' in {npz_path} has dtype {array.dtype}, \"\n        \"expected float32 or integer.\"\n    )\n\n\ndef _discover_array_specs(first_npz: Path) -> List[ArraySpec]:\n    with np.load(first_npz, allow_pickle=False) as data:\n        if \"ref_dof_pos\" not in data:\n            raise KeyError(f\"'ref_dof_pos' missing in NPZ: {first_npz}\")\n        if \"ref_global_translation\" not in data:\n            raise KeyError(\n                f\"'ref_global_translation' missing in NPZ: {first_npz}\"\n            )\n        if \"ref_global_rotation_quat\" not in data:\n            raise KeyError(\n                f\"'ref_global_rotation_quat' missing in NPZ: {first_npz}\"\n            )\n        dof_pos = data[\"ref_dof_pos\"]\n        global_trans = data[\"ref_global_translation\"]\n        global_rot = data[\"ref_global_rotation_quat\"]\n        if dof_pos.ndim < 2:\n            raise ValueError(f\"'ref_dof_pos' must be (T, ndof) in {first_npz}\")\n        if global_trans.ndim < 2 or global_trans.shape[-1] != 3:\n            raise ValueError(\n                f\"'ref_global_translation' must end with 3 in {first_npz}\"\n            )\n        if global_rot.ndim < 2 or global_rot.shape[-1] != 4:\n            raise ValueError(\n                f\"'ref_global_rotation_quat' must end with 4 in {first_npz}\"\n            )\n        dof_tail = tuple(dof_pos.shape[1:])\n    return [\n        ArraySpec(name=\"ref_dof_pos\", shape_tail=dof_tail, dtype=np.float32),\n        ArraySpec(name=\"ref_root_pos\", shape_tail=(3,), dtype=np.float32),\n        ArraySpec(name=\"ref_root_rot\", shape_tail=(4,), dtype=np.float32),\n    ]\n\n\ndef _load_npz_arrays(\n    npz_path: Path,\n    num_frames: int,\n    dof_tail: Tuple[int, ...],\n) -> Dict[str, np.ndarray]:\n    with np.load(npz_path, allow_pickle=False) as data:\n        dof_pos = _cast_array(data[\"ref_dof_pos\"], \"ref_dof_pos\", npz_path)\n        global_trans = _cast_array(\n            data[\"ref_global_translation\"], \"ref_global_translation\", npz_path\n        )\n        global_rot = _cast_array(\n            data[\"ref_global_rotation_quat\"],\n            \"ref_global_rotation_quat\",\n            npz_path,\n        )\n\n    if global_trans.ndim == 2:\n        root_pos = global_trans\n    elif global_trans.ndim >= 3:\n        root_pos = global_trans[:, 0, :]\n    else:\n        raise ValueError(\n            f\"ref_global_translation must be (T,3) or (T,B,3) in {npz_path}\"\n        )\n    if global_rot.ndim == 2:\n        root_rot = global_rot\n    elif global_rot.ndim >= 3:\n        root_rot = global_rot[:, 0, :]\n    else:\n        raise ValueError(\n            f\"ref_global_rotation_quat must be (T,4) or (T,B,4) in {npz_path}\"\n        )\n\n    expected_dof_shape = (num_frames, *dof_tail)\n    if dof_pos.shape != expected_dof_shape:\n        raise ValueError(\n            f\"ref_dof_pos shape {dof_pos.shape} does not match {expected_dof_shape} \"\n            f\"in {npz_path}\"\n        )\n    if root_pos.shape != (num_frames, 3):\n        raise ValueError(\n            f\"ref_root_pos shape {root_pos.shape} does not match {(num_frames, 3)} \"\n            f\"in {npz_path}\"\n        )\n    if root_rot.shape != (num_frames, 4):\n        raise ValueError(\n            f\"ref_root_rot shape {root_rot.shape} does not match {(num_frames, 4)} \"\n            f\"in {npz_path}\"\n        )\n\n    return {\n        \"ref_dof_pos\": dof_pos,\n        \"ref_root_pos\": root_pos,\n        \"ref_root_rot\": root_rot,\n    }\n\n\ndef _relative_npz_path(npz_path: Path, roots: List[str]) -> str:\n    npz_path = npz_path.expanduser().resolve()\n    for root in roots:\n        root_path = Path(root).expanduser().resolve()\n        try:\n            rel = npz_path.relative_to(root_path)\n        except ValueError:\n            continue\n        return str(Path(root_path.name) / rel)\n    return str(npz_path)\n\n\ndef _nan_array_names(arrays: Dict[str, np.ndarray]) -> List[str]:\n    nan_names: List[str] = []\n    for name, array in arrays.items():\n        if not np.issubdtype(array.dtype, np.floating):\n            continue\n        if np.isnan(array).any():\n            nan_names.append(name)\n            return nan_names\n    return []\n\n\ndef _estimate_bytes_for_motion(npz_path: Path, mode: str) -> int:\n    \"\"\"Estimate per-clip byte contribution for shard sizing.\n\n    Note:\n    - ``uncompressed_nbytes`` matches the in-memory float32 payload size and does\n      *not* correspond to on-disk shard size when compression is enabled.\n    - ``npz_filesize`` uses the compressed input file size as a cheap proxy for\n      on-disk shard size.\n    - ``h5_filesize`` does not use this estimator (it measures actual shard size\n      after writes).\n    \"\"\"\n    mode_norm = str(mode).lower().strip()\n    if mode_norm in (\"uncompressed_nbytes\", \"nbytes\", \"uncompressed\"):\n        with np.load(npz_path, allow_pickle=False) as data:\n            total = 0\n            for key in (\n                \"ref_dof_pos\",\n                \"ref_global_translation\",\n                \"ref_global_rotation_quat\",\n            ):\n                if key in data:\n                    total += int(data[key].nbytes)\n        return total\n    if mode_norm in (\"npz_filesize\", \"npz_size\", \"npz_bytes\"):\n        return int(npz_path.stat().st_size)\n    raise ValueError(\n        f\"Unsupported shard_target_mode '{mode}'. Expected one of: \"\n        \"uncompressed_nbytes | npz_filesize | h5_filesize\"\n    )\n\n\n@hydra.main(\n    config_path=\"../../config\",\n    config_name=\"motion_retargeting/pack_hdf5_v2\",\n    version_base=None,\n)\ndef main(cfg: OmegaConf) -> None:\n    roots = _normalize_root_list(cfg.get(\"holomotion_npz_root\", None))\n    if len(roots) == 0:\n        roots = _normalize_root_list(\n            cfg.get(\"holomotion_retargeted_dirs\", None)\n        )\n    if len(roots) == 0:\n        legacy_root = cfg.get(\"precomputed_npz_root\", None)\n        roots = _normalize_root_list(legacy_root)\n    if len(roots) == 0:\n        raise ValueError(\"holomotion_npz_root must be provided.\")\n\n    hdf5_root = cfg.get(\n        \"hdf5_root\", os.path.join(os.getcwd(), \"holomotion_hdf5_v2\")\n    )\n    chunks_t = int(cfg.get(\"chunks_t\", 1024))\n    compression = str(cfg.get(\"compression\", \"lzf\")).lower()\n    shard_target_gb = float(cfg.get(\"shard_target_gb\", 2.0))\n    shard_target_bytes = int(\n        cfg.get(\"shard_target_bytes\", shard_target_gb * (1 << 30))\n    )\n    shard_target_mode = str(\n        cfg.get(\"shard_target_mode\", \"uncompressed_nbytes\")\n    )\n\n    for root in roots:\n        if not os.path.isdir(root):\n            raise FileNotFoundError(f\"NPZ clips directory not found: {root}\")\n\n    entries = _discover_motion_entries(roots)\n    motion_keys: List[str] = []\n    motion_key2id: Dict[str, int] = {}\n    nan_npz_paths: List[str] = []\n\n    first_npz = Path(entries[0].path)\n    array_specs = _discover_array_specs(first_npz)\n    array_names_created = [s.name for s in array_specs]\n    dof_tail = next(\n        spec.shape_tail for spec in array_specs if spec.name == \"ref_dof_pos\"\n    )\n    logger.info(\n        \"HDF5 v2 datasets: {} (dof_tail={})\",\n        array_names_created,\n        dof_tail,\n    )\n\n    dof_names: List[str] = []\n    body_names: List[str] = []\n    extended_body_names: List[str] = []\n    robot_cfg = cfg.get(\"robot\", None)\n    if robot_cfg is not None and \"motion\" in robot_cfg:\n        motion_cfg = robot_cfg[\"motion\"]\n        dof_names = list(motion_cfg.get(\"dof_names\", []))\n        body_names = list(motion_cfg.get(\"body_names\", []))\n        extended_body_names = list(\n            list(motion_cfg.get(\"body_names\", []))\n            + [\n                i.get(\"joint_name\")\n                for i in motion_cfg.get(\"extend_config\", [])\n            ]\n        )\n\n    shard_dir = os.path.join(str(hdf5_root), \"shards\")\n    _ensure_dir(shard_dir)\n\n    hdf5_shards: List[Dict[str, Any]] = []\n    clips_manifest: Dict[str, Dict[str, Any]] = {}\n\n    curr_shard_idx = 0\n    curr_shard_bytes = 0\n    writer: Optional[Hdf5ShardWriter] = None\n    pbar = tqdm(total=len(entries), desc=\"Packing HDF5 v2 shards\")\n\n    for entry in entries:\n        npz_path = Path(entry.path)\n        metadata_json, metadata = _load_metadata_json(npz_path)\n        if \"num_frames\" not in metadata:\n            raise KeyError(f\"'num_frames' missing in metadata: {npz_path}\")\n        num_frames = int(metadata[\"num_frames\"])\n        if num_frames <= 0:\n            raise ValueError(f\"Invalid num_frames {num_frames} in {npz_path}\")\n\n        arrays_np = _load_npz_arrays(\n            npz_path=npz_path, num_frames=num_frames, dof_tail=dof_tail\n        )\n        nan_arrays = _nan_array_names(arrays_np)\n        if len(nan_arrays) > 0:\n            rel_npz_path = _relative_npz_path(npz_path, roots)\n            nan_npz_paths.append(rel_npz_path)\n            logger.warning(\n                \"NaN detected in NPZ (arrays: {}), skipping: {}\",\n                nan_arrays,\n                npz_path,\n            )\n            pbar.update(1)\n            continue\n\n        shard_mode_norm = shard_target_mode.lower().strip()\n        if shard_mode_norm in (\n            \"h5_filesize\",\n            \"h5_size\",\n            \"output_filesize\",\n            \"disk\",\n        ):\n            if writer is None:\n                shard_name = f\"holomotion_{curr_shard_idx:03d}.h5\"\n                shard_path = os.path.join(shard_dir, shard_name)\n                writer = Hdf5ShardWriter(\n                    shard_path,\n                    array_specs,\n                    chunks_t=chunks_t,\n                    compression=compression,\n                )\n        else:\n            motion_bytes = _estimate_bytes_for_motion(\n                npz_path, shard_target_mode\n            )\n            if (\n                writer is None\n                or (curr_shard_bytes + motion_bytes) > shard_target_bytes\n            ):\n                if writer is not None:\n                    shard_summary = writer.finalize()\n                    hdf5_shards.append(\n                        {\n                            \"file\": os.path.relpath(\n                                shard_summary[\"file\"], str(hdf5_root)\n                            ),\n                            \"num_clips\": shard_summary[\"num_clips\"],\n                            \"num_frames\": shard_summary[\"num_frames\"],\n                        }\n                    )\n                    curr_shard_idx += 1\n                    curr_shard_bytes = 0\n\n                shard_name = f\"holomotion_{curr_shard_idx:03d}.h5\"\n                shard_path = os.path.join(shard_dir, shard_name)\n                writer = Hdf5ShardWriter(\n                    shard_path,\n                    array_specs,\n                    chunks_t=chunks_t,\n                    compression=compression,\n                )\n\n        motion_id = motion_key2id.get(entry.name)\n        if motion_id is None:\n            motion_id = len(motion_keys)\n            motion_key2id[entry.name] = motion_id\n            motion_keys.append(entry.name)\n        start, length = writer.append_motion(\n            motion_id=motion_id,\n            np_arrays=arrays_np,\n            metadata_json=metadata_json,\n        )\n\n        clips_manifest[entry.name] = {\n            \"motion_key\": entry.name,\n            \"shard\": curr_shard_idx,\n            \"clip_idx\": len(writer._clip_starts) - 1,\n            \"start\": int(start),\n            \"length\": int(length),\n            \"available_arrays\": list(array_names_created),\n            \"metadata\": metadata,\n        }\n        if shard_mode_norm in (\n            \"h5_filesize\",\n            \"h5_size\",\n            \"output_filesize\",\n            \"disk\",\n        ):\n            writer.h5.flush()\n            curr_shard_bytes = int(os.path.getsize(writer.h5_path))\n        else:\n            curr_shard_bytes += motion_bytes\n        pbar.update(1)\n\n        if (\n            shard_mode_norm\n            in (\"h5_filesize\", \"h5_size\", \"output_filesize\", \"disk\")\n            and curr_shard_bytes >= shard_target_bytes\n            and writer is not None\n        ):\n            shard_summary = writer.finalize()\n            hdf5_shards.append(\n                {\n                    \"file\": os.path.relpath(\n                        shard_summary[\"file\"], str(hdf5_root)\n                    ),\n                    \"num_clips\": shard_summary[\"num_clips\"],\n                    \"num_frames\": shard_summary[\"num_frames\"],\n                }\n            )\n            curr_shard_idx += 1\n            curr_shard_bytes = 0\n            writer = None\n\n    pbar.close()\n\n    if writer is not None:\n        shard_summary = writer.finalize()\n        hdf5_shards.append(\n            {\n                \"file\": os.path.relpath(shard_summary[\"file\"], str(hdf5_root)),\n                \"num_clips\": shard_summary[\"num_clips\"],\n                \"num_frames\": shard_summary[\"num_frames\"],\n            }\n        )\n\n    manifest = {\n        \"version\": 1,\n        \"root\": str(hdf5_root),\n        \"hdf5_shards\": hdf5_shards,\n        \"clips\": clips_manifest,\n        \"motion_keys\": motion_keys,\n        \"dof_names\": dof_names,\n        \"body_names\": body_names,\n        \"extended_body_names\": extended_body_names,\n        \"array_names\": array_names_created,\n        \"chunks_t\": int(chunks_t),\n        \"compression\": compression,\n        \"shard_target_mode\": str(shard_target_mode),\n        \"shard_target_bytes\": int(shard_target_bytes),\n    }\n    _ensure_dir(str(hdf5_root))\n    nan_paths_path = os.path.join(str(hdf5_root), \"nan_npz_paths.json\")\n    with open(nan_paths_path, \"w\") as f:\n        json.dump(nan_npz_paths, f, indent=2)\n    if len(nan_npz_paths) > 0:\n        logger.warning(\n            \"Skipped {} NPZ files due to NaNs. List: {}\",\n            len(nan_npz_paths),\n            nan_paths_path,\n        )\n    else:\n        logger.info(\"No NaN detected in NPZ inputs.\")\n    with open(os.path.join(str(hdf5_root), \"manifest.json\"), \"w\") as f:\n        json.dump(manifest, f, indent=2)\n    logger.info(\n        \"HDF5 v2 packing complete. Shards: {}. Root: {}\",\n        len(hdf5_shards),\n        hdf5_root,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/reference_filtering.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom typing import Dict, Mapping, Tuple\n\nimport numpy as np\n\n# This module keeps the offline preprocess filtering path and the online\n# root/DoF-before-FK path aligned while still exposing helpers tailored to\n# each tensor family.\n\n\ndef _reshape_time_flat(a: np.ndarray) -> Tuple[np.ndarray, Tuple[int, ...]]:\n    shape = a.shape\n    t = shape[0]\n    return a.reshape(t, -1), shape\n\n\ndef _butterworth_lowpass_smooth_time(\n    a: np.ndarray, fps: float, cutoff_hz: float, order: int\n) -> np.ndarray:\n    from scipy.signal import butter, filtfilt\n\n    t = a.shape[0]\n    if t < 3:\n        return a.astype(np.float32, copy=True)\n    if fps <= 0.0 or cutoff_hz <= 0.0:\n        return a.astype(np.float32, copy=True)\n    nyquist = 0.5 * float(fps)\n    wn = float(cutoff_hz) / nyquist\n    if wn >= 1.0:\n        wn = 0.999\n    if wn <= 0.0:\n        return a.astype(np.float32, copy=True)\n    flat, shape = _reshape_time_flat(a.astype(np.float64, copy=False))\n    b, a_coefs = butter(int(order), wn, btype=\"low\", analog=False)\n    maxlen = max(len(b), len(a_coefs))\n    padlen_required = max(3 * (maxlen - 1), 3 * maxlen)\n    if t <= padlen_required:\n        return a.astype(np.float32, copy=True)\n    filtered = filtfilt(b, a_coefs, flat, axis=0, method=\"pad\")\n    return filtered.reshape(shape).astype(np.float32, copy=False)\n\n\ndef _quat_normalize(q: np.ndarray) -> np.ndarray:\n    norm = np.linalg.norm(q, axis=-1, keepdims=True)\n    norm = np.where(norm == 0.0, 1.0, norm)\n    return (q / norm).astype(np.float32, copy=False)\n\n\ndef _quat_hemisphere_align(q: np.ndarray) -> np.ndarray:\n    if q.shape[0] == 0:\n        return q\n    aligned = q.copy()\n    prev = aligned[0]\n    for t in range(1, aligned.shape[0]):\n        dots = np.sum(prev * aligned[t], axis=-1)\n        mask = dots < 0.0\n        if np.any(mask):\n            aligned[t, mask] = -aligned[t, mask]\n        prev = aligned[t]\n    return aligned\n\n\ndef _quat_conjugate(q: np.ndarray) -> np.ndarray:\n    conj = q.copy()\n    conj[..., :3] = -conj[..., :3]\n    return conj\n\n\ndef _quat_multiply(a: np.ndarray, b: np.ndarray) -> np.ndarray:\n    av = a[..., :3]\n    aw = a[..., 3:4]\n    bv = b[..., :3]\n    bw = b[..., 3:4]\n    cross = np.cross(av, bv)\n    vec = aw * bv + bw * av + cross\n    scalar = aw * bw - np.sum(av * bv, axis=-1, keepdims=True)\n    return np.concatenate([vec, scalar], axis=-1)\n\n\ndef _finite_difference_time(a: np.ndarray, dt: float) -> np.ndarray:\n    t = a.shape[0]\n    if t < 2 or dt <= 0.0:\n        return np.zeros_like(a, dtype=np.float32)\n    deriv = np.gradient(\n        a.astype(np.float64, copy=False),\n        dt,\n        axis=0,\n        edge_order=2 if t >= 3 else 1,\n    )\n    return deriv.astype(np.float32, copy=False)\n\n\ndef _angular_velocity_from_quat(\n    q: np.ndarray, q_dot: np.ndarray\n) -> np.ndarray:\n    q_conj = _quat_conjugate(q)\n    prod = _quat_multiply(q_conj, q_dot)\n    omega = 2.0 * prod[..., :3]\n    return omega.astype(np.float32, copy=False)\n\n\ndef butterworth_filter_ref_arrays(\n    arrays: Mapping[str, np.ndarray],\n    fps: float,\n    cutoff_hz: float,\n    order: int,\n) -> Dict[str, np.ndarray]:\n    out: Dict[str, np.ndarray] = {}\n    dt = 1.0 / float(fps) if float(fps) > 0.0 else 0.0\n    if \"ref_dof_pos\" in arrays:\n        dof_pos = arrays[\"ref_dof_pos\"].astype(np.float32, copy=True)\n        smooth_dof_pos = _butterworth_lowpass_smooth_time(\n            dof_pos, fps, cutoff_hz, order\n        )\n        dof_vel = _finite_difference_time(smooth_dof_pos, dt)\n        out[\"ft_ref_dof_pos\"] = smooth_dof_pos\n        out[\"ft_ref_dof_vel\"] = dof_vel\n    if \"ref_global_translation\" in arrays:\n        body_pos = arrays[\"ref_global_translation\"].astype(\n            np.float32, copy=True\n        )\n        smooth_body_pos = _butterworth_lowpass_smooth_time(\n            body_pos, fps, cutoff_hz, order\n        )\n        body_vel = _finite_difference_time(smooth_body_pos, dt)\n        out[\"ft_ref_global_translation\"] = smooth_body_pos\n        out[\"ft_ref_global_velocity\"] = body_vel\n    if \"ref_global_rotation_quat\" in arrays:\n        body_rot = arrays[\"ref_global_rotation_quat\"].astype(\n            np.float32, copy=True\n        )\n        body_rot = _quat_normalize(body_rot)\n        body_rot = _quat_hemisphere_align(body_rot)\n        smooth_body_rot = _butterworth_lowpass_smooth_time(\n            body_rot, fps, cutoff_hz, order\n        )\n        smooth_body_rot = _quat_normalize(smooth_body_rot)\n        body_rot_dot = _finite_difference_time(smooth_body_rot, dt)\n        out[\"ft_ref_global_rotation_quat\"] = _quat_normalize(smooth_body_rot)\n        out[\"ft_ref_global_angular_velocity\"] = _angular_velocity_from_quat(\n            smooth_body_rot, body_rot_dot\n        )\n    return out\n\n\ndef butterworth_filter_root_dof_arrays(\n    arrays: Mapping[str, np.ndarray],\n    fps: float,\n    cutoff_hz: float,\n    order: int,\n) -> Dict[str, np.ndarray]:\n    out: Dict[str, np.ndarray] = {}\n    if \"ref_root_pos\" in arrays:\n        root_pos = arrays[\"ref_root_pos\"].astype(np.float32, copy=True)\n        out[\"ft_ref_root_pos\"] = _butterworth_lowpass_smooth_time(\n            root_pos, fps, cutoff_hz, order\n        )\n    if \"ref_root_rot\" in arrays:\n        root_rot = arrays[\"ref_root_rot\"].astype(np.float32, copy=True)\n        root_rot = _quat_normalize(root_rot)\n        root_rot = _quat_hemisphere_align(root_rot)\n        smooth_root_rot = _butterworth_lowpass_smooth_time(\n            root_rot, fps, cutoff_hz, order\n        )\n        out[\"ft_ref_root_rot\"] = _quat_normalize(smooth_root_rot)\n    if \"ref_dof_pos\" in arrays:\n        dof_pos = arrays[\"ref_dof_pos\"].astype(np.float32, copy=True)\n        out[\"ft_ref_dof_pos\"] = _butterworth_lowpass_smooth_time(\n            dof_pos, fps, cutoff_hz, order\n        )\n    return out\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/utils/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/utils/_schema.json",
    "content": "{\n  \"schema\": {\n    \"root_trans_offset\": {\n      \"shape\": [\n        682,\n        3\n      ],\n      \"dtype\": \"float64\"\n    },\n    \"pose_aa\": {\n      \"shape\": [\n        682,\n        27,\n        3\n      ],\n      \"dtype\": \"float32\"\n    },\n    \"dof\": {\n      \"shape\": [\n        682,\n        23\n      ],\n      \"dtype\": \"float32\"\n    },\n    \"root_rot\": {\n      \"shape\": [\n        682,\n        4\n      ],\n      \"dtype\": \"float64\"\n    },\n    \"smpl_joints\": {\n      \"shape\": [\n        682,\n        24,\n        3\n      ],\n      \"dtype\": \"float32\"\n    },\n    \"fps\": {\n      \"shape\": [],\n      \"dtype\": \"int64\"\n    }\n  },\n  \"sample_top_key\": \"2024-12-28 16.03.15-视频-2025年主播必学120支热门小舞蹈-帅帅的《电话卡点舞》 #电话卡点舞 #...舞 #抖音热歌 #舞蹈教学 #网红必学+p02_1_btws_pad\"\n}"
  },
  {
    "path": "holomotion/src/motion_retargeting/utils/rotation_conversions.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom typing import Optional, Union\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef wxyz_to_xyzw(quat):\n    return quat[..., [1, 2, 3, 0]]\n\n\ndef xyzw_to_wxyz(quat):\n    return quat[..., [3, 0, 1, 2]]\n\n\nDevice = Union[str, torch.device]\n\"\"\"\nThe transformation matrices returned from the functions in this file assume\nthe points on which the transformation will be applied are column vectors.\ni.e. the R matrix is structured as\n\n    R = [\n            [Rxx, Rxy, Rxz],\n            [Ryx, Ryy, Ryz],\n            [Rzx, Rzy, Rzz],\n        ]  # (3, 3)\n\nThis matrix can be applied to column vectors by post multiplication\nby the points e.g.\n\n    points = [[0], [1], [2]]  # (3 x 1) xyz coordinates of a point\n    transformed_points = R * points\n\nTo apply the same matrix to points which are row vectors, the R matrix\ncan be transposed and pre multiplied by the points:\n\ne.g.\n    points = [[0, 1, 2]]  # (1 x 3) xyz coordinates of a point\n    transformed_points = points * R.transpose(1, 0)\n\"\"\"\n\n\ndef quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions, -1)\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\n\ndef _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n    \"\"\"Return a tensor of absolute value.\n\n    Return a tensor where each element has the absolute value taken from\n    the corresponding element of a, with sign taken from the corresponding\n    element of b. This is like the standard copysign floating-point operation,\n    but is not careful about negative 0 and NaN.\n\n    Args:\n        a: source tensor.\n        b: tensor whose signs will be used, of the same shape as a.\n\n    Returns:\n        Tensor of the same shape as a with the signs of b.\n\n    \"\"\"\n    signs_differ = (a < 0) != (b < 0)\n    return torch.where(signs_differ, -a, a)\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Returns torch.sqrt(torch.max(0, x)).\n\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"W x y z.\n\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        matrix.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    # we produce the desired quaternion multiplied by each of r, i, j, k\n    quat_by_rijk = torch.stack(\n        [\n            torch.stack(\n                [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1\n            ),\n            torch.stack(\n                [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1\n            ),\n            torch.stack(\n                [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1\n            ),\n            torch.stack(\n                [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1\n            ),\n        ],\n        dim=-2,\n    )\n\n    # We floor here at 0.1 but the exact level is not important; if q_abs is\n    # small, the candidate won't be picked.\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    # if not for numerical problems, quat_candidates[i] should be same\n    # (up to a sign), forall i; we pick the best-conditioned one\n    # (with the largest denominator)\n\n    return quat_candidates[\n        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,\n        :,  # pyre-ignore[16]\n    ].reshape(batch_dim + (4,))\n\n\ndef _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:\n    \"\"\"Return the rotation matrices for one of the rotations about an axis.\n\n    of which Euler angles describe, for each value of the angle given.\n\n    Args:\n        axis: Axis label \"X\" or \"Y or \"Z\".\n        angle: any shape tensor of Euler angles in radians\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n\n    \"\"\"\n    cos = torch.cos(angle)\n    sin = torch.sin(angle)\n    one = torch.ones_like(angle)\n    zero = torch.zeros_like(angle)\n\n    if axis == \"X\":\n        r_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)\n    elif axis == \"Y\":\n        r_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)\n    elif axis == \"Z\":\n        r_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)\n    else:\n        raise ValueError(\"letter must be either X, Y or Z.\")\n\n    return torch.stack(r_flat, -1).reshape(angle.shape + (3, 3))\n\n\ndef euler_angles_to_matrix(\n    euler_angles: torch.Tensor, convention: str\n) -> torch.Tensor:\n    \"\"\"Convert rotations given as Euler angles in radians to rotation matrices.\n\n    Args:\n        euler_angles: Euler angles in radians as tensor of shape (..., 3).\n        convention: Convention string of three uppercase letters from\n            {\"X\", \"Y\", and \"Z\"}.\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n\n    \"\"\"\n    if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:\n        raise ValueError(\"Invalid input euler angles.\")\n    if len(convention) != 3:\n        raise ValueError(\"Convention must have 3 letters.\")\n    if convention[1] in (convention[0], convention[2]):\n        raise ValueError(f\"Invalid convention {convention}.\")\n    for letter in convention:\n        if letter not in (\"X\", \"Y\", \"Z\"):\n            raise ValueError(f\"Invalid letter {letter} in convention string.\")\n    matrices = [\n        _axis_angle_rotation(c, e)\n        for c, e in zip(\n            convention, torch.unbind(euler_angles, -1), strict=False\n        )\n    ]\n    # return functools.reduce(torch.matmul, matrices)\n    return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])\n\n\ndef _angle_from_tan(\n    axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool\n) -> torch.Tensor:\n    \"\"\"Extract the first or third Euler angle from the two members of.\n\n    the matrix which are positive constant times its sine and cosine.\n\n    Args:\n        axis: Axis label \"X\" or \"Y or \"Z\" for the angle we are finding.\n        other_axis: Axis label \"X\" or \"Y or \"Z\" for the middle axis in the\n            convention.\n        data: Rotation matrices as tensor of shape (..., 3, 3).\n        horizontal: Whether we are looking for the angle for the third axis,\n            which means the relevant entries are in the same row of the\n            rotation matrix. If not, they are in the same column.\n        tait_bryan: Whether the first and third axes in the convention differ.\n\n    Returns:\n        Euler Angles in radians for each matrix in data as a tensor\n        of shape (...).\n\n    \"\"\"\n    i1, i2 = {\"X\": (2, 1), \"Y\": (0, 2), \"Z\": (1, 0)}[axis]\n    if horizontal:\n        i2, i1 = i1, i2\n    even = (axis + other_axis) in [\"XY\", \"YZ\", \"ZX\"]\n    if horizontal == even:\n        return torch.atan2(data[..., i1], data[..., i2])\n    if tait_bryan:\n        return torch.atan2(-data[..., i2], data[..., i1])\n    return torch.atan2(data[..., i2], -data[..., i1])\n\n\ndef _index_from_letter(letter: str) -> int:\n    if letter == \"X\":\n        return 0\n    if letter == \"Y\":\n        return 1\n    if letter == \"Z\":\n        return 2\n    raise ValueError(\"letter must be either X, Y or Z.\")\n\n\ndef matrix_to_euler_angles(\n    matrix: torch.Tensor, convention: str\n) -> torch.Tensor:\n    \"\"\"Convert rotations given as rotation matrices to Euler angles in radians.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n        convention: Convention string of three uppercase letters.\n\n    Returns:\n        Euler angles in radians as tensor of shape (..., 3).\n\n    \"\"\"\n    if len(convention) != 3:\n        raise ValueError(\"Convention must have 3 letters.\")\n    if convention[1] in (convention[0], convention[2]):\n        raise ValueError(f\"Invalid convention {convention}.\")\n    for letter in convention:\n        if letter not in (\"X\", \"Y\", \"Z\"):\n            raise ValueError(f\"Invalid letter {letter} in convention string.\")\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n    i0 = _index_from_letter(convention[0])\n    i2 = _index_from_letter(convention[2])\n    tait_bryan = i0 != i2\n    if tait_bryan:\n        central_angle = torch.asin(\n            matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)\n        )\n    else:\n        central_angle = torch.acos(matrix[..., i0, i0])\n\n    o = (\n        _angle_from_tan(\n            convention[0], convention[1], matrix[..., i2], False, tait_bryan\n        ),\n        central_angle,\n        _angle_from_tan(\n            convention[2], convention[1], matrix[..., i0, :], True, tait_bryan\n        ),\n    )\n    return torch.stack(o, -1)\n\n\ndef random_quaternions(\n    n: int,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[Device] = None,\n) -> torch.Tensor:\n    \"\"\"Generate random quaternions representing rotations.\n\n    i.e. versors with nonnegative real part.\n\n    Args:\n        n: Number of quaternions in a batch to return.\n        dtype: Type to return.\n        device: Desired device of returned tensor. Default:\n            uses the current device for the default tensor type.\n\n    Returns:\n        Quaternions as tensor of shape (N, 4).\n\n    \"\"\"\n    if isinstance(device, str):\n        device = torch.device(device)\n    o = torch.randn((n, 4), dtype=dtype, device=device)\n    s = (o * o).sum(1)\n    o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]\n    return o\n\n\ndef random_rotations(\n    n: int,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[Device] = None,\n) -> torch.Tensor:\n    \"\"\"Generate random rotations as 3x3 rotation matrices.\n\n    Args:\n        n: Number of rotation matrices in a batch to return.\n        dtype: Type to return.\n        device: Device of returned tensor. Default: if None,\n            uses the current device for the default tensor type.\n\n    Returns:\n        Rotation matrices as tensor of shape (n, 3, 3).\n\n    \"\"\"\n    quaternions = random_quaternions(n, dtype=dtype, device=device)\n    return quaternion_to_matrix(quaternions)\n\n\ndef random_rotation(\n    dtype: Optional[torch.dtype] = None, device: Optional[Device] = None\n) -> torch.Tensor:\n    \"\"\"Generate a single random 3x3 rotation matrix.\n\n    Args:\n        dtype: Type to return\n        device: Device of returned tensor. Default: if None,\n            uses the current device for the default tensor type\n\n    Returns:\n        Rotation matrix as tensor of shape (3, 3).\n\n    \"\"\"\n    return random_rotations(1, dtype, device)[0]\n\n\ndef standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert a unit quaternion to a standard form: one in which the real.\n\n    part is non negative.\n\n    Args:\n        quaternions: Quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Standardized quaternions as tensor of shape (..., 4).\n\n    \"\"\"\n    return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)\n\n\ndef quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n    \"\"\"Multiply two quaternions.\n\n    Usual torch rules for broadcasting apply.\n\n    Args:\n        a: Quaternions as tensor of shape (..., 4), real part first.\n        b: Quaternions as tensor of shape (..., 4), real part first.\n\n    Returns:\n        The product of a and b, a tensor of quaternions shape (..., 4).\n\n    \"\"\"\n    aw, ax, ay, az = torch.unbind(a, -1)\n    bw, bx, by, bz = torch.unbind(b, -1)\n    ow = aw * bw - ax * bx - ay * by - az * bz\n    ox = aw * bx + ax * bw + ay * bz - az * by\n    oy = aw * by - ax * bz + ay * bw + az * bx\n    oz = aw * bz + ax * by - ay * bx + az * bw\n    return torch.stack((ow, ox, oy, oz), -1)\n\n\ndef quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n    \"\"\"Multiply two quaternions representing rotations.\n\n    Returning the quaternion representing their composition,\n    i.e. the versor with nonnegative real part.\n    Usual torch rules for broadcasting apply.\n\n    Args:\n        a: Quaternions as tensor of shape (..., 4), real part first.\n        b: Quaternions as tensor of shape (..., 4), real part first.\n\n    Returns:\n        The product of a and b, a tensor of quaternions of shape (..., 4).\n\n    \"\"\"\n    ab = quaternion_raw_multiply(a, b)\n    return standardize_quaternion(ab)\n\n\ndef quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:\n    \"\"\"Get the quaternion representingquaternion representing rotation.\n\n    Args:\n        quaternion: Quaternions as tensor of shape (..., 4), with real part\n            first, which must be versors (unit quaternions).\n\n    Returns:\n        The inverse, a tensor of quaternions of shape (..., 4).\n\n    \"\"\"\n    scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)\n    return quaternion * scaling\n\n\ndef quaternion_apply(\n    quaternion: torch.Tensor, point: torch.Tensor\n) -> torch.Tensor:\n    \"\"\"Apply the rotation given by a quaternion to a 3D point.\n\n    Usual torch rules for broadcasting apply.\n\n    Args:\n        quaternion: Tensor of quaternions, real part first, of shape (..., 4).\n        point: Tensor of 3D points of shape (..., 3).\n\n    Returns:\n        Tensor of rotated points of shape (..., 3).\n\n    \"\"\"\n    if point.size(-1) != 3:\n        raise ValueError(f\"Points are not in 3D, {point.shape}.\")\n    real_parts = point.new_zeros(point.shape[:-1] + (1,))\n    point_as_quaternion = torch.cat((real_parts, point), -1)\n    out = quaternion_raw_multiply(\n        quaternion_raw_multiply(quaternion, point_as_quaternion),\n        quaternion_invert(quaternion),\n    )\n    return out[..., 1:]\n\n\ndef axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as axis/angle to rotation matrices.\n\n    Args:\n        axis_angle: Rotations given as a vector in axis angle form,\n            as a tensor of shape (..., 3), where the magnitude is\n            the angle turned anticlockwise in radians around the\n            vector's direction.\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n\n    \"\"\"\n    return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))\n\n\ndef matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as rotation matrices to axis/angle.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        Rotations given as a vector in axis angle form, as a tensor\n            of shape (..., 3), where the magnitude is the angle\n            turned anticlockwise in radians around the vector's\n            direction.\n\n    \"\"\"\n    return quaternion_to_axis_angle(matrix_to_quaternion(matrix))\n\n\ndef axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as axis/angle to quaternions.\n\n    Args:\n        axis_angle: Rotations given as a vector in axis angle form,\n            as a tensor of shape (..., 3), where the magnitude is\n            the angle turned anticlockwise in radians around the\n            vector's direction.\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n\n    \"\"\"\n    angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)\n    half_angles = angles * 0.5\n    eps = 1e-6\n    small_angles = angles.abs() < eps\n    sin_half_angles_over_angles = torch.empty_like(angles)\n    sin_half_angles_over_angles[~small_angles] = (\n        torch.sin(half_angles[~small_angles]) / angles[~small_angles]\n    )\n    # for x small, sin(x/2) is about x/2 - (x/2)^3/6\n    # so sin(x/2)/x is about 1/2 - (x*x)/48\n    sin_half_angles_over_angles[small_angles] = (\n        0.5 - (angles[small_angles] * angles[small_angles]) / 48\n    )\n    quaternions = torch.cat(\n        [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles],\n        dim=-1,\n    )\n    return quaternions\n\n\ndef quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as quaternions to axis/angle.\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotations given as a vector in axis angle form, as a tensor\n            of shape (..., 3), where the magnitude is the angle\n            turned anticlockwise in radians around the vector's\n            direction.\n\n    \"\"\"\n    norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)\n    half_angles = torch.atan2(norms, quaternions[..., :1])\n    angles = 2 * half_angles\n    eps = 1e-6\n    small_angles = angles.abs() < eps\n    sin_half_angles_over_angles = torch.empty_like(angles)\n    sin_half_angles_over_angles[~small_angles] = (\n        torch.sin(half_angles[~small_angles]) / angles[~small_angles]\n    )\n    # for x small, sin(x/2) is about x/2 - (x/2)^3/6\n    # so sin(x/2)/x is about 1/2 - (x*x)/48\n    sin_half_angles_over_angles[small_angles] = (\n        0.5 - (angles[small_angles] * angles[small_angles]) / 48\n    )\n    return quaternions[..., 1:] / sin_half_angles_over_angles\n\n\ndef rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:\n    \"\"\"Converts 6D rotation to rotation matrix.\n\n    Using Gram--Schmidt orthogonalization per Section B of [1].\n    Representation by Zhou et al. [1]\n\n    Args:\n        d6: 6D rotation representation, of size (*, 6)\n\n    Returns:\n        batch of rotation matrices of size (*, 3, 3)\n\n    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.\n    On the Continuity of Rotation Representations in Neural Networks.\n    IEEE Conference on Computer Vision and Pattern Recognition, 2019.\n    Retrieved from http://arxiv.org/abs/1812.07035\n\n    \"\"\"\n    a1, a2 = d6[..., :3], d6[..., 3:]\n    b1 = F.normalize(a1, dim=-1)\n    b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1\n    b2 = F.normalize(b2, dim=-1)\n    b3 = torch.cross(b1, b2, dim=-1)\n    return torch.stack((b1, b2, b3), dim=-2)\n\n\ndef matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"Converts rotation matrices to 6D rotation representation by Zhou et al.\n\n    by dropping the last row. Note that 6D representation is not unique.\n\n    Args:\n        matrix: batch of rotation matrices of size (*, 3, 3)\n\n    Returns:\n        6D rotation representation, of size (*, 6)\n\n    [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.\n    On the Continuity of Rotation Representations in Neural Networks.\n    IEEE Conference on Computer Vision and Pattern Recognition, 2019.\n    Retrieved from http://arxiv.org/abs/1812.07035\n\n    \"\"\"\n    batch_dim = matrix.size()[:-2]\n    return matrix[..., :2, :].clone().reshape(batch_dim + (6,))\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/utils/torch_humanoid_batch.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n#\n# This file was originally copied from the [PHC] repository:\n# https://github.com/ZhengyiLuo/PHC\n# Modifications have been made to fit the needs of this project.\n#\n\nimport os\nimport os.path as osp\nimport sys\n\nsys.path.append(os.getcwd())\nimport copy\nimport logging\nimport xml.etree.ElementTree as ETree\nfrom collections import OrderedDict, defaultdict\nfrom io import BytesIO\n\nimport numpy as np\nimport open3d as o3d\nimport scipy.ndimage.filters as filters\nimport smpl_sim.poselib.core.rotation3d as poselib_rotation3d\nimport smpl_sim.utils.rotation_conversions as torch_rotation_conversions\nimport torch\nfrom easydict import EasyDict\nfrom lxml.etree import XMLParser, parse\nfrom omegaconf import DictConfig\nfrom scipy.spatial.transform import Rotation as sRot\nfrom tqdm import tqdm\n\n# from loguru import logger\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.DEBUG, format=\"%(asctime)s - %(levelname)s - %(message)s\"\n)\n\n\nclass HumanoidBatch:\n    def __init__(self, cfg, device=None):\n        if device is None:\n            device = torch.device(\"cpu\")\n        self.cfg = cfg\n        self.mjcf_file = cfg.asset.assetFileName\n\n        parser = XMLParser(remove_blank_text=True)\n        tree = parse(\n            BytesIO(open(self.mjcf_file, \"rb\").read()),\n            parser=parser,\n        )\n        self.dof_axis = []\n        joints = sorted(\n            [\n                j.attrib[\"name\"]\n                for j in tree.getroot().find(\"worldbody\").findall(\".//joint\")\n            ]\n        )\n        motors = sorted(\n            [\n                m.attrib[\"name\"]\n                for m in tree.getroot().find(\"actuator\").getchildren()\n            ]\n        )\n        assert len(motors) > 0, \"No motors found in the mjcf file\"\n\n        self.num_dof = len(motors)\n        self.num_extend_dof = self.num_dof\n\n        self.mjcf_data = mjcf_data = self.from_mjcf(self.mjcf_file)\n        self.body_names = copy.deepcopy(mjcf_data[\"node_names\"])\n        # logger.info(f\"Body names from {self.mjcf_file}: {self.body_names}\")\n        self._parents = mjcf_data[\"parent_indices\"]\n        self.body_names_augment = copy.deepcopy(mjcf_data[\"node_names\"])\n        self._proper_kinematic_structure = copy.deepcopy(\n            mjcf_data[\"node_names\"]\n        )\n        self._offsets = mjcf_data[\"local_translation\"][None,].to(device)\n        self._local_rotation = mjcf_data[\"local_rotation\"][None,].to(device)\n        self.actuated_joints_idx = np.array(\n            [\n                self.body_names.index(k)\n                for k, v in mjcf_data[\"body_to_joint\"].items()\n            ]\n        )\n\n        for m in motors:\n            if m not in joints:\n                print(m)\n\n        if (\n            \"type\"\n            in tree.getroot().find(\"worldbody\").findall(\".//joint\")[0].attrib\n            and tree.getroot()\n            .find(\"worldbody\")\n            .findall(\".//joint\")[0]\n            .attrib[\"type\"]\n            == \"free\"\n        ):\n            for j in tree.getroot().find(\"worldbody\").findall(\".//joint\")[1:]:\n                self.dof_axis.append(\n                    [int(i) for i in j.attrib[\"axis\"].split(\" \")]\n                )\n            self.has_freejoint = True\n        elif (\n            \"type\"\n            not in tree.getroot()\n            .find(\"worldbody\")\n            .findall(\".//joint\")[0]\n            .attrib\n        ):\n            for j in tree.getroot().find(\"worldbody\").findall(\".//joint\"):\n                self.dof_axis.append(\n                    [int(i) for i in j.attrib[\"axis\"].split(\" \")]\n                )\n            self.has_freejoint = True\n        else:\n            for j in tree.getroot().find(\"worldbody\").findall(\".//joint\")[6:]:\n                self.dof_axis.append(\n                    [int(i) for i in j.attrib[\"axis\"].split(\" \")]\n                )\n            self.has_freejoint = False\n\n        axis_list = []\n        for _i, axis in enumerate(self.dof_axis):\n            if axis == [1, 0, 0]:\n                axis_list.append(\"x\")\n            elif axis == [0, 1, 0]:\n                axis_list.append(\"y\")\n            elif axis == [0, 0, 1]:\n                axis_list.append(\"z\")\n            else:\n                raise ValueError(f\"Invalid axis: {axis}\")\n        # print(\"Axis list for this humanoid: \", axis_list)\n\n        self.dof_axis = torch.tensor(self.dof_axis)\n\n        for extend_config in cfg.extend_config:\n            self.body_names_augment += [extend_config.joint_name]\n            self._parents = torch.cat(\n                [\n                    self._parents,\n                    torch.tensor(\n                        [self.body_names.index(extend_config.parent_name)]\n                    ).to(device),\n                ],\n                dim=0,\n            )\n            self._offsets = torch.cat(\n                [\n                    self._offsets,\n                    torch.tensor([[extend_config.pos]]).to(device),\n                ],\n                dim=1,\n            )\n            self._local_rotation = torch.cat(\n                [\n                    self._local_rotation,\n                    torch.tensor([[extend_config.rot]]).to(device),\n                ],\n                dim=1,\n            )\n            self.num_extend_dof += 1\n\n            parent_id = self._proper_kinematic_structure.index(\n                extend_config.parent_name\n            )\n            self._proper_kinematic_structure.insert(\n                parent_id + 1, extend_config.joint_name\n            )\n\n        self.num_bodies = len(self.body_names)\n        self.num_bodies_augment = len(self.body_names_augment)\n\n        self.joints_range = mjcf_data[\"joints_range\"].to(device)\n        self._local_rotation_mat = (\n            torch_rotation_conversions.quaternion_to_matrix(\n                self._local_rotation\n            ).float()\n        )  # w, x, y ,z\n        self.load_mesh()\n\n        self.extend_to_proper_mapping = []\n        for _i, name in enumerate(self._proper_kinematic_structure):\n            self.extend_to_proper_mapping.append(\n                self.body_names_augment.index(name)\n            )\n        self.proper_to_extend_mapping = []\n        for _i, name in enumerate(self.body_names_augment):\n            self.proper_to_extend_mapping.append(\n                self._proper_kinematic_structure.index(name)\n            )\n\n    def from_mjcf(self, path):\n        # function from Poselib:\n        tree = ETree.parse(path)\n        xml_doc_root = tree.getroot()\n        xml_world_body = xml_doc_root.find(\"worldbody\")\n        if xml_world_body is None:\n            raise ValueError(\"MJCF parsed incorrectly please verify it.\")\n        # assume this is the root\n        xml_body_root = xml_world_body.find(\"body\")\n        if xml_body_root is None:\n            raise ValueError(\"MJCF parsed incorrectly please verify it.\")\n\n        # xml_joint_root = xml_body_root.find(\"joint\")  # Unused variable\n\n        node_names = []\n        parent_indices = []\n        local_translation = []\n        local_rotation = []\n        joints_range = []\n        body_to_joint = OrderedDict()\n\n        # recursively adding all nodes into the skel_tree\n        def _add_xml_node(xml_node, parent_index, node_index):\n            node_name = xml_node.attrib.get(\"name\")\n            # parse the local translation into float list\n            pos = np.fromstring(\n                xml_node.attrib.get(\"pos\", \"0 0 0\"), dtype=float, sep=\" \"\n            )\n            quat = np.fromstring(\n                xml_node.attrib.get(\"quat\", \"1 0 0 0\"), dtype=float, sep=\" \"\n            )\n            node_names.append(node_name)\n            parent_indices.append(parent_index)\n            local_translation.append(pos)\n            local_rotation.append(quat)\n            curr_index = node_index\n            node_index += 1\n            all_joints = xml_node.findall(\n                \"joint\"\n            )  # joints need to remove the first 6 joints\n            if len(all_joints) == 6:\n                all_joints = all_joints[6:]\n\n            for joint in all_joints:\n                if joint.attrib.get(\"range\") is not None:\n                    joints_range.append(\n                        np.fromstring(\n                            joint.attrib.get(\"range\"), dtype=float, sep=\" \"\n                        )\n                    )\n                else:\n                    if not joint.attrib.get(\"type\") == \"free\":\n                        joints_range.append([-np.pi, np.pi])\n            for joint_node in xml_node.findall(\"joint\"):\n                body_to_joint[node_name] = joint_node.attrib.get(\"name\")\n\n            for next_node in xml_node.findall(\"body\"):\n                node_index = _add_xml_node(next_node, curr_index, node_index)\n\n            return node_index\n\n        _add_xml_node(xml_body_root, -1, 0)\n        assert len(joints_range) == self.num_dof\n        return {\n            \"node_names\": node_names,\n            \"parent_indices\": torch.from_numpy(\n                np.array(parent_indices, dtype=np.int32)\n            ),\n            \"local_translation\": torch.from_numpy(\n                np.array(local_translation, dtype=np.float32)\n            ),\n            \"local_rotation\": torch.from_numpy(\n                np.array(local_rotation, dtype=np.float32)\n            ),\n            \"joints_range\": torch.from_numpy(np.array(joints_range)),\n            \"body_to_joint\": body_to_joint,\n        }\n\n    def fk_batch(\n        self, pose, trans, convert_to_mat=True, return_full=False, dt=1 / 30\n    ):\n        # device, dtype = pose.device, pose.dtype  # Unused variables\n        # pose_input = pose.clone()  # Unused variable\n        b, seq_len = pose.shape[:2]\n        pose = pose[\n            ..., : len(self._parents), :\n        ]  # H1 fitted joints might have extra joints\n\n        if convert_to_mat:\n            pose_quat = torch_rotation_conversions.axis_angle_to_quaternion(\n                pose.clone()\n            )\n            pose_mat = torch_rotation_conversions.quaternion_to_matrix(\n                pose_quat\n            )\n        else:\n            pose_mat = pose\n\n        if pose_mat.shape != 5:\n            pose_mat = pose_mat.reshape(b, seq_len, -1, 3, 3)\n        # j = pose_mat.shape[2] - 1  # Exclude root - unused variable\n        wbody_pos, wbody_mat = self.forward_kinematics_batch(\n            pose_mat[:, :, 1:], pose_mat[:, :, 0:1], trans\n        )\n\n        return_dict = EasyDict()\n\n        wbody_rot = torch_rotation_conversions.wxyz_to_xyzw(\n            torch_rotation_conversions.matrix_to_quaternion(wbody_mat)\n        )\n        if len(self.cfg.extend_config) > 0:\n            if return_full:\n                return_dict.global_velocity_extend = self._compute_velocity(\n                    wbody_pos, dt\n                )\n                return_dict.global_angular_velocity_extend = (\n                    self._compute_angular_velocity(wbody_rot, dt)\n                )\n\n            return_dict.global_translation_extend = wbody_pos.clone()\n            return_dict.global_rotation_mat_extend = wbody_mat.clone()\n            return_dict.global_rotation_extend = wbody_rot\n\n            wbody_pos = wbody_pos[..., : self.num_bodies, :]\n            wbody_mat = wbody_mat[..., : self.num_bodies, :, :]\n            wbody_rot = wbody_rot[..., : self.num_bodies, :]\n\n        return_dict.global_translation = wbody_pos\n        return_dict.global_rotation_mat = wbody_mat\n        return_dict.global_rotation = wbody_rot\n        if return_full:\n            rigidbody_linear_velocity = self._compute_velocity(wbody_pos, dt)\n            # Isaac gym is [x, y, z, w]. All the previous functions are\n            # [w, x, y, z]\n            rigidbody_angular_velocity = self._compute_angular_velocity(\n                wbody_rot, dt\n            )\n            return_dict.local_rotation = (\n                torch_rotation_conversions.wxyz_to_xyzw(pose_quat)\n            )\n            return_dict.global_root_velocity = rigidbody_linear_velocity[\n                ..., 0, :\n            ]\n            return_dict.global_root_angular_velocity = (\n                rigidbody_angular_velocity[..., 0, :]\n            )\n            return_dict.global_angular_velocity = rigidbody_angular_velocity\n            return_dict.global_velocity = rigidbody_linear_velocity\n\n            if len(self.cfg.extend_config) > 0:\n                return_dict.dof_pos = pose.sum(dim=-1)[\n                    ..., 1 : self.num_bodies\n                ]\n                # you can sum it up since unitree's each joint has 1 dof.\n                # Last two are for hands. doesn't really matter.\n            else:\n                if not len(self.actuated_joints_idx) == len(self.body_names):\n                    return_dict.dof_pos = pose.sum(dim=-1)[\n                        ..., self.actuated_joints_idx\n                    ]\n                else:\n                    return_dict.dof_pos = pose.sum(dim=-1)[..., 1:]\n\n            dof_vel = (\n                return_dict.dof_pos[:, 1:] - return_dict.dof_pos[:, :-1]\n            ) / dt\n            return_dict.dof_vels = torch.cat(\n                [dof_vel, dof_vel[:, -2:-1]], dim=1\n            )\n            return_dict.fps = int(1 / dt)\n\n        return return_dict\n\n    def convert_to_proper_kinematic(self, return_dict):\n        if len(self.cfg.extend_config) > 0:\n            return_dict.global_translation_extend = (\n                return_dict.global_translation_extend[\n                    ..., self.extend_to_proper_mapping, :\n                ]\n            )\n            return_dict.global_rotation_mat_extend = (\n                return_dict.global_rotation_mat_extend[\n                    ..., self.extend_to_proper_mapping, :, :\n                ]\n            )\n            return_dict.global_rotation_extend = (\n                return_dict.global_rotation_extend[\n                    ..., self.extend_to_proper_mapping, :\n                ]\n            )\n            return_dict.global_velocity_extend = (\n                return_dict.global_velocity_extend[\n                    ..., self.extend_to_proper_mapping, :\n                ]\n            )\n            return_dict.global_angular_velocity_extend = (\n                return_dict.global_angular_velocity_extend[\n                    ..., self.extend_to_proper_mapping, :\n                ]\n            )\n        else:\n            return_dict.global_translation = return_dict.global_translation[\n                ..., self.extend_to_proper_mapping, :\n            ]\n            return_dict.global_rotation_mat = return_dict.global_rotation_mat[\n                ..., self.extend_to_proper_mapping, :, :\n            ]\n            return_dict.global_rotation = return_dict.global_rotation[\n                ..., self.extend_to_proper_mapping, :\n            ]\n            return_dict.global_velocity = return_dict.global_velocity[\n                ..., self.extend_to_proper_mapping, :\n            ]\n            return_dict.global_angular_velocity = (\n                return_dict.global_angular_velocity[\n                    ..., self.extend_to_proper_mapping, :\n                ]\n            )\n        return return_dict\n\n    def forward_kinematics_batch(\n        self, rotations, root_rotations, root_positions\n    ):\n        \"\"\"Perform forward kinematics using the trajectory and rotations.\n\n        Arguments (where B = batch size, J = number of joints):\n         -- rotations: (B, J, 4) tensor of unit quaternions describing the\n         local rotations of each joint.\n         -- root_positions: (B, 3) tensor describing the root joint positions.\n        Output: joint positions (B, J, 3)\n\n        Reference:\n            https://github.com/ZhengyiLuo/PHC/blob/master/phc/utils/\n            torch_humanoid_batch.py\n        \"\"\"\n        device, dtype = root_rotations.device, root_rotations.dtype\n        b, seq_len = rotations.size()[0:2]\n        j = self._offsets.shape[1]\n        positions_world = []\n        rotations_world = []\n\n        expanded_offsets = (\n            self._offsets[:, None]\n            .expand(b, seq_len, j, 3)\n            .to(device)\n            .type(dtype)\n        )\n        # print(expanded_offsets.shape, j)\n\n        for i in range(j):\n            if self._parents[i] == -1:\n                positions_world.append(root_positions)\n                rotations_world.append(root_rotations)\n            else:\n                jpos = (\n                    torch.matmul(\n                        rotations_world[self._parents[i]][:, :, 0],\n                        expanded_offsets[:, :, i, :, None],\n                    ).squeeze(-1)\n                    + positions_world[self._parents[i]]\n                )\n                rot_mat = torch.matmul(\n                    rotations_world[self._parents[i]],\n                    torch.matmul(\n                        self._local_rotation_mat[:, (i) : (i + 1)],\n                        rotations[:, :, (i - 1) : i, :],\n                    ),\n                )\n                # rot_mat = torch.matmul(rotations_world[self._parents[i]],\n                # rotations[:, :, (i - 1):i, :])\n                # print(rotations[:, :, (i - 1):i, :].shape,\n                # self._local_rotation_mat.shape)\n\n                positions_world.append(jpos)\n                rotations_world.append(rot_mat)\n\n        positions_world = torch.stack(positions_world, dim=2)\n        rotations_world = torch.cat(rotations_world, dim=2)\n        return positions_world, rotations_world\n\n    @staticmethod\n    def _compute_velocity(p, time_delta, guassian_filter=True):\n        velocity = np.gradient(p.numpy(), axis=-3) / time_delta\n        if guassian_filter:\n            velocity = torch.from_numpy(\n                filters.gaussian_filter1d(velocity, 2, axis=-3, mode=\"nearest\")\n            ).to(p)\n        else:\n            velocity = torch.from_numpy(velocity).to(p)\n\n        return velocity\n\n    @staticmethod\n    def _compute_angular_velocity(r, time_delta: float, guassian_filter=True):\n        # assume the second last dimension is the time axis\n        diff_quat_data = poselib_rotation3d.quat_identity_like(r).to(r)\n        diff_quat_data[..., :-1, :, :] = poselib_rotation3d.quat_mul_norm(\n            r[..., 1:, :, :],\n            poselib_rotation3d.quat_inverse(r[..., :-1, :, :]),\n        )\n        diff_angle, diff_axis = poselib_rotation3d.quat_angle_axis(\n            diff_quat_data\n        )\n        angular_velocity = diff_axis * diff_angle.unsqueeze(-1) / time_delta\n        if guassian_filter:\n            angular_velocity = torch.from_numpy(\n                filters.gaussian_filter1d(\n                    angular_velocity.numpy(), 2, axis=-3, mode=\"nearest\"\n                ),\n            )\n        return angular_velocity\n\n    def load_mesh(self):\n        xml_base = os.path.dirname(self.mjcf_file)\n        # Read the compiler tag from the g1.xml file to find if there is a\n        # meshdir defined\n        tree = ETree.parse(self.mjcf_file)\n        xml_doc_root = tree.getroot()\n        compiler_tag = xml_doc_root.find(\"compiler\")\n\n        if compiler_tag is not None and \"meshdir\" in compiler_tag.attrib:\n            mesh_base = os.path.join(xml_base, compiler_tag.attrib[\"meshdir\"])\n        else:\n            mesh_base = xml_base\n\n        self.tree = tree = ETree.parse(self.mjcf_file)\n        xml_doc_root = tree.getroot()\n        xml_world_body = xml_doc_root.find(\"worldbody\")\n\n        xml_assets = xml_doc_root.find(\"asset\")\n        all_mesh = xml_assets.findall(\".//mesh\")\n\n        geoms = xml_world_body.findall(\".//geom\")\n\n        # all_joints = xml_world_body.findall(\".//joint\")  # Unused variable\n        # all_motors = tree.findall(\".//motor\")  # Unused variable\n        # all_bodies = xml_world_body.findall(\".//body\")  # Unused variable\n\n        def find_parent(root, child):\n            for parent in root.iter():\n                for elem in parent:\n                    if elem == child:\n                        return parent\n            return None\n\n        mesh_dict = {}\n        # mesh_parent_dict = {}  # Unused variable\n\n        for mesh_file_node in all_mesh:\n            mesh_name = mesh_file_node.attrib[\"name\"]\n            mesh_file = mesh_file_node.attrib[\"file\"]\n            mesh_full_file = osp.join(mesh_base, mesh_file)\n            mesh_obj = o3d.io.read_triangle_mesh(mesh_full_file)\n            mesh_dict[mesh_name] = mesh_obj\n\n        geom_transform = {}\n\n        body_to_mesh = defaultdict(set)\n        mesh_to_body = {}\n        for geom_node in geoms:\n            if \"mesh\" in geom_node.attrib:\n                parent = find_parent(xml_doc_root, geom_node)\n                body_to_mesh[parent.attrib[\"name\"]].add(\n                    geom_node.attrib[\"mesh\"]\n                )\n                mesh_to_body[geom_node] = parent\n                if \"pos\" in geom_node.attrib or \"quat\" in geom_node.attrib:\n                    geom_transform[parent.attrib[\"name\"]] = {}\n                    geom_transform[parent.attrib[\"name\"]][\"pos\"] = np.array(\n                        [0.0, 0.0, 0.0]\n                    )\n                    geom_transform[parent.attrib[\"name\"]][\"quat\"] = np.array(\n                        [1.0, 0.0, 0.0, 0.0]\n                    )\n                    if \"pos\" in geom_node.attrib:\n                        geom_transform[parent.attrib[\"name\"]][\"pos\"] = (\n                            np.array(\n                                [\n                                    float(f)\n                                    for f in geom_node.attrib[\"pos\"].split(\" \")\n                                ]\n                            )\n                        )\n                    if \"quat\" in geom_node.attrib:\n                        geom_transform[parent.attrib[\"name\"]][\"quat\"] = (\n                            np.array(\n                                [\n                                    float(f)\n                                    for f in geom_node.attrib[\"quat\"].split(\n                                        \" \"\n                                    )\n                                ]\n                            )\n                        )\n\n            else:\n                pass\n\n        self.geom_transform = geom_transform\n        self.mesh_dict = mesh_dict\n        self.body_to_mesh = body_to_mesh\n        self.mesh_to_body = mesh_to_body\n\n    def mesh_fk(self, pose=None, trans=None):\n        \"\"\"Load the mesh from the XML file and merge into the humanoid.\n\n        Reference:\n            https://github.com/ZhengyiLuo/PHC/blob/master/phc/utils/\n            torch_humanoid_batch.py\n        \"\"\"\n        if pose is None:\n            fk_res = self.fk_batch(\n                torch.zeros(1, 1, len(self.body_names_augment), 3),\n                torch.zeros(1, 1, 3),\n            )\n        else:\n            fk_res = self.fk_batch(pose, trans)\n\n        g_trans = fk_res.global_translation.squeeze()\n        g_rot = fk_res.global_rotation_mat.squeeze()\n        geoms = self.tree.find(\"worldbody\").findall(\".//geom\")\n        joined_mesh_obj = []\n        for geom in geoms:\n            if \"mesh\" not in geom.attrib:\n                continue\n            # parent_name = geom.attrib[\"mesh\"]\n\n            k = self.mesh_to_body[geom].attrib[\"name\"]\n            mesh_names = self.body_to_mesh[k]\n            body_idx = self.body_names.index(k)\n\n            body_trans = g_trans[body_idx].numpy().copy()\n            body_rot = g_rot[body_idx].numpy().copy()\n            for mesh_name in mesh_names:\n                mesh_obj = copy.deepcopy(self.mesh_dict[mesh_name])\n                if k in self.geom_transform:\n                    pos = self.geom_transform[k][\"pos\"]\n                    quat = self.geom_transform[k][\"quat\"]\n                    body_trans = body_trans + body_rot @ pos\n                    global_rot = (\n                        body_rot\n                        @ sRot.from_quat(quat[[1, 2, 3, 0]]).as_matrix()\n                    ).T\n                else:\n                    global_rot = body_rot.T\n                mesh_obj.rotate(global_rot.T, center=(0, 0, 0))\n                mesh_obj.translate(body_trans)\n                joined_mesh_obj.append(mesh_obj)\n\n        # Merge all meshes into a single mesh\n        merged_mesh = joined_mesh_obj[0]\n        for mesh in joined_mesh_obj[1:]:\n            merged_mesh += mesh\n\n        # Save the merged mesh to a file\n        # merged_mesh.compute_vertex_normals()\n        # o3d.io.write_triangle_mesh(f\"data/{self.cfg.humanoid_type}/\n        # combined_{self.cfg.humanoid_type}.stl\", merged_mesh)\n        return merged_mesh\n\n\n# @hydra.main(version_base=None, config_path=\"../../phc/data/cfg\",\n# config_name=\"config\")\ndef main(cfg: DictConfig):\n    device = torch.device(\"cpu\")\n    humanoid_fk = HumanoidBatch(cfg.robot, device)\n    humanoid_fk.mesh_fk()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/motion_retargeting/utils/visualize_with_mujoco.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n#\n# This file was originally copied from the [PHC] repository:\n# https://github.com/ZhengyiLuo/PHC\n# Modifications have been made to fit the needs of this project.\n\nimport glob\nimport os\nfrom typing import Any, Dict, List, Tuple\n\nimport cv2\nimport hydra\nimport mujoco\nimport numpy as np\nimport ray\nfrom omegaconf import DictConfig\nfrom tqdm.auto import tqdm\n\n\nclass OffscreenRenderer:\n    \"\"\"Offscreen renderer (no SMPL markers or joint spheres).\"\"\"\n\n    def __init__(self, model, height, width):\n        self.model = model\n        self.height = height\n        self.width = width\n\n        # Create OpenGL context\n        self.ctx = mujoco.GLContext(width, height)\n        self.ctx.make_current()\n\n        # Scene and camera setup\n        self.scene = mujoco.MjvScene(model, maxgeom=1000)\n        self.cam = mujoco.MjvCamera()\n        self.opt = mujoco.MjvOption()\n\n        self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE\n        self.cam.distance = 4.0\n        self.cam.azimuth = 60.0\n        self.cam.elevation = -20\n        self.cam.lookat = np.array([0.0, 0.0, 1.0])\n\n        # Rendering context\n        self.con = mujoco.MjrContext(\n            model, mujoco.mjtFontScale.mjFONTSCALE_100\n        )\n\n        # Buffers\n        self.rgb_buffer = np.zeros((height, width, 3), dtype=np.uint8)\n        self.viewport = mujoco.MjrRect(0, 0, width, height)\n\n    def render(\n        self,\n        data,\n        ref_body_positions: np.ndarray | None = None,\n        ref_marker_radius: float = 0.03,\n        ref_marker_rgba: np.ndarray | None = None,\n    ):\n        mujoco.mjv_updateScene(\n            self.model,\n            data,\n            self.opt,\n            None,\n            self.cam,\n            mujoco.mjtCatBit.mjCAT_ALL.value,\n            self.scene,\n        )\n        _draw_body_spheres_to_scene(\n            scene=self.scene,\n            body_positions=ref_body_positions,\n            radius=ref_marker_radius,\n            rgba=ref_marker_rgba,\n        )\n        mujoco.mjr_render(self.viewport, self.scene, self.con)\n        mujoco.mjr_readPixels(self.rgb_buffer, None, self.viewport, self.con)\n        return np.flipud(self.rgb_buffer)\n\n    def close(self):\n        self.ctx.free()\n\n\ndef _get_key_prefix_order(cfg: DictConfig) -> List[str]:\n    \"\"\"\n    Determine the key prefix order used to extract arrays from NPZ files.\n    Priority:\n      1) cfg.key_prefix_order (list or single value)\n      2) cfg.key_prefix (single value)\n      3) default [\"ref_\", \"\", \"robot_\"]\n    \"\"\"\n    configured = cfg.get(\"key_prefix_order\", None)\n    if configured is not None:\n        order_list = (\n            [str(p) for p in configured]\n            if isinstance(configured, (list, tuple))\n            else [str(configured)]\n        )\n    else:\n        single = cfg.get(\"key_prefix\", None)\n        if single is not None:\n            order_list = [str(single)]\n        else:\n            order_list = [\"ref_\", \"\", \"robot_\"]\n    print(f\"Using key_prefix_order: {order_list}\")\n    return order_list\n\n\ndef _get_ref_key_prefix_order(cfg: DictConfig) -> List[str]:\n    \"\"\"Determine the prefix order used to read reference overlay arrays.\"\"\"\n    configured = cfg.get(\"ref_key_prefix_order\", None)\n    if configured is not None:\n        order_list = (\n            [str(p) for p in configured]\n            if isinstance(configured, (list, tuple))\n            else [str(configured)]\n        )\n    else:\n        single = cfg.get(\"ref_key_prefix\", None)\n        if single is not None:\n            order_list = [str(single)]\n        else:\n            order_list = [\"ref_\"]\n    print(f\"Using ref_key_prefix_order: {order_list}\")\n    return order_list\n\n\ndef _pick_with_prefixes(\n    arrays: Dict[str, np.ndarray],\n    base_name: str,\n    prefixes: List[str],\n) -> np.ndarray | None:\n    \"\"\"\n    Return arrays[prefix + base_name] for the first matching prefix in order.\n    For non-empty prefixes, also attempts \"<prefix.rstrip('_')>_<base_name>\".\n    \"\"\"\n    for prefix in prefixes:\n        if prefix == \"\":\n            candidate = base_name\n            if candidate in arrays:\n                return arrays[candidate]\n        else:\n            cand1 = f\"{prefix}{base_name}\"\n            if cand1 in arrays:\n                return arrays[cand1]\n            cand2 = f\"{prefix.rstrip('_')}_{base_name}\"\n            if cand2 in arrays:\n                return arrays[cand2]\n    return None\n\n\ndef _resolve_visualization_arrays(\n    arrays: Dict[str, np.ndarray],\n    key_prefix_order: List[str],\n    draw_ref_body_spheres: bool = False,\n    ref_key_prefix_order: List[str] | None = None,\n) -> Dict[str, np.ndarray | None]:\n    \"\"\"Resolve playback arrays and optional reference overlay arrays.\"\"\"\n    dof_pos = _pick_with_prefixes(arrays, \"dof_pos\", key_prefix_order)\n    global_translation = _pick_with_prefixes(\n        arrays, \"global_translation\", key_prefix_order\n    )\n    global_rotation_quat = _pick_with_prefixes(\n        arrays, \"global_rotation_quat\", key_prefix_order\n    )\n\n    ref_body_positions = None\n    if draw_ref_body_spheres:\n        ref_prefixes = (\n            ref_key_prefix_order\n            if ref_key_prefix_order is not None\n            else [\"ref_\"]\n        )\n        ref_body_positions = _pick_with_prefixes(\n            arrays, \"global_translation\", ref_prefixes\n        )\n\n    return {\n        \"dof_pos\": dof_pos,\n        \"global_translation\": global_translation,\n        \"global_rotation_quat\": global_rotation_quat,\n        \"ref_body_positions\": ref_body_positions,\n    }\n\n\ndef _draw_body_spheres_to_scene(\n    scene,\n    body_positions: np.ndarray | None,\n    radius: float,\n    rgba: np.ndarray | None,\n) -> None:\n    \"\"\"Append sphere markers for body positions to the current MuJoCo scene.\"\"\"\n    if body_positions is None:\n        return\n\n    sphere_rgba = (\n        np.array([0.8, 0.0, 0.0, 1.0], dtype=np.float32)\n        if rgba is None\n        else np.asarray(rgba, dtype=np.float32)\n    )\n    size = np.array([radius, 0.0, 0.0], dtype=np.float32)\n    mat = np.eye(3, dtype=np.float32).reshape(-1)\n\n    start = int(scene.ngeom)\n    idx = 0\n    for pos in body_positions:\n        geom_id = start + idx\n        if geom_id >= scene.maxgeom:\n            break\n        mujoco.mjv_initGeom(\n            scene.geoms[geom_id],\n            mujoco.mjtGeom.mjGEOM_SPHERE,\n            size,\n            pos.astype(np.float32),\n            mat,\n            sphere_rgba,\n        )\n        idx += 1\n    scene.ngeom = start + idx\n\n\ndef _load_npz_as_motion(\n    npz_path: str,\n) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], str]:\n    \"\"\"\n    Load a single .npz file and return (arrays_dict, metadata_dict, motion_name)\n    - metadata: parsed from JSON\n    - motion_name: file name without extension\n    \"\"\"\n    with np.load(npz_path) as z:\n        arrays = {k: z[k] for k in z.files if k != \"metadata\"}\n        meta_raw = z.get(\"metadata\", None)\n        if meta_raw is None:\n            metadata = {}\n        else:\n            metadata = {}\n            try:\n                metadata = dict(np.atleast_1d(meta_raw).tolist())\n            except Exception:\n                try:\n                    metadata = {**(dict()), **(eval(str(meta_raw)))}\n                except Exception:\n                    pass\n            # Parse metadata as JSON string\n            try:\n                import json\n\n                metadata = json.loads(str(np.atleast_1d(meta_raw)[0]))\n            except Exception:\n                pass\n\n    motion_name = os.path.splitext(os.path.basename(npz_path))[0]\n    return arrays, metadata, motion_name\n\n\ndef _collect_all_npz(\n    npz_root: str, motion_name: str\n) -> List[Tuple[Dict[str, np.ndarray], Dict[str, Any], str]]:\n    \"\"\"Collect all NPZ files to process based on configuration.\"\"\"\n    print(\"Collecting NPZ files...\", npz_root, motion_name)\n    base = (\n        os.path.join(npz_root, \"clips\")\n        if os.path.isdir(os.path.join(npz_root, \"clips\"))\n        else npz_root\n    )\n    if motion_name == \"all\":\n        npz_files = [\n            p\n            for p in glob.glob(\n                os.path.join(base, \"**\", \"*.npz\"), recursive=True\n            )\n        ]\n    else:\n        # try both base and base/clips\n        candidate = os.path.join(base, f\"{motion_name}.npz\")\n        npz_files = [candidate]\n\n    motions = []\n    for f in tqdm(npz_files, desc=\"Loading npz files\"):\n        try:\n            arrays, metadata, name = _load_npz_as_motion(f)\n            motions.append((arrays, metadata, name))\n        except Exception as e:\n            print(f\"Failed to load {f}: {e}\")\n    return motions\n\n\ndef _infer_fps_from_meta(\n    metadata: Dict[str, Any], default_fps: float = 50.0\n) -> float:\n    \"\"\"Infer FPS value from metadata.\"\"\"\n    try:\n        return float(metadata.get(\"motion_fps\", default_fps))\n    except Exception:\n        return float(default_fps)\n\n\ndef _time_length(*arrays) -> int:\n    \"\"\"Return the smallest time dimension length among given arrays, ignoring None.\"\"\"\n    T = None\n    for a in arrays:\n        if isinstance(a, np.ndarray) and a.ndim >= 1:\n            t = a.shape[0]\n            T = t if T is None else min(T, t)\n    return T if T is not None else 0\n\n\n@ray.remote\ndef process_single_motion_remote_npz(\n    arrays: Dict[str, np.ndarray],\n    metadata: Dict[str, Any],\n    motion_name: str,\n    cfg_dict: dict,\n) -> str:\n    try:\n        cfg = DictConfig(cfg_dict)\n\n        # MuJoCo model\n        mj_model = mujoco.MjModel.from_xml_path(cfg.robot.asset.assetFileName)\n        mj_data = mujoco.MjData(mj_model)\n\n        # Renderer\n        width, height = 1280, 720\n        renderer = OffscreenRenderer(mj_model, height, width)\n\n        # FPS\n        src_fps = _infer_fps_from_meta(metadata, default_fps=50.0)\n        skip_frames = getattr(cfg, \"skip_frames\", 1)\n        actual_fps = src_fps / max(1, int(skip_frames))\n\n        # Video writer\n        fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n        out_path = os.path.join(cfg.video_dir, f\"{motion_name}.mp4\")\n        os.makedirs(os.path.dirname(out_path), exist_ok=True)\n        out = cv2.VideoWriter(out_path, fourcc, actual_fps, (width, height))\n\n        try:\n            prefix_order = _get_key_prefix_order(cfg)\n            draw_ref_body_spheres = bool(\n                getattr(cfg, \"draw_ref_body_spheres\", False)\n            )\n            ref_prefix_order = _get_ref_key_prefix_order(cfg)\n            resolved = _resolve_visualization_arrays(\n                arrays=arrays,\n                key_prefix_order=prefix_order,\n                draw_ref_body_spheres=draw_ref_body_spheres,\n                ref_key_prefix_order=ref_prefix_order,\n            )\n            dof_pos = resolved[\"dof_pos\"]\n            gpos = resolved[\"global_translation\"]\n            grot = resolved[\"global_rotation_quat\"]\n            ref_body_positions = resolved[\"ref_body_positions\"]\n\n            if (\n                not isinstance(dof_pos, np.ndarray)\n                or not isinstance(gpos, np.ndarray)\n                or not isinstance(grot, np.ndarray)\n            ):\n                raise ValueError(\n                    \"Missing required NPZ keys: dof_pos / global_translation / global_rotation_quat\"\n                )\n\n            # Time dimension alignment\n            T = _time_length(dof_pos, gpos, grot, ref_body_positions)\n            if T == 0:\n                raise ValueError(\"No valid frames found.\")\n\n            for t in range(0, T, max(1, int(skip_frames))):\n                # Root position and quaternion: take from body 0\n                root_pos = gpos[t, 0]\n                root_quat_xyzw = grot[t, 0]\n                root_quat_wxyz = root_quat_xyzw[[3, 0, 1, 2]]\n\n                mj_data.qpos[:3] = root_pos\n                mj_data.qpos[3:7] = root_quat_wxyz\n                mj_data.qpos[7:] = dof_pos[t]\n\n                mujoco.mj_forward(mj_model, mj_data)\n                safe_lookat = np.array(\n                    renderer.cam.lookat\n                )  # 当前相机中心，先取出来\n                safe_lookat[0] = root_pos[0]\n                safe_lookat[1] = root_pos[1]\n\n                min_height = 1.0\n                safe_lookat[2] = max(root_pos[2], min_height)\n                renderer.cam.lookat[:] = safe_lookat\n                frame_ref_body_positions = (\n                    ref_body_positions[t]\n                    if isinstance(ref_body_positions, np.ndarray)\n                    else None\n                )\n                frame = renderer.render(\n                    mj_data,\n                    ref_body_positions=frame_ref_body_positions,\n                )\n                # Convert RGB (MuJoCo) -> BGR (OpenCV) before writing\n                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n                out.write(frame_bgr)\n\n        finally:\n            out.release()\n            renderer.close()\n\n        return motion_name\n\n    except Exception as e:\n        return f\"ERROR_{motion_name}: {str(e)}\"\n\n\nclass MotionRendererNPZ:\n    def process_single_motion(\n        self,\n        arrays: Dict[str, np.ndarray],\n        metadata: Dict[str, Any],\n        motion_name: str,\n        cfg: DictConfig,\n    ):\n        mj_model = mujoco.MjModel.from_xml_path(cfg.robot.asset.assetFileName)\n        mj_data = mujoco.MjData(mj_model)\n\n        width, height = 1280, 720\n        renderer = OffscreenRenderer(mj_model, height, width)\n\n        src_fps = _infer_fps_from_meta(metadata, default_fps=50.0)\n        skip_frames = getattr(cfg, \"skip_frames\", 1)\n        actual_fps = src_fps / max(1, int(skip_frames))\n\n        fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n        out_path = os.path.join(cfg.video_dir, f\"{motion_name}.mp4\")\n        os.makedirs(os.path.dirname(out_path), exist_ok=True)\n        out = cv2.VideoWriter(out_path, fourcc, actual_fps, (width, height))\n\n        try:\n            prefix_order = _get_key_prefix_order(cfg)\n            draw_ref_body_spheres = bool(\n                getattr(cfg, \"draw_ref_body_spheres\", False)\n            )\n            ref_prefix_order = _get_ref_key_prefix_order(cfg)\n            resolved = _resolve_visualization_arrays(\n                arrays=arrays,\n                key_prefix_order=prefix_order,\n                draw_ref_body_spheres=draw_ref_body_spheres,\n                ref_key_prefix_order=ref_prefix_order,\n            )\n            dof_pos = resolved[\"dof_pos\"]\n            gpos = resolved[\"global_translation\"]\n            grot = resolved[\"global_rotation_quat\"]\n            ref_body_positions = resolved[\"ref_body_positions\"]\n\n            if (\n                not isinstance(dof_pos, np.ndarray)\n                or not isinstance(gpos, np.ndarray)\n                or not isinstance(grot, np.ndarray)\n            ):\n                raise ValueError(\n                    \"Missing required NPZ keys: dof_pos / global_translation / global_rotation_quat\"\n                )\n\n            T = _time_length(dof_pos, gpos, grot, ref_body_positions)\n            if T == 0:\n                raise ValueError(\"No valid frames found.\")\n\n            for t in tqdm(\n                range(0, T, max(1, int(skip_frames))),\n                desc=f\"Rendering {motion_name}\",\n            ):\n                root_pos = gpos[t, 0]\n                root_quat_xyzw = grot[t, 0]\n                root_quat_wxyz = root_quat_xyzw[[3, 0, 1, 2]]\n\n                mj_data.qpos[:3] = root_pos\n                mj_data.qpos[3:7] = root_quat_wxyz\n                mj_data.qpos[7:] = dof_pos[t]\n\n                mujoco.mj_forward(mj_model, mj_data)\n                safe_lookat = np.array(\n                    renderer.cam.lookat\n                )  # 当前相机中心，先取出来\n                safe_lookat[0] = root_pos[0]\n                safe_lookat[1] = root_pos[1]\n\n                min_height = 1.0\n                safe_lookat[2] = max(root_pos[2], min_height)\n                renderer.cam.lookat[:] = safe_lookat\n                frame_ref_body_positions = (\n                    ref_body_positions[t]\n                    if isinstance(ref_body_positions, np.ndarray)\n                    else None\n                )\n                frame = renderer.render(\n                    mj_data,\n                    ref_body_positions=frame_ref_body_positions,\n                )\n                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n                out.write(frame_bgr)\n        finally:\n            out.release()\n            renderer.close()\n\n        return motion_name\n\n\n@hydra.main(\n    version_base=None,\n    config_path=\"../../../config/motion_retargeting\",\n    config_name=\"unitree_G1_29dof_retargeting\",\n)\ndef main(cfg: DictConfig) -> None:\n    \"\"\"\n    Required config fields:\n    - cfg.robot.asset.assetFileName : Path to the MuJoCo XML file\n    - cfg.video_dir : Output video directory\n    - cfg.motion_npz_root : Directory containing NPZ files\n    - cfg.motion_name : \"all\" or a specific clip name (without extension)\n    - cfg.skip_frames : Frame step size (>=1)\n    Optional:\n    - cfg.key_prefix_order : List[str] or str for key prefix matching order\n    - cfg.key_prefix : Single prefix to use (overridden by key_prefix_order)\n    \"\"\"\n    try:\n        # NPZ input\n        motions = _collect_all_npz(cfg.motion_npz_root, cfg.motion_name)\n        if not motions:\n            print(\"No NPZ motions found.\")\n            return\n\n        # Ray parallel or single-thread mode\n        if cfg.motion_name == \"all\":\n            if not ray.is_initialized():\n                num_cpus = min(os.cpu_count(), cfg.get(\"max_workers\", 8))\n                ray.init(num_cpus=num_cpus)\n                print(f\"Initialized Ray with {num_cpus} workers\")\n\n            cfg_dict = dict(cfg)\n            tasks = [\n                process_single_motion_remote_npz.remote(\n                    arr, meta, name, cfg_dict\n                )\n                for (arr, meta, name) in motions\n            ]\n\n            completed, failed = [], []\n            with tqdm(total=len(tasks), desc=\"Processing Motions\") as pbar:\n                remaining = list(tasks)\n                while remaining:\n                    ready, remaining = ray.wait(\n                        remaining, num_returns=1, timeout=1.0\n                    )\n                    for t in ready:\n                        try:\n                            res = ray.get(t)\n                            if isinstance(res, str) and res.startswith(\n                                \"ERROR_\"\n                            ):\n                                failed.append(res)\n                                print(f\"Failed: {res}\")\n                            else:\n                                completed.append(res)\n                                print(f\"Completed: {res}\")\n                        except Exception as e:\n                            failed.append(f\"Task exception: {e}\")\n                        pbar.update(1)\n\n            print(\"\\nProcessing complete!\")\n            print(f\"Success: {len(completed)}; Failed: {len(failed)}\")\n            if failed:\n                for f in failed:\n                    print(\"  -\", f)\n            ray.shutdown()\n        else:\n            renderer = MotionRendererNPZ()\n            for arr, meta, name in motions:\n                res = renderer.process_single_motion(arr, meta, name, cfg)\n                print(f\"Processed: {res}\")\n\n    except Exception as e:\n        print(f\"Error during processing: {e}\")\n        if ray.is_initialized():\n            ray.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/training/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/training/h5_dataloader.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\n\"\"\"Simplified HDF5 motion cache backed by a PyTorch ``DataLoader``.\n\nThis module provides two core utilities:\n\n* ``Hdf5MotionDataset`` – loads contiguous motion windows directly from HDF5\n  shards using metadata stored in ``manifest.json``.\n* ``MotionClipBatchCache`` – maintains a double-buffered cache of motion clips\n  with deterministic swapping semantics suitable for high-throughput\n  reinforcement learning.\n\nCompared to the legacy slot-based prefetcher, this implementation keeps the\npipeline intentionally simple:\n\n* A dataset-worker keeps shard handles open locally; no Ray dependency.\n* Each cached batch has a fixed shape\n  ``[max_num_clips, max_frame_length, feature_dims]``.\n* Swapping a batch is handled via an O(1) pointer flip once the next batch is\n  staged on the desired device (CPU or GPU).\n\nThe cache exposes helper methods that mirror the data access patterns required\nby ``RefMotionCommand``:\n\n* ``sample_env_assignments`` for initial clip/frame sampling.\n* ``gather_tensor`` to fetch exactly one tensor field for ``1 + n_future``\n  frames per environment.\n\nAll tensors returned by this module are ``torch.float32`` unless stated\notherwise; tensor shapes are noted explicitly in type annotations.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport math\nimport os\nimport re\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\nfrom collections import OrderedDict\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import (\n    Any,\n    Dict,\n    Iterator,\n    List,\n    Mapping,\n    Optional,\n    Sequence,\n    Tuple,\n)\n\nimport h5py\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nfrom torch.utils.data import DataLoader, Dataset, DistributedSampler, Sampler\nfrom loguru import logger\nfrom tabulate import tabulate\nfrom tqdm import tqdm\n\nfrom holomotion.src.motion_retargeting.reference_filtering import (\n    butterworth_filter_root_dof_arrays,\n)\nfrom holomotion.src.utils import torch_utils\nfrom holomotion.src.motion_retargeting.holomotion_fk import HoloMotionFK\n\nTensor = torch.Tensor\n\n\ndef _cpu_only_dataloader_worker_init_fn(worker_id: int) -> None:\n    \"\"\"Keep cache workers lightweight without mutating CUDA visibility.\"\"\"\n    del worker_id\n    torch.set_num_threads(1)\n\n\ndef _allocate_batch_counts(\n    raw_counts: List[float], target_total: int\n) -> List[int]:\n    \"\"\"Allocate integer counts that sum exactly to target_total.\"\"\"\n    total = int(max(0, target_total))\n    if len(raw_counts) == 0:\n        return []\n    base_counts = [max(0, int(c)) for c in raw_counts]\n    residuals = [float(c) - float(int(c)) for c in raw_counts]\n    remaining = total - int(sum(base_counts))\n    if remaining > 0:\n        order = sorted(\n            range(len(residuals)),\n            key=lambda i: residuals[i],\n            reverse=True,\n        )\n        idx_pos = 0\n        while remaining > 0:\n            j = order[idx_pos % len(order)]\n            base_counts[j] += 1\n            remaining -= 1\n            idx_pos += 1\n    elif remaining < 0:\n        order = sorted(range(len(residuals)), key=lambda i: residuals[i])\n        idx_pos = 0\n        while remaining < 0:\n            j = order[idx_pos % len(order)]\n            if base_counts[j] > 0:\n                base_counts[j] -= 1\n                remaining += 1\n            idx_pos += 1\n    if sum(base_counts) != total:\n        raise RuntimeError(\n            \"Internal error: integer batch-count allocation did not preserve total.\"\n        )\n    return [max(0, int(c)) for c in base_counts]\n\n\ndef _configure_weighted_bins(\n    keys: List[str],\n    cfg: Mapping[str, Any],\n    batch_size_for_log: int,\n) -> Tuple[List[List[int]], List[float], List[Dict[str, Any]]]:\n    \"\"\"Common helper to parse config, assign bins, and compute batch fractions.\"\"\"\n    if batch_size_for_log <= 0:\n        batch_size_for_log = 1\n\n    cfg_local: Dict[str, Any] = dict(cfg or {})\n\n    patterns_cfg = cfg_local.get(\"bin_regex_patterns\")\n    if patterns_cfg is None:\n        patterns_cfg = cfg_local.get(\"bin_regrex_patterns\")\n    if not patterns_cfg:\n        raise ValueError(\n            \"weighted_bin configuration requires 'bin_regex_patterns' \"\n            \"(list of {regex, ratio}) to be configured\"\n        )\n\n    compiled_patterns: List[Dict[str, Any]] = []\n    ratios: List[float] = []\n    for idx, entry in enumerate(patterns_cfg):\n        if not isinstance(entry, Mapping):\n            raise ValueError(\n                f\"Entry {idx} in bin_regex_patterns must be a mapping, \"\n                f\"got {type(entry)}\"\n            )\n        regex_str = entry.get(\"regex\", entry.get(\"regrex\", None))\n        if not isinstance(regex_str, str) or not regex_str:\n            raise ValueError(\n                f\"Entry {idx} in bin_regex_patterns is missing a non-empty \"\n                f\"'regex' field\"\n            )\n        ratio_val = entry.get(\"ratio\", None)\n        if ratio_val is None:\n            raise ValueError(\n                f\"Entry {idx} in bin_regex_patterns is missing 'ratio'\"\n            )\n        ratio_f = float(ratio_val)\n        if ratio_f < 0.0 or ratio_f > 1.0:\n            raise ValueError(\n                f\"Entry {idx} in bin_regex_patterns has invalid ratio \"\n                f\"{ratio_f:.6f}; expected in [0.0, 1.0]\"\n            )\n        compiled_patterns.append(\n            {\n                \"name\": str(entry.get(\"name\", f\"bin_{idx}\")),\n                \"regex\": regex_str,\n                \"compiled\": re.compile(regex_str),\n            }\n        )\n        ratios.append(ratio_f)\n\n    sum_explicit = float(sum(ratios))\n    if sum_explicit > 1.0 + 1.0e-6:\n        raise ValueError(\n            f\"Sum of weighted-bin ratios is {sum_explicit:.6f} (> 1.0). \"\n            \"Please reduce the ratios so that their sum is <= 1.0.\"\n        )\n    if sum_explicit > 1.0:\n        sum_explicit = 1.0\n    others_ratio = max(0.0, 1.0 - sum_explicit)\n\n    if len(keys) == 0:\n        raise ValueError(\n            \"weighted_bin configuration received an empty key set\"\n        )\n\n    num_items_total = float(len(keys))\n    num_explicit = len(compiled_patterns)\n    bin_indices: List[List[int]] = [[] for _ in range(num_explicit + 1)]\n\n    for idx, motion_key in enumerate(keys):\n        assigned = False\n        for b_idx, pat in enumerate(compiled_patterns):\n            if pat[\"compiled\"].search(motion_key):\n                bin_indices[b_idx].append(idx)\n                assigned = True\n                break\n        if not assigned:\n            bin_indices[-1].append(idx)\n\n    # Combine explicit ratios with implicit \"others\" ratio\n    all_ratios: List[float] = list(ratios)\n    all_ratios.append(others_ratio)\n\n    # If all motion keys are covered by explicit regex bins, but the specified\n    # ratios sum to less than 1.0, linearly reweight explicit ratios so that\n    # they sum to 1.0 and disable the implicit \"others\" bin.\n    others_count = len(bin_indices[-1])\n    if others_count == 0 and others_ratio > 0.0 and sum_explicit > 0.0:\n        scale = 1.0 / sum_explicit\n        ratios = [r * scale for r in ratios]\n        others_ratio = 0.0\n        all_ratios = list(ratios)\n        all_ratios.append(others_ratio)\n        logger.info(\n            \"Weighted-bin: all regex bins cover the dataset; \"\n            \"linearly reweighted explicit ratios to sum to 1.0 and disabled \"\n            \"the implicit 'others' bin.\"\n        )\n\n    # Validate non-empty bins for any positive ratio (including others)\n    for b_idx, r in enumerate(all_ratios):\n        if r > 0.0 and len(bin_indices[b_idx]) == 0:\n            if b_idx < num_explicit:\n                name = compiled_patterns[b_idx][\"name\"]\n                regex_s = compiled_patterns[b_idx][\"regex\"]\n                raise ValueError(\n                    f\"Weighted-bin '{name}' (regex='{regex_s}') has ratio \"\n                    f\"{r:.6f} but matched no motion keys\"\n                )\n            raise ValueError(\n                f\"Weighted-bin 'others' has ratio {r:.6f} but matched no motion keys\"\n            )\n\n    # Prepare logging summary using the configured cache batch size\n    raw_counts_log = [ratio * batch_size_for_log for ratio in all_ratios]\n    base_counts_log = _allocate_batch_counts(\n        raw_counts=raw_counts_log,\n        target_total=batch_size_for_log,\n    )\n    batch_fractions_log = [\n        float(c) / float(batch_size_for_log) for c in base_counts_log\n    ]\n\n    # Build specs using the final, actually used batch fractions\n    specs: List[Dict[str, Any]] = []\n    total_items = float(max(1, num_items_total))\n    for b_idx in range(num_explicit):\n        name = compiled_patterns[b_idx][\"name\"]\n        regex_s = compiled_patterns[b_idx][\"regex\"]\n        n = len(bin_indices[b_idx])\n        ds_frac = float(n) / total_items\n        bf = batch_fractions_log[b_idx]\n        specs.append(\n            {\n                \"name\": name,\n                \"regex\": regex_s,\n                \"ratio\": bf,\n                \"count\": n,\n                \"dataset_fraction\": ds_frac,\n                \"batch_fraction\": bf,\n            }\n        )\n    # Others bin\n    others_name = \"others\"\n    others_regex = \"<unmatched>\"\n    n_o = len(bin_indices[-1])\n    ds_frac_o = float(n_o) / total_items\n    bf_o = batch_fractions_log[-1]\n    specs.append(\n        {\n            \"name\": others_name,\n            \"regex\": others_regex,\n            \"ratio\": bf_o,\n            \"count\": n_o,\n            \"dataset_fraction\": ds_frac_o,\n            \"batch_fraction\": bf_o,\n        }\n    )\n\n    return bin_indices, all_ratios, specs\n\n\ndef _collect_manifest_keys(\n    manifest_path: str | Sequence[str],\n) -> Tuple[List[str], Dict[str, str], List[str]]:\n    if isinstance(manifest_path, (str, os.PathLike)):\n        manifest_paths: List[str] = [str(manifest_path)]\n    else:\n        manifest_paths = [str(p) for p in manifest_path]\n    if len(manifest_paths) == 0:\n        raise ValueError(\"Expected at least one manifest path\")\n\n    key_source: Dict[str, str] = {}\n    for mp in manifest_paths:\n        if not os.path.exists(mp):\n            raise FileNotFoundError(\n                f\"HDF5 manifest not found at {mp}. \"\n                \"Please set robot.motion.hdf5_root/train_hdf5_roots \"\n                \"to the correct path.\"\n            )\n        with open(mp, \"r\", encoding=\"utf-8\") as handle:\n            manifest = json.load(handle)\n        clips = manifest.get(\"clips\", {})\n        if not clips:\n            raise ValueError(\n                f\"Manifest at {mp} contains no clips; cannot preview sampling.\"\n            )\n        for key in clips.keys():\n            if key in key_source:\n                raise ValueError(\n                    f\"Duplicate motion clip key '{key}' found in multiple \"\n                    \"manifests; clip keys must be globally unique.\"\n                )\n            key_source[key] = mp\n\n    return list(key_source.keys()), key_source, manifest_paths\n\n\ndef _normalize_online_filter_cfg(\n    cfg: Optional[Mapping[str, Any]],\n    *,\n    default_vel_smoothing_sigma: float = 2.0,\n) -> Dict[str, Any]:\n    cfg_local = dict(cfg or {})\n    enabled = bool(cfg_local.get(\"enabled\", False))\n    cutoff_pool_cfg = cfg_local.get(\"butter_cutoff_hz_pool\", [])\n    cutoff_pool = tuple(float(v) for v in cutoff_pool_cfg)\n    ref_vel_smoothing_sigma = float(\n        cfg_local.get(\"ref_vel_smoothing_sigma\", default_vel_smoothing_sigma)\n    )\n    ft_ref_vel_smoothing_sigma = float(\n        cfg_local.get(\n            \"ft_ref_vel_smoothing_sigma\", default_vel_smoothing_sigma\n        )\n    )\n    if enabled and len(cutoff_pool) == 0:\n        raise ValueError(\n            \"online_filter.enabled=True requires butter_cutoff_hz_pool to \"\n            \"contain at least one cutoff value\"\n        )\n    butter_order = int(cfg_local.get(\"butter_order\", 4))\n    if butter_order <= 0:\n        raise ValueError(\"online_filter.butter_order must be positive\")\n    return {\n        \"enabled\": enabled,\n        \"butter_order\": butter_order,\n        \"butter_cutoff_hz_pool\": cutoff_pool,\n        \"ref_vel_smoothing_sigma\": ref_vel_smoothing_sigma,\n        \"ft_ref_vel_smoothing_sigma\": ft_ref_vel_smoothing_sigma,\n    }\n\n\ndef preview_weighted_bin_from_manifest(\n    manifest_path: str | Sequence[str],\n    batch_size: int,\n    cfg: Mapping[str, Any],\n) -> None:\n    \"\"\"Lightweight preview of weighted-bin sampling using manifest.json only.\n\n    This helper is intended to be called at configuration time before any\n    MotionClipBatchCache/DataLoader is constructed, so that invalid regex or\n    ratio settings can fail fast without incurring the cost of cache setup.\n    \"\"\"\n    if batch_size <= 0:\n        batch_size = 1\n\n    keys, _, _ = _collect_manifest_keys(manifest_path=manifest_path)\n    _, _, specs = _configure_weighted_bins(\n        keys=keys,\n        cfg=cfg,\n        batch_size_for_log=batch_size,\n    )\n\n    table_rows = []\n    for item in specs:\n        table_rows.append(\n            [\n                item[\"name\"],\n                item[\"regex\"],\n                f\"{item['ratio']:.4f}\",\n                int(item[\"count\"]),\n                f\"{item['dataset_fraction']:.4f}\",\n                f\"{item['batch_fraction']:.4f}\",\n            ]\n        )\n    headers = [\n        \"bin\",\n        \"regex\",\n        \"final_ratio\",\n        \"num_clips\",\n        \"clip_fraction\",\n        \"batch_fraction\",\n    ]\n    logger.info(\n        \"Weighted-bin config preview (manifest-level):\\n\"\n        + tabulate(table_rows, headers=headers, tablefmt=\"simple_outline\")\n    )\n\n\ndef preview_uniform_from_manifest(\n    manifest_path: str | Sequence[str],\n    batch_size: int,\n    *,\n    max_frame_length: int,\n    min_window_length: int,\n    handpicked_motion_names: Optional[Sequence[str]] = None,\n    excluded_motion_names: Optional[Sequence[str]] = None,\n) -> None:\n    \"\"\"Manifest-level preview table for uniform/curriculum sampling.\"\"\"\n    if batch_size <= 0:\n        batch_size = 1\n    if max_frame_length <= 0:\n        raise ValueError(\"max_frame_length must be positive\")\n    if min_window_length <= 0:\n        raise ValueError(\"min_window_length must be positive\")\n\n    _, _, manifest_paths = _collect_manifest_keys(manifest_path=manifest_path)\n    handpicked_set = (\n        set(handpicked_motion_names)\n        if handpicked_motion_names is not None\n        else None\n    )\n    excluded_set = (\n        set(excluded_motion_names)\n        if excluded_motion_names is not None\n        else None\n    )\n\n    def _normalize_key(value: Any) -> Optional[str]:\n        if value is None:\n            return None\n        if isinstance(value, bytes):\n            value = value.decode(\"utf-8\")\n        key = value if isinstance(value, str) else str(value)\n        if not key:\n            return None\n        return key\n\n    def _build_aliases(motion_key: str, meta: Mapping[str, Any]) -> List[str]:\n        aliases: List[str] = []\n\n        def _add(value: Any) -> None:\n            key = _normalize_key(value)\n            if key is None or key in aliases:\n                return\n            aliases.append(key)\n\n        _add(motion_key)\n        if isinstance(meta, Mapping):\n            _add(meta.get(\"motion_key\"))\n            metadata = meta.get(\"metadata\")\n            if isinstance(metadata, Mapping):\n                _add(metadata.get(\"motion_key\"))\n                _add(metadata.get(\"raw_motion_key\"))\n        return aliases\n\n    def _count_windows(clip_length: int) -> Tuple[int, int]:\n        remaining = clip_length\n        offset = 0\n        num_windows = 0\n        num_frames = 0\n        while remaining > 0:\n            window_length = min(max_frame_length, remaining)\n            if window_length >= min_window_length:\n                num_windows += 1\n                num_frames += int(window_length)\n            offset += int(window_length)\n            remaining = max(0, clip_length - offset)\n        return num_windows, num_frames\n\n    stats_by_manifest: Dict[str, Dict[str, float]] = {}\n    for mp in manifest_paths:\n        with open(mp, \"r\", encoding=\"utf-8\") as handle:\n            manifest = json.load(handle)\n        clips = manifest.get(\"clips\", {})\n        if not clips:\n            raise ValueError(\n                f\"Manifest at {mp} contains no clips; cannot preview sampling.\"\n            )\n        num_windows = 0\n        num_frames = 0\n        duration_s = 0.0\n        for key, meta in clips.items():\n            if isinstance(meta, Mapping):\n                aliases = _build_aliases(key, meta)\n            else:\n                aliases = [key]\n            if handpicked_set is not None and not any(\n                alias in handpicked_set for alias in aliases\n            ):\n                continue\n            if excluded_set is not None and any(\n                alias in excluded_set for alias in aliases\n            ):\n                continue\n            length = (\n                int(meta.get(\"length\", 0)) if isinstance(meta, Mapping) else 0\n            )\n            if length <= 0:\n                continue\n            metadata = (\n                meta.get(\"metadata\") if isinstance(meta, Mapping) else None\n            )\n            motion_fps_val = None\n            if isinstance(metadata, Mapping):\n                motion_fps_val = metadata.get(\"motion_fps\")\n            if motion_fps_val is None and isinstance(meta, Mapping):\n                motion_fps_val = meta.get(\"motion_fps\")\n            if motion_fps_val is None:\n                raise ValueError(\n                    f\"motion_fps missing for clip {key} in manifest {mp}\"\n                )\n            motion_fps = float(motion_fps_val)\n            if motion_fps <= 0.0:\n                raise ValueError(\n                    f\"Invalid motion_fps {motion_fps} for clip {key} in {mp}\"\n                )\n            clip_windows, clip_frames = _count_windows(length)\n            num_windows += int(clip_windows)\n            num_frames += int(clip_frames)\n            duration_s += float(clip_frames) / float(motion_fps)\n        stats_by_manifest[mp] = {\n            \"num_windows\": float(num_windows),\n            \"num_frames\": float(num_frames),\n            \"duration_s\": float(duration_s),\n        }\n\n    total_windows = int(\n        sum(stats[\"num_windows\"] for stats in stats_by_manifest.values())\n    )\n    if total_windows == 0:\n        raise ValueError(\n            \"No motion windows satisfy the requested frame length constraints\"\n        )\n\n    table_rows = []\n    denom = float(max(1, total_windows))\n    for mp in manifest_paths:\n        stats = stats_by_manifest.get(mp, {})\n        count = int(stats.get(\"num_windows\", 0))\n        frames = int(stats.get(\"num_frames\", 0))\n        duration_h = float(stats.get(\"duration_s\", 0.0)) / 3600.0\n        frac = float(count) / denom\n        table_rows.append(\n            [\n                os.path.dirname(mp),\n                count,\n                f\"{frac:.4f}\",\n                frames,\n                f\"{duration_h:.2f}\",\n                f\"{frac:.4f}\",\n            ]\n        )\n    headers = [\n        \"dataset_root\",\n        \"num_windows\",\n        \"window_fraction\",\n        \"num_frames\",\n        \"duration_h\",\n        \"batch_fraction\",\n    ]\n    logger.info(\n        \"Uniform sampling preview (manifest-level):\\n\"\n        + tabulate(table_rows, headers=headers, tablefmt=\"simple_outline\")\n    )\n\n\ndef preview_sampling_from_cfg(motion_cfg: Mapping[str, Any]) -> None:\n    \"\"\"Preview manifest-level sampling table for uniform/weighted-bin.\"\"\"\n    sampling_strategy_cfg = motion_cfg.get(\"sampling_strategy\", None)\n    if sampling_strategy_cfg is None:\n        sampling_strategy = \"uniform\"\n    else:\n        sampling_strategy = str(sampling_strategy_cfg).lower()\n    if sampling_strategy not in (\"uniform\", \"weighted_bin\", \"curriculum\"):\n        return\n\n    backend = str(motion_cfg.get(\"backend\", \"hdf5\")).lower()\n    if backend not in (\"hdf5\", \"hdf5_simple\", \"hdf5_v2\"):\n        return\n\n    train_roots = _normalize_root_list(\n        motion_cfg.get(\"train_hdf5_roots\", None)\n    )\n    if len(train_roots) == 0:\n        hdf5_root = motion_cfg.get(\"hdf5_root\", None)\n        if not hdf5_root:\n            return\n        train_roots = [str(hdf5_root)]\n    manifest_paths = [\n        os.path.join(str(root), \"manifest.json\") for root in train_roots\n    ]\n    cache_cfg = motion_cfg.get(\"cache\", {})\n    batch_size = int(cache_cfg.get(\"max_num_clips\", 1))\n\n    if sampling_strategy == \"weighted_bin\":\n        weighted_bin_cfg = dict(motion_cfg.get(\"weighted_bin\", {}))\n        preview_weighted_bin_from_manifest(\n            manifest_path=manifest_paths\n            if len(manifest_paths) > 1\n            else manifest_paths[0],\n            batch_size=batch_size,\n            cfg=weighted_bin_cfg,\n        )\n        return\n\n    max_frame_length = int(motion_cfg.get(\"max_frame_length\", 1))\n    min_window_length = int(motion_cfg.get(\"min_frame_length\", 1))\n    handpicked_motion_names = motion_cfg.get(\"handpicked_motion_names\", None)\n    excluded_motion_names = motion_cfg.get(\"excluded_motion_names\", None)\n    preview_uniform_from_manifest(\n        manifest_path=manifest_paths\n        if len(manifest_paths) > 1\n        else manifest_paths[0],\n        batch_size=batch_size,\n        max_frame_length=max_frame_length,\n        min_window_length=min_window_length,\n        handpicked_motion_names=handpicked_motion_names,\n        excluded_motion_names=excluded_motion_names,\n    )\n\n\nMANDATORY_DATASETS = {\n    \"dof_pos\": \"dof_pos\",\n    \"dof_vel\": \"dof_vel\",\n    \"rg_pos\": \"global_translation\",\n    \"rb_rot\": \"global_rotation_quat\",\n    \"body_vel\": \"global_velocity\",\n    \"body_ang_vel\": \"global_angular_velocity\",\n}\n\n\nclass _WorldFrameNormalizeTransform:\n    \"\"\"Normalize motion tensors into a canonical z-up world frame in-place.\"\"\"\n\n    @staticmethod\n    def _apply_prefix(\n        arrays: Dict[str, Tensor],\n        prefix: str,\n        *,\n        offset_xy: Tensor,\n        q_flat_wxyz: Tensor,\n        ref_rg_pos_shape: torch.Size,\n        ref_rb_rot_shape: torch.Size,\n    ) -> None:\n        pos_key = f\"{prefix}rg_pos\"\n        rot_key = f\"{prefix}rb_rot\"\n        vel_key = f\"{prefix}body_vel\"\n        ang_key = f\"{prefix}body_ang_vel\"\n        if (\n            pos_key not in arrays\n            or rot_key not in arrays\n            or vel_key not in arrays\n            or ang_key not in arrays\n        ):\n            return\n\n        pos = arrays[pos_key]\n        rot = arrays[rot_key]\n        vel = arrays[vel_key]\n        ang = arrays[ang_key]\n        if pos.shape != ref_rg_pos_shape or rot.shape != ref_rb_rot_shape:\n            return\n\n        # Center XY using canonical offset.\n        pos[..., 0] -= offset_xy[0]\n        pos[..., 1] -= offset_xy[1]\n\n        # Rotate vectors using shared quaternion utilities (WXYZ convention).\n        pos_flat = pos.reshape(-1, 3)\n        vel_flat = vel.reshape(-1, 3)\n        ang_flat = ang.reshape(-1, 3)\n        pos[:] = torch_utils.quat_apply(q_flat_wxyz, pos_flat).reshape_as(pos)\n        vel[:] = torch_utils.quat_apply(q_flat_wxyz, vel_flat).reshape_as(vel)\n        ang[:] = torch_utils.quat_apply(q_flat_wxyz, ang_flat).reshape_as(ang)\n\n        # Rotate orientations: q' = q_heading_inv * q.\n        rot_flat_xyzw = rot.reshape(-1, 4)\n        rot_flat_wxyz = torch_utils.xyzw_to_wxyz(rot_flat_xyzw)\n        rot_out_wxyz = torch_utils.quat_mul(q_flat_wxyz, rot_flat_wxyz)\n        rot[:] = torch_utils.wxyz_to_xyzw(rot_out_wxyz).reshape_as(rot)\n\n    def __call__(self, arrays: Dict[str, Tensor]) -> None:\n        if \"ref_rg_pos\" not in arrays or \"ref_rb_rot\" not in arrays:\n            raise ValueError(\"ref_rg_pos and ref_rb_rot are required\")\n        if \"ref_body_vel\" not in arrays or \"ref_body_ang_vel\" not in arrays:\n            raise ValueError(\"ref_body_vel and ref_body_ang_vel are required\")\n\n        rg_pos = arrays[\"ref_rg_pos\"]\n        rb_rot = arrays[\"ref_rb_rot\"]\n\n        # Root pose at frame 0, body 0 (XYZW quaternion, z-up).\n        p_root0 = rg_pos[0, 0]  # [3]\n        q_root0 = rb_rot[0, 0]  # [4]\n\n        # Compute XY offset from root at frame 0 (will be applied in _apply_to_set).\n        offset_xy = p_root0.clone()\n        offset_xy[2] = 0.0\n\n        # Extract yaw from q_root0 (XYZW) using z-up convention.\n        x = q_root0[0]\n        y = q_root0[1]\n        z = q_root0[2]\n        w = q_root0[3]\n        siny_cosp = 2.0 * (w * z + x * y)\n        cosy_cosp = w * w + x * x - y * y - z * z\n        yaw0 = torch.atan2(siny_cosp, cosy_cosp)\n\n        # Quaternion for rotation around +Z by -yaw0 (remove initial heading).\n        half = -0.5 * yaw0\n        sin_half = torch.sin(half)\n        cos_half = torch.cos(half)\n        q_heading_inv = torch.stack(\n            [\n                torch.zeros_like(sin_half),\n                torch.zeros_like(sin_half),\n                sin_half,\n                cos_half,\n            ],\n            dim=-1,\n        )  # [4], XYZW\n\n        t, b, _ = rg_pos.shape\n        q_flat = q_heading_inv.view(1, 1, 4).expand(t, b, 4).reshape(-1, 4)\n        q_flat_wxyz = torch_utils.xyzw_to_wxyz(q_flat)\n\n        for pfx in (\"ref_\", \"ft_ref_\"):\n            self._apply_prefix(\n                arrays,\n                pfx,\n                offset_xy=offset_xy,\n                q_flat_wxyz=q_flat_wxyz,\n                ref_rg_pos_shape=rg_pos.shape,\n                ref_rb_rot_shape=rb_rot.shape,\n            )\n\n\nclass _CpuFKTransform:\n    \"\"\"Compute FK on CPU and write ref_* tensors in-place.\"\"\"\n\n    def __init__(self, robot_file_path: str) -> None:\n        self._fk = HoloMotionFK(\n            robot_file_path=str(robot_file_path), device=torch.device(\"cpu\")\n        )\n        self._fk = self._fk.to(torch.device(\"cpu\"))\n\n    def __call__(\n        self,\n        arrays: Dict[str, Tensor],\n        fps: float,\n        prefix: str = \"ref_\",\n        vel_smoothing_sigma: float = 2.0,\n    ) -> None:\n        root_pos_key = f\"{prefix}root_pos\"\n        root_rot_key = f\"{prefix}root_rot\"\n        dof_pos_key = f\"{prefix}dof_pos\"\n        if (\n            root_pos_key not in arrays\n            or root_rot_key not in arrays\n            or dof_pos_key not in arrays\n        ):\n            raise KeyError(f\"Missing {prefix}root_* or {prefix}dof_pos for FK\")\n        with torch.no_grad():\n            fk_out = self._fk(\n                root_pos=arrays[root_pos_key][None, ...],\n                root_quat=arrays[root_rot_key][None, ...],\n                dof_pos=arrays[dof_pos_key][None, ...],\n                fps=float(fps),\n                vel_smoothing_sigma=float(vel_smoothing_sigma),\n                quat_format=\"xyzw\",\n            )\n        arrays[f\"{prefix}rg_pos\"] = fk_out[\"global_translation\"][0]\n        arrays[f\"{prefix}rb_rot\"] = fk_out[\"global_rotation_quat\"][0]\n        arrays[f\"{prefix}body_vel\"] = fk_out[\"global_velocity\"][0]\n        arrays[f\"{prefix}body_ang_vel\"] = fk_out[\"global_angular_velocity\"][0]\n        arrays[f\"{prefix}dof_vel\"] = fk_out[\"dof_vel\"][0]\n\n\n@dataclass\nclass MotionWindow:\n    \"\"\"Metadata describing a contiguous motion window within an HDF5 shard.\"\"\"\n\n    motion_key: str  # unique per window\n    shard_index: int\n    start: int\n    length: int\n    raw_motion_key: str  # original clip key\n    window_index: int\n\n\n@dataclass\nclass MotionClipSample:\n    \"\"\"In-memory representation of a motion window.\n\n    Attributes:\n        motion_key: Unique window identifier (includes slice info).\n        raw_motion_key: Original clip identifier from manifest.\n        tensors: Mapping from tensor name to data tensor of shape\n            ``[window_length, ...]`` (float32 unless specified otherwise).\n        length: Number of valid frames contained in the sample (``<=``\n            ``max_frame_length``).\n    \"\"\"\n\n    motion_key: str\n    raw_motion_key: str\n    window_index: int\n    tensors: Dict[str, Tensor]\n    length: int\n\n\n@dataclass\nclass ClipBatch:\n    \"\"\"Batch of motion clips ready for consumption by the environment.\n\n    Attributes:\n        tensors: Mapping from tensor name to tensor with shape\n            ``[batch_size, max_frame_length, ...]`` placed on the staging\n            device.\n        lengths: Valid frame counts per clip ``[batch_size]``.\n        motion_keys: List of motion keys corresponding to each clip.\n        max_frame_length: Fixed length configured for the cache.\n    \"\"\"\n\n    tensors: Dict[str, Tensor]\n    lengths: Tensor\n    motion_keys: List[str]\n    raw_motion_keys: List[str]\n    window_indices: Tensor\n    max_frame_length: int\n\n    @staticmethod\n    def collate_fn(samples: List[MotionClipSample]) -> \"ClipBatch\":\n        if len(samples) == 0:\n            raise ValueError(\n                \"ClipBatch collate_fn received an empty sample list\"\n            )\n\n        max_frame_length = max(\n            sample.tensors[\"ref_dof_pos\"].shape[0] for sample in samples\n        )\n        max_frame_length = int(max_frame_length)\n\n        batched_tensors: Dict[str, Tensor] = {}\n        lengths = torch.zeros(len(samples), dtype=torch.long)\n        motion_keys = []\n        raw_motion_keys = []\n        window_indices = torch.zeros(len(samples), dtype=torch.long)\n\n        for batch_idx, sample in enumerate(samples):\n            lengths[batch_idx] = sample.length\n            motion_keys.append(sample.motion_key)\n            raw_motion_keys.append(sample.raw_motion_key)\n            window_indices[batch_idx] = int(sample.window_index)\n\n            for name, tensor in sample.tensors.items():\n                if name not in batched_tensors:\n                    pad_shape = (\n                        len(samples),\n                        max_frame_length,\n                    ) + tensor.shape[1:]\n                    batched_tensors[name] = torch.zeros(\n                        pad_shape,\n                        dtype=tensor.dtype,\n                        device=tensor.device,\n                    )\n\n                target = batched_tensors[name]\n                valid_frames = sample.length\n                target[batch_idx, :valid_frames] = tensor\n\n                if valid_frames < max_frame_length and valid_frames > 0:\n                    target[batch_idx, valid_frames:] = tensor[valid_frames - 1]\n\n        return ClipBatch(\n            tensors=batched_tensors,\n            lengths=lengths,\n            motion_keys=motion_keys,\n            raw_motion_keys=raw_motion_keys,\n            window_indices=window_indices,\n            max_frame_length=max_frame_length,\n        )\n\n\nclass Hdf5RootDofDataset(Dataset[MotionClipSample]):\n    \"\"\"HDF5 dataset reading ref_root_* + ref_dof_pos only.\"\"\"\n\n    def __init__(\n        self,\n        manifest_path: str | Sequence[str],\n        max_frame_length: int,\n        min_window_length: int = 1,\n        handpicked_motion_names: Optional[List[str]] = None,\n        excluded_motion_names: Optional[List[str]] = None,\n        fk_robot_file_path: Optional[str] = None,\n        fk_vel_smoothing_sigma: float = 2.0,\n        fk_world_frame_normalization: bool = True,\n        online_filter_cfg: Optional[Mapping[str, Any]] = None,\n        allowed_prefixes: Optional[Sequence[str]] = None,\n    ) -> None:\n        super().__init__()\n        if max_frame_length <= 0:\n            raise ValueError(\"max_frame_length must be positive\")\n        if min_window_length <= 0:\n            raise ValueError(\"min_window_length must be positive\")\n\n        self.max_frame_length = int(max_frame_length)\n        self.min_window_length = int(min_window_length)\n        self.handpicked_motion_names = (\n            set(handpicked_motion_names)\n            if handpicked_motion_names is not None\n            else None\n        )\n        self.excluded_motion_names = (\n            set(excluded_motion_names)\n            if excluded_motion_names is not None\n            else None\n        )\n        self._fk_robot_file_path = (\n            str(fk_robot_file_path) if fk_robot_file_path is not None else \"\"\n        )\n        if not self._fk_robot_file_path:\n            raise ValueError(\"fk_robot_file_path is required for hdf5_v2 FK\")\n        self._fk_world_frame_normalization = bool(fk_world_frame_normalization)\n        self._fk_transform = _CpuFKTransform(self._fk_robot_file_path)\n        self._world_frame_transform = (\n            _WorldFrameNormalizeTransform()\n            if self._fk_world_frame_normalization\n            else None\n        )\n        self._fk_vel_smoothing_sigma = float(fk_vel_smoothing_sigma)\n        self._online_filter_cfg = _normalize_online_filter_cfg(\n            online_filter_cfg,\n            default_vel_smoothing_sigma=self._fk_vel_smoothing_sigma,\n        )\n        self._online_filter_enabled = bool(self._online_filter_cfg[\"enabled\"])\n        self._online_filter_butter_order = int(\n            self._online_filter_cfg[\"butter_order\"]\n        )\n        self._online_filter_cutoff_hz_pool = tuple(\n            float(v) for v in self._online_filter_cfg[\"butter_cutoff_hz_pool\"]\n        )\n        self._ref_vel_smoothing_sigma = float(\n            self._online_filter_cfg[\"ref_vel_smoothing_sigma\"]\n        )\n        self._ft_ref_vel_smoothing_sigma = float(\n            self._online_filter_cfg[\"ft_ref_vel_smoothing_sigma\"]\n        )\n        if allowed_prefixes is None:\n            self._allowed_prefixes = (\"ref_\", \"ft_ref_\")\n        else:\n            self._allowed_prefixes = tuple(str(v) for v in allowed_prefixes)\n        if \"ref_\" not in self._allowed_prefixes:\n            raise ValueError(\n                \"Hdf5RootDofDataset requires 'ref_' in allowed_prefixes\"\n            )\n\n        if isinstance(manifest_path, (str, os.PathLike)):\n            manifest_paths: List[str] = [str(manifest_path)]\n        else:\n            manifest_paths = [str(p) for p in manifest_path]\n        if len(manifest_paths) == 0:\n            raise ValueError(\"At least one manifest_path must be provided\")\n\n        self.hdf5_root = os.path.dirname(manifest_paths[0])\n        self._manifest_paths: List[str] = manifest_paths\n        self._shard_paths: List[str] = []\n        self.shards: List[Dict[str, Any]] = []\n        self.clips: Dict[str, Dict[str, Any]] = {}\n\n        for mp in manifest_paths:\n            if not os.path.exists(mp):\n                raise FileNotFoundError(\n                    f\"HDF5 manifest not found at {mp}. \"\n                    \"Please set robot.motion.hdf5_root/train_hdf5_roots \"\n                    \"to the correct path.\"\n                )\n            with open(mp, \"r\", encoding=\"utf-8\") as handle:\n                manifest = json.load(handle)\n\n            root = os.path.dirname(mp)\n            shards_local = list(manifest.get(\"hdf5_shards\", []))\n            clips_local = manifest.get(\"clips\", {})\n\n            shard_offset = len(self.shards)\n            for shard_meta in shards_local:\n                self.shards.append(shard_meta)\n                rel = shard_meta.get(\"file\", None)\n                if not isinstance(rel, str) or not rel:\n                    raise ValueError(\n                        f\"Shard entry in manifest {mp} is missing a valid 'file' field\"\n                    )\n                self._shard_paths.append(os.path.join(root, rel))\n\n            for key, meta in clips_local.items():\n                if key in self.clips:\n                    raise ValueError(\n                        f\"Duplicate motion clip key '{key}' found in multiple \"\n                        \"manifests; clip keys must be globally unique.\"\n                    )\n                meta_global = dict(meta)\n                meta_global[\"shard\"] = (\n                    int(meta_global.get(\"shard\", 0)) + shard_offset\n                )\n                self.clips[key] = meta_global\n\n        if len(self.shards) == 0:\n            raise ValueError(\n                f\"No HDF5 shards listed in manifests: {', '.join(manifest_paths)}\"\n            )\n\n        self.windows: List[MotionWindow] = self._enumerate_windows()\n        if len(self.windows) == 0:\n            raise ValueError(\n                \"No motion windows satisfy the requested frame length constraints\"\n            )\n\n        # Setting up hdf5 file handles management for bounded host-memory usage\n        self._file_handles: \"OrderedDict[int, h5py.File]\" = OrderedDict()\n        max_open_env = os.getenv(\"HOLOMOTION_HDF5_MAX_OPEN_SHARDS\")\n        if max_open_env is None:\n            self._h5_max_open_files = 16\n        else:\n            self._h5_max_open_files = max(1, int(max_open_env))\n        self._h5_access_counter = 0\n        self._h5_cleanup_interval = int(\n            1.0e6\n        )  # clean h5 handles every 1 million samples\n\n    def set_progress_counter(self, counter: Optional[mp.Value]) -> None:\n        self._progress_counter = counter\n\n    @staticmethod\n    def _normalize_motion_key(value: Any) -> Optional[str]:\n        if value is None:\n            return None\n        if isinstance(value, bytes):\n            value = value.decode(\"utf-8\")\n        if isinstance(value, str):\n            key = value\n        else:\n            key = str(value)\n        if not key:\n            return None\n        return key\n\n    def _build_motion_key_aliases(\n        self, motion_key: str, meta: Mapping[str, Any]\n    ) -> Tuple[str, ...]:\n        aliases: List[str] = []\n\n        def _add(value: Any) -> None:\n            key = self._normalize_motion_key(value)\n            if key is None:\n                return\n            if key in aliases:\n                return\n            aliases.append(key)\n\n        _add(motion_key)\n        if isinstance(meta, Mapping):\n            _add(meta.get(\"motion_key\"))\n            metadata = meta.get(\"metadata\")\n            if isinstance(metadata, Mapping):\n                _add(metadata.get(\"motion_key\"))\n                _add(metadata.get(\"raw_motion_key\"))\n        return tuple(aliases)\n\n    def _enumerate_windows(self) -> List[MotionWindow]:\n        windows: List[MotionWindow] = []\n        for motion_key, meta in self.clips.items():\n            aliases = self._build_motion_key_aliases(motion_key, meta)\n            if self.handpicked_motion_names is not None and not any(\n                alias in self.handpicked_motion_names for alias in aliases\n            ):\n                continue\n            if self.excluded_motion_names is not None and any(\n                alias in self.excluded_motion_names for alias in aliases\n            ):\n                continue\n\n            shard_index = int(meta.get(\"shard\", 0))\n            start = int(meta.get(\"start\", 0))\n            length = int(meta.get(\"length\", 0))\n\n            if length <= 0:\n                continue\n\n            remaining = length\n            offset = 0\n            window_index = 0\n            while remaining > 0:\n                window_length = min(self.max_frame_length, remaining)\n                if window_length >= self.min_window_length:\n                    win_start = start + offset\n                    unique_key = (\n                        f\"{motion_key}__start_{win_start}_len_{window_length}\"\n                    )\n                    windows.append(\n                        MotionWindow(\n                            motion_key=unique_key,\n                            shard_index=shard_index,\n                            start=win_start,\n                            length=window_length,\n                            raw_motion_key=motion_key,\n                            window_index=window_index,\n                        )\n                    )\n                    window_index += 1\n                offset += window_length\n                remaining = max(0, length - offset)\n        return windows\n\n    def __len__(self) -> int:\n        return len(self.windows)\n\n    @staticmethod\n    def _cast_motion_np(np_array: np.ndarray, name: str) -> Tensor:\n        if np_array.dtype == np.float32:\n            pass\n        elif np_array.dtype.kind == \"O\":\n            raise ValueError(f\"{name} has object dtype\")\n        elif np.issubdtype(np_array.dtype, np.integer):\n            logger.warning(\n                \"Casting {} from {} to float32.\", name, np_array.dtype\n            )\n            np_array = np_array.astype(np.float32, copy=False)\n        else:\n            raise ValueError(\n                f\"{name} has dtype {np_array.dtype}, expected float32 or integer.\"\n            )\n        return torch.from_numpy(np_array).to(torch.float32)\n\n    @staticmethod\n    def _make_scalar_metadata_tensor(value: float, length: int) -> Tensor:\n        return torch.full((int(length), 1), float(value), dtype=torch.float32)\n\n    def _sample_online_filter_cutoff_hz(self) -> float:\n        if not self._online_filter_enabled:\n            return 0.0\n        cutoff_pool = self._online_filter_cutoff_hz_pool\n        if len(cutoff_pool) == 0:\n            raise ValueError(\n                \"Online filter is enabled but butter_cutoff_hz_pool is empty\"\n            )\n        if len(cutoff_pool) == 1:\n            return cutoff_pool[0]\n        sample_idx = int(torch.randint(len(cutoff_pool), size=(1,)).item())\n        return cutoff_pool[sample_idx]\n\n    def _add_online_filtered_reference_tensors(\n        self,\n        arrays: Dict[str, Tensor],\n        fps: float,\n        cutoff_hz: float,\n    ) -> None:\n        filtered_inputs_np = butterworth_filter_root_dof_arrays(\n            arrays={\n                \"ref_root_pos\": arrays[\"ref_root_pos\"].cpu().numpy(),\n                \"ref_root_rot\": arrays[\"ref_root_rot\"].cpu().numpy(),\n                \"ref_dof_pos\": arrays[\"ref_dof_pos\"].cpu().numpy(),\n            },\n            fps=float(fps),\n            cutoff_hz=float(cutoff_hz),\n            order=self._online_filter_butter_order,\n        )\n        for tensor_name, np_array in filtered_inputs_np.items():\n            arrays[tensor_name] = torch.from_numpy(np_array).to(torch.float32)\n        self._fk_transform(\n            arrays,\n            fps,\n            prefix=\"ft_ref_\",\n            vel_smoothing_sigma=self._ft_ref_vel_smoothing_sigma,\n        )\n\n    @staticmethod\n    def _derive_root_state_tensors(\n        arrays: Dict[str, Tensor],\n        prefix: str = \"ref_\",\n    ) -> None:\n        rg_pos_key = f\"{prefix}rg_pos\"\n        rb_rot_key = f\"{prefix}rb_rot\"\n        body_vel_key = f\"{prefix}body_vel\"\n        body_ang_vel_key = f\"{prefix}body_ang_vel\"\n        if (\n            rg_pos_key not in arrays\n            or rb_rot_key not in arrays\n            or body_vel_key not in arrays\n            or body_ang_vel_key not in arrays\n        ):\n            return\n        # Keep root-level tensors consistent with the FK-derived body tensors.\n        arrays[f\"{prefix}root_pos\"] = arrays[rg_pos_key][:, 0, :]\n        arrays[f\"{prefix}root_rot\"] = arrays[rb_rot_key][:, 0, :]\n        arrays[f\"{prefix}root_vel\"] = arrays[body_vel_key][:, 0, :]\n        arrays[f\"{prefix}root_ang_vel\"] = arrays[body_ang_vel_key][:, 0, :]\n\n    def __getitem__(self, index: int) -> MotionClipSample:\n        window = self.windows[index]\n        shard_handle = self._get_shard_handle(window.shard_index)\n        start, end = window.start, window.start + window.length\n        arrays: Dict[str, Tensor] = {}\n\n        for dataset_name in (\"ref_root_pos\", \"ref_root_rot\", \"ref_dof_pos\"):\n            if dataset_name not in shard_handle:\n                raise KeyError(\n                    f\"Missing mandatory dataset '{dataset_name}' in shard index \"\n                    f\"{window.shard_index}\"\n                )\n            np_array = np.asarray(shard_handle[dataset_name][start:end, ...])\n            arrays[dataset_name] = self._cast_motion_np(np_array, dataset_name)\n\n        if \"frame_flag\" in shard_handle:\n            frame_flag_np = shard_handle[\"frame_flag\"][start:end]\n            if frame_flag_np.dtype.kind == \"O\":\n                raise ValueError(\"frame_flag has object dtype\")\n            frame_flag = torch.from_numpy(frame_flag_np).to(torch.long)\n        else:\n            frame_flag = torch.ones(window.length, dtype=torch.long)\n            if window.length > 1:\n                frame_flag[0] = 0\n                frame_flag[-1] = 2\n            elif window.length == 1:\n                frame_flag[0] = 2\n        arrays[\"frame_flag\"] = frame_flag\n\n        clip_meta = self.clips.get(window.raw_motion_key, {})\n        metadata = clip_meta.get(\"metadata\", {})\n        motion_fps_val = metadata.get(\n            \"motion_fps\", clip_meta.get(\"motion_fps\")\n        )\n        if motion_fps_val is None:\n            raise ValueError(\n                f\"motion_fps missing for clip {window.raw_motion_key}\"\n            )\n        motion_fps = float(motion_fps_val)\n        if motion_fps <= 0.0:\n            raise ValueError(\n                f\"Invalid motion_fps {motion_fps} for clip {window.raw_motion_key}\"\n            )\n        arrays[\"motion_fps\"] = self._make_scalar_metadata_tensor(\n            motion_fps, window.length\n        )\n        cutoff_hz = self._sample_online_filter_cutoff_hz()\n        arrays[\"filter_cutoff_hz\"] = self._make_scalar_metadata_tensor(\n            cutoff_hz, window.length\n        )\n\n        self._fk_transform(\n            arrays,\n            motion_fps,\n            vel_smoothing_sigma=self._ref_vel_smoothing_sigma,\n        )\n        if self._online_filter_enabled and \"ft_ref_\" in self._allowed_prefixes:\n            self._add_online_filtered_reference_tensors(\n                arrays,\n                motion_fps,\n                cutoff_hz,\n            )\n        if self._world_frame_transform is not None:\n            self._world_frame_transform(arrays)\n\n        self._derive_root_state_tensors(arrays, prefix=\"ref_\")\n        self._derive_root_state_tensors(arrays, prefix=\"ft_ref_\")\n\n        if self._progress_counter is not None:\n            with self._progress_counter.get_lock():\n                self._progress_counter.value += 1\n\n        return MotionClipSample(\n            motion_key=window.motion_key,\n            raw_motion_key=window.raw_motion_key,\n            window_index=int(index),\n            tensors=arrays,\n            length=window.length,\n        )\n\n    def _get_shard_handle(self, shard_index: int) -> h5py.File:\n        # periodically clean up the file handles\n        self._h5_access_counter += 1\n        if self._h5_access_counter >= self._h5_cleanup_interval:\n            self.close()\n            self._h5_access_counter = 0\n\n        if shard_index in self._file_handles:\n            handle = self._file_handles.pop(shard_index)\n            if handle.id:\n                self._file_handles[shard_index] = handle\n                return handle\n\n        if shard_index < 0 or shard_index >= len(self._shard_paths):\n            raise IndexError(\n                f\"Shard index {shard_index} out of range for \"\n                f\"{len(self._shard_paths)} available shards\"\n            )\n        shard_path = self._shard_paths[shard_index]\n        rdcc_nbytes_env = os.getenv(\"HOLOMOTION_HDF5_RDCC_NBYTES\")\n        if rdcc_nbytes_env is None:\n            rdcc_nbytes = 4 * 1024 * 1024\n        else:\n            rdcc_nbytes = int(rdcc_nbytes_env)\n        handle = h5py.File(\n            shard_path,\n            \"r\",\n            libver=\"latest\",\n            swmr=True,\n            rdcc_nbytes=rdcc_nbytes,\n            rdcc_w0=0.75,\n        )\n        if (\n            self._h5_max_open_files is not None\n            and len(self._file_handles) >= self._h5_max_open_files\n        ):\n            old_index, old_handle = self._file_handles.popitem(last=False)\n            old_handle.close()\n        self._file_handles[shard_index] = handle\n        return handle\n\n    def close(self) -> None:\n        logger.info(\"Clearing HDF5 file handles ...\")\n        for handle in self._file_handles.values():\n            if handle.id:\n                handle.close()\n        self._file_handles.clear()\n\n    def __del__(self) -> None:\n        self.close()\n\n\ndef _normalize_root_list(value: Any) -> List[str]:\n    if value is None:\n        return []\n    if isinstance(value, (str, os.PathLike)):\n        return [str(value)]\n    return [str(v) for v in value]\n\n\ndef build_motion_datasets_from_cfg(\n    motion_cfg: Mapping[str, Any],\n    *,\n    max_frame_length: int,\n    min_window_length: int,\n    world_frame_normalization: bool = True,\n    handpicked_motion_names: Optional[List[str]] = None,\n    excluded_motion_names: Optional[List[str]] = None,\n    allowed_prefixes: Optional[Sequence[str]] = None,\n) -> Tuple[\n    Dataset[MotionClipSample],\n    Optional[Dataset[MotionClipSample]],\n    Dict[str, Any],\n]:\n    preview_sampling_from_cfg(motion_cfg=motion_cfg)\n    backend = str(motion_cfg.get(\"backend\", \"hdf5\")).lower()\n    if backend in (\"hdf5\", \"hdf5_simple\"):\n        train_roots = _normalize_root_list(\n            motion_cfg.get(\"train_hdf5_roots\", None)\n        )\n        if len(train_roots) == 0:\n            hdf5_root = motion_cfg.get(\"hdf5_root\", None)\n            if not hdf5_root:\n                raise ValueError(\n                    \"HDF5 backend requires train_hdf5_roots or hdf5_root\"\n                )\n            train_roots = [str(hdf5_root)]\n        manifest_paths = [\n            os.path.join(str(root), \"manifest.json\") for root in train_roots\n        ]\n        train_dataset = Hdf5MotionDataset(\n            manifest_path=manifest_paths\n            if len(manifest_paths) > 1\n            else manifest_paths[0],\n            max_frame_length=max_frame_length,\n            min_window_length=min_window_length,\n            handpicked_motion_names=handpicked_motion_names,\n            excluded_motion_names=excluded_motion_names,\n            world_frame_normalization=world_frame_normalization,\n            allowed_prefixes=allowed_prefixes,\n        )\n\n        val_roots = _normalize_root_list(\n            motion_cfg.get(\"val_hdf5_roots\", motion_cfg.get(\"val_hdf5_root\"))\n        )\n        val_dataset = None\n        if len(val_roots) > 0:\n            val_manifest_paths = [\n                os.path.join(str(root), \"manifest.json\") for root in val_roots\n            ]\n            val_dataset = Hdf5MotionDataset(\n                manifest_path=val_manifest_paths\n                if len(val_manifest_paths) > 1\n                else val_manifest_paths[0],\n                max_frame_length=max_frame_length,\n                min_window_length=min_window_length,\n                handpicked_motion_names=handpicked_motion_names,\n                excluded_motion_names=excluded_motion_names,\n                world_frame_normalization=world_frame_normalization,\n                allowed_prefixes=allowed_prefixes,\n            )\n        return train_dataset, val_dataset, {}\n\n    if backend == \"hdf5_v2\":\n        fk_robot_file_path = motion_cfg.get(\"fk_robot_file_path\")\n        fk_vel_smoothing_sigma = float(\n            motion_cfg.get(\"fk_vel_smoothing_sigma\", 2.0)\n        )\n        fk_world_frame_normalization = bool(\n            motion_cfg.get(\"online_fk_world_frame_normalization\", True)\n        )\n        cache_cfg = motion_cfg.get(\"cache\", {})\n        allowed_prefixes = cache_cfg.get(\n            \"allowed_prefixes\",\n            [\"ref_\", \"ft_ref_\"],\n        )\n        online_filter_cfg = motion_cfg.get(\"online_filter\", {})\n        train_roots = _normalize_root_list(\n            motion_cfg.get(\"train_hdf5_roots\", None)\n        )\n        if len(train_roots) == 0:\n            hdf5_root = motion_cfg.get(\"hdf5_root\", None)\n            if not hdf5_root:\n                raise ValueError(\n                    \"HDF5 v2 backend requires train_hdf5_roots or hdf5_root\"\n                )\n            train_roots = [str(hdf5_root)]\n        train_manifest_paths = [\n            os.path.join(str(root), \"manifest.json\") for root in train_roots\n        ]\n        train_dataset = Hdf5RootDofDataset(\n            manifest_path=train_manifest_paths\n            if len(train_manifest_paths) > 1\n            else train_manifest_paths[0],\n            max_frame_length=max_frame_length,\n            min_window_length=min_window_length,\n            handpicked_motion_names=handpicked_motion_names,\n            excluded_motion_names=excluded_motion_names,\n            fk_robot_file_path=fk_robot_file_path,\n            fk_vel_smoothing_sigma=fk_vel_smoothing_sigma,\n            fk_world_frame_normalization=fk_world_frame_normalization,\n            online_filter_cfg=online_filter_cfg,\n            allowed_prefixes=allowed_prefixes,\n        )\n\n        val_roots = _normalize_root_list(\n            motion_cfg.get(\"val_hdf5_roots\", motion_cfg.get(\"val_hdf5_root\"))\n        )\n        val_dataset = None\n        if len(val_roots) > 0:\n            val_manifest_paths = [\n                os.path.join(str(root), \"manifest.json\") for root in val_roots\n            ]\n            val_dataset = Hdf5RootDofDataset(\n                manifest_path=val_manifest_paths\n                if len(val_manifest_paths) > 1\n                else val_manifest_paths[0],\n                max_frame_length=max_frame_length,\n                min_window_length=min_window_length,\n                handpicked_motion_names=handpicked_motion_names,\n                excluded_motion_names=excluded_motion_names,\n                fk_robot_file_path=fk_robot_file_path,\n                fk_vel_smoothing_sigma=fk_vel_smoothing_sigma,\n                fk_world_frame_normalization=fk_world_frame_normalization,\n                online_filter_cfg=online_filter_cfg,\n                allowed_prefixes=allowed_prefixes,\n            )\n        cache_kwargs = {\n            \"stage_on_swap_only\": bool(\n                motion_cfg.get(\"stage_on_swap_only\", True)\n            )\n        }\n        return train_dataset, val_dataset, cache_kwargs\n\n    raise ValueError(f\"Unsupported motion backend: {backend}\")\n\n\ndef _cache_collate_fn(\n    samples: List[MotionClipSample],\n    mode: str,\n    batch_size: int,\n) -> ClipBatch:\n    \"\"\"Collate function for motion cache DataLoader (supports validation padding).\"\"\"\n    if mode == \"val\" and batch_size > len(samples) and len(samples) > 0:\n        extra = batch_size - len(samples)\n        gen = torch.Generator()\n        idx = torch.randint(0, len(samples), size=(extra,), generator=gen)\n        padded = list(samples)\n        for i in idx.tolist():\n            padded.append(samples[i])\n        return ClipBatch.collate_fn(padded)\n    return ClipBatch.collate_fn(samples)\n\n\nclass InfiniteDistributedSampler(DistributedSampler):\n    \"\"\"Distributed sampler that yields an infinite stream by cycling epochs.\"\"\"\n\n    def __iter__(self):\n        # Infinite stream by cycling epochs\n        while True:\n            self.set_epoch(getattr(self, \"_epoch\", 0))\n            for idx in super().__iter__():\n                yield idx\n            self._epoch = getattr(self, \"_epoch\", 0) + 1\n\n\nclass InfiniteRandomSampler(Sampler[int]):\n    \"\"\"Random sampler that yields infinite reshuffled passes over the dataset.\"\"\"\n\n    def __init__(self, data_source: Dataset, seed: int = 0) -> None:\n        self.data_source = data_source\n        self.seed = int(seed)\n        self.epoch = 0\n\n    def __iter__(self):\n        # Yield infinite permutations of indices\n        while True:\n            g = torch.Generator()\n            g.manual_seed(self.seed + self.epoch)\n            perm = torch.randperm(len(self.data_source), generator=g)\n            for idx in perm.tolist():\n                yield int(idx)\n            self.epoch += 1\n\n    def __len__(self) -> int:\n        # Large sentinel to satisfy components that query length\n        return 2**31 - 1\n\n\nclass WeightedBinInfiniteSampler(Sampler[int]):\n    \"\"\"Infinite sampler that respects regex-based weighted bins over indices.\"\"\"\n\n    def __init__(\n        self,\n        dataset_len: int,\n        bin_indices: List[List[int]],\n        ratios: List[float],\n        batch_size: int,\n        seed: int,\n    ) -> None:\n        self._ds_len = int(max(0, dataset_len))\n        self._bins = [torch.tensor(b, dtype=torch.long) for b in bin_indices]\n        self._ratios = list(ratios)\n        self._batch_size = int(max(1, batch_size))\n        self._seed = int(seed)\n        self._epoch = 0\n\n        raw_counts = [r * float(self._batch_size) for r in self._ratios]\n        self._counts = _allocate_batch_counts(\n            raw_counts=raw_counts,\n            target_total=self._batch_size,\n        )\n\n    def __iter__(self):\n        while True:\n            g = torch.Generator()\n            g.manual_seed(self._seed + self._epoch)\n            batch: List[int] = []\n            for bin_idx, count in zip(self._bins, self._counts):\n                if count <= 0 or bin_idx.numel() == 0:\n                    continue\n                choice = torch.randint(\n                    0,\n                    int(bin_idx.numel()),\n                    size=(count,),\n                    generator=g,\n                )\n                selected = bin_idx[choice].tolist()\n                batch.extend(int(x) for x in selected)\n\n            if not batch:\n                # Fallback: uniform over dataset indices\n                if self._ds_len == 0:\n                    raise ValueError(\n                        \"WeightedBinInfiniteSampler cannot sample from an empty dataset\"\n                    )\n                all_idx = torch.randint(\n                    0,\n                    self._ds_len,\n                    size=(self._batch_size,),\n                    generator=g,\n                )\n                batch = [int(x) for x in all_idx.tolist()]\n\n            if len(batch) > self._batch_size:\n                batch = batch[: self._batch_size]\n            elif len(batch) < self._batch_size:\n                pad = self._batch_size - len(batch)\n                if pad > 0:\n                    batch.extend(batch[:pad])\n\n            perm = torch.randperm(len(batch), generator=g)\n            for idx in perm.tolist():\n                yield int(batch[idx])\n            self._epoch += 1\n\n    def __len__(self) -> int:\n        return 2**31 - 1\n\n\nclass PrioritizedInfiniteSampler(Sampler[int]):\n    \"\"\"Infinite sampler with persistent prioritized and fresh uniform pools.\"\"\"\n\n    def __init__(\n        self,\n        dataset_len: int,\n        batch_size: int,\n        seed: int,\n        *,\n        p_a_ratio: float = 0.2,\n        ema_alpha_signal: float = 0.2,\n        ema_alpha_rel_improve: float = 0.2,\n        relative_eps: float = 1.0e-6,\n    ) -> None:\n        self._ds_len = int(max(0, dataset_len))\n        self._batch_size = int(max(1, batch_size))\n        self._seed = int(seed)\n        self._epoch = 0\n\n        self._p_a_ratio = float(min(1.0, max(0.0, p_a_ratio)))\n        self._ema_alpha_signal = float(min(1.0, max(0.0, ema_alpha_signal)))\n        self._ema_alpha_rel_improve = float(\n            min(1.0, max(0.0, ema_alpha_rel_improve))\n        )\n        self._relative_eps = float(max(1.0e-12, relative_eps))\n\n        if self._ds_len <= 0:\n            self._ema_completion_rate = torch.zeros(0, dtype=torch.float32)\n            self._ema_completion_rate_sq = torch.zeros(0, dtype=torch.float32)\n            self._ema_completion_rel_improve = torch.zeros(\n                0, dtype=torch.float32\n            )\n            self._selection_counts = torch.zeros(0, dtype=torch.long)\n            self._seen_mask = torch.zeros(0, dtype=torch.bool)\n            self._prioritized_pool_indices = torch.zeros(0, dtype=torch.long)\n            self._prioritized_pool_mask = torch.zeros(0, dtype=torch.bool)\n        else:\n            self._ema_completion_rate = torch.zeros(\n                self._ds_len, dtype=torch.float32\n            )\n            self._ema_completion_rate_sq = torch.zeros(\n                self._ds_len, dtype=torch.float32\n            )\n            self._ema_completion_rel_improve = torch.zeros(\n                self._ds_len, dtype=torch.float32\n            )\n            self._selection_counts = torch.zeros(\n                self._ds_len, dtype=torch.long\n            )\n            self._seen_mask = torch.zeros(self._ds_len, dtype=torch.bool)\n            self._prioritized_pool_indices = torch.zeros(0, dtype=torch.long)\n            self._prioritized_pool_mask = torch.zeros(\n                self._ds_len, dtype=torch.bool\n            )\n        self._state_version = 0\n        self._last_updated_swap = -1\n        self._last_prioritized_pool_mean_score = 0.0\n        self._last_uniform_pool_mean_score = 0.0\n        self._last_entered_prioritized_pool_count = 0\n        self._last_exited_prioritized_pool_count = 0\n        self._uniform_cycle_start = 0\n        self._uniform_cycle_step = 1\n        self._uniform_cycle_offset = self._ds_len\n        self._uniform_cycle_epoch = 0\n\n    @property\n    def state_version(self) -> int:\n        return int(self._state_version)\n\n    def get_pool_statistics(self) -> Optional[Dict[str, float]]:\n        if self._ds_len <= 0:\n            return None\n        return self._pool_metric_stats()\n\n    @staticmethod\n    def _aggregate_by_index(\n        window_indices: Tensor,\n        values: Tensor,\n        counts: Tensor,\n    ) -> Tuple[Tensor, Tensor, Tensor]:\n        if window_indices.numel() == 0:\n            return (\n                torch.zeros(0, dtype=torch.long),\n                torch.zeros(0, dtype=torch.float32),\n                torch.zeros(0, dtype=torch.float32),\n            )\n        unique_indices, inverse = torch.unique(\n            window_indices.to(dtype=torch.long),\n            sorted=False,\n            return_inverse=True,\n        )\n        out_weighted_sum = torch.zeros(\n            unique_indices.numel(), dtype=torch.float32\n        )\n        out_count = torch.zeros(unique_indices.numel(), dtype=torch.float32)\n        out_weighted_sum.scatter_add_(0, inverse, values * counts)\n        out_count.scatter_add_(0, inverse, counts)\n        return unique_indices, out_weighted_sum, out_count\n\n    def _pool_batch_sizes(self) -> Tuple[int, int]:\n        if self._ds_len <= 0:\n            return 0, 0\n        uniform_count = int(round(self._p_a_ratio * float(self._batch_size)))\n        uniform_count = max(0, min(self._batch_size, uniform_count))\n        prioritized_count = max(0, self._batch_size - uniform_count)\n        return uniform_count, prioritized_count\n\n    def _priority_scores_for_indices(self, indices: Tensor) -> Tensor:\n        if indices.numel() == 0 or self._ds_len <= 0:\n            return torch.zeros(0, dtype=torch.float32)\n        idx = indices.to(dtype=torch.long)\n        progress = torch.clamp(\n            self._ema_completion_rel_improve.index_select(0, idx),\n            min=0.0,\n            max=1.0,\n        )\n        remaining_difficulty = torch.clamp(\n            1.0 - self._ema_completion_rate.index_select(0, idx),\n            min=0.0,\n            max=1.0,\n        )\n        seen = self._seen_mask.index_select(0, idx).to(dtype=torch.float32)\n        return progress * remaining_difficulty * seen\n\n    def _pool_metric_stats(self) -> Dict[str, float]:\n        prioritized_pool_size = int(self._prioritized_pool_indices.numel())\n        return {\n            \"prioritized_pool_size\": float(prioritized_pool_size),\n            \"prioritized_pool_mean_score\": float(\n                self._last_prioritized_pool_mean_score\n            ),\n            \"uniform_pool_mean_score\": float(\n                self._last_uniform_pool_mean_score\n            ),\n            \"entered_prioritized_pool_count\": float(\n                self._last_entered_prioritized_pool_count\n            ),\n            \"exited_prioritized_pool_count\": float(\n                self._last_exited_prioritized_pool_count\n            ),\n        }\n\n    def get_window_state_for_indices(\n        self, window_indices: Tensor\n    ) -> Dict[str, Tensor]:\n        if self._ds_len <= 0:\n            empty_bool = torch.zeros(0, dtype=torch.bool)\n            empty_float = torch.zeros(0, dtype=torch.float32)\n            return {\n                \"ema_completion_rate\": empty_float,\n                \"completion_rate_rel_improve\": empty_float,\n                \"selection_count\": torch.zeros(0, dtype=torch.long),\n                \"seen\": empty_bool,\n                \"in_prioritized_pool\": empty_bool,\n            }\n        idx = window_indices.detach().to(dtype=torch.long).reshape(-1).cpu()\n        if idx.numel() == 0:\n            empty_bool = torch.zeros(0, dtype=torch.bool)\n            empty_float = torch.zeros(0, dtype=torch.float32)\n            return {\n                \"ema_completion_rate\": empty_float,\n                \"completion_rate_rel_improve\": empty_float,\n                \"selection_count\": torch.zeros(0, dtype=torch.long),\n                \"seen\": empty_bool,\n                \"in_prioritized_pool\": empty_bool,\n            }\n        return {\n            \"ema_completion_rate\": self._ema_completion_rate.index_select(\n                0, idx\n            ).to(dtype=torch.float32),\n            \"completion_rate_rel_improve\": (\n                self._ema_completion_rel_improve.index_select(0, idx).to(\n                    dtype=torch.float32\n                )\n            ),\n            \"selection_count\": self._selection_counts.index_select(0, idx),\n            \"seen\": self._seen_mask.index_select(0, idx),\n            \"in_prioritized_pool\": self._prioritized_pool_mask.index_select(\n                0, idx\n            ),\n        }\n\n    def _rebuild_prioritized_pool(self, candidate_indices: Tensor) -> None:\n        if self._ds_len <= 0:\n            return\n        _, prioritized_count = self._pool_batch_sizes()\n        previous_indices = self._prioritized_pool_indices\n        selected = torch.zeros(0, dtype=torch.long)\n        if prioritized_count > 0:\n            candidates = torch.cat(\n                [\n                    previous_indices.to(dtype=torch.long),\n                    candidate_indices.to(dtype=torch.long).reshape(-1),\n                ]\n            )\n            candidates = torch.unique(candidates, sorted=False)\n            scores = self._priority_scores_for_indices(candidates)\n            positive = scores > 0.0\n            if bool(positive.any().item()):\n                candidates = candidates[positive]\n                scores = scores[positive]\n                order = torch.argsort(scores, descending=True)\n                selected = candidates.index_select(\n                    0, order[: min(prioritized_count, candidates.numel())]\n                )\n                scores = scores.index_select(\n                    0, order[: min(prioritized_count, scores.numel())]\n                )\n                self._last_prioritized_pool_mean_score = float(\n                    scores.mean().item()\n                )\n            else:\n                self._last_prioritized_pool_mean_score = 0.0\n            if candidates.numel() > selected.numel():\n                selected_mask = torch.zeros(\n                    candidates.numel(), dtype=torch.bool\n                )\n                if selected.numel() > 0:\n                    matches = candidates[:, None] == selected[None, :]\n                    selected_mask = matches.any(dim=1)\n                nonselected_scores = self._priority_scores_for_indices(\n                    candidates[~selected_mask]\n                )\n                self._last_uniform_pool_mean_score = (\n                    float(nonselected_scores.mean().item())\n                    if nonselected_scores.numel() > 0\n                    else 0.0\n                )\n            else:\n                self._last_uniform_pool_mean_score = 0.0\n        else:\n            self._last_prioritized_pool_mean_score = 0.0\n            self._last_uniform_pool_mean_score = 0.0\n        if previous_indices.numel() > 0:\n            self._prioritized_pool_mask[previous_indices] = False\n        if selected.numel() > 0:\n            self._prioritized_pool_mask[selected] = True\n        previous_set = set(previous_indices.tolist())\n        selected_set = set(selected.tolist())\n        self._last_entered_prioritized_pool_count = len(\n            selected_set - previous_set\n        )\n        self._last_exited_prioritized_pool_count = len(\n            previous_set - selected_set\n        )\n        self._prioritized_pool_indices = selected\n\n    def maybe_update_from_observations(\n        self,\n        *,\n        window_indices: Tensor,\n        mpkpe_signal_means: Tensor,\n        completion_rate_means: Tensor,\n        counts: Tensor,\n        swap_index: int,\n    ) -> bool:\n        if self._ds_len <= 0:\n            return False\n        swap_idx = int(swap_index)\n        if swap_idx <= 0:\n            return False\n        if self._last_updated_swap == swap_idx:\n            return False\n\n        indices = (\n            window_indices.detach().to(dtype=torch.long).reshape(-1).cpu()\n        )\n        # Keep validating the MPKPE tensor shape so the command-side\n        # curriculum aggregation stays aligned with completion-rate updates.\n        mpkpe_signal_numel = int(mpkpe_signal_means.numel())\n        completion_rate = (\n            completion_rate_means.detach()\n            .to(dtype=torch.float32)\n            .reshape(-1)\n            .cpu()\n        )\n        cnt = counts.detach().to(dtype=torch.float32).reshape(-1).cpu()\n        if not (\n            indices.numel() == mpkpe_signal_numel\n            and mpkpe_signal_numel == completion_rate.numel()\n            and completion_rate.numel() == cnt.numel()\n        ):\n            raise ValueError(\n                \"Prioritized sampler update tensors must have matching shape.\"\n            )\n\n        valid_dataset_idx = (indices >= 0) & (indices < self._ds_len)\n        valid = (\n            valid_dataset_idx & torch.isfinite(completion_rate) & (cnt > 0.0)\n        )\n        current_batch_indices = torch.unique(\n            indices[valid_dataset_idx], sorted=False\n        )\n        if not bool(valid.any().item()):\n            self._last_entered_prioritized_pool_count = 0\n            self._last_exited_prioritized_pool_count = 0\n            self._last_updated_swap = swap_idx\n            return False\n\n        idx_valid = indices[valid]\n        completion_rate_valid = completion_rate[valid]\n        cnt_valid = cnt[valid]\n\n        touched_idx, completion_rate_sum, completion_rate_count_sum = (\n            self._aggregate_by_index(\n                idx_valid,\n                completion_rate_valid,\n                cnt_valid,\n            )\n        )\n        if touched_idx.numel() == 0:\n            self._last_entered_prioritized_pool_count = 0\n            self._last_exited_prioritized_pool_count = 0\n            self._last_updated_swap = swap_idx\n            return False\n\n        completion_rate_obs = (\n            completion_rate_sum / completion_rate_count_sum.clamp_min(1.0e-12)\n        )\n        completion_rate_obs = torch.clamp(\n            completion_rate_obs, min=0.0, max=1.0\n        )\n\n        prev_seen = self._seen_mask[touched_idx]\n        prev_completion_rate = self._ema_completion_rate[touched_idx]\n        prev_completion_rate_sq = self._ema_completion_rate_sq[touched_idx]\n        prev_completion_rate_var = torch.clamp(\n            prev_completion_rate_sq\n            - prev_completion_rate * prev_completion_rate,\n            min=1.0e-6,\n        )\n        prev_completion_rate_std = torch.sqrt(prev_completion_rate_var)\n        next_completion_rate = torch.where(\n            prev_seen,\n            (1.0 - self._ema_alpha_signal) * prev_completion_rate\n            + self._ema_alpha_signal * completion_rate_obs,\n            completion_rate_obs,\n        )\n        next_completion_rate_sq = torch.where(\n            prev_seen,\n            (1.0 - self._ema_alpha_signal) * prev_completion_rate_sq\n            + self._ema_alpha_signal\n            * (completion_rate_obs * completion_rate_obs),\n            completion_rate_obs * completion_rate_obs,\n        )\n\n        completion_rel_improve_obs = torch.zeros_like(next_completion_rate)\n        completion_rel_improve_obs[prev_seen] = torch.tanh(\n            (completion_rate_obs[prev_seen] - prev_completion_rate[prev_seen])\n            / (prev_completion_rate_std[prev_seen] + self._relative_eps)\n        )\n        prev_completion_rel = self._ema_completion_rel_improve[touched_idx]\n        next_completion_rel = torch.where(\n            prev_seen,\n            (1.0 - self._ema_alpha_rel_improve) * prev_completion_rel\n            + self._ema_alpha_rel_improve * completion_rel_improve_obs,\n            completion_rel_improve_obs,\n        )\n\n        self._ema_completion_rate[touched_idx] = next_completion_rate\n        self._ema_completion_rate_sq[touched_idx] = next_completion_rate_sq\n        self._ema_completion_rel_improve[touched_idx] = next_completion_rel\n        self._seen_mask[touched_idx] = True\n\n        self._rebuild_prioritized_pool(touched_idx)\n        self._state_version += 1\n        self._last_updated_swap = swap_idx\n        return True\n\n    def _reset_uniform_cycle(self) -> None:\n        if self._ds_len <= 0:\n            self._uniform_cycle_start = 0\n            self._uniform_cycle_step = 1\n            self._uniform_cycle_offset = 0\n            return\n        generator = torch.Generator()\n        generator.manual_seed(self._seed + self._uniform_cycle_epoch * 1000003)\n        self._uniform_cycle_epoch += 1\n        self._uniform_cycle_start = int(\n            torch.randint(\n                low=0,\n                high=self._ds_len,\n                size=(1,),\n                generator=generator,\n            ).item()\n        )\n        if self._ds_len <= 1:\n            self._uniform_cycle_step = 1\n        else:\n            step = int(\n                torch.randint(\n                    low=1,\n                    high=self._ds_len,\n                    size=(1,),\n                    generator=generator,\n                ).item()\n            )\n            while math.gcd(step, self._ds_len) != 1:\n                step += 1\n                if step >= self._ds_len:\n                    step = 1\n            self._uniform_cycle_step = step\n        self._uniform_cycle_offset = 0\n\n    def _next_uniform_index(self) -> int:\n        if self._uniform_cycle_offset >= self._ds_len:\n            self._reset_uniform_cycle()\n        next_index = (\n            self._uniform_cycle_start\n            + self._uniform_cycle_offset * self._uniform_cycle_step\n        ) % self._ds_len\n        self._uniform_cycle_offset += 1\n        return int(next_index)\n\n    def _sample_uniform_indices(\n        self,\n        generator: torch.Generator,\n        count: int,\n        *,\n        exclude: Optional[Tensor] = None,\n    ) -> Tensor:\n        del generator\n        if count <= 0 or self._ds_len <= 0:\n            return torch.zeros(0, dtype=torch.long)\n        blocked = set()\n        if exclude is not None and exclude.numel() > 0:\n            blocked.update(\n                exclude.detach().to(dtype=torch.long).reshape(-1).tolist()\n            )\n        take = min(int(count), max(0, self._ds_len - len(blocked)))\n        if take <= 0:\n            return torch.zeros(0, dtype=torch.long)\n        selected: List[int] = []\n        stagnant_steps = 0\n        while len(selected) < take and stagnant_steps < self._ds_len:\n            next_index = self._next_uniform_index()\n            if next_index in blocked:\n                stagnant_steps += 1\n                continue\n            selected.append(next_index)\n            blocked.add(next_index)\n            stagnant_steps = 0\n        return torch.tensor(selected, dtype=torch.long)\n\n    def _sample_prioritized_indices(\n        self, generator: torch.Generator, count: int\n    ) -> Tensor:\n        if count <= 0 or self._prioritized_pool_indices.numel() == 0:\n            return torch.zeros(0, dtype=torch.long)\n        perm = torch.randperm(\n            self._prioritized_pool_indices.numel(), generator=generator\n        )\n        take = min(count, int(self._prioritized_pool_indices.numel()))\n        return self._prioritized_pool_indices.index_select(0, perm[:take])\n\n    def _sample_batch_indices(self, generator: torch.Generator) -> Tensor:\n        uniform_count, prioritized_count = self._pool_batch_sizes()\n        prioritized_indices = self._sample_prioritized_indices(\n            generator, prioritized_count\n        )\n        uniform_indices = self._sample_uniform_indices(\n            generator,\n            uniform_count,\n            exclude=prioritized_indices,\n        )\n        sampled_indices = torch.cat(\n            [uniform_indices, prioritized_indices], dim=0\n        )\n        if sampled_indices.numel() < self._batch_size:\n            extra_indices = self._sample_uniform_indices(\n                generator,\n                self._batch_size - int(sampled_indices.numel()),\n                exclude=sampled_indices,\n            )\n            sampled_indices = torch.cat(\n                [sampled_indices, extra_indices], dim=0\n            )\n        if sampled_indices.numel() != self._batch_size:\n            raise ValueError(\n                \"Prioritized sampler failed to assemble a full cache batch.\"\n            )\n\n        if sampled_indices.numel() > 0:\n            self._selection_counts[sampled_indices] += 1\n        return sampled_indices\n\n    def get_scores_for_indices(self, window_indices: Tensor) -> Tensor:\n        if self._ds_len <= 0:\n            return torch.zeros_like(window_indices, dtype=torch.float32)\n        idx = window_indices.detach().to(dtype=torch.long).reshape(-1).cpu()\n        if idx.numel() == 0:\n            return torch.zeros(0, dtype=torch.float32)\n        scores = self._priority_scores_for_indices(idx)\n        return scores.to(dtype=torch.float32)\n\n    def __iter__(self):\n        while True:\n            if self._ds_len <= 0:\n                raise ValueError(\n                    \"PrioritizedInfiniteSampler cannot sample from \"\n                    \"an empty dataset.\"\n                )\n            g = torch.Generator()\n            g.manual_seed(self._seed + self._epoch)\n            sampled_indices = self._sample_batch_indices(generator=g)\n            perm = torch.randperm(sampled_indices.numel(), generator=g)\n            yielded_indices = sampled_indices.index_select(0, perm)\n            for idx in yielded_indices.tolist():\n                yield int(idx)\n            self._epoch += 1\n\n    def __len__(self) -> int:\n        return 2**31 - 1\n\n\nclass Hdf5MotionDataset(Dataset[MotionClipSample]):\n    \"\"\"Dataset that materializes fixed-length motion windows from HDF5 shards.\"\"\"\n\n    def __init__(\n        self,\n        manifest_path: str | Sequence[str],\n        max_frame_length: int,\n        min_window_length: int = 1,\n        handpicked_motion_names: Optional[List[str]] = None,\n        excluded_motion_names: Optional[List[str]] = None,\n        world_frame_normalization: bool = True,\n        allowed_prefixes: Optional[Sequence[str]] = None,\n    ) -> None:\n        super().__init__()\n        if max_frame_length <= 0:\n            raise ValueError(\"max_frame_length must be positive\")\n\n        self.max_frame_length = int(max_frame_length)\n        self.min_window_length = int(min_window_length)\n        self.handpicked_motion_names = (\n            set(handpicked_motion_names)\n            if handpicked_motion_names is not None\n            else None\n        )\n        self.excluded_motion_names = (\n            set(excluded_motion_names)\n            if excluded_motion_names is not None\n            else None\n        )\n        self._world_frame_transform = (\n            _WorldFrameNormalizeTransform()\n            if bool(world_frame_normalization)\n            else None\n        )\n        self._allowed_prefixes: Tuple[str, ...] = (\"ref_\", \"ft_ref_\")\n        self._progress_counter: Optional[mp.Value] = None\n\n        # Normalize manifest path(s) to a list for aggregation.\n        if isinstance(manifest_path, (str, os.PathLike)):\n            manifest_paths: List[str] = [str(manifest_path)]\n        else:\n            manifest_paths = [str(p) for p in manifest_path]\n        if len(manifest_paths) == 0:\n            raise ValueError(\"At least one manifest_path must be provided\")\n\n        # Aggregate shards and clips across one or many manifests into a single\n        # logical dataset. Clip keys must be globally unique.\n        self.hdf5_root = os.path.dirname(manifest_paths[0])\n        self._manifest_paths: List[str] = manifest_paths\n        self._shard_paths: List[str] = []\n        self.shards: List[Dict[str, Any]] = []\n        self.clips: Dict[str, Dict[str, Any]] = {}\n\n        for mp in manifest_paths:\n            if not os.path.exists(mp):\n                raise FileNotFoundError(\n                    f\"HDF5 manifest not found at {mp}. \"\n                    \"Please set robot.motion.hdf5_root/train_hdf5_roots \"\n                    \"to the correct path.\"\n                )\n            with open(mp, \"r\", encoding=\"utf-8\") as handle:\n                manifest = json.load(handle)\n\n            root = os.path.dirname(mp)\n            shards_local = list(manifest.get(\"hdf5_shards\", []))\n            clips_local = manifest.get(\"clips\", {})\n\n            shard_offset = len(self.shards)\n            for shard_meta in shards_local:\n                self.shards.append(shard_meta)\n                rel = shard_meta.get(\"file\", None)\n                if not isinstance(rel, str) or not rel:\n                    raise ValueError(\n                        f\"Shard entry in manifest {mp} is missing a valid 'file' field\"\n                    )\n                self._shard_paths.append(os.path.join(root, rel))\n\n            for key, meta in clips_local.items():\n                if key in self.clips:\n                    raise ValueError(\n                        f\"Duplicate motion clip key '{key}' found in multiple \"\n                        \"manifests; clip keys must be globally unique.\"\n                    )\n                meta_global = dict(meta)\n                meta_global[\"shard\"] = (\n                    int(meta_global.get(\"shard\", 0)) + shard_offset\n                )\n                self.clips[key] = meta_global\n\n        if len(self.shards) == 0:\n            raise ValueError(\n                f\"No HDF5 shards listed in manifests: {', '.join(manifest_paths)}\"\n            )\n\n        self.windows: List[MotionWindow] = self._enumerate_windows()\n        if len(self.windows) == 0:\n            raise ValueError(\n                \"No motion windows satisfy the requested frame length constraints\"\n            )\n\n        # LRU cache of open HDF5 shard handles; size is bounded to avoid\n        # unbounded host-memory usage from per-file raw chunk caches.\n        self._file_handles: \"OrderedDict[int, h5py.File]\" = OrderedDict()\n        max_open_env = os.getenv(\"HOLOMOTION_HDF5_MAX_OPEN_SHARDS\")\n        if max_open_env is None:\n            self._max_open_files = 64\n        else:\n            self._max_open_files = max(1, int(max_open_env))\n\n    def set_progress_counter(self, counter: Optional[mp.Value]) -> None:\n        self._progress_counter = counter\n\n    def _enumerate_windows(self) -> List[MotionWindow]:\n        windows: List[MotionWindow] = []\n        for motion_key, meta in self.clips.items():\n            if (\n                self.handpicked_motion_names is not None\n                and motion_key not in self.handpicked_motion_names\n            ):\n                continue\n            if (\n                self.excluded_motion_names is not None\n                and motion_key in self.excluded_motion_names\n            ):\n                continue\n\n            shard_index = int(meta.get(\"shard\", 0))\n            start = int(meta.get(\"start\", 0))\n            length = int(meta.get(\"length\", 0))\n\n            if length <= 0:\n                continue\n\n            remaining = length\n            offset = 0\n            window_index = 0\n            while remaining > 0:\n                window_length = min(self.max_frame_length, remaining)\n                if window_length >= self.min_window_length:\n                    win_start = start + offset\n                    unique_key = (\n                        f\"{motion_key}__start_{win_start}_len_{window_length}\"\n                    )\n                    windows.append(\n                        MotionWindow(\n                            motion_key=unique_key,\n                            shard_index=shard_index,\n                            start=win_start,\n                            length=window_length,\n                            raw_motion_key=motion_key,\n                            window_index=window_index,\n                        )\n                    )\n                    window_index += 1\n                offset += window_length\n                remaining = max(0, length - offset)\n\n        return windows\n\n    def __len__(self) -> int:\n        return len(self.windows)\n\n    def __getitem__(self, index: int) -> MotionClipSample:\n        window = self.windows[index]\n        shard_handle = self._get_shard_handle(window.shard_index)\n        start, end = window.start, window.start + window.length\n\n        arrays: Dict[str, Tensor] = {}\n\n        # Mandatory reference source: ref_*\n        for logical_name, dataset_name in MANDATORY_DATASETS.items():\n            dname = f\"ref_{dataset_name}\"\n            if dname not in shard_handle:\n                raise KeyError(\n                    f\"Missing mandatory dataset '{dname}' in shard index {window.shard_index}\"\n                )\n            np_array = shard_handle[dname][start:end]\n            arrays[f\"ref_{logical_name}\"] = torch.from_numpy(np_array).to(\n                torch.float32\n            )\n\n        # Optional filtered reference source: ft_ref_*\n        for logical_name, dataset_name in MANDATORY_DATASETS.items():\n            dname = f\"ft_ref_{dataset_name}\"\n            if dname in shard_handle:\n                np_array = shard_handle[dname][start:end]\n                arrays[f\"ft_ref_{logical_name}\"] = torch.from_numpy(\n                    np_array\n                ).to(torch.float32)\n\n        if \"frame_flag\" in shard_handle:\n            frame_flag_np = shard_handle[\"frame_flag\"][start:end]\n            frame_flag = torch.from_numpy(frame_flag_np).to(torch.long)\n        else:\n            frame_flag = torch.ones(window.length, dtype=torch.long)\n            if window.length > 1:\n                frame_flag[0] = 0\n                frame_flag[-1] = 2\n            elif window.length == 1:\n                # Single-frame window: mark as both start and end (use 2 for end)\n                frame_flag[0] = 2\n        arrays[\"frame_flag\"] = frame_flag\n\n        if self._world_frame_transform is not None:\n            self._world_frame_transform(arrays)\n\n        # Derived root_* for ref_* (after normalization)\n        arrays[\"ref_root_pos\"] = arrays[\"ref_rg_pos\"][:, 0, :]\n        arrays[\"ref_root_rot\"] = arrays[\"ref_rb_rot\"][:, 0, :]\n        arrays[\"ref_root_vel\"] = arrays[\"ref_body_vel\"][:, 0, :]\n        arrays[\"ref_root_ang_vel\"] = arrays[\"ref_body_ang_vel\"][:, 0, :]\n\n        # Derived root_* for optional ft_ref_* (after normalization)\n        if (\n            \"ft_ref_rg_pos\" in arrays\n            and \"ft_ref_rb_rot\" in arrays\n            and \"ft_ref_body_vel\" in arrays\n            and \"ft_ref_body_ang_vel\" in arrays\n        ):\n            arrays[\"ft_ref_root_pos\"] = arrays[\"ft_ref_rg_pos\"][:, 0, :]\n            arrays[\"ft_ref_root_rot\"] = arrays[\"ft_ref_rb_rot\"][:, 0, :]\n            arrays[\"ft_ref_root_vel\"] = arrays[\"ft_ref_body_vel\"][:, 0, :]\n            arrays[\"ft_ref_root_ang_vel\"] = arrays[\"ft_ref_body_ang_vel\"][\n                :, 0, :\n            ]\n\n        if self._progress_counter is not None:\n            with self._progress_counter.get_lock():\n                self._progress_counter.value += 1\n\n        return MotionClipSample(\n            motion_key=window.motion_key,\n            raw_motion_key=window.raw_motion_key,\n            window_index=int(index),\n            tensors=arrays,\n            length=window.length,\n        )\n\n    def _get_shard_handle(self, shard_index: int) -> h5py.File:\n        if shard_index in self._file_handles:\n            handle = self._file_handles.pop(shard_index)\n            if handle.id:\n                # Mark as most recently used.\n                self._file_handles[shard_index] = handle\n                return handle\n\n        if shard_index < 0 or shard_index >= len(self._shard_paths):\n            raise IndexError(\n                f\"Shard index {shard_index} out of range for \"\n                f\"{len(self._shard_paths)} available shards\"\n            )\n        shard_path = self._shard_paths[shard_index]\n        # Open with SWMR and a configurable raw chunk cache to speed up repeated reads.\n        # The default cache size (in bytes) can be overridden via the\n        # HOLOMOTION_HDF5_RDCC_NBYTES environment variable.\n        rdcc_nbytes_env = os.getenv(\"HOLOMOTION_HDF5_RDCC_NBYTES\")\n        if rdcc_nbytes_env is None:\n            rdcc_nbytes = 256 * 1024 * 1024  # 256MB default\n        else:\n            rdcc_nbytes = int(rdcc_nbytes_env)\n        handle = h5py.File(\n            shard_path,\n            \"r\",\n            libver=\"latest\",\n            swmr=True,\n            rdcc_nbytes=rdcc_nbytes,\n            rdcc_w0=0.75,\n        )\n        # Enforce LRU limit on the number of simultaneously open shard files.\n        if (\n            self._max_open_files is not None\n            and len(self._file_handles) >= self._max_open_files\n        ):\n            old_index, old_handle = self._file_handles.popitem(last=False)\n            old_handle.close()\n        self._file_handles[shard_index] = handle\n        return handle\n\n    def close(self) -> None:\n        \"\"\"Close all open HDF5 shard handles for this dataset.\"\"\"\n        for handle in self._file_handles.values():\n            if handle.id:\n                handle.close()\n        self._file_handles.clear()\n\n\nclass MotionClipBatchCache:\n    \"\"\"Double-buffered motion cache for RL training and evaluation.\"\"\"\n\n    @staticmethod\n    def _infer_cuda_device_index() -> int:\n        device_count = int(torch.cuda.device_count())\n        local_rank_env = os.environ.get(\"LOCAL_RANK\")\n        if local_rank_env is not None:\n            local_rank = int(local_rank_env)\n            if 0 <= local_rank < device_count:\n                return local_rank\n        return int(torch.cuda.current_device())\n\n    @classmethod\n    def _normalize_stage_device(\n        cls, stage_device: Optional[object]\n    ) -> Optional[torch.device]:\n        if stage_device is None:\n            return None\n\n        if isinstance(stage_device, torch.device):\n            if stage_device.type == \"cpu\":\n                return None\n            if stage_device.type != \"cuda\":\n                raise ValueError(\n                    f\"Unsupported stage_device type: {stage_device.type}\"\n                )\n            if not torch.cuda.is_available():\n                raise RuntimeError(\n                    \"stage_device requested CUDA but CUDA is not available\"\n                )\n            if stage_device.index is not None:\n                return stage_device\n            return torch.device(\"cuda\", cls._infer_cuda_device_index())\n\n        if isinstance(stage_device, str):\n            stage_device_str = stage_device.strip().lower()\n            if stage_device_str in (\"none\", \"cpu\"):\n                return None\n            if stage_device_str == \"cuda\":\n                if not torch.cuda.is_available():\n                    raise RuntimeError(\n                        \"stage_device requested CUDA but CUDA is not available\"\n                    )\n                return torch.device(\"cuda\", cls._infer_cuda_device_index())\n            if stage_device_str.startswith(\"cuda:\"):\n                if not torch.cuda.is_available():\n                    raise RuntimeError(\n                        \"stage_device requested CUDA but CUDA is not available\"\n                    )\n                return torch.device(stage_device_str)\n            raise ValueError(\n                f\"Unsupported stage_device string: {stage_device}\"\n            )\n\n        raise TypeError(\n            f\"Unsupported stage_device value type: {type(stage_device)}\"\n        )\n\n    def __init__(\n        self,\n        train_dataset: Dataset[MotionClipSample],\n        *,\n        val_dataset: Optional[Dataset[MotionClipSample]] = None,\n        batch_size: int,\n        stage_device: Optional[torch.device] = None,\n        num_workers: int = 4,\n        prefetch_factor: int = 2,\n        pin_memory: bool = True,\n        persistent_workers: bool = True,\n        sampler_rank: int = 0,\n        sampler_world_size: int = 1,\n        allowed_prefixes: Optional[Sequence[str]] = None,\n        swap_interval_steps: Optional[int] = None,\n        force_timeout_on_swap: bool = True,\n        stage_on_swap_only: bool = False,\n        batch_progress_bar: bool = False,\n        seed: Optional[int] = None,\n        loader_timeout: float = 0.0,\n    ) -> None:\n        if batch_size <= 0:\n            raise ValueError(\"batch_size must be positive\")\n        if float(loader_timeout) < 0.0:\n            raise ValueError(\"loader_timeout must be >= 0\")\n\n        self._datasets = {\n            \"train\": train_dataset,\n            \"val\": val_dataset if val_dataset is not None else train_dataset,\n        }\n        self._mode = \"train\"\n        self._seed = (\n            int(seed) if seed is not None else int(time.time_ns() & 0x7FFFFFFF)\n        )\n        self._stage_device = self._normalize_stage_device(stage_device)\n        self._sampler_rank = int(sampler_rank)\n        self._sampler_world_size = int(max(1, sampler_world_size))\n        self._batch_size = int(batch_size)\n        self._allowed_prefixes: Optional[Tuple[str, ...]] = (\n            tuple(allowed_prefixes) if allowed_prefixes is not None else None\n        )\n\n        # If enabled, keep the prefetched batch on CPU (FK on CPU) and stage to GPU\n        # only during cache swapping (advance).\n        self._stage_on_swap_only = bool(stage_on_swap_only)\n        self._batch_progress_bar = bool(batch_progress_bar)\n        self._loader_timeout = float(loader_timeout)\n        self.force_timeout_on_swap = bool(force_timeout_on_swap)\n        self._batch_progress_counter: Optional[mp.Value] = None\n        if self._should_use_batch_progress():\n            ctx = mp.get_context(\"spawn\")\n            self._batch_progress_counter = ctx.Value(\"i\", 0)\n\n        self.swap_interval_steps = (\n            swap_interval_steps\n            if swap_interval_steps is not None\n            else train_dataset.max_frame_length\n        )\n\n        self._num_workers = int(max(0, num_workers))\n        self._prefetch_factor = (\n            prefetch_factor if prefetch_factor is not None else None\n        )\n        self._pin_memory = bool(pin_memory)\n        self._persistent_workers = bool(persistent_workers and num_workers > 0)\n\n        self._dataloader: Optional[DataLoader] = None\n        self._sampler: Optional[Sampler[int]] = None\n        self._iterator: Optional[Iterator[ClipBatch]] = None\n\n        self._current_batch: Optional[ClipBatch] = None\n        self._next_batch: Optional[ClipBatch] = None\n        self._swap_index = 0\n\n        self._effective_batch_size: Optional[int] = None\n        self._num_batches: Optional[int] = None\n\n        # Weighted-bin sampling state\n        self._weighted_bin_enabled: bool = False\n        self._weighted_bin_bins: Optional[List[List[int]]] = None\n        self._weighted_bin_ratios: Optional[List[float]] = None\n        self._weighted_bin_specs: Optional[List[Dict[str, Any]]] = None\n        self._cache_curriculum_enabled: bool = False\n        self._cache_curriculum_cfg: Dict[str, Any] = {}\n        self._cache_curriculum_sampler: Optional[\n            PrioritizedInfiniteSampler\n        ] = None\n        self._cache_curriculum_dump_enabled: bool = False\n        self._cache_curriculum_dump_every_swaps: int = 10\n        self._cache_curriculum_dump_chunk_size: int = 4096\n        self._cache_curriculum_dump_dir: Path = Path(\n            \"cache_curriculum_window_scores\"\n        )\n        self._cache_curriculum_last_dump_swap: int = -1\n\n        # Async GPU staging helpers\n        self._copy_stream = None\n        self._pending_ready_event = None\n        self._current_ready_event = None\n        self._next_ready_event = None\n\n        self._build_dataloader()\n        if (\n            self._stage_device is not None\n            and self._stage_device.type == \"cuda\"\n        ):\n            self._copy_stream = torch.cuda.Stream(device=self._stage_device)\n        self._prime_buffers()\n\n    @property\n    def current_batch(self) -> ClipBatch:\n        assert self._current_batch is not None\n        return self._current_batch\n\n    @property\n    def max_frame_length(self) -> int:\n        return self.current_batch.max_frame_length\n\n    @property\n    def clip_count(self) -> int:\n        return self.current_batch.lengths.shape[0]\n\n    @property\n    def mode(self) -> str:\n        return self._mode\n\n    @property\n    def swap_index(self) -> int:\n        return self._swap_index\n\n    @property\n    def num_batches(self) -> int:\n        if self._num_batches is None:\n            raise RuntimeError(\"DataLoader is not initialised\")\n        return int(self._num_batches)\n\n    def set_mode(self, mode: str) -> None:\n        if mode == self._mode:\n            return\n        if mode not in self._datasets:\n            raise ValueError(f\"Unknown cache mode: {mode}\")\n        self._mode = mode\n        self._build_dataloader()\n        self._prime_buffers()\n\n    def set_seed(self, seed: int, *, reinitialize: bool = True) -> None:\n        self._seed = int(seed)\n        if reinitialize:\n            self._build_dataloader()\n            self._prime_buffers()\n\n    def advance(self) -> None:\n        if self._stage_on_swap_only:\n            if self._next_batch is None:\n                self._next_batch = self._fetch_next_batch()\n            # Stage the prefetched CPU batch to GPU only at swap time.\n            staged = self._stage_batch_blocking(self._next_batch)\n            self._current_batch = staged\n            self._next_batch = self._fetch_next_batch()\n            self._swap_index += 1\n            return\n\n        if self._next_batch is None:\n            self._next_batch = self._fetch_next_batch()\n        # Ensure asynchronous staging finished before swapping in next batch\n        if (\n            self._next_ready_event is not None\n            and self._stage_device is not None\n            and self._stage_device.type == \"cuda\"\n        ):\n            torch.cuda.current_stream(self._stage_device).wait_event(\n                self._next_ready_event\n            )\n        self._current_batch = self._next_batch\n        self._next_batch = self._fetch_next_batch()\n        self._swap_index += 1\n\n    # -------------------------\n    # Weighted-bin configuration\n    # -------------------------\n    def enable_weighted_bin_sampling(\n        self, cfg: Optional[Dict[str, Any]] = None\n    ) -> None:\n        \"\"\"Enable regex-based weighted-bin sampling over manifest motion keys.\n\n        The configuration must provide a list under ``bin_regex_patterns`` (or the\n        legacy name ``bin_regrex_patterns``), where each element is a mapping with:\n\n        - ``regex`` (or ``regrex``): Python regular expression applied to the\n          manifest clip key (e.g., ``AMASS_.*``, ``VR_pico_.*``).\n        - ``ratio``: Target sampling ratio in [0, 1].\n\n        The sum of explicit bin ratios must be <= 1.0. Any remaining mass is\n        assigned to an implicit ``others`` bin that collects all clips not\n        matched by any regex.\n        \"\"\"\n        cfg_local: Dict[str, Any] = dict(cfg or {})\n        if self._cache_curriculum_enabled:\n            raise ValueError(\n                \"weighted-bin and cache curriculum sampling cannot be enabled together.\"\n            )\n\n        dataset = self._datasets.get(\"train\")\n        if dataset is None:\n            raise ValueError(\n                \"Weighted-bin sampling requires a training dataset\"\n            )\n\n        # Collect manifest-level motion keys for all windows in order\n        window_keys: List[str] = []\n        for window in dataset.windows:\n            motion_key = getattr(window, \"raw_motion_key\", None)\n            if motion_key is None:\n                full_key = getattr(window, \"motion_key\", \"\")\n                if \"__start_\" in full_key:\n                    motion_key = full_key.split(\"__start_\", 1)[0]\n                else:\n                    motion_key = full_key\n            window_keys.append(motion_key)\n\n        bin_indices, all_ratios, specs = _configure_weighted_bins(\n            keys=window_keys,\n            cfg=cfg_local,\n            batch_size_for_log=int(self._batch_size),\n        )\n\n        # Log summary in terms of windows\n        table_rows = []\n        for item in specs:\n            table_rows.append(\n                [\n                    item[\"name\"],\n                    item[\"regex\"],\n                    f\"{item['ratio']:.4f}\",\n                    int(item[\"count\"]),\n                    f\"{item['dataset_fraction']:.4f}\",\n                    f\"{item['batch_fraction']:.4f}\",\n                ]\n            )\n        headers = [\n            \"bin\",\n            \"regex\",\n            \"final_ratio\",\n            \"num_windows\",\n            \"dataset_fraction\",\n            \"batch_fraction\",\n        ]\n        logger.info(\n            \"Motion cache weighted-bin sampling configured:\\n\"\n            + tabulate(table_rows, headers=headers, tablefmt=\"simple_outline\")\n        )\n\n        # Activate weighted-bin sampling and rebuild dataloader/cache\n        self._weighted_bin_enabled = True\n        self._weighted_bin_bins = bin_indices\n        self._weighted_bin_ratios = all_ratios\n        self._weighted_bin_specs = specs\n        self._build_dataloader()\n        self._prime_buffers()\n\n    def enable_cache_curriculum_sampling(\n        self, cfg: Optional[Dict[str, Any]] = None\n    ) -> None:\n        if self._weighted_bin_enabled:\n            raise ValueError(\n                \"cache curriculum and weighted-bin sampling cannot be enabled together.\"\n            )\n        self._cache_curriculum_enabled = True\n        self._cache_curriculum_cfg = dict(cfg or {})\n        self._cache_curriculum_dump_enabled = bool(\n            self._cache_curriculum_cfg.get(\n                \"dump_whole_window_scores_json\", True\n            )\n        )\n        self._cache_curriculum_dump_every_swaps = max(\n            1,\n            int(\n                self._cache_curriculum_cfg.get(\n                    \"dump_whole_window_scores_every_swaps\", 10\n                )\n            ),\n        )\n        self._cache_curriculum_dump_chunk_size = max(\n            1,\n            int(\n                self._cache_curriculum_cfg.get(\n                    \"dump_whole_window_scores_chunk_size\", 4096\n                )\n            ),\n        )\n        self._cache_curriculum_dump_dir = Path(\n            str(\n                self._cache_curriculum_cfg.get(\n                    \"dump_whole_window_scores_dir\",\n                    \"cache_curriculum_window_scores\",\n                )\n            )\n        )\n        self._cache_curriculum_last_dump_swap = -1\n        self._prepare_cache_curriculum_dump_dir(\n            self._cache_curriculum_dump_dir,\n            reason=\"enabled\",\n        )\n        self._build_dataloader()\n        self._prime_buffers()\n\n    def _prepare_cache_curriculum_dump_dir(\n        self, dump_dir: Path, *, reason: str\n    ) -> None:\n        self._cache_curriculum_dump_dir = Path(str(dump_dir))\n        if not self._cache_curriculum_dump_enabled:\n            return\n        self._cache_curriculum_dump_dir.mkdir(parents=True, exist_ok=True)\n        logger.info(\n            \"Cache curriculum whole-window score dump \"\n            f\"{reason}: dir={self._cache_curriculum_dump_dir}, \"\n            f\"every_swaps={self._cache_curriculum_dump_every_swaps}, \"\n            f\"rank={self._sampler_rank}\"\n        )\n\n    def set_cache_curriculum_dump_dir(self, dump_dir: str) -> None:\n        self._prepare_cache_curriculum_dump_dir(\n            Path(str(dump_dir)),\n            reason=\"directory set\",\n        )\n\n    def update_cache_curriculum(\n        self,\n        *,\n        window_indices: Tensor,\n        mpkpe_signal_means: Tensor,\n        completion_rate_means: Tensor,\n        counts: Tensor,\n        swap_index: int,\n    ) -> bool:\n        if self._cache_curriculum_sampler is None:\n            return False\n        updated = (\n            self._cache_curriculum_sampler.maybe_update_from_observations(\n                window_indices=window_indices,\n                mpkpe_signal_means=mpkpe_signal_means,\n                completion_rate_means=completion_rate_means,\n                counts=counts,\n                swap_index=swap_index,\n            )\n        )\n        if updated:\n            self._refresh_prefetched_batch()\n        self._maybe_dump_cache_curriculum_scores_json(swap_index=swap_index)\n        return updated\n\n    def _refresh_prefetched_batch(self) -> None:\n        if self._next_batch is None:\n            return\n        self._next_batch = self._fetch_next_batch()\n\n    def _maybe_dump_cache_curriculum_scores_json(\n        self, *, swap_index: int\n    ) -> None:\n        if not self._cache_curriculum_dump_enabled:\n            return\n        if self._cache_curriculum_sampler is None:\n            return\n\n        swap_idx = int(swap_index)\n        if swap_idx <= 0:\n            return\n        if swap_idx % self._cache_curriculum_dump_every_swaps != 0:\n            return\n        if self._cache_curriculum_last_dump_swap == swap_idx:\n            return\n\n        dataset = self._datasets[\"train\"]\n        ds_len = int(len(dataset))\n        if ds_len <= 0:\n            return\n\n        self._cache_curriculum_dump_dir.mkdir(parents=True, exist_ok=True)\n        output_path = self._cache_curriculum_dump_dir / (\n            \"whole_window_scores_\"\n            f\"rank_{self._sampler_rank:04d}_swap_{swap_idx:06d}.json\"\n        )\n        sampler_version = int(self._cache_curriculum_sampler.state_version)\n        windows = dataset.windows\n        score_values: List[float] = []\n        completion_values: List[float] = []\n        rel_improve_values: List[float] = []\n        selection_count_values: List[int] = []\n        seen_values: List[bool] = []\n        in_pool_values: List[bool] = []\n        chunk_size = max(1, int(self._cache_curriculum_dump_chunk_size))\n        for chunk_start in range(0, ds_len, chunk_size):\n            chunk_end = min(ds_len, chunk_start + chunk_size)\n            chunk_indices = torch.arange(\n                chunk_start, chunk_end, dtype=torch.long\n            )\n            chunk_scores = (\n                self._cache_curriculum_sampler.get_scores_for_indices(\n                    chunk_indices\n                )\n            )\n            chunk_state = (\n                self._cache_curriculum_sampler.get_window_state_for_indices(\n                    chunk_indices\n                )\n            )\n            if chunk_scores.numel() != chunk_indices.numel():\n                raise ValueError(\n                    \"Whole-window score dump shape mismatch for \"\n                    \"cache curriculum sampler.\"\n                )\n            score_values.extend(chunk_scores.tolist())\n            completion_values.extend(\n                chunk_state[\"ema_completion_rate\"].tolist()\n            )\n            rel_improve_values.extend(\n                chunk_state[\"completion_rate_rel_improve\"].tolist()\n            )\n            selection_count_values.extend(\n                chunk_state[\"selection_count\"].tolist()\n            )\n            seen_values.extend(chunk_state[\"seen\"].tolist())\n            in_pool_values.extend(chunk_state[\"in_prioritized_pool\"].tolist())\n        rows: List[Dict[str, Any]] = []\n        for window_index in range(ds_len):\n            window = windows[window_index]\n            rows.append(\n                {\n                    \"swap_index\": int(swap_idx),\n                    \"rank\": int(self._sampler_rank),\n                    \"sampler_state_version\": sampler_version,\n                    \"window_index\": int(window_index),\n                    \"raw_motion_key\": str(window.raw_motion_key),\n                    \"motion_key\": str(window.motion_key),\n                    \"start\": int(window.start),\n                    \"length\": int(window.length),\n                    \"score\": float(score_values[window_index]),\n                    \"selection_count\": int(\n                        selection_count_values[window_index]\n                    ),\n                    \"ema_completion_rate\": float(\n                        completion_values[window_index]\n                    ),\n                    \"completion_rate_rel_improve\": float(\n                        rel_improve_values[window_index]\n                    ),\n                    \"seen\": bool(seen_values[window_index]),\n                    \"in_prioritized_pool\": bool(in_pool_values[window_index]),\n                }\n            )\n        payload: Dict[str, Any] = {\n            \"swap_index\": int(swap_idx),\n            \"rank\": int(self._sampler_rank),\n            \"sampler_state_version\": sampler_version,\n            \"num_windows\": int(ds_len),\n            \"pool_metrics\": self._cache_curriculum_sampler.get_pool_statistics()\n            or {},\n            \"rows\": rows,\n        }\n        with output_path.open(\"w\", encoding=\"utf-8\") as handle:\n            json.dump(payload, handle, indent=2)\n            handle.write(\"\\n\")\n        self._cache_curriculum_last_dump_swap = swap_idx\n\n    def cache_curriculum_scores_for_window_indices(\n        self, window_indices: Tensor\n    ) -> Optional[Tuple[Tensor, Dict[str, Tensor], int]]:\n        if self._cache_curriculum_sampler is None:\n            return None\n        scores = self._cache_curriculum_sampler.get_scores_for_indices(\n            window_indices\n        )\n        state = self._cache_curriculum_sampler.get_window_state_for_indices(\n            window_indices\n        )\n        version = self._cache_curriculum_sampler.state_version\n        return scores, state, version\n\n    def cache_curriculum_pool_statistics(\n        self,\n    ) -> Optional[Dict[str, float]]:\n        if self._cache_curriculum_sampler is None:\n            return None\n        return self._cache_curriculum_sampler.get_pool_statistics()\n\n    def sample_env_assignments(\n        self,\n        num_envs: int,\n        n_future_frames: int,\n        device: torch.device,\n        *,\n        deterministic_start: bool = False,\n    ) -> Tuple[Tensor, Tensor]:\n        batch = self.current_batch\n        lengths = batch.lengths.to(device)\n\n        if num_envs <= 0:\n            raise ValueError(\"num_envs must be positive\")\n\n        total = int(lengths.shape[0])\n        if total == 0:\n            raise ValueError(\n                \"Cannot sample from an empty batch. Ensure the cache contains \"\n                \"at least one motion clip before calling sample_env_assignments.\"\n            )\n        clip_indices = torch.randint(\n            low=0, high=total, size=(num_envs,), device=device\n        )\n\n        max_start = torch.clamp(\n            lengths[clip_indices] - 1 - n_future_frames, min=0\n        )\n        if deterministic_start:\n            frame_starts = torch.zeros_like(max_start)\n        else:\n            rand = torch.rand_like(max_start, dtype=torch.float32)\n            frame_starts = torch.floor(rand * (max_start + 1).float()).to(\n                torch.long\n            )\n\n        return clip_indices, frame_starts\n\n    def _prepare_gather_indices(\n        self,\n        *,\n        clip_indices: Tensor,\n        frame_indices: Tensor,\n        n_future_frames: int,\n    ) -> Tuple[Tensor, Tensor]:\n        batch = self.current_batch\n        staged_device = batch.lengths.device\n        selected_clips = clip_indices.to(\n            staged_device, dtype=torch.long\n        ).clone()\n        frame_indices = frame_indices.to(\n            staged_device, dtype=torch.long\n        ).clone()\n\n        temporal_span = 1 + int(n_future_frames)\n        time_offsets = torch.arange(\n            temporal_span, device=staged_device, dtype=torch.long\n        )\n        gather_timesteps = frame_indices[:, None] + time_offsets[None, :]\n\n        lengths = batch.lengths\n        max_valid = torch.clamp(\n            lengths.index_select(0, selected_clips) - 1, min=0\n        )\n        gather_timesteps = torch.minimum(\n            gather_timesteps, max_valid[:, None]\n        ).clone()\n\n        return selected_clips, gather_timesteps\n\n    def gather_tensor(\n        self,\n        tensor_name: str,\n        *,\n        clip_indices: Tensor,\n        frame_indices: Tensor,\n        n_future_frames: int,\n    ) -> Tensor:\n        batch = self.current_batch\n        if tensor_name not in batch.tensors:\n            raise KeyError(\n                f\"Tensor '{tensor_name}' is not present in current_batch\"\n            )\n        selected_clips, gather_timesteps = self._prepare_gather_indices(\n            clip_indices=clip_indices,\n            frame_indices=frame_indices,\n            n_future_frames=n_future_frames,\n        )\n        tensor = batch.tensors[tensor_name]\n        return tensor[selected_clips[:, None], gather_timesteps, ...]\n\n    def lengths_for_indices(self, clip_indices: Tensor) -> Tensor:\n        lengths = self.current_batch.lengths.to(clip_indices.device)\n        return lengths.index_select(0, clip_indices.long())\n\n    def motion_keys_for_indices(self, clip_indices: Tensor) -> List[str]:\n        result = []\n        base_keys = self.current_batch.motion_keys\n        for idx in clip_indices.tolist():\n            result.append(base_keys[int(idx)])\n        return result\n\n    def window_indices_for_indices(self, clip_indices: Tensor) -> Tensor:\n        base_indices = self.current_batch.window_indices.to(\n            clip_indices.device\n        )\n        return base_indices.index_select(0, clip_indices.long())\n\n    def _prime_buffers(self) -> None:\n        if self._stage_on_swap_only:\n            # Prefetch on CPU; stage to GPU only for current batch.\n            cpu_current = self._fetch_next_batch()\n            self._current_batch = self._stage_batch_blocking(cpu_current)\n            self._next_batch = self._fetch_next_batch()\n            self._pending_ready_event = None\n            self._current_ready_event = None\n            self._next_ready_event = None\n            return\n\n        self._current_batch = self._fetch_next_batch()\n        # Ensure first staged batch is ready before consumption\n        if (\n            self._current_ready_event is not None\n            and self._stage_device is not None\n            and self._stage_device.type == \"cuda\"\n        ):\n            t0 = time.time()\n            torch.cuda.current_stream(self._stage_device).wait_event(\n                self._current_ready_event\n            )\n            t1 = time.time()\n            logger.info(\n                f\"Perf/Cache/cuda_wait_event_ms={((t1 - t0) * 1e3):.2f} (first)\"\n            )\n        self._next_batch = self._fetch_next_batch()\n\n    def _fetch_next_batch(self) -> ClipBatch:\n        batch = self._load_next_batch()\n        if self._stage_on_swap_only:\n            # Prefetch raw batch on CPU.\n            return batch\n\n        staged = self._stage_batch(batch, record_event=True)\n        # Move pending event into current/next slot\n        if self._current_batch is None:\n            self._current_ready_event = self._pending_ready_event\n        else:\n            self._next_ready_event = self._pending_ready_event\n        self._pending_ready_event = None\n        return staged\n\n    def _load_next_batch(self) -> ClipBatch:\n        if self._should_use_batch_progress():\n            return self._load_next_batch_with_progress()\n        return self._load_next_batch_raw()\n\n    def _load_next_batch_raw(self) -> ClipBatch:\n        if self._iterator is None:\n            self._iterator = self._build_iterator()\n\n        try:\n            batch = next(self._iterator)\n        except StopIteration:\n            self._iterator = self._build_iterator(reset_epoch=True)\n            batch = next(self._iterator)\n        return batch\n\n    def _load_next_batch_with_progress(self) -> ClipBatch:\n        if self._iterator is None:\n            self._iterator = self._build_iterator()\n\n        expected = int(self._effective_batch_size or self._batch_size)\n        counter = self._batch_progress_counter\n        if counter is None:\n            return self._load_next_batch_raw()\n\n        with counter.get_lock():\n            counter.value = 0\n\n        pbar = tqdm(\n            total=expected,\n            desc=\"Collecting motion batch\",\n            leave=False,\n            dynamic_ncols=True,\n        )\n        last = 0\n        with ThreadPoolExecutor(max_workers=1) as executor:\n            future = executor.submit(self._load_next_batch_raw)\n            while not future.done():\n                with counter.get_lock():\n                    value = counter.value\n                if value > last:\n                    step = min(value, expected) - last\n                    if step > 0:\n                        pbar.update(step)\n                        last += step\n                time.sleep(0.05)\n            batch = future.result(timeout=self._result_timeout())\n\n        with counter.get_lock():\n            value = counter.value\n        if value > last:\n            step = min(value, expected) - last\n            if step > 0:\n                pbar.update(step)\n        pbar.close()\n        return batch\n\n    def _stage_batch_blocking(self, batch: ClipBatch) -> ClipBatch:\n        \"\"\"Stage a CPU batch to the configured device on the current stream.\n\n        This path is used when `stage_on_swap_only=True` so that only the current\n        cache batch resides on GPU.\n        \"\"\"\n        if self._stage_device is None:\n            return batch\n        non_blocking = bool(\n            self._pin_memory and self._stage_device.type == \"cuda\"\n        )\n        tensors = {\n            name: tensor.to(self._stage_device, non_blocking=non_blocking)\n            for name, tensor in batch.tensors.items()\n        }\n        lengths = batch.lengths.to(\n            self._stage_device, non_blocking=non_blocking\n        )\n        window_indices = batch.window_indices.to(\n            self._stage_device, non_blocking=non_blocking\n        )\n        staged = ClipBatch(\n            tensors=tensors,\n            lengths=lengths,\n            motion_keys=batch.motion_keys,\n            raw_motion_keys=getattr(\n                batch, \"raw_motion_keys\", batch.motion_keys\n            ),\n            window_indices=window_indices,\n            max_frame_length=batch.max_frame_length,\n        )\n        return staged\n\n    def _stage_batch(\n        self,\n        batch: ClipBatch,\n        record_event: bool = False,\n    ) -> ClipBatch:\n        if self._stage_device is None:\n            return batch\n\n        # If CUDA, copy on a dedicated stream and record readiness\n        if self._copy_stream is None and (\n            self._stage_device is not None\n            and self._stage_device.type == \"cuda\"\n        ):\n            self._copy_stream = torch.cuda.Stream(device=self._stage_device)\n            logger.info(\n                f\"Perf/Cache: created CUDA copy stream lazily on {self._stage_device}\"\n            )\n\n        if self._copy_stream is not None:\n            # estimate payload size for logging\n            try:\n                total_bytes = 0\n                for tensor in batch.tensors.values():\n                    total_bytes += int(tensor.element_size() * tensor.numel())\n                total_bytes += int(\n                    batch.lengths.element_size() * batch.lengths.numel()\n                )\n                total_bytes += int(\n                    batch.window_indices.element_size()\n                    * batch.window_indices.numel()\n                )\n            except Exception:\n                total_bytes = -1\n            with torch.cuda.stream(self._copy_stream):\n                tensors = {\n                    name: tensor.to(self._stage_device, non_blocking=True)\n                    for name, tensor in batch.tensors.items()\n                }\n                lengths = batch.lengths.to(\n                    self._stage_device, non_blocking=True\n                )\n                window_indices = batch.window_indices.to(\n                    self._stage_device, non_blocking=True\n                )\n            if record_event:\n                ev = torch.cuda.Event()\n                ev.record(self._copy_stream)\n                self._pending_ready_event = ev\n\n        else:\n            tensors = {\n                name: tensor.to(self._stage_device, non_blocking=True)\n                for name, tensor in batch.tensors.items()\n            }\n            lengths = batch.lengths.to(self._stage_device, non_blocking=True)\n            window_indices = batch.window_indices.to(\n                self._stage_device, non_blocking=True\n            )\n\n        return ClipBatch(\n            tensors=tensors,\n            lengths=lengths,\n            motion_keys=batch.motion_keys,\n            raw_motion_keys=getattr(\n                batch, \"raw_motion_keys\", batch.motion_keys\n            ),\n            window_indices=window_indices,\n            max_frame_length=batch.max_frame_length,\n        )\n\n    def _build_iterator(\n        self, *, reset_epoch: bool = False\n    ) -> Iterator[ClipBatch]:\n        if self._dataloader is None:\n            raise RuntimeError(\"DataLoader is not initialised\")\n\n        if isinstance(self._sampler, DistributedSampler) and reset_epoch:\n            self._sampler.set_epoch(self._swap_index + 1)\n\n        return iter(self._dataloader)\n\n    def _build_dataloader(self) -> None:\n        dataset = self._datasets[self._mode]\n        dataset.set_progress_counter(self._batch_progress_counter)\n\n        # Clamp batch size to dataset length to avoid empty iterator when drop_last is disabled\n        effective_batch_size = self._batch_size\n        ds_len = len(dataset)\n        if isinstance(ds_len, int) and ds_len > 0:\n            effective_batch_size = max(1, min(self._batch_size, ds_len))\n\n        # Sampler selection: validation uses standard distributed/sequential samplers;\n        # training can optionally use weighted-bin sampling.\n        if self._mode == \"val\":\n            if self._sampler_world_size > 1:\n                self._sampler = DistributedSampler(\n                    dataset,\n                    num_replicas=self._sampler_world_size,\n                    rank=self._sampler_rank,\n                    shuffle=False,\n                    drop_last=False,\n                )\n            else:\n                self._sampler = None\n            self._cache_curriculum_sampler = None\n        else:\n            if self._cache_curriculum_enabled:\n                seed = self._seed + self._sampler_rank * 100003\n                cfg = dict(self._cache_curriculum_cfg)\n                self._cache_curriculum_sampler = PrioritizedInfiniteSampler(\n                    dataset_len=ds_len,\n                    batch_size=effective_batch_size,\n                    seed=seed,\n                    p_a_ratio=float(cfg.get(\"p_a_ratio\", 0.2)),\n                    ema_alpha_signal=float(cfg.get(\"ema_alpha_signal\", 0.2)),\n                    ema_alpha_rel_improve=float(\n                        cfg.get(\"ema_alpha_rel_improve\", 0.2)\n                    ),\n                    relative_eps=float(cfg.get(\"relative_eps\", 1.0e-6)),\n                )\n                self._cache_curriculum_last_dump_swap = -1\n                self._sampler = self._cache_curriculum_sampler\n            elif (\n                self._weighted_bin_enabled\n                and self._weighted_bin_bins is not None\n                and self._weighted_bin_ratios is not None\n            ):\n                seed = self._seed + self._sampler_rank * 100003\n                self._sampler = WeightedBinInfiniteSampler(\n                    dataset_len=ds_len,\n                    bin_indices=self._weighted_bin_bins,\n                    ratios=self._weighted_bin_ratios,\n                    batch_size=effective_batch_size,\n                    seed=seed,\n                )\n                self._cache_curriculum_sampler = None\n            else:\n                if self._sampler_world_size > 1:\n                    # Infinite sampler for training: no epoch boundaries\n                    self._sampler = InfiniteDistributedSampler(\n                        dataset,\n                        num_replicas=self._sampler_world_size,\n                        rank=self._sampler_rank,\n                        shuffle=True,\n                        drop_last=False,\n                    )\n                else:\n                    # Infinite sampler for single-process training\n                    self._sampler = InfiniteRandomSampler(dataset)\n                self._cache_curriculum_sampler = None\n\n        # Only pass prefetch_factor when using workers\n        pf = (\n            self._prefetch_factor\n            if (self._num_workers and self._num_workers > 0)\n            else None\n        )\n        pw = (\n            self._persistent_workers\n            if (self._num_workers and self._num_workers > 0)\n            else False\n        )\n\n        # Collate wrapper: in validation, pad the batch up to cache size by\n        # uniformly repeating samples when dataset is smaller than batch size.\n        collate = partial(\n            _cache_collate_fn,\n            mode=self._mode,\n            batch_size=self._batch_size,\n        )\n\n        mp_ctx = None\n        if self._num_workers and self._num_workers > 0:\n            mp_ctx = mp.get_context(\"spawn\")\n\n        worker_init_fn = None\n        if (\n            self._num_workers > 0\n            and self._stage_device is not None\n            and self._stage_device.type == \"cuda\"\n        ):\n            worker_init_fn = _cpu_only_dataloader_worker_init_fn\n\n        self._dataloader = DataLoader(\n            dataset,\n            batch_size=effective_batch_size,\n            sampler=self._sampler,\n            shuffle=(self._sampler is None and self._mode != \"val\"),\n            num_workers=self._num_workers,\n            prefetch_factor=pf,\n            pin_memory=self._pin_memory,\n            timeout=self._loader_timeout_seconds(),\n            persistent_workers=pw,\n            collate_fn=collate,\n            drop_last=False,\n            multiprocessing_context=mp_ctx,\n            worker_init_fn=worker_init_fn,\n        )\n        self._iterator = None\n        self._current_batch = None\n        self._next_batch = None\n        self._swap_index = 0\n\n        # Compute number of batches only for validation; training is infinite\n        local_len = ds_len\n        if self._mode == \"val\":\n            if self._sampler is not None:\n                local_len = (\n                    ds_len + self._sampler_world_size - 1\n                ) // self._sampler_world_size\n            self._effective_batch_size = int(effective_batch_size)\n            self._num_batches = (\n                local_len + self._effective_batch_size - 1\n            ) // self._effective_batch_size\n        else:\n            self._effective_batch_size = int(effective_batch_size)\n            self._num_batches = 2**31  # effectively infinite for logging\n\n    def close(self) -> None:\n        \"\"\"Release DataLoader workers and close underlying HDF5 datasets.\"\"\"\n        datasets = self.__dict__.get(\"_datasets\")\n        if datasets is None:\n            return\n        self._iterator = None\n        self._current_batch = None\n        self._next_batch = None\n        self._dataloader = None\n        self._copy_stream = None\n        self._pending_ready_event = None\n        self._current_ready_event = None\n        self._next_ready_event = None\n\n        for ds in datasets.values():\n            if ds is not None:\n                ds.close()\n\n    def __del__(self) -> None:\n        self.close()\n\n    def _loader_timeout_seconds(self) -> float:\n        if not self.force_timeout_on_swap:\n            return 0.0\n        return self._loader_timeout\n\n    def _result_timeout(self) -> Optional[float]:\n        timeout_s = self._loader_timeout_seconds()\n        if timeout_s <= 0.0:\n            return None\n        return timeout_s + 1.0\n\n    def _should_use_batch_progress(self) -> bool:\n        if not self._batch_progress_bar:\n            return False\n        if self._sampler_world_size > 1:\n            return False\n        if self._loader_timeout_seconds() > 0.0:\n            return False\n        return True\n"
  },
  {
    "path": "holomotion/src/training/reference_filter_export.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nimport json\nimport tempfile\nfrom pathlib import Path\nfrom typing import Mapping, Sequence\n\nimport matplotlib\nimport numpy as np\nfrom omegaconf import DictConfig, ListConfig\nfrom scipy.spatial.transform import Rotation as Rotation3D\n\nfrom holomotion.src.training.h5_dataloader import (\n    MotionClipSample,\n    build_motion_datasets_from_cfg,\n)\n\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\n\n\ndef _to_numpy(array_like) -> np.ndarray:\n    if hasattr(array_like, \"detach\"):\n        array_like = array_like.detach().cpu().numpy()\n    return np.asarray(array_like, dtype=np.float32)\n\n\ndef _require_tensor(\n    tensors: Mapping[str, object], tensor_name: str, error_message: str\n) -> np.ndarray:\n    if tensor_name not in tensors:\n        raise ValueError(error_message)\n    return _to_numpy(tensors[tensor_name])\n\n\ndef _quat_xyzw_to_rpy(quat_xyzw: np.ndarray) -> np.ndarray:\n    quat_xyzw = np.asarray(quat_xyzw, dtype=np.float32)\n    flat = quat_xyzw.reshape(-1, 4)\n    euler = Rotation3D.from_quat(flat).as_euler(\"xyz\", degrees=False)\n    return euler.reshape(*quat_xyzw.shape[:-1], 3).astype(\n        np.float32, copy=False\n    )\n\n\ndef _write_npz(output_path: Path, payload: Mapping[str, np.ndarray]) -> None:\n    np.savez(str(output_path), **payload)\n\n\ndef _plot_series_groups(\n    output_path: Path,\n    title: str,\n    groups: Sequence[tuple[str, np.ndarray, np.ndarray]],\n    axis_labels: Sequence[str] = (\"x\", \"y\", \"z\"),\n) -> None:\n    nrows = len(groups)\n    ncols = len(axis_labels)\n    fig, axes = plt.subplots(\n        nrows=nrows,\n        ncols=ncols,\n        figsize=(4.0 * ncols, 2.8 * max(1, nrows)),\n        squeeze=False,\n    )\n    plot_steps = np.arange(groups[0][1].shape[0], dtype=np.int32)\n    for row_idx, (group_name, ref_values, ft_values) in enumerate(groups):\n        for col_idx, axis_name in enumerate(axis_labels):\n            ax = axes[row_idx, col_idx]\n            ax.plot(\n                plot_steps,\n                ref_values[:, col_idx],\n                label=\"raw\",\n                linewidth=1.4,\n            )\n            ax.plot(\n                plot_steps,\n                ft_values[:, col_idx],\n                label=\"filtered\",\n                linewidth=1.2,\n            )\n            ax.set_title(f\"{group_name} {axis_name}\")\n            ax.grid(True, alpha=0.3)\n            if row_idx == 0 and col_idx == 0:\n                ax.legend(loc=\"best\")\n    fig.suptitle(title)\n    fig.tight_layout()\n    fig.savefig(output_path, dpi=150, bbox_inches=\"tight\")\n    plt.close(fig)\n\n\ndef _plot_dof_matrix(\n    output_path: Path,\n    title: str,\n    dof_names: Sequence[str],\n    ref_values: np.ndarray,\n    ft_values: np.ndarray,\n) -> None:\n    num_dofs = len(dof_names)\n    fig, axes = plt.subplots(\n        nrows=num_dofs,\n        ncols=1,\n        figsize=(14.0, max(2.8 * num_dofs, 3.5)),\n        squeeze=False,\n    )\n    plot_steps = np.arange(ref_values.shape[0], dtype=np.int32)\n    for idx, dof_name in enumerate(dof_names):\n        ax = axes[idx, 0]\n        ax.plot(plot_steps, ref_values[:, idx], label=\"raw\", linewidth=1.4)\n        ax.plot(plot_steps, ft_values[:, idx], label=\"filtered\", linewidth=1.2)\n        ax.set_title(dof_name)\n        ax.grid(True, alpha=0.3)\n        if idx == 0:\n            ax.legend(loc=\"best\")\n    fig.suptitle(title)\n    fig.tight_layout()\n    fig.savefig(output_path, dpi=150, bbox_inches=\"tight\")\n    plt.close(fig)\n\n\ndef export_reference_filter_debug_artifacts(\n    *,\n    sample: MotionClipSample,\n    output_dir: str | Path,\n    body_names: Sequence[str],\n    dof_names: Sequence[str],\n    selected_body_links: Sequence[str],\n) -> Path:\n    tensors = sample.tensors\n    if \"ft_ref_rg_pos\" not in tensors or \"ft_ref_dof_pos\" not in tensors:\n        raise ValueError(\n            \"Filtered reference tensors are unavailable. Ensure online filtering \"\n            \"is enabled and ft_ref_* tensors are materialized.\"\n        )\n\n    output_dir = Path(output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n\n    ref_root_pos = _require_tensor(\n        tensors,\n        \"ref_root_pos\",\n        \"Missing ref_root_pos tensor in sampled clip.\",\n    )\n    ft_root_pos = _require_tensor(\n        tensors,\n        \"ft_ref_root_pos\",\n        \"Missing ft_ref_root_pos tensor in sampled clip.\",\n    )\n    ref_root_rot = _require_tensor(\n        tensors,\n        \"ref_root_rot\",\n        \"Missing ref_root_rot tensor in sampled clip.\",\n    )\n    ft_root_rot = _require_tensor(\n        tensors,\n        \"ft_ref_root_rot\",\n        \"Missing ft_ref_root_rot tensor in sampled clip.\",\n    )\n    ref_root_vel = _require_tensor(\n        tensors,\n        \"ref_root_vel\",\n        \"Missing ref_root_vel tensor in sampled clip.\",\n    )\n    ft_root_vel = _require_tensor(\n        tensors,\n        \"ft_ref_root_vel\",\n        \"Missing ft_ref_root_vel tensor in sampled clip.\",\n    )\n    ref_root_ang_vel = _require_tensor(\n        tensors,\n        \"ref_root_ang_vel\",\n        \"Missing ref_root_ang_vel tensor in sampled clip.\",\n    )\n    ft_root_ang_vel = _require_tensor(\n        tensors,\n        \"ft_ref_root_ang_vel\",\n        \"Missing ft_ref_root_ang_vel tensor in sampled clip.\",\n    )\n    ref_rg_pos = _require_tensor(\n        tensors,\n        \"ref_rg_pos\",\n        \"Missing ref_rg_pos tensor in sampled clip.\",\n    )\n    ft_ref_rg_pos = _require_tensor(\n        tensors,\n        \"ft_ref_rg_pos\",\n        \"Missing ft_ref_rg_pos tensor in sampled clip.\",\n    )\n    ref_body_vel = _require_tensor(\n        tensors,\n        \"ref_body_vel\",\n        \"Missing ref_body_vel tensor in sampled clip.\",\n    )\n    ft_ref_body_vel = _require_tensor(\n        tensors,\n        \"ft_ref_body_vel\",\n        \"Missing ft_ref_body_vel tensor in sampled clip.\",\n    )\n    ref_body_ang_vel = _require_tensor(\n        tensors,\n        \"ref_body_ang_vel\",\n        \"Missing ref_body_ang_vel tensor in sampled clip.\",\n    )\n    ft_ref_body_ang_vel = _require_tensor(\n        tensors,\n        \"ft_ref_body_ang_vel\",\n        \"Missing ft_ref_body_ang_vel tensor in sampled clip.\",\n    )\n    ref_dof_pos = _require_tensor(\n        tensors,\n        \"ref_dof_pos\",\n        \"Missing ref_dof_pos tensor in sampled clip.\",\n    )\n    ft_ref_dof_pos = _require_tensor(\n        tensors,\n        \"ft_ref_dof_pos\",\n        \"Missing ft_ref_dof_pos tensor in sampled clip.\",\n    )\n    ref_dof_vel = _require_tensor(\n        tensors,\n        \"ref_dof_vel\",\n        \"Missing ref_dof_vel tensor in sampled clip.\",\n    )\n    ft_ref_dof_vel = _require_tensor(\n        tensors,\n        \"ft_ref_dof_vel\",\n        \"Missing ft_ref_dof_vel tensor in sampled clip.\",\n    )\n\n    body_name_to_idx = {name: idx for idx, name in enumerate(body_names)}\n    missing_links = [\n        link_name\n        for link_name in selected_body_links\n        if link_name not in body_name_to_idx\n    ]\n    if missing_links:\n        raise ValueError(\n            f\"Requested body links are missing from robot.body_names: {missing_links}\"\n        )\n\n    ref_root_rpy = _quat_xyzw_to_rpy(ref_root_rot)\n    ft_root_rpy = _quat_xyzw_to_rpy(ft_root_rot)\n\n    root_payload = {\n        \"ref_global_pos\": ref_root_pos,\n        \"ft_ref_global_pos\": ft_root_pos,\n        \"ref_rpy\": ref_root_rpy,\n        \"ft_ref_rpy\": ft_root_rpy,\n        \"ref_lin_vel\": ref_root_vel,\n        \"ft_ref_lin_vel\": ft_root_vel,\n        \"ref_ang_vel\": ref_root_ang_vel,\n        \"ft_ref_ang_vel\": ft_root_ang_vel,\n    }\n    _write_npz(output_dir / \"root_signals.npz\", root_payload)\n\n    body_payload: dict[str, np.ndarray] = {}\n    for link_name in selected_body_links:\n        body_idx = body_name_to_idx[link_name]\n        body_payload[f\"{link_name}__ref_global_pos\"] = ref_rg_pos[\n            :, body_idx, :\n        ]\n        body_payload[f\"{link_name}__ft_ref_global_pos\"] = ft_ref_rg_pos[\n            :, body_idx, :\n        ]\n        body_payload[f\"{link_name}__ref_lin_vel\"] = ref_body_vel[\n            :, body_idx, :\n        ]\n        body_payload[f\"{link_name}__ft_ref_lin_vel\"] = ft_ref_body_vel[\n            :, body_idx, :\n        ]\n        body_payload[f\"{link_name}__ref_ang_vel\"] = ref_body_ang_vel[\n            :, body_idx, :\n        ]\n        body_payload[f\"{link_name}__ft_ref_ang_vel\"] = ft_ref_body_ang_vel[\n            :, body_idx, :\n        ]\n    _write_npz(output_dir / \"bodylink_signals.npz\", body_payload)\n\n    dof_payload = {\n        \"ref_dof_pos\": ref_dof_pos,\n        \"ft_ref_dof_pos\": ft_ref_dof_pos,\n        \"ref_dof_vel\": ref_dof_vel,\n        \"ft_ref_dof_vel\": ft_ref_dof_vel,\n    }\n    _write_npz(output_dir / \"dof_signals.npz\", dof_payload)\n\n    filter_cutoff_tensor = tensors.get(\"filter_cutoff_hz\")\n    filter_cutoff_hz = None\n    if filter_cutoff_tensor is not None:\n        cutoff_values = _to_numpy(filter_cutoff_tensor).reshape(-1)\n        if cutoff_values.size > 0:\n            filter_cutoff_hz = float(cutoff_values[0])\n\n    metadata = {\n        \"motion_key\": sample.motion_key,\n        \"raw_motion_key\": sample.raw_motion_key,\n        \"window_index\": int(sample.window_index),\n        \"length\": int(sample.length),\n        \"filter_cutoff_hz\": filter_cutoff_hz,\n        \"selected_body_links\": list(selected_body_links),\n        \"body_names\": list(body_names),\n        \"dof_names\": list(dof_names),\n    }\n    (output_dir / \"metadata.json\").write_text(\n        json.dumps(metadata, indent=2, sort_keys=True),\n        encoding=\"utf-8\",\n    )\n\n    _plot_series_groups(\n        output_dir / \"root_comparison.png\",\n        title=\"Root Raw vs Filtered Reference Signals\",\n        groups=[\n            (\"global_pos\", ref_root_pos, ft_root_pos),\n            (\"rpy\", ref_root_rpy, ft_root_rpy),\n            (\"lin_vel\", ref_root_vel, ft_root_vel),\n            (\"ang_vel\", ref_root_ang_vel, ft_root_ang_vel),\n        ],\n    )\n\n    for link_name in selected_body_links:\n        _plot_series_groups(\n            output_dir / f\"{link_name}_comparison.png\",\n            title=f\"{link_name} Raw vs Filtered Reference Signals\",\n            groups=[\n                (\n                    \"global_pos\",\n                    body_payload[f\"{link_name}__ref_global_pos\"],\n                    body_payload[f\"{link_name}__ft_ref_global_pos\"],\n                ),\n                (\n                    \"lin_vel\",\n                    body_payload[f\"{link_name}__ref_lin_vel\"],\n                    body_payload[f\"{link_name}__ft_ref_lin_vel\"],\n                ),\n                (\n                    \"ang_vel\",\n                    body_payload[f\"{link_name}__ref_ang_vel\"],\n                    body_payload[f\"{link_name}__ft_ref_ang_vel\"],\n                ),\n            ],\n        )\n\n    _plot_dof_matrix(\n        output_dir / \"dof_pos_comparison.png\",\n        title=\"DOF Position Raw vs Filtered\",\n        dof_names=dof_names,\n        ref_values=ref_dof_pos,\n        ft_values=ft_ref_dof_pos,\n    )\n    _plot_dof_matrix(\n        output_dir / \"dof_vel_comparison.png\",\n        title=\"DOF Velocity Raw vs Filtered\",\n        dof_names=dof_names,\n        ref_values=ref_dof_vel,\n        ft_values=ft_ref_dof_vel,\n    )\n\n    return output_dir\n\n\ndef _to_plain_sequence(values) -> list[str]:\n    if values is None:\n        return []\n    if isinstance(values, (ListConfig, tuple, list)):\n        return [str(v) for v in values]\n    return [str(values)]\n\n\ndef export_reference_filter_artifacts_from_config(config) -> Path:\n    debug_cfg = getattr(config, \"debug_reference_filter_export\", None)\n    if debug_cfg is None or not bool(debug_cfg.get(\"enabled\", False)):\n        raise ValueError(\"debug_reference_filter_export.enabled must be true.\")\n\n    motion_cfg = config.robot.motion\n    online_filter_cfg = motion_cfg.get(\"online_filter\", {})\n    if not bool(online_filter_cfg.get(\"enabled\", False)):\n        raise ValueError(\n            \"Reference filter debug export requires robot.motion.online_filter.enabled=true.\"\n        )\n\n    output_dir = debug_cfg.get(\"output_dir\", None)\n    if output_dir in (None, \"\"):\n        output_dir = tempfile.mkdtemp(prefix=\"motrack-ref-filter-\")\n\n    train_dataset, _, _ = build_motion_datasets_from_cfg(\n        motion_cfg=motion_cfg,\n        max_frame_length=int(motion_cfg.max_frame_length),\n        min_window_length=int(motion_cfg.min_frame_length),\n        world_frame_normalization=bool(\n            motion_cfg.get(\"world_frame_normalization\", True)\n        ),\n    )\n    sample = train_dataset[0]\n\n    return export_reference_filter_debug_artifacts(\n        sample=sample,\n        output_dir=Path(str(output_dir)),\n        body_names=_to_plain_sequence(config.robot.body_names),\n        dof_names=_to_plain_sequence(config.robot.dof_names),\n        selected_body_links=_to_plain_sequence(\n            debug_cfg.get(\"selected_body_links\", [])\n        ),\n    )\n"
  },
  {
    "path": "holomotion/src/training/train.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport os\nfrom pathlib import Path\nimport sys\n\nimport hydra\nfrom hydra.utils import get_class\nfrom omegaconf import ListConfig, OmegaConf\n\nfrom accelerate import Accelerator\nfrom accelerate.utils import ProjectConfiguration\nfrom loguru import logger\nfrom holomotion.src.training.reference_filter_export import (\n    export_reference_filter_artifacts_from_config,\n)\nfrom holomotion.src.utils.config import compile_config\n\n\ndef _resolve_mujoco_eval_onnx_names(\n    exported_dir: Path, ckpt_onnx_names\n) -> list[str]:\n    if not exported_dir.is_dir():\n        raise FileNotFoundError(\n            f\"Exported ONNX directory not found: {exported_dir}\"\n        )\n    existing = sorted([p.name for p in exported_dir.glob(\"*.onnx\")])\n    if len(existing) == 0:\n        raise FileNotFoundError(f\"No .onnx files found under {exported_dir}\")\n    existing_set = set(existing)\n\n    if ckpt_onnx_names is None:\n        return existing\n    if isinstance(ckpt_onnx_names, ListConfig):\n        requested = list(ckpt_onnx_names)\n    elif isinstance(ckpt_onnx_names, (list, tuple)):\n        requested = list(ckpt_onnx_names)\n    else:\n        raise TypeError(\n            \"mujoco_eval.ckpt_onnx_names must be a list/tuple, \"\n            f\"got {type(ckpt_onnx_names)}\"\n        )\n    requested_norm = []\n    for name in requested:\n        name_str = str(name).strip()\n        if name_str == \"\":\n            continue\n        requested_norm.append(Path(name_str).name)\n    if len(requested_norm) == 0:\n        return existing\n\n    selected = [name for name in requested_norm if name in existing_set]\n    if len(selected) == 0:\n        raise ValueError(\n            \"No requested ONNX checkpoints exist under exported directory. \"\n            f\"exported_dir={exported_dir}, requested={requested_norm}, \"\n            f\"existing={existing}\"\n        )\n    return selected\n\n\ndef _exec_mujoco_eval(eval_override_dict: dict) -> None:\n    cli_args = []\n    for key in sorted(eval_override_dict.keys()):\n        value = eval_override_dict[key]\n        if value is None:\n            continue\n        if isinstance(value, bool):\n            cli_args.append(f\"{key}={'true' if value else 'false'}\")\n        elif isinstance(value, (int, float)):\n            cli_args.append(f\"{key}={value}\")\n        elif isinstance(value, str):\n            cli_args.append(f\"{key}={value}\")\n        elif isinstance(value, (list, tuple)):\n            inner = \",\".join([str(v) for v in value])\n            cli_args.append(f\"{key}=[{inner}]\")\n        else:\n            cli_args.append(f\"{key}={value}\")\n\n    argv = [\n        sys.executable,\n        \"-m\",\n        \"holomotion.src.evaluation.eval_mujoco_sim2sim\",\n    ] + cli_args\n    os.execv(sys.executable, argv)\n\n\ndef _maybe_export_reference_filter_artifacts(config: OmegaConf) -> None:\n    debug_cfg = getattr(config, \"debug_reference_filter_export\", None)\n    if debug_cfg is None or not bool(debug_cfg.get(\"enabled\", False)):\n        return\n    if not bool(getattr(config, \"main_process\", True)):\n        return\n    export_dir = export_reference_filter_artifacts_from_config(config)\n    logger.info(f\"Exported reference filter debug artifacts to: {export_dir}\")\n\n\n@hydra.main(\n    config_path=\"../../config\",\n    config_name=\"training/train_base\",\n    version_base=None,\n)\ndef main(config: OmegaConf):\n    \"\"\"Train the motion tracking model.\n\n    Args:\n        config: OmegaConf object containing the configuration.\n\n    \"\"\"\n\n    config = compile_config(config, accelerator=None)\n    dist = None\n\n    # In distributed runs, Hydra resolves ${now:...} per process so experiment_save_dir\n    # can differ by rank (e.g. staggered startup). Use Accelerator to init the process\n    # group, then broadcast rank 0's path so all ranks write to the same directory.\n    if getattr(config, \"num_processes\", 1) > 1:\n        project_config = ProjectConfiguration(\n            project_dir=config.experiment_save_dir,\n            logging_dir=config.experiment_save_dir,\n        )\n        _accelerator = Accelerator(project_config=project_config)\n        import torch.distributed as dist\n\n        path_list = (\n            [config.experiment_save_dir]\n            if _accelerator.is_main_process\n            else [None]\n        )\n        dist.broadcast_object_list(path_list, src=0)\n        config.experiment_save_dir = path_list[0]\n\n    _maybe_export_reference_filter_artifacts(config)\n    if dist is not None:\n        dist.barrier()\n\n    log_dir = config.experiment_save_dir\n    headless = config.headless\n    algo_class = get_class(config.algo._target_)\n    algo = algo_class(\n        env_config=config.env,\n        config=config.algo.config,\n        log_dir=log_dir,\n        headless=headless,\n    )\n\n    algo.load(config.checkpoint)\n    algo.learn()\n\n    if not bool(config.mujoco_eval.get(\"enabled\", False)):\n        return\n    if not bool(config.algo.config.get(\"export_policy\", False)):\n        msg = (\n            \"mujoco_eval.enabled=true requires \"\n            \"algo.config.export_policy=true to export ONNX \"\n            \"before post-training evaluation.\"\n        )\n        raise ValueError(msg)\n\n    if not bool(algo.is_main_process):\n        os._exit(0)\n\n    exported_dir = Path(log_dir) / \"exported\"\n    selected_onnx_names = _resolve_mujoco_eval_onnx_names(\n        exported_dir, config.mujoco_eval.get(\"ckpt_onnx_names\", None)\n    )\n    eval_override_dict = OmegaConf.to_container(\n        config.mujoco_eval, resolve=True\n    )\n    eval_override_dict.pop(\"enabled\", None)\n    eval_override_dict[\"ckpt_onnx_root_dir\"] = str(exported_dir)\n    eval_override_dict[\"ckpt_onnx_names\"] = selected_onnx_names\n    _exec_mujoco_eval(eval_override_dict)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/src/utils/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/utils/config.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\nimport copy\nimport math\nimport os\nfrom pathlib import Path\n\nimport torch\nfrom accelerate import Accelerator\nfrom loguru import logger\nfrom omegaconf import OmegaConf\n\n\ndef setup_hydra_resolvers():\n    \"\"\"Set up custom resolvers for OmegaConf.\n\n    This function registers a set of custom resolvers with OmegaConf to allow\n    for more dynamic and flexible configurations within Hydra. These resolvers\n    enable performing calculations, conditional logic, and other operations\n    directly in the YAML configuration files. For example,\n    you can use `${sqrt:4}` to get `2.0`.\n\n    The registered resolvers include:\n    - `eval`: Evaluates a Python expression.\n    - `if`: Conditional logic (if-else).\n    - `eq`: Case-insensitive string comparison.\n    - `sqrt`: Calculates the square root.\n    - `sum`: Sums a list of numbers.\n    - `ceil`: Computes the ceiling of a number.\n    - `int`: Casts a value to an integer.\n    - `len`: Returns the length of a list or string.\n    - `sum_list`: Sums a list of numbers.\n    \"\"\"\n    try:\n        OmegaConf.register_new_resolver(\"eval\", eval)\n        OmegaConf.register_new_resolver(\n            \"if\", lambda pred, a, b: a if pred else b\n        )\n        OmegaConf.register_new_resolver(\n            \"eq\", lambda x, y: x.lower() == y.lower()\n        )\n        OmegaConf.register_new_resolver(\"sqrt\", lambda x: math.sqrt(float(x)))\n        OmegaConf.register_new_resolver(\"sum\", lambda x: sum(x))\n        OmegaConf.register_new_resolver(\"ceil\", lambda x: math.ceil(x))\n        OmegaConf.register_new_resolver(\"int\", lambda x: int(x))\n        OmegaConf.register_new_resolver(\"len\", lambda x: len(x))\n        OmegaConf.register_new_resolver(\"sum_list\", lambda lst: sum(lst))\n    except Exception as e:\n        logger.warning(f\"Warning: Some resolvers already registered: {e}\")\n\n\ndef compile_config(\n    config: OmegaConf,\n    accelerator: Accelerator = None,\n    eval: bool = False,\n) -> None:\n    \"\"\"Compile the configuration.\n\n    Args:\n        config: Unresolved configuration.\n        accelerator: Accelerator instance.\n\n    Returns:\n        Compiled configuration.\n\n    \"\"\"\n    setup_hydra_resolvers()\n    config = copy.deepcopy(config)\n    config = compile_config_hf_accelerate(config, accelerator)\n    config = compile_config_directories(config, eval)\n    config = compile_config_devices(config)\n    return config\n\n\ndef compile_config_hf_accelerate(\n    config,\n    accelerator: Accelerator = None,\n) -> None:\n    \"\"\"Compile the configuration for HF Accelerate.\n\n    Args:\n        config: Configuration.\n        accelerator: Accelerator instance.\n\n    Returns:\n        Compiled configuration.\n\n    \"\"\"\n    if accelerator is not None:\n        device = accelerator.device\n        is_main_process = accelerator.is_main_process\n        process_idx = accelerator.process_index\n        total_processes = accelerator.num_processes\n    else:\n        # Best-effort distributed metadata when running under torchrun / Accelerate launch,\n        # even if an Accelerator instance is not provided yet.\n        process_idx = int(\n            os.environ.get(\n                \"RANK\", os.environ.get(\"ACCELERATE_PROCESS_INDEX\", \"0\")\n            )\n        )\n        total_processes = int(\n            os.environ.get(\n                \"WORLD_SIZE\", os.environ.get(\"ACCELERATE_NUM_PROCESSES\", \"1\")\n            )\n        )\n        local_rank = int(\n            os.environ.get(\n                \"LOCAL_RANK\",\n                os.environ.get(\"ACCELERATE_LOCAL_PROCESS_INDEX\", \"0\"),\n            )\n        )\n        is_main_process = process_idx == 0\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\", local_rank)\n        else:\n            device = torch.device(\"cpu\")\n\n    config.process_id = process_idx\n    config.num_processes = total_processes\n    config.main_process = is_main_process\n\n    if hasattr(config, \"device\"):\n        config.device = str(device)\n\n    logger.info(f\"Using device: {device}\")\n    if is_main_process:\n        logger.info(f\"Process {process_idx} on device: {device}\")\n\n    return config\n\n\ndef compile_config_devices(config):\n    \"\"\"Propagate device and process metadata into the environment configuration.\"\"\"\n    config = copy.deepcopy(config)\n    if hasattr(config, \"device\"):\n        device_str = str(config.device)\n    else:\n        device_str = str(\n            torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        )\n    world_size = getattr(config, \"num_processes\", 1)\n    process_rank = getattr(config, \"process_id\", 0)\n    is_main_process = getattr(config, \"main_process\", True)\n\n    if hasattr(config, \"env\") and hasattr(config.env, \"config\"):\n        env_cfg = config.env.config\n        env_cfg_struct = OmegaConf.is_struct(env_cfg)\n        OmegaConf.set_struct(env_cfg, False)\n        env_cfg.num_processes = world_size\n        env_cfg.process_id = process_rank\n        env_cfg.main_process = is_main_process\n        env_cfg.simulation_device = device_str\n        for key in [\n            \"sim_device\",\n            \"rl_device\",\n            \"compute_device\",\n            \"physx_device\",\n        ]:\n            setattr(env_cfg, key, device_str)\n        if hasattr(env_cfg, \"simulation\"):\n            for sim_key in [\"device\", \"compute_device\", \"rl_device\"]:\n                sim_cfg = env_cfg.simulation\n                sim_struct = OmegaConf.is_struct(sim_cfg)\n                OmegaConf.set_struct(sim_cfg, False)\n                setattr(sim_cfg, sim_key, device_str)\n                OmegaConf.set_struct(sim_cfg, sim_struct)\n        OmegaConf.set_struct(env_cfg, env_cfg_struct)\n\n    return config\n\n\ndef compile_config_directories(config, eval: bool = False) -> None:\n    \"\"\"Compile the configuration for folders.\n\n    Args:\n        config: Configuration.\n\n    Returns:\n        Compiled configuration.\n\n    \"\"\"\n    if eval:\n        return config\n    config = copy.deepcopy(config)\n    experiment_save_dir = Path(config.experiment_dir)\n    experiment_save_dir.mkdir(exist_ok=True, parents=True)\n    config.experiment_save_dir = str(experiment_save_dir)\n    if hasattr(config, \"env\"):\n        config.env.config.save_rendering_dir = str(\n            Path(config.experiment_dir) / \"renderings_training\"\n        )\n    unresolved_conf = OmegaConf.to_container(config, resolve=False)\n    if config.main_process:\n        logger.info(f\"Saving config file to {experiment_save_dir}\")\n        with open(experiment_save_dir / \"config.yaml\", \"w\") as file:\n            OmegaConf.save(unresolved_conf, file)\n    return config\n"
  },
  {
    "path": "holomotion/src/utils/frame_utils.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nimport isaaclab.utils.math as isaaclab_math\nimport torch\n\n\ndef positions_world_to_env_frame(\n    positions_w: torch.Tensor,\n    env_origins: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Convert simulator-world positions to IsaacLab environment frame.\n\n    IsaacLab's MDP root position helpers return positions in the environment\n    frame, i.e. simulation-world coordinates with per-environment\n    `env_origins` subtracted. This helper applies the same\n    translation removal to arbitrary position tensors so position arithmetic\n    stays frame-consistent.\n    \"\"\"\n    if positions_w.ndim < 2 or positions_w.shape[-1] != 3:\n        raise ValueError(\n            \"positions_w must have shape [B, ..., 3], \"\n            f\"got {tuple(positions_w.shape)}.\"\n        )\n    if env_origins.ndim != 2 or env_origins.shape[-1] != 3:\n        raise ValueError(\n            \"env_origins must have shape [B, 3], \"\n            f\"got {tuple(env_origins.shape)}.\"\n        )\n    if positions_w.shape[0] != env_origins.shape[0]:\n        raise ValueError(\n            \"Batch size mismatch between positions_w and env_origins: \"\n            f\"{positions_w.shape[0]} vs {env_origins.shape[0]}.\"\n        )\n    origin_view = env_origins.view(\n        env_origins.shape[0],\n        *([1] * (positions_w.ndim - 2)),\n        3,\n    )\n    return positions_w - origin_view\n\n\ndef root_relative_positions_from_env_frame(\n    body_pos_env: torch.Tensor,\n    root_pos_env: torch.Tensor,\n    root_quat_w: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Convert environment-frame body positions into the root frame.\n\n    The input positions must already be in IsaacLab's environment frame rather\n    than raw simulator-world coordinates. Orientation is unaffected by\n    `env_origins`, so the articulation root quaternion is reused directly.\n    \"\"\"\n    if body_pos_env.ndim < 3 or body_pos_env.shape[-1] != 3:\n        raise ValueError(\n            \"body_pos_env must have shape [B, ..., 3], \"\n            f\"got {tuple(body_pos_env.shape)}.\"\n        )\n    if root_pos_env.ndim != 2 or root_pos_env.shape[-1] != 3:\n        raise ValueError(\n            \"root_pos_env must have shape [B, 3], \"\n            f\"got {tuple(root_pos_env.shape)}.\"\n        )\n    if root_quat_w.ndim != 2 or root_quat_w.shape[-1] != 4:\n        raise ValueError(\n            \"root_quat_w must have shape [B, 4], \"\n            f\"got {tuple(root_quat_w.shape)}.\"\n        )\n    if body_pos_env.shape[0] != root_pos_env.shape[0]:\n        raise ValueError(\n            \"Batch size mismatch between body_pos_env and root_pos_env: \"\n            f\"{body_pos_env.shape[0]} vs {root_pos_env.shape[0]}.\"\n        )\n    if body_pos_env.shape[0] != root_quat_w.shape[0]:\n        raise ValueError(\n            \"Batch size mismatch between body_pos_env and root_quat_w: \"\n            f\"{body_pos_env.shape[0]} vs {root_quat_w.shape[0]}.\"\n        )\n    root_pos_view = root_pos_env.view(\n        root_pos_env.shape[0],\n        *([1] * (body_pos_env.ndim - 2)),\n        3,\n    )\n    root_quat_view = root_quat_w.view(\n        root_quat_w.shape[0],\n        *([1] * (body_pos_env.ndim - 2)),\n        4,\n    ).expand(*body_pos_env.shape[:-1], 4)\n    rel_pos_env = body_pos_env - root_pos_view\n    return isaaclab_math.quat_apply_inverse(root_quat_view, rel_pos_env)\n\n\ndef root_relative_positions_from_mixed_position_frames(\n    body_pos_w: torch.Tensor,\n    root_pos_env: torch.Tensor,\n    root_quat_w: torch.Tensor,\n    env_origins: torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Build root-relative positions from world-frame bodies.\n\n    This is the safe adapter for common IsaacLab code paths where body poses\n    are read from `robot.data.body_pos_w` in simulator world coordinates while\n    `isaaclab_mdp.root_pos_w(env)` is already expressed in the environment\n    frame.\n    \"\"\"\n    body_pos_env = positions_world_to_env_frame(body_pos_w, env_origins)\n    return root_relative_positions_from_env_frame(\n        body_pos_env=body_pos_env,\n        root_pos_env=root_pos_env,\n        root_quat_w=root_quat_w,\n    )\n"
  },
  {
    "path": "holomotion/src/utils/isaac_utils/__init__.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n"
  },
  {
    "path": "holomotion/src/utils/isaac_utils/maths.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n# This file was originally copied from the [ASAP] repository:\n# https://github.com/LeCAR-Lab/ASAP\n# Modifications have been made to fit the needs of this project.\n\nimport os\nimport random\n\nimport numpy as np\nimport torch\n\n\n@torch.jit.script\ndef normalize(x, eps: float = 1e-9):\n    return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)\n\n\n@torch.jit.script\ndef torch_rand_float(lower, upper, shape, device):\n    # type: (float, float, Tuple[int, int], str) -> Tensor\n    return (upper - lower) * torch.rand(*shape, device=device) + lower\n\n\n@torch.jit.script\ndef copysign(a, b):\n    # type: (float, Tensor) -> Tensor\n    a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])\n    return torch.abs(a) * torch.sign(b)\n\n\ndef set_seed(seed, torch_deterministic=False):\n    \"\"\"Set seed across modules\"\"\"\n    if seed == -1 and torch_deterministic:\n        seed = 42\n    elif seed == -1:\n        seed = np.random.randint(0, 10000)\n    print(f\"Setting seed: {seed}\")\n\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n    if torch_deterministic:\n        # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n        os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n        torch.backends.cudnn.benchmark = False\n        torch.backends.cudnn.deterministic = True\n        torch.use_deterministic_algorithms(True)\n    else:\n        torch.backends.cudnn.benchmark = True\n        torch.backends.cudnn.deterministic = False\n\n    return seed\n"
  },
  {
    "path": "holomotion/src/utils/isaac_utils/rotations.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n# This file was originally copied from the [ASAP] repository:\n# https://github.com/LeCAR-Lab/ASAP\n# Modifications have been made to fit the needs of this project.\n\nfrom typing import List, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom holomotion.src.utils.isaac_utils.maths import (\n    copysign,\n    normalize,\n)\n\n\n@torch.jit.script\ndef quat_unit(a):\n    return normalize(a)\n\n\n@torch.jit.script\ndef quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor:\n    shape = b.shape\n    a = a.reshape(-1, 4)\n    b = b.reshape(-1, 3)\n    if w_last:\n        xyz = a[:, :3]\n        w = a[:, 3:]\n    else:\n        xyz = a[:, 1:]\n        w = a[:, :1]\n    t = xyz.cross(b, dim=-1) * 2\n    return (b + w * t + xyz.cross(t, dim=-1)).view(shape)\n\n\n@torch.jit.script\ndef quat_apply_yaw(quat: Tensor, vec: Tensor, w_last: bool) -> Tensor:\n    quat_yaw = quat.clone().view(-1, 4)\n    quat_yaw[:, :2] = 0.0\n    quat_yaw = normalize(quat_yaw)\n    return quat_apply(quat_yaw, vec, w_last)\n\n\n@torch.jit.script\ndef wrap_to_pi(angles):\n    angles %= 2 * np.pi\n    angles -= 2 * np.pi * (angles > np.pi)\n    return angles\n\n\n@torch.jit.script\ndef quat_conjugate(a: Tensor, w_last: bool) -> Tensor:\n    shape = a.shape\n    a = a.reshape(-1, 4)\n    if w_last:\n        return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)\n    else:\n        return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape)\n\n\n@torch.jit.script\ndef quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor:\n    shape = b.shape\n    a = a.reshape(-1, 4)\n    b = b.reshape(-1, 3)\n    if w_last:\n        xyz = a[:, :3]\n        w = a[:, 3:]\n    else:\n        xyz = a[:, 1:]\n        w = a[:, :1]\n    t = xyz.cross(b, dim=-1) * 2\n    return (b + w * t + xyz.cross(t, dim=-1)).view(shape)\n\n\n@torch.jit.script\ndef quat_rotate(q: Tensor, v: Tensor, w_last: bool) -> Tensor:\n    shape = q.shape\n    if w_last:\n        q_w = q[:, -1]\n        q_vec = q[:, :3]\n    else:\n        q_w = q[:, 0]\n        q_vec = q[:, 1:]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a + b + c\n\n\n@torch.jit.script\ndef quat_rotate_inverse(q: Tensor, v: Tensor, w_last: bool) -> Tensor:\n    shape = q.shape\n    if w_last:\n        q_w = q[:, -1]\n        q_vec = q[:, :3]\n    else:\n        q_w = q[:, 0]\n        q_vec = q[:, 1:]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a - b + c\n\n\n@torch.jit.script\ndef quat_angle_axis(x: Tensor, w_last: bool) -> Tuple[Tensor, Tensor]:\n    \"\"\"The (angle, axis) representation of the rotation. The axis is normalized to unit length.\n    The angle is guaranteed to be between [0, pi].\n    \"\"\"\n    if w_last:\n        w = x[..., -1]\n        axis = x[..., :3]\n    else:\n        w = x[..., 0]\n        axis = x[..., 1:]\n    s = 2 * (w**2) - 1\n    angle = s.clamp(-1, 1).arccos()  # just to be safe\n    axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9)\n    return angle, axis\n\n\n@torch.jit.script\ndef quat_from_angle_axis(angle: Tensor, axis: Tensor, w_last: bool) -> Tensor:\n    theta = (angle / 2).unsqueeze(-1)\n    xyz = normalize(axis) * theta.sin()\n    w = theta.cos()\n    if w_last:\n        return quat_unit(torch.cat([xyz, w], dim=-1))\n    else:\n        return quat_unit(torch.cat([w, xyz], dim=-1))\n\n\n@torch.jit.script\ndef vec_to_heading(h_vec):\n    h_theta = torch.atan2(h_vec[..., 1], h_vec[..., 0])\n    return h_theta\n\n\n@torch.jit.script\ndef heading_to_quat(h_theta, w_last: bool):\n    axis = torch.zeros(\n        h_theta.shape\n        + [\n            3,\n        ],\n        device=h_theta.device,\n    )\n    axis[..., 2] = 1\n    heading_q = quat_from_angle_axis(h_theta, axis, w_last=w_last)\n    return heading_q\n\n\n@torch.jit.script\ndef quat_axis(q: Tensor, axis: int, w_last: bool) -> Tensor:\n    basis_vec = torch.zeros(q.shape[0], 3, device=q.device)\n    basis_vec[:, axis] = 1\n    return quat_rotate(q, basis_vec, w_last)\n\n\n@torch.jit.script\ndef normalize_angle(x):\n    return torch.atan2(torch.sin(x), torch.cos(x))\n\n\n@torch.jit.script\ndef get_basis_vector(q: Tensor, v: Tensor, w_last: bool) -> Tensor:\n    return quat_rotate(q, v, w_last)\n\n\n@torch.jit.script\ndef quat_to_angle_axis(q):\n    # type: (Tensor) -> Tuple[Tensor, Tensor]\n    # computes axis-angle representation from quaternion q\n    # q must be normalized\n    # ZL: could have issues.\n    min_theta = 1e-5\n    qx, qy, qz, qw = 0, 1, 2, 3\n\n    sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])\n    angle = 2 * torch.acos(q[..., qw])\n    angle = normalize_angle(angle)\n    sin_theta_expand = sin_theta.unsqueeze(-1)\n    axis = q[..., qx:qw] / sin_theta_expand\n\n    mask = torch.abs(sin_theta) > min_theta\n    default_axis = torch.zeros_like(axis)\n    default_axis[..., -1] = 1\n\n    angle = torch.where(mask, angle, torch.zeros_like(angle))\n    mask_expand = mask.unsqueeze(-1)\n    axis = torch.where(mask_expand, axis, default_axis)\n    return angle, axis\n\n\n@torch.jit.script\ndef slerp(q0, q1, t):\n    # type: (Tensor, Tensor, Tensor) -> Tensor\n    cos_half_theta = torch.sum(q0 * q1, dim=-1)\n\n    neg_mask = cos_half_theta < 0\n    q1 = q1.clone()\n    q1[neg_mask] = -q1[neg_mask]\n    cos_half_theta = torch.abs(cos_half_theta)\n    cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)\n\n    half_theta = torch.acos(cos_half_theta)\n    sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)\n\n    ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta\n    ratioB = torch.sin(t * half_theta) / sin_half_theta\n\n    new_q = ratioA * q0 + ratioB * q1\n\n    new_q = torch.where(\n        torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q\n    )\n    new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)\n\n    return new_q\n\n\n@torch.jit.script\ndef angle_axis_to_exp_map(angle, axis):\n    # type: (Tensor, Tensor) -> Tensor\n    # compute exponential map from axis-angle\n    angle_expand = angle.unsqueeze(-1)\n    exp_map = angle_expand * axis\n    return exp_map\n\n\n@torch.jit.script\ndef my_quat_rotate(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a + b + c\n\n\n@torch.jit.script\ndef calc_heading(q):\n    # type: (Tensor) -> Tensor\n    # calculate heading direction from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    # this is the x axis heading\n    ref_dir = torch.zeros_like(q[..., 0:3])\n    ref_dir[..., 0] = 1\n    rot_dir = my_quat_rotate(q, ref_dir)\n\n    heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])\n    return heading\n\n\n@torch.jit.script\ndef quat_to_exp_map(q):\n    # type: (Tensor) -> Tensor\n    # compute exponential map from quaternion\n    # q must be normalized\n    angle, axis = quat_to_angle_axis(q)\n    exp_map = angle_axis_to_exp_map(angle, axis)\n    return exp_map\n\n\n@torch.jit.script\ndef calc_heading_quat(q, w_last):\n    # type: (Tensor, bool) -> Tensor\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(heading, axis, w_last=w_last)\n    return heading_q\n\n\n@torch.jit.script\ndef calc_heading_quat_inv(q, w_last):\n    # type: (Tensor, bool) -> Tensor\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(-heading, axis, w_last=w_last)\n    return heading_q\n\n\n@torch.jit.script\ndef quat_inverse(x, w_last):\n    # type: (Tensor, bool) -> Tensor\n    \"\"\"The inverse of the rotation\"\"\"\n    return quat_conjugate(x, w_last=w_last)\n\n\n@torch.jit.script\ndef get_euler_xyz(q: Tensor, w_last: bool) -> Tuple[Tensor, Tensor, Tensor]:\n    if w_last:\n        qx, qy, qz, qw = 0, 1, 2, 3\n    else:\n        qw, qx, qy, qz = 0, 1, 2, 3\n    # roll (x-axis rotation)\n    sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])\n    cosr_cosp = (\n        q[:, qw] * q[:, qw]\n        - q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        + q[:, qz] * q[:, qz]\n    )\n    roll = torch.atan2(sinr_cosp, cosr_cosp)\n\n    # pitch (y-axis rotation)\n    sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])\n    pitch = torch.where(\n        torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)\n    )\n\n    # yaw (z-axis rotation)\n    siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])\n    cosy_cosp = (\n        q[:, qw] * q[:, qw]\n        + q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        - q[:, qz] * q[:, qz]\n    )\n    yaw = torch.atan2(siny_cosp, cosy_cosp)\n\n    return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)\n\n\n# @torch.jit.script\ndef get_euler_xyz_in_tensor(q):\n    qx, qy, qz, qw = 0, 1, 2, 3\n    # roll (x-axis rotation)\n    sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])\n    cosr_cosp = (\n        q[:, qw] * q[:, qw]\n        - q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        + q[:, qz] * q[:, qz]\n    )\n    roll = torch.atan2(sinr_cosp, cosr_cosp)\n\n    # pitch (y-axis rotation)\n    sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])\n    pitch = torch.where(\n        torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)\n    )\n\n    # yaw (z-axis rotation)\n    siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])\n    cosy_cosp = (\n        q[:, qw] * q[:, qw]\n        + q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        - q[:, qz] * q[:, qz]\n    )\n    yaw = torch.atan2(siny_cosp, cosy_cosp)\n\n    return torch.stack((roll, pitch, yaw), dim=-1)\n\n\n@torch.jit.script\ndef quat_pos(x):\n    \"\"\"Make all the real part of the quaternion positive\"\"\"\n    q = x\n    z = (q[..., 3:] < 0).float()\n    q = (1 - 2 * z) * q\n    return q\n\n\n@torch.jit.script\ndef is_valid_quat(q):\n    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]\n    return (w * w + x * x + y * y + z * z).allclose(torch.ones_like(w))\n\n\n@torch.jit.script\ndef quat_normalize(q):\n    \"\"\"Construct 3D rotation from quaternion (the quaternion needs not to be normalized).\"\"\"\n    q = quat_unit(quat_pos(q))  # normalized to positive and unit quaternion\n    return q\n\n\n@torch.jit.script\ndef quat_mul(a, b, w_last: bool):\n    assert a.shape == b.shape\n    shape = a.shape\n    a = a.reshape(-1, 4)\n    b = b.reshape(-1, 4)\n\n    if w_last:\n        x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]\n        x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]\n    else:\n        w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]\n        w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]\n    ww = (z1 + x1) * (x2 + y2)\n    yy = (w1 - y1) * (w2 + z2)\n    zz = (w1 + y1) * (w2 - z2)\n    xx = ww + yy + zz\n    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))\n    w = qq - ww + (z1 - y1) * (y2 - z2)\n    x = qq - xx + (x1 + w1) * (x2 + w2)\n    y = qq - yy + (w1 - x1) * (y2 + z2)\n    z = qq - zz + (z1 + y1) * (w2 - x2)\n\n    if w_last:\n        quat = torch.stack([x, y, z, w], dim=-1).view(shape)\n    else:\n        quat = torch.stack([w, x, y, z], dim=-1).view(shape)\n\n    return quat\n\n\n@torch.jit.script\ndef quat_mul_norm(x, y, w_last):\n    # type: (Tensor, Tensor, bool) -> Tensor\n    r\"\"\"Combine two set of 3D rotations together using \\**\\* operator. The shape needs to be\n    broadcastable\n    \"\"\"\n    return quat_normalize(quat_mul(x, y, w_last))\n\n\n@torch.jit.script\ndef quat_mul_norm(x, y, w_last):\n    # type: (Tensor, Tensor, bool) -> Tensor\n    r\"\"\"Combine two set of 3D rotations together using \\**\\* operator. The shape needs to be\n    broadcastable\n    \"\"\"\n    return quat_unit(quat_mul(x, y, w_last))\n\n\n@torch.jit.script\ndef quat_identity(shape: List[int]):\n    \"\"\"Construct 3D identity rotation given shape\"\"\"\n    w = torch.ones(shape + [1])\n    xyz = torch.zeros(shape + [3])\n    q = torch.cat([xyz, w], dim=-1)\n    return quat_normalize(q)\n\n\n@torch.jit.script\ndef quat_identity_like(x):\n    \"\"\"Construct identity 3D rotation with the same shape\"\"\"\n    return quat_identity(x.shape[:-1])\n\n\n@torch.jit.script\ndef transform_from_rotation_translation(\n    r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None\n):\n    \"\"\"Construct a transform from a quaternion and 3D translation. Only one of them can be None.\"\"\"\n    assert r is not None or t is not None, (\n        \"rotation and translation can't be all None\"\n    )\n    if r is None:\n        assert t is not None\n        r = quat_identity(list(t.shape))\n    if t is None:\n        t = torch.zeros(list(r.shape) + [3])\n    return torch.cat([r, t], dim=-1)\n\n\n@torch.jit.script\ndef transform_rotation(x):\n    \"\"\"Get rotation from transform\"\"\"\n    return x[..., :4]\n\n\n@torch.jit.script\ndef transform_translation(x):\n    \"\"\"Get translation from transform\"\"\"\n    return x[..., 4:]\n\n\n@torch.jit.script\ndef transform_mul(x, y):\n    \"\"\"Combine two transformation together\"\"\"\n    z = transform_from_rotation_translation(\n        r=quat_mul_norm(\n            transform_rotation(x), transform_rotation(y), w_last=True\n        ),\n        t=quat_rotate(\n            transform_rotation(x), transform_translation(y), w_last=True\n        )\n        + transform_translation(x),\n    )\n    return z\n\n\n@torch.compile\ndef quaternion_to_matrix(\n    quaternions: torch.Tensor,\n    w_last: bool = True,\n) -> torch.Tensor:\n    \"\"\"Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: quaternions as tensor of shape (..., 4).\n            If w_last=True (default): real part last (x, y, z, w)\n            If w_last=False: real part first (w, x, y, z)\n        w_last: If True, quaternion format is (x, y, z, w).\n                If False, quaternion format is (w, x, y, z). Default: True.\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n\n    \"\"\"\n    if w_last:\n        i, j, k, r = torch.unbind(quaternions, -1)\n    else:\n        r, i, j, k = torch.unbind(quaternions, -1)\n\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\n\n@torch.jit.script\ndef axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as axis/angle to quaternions.\n\n    Args:\n        axis_angle: Rotations given as a vector in axis angle form,\n            as a tensor of shape (..., 3), where the magnitude is\n            the angle turned anticlockwise in radians around the\n            vector's direction.\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n\n    \"\"\"\n    angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)\n    half_angles = angles * 0.5\n    eps = 1e-6\n    small_angles = angles.abs() < eps\n    sin_half_angles_over_angles = torch.empty_like(angles)\n    sin_half_angles_over_angles[~small_angles] = (\n        torch.sin(half_angles[~small_angles]) / angles[~small_angles]\n    )\n    # for x small, sin(x/2) is about x/2 - (x/2)^3/6\n    # so sin(x/2)/x is about 1/2 - (x*x)/48\n    sin_half_angles_over_angles[small_angles] = (\n        0.5 - (angles[small_angles] * angles[small_angles]) / 48\n    )\n    quaternions = torch.cat(\n        [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles],\n        dim=-1,\n    )\n    return quaternions\n\n\n# @torch.jit.script\ndef wxyz_to_xyzw(quat):\n    return quat[..., [1, 2, 3, 0]]\n\n\n# @torch.jit.script\ndef xyzw_to_wxyz(quat):\n    return quat[..., [3, 0, 1, 2]]\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"W x y z\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        matrix.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    # we produce the desired quaternion multiplied by each of r, i, j, k\n    quat_by_rijk = torch.stack(\n        [\n            torch.stack(\n                [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1\n            ),\n            torch.stack(\n                [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1\n            ),\n            torch.stack(\n                [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1\n            ),\n            torch.stack(\n                [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1\n            ),\n        ],\n        dim=-2,\n    )\n\n    # We floor here at 0.1 but the exact level is not important; if q_abs is small,\n    # the candidate won't be picked.\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),\n    # forall i; we pick the best-conditioned one (with the largest denominator)\n\n    return quat_candidates[\n        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,\n        :,  # pyre-ignore[16]\n    ].reshape(batch_dim + (4,))\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Returns torch.sqrt(torch.max(0, x))\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    ret[positive_mask] = torch.sqrt(x[positive_mask])\n    return ret\n\n\ndef quat_w_first(rot):\n    rot = torch.cat([rot[..., [-1]], rot[..., :-1]], -1)\n    return rot\n\n\n@torch.jit.script\ndef quat_from_euler_xyz(roll, pitch, yaw):\n    cy = torch.cos(yaw * 0.5)\n    sy = torch.sin(yaw * 0.5)\n    cr = torch.cos(roll * 0.5)\n    sr = torch.sin(roll * 0.5)\n    cp = torch.cos(pitch * 0.5)\n    sp = torch.sin(pitch * 0.5)\n\n    qw = cy * cr * cp + sy * sr * sp\n    qx = cy * sr * cp - sy * cr * sp\n    qy = cy * cr * sp + sy * sr * cp\n    qz = sy * cr * cp - cy * sr * sp\n\n    return torch.stack([qx, qy, qz, qw], dim=-1)\n\n\n@torch.compile\ndef remove_yaw_component(\n    quat_raw: Tensor,\n    quat_init: Tensor,\n    w_last: bool = True,\n) -> Tensor:\n    \"\"\"Remove yaw component from quaternion while keeping roll and pitch.\n\n    This function extracts the yaw component from the initial quaternion and uses\n    it to normalize the raw quaternion, effectively removing the initial heading\n    offset while preserving roll and pitch components.\n\n    Args:\n        quat_raw: Current quaternion from IMU, shape (..., 4)\n        quat_init: Initial quaternion (contains the yaw to be removed), shape (..., 4)\n        w_last: If True, quaternion format is (x, y, z, w).\n                If False, quaternion format is (w, x, y, z). Default: True.\n\n    Returns:\n        Quaternion with initial yaw component removed, same shape as input.\n        The resulting quaternion represents roll and pitch relative to the\n        heading-aligned coordinate frame.\n\n    Example:\n        >>> # Initial robot orientation (roll=0°, pitch=0°, yaw=45°)\n        >>> quat_init = quat_from_euler_xyz(\n        ...     torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.7854)\n        ... )\n        >>> # Current IMU reading (roll=10°, pitch=20°, yaw=60°)\n        >>> quat_raw = quat_from_euler_xyz(\n        ...     torch.tensor(0.1745),\n        ...     torch.tensor(0.3491),\n        ...     torch.tensor(1.0472),\n        ... )\n        >>> quat_norm = remove_yaw_component(quat_raw, quat_init)\n        >>> # quat_norm contains roll=10°, pitch=20°, with initial yaw offset removed\n    \"\"\"\n    # Extract quaternion components based on format\n    if w_last:\n        q_w = quat_init[..., -1]\n        q_vec = quat_init[..., :3]\n    else:\n        q_w = quat_init[..., 0]\n        q_vec = quat_init[..., 1:]\n\n    # Calculate heading by rotating x-axis with quaternion\n    # ref_dir = [1, 0, 0] (x-axis)\n    ref_dir = torch.zeros_like(q_vec)\n    ref_dir[..., 0] = 1.0\n\n    # Quaternion rotation: v' = v + 2 * w * (q_vec × v) + 2 * q_vec × (q_vec × v)\n    cross1 = torch.cross(q_vec, ref_dir, dim=-1)\n    cross2 = torch.cross(q_vec, cross1, dim=-1)\n    rot_dir = ref_dir + 2.0 * q_w.unsqueeze(-1) * cross1 + 2.0 * cross2\n\n    # Extract heading angle from rotated x-axis\n    heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])\n\n    # Create inverse heading quaternion (rotation about negative z-axis)\n    half_heading = (-heading) * 0.5\n    heading_q_inv = torch.zeros_like(quat_init)\n\n    if w_last:\n        heading_q_inv[..., 0] = 0.0  # x\n        heading_q_inv[..., 1] = 0.0  # y\n        heading_q_inv[..., 2] = torch.sin(half_heading)  # z\n        heading_q_inv[..., 3] = torch.cos(half_heading)  # w\n    else:\n        heading_q_inv[..., 0] = torch.cos(half_heading)  # w\n        heading_q_inv[..., 1] = 0.0  # x\n        heading_q_inv[..., 2] = 0.0  # y\n        heading_q_inv[..., 3] = torch.sin(half_heading)  # z\n\n    # Quaternion multiplication: heading_q_inv * quat_raw\n    shape = quat_raw.shape\n    a = heading_q_inv.reshape(-1, 4)\n    b = quat_raw.reshape(-1, 4)\n\n    if w_last:\n        x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]\n        x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]\n    else:\n        w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]\n        w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]\n\n    # Quaternion multiplication formula\n    ww = (z1 + x1) * (x2 + y2)\n    yy = (w1 - y1) * (w2 + z2)\n    zz = (w1 + y1) * (w2 - z2)\n    xx = ww + yy + zz\n    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))\n    w = qq - ww + (z1 - y1) * (y2 - z2)\n    x = qq - xx + (x1 + w1) * (x2 + w2)\n    y = qq - yy + (w1 - x1) * (y2 + z2)\n    z = qq - zz + (z1 + y1) * (w2 - x2)\n\n    if w_last:\n        quat_result = torch.stack([x, y, z, w], dim=-1).view(shape)\n    else:\n        quat_result = torch.stack([w, x, y, z], dim=-1).view(shape)\n\n    # Normalize the result quaternion\n    norm = torch.norm(quat_result, p=2, dim=-1, keepdim=True)\n    quat_norm = quat_result / norm.clamp(min=1e-8)\n\n    return quat_norm\n"
  },
  {
    "path": "holomotion/src/utils/isaac_utils/setup.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n# This file was originally copied from the [ASAP] repository:\n# https://github.com/LeCAR-Lab/ASAP\n# Modifications have been made to fit the needs of this project.\n\nfrom setuptools import setup\n\nsetup(\n    name=\"isaac_utils\",\n    packages=[\"isaac_utils\"],\n    version=\"0.0.1\",\n    description=\"Unified torch env_utils for IsaacGym and IsaacSim\",\n    author=\"\",\n    classifiers=[],\n)\n"
  },
  {
    "path": "holomotion/src/utils/onnx_export.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nimport inspect\nimport re\nfrom pathlib import Path\n\nfrom loguru import logger\n\n\ndef _list_to_csv_str(arr, *, decimals: int = 3, delimiter: str = \",\") -> str:\n    fmt = f\"{{:.{decimals}f}}\"\n    return delimiter.join(\n        fmt.format(x) if isinstance(x, (int, float)) else str(x) for x in arr\n    )\n\n\ndef attach_onnx_metadata_holomotion(env, onnx_path: str) -> None:\n    import onnx\n\n    metadata = {\n        \"joint_names\": env.scene[\"robot\"].data.joint_names,\n        \"joint_stiffness\": env.scene[\"robot\"]\n        .data.default_joint_stiffness[0]\n        .cpu()\n        .tolist(),\n        \"joint_damping\": env.scene[\"robot\"]\n        .data.default_joint_damping[0]\n        .cpu()\n        .tolist(),\n        \"default_joint_pos\": env.scene[\"robot\"]\n        .data.default_joint_pos[0]\n        .cpu()\n        .tolist(),\n        \"action_scale\": env.action_manager.get_term(\"dof_pos\")\n        ._scale[0]\n        .cpu()\n        .tolist(),\n    }\n\n    model = onnx.load(onnx_path)\n    for key, value in metadata.items():\n        entry = onnx.StringStringEntryProto()\n        entry.key = key\n        entry.value = (\n            _list_to_csv_str(value) if isinstance(value, list) else str(value)\n        )\n        model.metadata_props.append(entry)\n    onnx.save(model, onnx_path)\n\n\ndef export_policy_to_onnx(\n    algo,\n    checkpoint_path: str,\n    *,\n    onnx_name_suffix: str | None = None,\n    use_kv_cache: bool = True,\n) -> str:\n    checkpoint = Path(checkpoint_path)\n    export_dir = checkpoint.parent / \"exported\"\n    export_dir.mkdir(exist_ok=True)\n\n    onnx_name = checkpoint.name.replace(\".pt\", \".onnx\")\n    if onnx_name_suffix is not None:\n        suffix = re.sub(r\"[\\s+]\", \"_\", str(onnx_name_suffix))\n        onnx_name = onnx_name.replace(\".onnx\", f\"_{suffix}.onnx\")\n    onnx_path = export_dir / onnx_name\n\n    logger.info(\"Starting ONNX minimal policy export (actions-only)...\")\n    actor_was_training = getattr(algo.actor, \"training\", None)\n    critic_was_training = getattr(algo.critic, \"training\", None)\n    algo.actor.eval()\n    algo.critic.eval()\n\n    try:\n        actor_for_export = algo.accelerator.unwrap_model(algo.actor)\n        orig_mod = getattr(actor_for_export, \"_orig_mod\", None)\n        if orig_mod is not None:\n            actor_for_export = orig_mod\n\n        export_signature = inspect.signature(actor_for_export.export_onnx)\n        export_kwargs = {\"onnx_path\": onnx_path, \"opset_version\": 17}\n        if \"use_kv_cache\" in export_signature.parameters:\n            export_kwargs[\"use_kv_cache\"] = bool(use_kv_cache)\n\n        onnx_path_str = actor_for_export.export_onnx(**export_kwargs)\n        attach_onnx_metadata_holomotion(algo.env._env, onnx_path=onnx_path_str)\n        logger.info(\n            f\"Successfully exported minimal policy to: {onnx_path_str}\"\n        )\n        return onnx_path_str\n    finally:\n        if actor_was_training is not None:\n            algo.actor.train(actor_was_training)\n        if critic_was_training is not None:\n            algo.critic.train(critic_was_training)\n"
  },
  {
    "path": "holomotion/src/utils/reference_prefix.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\nfrom typing import Mapping\n\n\ndef resolve_reference_tensor_key(\n    batch_tensors: Mapping[str, object],\n    base_key: str,\n    prefix: str = \"ref_\",\n) -> str:\n    tensor_key = base_key\n    if prefix:\n        prefixed_key = f\"{prefix}{base_key}\"\n        if prefixed_key in batch_tensors:\n            tensor_key = prefixed_key\n        elif prefix == \"ft_ref_\":\n            raise KeyError(\n                f\"Filtered tensor '{prefixed_key}' is not present in the \"\n                \"current motion cache batch. Ensure online filtering is \"\n                \"enabled and 'ft_ref_' is materialized in allowed_prefixes.\"\n            )\n        elif base_key not in batch_tensors:\n            raise KeyError(\n                f\"Neither '{prefixed_key}' nor '{base_key}' is present in \"\n                \"the current motion cache batch.\"\n            )\n    elif base_key not in batch_tensors:\n        raise KeyError(\n            f\"Tensor '{base_key}' is not present in the current motion cache batch.\"\n        )\n    return tensor_key\n"
  },
  {
    "path": "holomotion/src/utils/torch_utils.py",
    "content": "\"\"\"Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited.\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\n\ndef to_torch(x, dtype=torch.float, device=\"cpu\", requires_grad=False):\n    return torch.tensor(\n        x, dtype=dtype, device=device, requires_grad=requires_grad\n    )\n\n\n@torch.jit.script\ndef quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:\n    \"\"\"Multiply two quaternions together.\n\n    Args:\n        q1: The first quaternion in (w, x, y, z). Shape is (..., 4).\n        q2: The second quaternion in (w, x, y, z). Shape is (..., 4).\n\n    Returns:\n        The product of the two quaternions in (w, x, y, z). Shape is (..., 4).\n\n    Raises:\n        ValueError: Input shapes of ``q1`` and ``q2`` are not matching.\n    \"\"\"\n    # check input is correct\n    if q1.shape != q2.shape:\n        msg = f\"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}.\"\n        raise ValueError(msg)\n    # reshape to (N, 4) for multiplication\n    shape = q1.shape\n    q1 = q1.reshape(-1, 4)\n    q2 = q2.reshape(-1, 4)\n    # extract components from quaternions\n    w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]\n    w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]\n    # perform multiplication\n    ww = (z1 + x1) * (x2 + y2)\n    yy = (w1 - y1) * (w2 + z2)\n    zz = (w1 + y1) * (w2 - z2)\n    xx = ww + yy + zz\n    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))\n    w = qq - ww + (z1 - y1) * (y2 - z2)\n    x = qq - xx + (x1 + w1) * (x2 + w2)\n    y = qq - yy + (w1 - x1) * (y2 + z2)\n    z = qq - zz + (z1 + y1) * (w2 - x2)\n\n    return torch.stack([w, x, y, z], dim=-1).view(shape)\n\n\n@torch.jit.script\ndef normalize(x, eps: float = 1e-9):\n    return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)\n\n\n@torch.jit.script\ndef quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:\n    \"\"\"Apply a quaternion rotation to a vector.\n\n    Args:\n        quat: The quaternion in (w, x, y, z). Shape is (..., 4).\n        vec: The vector in (x, y, z). Shape is (..., 3).\n\n    Returns:\n        The rotated vector in (x, y, z). Shape is (..., 3).\n    \"\"\"\n    # store shape\n    shape = vec.shape\n    # reshape to (N, 3) for multiplication\n    quat = quat.reshape(-1, 4)\n    vec = vec.reshape(-1, 3)\n    # extract components from quaternions\n    xyz = quat[:, 1:]\n    t = xyz.cross(vec, dim=-1) * 2\n    return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)\n\n\n@torch.jit.script\ndef quat_apply_inverse(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:\n    \"\"\"Apply an inverse quaternion rotation to a vector.\n\n    Args:\n        quat: The quaternion in (w, x, y, z). Shape is (..., 4).\n        vec: The vector in (x, y, z). Shape is (..., 3).\n\n    Returns:\n        The rotated vector in (x, y, z). Shape is (..., 3).\n    \"\"\"\n    # store shape\n    shape = vec.shape\n    # reshape to (N, 3) for multiplication\n    quat = quat.reshape(-1, 4)\n    vec = vec.reshape(-1, 3)\n    # extract components from quaternions\n    xyz = quat[:, 1:]\n    t = xyz.cross(vec, dim=-1) * 2\n    return (vec - quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)\n\n\n@torch.jit.script\ndef quat_rotate(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a + b + c\n\n\n# @torch.jit.script\ndef quat_rotate_inverse(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a - b + c\n\n\n@torch.jit.script\ndef quat_conjugate(a):\n    shape = a.shape\n    a = a.reshape(-1, 4)\n    return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape)\n    # return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)\n\n\n@torch.jit.script\ndef quat_unit(a):\n    return normalize(a)\n\n\n@torch.jit.script\ndef quat_from_angle_axis(angle, axis):\n    theta = (angle / 2).unsqueeze(-1)\n    xyz = normalize(axis) * theta.sin()\n    w = theta.cos()\n    return quat_unit(torch.cat([xyz, w], dim=-1))\n\n\n@torch.jit.script\ndef normalize_angle(x):\n    return torch.atan2(torch.sin(x), torch.cos(x))\n\n\n@torch.jit.script\ndef tf_inverse(q, t):\n    q_inv = quat_conjugate(q)\n    return q_inv, -quat_apply(q_inv, t)\n\n\n@torch.jit.script\ndef tf_apply(q, t, v):\n    return quat_apply(q, v) + t\n\n\n@torch.jit.script\ndef tf_vector(q, v):\n    return quat_apply(q, v)\n\n\n@torch.jit.script\ndef tf_combine(q1, t1, q2, t2):\n    return quat_mul(q1, q2), quat_apply(q1, t2) + t1\n\n\n@torch.jit.script\ndef get_basis_vector(q, v):\n    return quat_rotate(q, v)\n\n\ndef get_axis_params(value, axis_idx, x_value=0.0, dtype=np.float64, n_dims=3):\n    \"\"\"Construct arguments to `Vec` according to axis index.\"\"\"\n    zs = np.zeros((n_dims,))\n    assert axis_idx < n_dims, (\n        \"the axis dim should be within the vector dimensions\"\n    )\n    zs[axis_idx] = 1.0\n    params = np.where(zs == 1.0, value, zs)\n    params[0] = x_value\n    return list(params.astype(dtype))\n\n\n@torch.jit.script\ndef copysign(a, b):\n    a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])\n    return torch.abs(a) * torch.sign(b)\n\n\n@torch.jit.script\ndef get_euler_xyz(q: torch.Tensor) -> tuple:\n    qx, qy, qz, qw = 0, 1, 2, 3\n    # roll (x-axis rotation)\n    sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])\n    cosr_cosp = (\n        q[:, qw] * q[:, qw]\n        - q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        + q[:, qz] * q[:, qz]\n    )\n    roll = torch.atan2(sinr_cosp, cosr_cosp)\n\n    # pitch (y-axis rotation)\n    sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])\n    pitch = torch.where(\n        torch.abs(sinp) >= 1,\n        copysign(torch.tensor(np.pi / 2.0, device=sinp.device), sinp),\n        torch.asin(sinp),\n    )\n\n    # yaw (z-axis rotation)\n    siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])\n    cosy_cosp = (\n        q[:, qw] * q[:, qw]\n        + q[:, qx] * q[:, qx]\n        - q[:, qy] * q[:, qy]\n        - q[:, qz] * q[:, qz]\n    )\n    yaw = torch.atan2(siny_cosp, cosy_cosp)\n\n    return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)\n\n\n@torch.jit.script\ndef quat_from_euler_xyz(roll, pitch, yaw):\n    cy = torch.cos(yaw * 0.5)\n    sy = torch.sin(yaw * 0.5)\n    cr = torch.cos(roll * 0.5)\n    sr = torch.sin(roll * 0.5)\n    cp = torch.cos(pitch * 0.5)\n    sp = torch.sin(pitch * 0.5)\n\n    qw = cy * cr * cp + sy * sr * sp\n    qx = cy * sr * cp - sy * cr * sp\n    qy = cy * cr * sp + sy * sr * cp\n    qz = sy * cr * cp - cy * sr * sp\n\n    return torch.stack([qx, qy, qz, qw], dim=-1)\n\n\ndef torch_rand_float(lower, upper, shape, device):\n    return (upper - lower) * torch.rand(*shape, device=device) + lower\n\n\n# @torch.jit.script\n@torch.compile\ndef torch_random_dir_2(shape, device):\n    angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1)\n    return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)\n\n\n@torch.jit.script\ndef tensor_clamp(t, min_t, max_t):\n    return torch.max(torch.min(t, max_t), min_t)\n\n\n@torch.jit.script\ndef scale(x, lower, upper):\n    return 0.5 * (x + 1.0) * (upper - lower) + lower\n\n\n@torch.jit.script\ndef unscale(x, lower, upper):\n    return (2.0 * x - upper - lower) / (upper - lower)\n\n\ndef unscale_np(x, lower, upper):\n    return (2.0 * x - upper - lower) / (upper - lower)\n\n\n@torch.jit.script\ndef quat_to_angle_axis(q):\n    # computes axis-angle representation from quaternion q\n    # q must be normalized\n    min_theta = 1e-5\n    qx, _, _, qw = 0, 1, 2, 3\n\n    sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])\n    angle = 2 * torch.acos(q[..., qw])\n    angle = normalize_angle(angle)\n    sin_theta_expand = sin_theta.unsqueeze(-1)\n    axis = q[..., qx:qw] / sin_theta_expand\n\n    mask = torch.abs(sin_theta) > min_theta\n    default_axis = torch.zeros_like(axis)\n    default_axis[..., -1] = 1\n\n    angle = torch.where(mask, angle, torch.zeros_like(angle))\n    mask_expand = mask.unsqueeze(-1)\n    axis = torch.where(mask_expand, axis, default_axis)\n    return angle, axis\n\n\n@torch.jit.script\ndef angle_axis_to_exp_map(angle, axis):\n    # compute exponential map from axis-angle\n    angle_expand = angle.unsqueeze(-1)\n    exp_map = angle_expand * axis\n    return exp_map\n\n\n@torch.jit.script\ndef quat_to_exp_map(q):\n    # compute exponential map from quaternion\n    # q must be normalized\n    angle, axis = quat_to_angle_axis(q)\n    exp_map = angle_axis_to_exp_map(angle, axis)\n    return exp_map\n\n\n@torch.jit.script\ndef slerp(q0, q1, t):\n    cos_half_theta = torch.sum(q0 * q1, dim=-1)\n\n    neg_mask = cos_half_theta < 0\n    q1 = q1.clone()\n    q1[neg_mask] = -q1[neg_mask]\n    cos_half_theta = torch.abs(cos_half_theta)\n    cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)\n\n    half_theta = torch.acos(cos_half_theta)\n    sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)\n\n    ratio_a = torch.sin((1 - t) * half_theta) / sin_half_theta\n    ratio_b = torch.sin(t * half_theta) / sin_half_theta\n\n    new_q = ratio_a * q0 + ratio_b * q1\n\n    new_q = torch.where(\n        torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q\n    )\n    new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)\n\n    return new_q\n\n\n@torch.jit.script\ndef my_quat_rotate(q, v):\n    shape = q.shape\n    q_w = q[:, -1]\n    q_vec = q[:, :3]\n    a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)\n    b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0\n    c = (\n        q_vec\n        * torch.bmm(\n            q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)\n        ).squeeze(-1)\n        * 2.0\n    )\n    return a + b + c\n\n\n@torch.jit.script\ndef calc_heading(q):\n    # calculate heading direction from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    # this is the x axis heading\n    ref_dir = torch.zeros_like(q[..., 0:3])\n    ref_dir[..., 0] = 1\n    rot_dir = my_quat_rotate(q, ref_dir)\n\n    heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])\n    return heading\n\n\n@torch.jit.script\ndef calc_heading_quat(q):\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(heading, axis)\n    return heading_q\n\n\n@torch.jit.script\ndef calc_heading_quat_inv(q):\n    # calculate heading rotation from quaternion\n    # the heading is the direction on the xy plane\n    # q must be normalized\n    heading = calc_heading(q)\n    axis = torch.zeros_like(q[..., 0:3])\n    axis[..., 2] = 1\n\n    heading_q = quat_from_angle_axis(-heading, axis)\n    return heading_q\n\n\n@torch.compiler.disable\ndef axis_angle_from_quat(\n    quat: torch.Tensor,\n    w_last: bool = True,\n) -> torch.Tensor:\n    \"\"\"Compute axis-angle (log map) vector from a quaternion.\n\n    Args:\n        quat (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z].\n        w_last (bool): Whether the scalar part w is the last element.\n\n    Returns:\n        torch.Tensor: (..., 3) axis-angle vector (axis * angle), with angle in radians in [0, pi].\n\n    Notes:\n        - The quaternion is sign-adjusted to ensure w >= 0 and normalized to unit length for numerical stability.\n        - Uses a stable small-angle handling to avoid NaNs and gradient issues.\n    \"\"\"\n    # Handle different quaternion formats\n    if w_last:\n        # Quaternion is [q_x, q_y, q_z, q_w]\n        quat_w_orig = quat[..., -1:]\n    else:\n        # Quaternion is [q_w, q_x, q_y, q_z]\n        quat_w_orig = quat[..., 0:1]\n\n    # Normalize quaternion to have w > 0\n    quat = quat * (1.0 - 2.0 * (quat_w_orig < 0.0))\n\n    # Ensure unit quaternion for stability\n    quat = quat / torch.linalg.norm(quat, dim=-1, keepdim=True).clamp_min(\n        1.0e-9\n    )\n\n    # Recompute quat_xyz and quat_w after potential sign flip\n    if w_last:\n        quat_w = quat[..., -1:]\n        quat_xyz = quat[..., :3]\n    else:\n        quat_w = quat[..., 0:1]\n        quat_xyz = quat[..., 1:4]\n\n    mag = torch.linalg.norm(quat_xyz, dim=-1)\n    half_angle = torch.atan2(mag, quat_w.squeeze(-1))\n    angle = 2.0 * half_angle\n    # check whether to apply Taylor approximation\n    use_taylor = angle.abs() <= 1.0e-6\n    # To prevent NaN gradients with torch.where, we compute both branches and blend\n    # based on the condition.\n    # See: https://pytorch.org/docs/1.9.0/generated/torch.where.html#torch-where\n    # \"However, if you need the gradients to flow through the branches, please use torch.lerp\"\n    # Although we are not using lerp, the principle of avoiding sharp branches is the same.\n    sin_half_angles_over_angles_approx = 0.5 - angle * angle / 48\n    # Clamp angle to avoid division by zero in the non-taylor branch when angle is exactly 0.\n    angle_safe = torch.where(use_taylor, torch.ones_like(angle), angle)\n    sin_half_angles_over_angles_exact = torch.sin(half_angle) / angle_safe\n\n    sin_half_angles_over_angles = torch.where(\n        use_taylor,\n        sin_half_angles_over_angles_approx,\n        sin_half_angles_over_angles_exact,\n    )\n    return quat_xyz / sin_half_angles_over_angles[..., None]\n\n\n@torch.compile\ndef quat_box_minus(\n    q1: torch.Tensor,\n    q2: torch.Tensor,\n    w_last: bool = True,\n) -> torch.Tensor:\n    \"\"\"Right-invariant quaternion difference mapped to so(3) via log map.\n\n    Computes log(q1 * q2^{-1}) using the shortest rotation convention.\n\n    Args:\n        q1 (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z].\n        q2 (torch.Tensor): (..., 4) quaternion with the same format as `q1`.\n        w_last (bool): Whether the scalar part w is the last element.\n\n    Returns:\n        torch.Tensor: (..., 3) axis-angle error vector.\n    \"\"\"\n    if w_last:\n        q1_xyzw = q1\n        q2_xyzw = q2\n    else:\n        # Convert from (w, x, y, z) to (x, y, z, w)\n        q1_xyzw = torch.cat([q1[..., 1:4], q1[..., 0:1]], dim=-1)\n        q2_xyzw = torch.cat([q2[..., 1:4], q2[..., 0:1]], dim=-1)\n\n    quat_diff = quat_mul(\n        q1_xyzw,\n        quat_conjugate(q2_xyzw),\n        w_last=True,\n    )  # q1 * q2^-1\n    return axis_angle_from_quat(quat_diff, w_last=True)  # log(qd)\n\n\n@torch.compile\ndef quat_error_magnitude(\n    q1: torch.Tensor,\n    q2: torch.Tensor,\n    w_last: bool = True,\n) -> torch.Tensor:\n    \"\"\"Geodesic angle between two orientations given as quaternions.\n\n    Args:\n        q1 (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z].\n        q2 (torch.Tensor): (..., 4) quaternion with the same format as `q1`.\n        w_last (bool): Whether the scalar part w is the last element.\n\n    Returns:\n        torch.Tensor: (...,) rotation angle in radians in [0, pi].\n    \"\"\"\n    axis_angle_error = quat_box_minus(q1, q2, w_last=w_last)\n    return torch.norm(axis_angle_error, dim=-1)\n\n\n@torch.jit.script\ndef quat_inv(q: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:\n    \"\"\"Computes the inverse of a quaternion.\n\n    Args:\n        q: The quaternion orientation in (w, x, y, z). Shape is (N, 4).\n        eps: A small value to avoid division by zero. Defaults to 1e-9.\n\n    Returns:\n        The inverse quaternion in (w, x, y, z). Shape is (N, 4).\n    \"\"\"\n    return quat_conjugate(q) / q.pow(2).sum(dim=-1, keepdim=True).clamp(\n        min=eps\n    )\n\n\n# --------------------- WXYZ helpers (torch) ---------------------\ndef xyzw_to_wxyz(q: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert quaternion from XYZW to WXYZ.\n    Args:\n        q (torch.Tensor): (..., 4) quaternion in XYZW.\n    Returns:\n        torch.Tensor: (..., 4) quaternion in WXYZ.\n    \"\"\"\n    return torch.cat([q[..., 3:4], q[..., 0:3]], dim=-1)\n\n\ndef wxyz_to_xyzw(q: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert quaternion from WXYZ to XYZW.\n    Args:\n        q (torch.Tensor): (..., 4) quaternion in WXYZ.\n    Returns:\n        torch.Tensor: (..., 4) quaternion in XYZW.\n    \"\"\"\n    return torch.cat([q[..., 1:4], q[..., 0:1]], dim=-1)\n\n\n@torch.compile\ndef quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Hamilton product in WXYZ layout using fused implementation.\n    Args:\n        q1 (torch.Tensor): (..., 4) WXYZ.\n        q2 (torch.Tensor): (..., 4) WXYZ.\n    Returns:\n        torch.Tensor: (..., 4) WXYZ.\n    \"\"\"\n    return quat_mul(q1, q2, w_last=False)\n\n\ndef rotate_vec_wxyz(q_wxyz: torch.Tensor, v: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Rotate vector v by quaternion q (WXYZ).\n    Args:\n        q_wxyz (torch.Tensor): (..., 4) WXYZ.\n        v (torch.Tensor): (..., 3).\n    Returns:\n        torch.Tensor: (..., 3) rotated vector.\n    \"\"\"\n    # Support single-vector inputs by promoting to batch\n    single = q_wxyz.ndim == 1\n    if single:\n        q_in = q_wxyz[None, :]\n        v_in = v[None, :]\n    else:\n        q_in = q_wxyz\n        v_in = v\n    q_xyzw = wxyz_to_xyzw(q_in)\n    out = quat_apply(q_xyzw, v_in)\n    if single:\n        return out[0]\n    return out\n\n\ndef rotate_vec_inv_wxyz(q_wxyz: torch.Tensor, v: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Rotate vector v by inverse of quaternion q (WXYZ).\n    Args:\n        q_wxyz (torch.Tensor): (..., 4) WXYZ.\n        v (torch.Tensor): (..., 3).\n    Returns:\n        torch.Tensor: (..., 3) rotated vector in inverse rotation.\n    \"\"\"\n    single = q_wxyz.ndim == 1\n    if single:\n        q_in = q_wxyz[None, :]\n        v_in = v[None, :]\n    else:\n        q_in = q_wxyz\n        v_in = v\n    q_xyzw = wxyz_to_xyzw(q_in)\n    q_inv_xyzw = quat_conjugate(q_xyzw)\n    out = quat_apply(q_inv_xyzw, v_in)\n    if single:\n        return out[0]\n    return out\n\n\ndef subtract_frame_transforms(\n    t01: torch.Tensor,\n    q01: torch.Tensor,\n    t02: torch.Tensor | None = None,\n    q02: torch.Tensor | None = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    r\"\"\"Subtract transformations between two reference frames into a stationary frame.\n\n    It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \\times T_{02}`,\n    where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B.\n\n    Args:\n        t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3).\n        q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).\n        t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3).\n            Defaults to None, in which case the position is assumed to be zero.\n        q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).\n            Defaults to None, in which case the orientation is assumed to be identity.\n\n    Returns:\n        A tuple containing the position and orientation of frame 2 w.r.t. frame 1.\n        Shape of the tensors are (N, 3) and (N, 4) respectively.\n    \"\"\"\n    # compute orientation\n    q10 = quat_inv(q01)\n    if q02 is not None:\n        q12 = quat_mul(q10, q02)\n    else:\n        q12 = q10\n    # compute translation\n    if t02 is not None:\n        t12 = quat_apply(q10, t02 - t01)\n    else:\n        t12 = quat_apply(q10, -t01)\n    return t12, q12\n\n\n@torch.compile\ndef quat_normalize_wxyz(q_wxyz: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Normalize quaternion in WXYZ layout.\n    Args:\n        q_wxyz (torch.Tensor): (..., 4) WXYZ.\n    Returns:\n        torch.Tensor: (..., 4) normalized WXYZ.\n    \"\"\"\n    return q_wxyz / torch.linalg.norm(q_wxyz, dim=-1, keepdim=True).clamp_min(\n        1.0e-9\n    )\n\n\n# @torch.compile\n@torch.jit.script\ndef matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4).\n\n    Returns:\n        Rotation matrices. The shape is (..., 3, 3).\n\n    Reference:\n        https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions, -1)\n    # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\n\n@torch.jit.script\ndef rot6d_from_quat(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as quaternions to 6D rotation representation.\n\n    Uses the continuous 6D rotation representation from Zhou et al. (CVPR 2019).\n\n    Args:\n        quaternions: (..., 4) quaternion in (w, x, y, z).\n\n    Returns:\n        (..., 6) 6D rotation representation (first two columns of rotation matrix, flattened).\n    \"\"\"\n    mat = matrix_from_quat(quaternions)  # (..., 3, 3)\n    batch_shape = mat.shape[:-2]\n    return mat[..., :, :2].reshape(batch_shape + (6,))\n\n\n@torch.jit.script\ndef matrix_from_rot6d(rot6d: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert 6D rotation representation to rotation matrix.\n\n    Uses Gram-Schmidt orthogonalization to reconstruct the rotation matrix\n    from the first two columns.\n\n    Args:\n        rot6d: (..., 6) 6D rotation representation (first two columns of rotation matrix, flattened).\n\n    Returns:\n        (..., 3, 3) rotation matrix.\n    \"\"\"\n    # Extract first two columns\n    a1 = rot6d[..., :3]  # first column\n    a2 = rot6d[..., 3:]  # second column\n\n    # Gram-Schmidt orthogonalization\n    b1 = torch.nn.functional.normalize(a1, dim=-1)\n    b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1\n    b2 = torch.nn.functional.normalize(b2, dim=-1)\n    b3 = torch.cross(b1, b2, dim=-1)\n\n    # Stack columns to form rotation matrix\n    mat = torch.stack((b1, b2, b3), dim=-1)  # (..., 3, 3)\n    return mat\n\n\n@torch.jit.script\ndef quat_from_matrix(mat: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotation matrix to quaternion.\n\n    Args:\n        mat: (..., 3, 3) rotation matrix.\n\n    Returns:\n        (..., 4) quaternion in (w, x, y, z).\n    \"\"\"\n    batch_dim = mat.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n        mat.reshape(batch_dim + (9,)), dim=-1\n    )\n\n    # Compute q_abs = sqrt(max(0, trace_terms))\n    q_abs = torch.sqrt(\n        torch.clamp(\n            torch.stack(\n                [\n                    1.0 + m00 + m11 + m22,\n                    1.0 + m00 - m11 - m22,\n                    1.0 - m00 + m11 - m22,\n                    1.0 - m00 - m11 + m22,\n                ],\n                dim=-1,\n            ),\n            min=0.0,\n        )\n    )\n\n    # Compute quaternion candidates for each branch\n    quat_by_rijk = torch.stack(\n        [\n            torch.stack(\n                [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1\n            ),\n            torch.stack(\n                [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1\n            ),\n            torch.stack(\n                [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1\n            ),\n            torch.stack(\n                [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1\n            ),\n        ],\n        dim=-2,\n    )\n\n    # Normalize candidates (floor at 0.1 for numerical stability)\n    flr = torch.tensor(0.1, dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].clamp(min=flr))\n\n    # Pick the best-conditioned candidate (largest denominator)\n    return quat_candidates[\n        torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,\n        :,\n    ].reshape(batch_dim + (4,))\n\n\n@torch.jit.script\ndef quat_from_rot6d(rot6d: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert 6D rotation representation to quaternions.\n\n    Args:\n        rot6d: (..., 6) 6D rotation representation (first two columns of rotation matrix, flattened).\n\n    Returns:\n        (..., 4) quaternion in (w, x, y, z).\n    \"\"\"\n    mat = matrix_from_rot6d(rot6d)\n    return quat_from_matrix(mat)\n\n\n@torch.jit.script\ndef euler_xyz_from_quat(\n    quat: torch.Tensor, wrap_to_2pi: bool = False\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Convert rotations given as quaternions to Euler angles in radians.\n\n    Note:\n        The euler angles are assumed in XYZ extrinsic convention.\n\n    Args:\n        quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4).\n        wrap_to_2pi (bool): Whether to wrap output Euler angles into [0, 2π). If\n            False, angles are returned in the default range (−π, π]. Defaults to\n            False.\n\n    Returns:\n        A tuple containing roll-pitch-yaw. Each element is a tensor of shape (N,).\n\n    Reference:\n        https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles\n    \"\"\"\n    q_w, q_x, q_y, q_z = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3]\n    # roll (x-axis rotation)\n    sin_roll = 2.0 * (q_w * q_x + q_y * q_z)\n    cos_roll = 1 - 2 * (q_x * q_x + q_y * q_y)\n    roll = torch.atan2(sin_roll, cos_roll)\n\n    # pitch (y-axis rotation)\n    sin_pitch = 2.0 * (q_w * q_y - q_z * q_x)\n    pitch = torch.where(\n        torch.abs(sin_pitch) >= 1,\n        torch.copysign(\n            torch.tensor(torch.pi / 2.0, device=quat.device, dtype=quat.dtype),\n            sin_pitch,\n        ),\n        torch.asin(sin_pitch),\n    )\n\n    # yaw (z-axis rotation)\n    sin_yaw = 2.0 * (q_w * q_z + q_x * q_y)\n    cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z)\n    yaw = torch.atan2(sin_yaw, cos_yaw)\n\n    if wrap_to_2pi:\n        return (\n            roll % (2 * torch.pi),\n            pitch % (2 * torch.pi),\n            yaw % (2 * torch.pi),\n        )\n    return roll, pitch, yaw\n\n\n@torch.jit.script\ndef yaw_quat(quat: torch.Tensor) -> torch.Tensor:\n    \"\"\"Extract the yaw component of a quaternion.\n\n    Args:\n        quat: The orientation in (w, x, y, z). Shape is (..., 4)\n\n    Returns:\n        A quaternion with only yaw component.\n    \"\"\"\n    shape = quat.shape\n    quat_yaw = quat.view(-1, 4)\n    qw = quat_yaw[:, 0]\n    qx = quat_yaw[:, 1]\n    qy = quat_yaw[:, 2]\n    qz = quat_yaw[:, 3]\n    yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz))\n    quat_yaw = torch.zeros_like(quat_yaw)\n    quat_yaw[:, 3] = torch.sin(yaw / 2)\n    quat_yaw[:, 0] = torch.cos(yaw / 2)\n    quat_yaw = normalize(quat_yaw)\n    return quat_yaw.view(shape)\n\n\ndef standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert a unit quaternion to a standard form: one in which the real\n    part is non negative.\n\n    Args:\n        quaternions: Quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Standardized quaternions as tensor of shape (..., 4).\n    \"\"\"\n    return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)\n\n\n@torch.compiler.disable\ndef gaussian_kernel1d(\n    sigma: float, device: torch.device, dtype: torch.dtype\n) -> torch.Tensor:\n    if sigma <= 0.0:\n        raise ValueError(f\"Invalid sigma: {sigma}\")\n    radius = int(4.0 * sigma + 0.5)\n    x = torch.arange(-radius, radius + 1, device=device, dtype=dtype)\n    kernel = torch.exp(-0.5 * (x / sigma).square())\n    return kernel / kernel.sum()\n\n\n@torch.compiler.disable\ndef gaussian_filter1d(x: torch.Tensor, sigma: float, dim: int) -> torch.Tensor:\n    if x.shape[dim] < 2:\n        return x\n    kernel = gaussian_kernel1d(sigma, device=x.device, dtype=x.dtype).reshape(\n        1, 1, -1\n    )\n    x_perm = x.movedim(dim, -1)\n    x_flat = x_perm.reshape(-1, 1, x_perm.shape[-1])\n    pad = kernel.shape[-1] // 2\n    x_flat = F.pad(x_flat, (pad, pad), mode=\"replicate\")\n    y = F.conv1d(x_flat, kernel)\n    y = y.reshape(x_perm.shape)\n    return y.movedim(-1, dim)\n\n\ndef smooth_time_series(\n    x: torch.Tensor, sigma: float, dim: int\n) -> torch.Tensor:\n    \"\"\"Gaussian smooth along a time dimension.\n\n    This is a thin wrapper around :func:`gaussian_filter1d` that treats\n    non-positive sigma as \"no-op\" for easy ablations.\n    \"\"\"\n    if sigma <= 0.0:\n        return x\n    return gaussian_filter1d(x, sigma=float(sigma), dim=int(dim))\n\n\n@torch.compiler.disable\ndef grad_t(x: torch.Tensor, dt: float) -> torch.Tensor:\n    if dt <= 0.0:\n        raise ValueError(f\"Invalid dt: {dt}\")\n    if x.shape[1] < 2:\n        return torch.zeros_like(x)\n    grad = torch.empty_like(x)\n    inv_dt = 1.0 / dt\n    grad[:, 0] = (x[:, 1] - x[:, 0]) * inv_dt\n    grad[:, -1] = (x[:, -1] - x[:, -2]) * inv_dt\n    if x.shape[1] > 2:\n        grad[:, 1:-1] = (x[:, 2:] - x[:, :-2]) * (0.5 * inv_dt)\n    return grad\n\n\ndef axis_angle_to_matrix(\n    angles: torch.Tensor, axes: torch.Tensor\n) -> torch.Tensor:\n    if axes.shape[-1] != 3:\n        raise ValueError(\"Axes must have shape (N, 3)\")\n    axis_norm = torch.linalg.norm(axes, dim=-1)\n    if torch.any(axis_norm <= 0):\n        raise ValueError(\"Axis vector has zero norm\")\n    axis = axes / axis_norm[:, None]\n    aat = torch.einsum(\"ni,nj->nij\", axis, axis)\n    skew = torch.zeros(\n        (axis.shape[0], 3, 3), device=axes.device, dtype=axes.dtype\n    )\n    ax, ay, az = axis[:, 0], axis[:, 1], axis[:, 2]\n    skew[:, 0, 1] = -az\n    skew[:, 0, 2] = ay\n    skew[:, 1, 0] = az\n    skew[:, 1, 2] = -ax\n    skew[:, 2, 0] = -ay\n    skew[:, 2, 1] = ax\n    sin_t = torch.sin(angles)\n    cos_t = torch.cos(angles)\n    eye = torch.eye(3, device=axes.device, dtype=axes.dtype)[None, None, None]\n    return (\n        cos_t[..., None, None] * eye\n        + (1.0 - cos_t)[..., None, None] * aat[None, None]\n        + sin_t[..., None, None] * skew[None, None]\n    )\n"
  },
  {
    "path": "holomotion/src/utils/unitree_g1_actuator_calculator.py",
    "content": "# Project HoloMotion\n#\n# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n# implied. See the License for the specific language governing\n# permissions and limitations under the License.\n\n\n\nfrom __future__ import annotations\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any\n\n\n@dataclass(frozen=True)\nclass MotorFamily:\n    name: str\n    armature: float\n    x1: float\n    x2: float\n    y1: float\n    y2: float\n    fs: float\n    fd: float\n    va: float = 0.01\n\n\n@dataclass(frozen=True)\nclass JointSpec:\n    joint_expr: str\n    motor: MotorFamily\n    effort_limit: float\n    velocity_limit: float\n    servo_scale: float = 1.0\n    envelope_scale: float = 1.0\n    friction_scale: float = 1.0\n\n\n# -----------------------------------------------------------------------------\n# Base actuator families\n# -----------------------------------------------------------------------------\n\nN5020_16 = MotorFamily(\n    name=\"N5020_16\",\n    armature=0.003609725,\n    x1=30.86,\n    x2=40.13,\n    y1=24.8,\n    y2=31.9,\n    fs=0.6,\n    fd=0.06,\n)\n\nN7520_14P3 = MotorFamily(\n    name=\"N7520_14P3\",\n    armature=0.010177520,\n    x1=22.63,\n    x2=35.52,\n    y1=71.0,\n    y2=83.3,\n    fs=1.6,\n    fd=0.16,\n)\n\nN7520_22P5 = MotorFamily(\n    name=\"N7520_22P5\",\n    armature=0.025101925,\n    x1=14.5,\n    x2=22.7,\n    y1=111.0,\n    y2=131.0,\n    fs=2.4,\n    fd=0.24,\n)\n\nW4010_25 = MotorFamily(\n    name=\"W4010_25\",\n    armature=0.00425,\n    x1=15.3,\n    x2=24.76,\n    y1=4.8,\n    y2=8.6,\n    fs=0.6,\n    fd=0.06,\n)\n\n\n# -----------------------------------------------------------------------------\n# Design constants\n# -----------------------------------------------------------------------------\n\nNATURAL_FREQ_HZ = 10.0\nDAMPING_RATIO = 2.0\n\n# Set this to your actual physics dt before running the generator.\nPHYSICS_DT = 1.0 / 200.0\n\n# Desired action delay budget: at most 2 * (1 / 50) = 0.04 s.\nMIN_DELAY_SECONDS = 0.0\nMAX_DELAY_SECONDS = 2.0 / 50.0\n\n\ndef seconds_to_delay_steps(delay_seconds: float, physics_dt: float) -> int:\n    return int(math.floor(delay_seconds / physics_dt + 1e-12))\n\n\nMIN_DELAY = seconds_to_delay_steps(MIN_DELAY_SECONDS, PHYSICS_DT)\nMAX_DELAY = seconds_to_delay_steps(MAX_DELAY_SECONDS, PHYSICS_DT)\n\n\n# -----------------------------------------------------------------------------\n# Single-group mapping\n#\n# ankle / waist:\n# - servo-side armature/gains are doubled\n# - torque envelope is NOT doubled\n# - friction is NOT doubled\n# -----------------------------------------------------------------------------\n\nALL_JOINT_SPECS: list[JointSpec] = [\n    # legs\n    JointSpec(\n        \".*_hip_yaw_joint\", N7520_14P3, effort_limit=88.0, velocity_limit=32.0\n    ),\n    JointSpec(\n        \".*_hip_roll_joint\",\n        N7520_22P5,\n        effort_limit=139.0,\n        velocity_limit=20.0,\n    ),\n    JointSpec(\n        \".*_hip_pitch_joint\",\n        N7520_14P3,\n        effort_limit=88.0,\n        velocity_limit=32.0,\n    ),\n    JointSpec(\n        \".*_knee_joint\", N7520_22P5, effort_limit=139.0, velocity_limit=20.0\n    ),\n    # feet\n    JointSpec(\n        \".*_ankle_pitch_joint\",\n        N5020_16,\n        effort_limit=50.0,\n        velocity_limit=37.0,\n        servo_scale=2.0,\n    ),\n    JointSpec(\n        \".*_ankle_roll_joint\",\n        N5020_16,\n        effort_limit=50.0,\n        velocity_limit=37.0,\n        servo_scale=2.0,\n    ),\n    # waist\n    JointSpec(\n        \"waist_roll_joint\",\n        N5020_16,\n        effort_limit=50.0,\n        velocity_limit=37.0,\n        servo_scale=2.0,\n    ),\n    JointSpec(\n        \"waist_pitch_joint\",\n        N5020_16,\n        effort_limit=50.0,\n        velocity_limit=37.0,\n        servo_scale=2.0,\n    ),\n    JointSpec(\n        \"waist_yaw_joint\", N7520_14P3, effort_limit=88.0, velocity_limit=32.0\n    ),\n    # arms\n    JointSpec(\n        \".*_shoulder_pitch_joint\",\n        N5020_16,\n        effort_limit=25.0,\n        velocity_limit=37.0,\n    ),\n    JointSpec(\n        \".*_shoulder_roll_joint\",\n        N5020_16,\n        effort_limit=25.0,\n        velocity_limit=37.0,\n    ),\n    JointSpec(\n        \".*_shoulder_yaw_joint\",\n        N5020_16,\n        effort_limit=25.0,\n        velocity_limit=37.0,\n    ),\n    JointSpec(\n        \".*_elbow_joint\", N5020_16, effort_limit=25.0, velocity_limit=37.0\n    ),\n    JointSpec(\n        \".*_wrist_roll_joint\", N5020_16, effort_limit=25.0, velocity_limit=37.0\n    ),\n    JointSpec(\n        \".*_wrist_pitch_joint\", W4010_25, effort_limit=5.0, velocity_limit=22.0\n    ),\n    JointSpec(\n        \".*_wrist_yaw_joint\", W4010_25, effort_limit=5.0, velocity_limit=22.0\n    ),\n]\n\n\ndef compute_pd_gains(\n    armature: float, natural_freq_hz: float, damping_ratio: float\n) -> tuple[float, float]:\n    wn = natural_freq_hz * 2.0 * math.pi\n    kp = armature * wn * wn\n    kd = 2.0 * damping_ratio * armature * wn\n    return kp, kd\n\n\ndef fmt_float(x: float) -> str:\n    return format(float(x), \".12g\")\n\n\ndef fmt_value(value: Any, indent: int = 0) -> str:\n    sp = \" \" * indent\n\n    if isinstance(value, dict):\n        if not value:\n            return \"{}\"\n        lines = [\"{\"]\n        for k, v in value.items():\n            lines.append(f\"{sp}    {k!r}: {fmt_value(v, indent + 4)},\")\n        lines.append(f\"{sp}}}\")\n        return \"\\n\".join(lines)\n\n    if isinstance(value, list):\n        if not value:\n            return \"[]\"\n        lines = [\"[\"]\n        for item in value:\n            lines.append(f\"{sp}    {fmt_value(item, indent + 4)},\")\n        lines.append(f\"{sp}]\")\n        return \"\\n\".join(lines)\n\n    if isinstance(value, float):\n        return fmt_float(value)\n\n    return repr(value)\n\n\ndef build_single_group_cfg(\n    specs: list[JointSpec],\n    natural_freq_hz: float = NATURAL_FREQ_HZ,\n    damping_ratio: float = DAMPING_RATIO,\n    min_delay: int = MIN_DELAY,\n    max_delay: int = MAX_DELAY,\n) -> dict[str, Any]:\n    joint_names_expr = [spec.joint_expr for spec in specs]\n\n    effort_limit: dict[str, float] = {}\n    velocity_limit: dict[str, float] = {}\n    stiffness: dict[str, float] = {}\n    damping: dict[str, float] = {}\n    armature: dict[str, float] = {}\n    x1: dict[str, float] = {}\n    x2: dict[str, float] = {}\n    y1: dict[str, float] = {}\n    y2: dict[str, float] = {}\n    fs: dict[str, float] = {}\n    fd: dict[str, float] = {}\n    va: dict[str, float] = {}\n    action_scale: dict[str, float] = {}\n\n    for spec in specs:\n        name = spec.joint_expr\n        servo_armature = spec.motor.armature * spec.servo_scale\n        kp, kd = compute_pd_gains(\n            servo_armature, natural_freq_hz, damping_ratio\n        )\n\n        effort_limit[name] = spec.effort_limit\n        velocity_limit[name] = spec.velocity_limit\n        stiffness[name] = kp\n        damping[name] = kd\n        armature[name] = servo_armature\n\n        x1[name] = spec.motor.x1\n        x2[name] = spec.motor.x2\n        y1[name] = spec.motor.y1 * spec.envelope_scale\n        y2[name] = spec.motor.y2 * spec.envelope_scale\n        fs[name] = spec.motor.fs * spec.friction_scale\n        fd[name] = spec.motor.fd * spec.friction_scale\n        va[name] = spec.motor.va\n\n        action_scale[name] = 0.25 * spec.effort_limit / kp\n\n    return {\n        \"joint_names_expr\": joint_names_expr,\n        \"min_delay\": min_delay,\n        \"max_delay\": max_delay,\n        \"effort_limit\": effort_limit,\n        \"velocity_limit\": velocity_limit,\n        \"stiffness\": stiffness,\n        \"damping\": damping,\n        \"armature\": armature,\n        \"friction\": 0.0,\n        \"dynamic_friction\": 0.0,\n        \"viscous_friction\": 0.0,\n        \"X1\": x1,\n        \"X2\": x2,\n        \"Y1\": y1,\n        \"Y2\": y2,\n        \"Fs\": fs,\n        \"Fd\": fd,\n        \"Va\": va,\n        \"action_scale\": action_scale,\n    }\n\n\ndef render_single_group_cfg(\n    cfg: dict[str, Any], group_name: str = \"all_joints\"\n) -> str:\n    ordered_keys = [\n        \"joint_names_expr\",\n        \"min_delay\",\n        \"max_delay\",\n        \"effort_limit\",\n        \"velocity_limit\",\n        \"stiffness\",\n        \"damping\",\n        \"armature\",\n        \"friction\",\n        \"dynamic_friction\",\n        \"viscous_friction\",\n        \"X1\",\n        \"X2\",\n        \"Y1\",\n        \"Y2\",\n        \"Fs\",\n        \"Fd\",\n        \"Va\",\n    ]\n\n    lines = [\n        \"from unitree_actuators import UnitreeActuatorCfg\",\n        \"\",\n        \"G1_HIFI_ACTUATORS = {\",\n        f\"    {group_name!r}: UnitreeActuatorCfg(\",\n    ]\n    for key in ordered_keys:\n        rendered = fmt_value(cfg[key], indent=8)\n        lines.append(f\"        {key}={rendered},\")\n    lines.append(\"    )\")\n    lines.append(\"}\")\n    lines.append(\"\")\n    lines.append(\"G1_HIFI_ACTION_SCALE = {\")\n    for joint_expr in cfg[\"joint_names_expr\"]:\n        lines.append(\n            f\"    {joint_expr!r}: {fmt_float(cfg['action_scale'][joint_expr])},\"\n        )\n    lines.append(\"}\")\n    return \"\\n\".join(lines)\n\n\ndef print_summary(cfg: dict[str, Any]) -> None:\n    print(\"# === SUMMARY ===\")\n    print(f\"# physics_dt = {fmt_float(PHYSICS_DT)}\")\n    print(f\"# min_delay  = {cfg['min_delay']}\")\n    print(f\"# max_delay  = {cfg['max_delay']}\")\n    print(\n        \"# joint_expr | effort_limit | velocity_limit | armature | kp | kd | \"\n        \"X1 | X2 | Y1 | Y2 | Fs | Fd | action_scale\"\n    )\n    for joint_expr in cfg[\"joint_names_expr\"]:\n        print(\n            f\"# {joint_expr} | \"\n            f\"{fmt_float(cfg['effort_limit'][joint_expr])} | \"\n            f\"{fmt_float(cfg['velocity_limit'][joint_expr])} | \"\n            f\"{fmt_float(cfg['armature'][joint_expr])} | \"\n            f\"{fmt_float(cfg['stiffness'][joint_expr])} | \"\n            f\"{fmt_float(cfg['damping'][joint_expr])} | \"\n            f\"{fmt_float(cfg['X1'][joint_expr])} | \"\n            f\"{fmt_float(cfg['X2'][joint_expr])} | \"\n            f\"{fmt_float(cfg['Y1'][joint_expr])} | \"\n            f\"{fmt_float(cfg['Y2'][joint_expr])} | \"\n            f\"{fmt_float(cfg['Fs'][joint_expr])} | \"\n            f\"{fmt_float(cfg['Fd'][joint_expr])} | \"\n            f\"{fmt_float(cfg['action_scale'][joint_expr])}\"\n        )\n    print()\n\n\ndef main() -> None:\n    cfg = build_single_group_cfg(ALL_JOINT_SPECS)\n    print_summary(cfg)\n    print(render_single_group_cfg(cfg, group_name=\"all_joints\"))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "holomotion/tests/__init__.py",
    "content": ""
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=64.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"holomotion\"\nversion = \"1.2.0\"\ndescription = \"HoloMotion\"\nauthors = [\n    {name = \"Horizon Robotics\"},\n]\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\n\ndependencies = []\n\n\n[project.urls]\nHomepage = \"https://horizonrobotics.github.io/robot_lab/holomotion/ \"\nRepository = \"https://github.com/\"\n\n[tool.setuptools.packages.find]\nwhere = [\".\"]\ninclude = [\"holomotion*\"]\n\n\n[tool.ruff]\n\nexclude = [\n    # common\n    \".bzr\",\n    \".direnv\",\n    \".eggs\",\n    \".git\",\n    \".git-rewrite\",\n    \".hg\",\n    \".ipynb_checkpoints\",\n    \".mypy_cache\",\n    \".nox\",\n    \".pants.d\",\n    \".pyenv\",\n    \".pytest_cache\",\n    \".pytype\",\n    \".ruff_cache\",\n    \".svn\",\n    \".tox\",\n    \".venv\",\n    \".vscode\",\n    \"__pypackages__\",\n    \"_build\",\n    \"buck-out\",\n    \"build\",\n    \"dist\",\n    \"node_modules\",\n    \"site-packages\",\n    \"venv\",\n    # project\n    \"3rdparty/*\",\n    \"dummy/*\",\n    \"*.pyi\",\n    \"*_pb2.py\",\n]\n\n# Same as Black.\nline-length = 79\nindent-width = 4\n\n# required python 3.11\ntarget-version = \"py311\"\n\n[tool.ruff.lint]\n\nselect = [\n    \"E\",   # flake8-errors\n    \"F\",   # pyflake\n    \"I\",   # isort\n    \"B\",   # flake8-bugber\n    \"TID\", # flake8-tidy-imports\n    \"D\",   # pydocstyle\n    \"Q\",   # flake8-quotes\n    \"W\",   # flake8-warnings\n    \"N\",   # pep8-naming\n]\n\nignore = [\n    \"D104\",\n    \"D107\",\n    \"D202\",\n    \"D105\",\n    \"D100\",\n    \"D102\",\n    \"D103\",\n    \"D101\",\n    \"D301\",\n    \"F403\",\n    \"B904\", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling\n    \"B028\", # No explicit `stacklevel` keyword argument found\n    \"D417\", # requires documentation for every function parameter.\n]\n\n[tool.ruff.lint.isort]\nknown-third-party = []\nno-lines-before = [\"future\", \"standard-library\"]\ncombine-as-imports = true\nforce-wrap-aliases = true\n\n[tool.ruff.lint.pydocstyle]\nconvention = \"google\"\n\n[tool.ruff.lint.flake8-tidy-imports]\n# Disallow all relative imports.\nban-relative-imports = \"all\"\n\n[tool.ruff.lint.flake8-quotes]\navoid-escape = false\n\n[tool.ruff.lint.mccabe]\nmax-complexity = 18\n\n[tool.ruff.lint.per-file-ignores]\n\"__init__.py\" = [\"TID252\", \"F401\"]\n\n[tool.ruff.lint.pep8-naming]\nclassmethod-decorators = [\n    # Allow Pydantic's `@validator` decorator to trigger class method treatment.\n    \"pydantic.validator\",\n    # Allow SQLAlchemy's dynamic decorators, like `@field.expression`, to trigger class method treatment.\n    \"declared_attr\",\n    \"expression\",\n    \"comparator\",\n]\nignore-names = [\n    # ruff default (https://docs.astral.sh/ruff/settings/#lintpep8-naming)\n    \"setUp\",\n    \"tearDown\",\n    \"setUpClass\",\n    \"tearDownClass\",\n    \"setUpModule\",\n    \"tearDownModule\",\n    \"asyncSetUp\",\n    \"asyncTearDown\",\n    \"setUpTestData\",\n    \"failureException\",\n    \"longMessage\",\n    \"maxDiff\",\n    # project\n    \"PROJECT_ROOT\",  # project test environment fixture\n    \"ROBO_ORCHARD_TEST_WORKSPACE\",  # project test fixture\n    \"F\",  # import torch.nn.functional as F\n]\n\n[tool.ruff.format]\n\n# Like Black, use double quotes for strings.\nquote-style = \"double\"\n\n# Like Black, indent with spaces, rather than tabs.\nindent-style = \"space\"\n\n# Like Black, respect magic trailing commas.\nskip-magic-trailing-comma = false\n\n# Like Black, automatically detect the appropriate line ending.\nline-ending = \"auto\"\n\ndocstring-code-format = true\n"
  },
  {
    "path": "tests/benchmark_legacy_onnx_attention.py",
    "content": "import sys\nimport tempfile\nimport time\nfrom pathlib import Path\n\nimport numpy as np\nimport onnx\nimport onnxruntime\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.modules.network_modules import (\n    export_safe_scaled_dot_product_attention,\n)\n\n\nclass _RawAttentionModule(nn.Module):\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        mask: torch.Tensor,\n    ) -> torch.Tensor:\n        return F.scaled_dot_product_attention(\n            q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False\n        )\n\n\nclass _SafeAttentionModule(nn.Module):\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        mask: torch.Tensor,\n    ) -> torch.Tensor:\n        return export_safe_scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attn_mask=mask,\n            dropout_p=0.0,\n            is_causal=False,\n        )\n\n\ndef _export_model(\n    module: nn.Module,\n    export_path: Path,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    mask: torch.Tensor,\n) -> None:\n    torch.onnx.export(\n        module.eval(),\n        (q, k, v, mask),\n        str(export_path),\n        opset_version=17,\n        input_names=[\"q\", \"k\", \"v\", \"mask\"],\n        output_names=[\"out\"],\n        dynamo=False,\n        verbose=False,\n    )\n\n\ndef _benchmark_session(\n    model_path: Path,\n    provider,\n    feed: dict[str, np.ndarray],\n    *,\n    warmup_iters: int = 50,\n    measure_iters: int = 300,\n) -> float:\n    providers = (\n        [\"CPUExecutionProvider\"]\n        if provider == \"CPUExecutionProvider\"\n        else [provider, \"CPUExecutionProvider\"]\n    )\n    session = onnxruntime.InferenceSession(\n        str(model_path),\n        providers=providers,\n    )\n    for _ in range(warmup_iters):\n        session.run([\"out\"], feed)\n    start = time.perf_counter()\n    for _ in range(measure_iters):\n        session.run([\"out\"], feed)\n    elapsed_s = time.perf_counter() - start\n    return (elapsed_s * 1000.0) / measure_iters\n\n\ndef main() -> None:\n    torch.manual_seed(0)\n    q = torch.randn(4, 8, 1, 64)\n    k = torch.randn(4, 8, 32, 64)\n    v = torch.randn(4, 8, 32, 64)\n    valid_lengths = torch.tensor([32, 24, 16, 8], dtype=torch.int64)\n    mask = (\n        torch.arange(32, dtype=torch.int64)[None, :] < valid_lengths[:, None]\n    )\n    mask = mask[:, None, None, :]\n    feed = {\n        \"q\": q.numpy(),\n        \"k\": k.numpy(),\n        \"v\": v.numpy(),\n        \"mask\": mask.numpy(),\n    }\n\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        tmp_path = Path(tmp_dir)\n        raw_path = tmp_path / \"raw_attention.onnx\"\n        safe_path = tmp_path / \"safe_attention.onnx\"\n        _export_model(_RawAttentionModule(), raw_path, q, k, v, mask)\n        _export_model(_SafeAttentionModule(), safe_path, q, k, v, mask)\n        raw_model = onnx.load(str(raw_path))\n        safe_model = onnx.load(str(safe_path))\n        raw_ops = [node.op_type for node in raw_model.graph.node]\n        safe_ops = [node.op_type for node in safe_model.graph.node]\n        print(\n            \"Graph ops: \"\n            f\"raw_has_isnan={'IsNaN' in raw_ops}, \"\n            f\"safe_has_isnan={'IsNaN' in safe_ops}\"\n        )\n\n        cpu_raw = _benchmark_session(raw_path, \"CPUExecutionProvider\", feed)\n        cpu_safe = _benchmark_session(safe_path, \"CPUExecutionProvider\", feed)\n        print(\n            f\"CPUExecutionProvider: raw={cpu_raw:.4f} ms, \"\n            f\"safe={cpu_safe:.4f} ms, \"\n            f\"delta={(cpu_safe - cpu_raw) / cpu_raw * 100.0:.2f}%\"\n        )\n\n        if \"CUDAExecutionProvider\" in onnxruntime.get_available_providers():\n            cuda_raw = _benchmark_session(\n                raw_path, \"CUDAExecutionProvider\", feed\n            )\n            cuda_safe = _benchmark_session(\n                safe_path, \"CUDAExecutionProvider\", feed\n            )\n            print(\n                f\"CUDAExecutionProvider: raw={cuda_raw:.4f} ms, \"\n                f\"safe={cuda_safe:.4f} ms, \"\n                f\"delta={(cuda_safe - cuda_raw) / cuda_raw * 100.0:.2f}%\"\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/benchmark_moe_router_export.py",
    "content": "import re\nimport time\nfrom pathlib import Path\n\nimport torch\n\n\ndef _extract_int_setting(config_path: Path, key: str) -> int:\n    pattern = re.compile(rf\"^\\s*{re.escape(key)}:\\s*([0-9]+)\\s*$\")\n    for line in config_path.read_text().splitlines():\n        match = pattern.match(line)\n        if match:\n            return int(match.group(1))\n    raise ValueError(\n        f\"Could not find integer setting {key!r} in {config_path}\"\n    )\n\n\ndef _load_b0310_shape_config() -> dict[str, int]:\n    repo_root = Path(__file__).resolve().parents[1]\n    module_cfg = (\n        repo_root\n        / \"holomotion\"\n        / \"config\"\n        / \"modules\"\n        / \"motion_tracking\"\n        / \"tf_motrack_v3.yaml\"\n    )\n    return {\n        \"num_fine_experts\": _extract_int_setting(\n            module_cfg, \"num_fine_experts\"\n        ),\n        \"top_k\": _extract_int_setting(module_cfg, \"top_k\"),\n        \"max_ctx_len\": _extract_int_setting(module_cfg, \"max_ctx_len\"),\n    }\n\n\ndef _router_scores_training(\n    logits_fp32: torch.Tensor,\n    *,\n    top_k: int,\n    bias_fp32: torch.Tensor | None = None,\n) -> torch.Tensor:\n    choice_logits = (\n        logits_fp32 if bias_fp32 is None else logits_fp32 + bias_fp32\n    )\n    _, topk_idx = torch.topk(choice_logits, top_k, dim=-1)\n    selected_logits = logits_fp32.gather(-1, topk_idx)\n    log_z = torch.logsumexp(logits_fp32, dim=-1, keepdim=True)\n    selected_probs = torch.exp(selected_logits - log_z)\n    return selected_probs / selected_probs.sum(dim=-1, keepdim=True).clamp_min(\n        1.0e-20\n    )\n\n\ndef _router_scores_export_safe(\n    logits_fp32: torch.Tensor,\n    *,\n    top_k: int,\n    bias_fp32: torch.Tensor | None = None,\n) -> torch.Tensor:\n    choice_logits = (\n        logits_fp32 if bias_fp32 is None else logits_fp32 + bias_fp32\n    )\n    _, topk_idx = torch.topk(choice_logits, top_k, dim=-1)\n    selected_probs = torch.softmax(logits_fp32, dim=-1).gather(-1, topk_idx)\n    return selected_probs / selected_probs.sum(dim=-1, keepdim=True).clamp_min(\n        1.0e-20\n    )\n\n\ndef _benchmark(\n    fn,\n    logits_fp32: torch.Tensor,\n    *,\n    top_k: int,\n    bias_fp32: torch.Tensor | None = None,\n    warmup_iters: int = 200,\n    measure_iters: int = 2000,\n) -> float:\n    is_cuda = logits_fp32.is_cuda\n    with torch.inference_mode():\n        for _ in range(warmup_iters):\n            fn(logits_fp32, top_k=top_k, bias_fp32=bias_fp32)\n        if is_cuda:\n            torch.cuda.synchronize(logits_fp32.device)\n        start = time.perf_counter()\n        for _ in range(measure_iters):\n            fn(logits_fp32, top_k=top_k, bias_fp32=bias_fp32)\n        if is_cuda:\n            torch.cuda.synchronize(logits_fp32.device)\n    elapsed_s = time.perf_counter() - start\n    return (elapsed_s * 1000.0) / measure_iters\n\n\ndef _run_case(\n    device: torch.device,\n    *,\n    case_name: str,\n    batch_size: int,\n    seq_len: int,\n    num_fine_experts: int,\n    top_k: int,\n) -> None:\n    seed = 0\n    generator = torch.Generator(device=\"cpu\")\n    generator.manual_seed(seed)\n    logits_fp32 = torch.randn(\n        batch_size,\n        seq_len,\n        num_fine_experts,\n        generator=generator,\n        dtype=torch.float32,\n    ).to(device)\n\n    eager_scores = _router_scores_training(logits_fp32, top_k=top_k)\n    export_scores = _router_scores_export_safe(logits_fp32, top_k=top_k)\n    max_abs_diff = (eager_scores - export_scores).abs().max().item()\n\n    eager_ms = _benchmark(\n        _router_scores_training,\n        logits_fp32,\n        top_k=top_k,\n    )\n    export_ms = _benchmark(\n        _router_scores_export_safe,\n        logits_fp32,\n        top_k=top_k,\n    )\n    delta_pct = ((export_ms - eager_ms) / eager_ms) * 100.0\n\n    print(\n        f\"{device.type}:{case_name}: \"\n        f\"shape=({batch_size}, {seq_len}, {num_fine_experts}), \"\n        f\"top_k={top_k}, \"\n        f\"training={eager_ms:.6f} ms, \"\n        f\"export_safe={export_ms:.6f} ms, \"\n        f\"delta={delta_pct:.2f}%, \"\n        f\"max_abs_diff={max_abs_diff:.3e}\"\n    )\n\n\ndef main() -> None:\n    shape_cfg = _load_b0310_shape_config()\n    num_fine_experts = shape_cfg[\"num_fine_experts\"]\n    top_k = shape_cfg[\"top_k\"]\n    max_ctx_len = shape_cfg[\"max_ctx_len\"]\n\n    cases = [\n        (\"single_step_export\", 1, 1),\n        (\"training_like_sequence\", 16, max_ctx_len),\n    ]\n    devices = [torch.device(\"cpu\")]\n    if torch.cuda.is_available():\n        devices.append(torch.device(\"cuda\"))\n\n    print(\n        \"Benchmarking MoE router formulas with \"\n        f\"num_fine_experts={num_fine_experts}, top_k={top_k}, \"\n        f\"max_ctx_len={max_ctx_len}\"\n    )\n    for device in devices:\n        for case_name, batch_size, seq_len in cases:\n            _run_case(\n                device,\n                case_name=case_name,\n                batch_size=batch_size,\n                seq_len=seq_len,\n                num_fine_experts=num_fine_experts,\n                top_k=top_k,\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/test_actor_export_config.py",
    "content": "import importlib\nimport sys\nimport unittest\nfrom pathlib import Path\nfrom types import SimpleNamespace\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.modules.agent_modules import (\n    PPOTFActor,\n    PPOTFRefRouterActor,\n    PPOTFRefRouterSeqActor,\n    PPOTFRefRouterV3Actor,\n    _clone_module_for_cpu_export,\n)\nfrom holomotion.src.modules.network_modules import (\n    GroupedMoEBlock,\n    GroupedMoETransformerPolicy,\n    ReferenceRoutedGroupedMoETransformerPolicy,\n    ReferenceRoutedGroupedMoETransformerPolicyV2,\n    ReferenceRoutedGroupedMoETransformerPolicyV3,\n    export_safe_scaled_dot_product_attention,\n)\nfrom holomotion.src.utils.onnx_export import export_policy_to_onnx\nfrom tensordict import TensorDict\n\ntry:\n    onnx = importlib.import_module(\"onnx\")\n    torch = importlib.import_module(\"torch\")\n    nn = importlib.import_module(\"torch.nn\")\nexcept ModuleNotFoundError as exc:\n    raise unittest.SkipTest(\n        f\"Optional ONNX test dependency missing: {exc.name}\"\n    ) from exc\n\n\nclass _DummyTFModule(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.n_layers = 1\n        self.max_ctx_len = 4\n        self.n_kv_heads = 1\n        self.head_dim = 2\n\n    def forward(\n        self,\n        obs: torch.Tensor,\n        past_key_values: torch.Tensor,\n        current_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        return obs[:, :2], past_key_values\n\n\nclass _DummyAttentionTFModule(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.n_layers = 1\n        self.max_ctx_len = 4\n        self.n_kv_heads = 1\n        self.head_dim = 2\n\n    def forward(\n        self,\n        obs: torch.Tensor,\n        past_key_values: torch.Tensor,\n        current_pos: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        batch_size = obs.shape[0]\n        max_len = past_key_values.shape[3]\n        valid_len = (current_pos + 1).clamp(max=max_len)\n        pos_idx = torch.arange(max_len, device=obs.device, dtype=torch.int64)\n        mask = (pos_idx[None, :] < valid_len[:, None])[:, None, None, :]\n\n        q = obs[:, :2].reshape(batch_size, 1, 1, 2)\n        k = torch.zeros(batch_size, 1, max_len, 2, device=obs.device)\n        v = torch.ones(batch_size, 1, max_len, 2, device=obs.device)\n        attn_out = export_safe_scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attn_mask=mask,\n            dropout_p=0.0,\n            is_causal=False,\n        )\n        actions = attn_out.reshape(batch_size, 2)\n        return actions, past_key_values\n\n\nclass _RecordingDeviceModule(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.current_device = \"cuda:0\"\n        self.to_calls = []\n\n    def to(self, device):\n        self.to_calls.append(str(device))\n        self.current_device = str(device)\n        return self\n\n\ndef _make_minimal_real_transformer_actor(\n    *,\n    n_layers: int = 1,\n    routing_score_fn: str = \"softmax\",\n    num_fine_experts: int = 1,\n    top_k: int = 1,\n    use_dynamic_bias: bool = False,\n    dense_layer_at_last: bool = False,\n    selected_expert_margin_to_unselected_enabled: bool = False,\n    selected_expert_margin_to_unselected_target: float = 0.0,\n) -> PPOTFActor:\n    actor = PPOTFActor.__new__(PPOTFActor)\n    nn.Module.__init__(actor)\n    actor.actor_module = GroupedMoETransformerPolicy(\n        input_dim=6,\n        output_dim=2,\n        module_config_dict={\n            \"type\": \"GroupedMoETransformerPolicy\",\n            \"num_fine_experts\": num_fine_experts,\n            \"num_shared_experts\": 0,\n            \"top_k\": top_k,\n            \"obs_embed_mlp_hidden\": 8,\n            \"d_model\": 8,\n            \"n_layers\": n_layers,\n            \"n_heads\": 2,\n            \"n_kv_heads\": 1,\n            \"ff_mult\": 1.0,\n            \"ff_mult_dense\": 1,\n            \"attn_dropout\": 0.0,\n            \"mlp_dropout\": 0.0,\n            \"max_ctx_len\": 4,\n            \"dense_layer_at_last\": dense_layer_at_last,\n            \"use_gated_attn\": False,\n            \"use_qk_norm\": True,\n            \"routing_score_fn\": routing_score_fn,\n            \"use_dynamic_bias\": use_dynamic_bias,\n            \"selected_expert_margin_to_unselected\": {\n                \"enabled\": selected_expert_margin_to_unselected_enabled,\n                \"target\": selected_expert_margin_to_unselected_target,\n            },\n        },\n    )\n    actor.obs_norm_enabled = False\n    actor.obs_normalizer = nn.Identity()\n    actor.obs_norm_clip = 0.0\n    actor.assembler = SimpleNamespace(output_dim=6)\n    return actor\n\n\ndef _capture_moe_router_outputs(\n    monkeypatch,\n    *,\n    export_mode: bool,\n    top_k: int,\n    use_dynamic_bias: bool,\n    x: torch.Tensor,\n    router_weight: torch.Tensor,\n    router_x: torch.Tensor | None = None,\n    expert_bias: torch.Tensor | None = None,\n):\n    block = GroupedMoEBlock(\n        d_model=x.shape[-1],\n        n_heads=2,\n        n_kv_heads=1,\n        num_fine_experts=router_weight.shape[0],\n        num_shared_experts=1,\n        top_k=top_k,\n        ff_mult=1.0,\n        use_qk_norm=True,\n        use_gated_attn=False,\n        attn_dropout=0.0,\n        mlp_dropout=0.0,\n        use_dynamic_bias=use_dynamic_bias,\n        routing_score_fn=\"softmax\",\n    )\n    block.eval()\n\n    with torch.no_grad():\n        block.router.weight.copy_(router_weight)\n        if expert_bias is not None:\n            block.expert_bias.copy_(expert_bias)\n\n    captured = {}\n\n    def _fake_sparse_experts(\n        x_input: torch.Tensor,\n        topk_idx: torch.Tensor,\n        topk_scores: torch.Tensor,\n    ) -> torch.Tensor:\n        captured[\"topk_idx\"] = topk_idx.detach().clone()\n        captured[\"topk_scores\"] = topk_scores.detach().clone()\n        return torch.zeros_like(x_input)\n\n    monkeypatch.setattr(torch.onnx, \"is_in_onnx_export\", lambda: export_mode)\n    monkeypatch.setattr(block, \"_compute_sparse_experts\", _fake_sparse_experts)\n\n    captured[\"output\"] = block.compute_moe_ffn(x, router_x=router_x)\n    return captured\n\n\ndef _make_minimal_ref_router_actor() -> PPOTFRefRouterActor:\n    actor = PPOTFRefRouterActor.__new__(PPOTFRefRouterActor)\n    nn.Module.__init__(actor)\n    actor.actor_module = ReferenceRoutedGroupedMoETransformerPolicy(\n        input_dim=8,\n        output_dim=2,\n        module_config_dict={\n            \"type\": \"ReferenceRoutedGroupedMoETransformerPolicy\",\n            \"num_fine_experts\": 4,\n            \"num_shared_experts\": 0,\n            \"top_k\": 2,\n            \"obs_embed_mlp_hidden\": 8,\n            \"router_embed_mlp_hidden\": 8,\n            \"router_input_dim\": 4,\n            \"router_feature_indices\": [0, 1, 4, 5],\n            \"d_model\": 8,\n            \"n_layers\": 2,\n            \"n_heads\": 2,\n            \"n_kv_heads\": 1,\n            \"ff_mult\": 1.0,\n            \"ff_mult_dense\": 1,\n            \"attn_dropout\": 0.0,\n            \"mlp_dropout\": 0.0,\n            \"max_ctx_len\": 4,\n            \"use_gated_attn\": False,\n            \"use_qk_norm\": True,\n            \"routing_score_fn\": \"softmax\",\n            \"use_dynamic_bias\": False,\n        },\n    )\n    actor.obs_norm_enabled = False\n    actor.obs_normalizer = nn.Identity()\n    actor.obs_norm_clip = 0.0\n    actor.assembler = SimpleNamespace(output_dim=8)\n    return actor\n\n\ndef _make_ref_router_v2_obs_schema() -> dict:\n    return {\n        \"flattened_obs\": {\n            \"seq_len\": 1,\n            \"terms\": [\n                \"unified/actor_ref_gravity_projection_cur\",\n                \"unified/actor_ref_base_linvel_cur\",\n                \"unified/actor_ref_base_angvel_cur\",\n                \"unified/actor_ref_dof_pos_cur\",\n                \"unified/actor_projected_gravity\",\n                \"unified/actor_rel_robot_root_ang_vel\",\n                \"unified/actor_dof_vel\",\n                \"unified/actor_dof_pos\",\n                \"unified/actor_ref_root_height_cur\",\n                \"unified/actor_last_action\",\n            ],\n        },\n        \"flattened_obs_fut\": {\n            \"seq_len\": 5,\n            \"terms\": [\n                \"unified/actor_ref_gravity_projection_fut\",\n                \"unified/actor_ref_base_linvel_fut\",\n                \"unified/actor_ref_base_angvel_fut\",\n                \"unified/actor_ref_dof_pos_fut\",\n                \"unified/actor_ref_root_height_fut\",\n            ],\n        },\n    }\n\n\ndef _make_ref_router_v2_obs(batch_size: list[int]) -> TensorDict:\n    shape = list(batch_size)\n    actor_fut_shape = shape + [5]\n    unified = TensorDict(\n        {\n            \"actor_ref_gravity_projection_cur\": torch.randn(*shape, 3),\n            \"actor_ref_base_linvel_cur\": torch.randn(*shape, 3),\n            \"actor_ref_base_angvel_cur\": torch.randn(*shape, 3),\n            \"actor_ref_dof_pos_cur\": torch.randn(*shape, 2),\n            \"actor_projected_gravity\": torch.randn(*shape, 3),\n            \"actor_rel_robot_root_ang_vel\": torch.randn(*shape, 3),\n            \"actor_dof_vel\": torch.randn(*shape, 3),\n            \"actor_dof_pos\": torch.randn(*shape, 3),\n            \"actor_ref_root_height_cur\": torch.randn(*shape, 1),\n            \"actor_last_action\": torch.randn(*shape, 2),\n            \"actor_ref_gravity_projection_fut\": torch.randn(\n                *actor_fut_shape, 3\n            ),\n            \"actor_ref_base_linvel_fut\": torch.randn(*actor_fut_shape, 3),\n            \"actor_ref_base_angvel_fut\": torch.randn(*actor_fut_shape, 3),\n            \"actor_ref_dof_pos_fut\": torch.randn(*actor_fut_shape, 2),\n            \"actor_ref_root_height_fut\": torch.randn(*actor_fut_shape, 1),\n        },\n        batch_size=shape,\n    )\n    return TensorDict({\"unified\": unified}, batch_size=shape)\n\n\ndef _make_minimal_ref_router_v2_actor() -> PPOTFRefRouterSeqActor:\n    obs_schema = _make_ref_router_v2_obs_schema()\n    obs_example = _make_ref_router_v2_obs([2])\n    return PPOTFRefRouterSeqActor(\n        obs_schema=obs_schema,\n        module_config_dict={\n            \"type\": \"ReferenceRoutedGroupedMoETransformerPolicyV2\",\n            \"num_fine_experts\": 4,\n            \"num_shared_experts\": 0,\n            \"top_k\": 2,\n            \"obs_embed_mlp_hidden\": 8,\n            \"d_model\": 8,\n            \"n_layers\": 2,\n            \"n_heads\": 2,\n            \"n_kv_heads\": 1,\n            \"ff_mult\": 1.0,\n            \"ff_mult_dense\": 1,\n            \"attn_dropout\": 0.0,\n            \"mlp_dropout\": 0.0,\n            \"max_ctx_len\": 4,\n            \"use_gated_attn\": False,\n            \"use_qk_norm\": True,\n            \"routing_score_fn\": \"softmax\",\n            \"use_dynamic_bias\": False,\n            \"ref_hist_n_layers\": 1,\n            \"ref_future_conv_channels\": 8,\n            \"ref_future_conv_layers\": 2,\n            \"ref_future_conv_kernel_size\": 3,\n            \"ref_future_conv_stride\": 2,\n            \"obs_norm\": {\"enabled\": False},\n            \"output_dim\": 2,\n        },\n        num_actions=2,\n        init_noise_std=0.2,\n        obs_example=obs_example,\n    )\n\n\ndef _make_minimal_ref_router_v3_actor() -> PPOTFRefRouterV3Actor:\n    obs_schema = _make_ref_router_v2_obs_schema()\n    obs_example = _make_ref_router_v2_obs([2])\n    return PPOTFRefRouterV3Actor(\n        obs_schema=obs_schema,\n        module_config_dict={\n            \"type\": \"ReferenceRoutedGroupedMoETransformerPolicyV3\",\n            \"num_fine_experts\": 4,\n            \"num_shared_experts\": 0,\n            \"top_k\": 2,\n            \"obs_embed_mlp_hidden\": 8,\n            \"d_model\": 8,\n            \"n_layers\": 2,\n            \"n_heads\": 2,\n            \"n_kv_heads\": 1,\n            \"ff_mult\": 1.0,\n            \"ff_mult_dense\": 1,\n            \"attn_dropout\": 0.0,\n            \"mlp_dropout\": 0.0,\n            \"max_ctx_len\": 4,\n            \"use_gated_attn\": False,\n            \"use_qk_norm\": True,\n            \"routing_score_fn\": \"softmax\",\n            \"use_dynamic_bias\": False,\n            \"ref_hist_n_layers\": 1,\n            \"router_future_hidden_dim\": 12,\n            \"router_layer_proj_hidden_dim\": 10,\n            \"obs_norm\": {\"enabled\": False},\n            \"output_dim\": 2,\n        },\n        num_actions=2,\n        init_noise_std=0.2,\n        obs_example=obs_example,\n    )\n\n\ndef test_export_policy_to_onnx_uses_opset_17(monkeypatch, tmp_path):\n    captured = {}\n\n    class _FakeActor:\n        def eval(self):\n            return self\n\n        def export_onnx(\n            self,\n            *,\n            onnx_path,\n            opset_version,\n            use_kv_cache=True,\n        ):\n            captured[\"onnx_path\"] = onnx_path\n            captured[\"opset_version\"] = opset_version\n            captured[\"use_kv_cache\"] = use_kv_cache\n            return str(onnx_path)\n\n    actor = _FakeActor()\n    algo = SimpleNamespace(\n        actor=actor,\n        critic=SimpleNamespace(eval=lambda: None),\n        accelerator=SimpleNamespace(unwrap_model=lambda model: model),\n        env=SimpleNamespace(_env=object()),\n    )\n\n    monkeypatch.setattr(\n        \"holomotion.src.utils.onnx_export.attach_onnx_metadata_holomotion\",\n        lambda env, onnx_path: None,\n    )\n\n    checkpoint_path = tmp_path / \"model.pt\"\n    checkpoint_path.write_bytes(b\"\")\n    export_policy_to_onnx(algo, str(checkpoint_path), use_kv_cache=False)\n\n    assert captured[\"opset_version\"] == 17\n    assert captured[\"use_kv_cache\"] is False\n\n\ndef test_export_policy_to_onnx_restores_training_mode(monkeypatch, tmp_path):\n    class _FakeActor:\n        def __init__(self):\n            self.training = True\n\n        def eval(self):\n            self.training = False\n            return self\n\n        def train(self, mode: bool = True):\n            self.training = bool(mode)\n            return self\n\n        def export_onnx(\n            self,\n            *,\n            onnx_path,\n            opset_version,\n            use_kv_cache=True,\n        ):\n            return str(onnx_path)\n\n    class _FakeCritic:\n        def __init__(self):\n            self.training = True\n\n        def eval(self):\n            self.training = False\n            return self\n\n        def train(self, mode: bool = True):\n            self.training = bool(mode)\n            return self\n\n    actor = _FakeActor()\n    critic = _FakeCritic()\n    algo = SimpleNamespace(\n        actor=actor,\n        critic=critic,\n        accelerator=SimpleNamespace(unwrap_model=lambda model: model),\n        env=SimpleNamespace(_env=object()),\n    )\n\n    monkeypatch.setattr(\n        \"holomotion.src.utils.onnx_export.attach_onnx_metadata_holomotion\",\n        lambda env, onnx_path: None,\n    )\n\n    checkpoint_path = tmp_path / \"model.pt\"\n    checkpoint_path.write_bytes(b\"\")\n    export_policy_to_onnx(algo, str(checkpoint_path), use_kv_cache=False)\n\n    assert actor.training is True\n    assert critic.training is True\n\n\ndef test_clone_module_for_cpu_export_does_not_move_live_module(monkeypatch):\n    module = _RecordingDeviceModule()\n\n    monkeypatch.setattr(\n        \"holomotion.src.modules.agent_modules._module_device\",\n        lambda _: torch.device(\"cuda:0\"),\n    )\n\n    cloned = _clone_module_for_cpu_export(module)\n\n    assert module.to_calls == []\n    assert module.current_device == \"cuda:0\"\n    assert isinstance(cloned, _RecordingDeviceModule)\n    assert cloned is not module\n\n\ndef test_ppotf_actor_export_uses_legacy_torchscript(monkeypatch, tmp_path):\n    export_calls = []\n\n    def _fake_export(*args, **kwargs):\n        export_calls.append(kwargs)\n\n    monkeypatch.setattr(torch.onnx, \"export\", _fake_export)\n\n    actor = PPOTFActor.__new__(PPOTFActor)\n    nn.Module.__init__(actor)\n    actor.actor_module = _DummyTFModule()\n    actor.obs_norm_enabled = False\n    actor.obs_normalizer = nn.Identity()\n    actor.obs_norm_clip = 0.0\n    actor.assembler = SimpleNamespace(output_dim=6)\n\n    out_path = tmp_path / \"policy.onnx\"\n    PPOTFActor.export_onnx(\n        actor,\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    assert len(export_calls) == 1\n    assert export_calls[0][\"opset_version\"] == 17\n    assert export_calls[0][\"dynamo\"] is False\n\n\ndef test_ppotf_actor_export_onnx_avoids_isnan(tmp_path):\n    actor = PPOTFActor.__new__(PPOTFActor)\n    nn.Module.__init__(actor)\n    actor.actor_module = _DummyAttentionTFModule()\n    actor.obs_norm_enabled = False\n    actor.obs_normalizer = nn.Identity()\n    actor.obs_norm_clip = 0.0\n    actor.assembler = SimpleNamespace(output_dim=2)\n\n    out_path = tmp_path / \"policy.onnx\"\n    PPOTFActor.export_onnx(\n        actor,\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    op_types = [node.op_type for node in model.graph.node]\n\n    assert \"IsNaN\" not in op_types\n\n\ndef test_ppotf_real_transformer_export_onnx_avoids_isnan(tmp_path):\n    actor = _make_minimal_real_transformer_actor()\n\n    out_path = tmp_path / \"policy_real_tf.onnx\"\n    PPOTFActor.export_onnx(\n        actor,\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    op_types = [node.op_type for node in model.graph.node]\n\n    assert \"IsNaN\" not in op_types\n\n\ndef test_ppotf_real_moe_transformer_export_reaches_router_ops(tmp_path):\n    actor = _make_minimal_real_transformer_actor(\n        n_layers=2,\n        num_fine_experts=4,\n        top_k=2,\n        routing_score_fn=\"softmax\",\n    )\n\n    out_path = tmp_path / \"policy_real_moe_tf.onnx\"\n    PPOTFActor.export_onnx(\n        actor,\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    op_types = [node.op_type for node in model.graph.node]\n\n    assert \"TopK\" in op_types\n\n\ndef test_ppotf_real_moe_transformer_export_exposes_routing_outputs(tmp_path):\n    actor = _make_minimal_real_transformer_actor(\n        n_layers=3,\n        num_fine_experts=4,\n        top_k=2,\n        routing_score_fn=\"softmax\",\n    )\n\n    out_path = tmp_path / \"policy_real_moe_tf_outputs.onnx\"\n    PPOTFActor.export_onnx(\n        actor,\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    output_names = [value.name for value in model.graph.output]\n\n    assert output_names == [\n        \"actions\",\n        \"present_key_values\",\n        \"moe_layer_1_expert_indices\",\n        \"moe_layer_1_expert_logits\",\n        \"moe_layer_2_expert_indices\",\n        \"moe_layer_2_expert_logits\",\n    ]\n\n\ndef test_ppotf_real_moe_transformer_export_dense_last_uses_actual_moe_indices(\n    tmp_path,\n):\n    actor = _make_minimal_real_transformer_actor(\n        n_layers=4,\n        num_fine_experts=4,\n        top_k=2,\n        routing_score_fn=\"softmax\",\n        dense_layer_at_last=True,\n    )\n\n    out_path = tmp_path / \"policy_real_moe_tf_dense_last_outputs.onnx\"\n    PPOTFActor.export_onnx(\n        actor,\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    output_names = [value.name for value in model.graph.output]\n\n    assert output_names == [\n        \"actions\",\n        \"present_key_values\",\n        \"moe_layer_1_expert_indices\",\n        \"moe_layer_1_expert_logits\",\n        \"moe_layer_2_expert_indices\",\n        \"moe_layer_2_expert_logits\",\n    ]\n\n\ndef test_ppotf_real_moe_transformer_export_avoids_reduce_log_sum_exp(\n    tmp_path,\n):\n    actor = _make_minimal_real_transformer_actor(\n        n_layers=2,\n        num_fine_experts=4,\n        top_k=2,\n        routing_score_fn=\"softmax\",\n    )\n\n    out_path = tmp_path / \"policy_real_moe_tf_no_rlse.onnx\"\n    PPOTFActor.export_onnx(\n        actor,\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    op_types = [node.op_type for node in model.graph.node]\n\n    assert \"ReduceLogSumExp\" not in op_types\n\n\ndef test_export_safe_moe_router_matches_training_scores_for_topk1(monkeypatch):\n    x = torch.tensor([[[1.0, -0.5, 0.25, 2.0]]], dtype=torch.float32)\n    router_weight = torch.tensor(\n        [\n            [0.1, 0.3, -0.2, 0.5],\n            [0.2, -0.4, 0.1, 0.7],\n            [-0.3, 0.6, 0.2, -0.1],\n            [0.4, 0.1, -0.5, 0.2],\n        ],\n        dtype=torch.float32,\n    )\n\n    eager = _capture_moe_router_outputs(\n        monkeypatch,\n        export_mode=False,\n        top_k=1,\n        use_dynamic_bias=False,\n        x=x,\n        router_weight=router_weight,\n    )\n    export = _capture_moe_router_outputs(\n        monkeypatch,\n        export_mode=True,\n        top_k=1,\n        use_dynamic_bias=False,\n        x=x,\n        router_weight=router_weight,\n    )\n\n    assert torch.equal(export[\"topk_idx\"], eager[\"topk_idx\"])\n    torch.testing.assert_close(\n        export[\"topk_scores\"],\n        eager[\"topk_scores\"],\n        atol=1.0e-6,\n        rtol=1.0e-5,\n    )\n\n\ndef test_export_safe_moe_router_matches_training_scores_with_dynamic_bias(\n    monkeypatch,\n):\n    x = torch.tensor(\n        [\n            [[0.2, -1.0, 0.5, 1.1], [0.4, 0.3, -0.7, 0.9]],\n            [[-0.6, 0.8, 1.0, -0.2], [0.1, -0.4, 0.6, 0.7]],\n        ],\n        dtype=torch.float32,\n    )\n    router_weight = torch.tensor(\n        [\n            [0.2, -0.1, 0.5, 0.3],\n            [-0.4, 0.7, 0.2, 0.1],\n            [0.6, 0.2, -0.3, 0.4],\n            [0.1, 0.5, 0.4, -0.6],\n        ],\n        dtype=torch.float32,\n    )\n    expert_bias = torch.tensor([0.0, 0.4, -0.3, 0.2], dtype=torch.float32)\n\n    eager = _capture_moe_router_outputs(\n        monkeypatch,\n        export_mode=False,\n        top_k=2,\n        use_dynamic_bias=True,\n        x=x,\n        router_weight=router_weight,\n        expert_bias=expert_bias,\n    )\n    export = _capture_moe_router_outputs(\n        monkeypatch,\n        export_mode=True,\n        top_k=2,\n        use_dynamic_bias=True,\n        x=x,\n        router_weight=router_weight,\n        expert_bias=expert_bias,\n    )\n\n    assert torch.equal(export[\"topk_idx\"], eager[\"topk_idx\"])\n    torch.testing.assert_close(\n        export[\"topk_scores\"],\n        eager[\"topk_scores\"],\n        atol=1.0e-6,\n        rtol=1.0e-5,\n    )\n\n\ndef test_grouped_moe_router_x_keeps_topk_when_main_input_changes(monkeypatch):\n    router_weight = torch.tensor(\n        [\n            [1.0, 0.0, 0.0, 0.0],\n            [0.0, 1.0, 0.0, 0.0],\n            [0.0, 0.0, 1.0, 0.0],\n        ],\n        dtype=torch.float32,\n    )\n    router_x = torch.tensor([[[4.0, 1.0, 0.0, 0.0]]], dtype=torch.float32)\n    x_a = torch.tensor([[[0.0, 0.5, 1.0, 1.5]]], dtype=torch.float32)\n    x_b = torch.tensor([[[3.0, -2.0, -1.0, 6.0]]], dtype=torch.float32)\n\n    out_a = _capture_moe_router_outputs(\n        monkeypatch,\n        export_mode=False,\n        top_k=1,\n        use_dynamic_bias=False,\n        x=x_a,\n        router_x=router_x,\n        router_weight=router_weight,\n    )\n    out_b = _capture_moe_router_outputs(\n        monkeypatch,\n        export_mode=False,\n        top_k=1,\n        use_dynamic_bias=False,\n        x=x_b,\n        router_x=router_x,\n        router_weight=router_weight,\n    )\n\n    assert torch.equal(out_a[\"topk_idx\"], out_b[\"topk_idx\"])\n    assert not torch.allclose(out_a[\"output\"], out_b[\"output\"])\n\n\ndef test_grouped_moe_router_x_changes_topk_when_router_input_changes(\n    monkeypatch,\n):\n    router_weight = torch.tensor(\n        [\n            [1.0, 0.0, 0.0, 0.0],\n            [0.0, 1.0, 0.0, 0.0],\n            [0.0, 0.0, 1.0, 0.0],\n        ],\n        dtype=torch.float32,\n    )\n    x = torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32)\n    router_x_a = torch.tensor([[[3.0, 0.0, 0.0, 0.0]]], dtype=torch.float32)\n    router_x_b = torch.tensor([[[0.0, 5.0, 0.0, 0.0]]], dtype=torch.float32)\n\n    out_a = _capture_moe_router_outputs(\n        monkeypatch,\n        export_mode=False,\n        top_k=1,\n        use_dynamic_bias=False,\n        x=x,\n        router_x=router_x_a,\n        router_weight=router_weight,\n    )\n    out_b = _capture_moe_router_outputs(\n        monkeypatch,\n        export_mode=False,\n        top_k=1,\n        use_dynamic_bias=False,\n        x=x,\n        router_x=router_x_b,\n        router_weight=router_weight,\n    )\n\n    assert not torch.equal(out_a[\"topk_idx\"], out_b[\"topk_idx\"])\n\n\ndef test_ref_router_actor_export_keeps_single_obs_input_and_reaches_moe(\n    tmp_path,\n):\n    actor = _make_minimal_ref_router_actor()\n\n    out_path = tmp_path / \"policy_ref_router.onnx\"\n    PPOTFActor.export_onnx(\n        actor,\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    input_names = [value.name for value in model.graph.input]\n    op_types = [node.op_type for node in model.graph.node]\n\n    assert input_names == [\"obs\", \"past_key_values\", \"step_idx\"]\n    assert \"TopK\" in op_types\n\n\ndef test_ref_router_v2_actor_export_keeps_single_obs_input_and_reaches_moe(\n    tmp_path,\n):\n    actor = _make_minimal_ref_router_v2_actor()\n\n    assert actor.onnx_past_key_values_shape(batch_size=1) == (\n        3,\n        2,\n        1,\n        4,\n        1,\n        4,\n    )\n\n    out_path = tmp_path / \"policy_ref_router_v2.onnx\"\n    actor.export_onnx(\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    input_names = [value.name for value in model.graph.input]\n    op_types = [node.op_type for node in model.graph.node]\n\n    assert input_names == [\"obs\", \"past_key_values\", \"step_idx\"]\n    assert \"TopK\" in op_types\n\n\ndef test_ref_router_v3_actor_export_keeps_single_obs_input_and_reaches_moe(\n    tmp_path,\n):\n    actor = _make_minimal_ref_router_v3_actor()\n\n    assert actor.onnx_past_key_values_shape(batch_size=1) == (\n        3,\n        2,\n        1,\n        4,\n        1,\n        4,\n    )\n\n    out_path = tmp_path / \"policy_ref_router_v3.onnx\"\n    actor.export_onnx(\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    input_names = [value.name for value in model.graph.input]\n    op_types = [node.op_type for node in model.graph.node]\n\n    assert input_names == [\"obs\", \"past_key_values\", \"step_idx\"]\n    assert \"TopK\" in op_types\n\n\ndef test_real_transformer_actor_export_supports_selected_expert_margin(\n    tmp_path,\n):\n    actor = _make_minimal_real_transformer_actor(\n        n_layers=2,\n        num_fine_experts=4,\n        top_k=2,\n        selected_expert_margin_to_unselected_enabled=True,\n        selected_expert_margin_to_unselected_target=0.4,\n    )\n\n    out_path = tmp_path / \"policy_selected_expert_margin.onnx\"\n    actor.export_onnx(\n        out_path,\n        opset_version=17,\n        use_kv_cache=True,\n    )\n\n    model = onnx.load(str(out_path))\n    input_names = [value.name for value in model.graph.input]\n    op_types = [node.op_type for node in model.graph.node]\n\n    assert input_names == [\"obs\", \"past_key_values\", \"step_idx\"]\n    assert \"TopK\" in op_types\n"
  },
  {
    "path": "tests/test_algo_base_iteration_logging.py",
    "content": "from pathlib import Path\nimport sys\nfrom unittest import mock\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.algo.algo_base import BaseOnpolicyRL\n\n\ndef test_log_iteration_uses_checkpoint_start_for_total_iterations():\n    algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL)\n    algo.log_dir = \"/tmp/holomotion-test\"\n    algo.gpu_world_size = 1\n    algo.num_steps_per_env = 8\n    algo.num_envs = 16\n    algo.num_learning_iterations = 10\n    algo.current_learning_iteration = 123\n    algo.rewbuffer = []\n    algo.lenbuffer = []\n    algo._aggregate_episode_log_metrics = lambda: {}\n    algo._get_additional_log_metrics = lambda: {}\n    algo.algo_logger = mock.Mock()\n\n    BaseOnpolicyRL._log_iteration(\n        algo,\n        it=123,\n        loss_dict={\"policy\": 1.5},\n        collection_time=2.0,\n        learn_time=2.0,\n    )\n\n    algo.algo_logger.log_iteration.assert_called_once()\n    _, kwargs = algo.algo_logger.log_iteration.call_args\n    assert kwargs[\"step\"] == 123\n    assert kwargs[\"total_learning_iterations\"] == 133\n    assert kwargs[\"metrics\"][\"0-Train/iteration\"] == 123\n    assert kwargs[\"metrics\"][\"0-Train/iterations_total\"] == 133\n"
  },
  {
    "path": "tests/test_build_quantization_dataset.py",
    "content": "import importlib.util\nfrom pathlib import Path\n\n\ndef _load_module():\n    module_path = (\n        Path(__file__).resolve().parents[1]\n        / \"not_for_commit\"\n        / \"build_quantization_dataset.py\"\n    )\n    spec = importlib.util.spec_from_file_location(\n        \"build_quantization_dataset\", module_path\n    )\n    module = importlib.util.module_from_spec(spec)\n    assert spec.loader is not None\n    spec.loader.exec_module(module)\n    return module\n\n\ndef test_allocate_sample_counts_normalizes_weights_and_matches_total():\n    module = _load_module()\n\n    counts = module.allocate_sample_counts(\n        {\n            \"AMASS\": 0.2,\n            \"lafan1\": 0.2,\n            \"MotionMillion-ft\": 0.4,\n            \"pico_train\": 0.05,\n        },\n        17,\n    )\n\n    assert counts == {\n        \"AMASS\": 4,\n        \"lafan1\": 4,\n        \"MotionMillion-ft\": 8,\n        \"pico_train\": 1,\n    }\n    assert sum(counts.values()) == 17\n\n\ndef test_build_quantization_dataset_creates_symlinks(tmp_path):\n    module = _load_module()\n    npz_root = tmp_path / \"retargeted\"\n    npz_root.mkdir()\n\n    for dataset_name, clip_count in {\"AMASS\": 3, \"lafan1\": 2}.items():\n        dataset_dir = npz_root / dataset_name\n        dataset_dir.mkdir()\n        for clip_idx in range(clip_count):\n            (dataset_dir / f\"clip_{clip_idx}.npz\").write_text(\n                f\"{dataset_name}-{clip_idx}\", encoding=\"utf-8\"\n            )\n\n    output_dir = module.build_quantization_dataset(\n        npz_root=npz_root,\n        dataset_ratios={\"AMASS\": 2.0, \"lafan1\": 1.0},\n        num_clips=3,\n        seed=0,\n        current_date=\"20260324\",\n    )\n\n    assert output_dir == npz_root / \"20260324_quant_dataset\"\n    created_links = sorted(output_dir.iterdir())\n    assert len(created_links) == 3\n    assert all(link.is_symlink() for link in created_links)\n    assert {link.name.split(\"__\", 1)[0] for link in created_links} == {\n        \"AMASS\",\n        \"lafan1\",\n    }\n    for link in created_links:\n        assert link.resolve().is_file()\n        assert link.suffix == \".npz\"\n\n\ndef test_build_quantization_dataset_caps_each_dataset_at_available_clips(\n    tmp_path,\n):\n    module = _load_module()\n    npz_root = tmp_path / \"retargeted\"\n    npz_root.mkdir()\n\n    for dataset_name, clip_count in {\"AMASS\": 1, \"lafan1\": 5}.items():\n        dataset_dir = npz_root / dataset_name\n        dataset_dir.mkdir()\n        for clip_idx in range(clip_count):\n            (dataset_dir / f\"clip_{clip_idx}.npz\").write_text(\n                f\"{dataset_name}-{clip_idx}\", encoding=\"utf-8\"\n            )\n\n    output_dir = module.build_quantization_dataset(\n        npz_root=npz_root,\n        dataset_ratios={\"AMASS\": 2.0, \"lafan1\": 1.0},\n        num_clips=6,\n        seed=0,\n        current_date=\"20260324\",\n    )\n\n    created_links = sorted(output_dir.iterdir())\n\n    assert len(created_links) == 3\n    assert sum(link.name.startswith(\"AMASS__\") for link in created_links) == 1\n    assert sum(link.name.startswith(\"lafan1__\") for link in created_links) == 2\n"
  },
  {
    "path": "tests/test_cache_curriculum_sampler.py",
    "content": "import json\nfrom types import SimpleNamespace\n\nimport torch\n\nimport holomotion.src.training.h5_dataloader as h5_dataloader\nfrom holomotion.src.algo.ppo import PPO\nfrom holomotion.src.training.h5_dataloader import (\n    MotionClipBatchCache,\n    ClipBatch,\n    PrioritizedInfiniteSampler,\n)\n\n\ndef _update_sampler(\n    sampler: PrioritizedInfiniteSampler,\n    completion_rates: list[float],\n    *,\n    swap_index: int,\n) -> bool:\n    num_windows = len(completion_rates)\n    window_indices = torch.arange(num_windows, dtype=torch.long)\n    completion_rate_means = torch.tensor(completion_rates, dtype=torch.float32)\n    mpkpe_signal_means = torch.zeros(num_windows, dtype=torch.float32)\n    counts = torch.ones(num_windows, dtype=torch.float32)\n    return sampler.maybe_update_from_observations(\n        window_indices=window_indices,\n        mpkpe_signal_means=mpkpe_signal_means,\n        completion_rate_means=completion_rate_means,\n        counts=counts,\n        swap_index=swap_index,\n    )\n\n\ndef _update_sampler_subset(\n    sampler: PrioritizedInfiniteSampler,\n    *,\n    window_indices: list[int],\n    completion_rates: list[float],\n    swap_index: int,\n    counts: list[float] | None = None,\n) -> bool:\n    if counts is None:\n        counts = [1.0] * len(window_indices)\n    window_index_tensor = torch.tensor(window_indices, dtype=torch.long)\n    completion_rate_tensor = torch.tensor(\n        completion_rates, dtype=torch.float32\n    )\n    count_tensor = torch.tensor(counts, dtype=torch.float32)\n    mpkpe_signal_means = torch.zeros(len(window_indices), dtype=torch.float32)\n    return sampler.maybe_update_from_observations(\n        window_indices=window_index_tensor,\n        mpkpe_signal_means=mpkpe_signal_means,\n        completion_rate_means=completion_rate_tensor,\n        counts=count_tensor,\n        swap_index=swap_index,\n    )\n\n\nclass _ChunkLimitedSampler:\n    def __init__(self, *, max_query_size: int) -> None:\n        self.max_query_size = int(max_query_size)\n        self.state_version = 7\n        self.query_sizes: list[int] = []\n\n    def _checked_indices(self, window_indices: torch.Tensor) -> torch.Tensor:\n        indices = window_indices.to(dtype=torch.long).reshape(-1)\n        size = int(indices.numel())\n        self.query_sizes.append(size)\n        if size > self.max_query_size:\n            raise AssertionError(\n                f\"expected chunked access <= {self.max_query_size}, got {size}\"\n            )\n        return indices\n\n    def get_scores_for_indices(\n        self, window_indices: torch.Tensor\n    ) -> torch.Tensor:\n        indices = self._checked_indices(window_indices)\n        return indices.to(dtype=torch.float32)\n\n    def get_window_state_for_indices(\n        self, window_indices: torch.Tensor\n    ) -> dict[str, torch.Tensor]:\n        indices = self._checked_indices(window_indices)\n        count = int(indices.numel())\n        return {\n            \"ema_completion_rate\": torch.zeros(count, dtype=torch.float32),\n            \"completion_rate_rel_improve\": torch.zeros(\n                count, dtype=torch.float32\n            ),\n            \"selection_count\": indices + 10,\n            \"seen\": torch.zeros(count, dtype=torch.bool),\n            \"in_prioritized_pool\": torch.zeros(count, dtype=torch.bool),\n        }\n\n    def get_pool_statistics(self) -> dict[str, float]:\n        return {\n            \"prioritized_pool_size\": 0.0,\n            \"prioritized_pool_mean_score\": 0.0,\n            \"uniform_pool_mean_score\": 0.0,\n            \"entered_prioritized_pool_count\": 0.0,\n            \"exited_prioritized_pool_count\": 0.0,\n        }\n\n\nclass _FakeTrainDataset:\n    def __init__(self, windows: list[SimpleNamespace]) -> None:\n        self.windows = windows\n\n    def __len__(self) -> int:\n        return len(self.windows)\n\n    def close(self) -> None:\n        return\n\n\ndef test_sampler_uses_configured_uniform_ratio_immediately():\n    sampler = PrioritizedInfiniteSampler(\n        dataset_len=8,\n        batch_size=10,\n        seed=0,\n        p_a_ratio=0.3,\n    )\n\n    assert sampler._pool_batch_sizes() == (3, 7)\n\n\ndef test_completion_rate_relative_improvement_alone_drives_scores():\n    sampler = PrioritizedInfiniteSampler(\n        dataset_len=3,\n        batch_size=2,\n        seed=0,\n        p_a_ratio=0.5,\n        ema_alpha_signal=0.5,\n        ema_alpha_rel_improve=1.0,\n    )\n\n    assert _update_sampler(sampler, [0.2, 0.2, 0.2], swap_index=1)\n    assert _update_sampler(sampler, [0.8, 0.2, 0.2], swap_index=2)\n\n    scores = sampler.get_scores_for_indices(torch.arange(3, dtype=torch.long))\n\n    assert scores[0].item() > 0.0\n    assert scores[1].item() == 0.0\n    assert scores[2].item() == 0.0\n\n\ndef test_sampler_weights_progress_by_remaining_difficulty():\n    sampler = PrioritizedInfiniteSampler(\n        dataset_len=2,\n        batch_size=2,\n        seed=0,\n        p_a_ratio=0.5,\n        ema_alpha_signal=1.0,\n        ema_alpha_rel_improve=1.0,\n    )\n\n    assert _update_sampler(sampler, [0.1, 0.8], swap_index=1)\n    assert _update_sampler(sampler, [0.2, 0.9], swap_index=2)\n\n    scores = sampler.get_scores_for_indices(torch.arange(2, dtype=torch.long))\n\n    assert scores[0].item() > scores[1].item()\n\n\ndef test_sampler_tracks_cumulative_selection_counts():\n    sampler = PrioritizedInfiniteSampler(\n        dataset_len=4,\n        batch_size=2,\n        seed=0,\n        p_a_ratio=0.5,\n    )\n\n    iterator = iter(sampler)\n    selected_indices = [next(iterator) for _ in range(4)]\n    state = sampler.get_window_state_for_indices(torch.arange(4))\n    selection_count = state[\"selection_count\"]\n    expected_count = torch.bincount(\n        torch.tensor(selected_indices, dtype=torch.long), minlength=4\n    )\n\n    assert int(selection_count.sum().item()) == 4\n    assert torch.equal(selection_count, expected_count)\n\n\ndef test_low_completion_plateau_drops_from_prioritized_replay():\n    sampler = PrioritizedInfiniteSampler(\n        dataset_len=3,\n        batch_size=3,\n        seed=0,\n        p_a_ratio=1.0 / 3.0,\n        ema_alpha_signal=1.0,\n        ema_alpha_rel_improve=1.0,\n    )\n\n    assert _update_sampler(sampler, [0.1, 0.2, 0.2], swap_index=1)\n    assert _update_sampler(sampler, [0.4, 0.2, 0.2], swap_index=2)\n    assert _update_sampler(sampler, [0.4, 0.8, 0.8], swap_index=3)\n\n    state = sampler.get_window_state_for_indices(\n        torch.arange(3, dtype=torch.long)\n    )\n    assert not bool(state[\"in_prioritized_pool\"][0].item())\n\n    generator = torch.Generator().manual_seed(0)\n    uniform_pick = sampler._sample_uniform_indices(generator, 3)\n    assert 0 in uniform_pick.tolist()\n\n\ndef test_prioritized_windows_persist_beyond_immediate_batch():\n    sampler = PrioritizedInfiniteSampler(\n        dataset_len=6,\n        batch_size=4,\n        seed=0,\n        p_a_ratio=0.5,\n        ema_alpha_signal=1.0,\n        ema_alpha_rel_improve=1.0,\n    )\n\n    assert _update_sampler_subset(\n        sampler,\n        window_indices=[0, 1],\n        completion_rates=[0.2, 0.2],\n        swap_index=1,\n    )\n    assert _update_sampler_subset(\n        sampler,\n        window_indices=[0, 1],\n        completion_rates=[0.8, 0.7],\n        swap_index=2,\n    )\n    assert _update_sampler_subset(\n        sampler,\n        window_indices=[2, 3],\n        completion_rates=[0.3, 0.3],\n        swap_index=3,\n    )\n\n    state = sampler.get_window_state_for_indices(torch.tensor([0, 1]))\n    assert torch.equal(\n        state[\"in_prioritized_pool\"],\n        torch.tensor([True, True], dtype=torch.bool),\n    )\n\n\ndef test_sampler_reports_pool_means_and_membership_churn():\n    sampler = PrioritizedInfiniteSampler(\n        dataset_len=4,\n        batch_size=4,\n        seed=0,\n        p_a_ratio=0.5,\n        ema_alpha_signal=0.5,\n        ema_alpha_rel_improve=1.0,\n    )\n\n    assert _update_sampler(sampler, [0.2, 0.2, 0.2, 0.2], swap_index=1)\n    assert _update_sampler(sampler, [0.9, 0.8, 0.2, 0.2], swap_index=2)\n    assert _update_sampler(sampler, [0.1, 0.2, 0.9, 0.8], swap_index=3)\n\n    next(iter(sampler))\n    stats = sampler.get_pool_statistics()\n\n    assert stats is not None\n    assert set(stats) == {\n        \"prioritized_pool_size\",\n        \"prioritized_pool_mean_score\",\n        \"uniform_pool_mean_score\",\n        \"entered_prioritized_pool_count\",\n        \"exited_prioritized_pool_count\",\n    }\n    assert stats[\"prioritized_pool_size\"] == 2.0\n    assert stats[\"entered_prioritized_pool_count\"] == 2.0\n    assert stats[\"exited_prioritized_pool_count\"] == 2.0\n    assert (\n        stats[\"prioritized_pool_mean_score\"] > stats[\"uniform_pool_mean_score\"]\n    )\n\n\ndef test_sampler_hot_path_avoids_full_dataset_temporaries(monkeypatch):\n    sampler = PrioritizedInfiniteSampler(\n        dataset_len=1_000_000,\n        batch_size=8,\n        seed=0,\n        p_a_ratio=0.5,\n        ema_alpha_signal=1.0,\n        ema_alpha_rel_improve=1.0,\n    )\n\n    orig_zeros = h5_dataloader.torch.zeros\n    orig_arange = h5_dataloader.torch.arange\n    orig_randperm = h5_dataloader.torch.randperm\n\n    def _guard_size(arg) -> int | None:\n        if isinstance(arg, int):\n            return arg\n        if (\n            isinstance(arg, tuple)\n            and len(arg) == 1\n            and isinstance(arg[0], int)\n        ):\n            return arg[0]\n        return None\n\n    def guarded_zeros(*args, **kwargs):\n        size = _guard_size(args[0]) if args else None\n        if size == sampler._ds_len:\n            raise AssertionError(\"full-dataset zeros in hot path\")\n        return orig_zeros(*args, **kwargs)\n\n    def guarded_arange(*args, **kwargs):\n        if args and args[0] == sampler._ds_len:\n            raise AssertionError(\"full-dataset arange in hot path\")\n        return orig_arange(*args, **kwargs)\n\n    def guarded_randperm(*args, **kwargs):\n        if args and args[0] == sampler._ds_len:\n            raise AssertionError(\"full-dataset randperm in hot path\")\n        return orig_randperm(*args, **kwargs)\n\n    monkeypatch.setattr(h5_dataloader.torch, \"zeros\", guarded_zeros)\n    monkeypatch.setattr(h5_dataloader.torch, \"arange\", guarded_arange)\n    monkeypatch.setattr(h5_dataloader.torch, \"randperm\", guarded_randperm)\n\n    assert _update_sampler_subset(\n        sampler,\n        window_indices=[5, 25, 125, 625],\n        completion_rates=[0.1, 0.2, 0.3, 0.4],\n        swap_index=1,\n    )\n    batch_indices = sampler._sample_batch_indices(\n        torch.Generator().manual_seed(0)\n    )\n    assert int(batch_indices.numel()) == 8\n\n\ndef test_ppo_logs_only_core_curriculum_metrics():\n    algo = PPO.__new__(PPO)\n    algo.actor_learning_rate = 1.0e-4\n    algo.critic_learning_rate = 2.0e-4\n    algo._last_update_metrics = {}\n    algo.command_name = \"ref_motion\"\n    algo._get_mean_policy_std = lambda: torch.tensor(0.0)\n\n    cache = SimpleNamespace(\n        swap_index=12,\n        cache_curriculum_pool_statistics=lambda: {\n            \"prioritized_pool_size\": 2.0,\n            \"prioritized_pool_mean_score\": 0.8,\n            \"uniform_pool_mean_score\": 0.1,\n            \"entered_prioritized_pool_count\": 1.0,\n            \"exited_prioritized_pool_count\": 1.0,\n        },\n    )\n    motion_cmd = SimpleNamespace(_motion_cache=cache)\n    algo.env = SimpleNamespace(\n        _env=SimpleNamespace(\n            command_manager=SimpleNamespace(\n                get_term=lambda name: motion_cmd,\n            )\n        )\n    )\n\n    metrics = algo._get_additional_log_metrics()\n\n    assert metrics[\"1-Perf/Cache/swap_index\"] == 12.0\n    assert metrics[\"1-Perf/Cache/prioritized_pool_size\"] == 2.0\n    assert metrics[\"1-Perf/Cache/prioritized_pool_mean_score\"] == 0.8\n    assert metrics[\"1-Perf/Cache/uniform_pool_mean_score\"] == 0.1\n    assert metrics[\"1-Perf/Cache/entered_prioritized_pool_count\"] == 1.0\n    assert metrics[\"1-Perf/Cache/exited_prioritized_pool_count\"] == 1.0\n    assert (\n        \"1-Perf/Cache/curriculum_probability_coefficient_of_variation\"\n        not in metrics\n    )\n    assert (\n        \"1-Perf/Cache/curriculum_max_probability_over_uniform\" not in metrics\n    )\n    assert \"1-Perf/Cache/uniform_floor_ratio\" not in metrics\n\n\ndef test_cache_curriculum_dumps_on_scheduled_swap_even_without_state_update():\n    cache = MotionClipBatchCache.__new__(MotionClipBatchCache)\n    cache._datasets = {}\n    cache._cache_curriculum_sampler = SimpleNamespace(\n        maybe_update_from_observations=lambda **kwargs: False,\n    )\n    dumped_swaps = []\n    cache._maybe_dump_cache_curriculum_scores_json = (\n        lambda *, swap_index: dumped_swaps.append(swap_index)\n    )\n\n    updated = cache.update_cache_curriculum(\n        window_indices=torch.tensor([0], dtype=torch.long),\n        mpkpe_signal_means=torch.tensor([0.0], dtype=torch.float32),\n        completion_rate_means=torch.tensor([0.0], dtype=torch.float32),\n        counts=torch.tensor([1.0], dtype=torch.float32),\n        swap_index=5,\n    )\n\n    assert updated is False\n    assert dumped_swaps == [5]\n\n\ndef test_update_cache_curriculum_refreshes_prefetched_batch_when_state_changes():\n    cache = MotionClipBatchCache.__new__(MotionClipBatchCache)\n    cache._datasets = {}\n    cache._cache_curriculum_sampler = SimpleNamespace(\n        maybe_update_from_observations=lambda **kwargs: True,\n    )\n    cache._cache_curriculum_dump_enabled = False\n    cache._next_batch = ClipBatch(\n        tensors={},\n        lengths=torch.tensor([1], dtype=torch.long),\n        motion_keys=[\"stale\"],\n        raw_motion_keys=[\"stale\"],\n        window_indices=torch.tensor([0], dtype=torch.long),\n        max_frame_length=1,\n    )\n    refreshed_batch = ClipBatch(\n        tensors={},\n        lengths=torch.tensor([1], dtype=torch.long),\n        motion_keys=[\"fresh\"],\n        raw_motion_keys=[\"fresh\"],\n        window_indices=torch.tensor([1], dtype=torch.long),\n        max_frame_length=1,\n    )\n    cache._fetch_next_batch = lambda: refreshed_batch\n\n    updated = cache.update_cache_curriculum(\n        window_indices=torch.tensor([0], dtype=torch.long),\n        mpkpe_signal_means=torch.tensor([0.0], dtype=torch.float32),\n        completion_rate_means=torch.tensor([0.0], dtype=torch.float32),\n        counts=torch.tensor([1.0], dtype=torch.float32),\n        swap_index=5,\n    )\n\n    assert updated is True\n    assert cache._next_batch.motion_keys == [\"fresh\"]\n\n\ndef test_cache_curriculum_whole_window_dump_streams_rows_in_chunks(\n    tmp_path,\n):\n    cache = MotionClipBatchCache.__new__(MotionClipBatchCache)\n    cache._datasets = {\n        \"train\": _FakeTrainDataset(\n            [\n                SimpleNamespace(\n                    raw_motion_key=f\"raw_{idx}\",\n                    motion_key=f\"motion_{idx}\",\n                    start=idx,\n                    length=idx + 1,\n                )\n                for idx in range(5)\n            ]\n        )\n    }\n    sampler = _ChunkLimitedSampler(max_query_size=2)\n    cache._cache_curriculum_sampler = sampler\n    cache._cache_curriculum_dump_enabled = True\n    cache._cache_curriculum_dump_every_swaps = 1\n    cache._cache_curriculum_dump_chunk_size = 2\n    cache._cache_curriculum_last_dump_swap = -1\n    cache._cache_curriculum_dump_dir = tmp_path\n    cache._sampler_rank = 3\n\n    cache._maybe_dump_cache_curriculum_scores_json(swap_index=1)\n\n    output_path = tmp_path / \"whole_window_scores_rank_0003_swap_000001.json\"\n    payload = json.loads(output_path.read_text())\n\n    assert output_path.exists()\n    assert payload[\"swap_index\"] == 1\n    assert payload[\"rank\"] == 3\n    assert payload[\"sampler_state_version\"] == 7\n    assert payload[\"num_windows\"] == 5\n    assert len(payload[\"rows\"]) == 5\n    assert payload[\"rows\"][0][\"window_index\"] == 0\n    assert payload[\"rows\"][0][\"selection_count\"] == 10\n    assert payload[\"rows\"][-1][\"window_index\"] == 4\n    assert payload[\"rows\"][-1][\"selection_count\"] == 14\n    assert \"probability\" not in payload[\"rows\"][0]\n    assert max(sampler.query_sizes) == 2\n"
  },
  {
    "path": "tests/test_domain_rand_config_builder.py",
    "content": "import importlib.util\nimport sys\nfrom pathlib import Path\nfrom types import ModuleType\n\n\nMODULE_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"holomotion\"\n    / \"src\"\n    / \"env\"\n    / \"isaaclab_components\"\n    / \"isaaclab_domain_rand.py\"\n)\n\n\nclass _DummyEventTermCfg:\n    def __init__(self, **kwargs):\n        self.__dict__.update(kwargs)\n\n\nclass _DummySceneEntityCfg:\n    def __init__(self, *args, **kwargs):\n        self.args = args\n        self.kwargs = kwargs\n\n    def resolve(self, _scene):\n        return None\n\n\ndef _load_domain_rand_module(monkeypatch):\n    isaaclab = ModuleType(\"isaaclab\")\n    isaaclab_utils = ModuleType(\"isaaclab.utils\")\n    isaaclab_utils.configclass = lambda cls: cls\n    isaaclab_utils_math = ModuleType(\"isaaclab.utils.math\")\n\n    isaaclab_assets = ModuleType(\"isaaclab.assets\")\n    isaaclab_assets.Articulation = object\n\n    isaaclab_envs = ModuleType(\"isaaclab.envs\")\n    isaaclab_envs.ManagerBasedEnv = object\n    isaaclab_envs_mdp = ModuleType(\"isaaclab.envs.mdp\")\n    isaaclab_envs_mdp.events = ModuleType(\"isaaclab.envs.mdp.events\")\n    isaaclab_envs_mdp.events._randomize_prop_by_op = (\n        lambda *args, **kwargs: None\n    )\n\n    isaaclab_managers = ModuleType(\"isaaclab.managers\")\n    isaaclab_managers.SceneEntityCfg = _DummySceneEntityCfg\n    isaaclab_managers.EventTermCfg = _DummyEventTermCfg\n\n    isaaclab.envs = isaaclab_envs\n    isaaclab.assets = isaaclab_assets\n    isaaclab.utils = isaaclab_utils\n    isaaclab_envs.mdp = isaaclab_envs_mdp\n    isaaclab_utils.math = isaaclab_utils_math\n\n    for name, module in {\n        \"isaaclab\": isaaclab,\n        \"isaaclab.utils\": isaaclab_utils,\n        \"isaaclab.utils.math\": isaaclab_utils_math,\n        \"isaaclab.assets\": isaaclab_assets,\n        \"isaaclab.envs\": isaaclab_envs,\n        \"isaaclab.envs.mdp\": isaaclab_envs_mdp,\n        \"isaaclab.envs.mdp.events\": isaaclab_envs_mdp.events,\n        \"isaaclab.managers\": isaaclab_managers,\n    }.items():\n        monkeypatch.setitem(sys.modules, name, module)\n\n    module_name = \"_test_domain_rand_builder\"\n    spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH)\n    module = importlib.util.module_from_spec(spec)\n    assert spec is not None\n    assert spec.loader is not None\n    sys.modules[module_name] = module\n    spec.loader.exec_module(module)\n    return module\n\n\ndef test_build_domain_rand_config_skips_non_event_metadata(monkeypatch):\n    module = _load_domain_rand_module(monkeypatch)\n\n    events_cfg = module.build_domain_rand_config(\n        {\n            \"erfi\": {\n                \"enabled\": True,\n                \"rfi_probability\": 0.5,\n                \"rfi_lim\": 0.1,\n                \"randomize_rfi_lim\": True,\n                \"rfi_lim_range\": [0.5, 1.5],\n                \"rao_lim\": 0.1,\n            },\n            \"action_delay\": {\n                \"enabled\": True,\n                \"min_delay\": 1,\n                \"max_delay\": 3,\n            },\n            \"motion_init_perturb\": {\n                \"root_pose_perturb_range\": {\"x\": [-0.1, 0.1]}\n            },\n            \"obs_noise\": {\"actor_dof_pos\": {\"n_min\": -0.01, \"n_max\": 0.01}},\n            \"default_dof_pos_bias\": {\n                \"mode\": \"startup\",\n                \"params\": {\n                    \"joint_names\": [\".*\"],\n                    \"pos_distribution_params\": [-0.01, 0.01],\n                    \"operation\": \"add\",\n                    \"distribution\": \"uniform\",\n                },\n            },\n        }\n    )\n\n    assert hasattr(events_cfg, \"default_dof_pos_bias\")\n    assert events_cfg.default_dof_pos_bias.mode == \"startup\"\n    assert not hasattr(events_cfg, \"erfi\")\n    assert not hasattr(events_cfg, \"action_delay\")\n    assert not hasattr(events_cfg, \"motion_init_perturb\")\n    assert not hasattr(events_cfg, \"obs_noise\")\n"
  },
  {
    "path": "tests/test_eval_mujoco_action_delay.py",
    "content": "from collections import deque\n\nimport numpy as np\nfrom omegaconf import OmegaConf\n\nimport holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim\n\n\ndef test_action_delay_cfg_defaults_to_disabled_episode():\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.config = OmegaConf.create({})\n\n    max_delay_step, delay_type = evaluator._get_action_delay_cfg()\n\n    assert max_delay_step == 0\n    assert delay_type == \"episode\"\n\n\ndef test_action_delay_cfg_rejects_invalid_delay_type():\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.config = OmegaConf.create(\n        {\n            \"policy_action_delay_step\": 2,\n            \"action_delay_type\": \"frame\",\n        }\n    )\n\n    try:\n        evaluator._get_action_delay_cfg()\n    except ValueError as exc:\n        assert \"action_delay_type\" in str(exc)\n    else:\n        raise AssertionError(\"Expected ValueError for invalid delay type.\")\n\n\ndef test_apply_action_delay_passthrough_when_disabled():\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.policy_action_delay_step = 0\n    evaluator.action_delay_type = \"episode\"\n    evaluator._policy_action_delay_buffer = deque(maxlen=1)\n    evaluator._current_policy_action_delay_step = 0\n\n    delayed = evaluator._apply_action_delay(\n        np.array([1.0, -1.0], dtype=np.float32)\n    )\n\n    np.testing.assert_allclose(\n        delayed, np.array([1.0, -1.0], dtype=np.float32)\n    )\n\n\ndef test_apply_action_delay_episode_reuses_single_sample(monkeypatch):\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.policy_action_delay_step = 1\n    evaluator.action_delay_type = \"episode\"\n\n    calls = []\n\n    def fake_randint(low, high):\n        calls.append((low, high))\n        return 1\n\n    monkeypatch.setattr(eval_mujoco_sim2sim.np.random, \"randint\", fake_randint)\n\n    evaluator._reset_action_delay_randomization()\n\n    first = evaluator._apply_action_delay(np.array([1.0], dtype=np.float32))\n    second = evaluator._apply_action_delay(np.array([2.0], dtype=np.float32))\n    third = evaluator._apply_action_delay(np.array([3.0], dtype=np.float32))\n\n    assert calls == [(0, 2)]\n    assert evaluator._current_policy_action_delay_step == 1\n    np.testing.assert_allclose(first, np.array([1.0], dtype=np.float32))\n    np.testing.assert_allclose(second, np.array([1.0], dtype=np.float32))\n    np.testing.assert_allclose(third, np.array([2.0], dtype=np.float32))\n\n\ndef test_apply_action_delay_step_resamples_each_policy_step(monkeypatch):\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.policy_action_delay_step = 2\n    evaluator.action_delay_type = \"step\"\n    evaluator._policy_action_delay_buffer = deque(maxlen=3)\n    evaluator._current_policy_action_delay_step = 0\n\n    sampled_delays = iter([2, 0, 1])\n    calls = []\n\n    def fake_randint(low, high):\n        calls.append((low, high))\n        return next(sampled_delays)\n\n    monkeypatch.setattr(eval_mujoco_sim2sim.np.random, \"randint\", fake_randint)\n\n    first = evaluator._apply_action_delay(np.array([1.0], dtype=np.float32))\n    second = evaluator._apply_action_delay(np.array([2.0], dtype=np.float32))\n    third = evaluator._apply_action_delay(np.array([3.0], dtype=np.float32))\n\n    assert calls == [(0, 3), (0, 3), (0, 3)]\n    assert evaluator._current_policy_action_delay_step == 1\n    np.testing.assert_allclose(first, np.array([1.0], dtype=np.float32))\n    np.testing.assert_allclose(second, np.array([2.0], dtype=np.float32))\n    np.testing.assert_allclose(third, np.array([2.0], dtype=np.float32))\n"
  },
  {
    "path": "tests/test_eval_mujoco_action_ema.py",
    "content": "import numpy as np\nfrom omegaconf import OmegaConf\n\nimport holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim\n\n\ndef test_action_ema_filter_cfg_reads_erfi_settings():\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.config = OmegaConf.create(\n        {\n            \"robot\": {\n                \"actuators\": {\n                    \"actuator_type\": \"unitree_erfi\",\n                    \"ema_filter_enabled\": True,\n                    \"ema_filter_alpha\": 0.37,\n                }\n            },\n        }\n    )\n\n    enabled, alpha = evaluator._get_action_ema_filter_cfg()\n\n    assert enabled is True\n    assert alpha == 0.37\n\n\ndef test_action_ema_filter_defaults_to_disabled_for_non_erfi():\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.config = OmegaConf.create(\n        {\n            \"robot\": {\n                \"actuators\": {\n                    \"actuator_type\": \"unitree\",\n                    \"ema_filter_enabled\": True,\n                    \"ema_filter_alpha\": 0.37,\n                }\n            },\n        }\n    )\n\n    enabled, alpha = evaluator._get_action_ema_filter_cfg()\n\n    assert enabled is False\n    assert alpha == 1.0\n\n\ndef test_apply_action_ema_filter_uses_previous_filtered_action():\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.action_ema_filter_enabled = True\n    evaluator.action_ema_filter_alpha = 0.25\n    evaluator._filtered_actions_onnx = None\n\n    first = evaluator._apply_action_ema_filter(\n        np.array([1.0, -1.0], dtype=np.float32)\n    )\n    second = evaluator._apply_action_ema_filter(\n        np.array([3.0, 1.0], dtype=np.float32)\n    )\n\n    np.testing.assert_allclose(first, np.array([1.0, -1.0], dtype=np.float32))\n    np.testing.assert_allclose(second, np.array([1.5, -0.5], dtype=np.float32))\n\n\ndef test_reset_action_ema_filter_clears_state():\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator._filtered_actions_onnx = np.array([1.0], dtype=np.float32)\n\n    evaluator._reset_action_ema_filter()\n\n    assert evaluator._filtered_actions_onnx is None\n"
  },
  {
    "path": "tests/test_eval_mujoco_contact_export.py",
    "content": "import json\nfrom pathlib import Path\n\nimport numpy as np\nfrom omegaconf import OmegaConf\n\nimport holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim\n\n\ndef _build_export_evaluator(tmp_path: Path):\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.simulation_dt = 0.005\n    evaluator._get_stacked_moe_routing_tensors = lambda: (None, None)\n    evaluator._robot_dof_pos_seq = [\n        np.array([0.0, 1.0], dtype=np.float32),\n        np.array([0.5, 1.5], dtype=np.float32),\n    ]\n    evaluator._robot_dof_vel_seq = [\n        np.array([0.1, 0.2], dtype=np.float32),\n        np.array([0.3, 0.4], dtype=np.float32),\n    ]\n    evaluator._robot_dof_acc_seq = [\n        np.array([1.0, 2.0], dtype=np.float32),\n        np.array([3.0, 4.0], dtype=np.float32),\n    ]\n    evaluator._robot_dof_torque_seq = [\n        np.array([5.0, 6.0], dtype=np.float32),\n        np.array([7.0, 8.0], dtype=np.float32),\n    ]\n    evaluator._robot_low_level_dof_torque_seq = [\n        np.array([1.0, 2.0], dtype=np.float32),\n        np.array([3.0, 4.0], dtype=np.float32),\n        np.array([5.0, 6.0], dtype=np.float32),\n        np.array([7.0, 8.0], dtype=np.float32),\n    ]\n    evaluator._robot_low_level_foot_contact_seq = [\n        np.array([1.0, 0.0], dtype=np.float32),\n        np.array([1.0, 1.0], dtype=np.float32),\n        np.array([0.0, 1.0], dtype=np.float32),\n        np.array([0.0, 0.0], dtype=np.float32),\n    ]\n    evaluator._robot_low_level_foot_normal_force_seq = [\n        np.array([50.0, 0.0], dtype=np.float32),\n        np.array([60.0, 55.0], dtype=np.float32),\n        np.array([0.0, 45.0], dtype=np.float32),\n        np.array([0.0, 0.0], dtype=np.float32),\n    ]\n    evaluator._robot_low_level_foot_tangent_speed_seq = [\n        np.array([0.02, 0.0], dtype=np.float32),\n        np.array([0.03, 0.04], dtype=np.float32),\n        np.array([0.0, 0.05], dtype=np.float32),\n        np.array([0.0, 0.0], dtype=np.float32),\n    ]\n    evaluator._robot_action_rate_seq = [\n        np.float32(0.0),\n        np.float32(1.0),\n    ]\n    evaluator._robot_actions_seq = [\n        np.array([0.11, 0.22], dtype=np.float32),\n        np.array([0.33, 0.44], dtype=np.float32),\n    ]\n    evaluator._robot_global_translation_seq = [\n        np.zeros((2, 3), dtype=np.float32),\n        np.ones((2, 3), dtype=np.float32),\n    ]\n    evaluator._robot_global_rotation_quat_seq = [\n        np.tile(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 1)),\n        np.tile(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 1)),\n    ]\n    evaluator._robot_global_velocity_seq = [\n        np.zeros((2, 3), dtype=np.float32),\n        np.ones((2, 3), dtype=np.float32),\n    ]\n    evaluator._robot_global_angular_velocity_seq = [\n        np.zeros((2, 3), dtype=np.float32),\n        np.ones((2, 3), dtype=np.float32),\n    ]\n    evaluator.ref_dof_pos = np.zeros((2, 2), dtype=np.float32)\n    evaluator.ref_dof_vel = np.zeros((2, 2), dtype=np.float32)\n    evaluator.ref_global_translation = np.zeros((2, 2, 3), dtype=np.float32)\n    evaluator.ref_global_rotation_quat_xyzw = np.tile(\n        np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 2, 1)\n    )\n    evaluator.ref_global_velocity = np.zeros((2, 2, 3), dtype=np.float32)\n    evaluator.ref_global_angular_velocity = np.zeros(\n        (2, 2, 3), dtype=np.float32\n    )\n\n    motion_npz_path = tmp_path / \"motion.npz\"\n    np.savez_compressed(\n        motion_npz_path,\n        metadata=np.array(json.dumps({\"clip_length\": 2}), dtype=np.str_),\n    )\n    evaluator.config = OmegaConf.create(\n        {\n            \"motion_npz_path\": str(motion_npz_path),\n            \"ckpt_onnx_path\": str(tmp_path / \"model.onnx\"),\n        }\n    )\n    return evaluator\n\n\ndef test_save_batch_result_exports_low_level_contact_traces(tmp_path: Path):\n    evaluator = _build_export_evaluator(tmp_path)\n    output_path = tmp_path / \"batch_result.npz\"\n\n    evaluator.save_batch_result(str(output_path), {\"clip_length\": 2})\n\n    with np.load(output_path, allow_pickle=True) as data:\n        assert \"robot_actions\" in data.files\n        assert \"robot_low_level_foot_contact\" in data.files\n        assert \"robot_low_level_foot_normal_force\" in data.files\n        assert \"robot_low_level_foot_tangent_speed\" in data.files\n        assert \"robot_low_level_contact_dt\" in data.files\n        np.testing.assert_allclose(\n            data[\"robot_actions\"],\n            np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32),\n        )\n        assert data[\"robot_low_level_foot_contact\"].shape == (4, 2)\n        np.testing.assert_allclose(\n            data[\"robot_low_level_contact_dt\"], np.array(0.005, np.float32)\n        )\n\n\ndef test_dump_robot_augmented_npz_exports_low_level_contact_traces(\n    tmp_path: Path,\n):\n    evaluator = _build_export_evaluator(tmp_path)\n\n    evaluator._dump_robot_augmented_npz()\n\n    output_path = (\n        tmp_path\n        / \"mujoco_output_model\"\n        / f\"{Path(evaluator.config.motion_npz_path).stem}_robot.npz\"\n    )\n    with np.load(output_path, allow_pickle=True) as data:\n        assert \"robot_actions\" in data.files\n        assert \"robot_low_level_foot_contact\" in data.files\n        assert \"robot_low_level_foot_normal_force\" in data.files\n        assert \"robot_low_level_foot_tangent_speed\" in data.files\n        assert \"robot_low_level_contact_dt\" in data.files\n        np.testing.assert_allclose(\n            data[\"robot_actions\"],\n            np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32),\n        )\n        assert data[\"robot_low_level_foot_normal_force\"].shape == (4, 2)\n\n\ndef test_init_low_level_foot_contact_logging_falls_back_to_ankle_roll_bodies(\n    monkeypatch,\n):\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.config = OmegaConf.create({\"robot\": {}})\n    evaluator.m = type(\n        \"FakeModel\",\n        (),\n        {\n            \"geom_bodyid\": np.array([5, 6, 6, 9, 10], dtype=np.int32),\n            \"geom_contype\": np.array([0, 1, 1, 0, 1], dtype=np.int32),\n            \"geom_conaffinity\": np.array([0, 1, 1, 0, 1], dtype=np.int32),\n        },\n    )()\n\n    def fake_name2id(model, obj_type, name):\n        if obj_type == eval_mujoco_sim2sim.mujoco.mjtObj.mjOBJ_GEOM:\n            return -1\n        if obj_type == eval_mujoco_sim2sim.mujoco.mjtObj.mjOBJ_BODY:\n            return {\n                \"left_ankle_roll_link\": 6,\n                \"right_ankle_roll_link\": 10,\n            }.get(name, -1)\n        return -1\n\n    monkeypatch.setattr(eval_mujoco_sim2sim.mujoco, \"mj_name2id\", fake_name2id)\n\n    evaluator._init_low_level_foot_contact_logging()\n\n    assert evaluator._foot_contact_logging_enabled is True\n    assert evaluator._foot_geom_id_groups == [[1, 2], [4]]\n    assert evaluator._foot_geom_id_to_side == {1: 0, 2: 0, 4: 1}\n"
  },
  {
    "path": "tests/test_eval_mujoco_s100_horizon_ptq.py",
    "content": "import sys\nfrom pathlib import Path\nfrom types import SimpleNamespace\n\nimport numpy as np\nimport pytest\nfrom omegaconf import OmegaConf\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nimport holomotion.src.evaluation.eval_mujoco_sim2sim_s100 as eval_mujoco_sim2sim_s100\n\n\nclass _FakeIoNode:\n    def __init__(self, name, shape):\n        self.name = name\n        self.shape = shape\n\n\ndef _make_value_info(name, shape):\n    dims = [SimpleNamespace(dim_value=dim) for dim in shape]\n    tensor_shape = SimpleNamespace(dim=dims)\n    tensor_type = SimpleNamespace(shape=tensor_shape)\n    return SimpleNamespace(\n        name=name, type=SimpleNamespace(tensor_type=tensor_type)\n    )\n\n\ndef _make_fake_onnx_model():\n    return SimpleNamespace(\n        graph=SimpleNamespace(\n            input=[\n                _make_value_info(\"obs\", [1, 16]),\n                _make_value_info(\"past_key_values\", [1, 2, 3, 4]),\n                _make_value_info(\"step_idx\", [1]),\n            ],\n            output=[\n                _make_value_info(\"action\", [1, 12]),\n                _make_value_info(\"present_key_values\", [1, 2, 3, 4]),\n            ],\n        )\n    )\n\n\ndef _make_evaluator(model_path: Path, bc_path: Path | None = None):\n    config_dict = {\n        \"ckpt_onnx_path\": str(model_path),\n        \"use_gpu\": False,\n        \"gpu_id\": 0,\n    }\n    if bc_path is not None:\n        config_dict[\"bc_path\"] = str(bc_path)\n\n    evaluator = eval_mujoco_sim2sim_s100.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim_s100.MujocoEvaluator\n    )\n    evaluator.config = OmegaConf.create(config_dict)\n    evaluator.max_context_len = 0\n    evaluator._discover_policy_moe_outputs = lambda: None\n    return evaluator\n\n\ndef test_load_policy_falls_back_to_horizon_quantized_bc_for_ptq_onnx(\n    monkeypatch, tmp_path\n):\n    model_path = tmp_path / \"demo_ptq_model.onnx\"\n    model_path.write_bytes(b\"onnx\")\n    quantized_path = tmp_path / \"demo_quantized_model.bc\"\n    quantized_path.write_bytes(b\"bc\")\n    captured = {}\n\n    class _FakeHBRuntime:\n        def __init__(self, model_path):\n            captured[\"hb_model_path\"] = model_path\n            self.input_names = [\"obs\", \"past_key_values\", \"step_idx\"]\n            self.output_names = [\"action\", \"present_key_values\"]\n\n        def run(self, output_names, input_feed):\n            raise AssertionError(\"run should not be called in this test\")\n\n    def _raise_hz_calibration(*args, **kwargs):\n        raise RuntimeError(\"Failed to load custom op HzCalibration\")\n\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"get_available_providers\",\n        lambda: [\"CPUExecutionProvider\"],\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"InferenceSession\",\n        _raise_hz_calibration,\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100,\n        \"HBRuntime\",\n        _FakeHBRuntime,\n        raising=False,\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnx,\n        \"load\",\n        lambda _: _make_fake_onnx_model(),\n    )\n\n    evaluator = _make_evaluator(model_path)\n\n    evaluator.load_policy()\n\n    assert captured[\"hb_model_path\"] == str(quantized_path)\n    assert evaluator.policy_input_name == \"obs\"\n    assert evaluator.policy_kv_input_name == \"past_key_values\"\n    assert evaluator.policy_step_input_name == \"step_idx\"\n    assert evaluator.policy_output_name == \"action\"\n    assert evaluator.policy_kv_output_name == \"present_key_values\"\n    assert evaluator.policy_model_context_len == 4\n\n\n@pytest.mark.parametrize(\n    \"runtime_name\",\n    [\n        \"demo_model_16000_ptq_model.bc\",\n        \"demo_model_16000_ptq_model.hbm\",\n        \"demo_model_16000_quantized_model.hbm\",\n    ],\n)\ndef test_load_policy_resolves_common_horizon_runtime_artifact_names(\n    monkeypatch, tmp_path, runtime_name\n):\n    model_path = tmp_path / \"demo_model_16000_ptq_model.onnx\"\n    model_path.write_bytes(b\"onnx\")\n    runtime_path = tmp_path / runtime_name\n    runtime_path.write_bytes(b\"runtime\")\n    captured = {}\n\n    class _FakeHBRuntime:\n        def __init__(self, model_path):\n            captured[\"hb_model_path\"] = model_path\n            self.input_names = [\"obs\", \"past_key_values\", \"step_idx\"]\n            self.output_names = [\"action\", \"present_key_values\"]\n\n        def run(self, output_names, input_feed):\n            raise AssertionError(\"run should not be called in this test\")\n\n    def _raise_hz_calibration(*args, **kwargs):\n        raise RuntimeError(\"Failed to load custom op HzCalibration\")\n\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"get_available_providers\",\n        lambda: [\"CPUExecutionProvider\"],\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"InferenceSession\",\n        _raise_hz_calibration,\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100,\n        \"HBRuntime\",\n        _FakeHBRuntime,\n        raising=False,\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnx,\n        \"load\",\n        lambda _: _make_fake_onnx_model(),\n    )\n\n    evaluator = _make_evaluator(model_path)\n\n    evaluator.load_policy()\n\n    assert captured[\"hb_model_path\"] == str(runtime_path)\n    assert evaluator.policy_model_context_len == 4\n\n\ndef test_load_policy_raises_original_error_when_ptq_fallback_bc_missing(\n    monkeypatch, tmp_path\n):\n    model_path = tmp_path / \"demo_ptq_model.onnx\"\n    model_path.write_bytes(b\"onnx\")\n\n    def _raise_hz_calibration(*args, **kwargs):\n        raise RuntimeError(\"Failed to load custom op HzCalibration\")\n\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"get_available_providers\",\n        lambda: [\"CPUExecutionProvider\"],\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"InferenceSession\",\n        _raise_hz_calibration,\n    )\n\n    evaluator = _make_evaluator(model_path)\n\n    with pytest.raises(RuntimeError, match=\"HzCalibration\"):\n        evaluator.load_policy()\n\n\ndef test_load_policy_keeps_standard_onnxruntime_path_for_regular_onnx(\n    monkeypatch, tmp_path\n):\n    model_path = tmp_path / \"demo_model.onnx\"\n    model_path.write_bytes(b\"onnx\")\n    captured = {}\n\n    class _FakeInferenceSession:\n        def __init__(self, model_path, sess_options, providers):\n            captured[\"model_path\"] = model_path\n            captured[\"providers\"] = providers\n\n        def get_providers(self):\n            return [\"CPUExecutionProvider\"]\n\n        def get_inputs(self):\n            return [\n                _FakeIoNode(\"obs\", [1, 16]),\n                _FakeIoNode(\"past_key_values\", [1, 2, 3, 4]),\n                _FakeIoNode(\"step_idx\", [1]),\n            ]\n\n        def get_outputs(self):\n            return [\n                _FakeIoNode(\"action\", [1, 12]),\n                _FakeIoNode(\"present_key_values\", [1, 2, 3, 4]),\n            ]\n\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"get_available_providers\",\n        lambda: [\"CPUExecutionProvider\"],\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"InferenceSession\",\n        _FakeInferenceSession,\n    )\n\n    evaluator = _make_evaluator(model_path)\n\n    evaluator.load_policy()\n\n    assert captured[\"model_path\"] == str(model_path)\n    assert captured[\"providers\"] == [\"CPUExecutionProvider\"]\n    assert evaluator.policy_model_context_len == 4\n\n\ndef test_load_policy_prefers_explicit_bc_path_for_inference_and_onnx_for_metadata(\n    monkeypatch, tmp_path\n):\n    model_path = tmp_path / \"demo_model.onnx\"\n    model_path.write_bytes(b\"onnx\")\n    runtime_path = tmp_path / \"demo_quantized_model.bc\"\n    runtime_path.write_bytes(b\"bc\")\n    captured = {}\n\n    class _FakeHBRuntime:\n        def __init__(self, model_path):\n            captured[\"hb_model_path\"] = model_path\n            self.input_names = [\"obs\", \"past_key_values\", \"step_idx\"]\n            self.output_names = [\"action\", \"present_key_values\"]\n\n        def run(self, output_names, input_feed):\n            raise AssertionError(\"run should not be called in this test\")\n\n    def _unexpected_ort_session(*args, **kwargs):\n        raise AssertionError(\n            \"InferenceSession should not be created when bc_path is set\"\n        )\n\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"get_available_providers\",\n        lambda: [\"CPUExecutionProvider\"],\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnxruntime,\n        \"InferenceSession\",\n        _unexpected_ort_session,\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100,\n        \"HBRuntime\",\n        _FakeHBRuntime,\n        raising=False,\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100.onnx,\n        \"load\",\n        lambda _: _make_fake_onnx_model(),\n    )\n\n    evaluator = _make_evaluator(model_path, bc_path=runtime_path)\n\n    evaluator.load_policy()\n\n    assert captured[\"hb_model_path\"] == str(runtime_path)\n    assert evaluator.policy_input_name == \"obs\"\n    assert evaluator.policy_kv_input_name == \"past_key_values\"\n    assert evaluator.policy_step_input_name == \"step_idx\"\n    assert evaluator.policy_output_name == \"action\"\n    assert evaluator.policy_kv_output_name == \"present_key_values\"\n    assert evaluator.policy_model_context_len == 4\n\n\ndef test_bc_runtime_run_normalizes_inputs_for_hbruntime(monkeypatch, tmp_path):\n    runtime_path = tmp_path / \"demo_quantized_model.bc\"\n    runtime_path.write_bytes(b\"bc\")\n    captured = {}\n\n    class _FakeHBRuntime:\n        def __init__(self, model_path):\n            captured[\"model_path\"] = model_path\n            self.input_names = [\"obs\", \"past_key_values\", \"step_idx\"]\n            self.output_names = [\"action\", \"present_key_values\"]\n\n        def run(self, output_names, input_feed):\n            captured[\"output_names\"] = list(output_names)\n            captured[\"input_feed\"] = input_feed\n            return [\"ok\"]\n\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim_s100,\n        \"HBRuntime\",\n        _FakeHBRuntime,\n        raising=False,\n    )\n\n    wrapper = eval_mujoco_sim2sim_s100._HbSessionWrapper(runtime_path)\n    obs = np.arange(6, dtype=np.float64).reshape(2, 3).T\n    past_key_values = np.arange(24, dtype=np.float64).reshape(2, 3, 4)\n    step_idx = np.array([7], dtype=np.int32)\n\n    outputs = wrapper.run(\n        [\"action\"],\n        {\n            \"obs\": obs,\n            \"past_key_values\": past_key_values,\n            \"step_idx\": step_idx,\n        },\n    )\n\n    assert outputs == [\"ok\"]\n    assert captured[\"model_path\"] == str(runtime_path)\n    assert captured[\"output_names\"] == [\"action\"]\n    assert captured[\"input_feed\"][\"obs\"].dtype == np.float32\n    assert captured[\"input_feed\"][\"obs\"].flags[\"C_CONTIGUOUS\"]\n    assert captured[\"input_feed\"][\"past_key_values\"].dtype == np.float32\n    assert captured[\"input_feed\"][\"past_key_values\"].flags[\"C_CONTIGUOUS\"]\n    assert captured[\"input_feed\"][\"step_idx\"].dtype == np.int64\n    assert captured[\"input_feed\"][\"step_idx\"].flags[\"C_CONTIGUOUS\"]\n\n\ndef test_update_policy_raises_clear_error_before_runtime_on_obs_dim_mismatch():\n    evaluator = eval_mujoco_sim2sim_s100.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim_s100.MujocoEvaluator\n    )\n    evaluator._record_robot_states = lambda: None\n    evaluator.obs_builder = SimpleNamespace(\n        build_policy_obs=lambda: np.zeros(425, dtype=np.float32)\n    )\n    evaluator.policy_input_name = \"obs\"\n    evaluator.policy_output_name = \"action\"\n    evaluator.policy_obs_expected_dim = 786\n    evaluator.use_kv_cache = False\n    evaluator.policy_step_input_name = None\n    evaluator.policy_kv_output_name = None\n    evaluator.policy_moe_layer_output_names = []\n    evaluator.dump_onnx_io_npy = False\n    evaluator.counter = 0\n    evaluator.command_mode = \"velocity_tracking\"\n    evaluator.config = OmegaConf.create(\n        {\"motion_npz_dir\": \"\", \"motion_npz_path\": \"\"}\n    )\n    evaluator.policy_session = SimpleNamespace(\n        run=lambda *args, **kwargs: (_ for _ in ()).throw(\n            AssertionError(\"runtime should not be called on shape mismatch\")\n        )\n    )\n\n    with pytest.raises(\n        ValueError, match=\"expects 786 features but evaluator built 425\"\n    ):\n        evaluator._update_policy()\n"
  },
  {
    "path": "tests/test_eval_mujoco_use_gpu.py",
    "content": "import sys\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nimport holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim\n\n\nclass _FakeIoNode:\n    def __init__(self, name, shape):\n        self.name = name\n        self.shape = shape\n\n\ndef test_load_policy_treats_false_string_use_gpu_as_cpu(monkeypatch):\n    captured = {}\n\n    class _FakeInferenceSession:\n        def __init__(self, model_path, sess_options, providers):\n            captured[\"model_path\"] = model_path\n            captured[\"providers\"] = providers\n\n        def get_providers(self):\n            return [\"CPUExecutionProvider\"]\n\n        def get_inputs(self):\n            return [_FakeIoNode(\"obs\", [1, 16])]\n\n        def get_outputs(self):\n            return [_FakeIoNode(\"action\", [1, 12])]\n\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim.onnxruntime,\n        \"get_available_providers\",\n        lambda: [\"CUDAExecutionProvider\", \"CPUExecutionProvider\"],\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim.onnxruntime,\n        \"InferenceSession\",\n        _FakeInferenceSession,\n    )\n\n    evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(\n        eval_mujoco_sim2sim.MujocoEvaluator\n    )\n    evaluator.config = OmegaConf.create(\n        {\n            \"ckpt_onnx_path\": \"model.onnx\",\n            \"use_gpu\": \"false\",\n            \"gpu_id\": 3,\n        }\n    )\n    evaluator.max_context_len = 0\n\n    evaluator.load_policy()\n\n    assert captured[\"model_path\"] == \"model.onnx\"\n    assert captured[\"providers\"] == [\"CPUExecutionProvider\"]\n\n\ndef test_create_ray_evaluator_preserves_use_gpu_false(monkeypatch):\n    captured = {}\n\n    class _FakeEvaluator:\n        def __init__(self, config):\n            captured[\"use_gpu\"] = config.use_gpu\n            captured[\"gpu_id\"] = config.gpu_id\n\n    monkeypatch.setattr(eval_mujoco_sim2sim, \"MujocoEvaluator\", _FakeEvaluator)\n\n    eval_mujoco_sim2sim._create_ray_evaluator(\n        {\"use_gpu\": False, \"gpu_id\": 5}, \"holomotion\"\n    )\n\n    assert captured[\"use_gpu\"] is False\n    assert captured[\"gpu_id\"] == 5\n\n\ndef test_run_mujoco_sim2sim_eval_preserves_use_gpu_false(\n    monkeypatch, tmp_path\n):\n    captured = {}\n\n    class _FakeEvaluator:\n        def __init__(self, config):\n            captured[\"use_gpu\"] = config.use_gpu\n\n        def setup(self):\n            captured[\"setup\"] = True\n\n        def run_simulation(self):\n            captured[\"run_simulation\"] = True\n\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim.hydra.utils,\n        \"get_original_cwd\",\n        lambda: str(tmp_path),\n    )\n    monkeypatch.setattr(\n        eval_mujoco_sim2sim,\n        \"process_config\",\n        lambda _: OmegaConf.create(\n            {\n                \"use_gpu\": False,\n                \"model_type\": \"holomotion\",\n            }\n        ),\n    )\n    monkeypatch.setattr(eval_mujoco_sim2sim, \"MujocoEvaluator\", _FakeEvaluator)\n\n    eval_mujoco_sim2sim.run_mujoco_sim2sim_eval(OmegaConf.create({}))\n\n    assert captured[\"use_gpu\"] is False\n    assert captured[\"setup\"] is True\n    assert captured[\"run_simulation\"] is True\n"
  },
  {
    "path": "tests/test_eval_onnx_io_dump.py",
    "content": "import json\nimport sys\nimport types\nfrom pathlib import Path\nfrom types import SimpleNamespace\n\nimport numpy as np\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.evaluation.eval_mujoco_sim2sim import (\n    MujocoEvaluator,\n    write_onnx_io_dump_readme,\n)\nfrom holomotion.src.evaluation.ray_evaluator_actor import RayEvaluatorActor\n\n\nclass _Config(SimpleNamespace):\n    def get(self, key, default=None):\n        return getattr(self, key, default)\n\n\ndef test_save_onnx_io_dump_stacks_per_frame_inputs_and_outputs(tmp_path):\n    evaluator = MujocoEvaluator.__new__(MujocoEvaluator)\n    evaluator._reset_onnx_io_dump_buffers()\n\n    evaluator._record_onnx_io_frame(\n        input_feed={\n            \"obs\": np.array([[1.0, 2.0]], dtype=np.float32),\n            \"step\": np.array([0], dtype=np.int64),\n        },\n        output_names=[\"action\", \"kv_cache\"],\n        onnx_output=[\n            np.array([[0.1, 0.2]], dtype=np.float32),\n            np.array([[[3.0, 4.0]]], dtype=np.float32),\n        ],\n    )\n    evaluator._record_onnx_io_frame(\n        input_feed={\n            \"obs\": np.array([[5.0, 6.0]], dtype=np.float32),\n            \"step\": np.array([1], dtype=np.int64),\n        },\n        output_names=[\"action\", \"kv_cache\"],\n        onnx_output=[\n            np.array([[0.3, 0.4]], dtype=np.float32),\n            np.array([[[7.0, 8.0]]], dtype=np.float32),\n        ],\n    )\n\n    output_path = tmp_path / \"clip_onnx_io.npy\"\n    evaluator.save_onnx_io_dump(\n        output_path,\n        {\n            \"source_npz\": \"clip.npz\",\n            \"onnx_model\": \"model.onnx\",\n        },\n    )\n\n    payload = np.load(output_path, allow_pickle=True).item()\n\n    assert payload[\"input_names\"] == [\"obs\", \"step\"]\n    assert payload[\"output_names\"] == [\"action\", \"kv_cache\"]\n    np.testing.assert_allclose(\n        payload[\"inputs\"][\"obs\"],\n        np.array([[[1.0, 2.0]], [[5.0, 6.0]]], dtype=np.float32),\n    )\n    np.testing.assert_array_equal(\n        payload[\"inputs\"][\"step\"],\n        np.array([[0], [1]], dtype=np.int64),\n    )\n    np.testing.assert_allclose(\n        payload[\"outputs\"][\"action\"],\n        np.array([[[0.1, 0.2]], [[0.3, 0.4]]], dtype=np.float32),\n    )\n    np.testing.assert_allclose(\n        payload[\"outputs\"][\"kv_cache\"],\n        np.array([[[[3.0, 4.0]]], [[[7.0, 8.0]]]], dtype=np.float32),\n    )\n    assert payload[\"source_npz\"] == \"clip.npz\"\n    assert payload[\"onnx_model\"] == \"model.onnx\"\n\n\ndef test_write_onnx_io_dump_readme_creates_chinese_loading_instructions(\n    tmp_path,\n):\n    readme_path = write_onnx_io_dump_readme(tmp_path)\n\n    assert readme_path == tmp_path / \"README.md\"\n    content = readme_path.read_text(encoding=\"utf-8\")\n    assert \"每个动作片段会生成一个 `.npy` 文件\" in content\n    assert \"allow_pickle=True\" in content\n    assert \"np.load(npy_path, allow_pickle=True).item()\" in content\n\n\ndef test_save_batch_result_persists_low_level_torque_dump_and_dt(tmp_path):\n    evaluator = MujocoEvaluator.__new__(MujocoEvaluator)\n    evaluator.policy_dt = 0.02\n    evaluator.simulation_dt = 0.005\n\n    evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)]\n    evaluator._robot_low_level_dof_torque_seq = [\n        np.array([1.0, -1.0], dtype=np.float32),\n        np.array([-1.0, 1.0], dtype=np.float32),\n    ]\n    evaluator._robot_action_rate_seq = [np.float32(0.0)]\n    evaluator._robot_global_translation_seq = [\n        np.zeros((1, 3), dtype=np.float32)\n    ]\n    evaluator._robot_global_rotation_quat_seq = [\n        np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32)\n    ]\n    evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)]\n    evaluator._robot_global_angular_velocity_seq = [\n        np.zeros((1, 3), dtype=np.float32)\n    ]\n    evaluator.ref_dof_pos = np.zeros((1, 2), dtype=np.float32)\n    evaluator.ref_dof_vel = np.zeros((1, 2), dtype=np.float32)\n    evaluator.ref_global_translation = np.zeros((1, 1, 3), dtype=np.float32)\n    evaluator.ref_global_rotation_quat_xyzw = np.array(\n        [[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32\n    )\n    evaluator.ref_global_velocity = np.zeros((1, 1, 3), dtype=np.float32)\n    evaluator.ref_global_angular_velocity = np.zeros(\n        (1, 1, 3), dtype=np.float32\n    )\n\n    output_path = tmp_path / \"demo_eval.npz\"\n    evaluator.save_batch_result(\n        output_path, {\"source_file\": \"clip.npz\", \"clip_length\": 1}\n    )\n\n    with np.load(output_path, allow_pickle=True) as payload:\n        metadata = json.loads(payload[\"metadata\"].item())\n        np.testing.assert_allclose(\n            payload[\"robot_low_level_dof_torque\"],\n            np.array([[1.0, -1.0], [-1.0, 1.0]], dtype=np.float32),\n        )\n\n    assert metadata[\"source_file\"] == \"clip.npz\"\n    assert metadata[\"clip_length\"] == 1\n    assert metadata[\"robot_low_level_torque_dt\"] == 0.005\n\n\ndef test_save_batch_result_persists_moe_routing_tensors(tmp_path):\n    evaluator = MujocoEvaluator.__new__(MujocoEvaluator)\n    evaluator.policy_dt = 0.02\n    evaluator.simulation_dt = 0.005\n\n    evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)]\n    evaluator._robot_low_level_dof_torque_seq = [\n        np.array([0.5, -0.5], dtype=np.float32)\n    ]\n    evaluator._robot_action_rate_seq = [np.float32(0.0)]\n    evaluator._robot_global_translation_seq = [\n        np.zeros((1, 3), dtype=np.float32)\n    ]\n    evaluator._robot_global_rotation_quat_seq = [\n        np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32)\n    ]\n    evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)]\n    evaluator._robot_global_angular_velocity_seq = [\n        np.zeros((1, 3), dtype=np.float32)\n    ]\n    evaluator._robot_moe_expert_indices_seq = [\n        np.array([[1, 3], [0, 2]], dtype=np.int64)\n    ]\n    evaluator._robot_moe_expert_logits_seq = [\n        np.array(\n            [[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]],\n            dtype=np.float32,\n        )\n    ]\n    evaluator.ref_dof_pos = np.zeros((1, 2), dtype=np.float32)\n    evaluator.ref_dof_vel = np.zeros((1, 2), dtype=np.float32)\n    evaluator.ref_global_translation = np.zeros((1, 1, 3), dtype=np.float32)\n    evaluator.ref_global_rotation_quat_xyzw = np.array(\n        [[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32\n    )\n    evaluator.ref_global_velocity = np.zeros((1, 1, 3), dtype=np.float32)\n    evaluator.ref_global_angular_velocity = np.zeros(\n        (1, 1, 3), dtype=np.float32\n    )\n\n    output_path = tmp_path / \"demo_eval_moe.npz\"\n    evaluator.save_batch_result(output_path, {\"source_file\": \"clip.npz\"})\n\n    with np.load(output_path, allow_pickle=True) as payload:\n        np.testing.assert_array_equal(\n            payload[\"robot_moe_expert_indices\"],\n            np.array([[[1, 3], [0, 2]]], dtype=np.int64),\n        )\n        np.testing.assert_allclose(\n            payload[\"robot_moe_expert_logits\"],\n            np.array(\n                [[[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]]],\n                dtype=np.float32,\n            ),\n        )\n\n\ndef test_dump_robot_augmented_npz_persists_moe_routing_tensors(tmp_path):\n    source_npz = tmp_path / \"clip.npz\"\n    np.savez(source_npz, ref=np.array([1], dtype=np.int32))\n\n    onnx_path = tmp_path / \"model.onnx\"\n    onnx_path.write_bytes(b\"\")\n\n    evaluator = MujocoEvaluator.__new__(MujocoEvaluator)\n    evaluator.simulation_dt = 0.005\n    evaluator.config = _Config(\n        motion_npz_path=str(source_npz),\n        ckpt_onnx_path=str(onnx_path),\n    )\n    evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)]\n    evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)]\n    evaluator._robot_low_level_dof_torque_seq = [\n        np.array([0.5, -0.5], dtype=np.float32)\n    ]\n    evaluator._robot_action_rate_seq = [np.float32(0.0)]\n    evaluator._robot_global_translation_seq = [\n        np.zeros((1, 3), dtype=np.float32)\n    ]\n    evaluator._robot_global_rotation_quat_seq = [\n        np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32)\n    ]\n    evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)]\n    evaluator._robot_global_angular_velocity_seq = [\n        np.zeros((1, 3), dtype=np.float32)\n    ]\n    evaluator._robot_moe_expert_indices_seq = [\n        np.array([[1, 3], [0, 2]], dtype=np.int64)\n    ]\n    evaluator._robot_moe_expert_logits_seq = [\n        np.array(\n            [[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]],\n            dtype=np.float32,\n        )\n    ]\n\n    evaluator._dump_robot_augmented_npz()\n\n    out_path = tmp_path / \"mujoco_output_model\" / \"clip_robot.npz\"\n    with np.load(out_path, allow_pickle=True) as payload:\n        np.testing.assert_array_equal(\n            payload[\"robot_moe_expert_indices\"],\n            np.array([[[1, 3], [0, 2]]], dtype=np.int64),\n        )\n        np.testing.assert_allclose(\n            payload[\"robot_moe_expert_logits\"],\n            np.array(\n                [[[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]]],\n                dtype=np.float32,\n            ),\n        )\n\n\ndef test_ray_actor_run_clip_overwrites_existing_outputs_and_sidecar(tmp_path):\n    class _FakeEvaluator:\n        def __init__(self):\n            self.n_motion_frames = 2\n            self.calls = []\n            self.counter = 0\n\n        def load_specific_motion(self, file_path):\n            self.calls.append((\"load\", file_path))\n\n        def reset_state_teleport(self):\n            self.calls.append((\"reset\",))\n\n        def _update_policy(self):\n            self.calls.append((\"update\",))\n\n        def _apply_control(self, sleep=False):\n            self.calls.append((\"apply\", sleep))\n\n        def save_batch_result(self, output_path, meta_info):\n            self.calls.append((\"save_batch\", output_path, meta_info))\n            Path(output_path).write_text(\"fresh-npz\", encoding=\"utf-8\")\n\n        def save_onnx_io_dump(self, output_path, meta_info):\n            self.calls.append((\"save_onnx\", output_path, meta_info))\n            np.save(\n                output_path,\n                {\"source_npz\": meta_info[\"source_file\"]},\n                allow_pickle=True,\n            )\n\n    actor = RayEvaluatorActor.__new__(RayEvaluatorActor)\n    actor.output_dir = str(tmp_path)\n    actor.config_dict = {\n        \"ckpt_onnx_path\": \"model.onnx\",\n        \"dump_onnx_io_npy\": True,\n    }\n    actor.evaluator = _FakeEvaluator()\n\n    clip_path = tmp_path / \"demo_clip.npz\"\n    np.savez(clip_path, dummy=np.array([1], dtype=np.int32))\n\n    existing_npz = tmp_path / \"demo_clip_eval.npz\"\n    existing_npz.write_text(\"stale\", encoding=\"utf-8\")\n    onnx_dir = tmp_path / \"onnx_io_npy\"\n    onnx_dir.mkdir()\n\n    status = actor.run_clip(str(clip_path))\n\n    assert status == \"success\"\n    assert existing_npz.read_text(encoding=\"utf-8\") == \"fresh-npz\"\n    onnx_dump_path = onnx_dir / \"demo_clip_onnx_io.npy\"\n    assert onnx_dump_path.is_file()\n    payload = np.load(onnx_dump_path, allow_pickle=True).item()\n    assert payload[\"source_npz\"] == \"demo_clip.npz\"\n    assert (\"load\", str(clip_path)) in actor.evaluator.calls\n    assert (\"reset\",) in actor.evaluator.calls\n    assert actor.evaluator.calls.count((\"update\",)) == 2\n\n\ndef test_ray_actor_skips_sidecar_for_non_default_model_type(tmp_path):\n    class _FakeEvaluator:\n        def __init__(self):\n            self.n_motion_frames = 1\n            self.calls = []\n            self.counter = 0\n\n        def load_specific_motion(self, file_path):\n            self.calls.append((\"load\", file_path))\n\n        def reset_state_teleport(self):\n            self.calls.append((\"reset\",))\n\n        def _update_policy(self):\n            self.calls.append((\"update\",))\n\n        def _apply_control(self, sleep=False):\n            self.calls.append((\"apply\", sleep))\n\n        def save_batch_result(self, output_path, meta_info):\n            self.calls.append((\"save_batch\", output_path, meta_info))\n            Path(output_path).write_text(\"fresh-npz\", encoding=\"utf-8\")\n\n        def save_onnx_io_dump(self, output_path, meta_info):\n            self.calls.append((\"save_onnx\", output_path, meta_info))\n\n    actor = RayEvaluatorActor.__new__(RayEvaluatorActor)\n    actor.output_dir = str(tmp_path)\n    actor.config_dict = {\n        \"ckpt_onnx_path\": \"model.onnx\",\n        \"dump_onnx_io_npy\": True,\n        \"model_type\": \"gmt\",\n    }\n    actor.evaluator = _FakeEvaluator()\n\n    clip_path = tmp_path / \"demo_clip.npz\"\n    np.savez(clip_path, dummy=np.array([1], dtype=np.int32))\n\n    status = actor.run_clip(str(clip_path))\n\n    assert status == \"success\"\n    assert not any(call[0] == \"save_onnx\" for call in actor.evaluator.calls)\n\n\ndef test_ray_actor_treats_empty_model_type_as_default_holomotion(tmp_path):\n    class _FakeEvaluator:\n        def __init__(self):\n            self.n_motion_frames = 1\n            self.calls = []\n            self.counter = 0\n\n        def load_specific_motion(self, file_path):\n            self.calls.append((\"load\", file_path))\n\n        def reset_state_teleport(self):\n            self.calls.append((\"reset\",))\n\n        def _update_policy(self):\n            self.calls.append((\"update\",))\n\n        def _apply_control(self, sleep=False):\n            self.calls.append((\"apply\", sleep))\n\n        def save_batch_result(self, output_path, meta_info):\n            self.calls.append((\"save_batch\", output_path, meta_info))\n            Path(output_path).write_text(\"fresh-npz\", encoding=\"utf-8\")\n\n        def save_onnx_io_dump(self, output_path, meta_info):\n            self.calls.append((\"save_onnx\", output_path, meta_info))\n            np.save(output_path, {\"source_npz\": meta_info[\"source_file\"]})\n\n    actor = RayEvaluatorActor.__new__(RayEvaluatorActor)\n    actor.output_dir = str(tmp_path)\n    actor.config_dict = {\n        \"ckpt_onnx_path\": \"model.onnx\",\n        \"dump_onnx_io_npy\": True,\n        \"model_type\": \"\",\n    }\n    actor.evaluator = _FakeEvaluator()\n\n    clip_path = tmp_path / \"demo_clip.npz\"\n    np.savez(clip_path, dummy=np.array([1], dtype=np.int32))\n\n    status = actor.run_clip(str(clip_path))\n\n    assert status == \"success\"\n    assert any(call[0] == \"save_onnx\" for call in actor.evaluator.calls)\n\n\ndef test_ray_actor_init_uses_configured_evaluator_module(\n    monkeypatch, tmp_path\n):\n    class _FakeEvaluator:\n        def __init__(self):\n            self.setup_called = False\n\n        def setup(self):\n            self.setup_called = True\n\n    captured = {}\n    fake_evaluator = _FakeEvaluator()\n\n    def _unexpected_default_factory(*args, **kwargs):\n        raise AssertionError(\"default evaluator factory should not be used\")\n\n    def _fake_override_factory(config_dict, model_type):\n        captured[\"config_dict\"] = config_dict\n        captured[\"model_type\"] = model_type\n        return fake_evaluator\n\n    monkeypatch.setattr(\n        \"holomotion.src.evaluation.eval_mujoco_sim2sim._create_ray_evaluator\",\n        _unexpected_default_factory,\n    )\n    sys.modules[\"holomotion.src.evaluation.fake_eval_module\"] = (\n        types.SimpleNamespace(_create_ray_evaluator=_fake_override_factory)\n    )\n\n    actor = RayEvaluatorActor(\n        {\n            \"ckpt_onnx_path\": \"model.onnx\",\n            \"model_type\": \"holomotion\",\n            \"ray_evaluator_module\": \"holomotion.src.evaluation.fake_eval_module\",\n        },\n        str(tmp_path),\n    )\n\n    assert actor.evaluator is fake_evaluator\n    assert fake_evaluator.setup_called is True\n    assert captured[\"model_type\"] == \"holomotion\"\n    assert (\n        captured[\"config_dict\"][\"ray_evaluator_module\"]\n        == \"holomotion.src.evaluation.fake_eval_module\"\n    )\n"
  },
  {
    "path": "tests/test_evaluation_metrics.py",
    "content": "import csv\nimport json\nfrom pathlib import Path\n\nimport numpy as np\n\nfrom holomotion.src.evaluation.metrics import (\n    _compute_clip_stability_summary,\n    _per_frame_metrics_from_npz,\n    offline_evaluate_dumped_npzs,\n)\n\n\ndef _make_eval_data(\n    robot_dof_torque: np.ndarray,\n    *,\n    robot_dof_vel: np.ndarray | None = None,\n    robot_dof_acc: np.ndarray | None = None,\n    robot_action_rate: np.ndarray | None = None,\n    robot_low_level_dof_torque: np.ndarray | None = None,\n    robot_global_angular_velocity: np.ndarray | None = None,\n    robot_low_level_foot_contact: np.ndarray | None = None,\n    robot_low_level_foot_normal_force: np.ndarray | None = None,\n    robot_low_level_foot_tangent_speed: np.ndarray | None = None,\n    robot_moe_expert_logits: np.ndarray | None = None,\n):\n    num_frames = int(robot_dof_torque.shape[0])\n    num_dofs = int(robot_dof_torque.shape[1])\n\n    root = np.zeros((num_frames, 1, 3), dtype=np.float32)\n    child = np.tile(\n        np.array([[[0.0, 0.0, 1.0]]], dtype=np.float32), (num_frames, 1, 1)\n    )\n    global_translation = np.concatenate([root, child], axis=1)\n    global_rotation = np.tile(\n        np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32),\n        (num_frames, 2, 1),\n    )\n\n    zeros_dof = np.zeros((num_frames, num_dofs), dtype=np.float32)\n\n    payload = {\n        \"ref_dof_pos\": zeros_dof.copy(),\n        \"robot_dof_pos\": zeros_dof.copy(),\n        \"ref_dof_vel\": zeros_dof.copy(),\n        \"ref_global_translation\": global_translation.copy(),\n        \"robot_global_translation\": global_translation.copy(),\n        \"ref_global_rotation_quat\": global_rotation.copy(),\n        \"robot_global_rotation_quat\": global_rotation.copy(),\n        \"ref_global_velocity\": np.zeros((num_frames, 2, 3), dtype=np.float32),\n        \"ref_global_angular_velocity\": np.zeros(\n            (num_frames, 2, 3), dtype=np.float32\n        ),\n        \"robot_global_velocity\": np.zeros(\n            (num_frames, 2, 3), dtype=np.float32\n        ),\n        \"robot_global_angular_velocity\": (\n            np.zeros((num_frames, 2, 3), dtype=np.float32)\n            if robot_global_angular_velocity is None\n            else robot_global_angular_velocity.astype(np.float32)\n        ),\n        \"robot_dof_vel\": (\n            zeros_dof.copy()\n            if robot_dof_vel is None\n            else robot_dof_vel.astype(np.float32)\n        ),\n        \"robot_dof_acc\": (\n            zeros_dof.copy()\n            if robot_dof_acc is None\n            else robot_dof_acc.astype(np.float32)\n        ),\n        \"robot_dof_torque\": robot_dof_torque.astype(np.float32),\n        \"robot_action_rate\": (\n            np.zeros((num_frames,), dtype=np.float32)\n            if robot_action_rate is None\n            else robot_action_rate.astype(np.float32)\n        ),\n    }\n    if robot_low_level_dof_torque is not None:\n        payload[\"robot_low_level_dof_torque\"] = (\n            robot_low_level_dof_torque.astype(np.float32)\n        )\n    if robot_low_level_foot_contact is not None:\n        payload[\"robot_low_level_foot_contact\"] = (\n            robot_low_level_foot_contact.astype(np.float32)\n        )\n    if robot_low_level_foot_normal_force is not None:\n        payload[\"robot_low_level_foot_normal_force\"] = (\n            robot_low_level_foot_normal_force.astype(np.float32)\n        )\n    if robot_low_level_foot_tangent_speed is not None:\n        payload[\"robot_low_level_foot_tangent_speed\"] = (\n            robot_low_level_foot_tangent_speed.astype(np.float32)\n        )\n    if robot_moe_expert_logits is not None:\n        payload[\"robot_moe_expert_logits\"] = robot_moe_expert_logits.astype(\n            np.float32\n        )\n    return payload\n\n\ndef test_per_frame_metrics_include_torque_jump_diagnostics():\n    constant_torque = np.ones((4, 2), dtype=np.float32)\n    constant_df = _per_frame_metrics_from_npz(\n        motion_key=\"constant\",\n        data=_make_eval_data(constant_torque),\n        robot_control_dt=0.5,\n    )\n\n    assert \"mean_torque_jump_norm\" in constant_df.columns\n    assert \"mean_torque_jump_ratio\" in constant_df.columns\n    assert np.isnan(constant_df[\"mean_torque_jump_norm\"].iloc[0])\n    assert np.isnan(constant_df[\"mean_torque_jump_ratio\"].iloc[0])\n    np.testing.assert_allclose(\n        np.nan_to_num(constant_df[\"mean_torque_jump_norm\"].to_numpy()),\n        np.zeros(4, dtype=np.float64),\n    )\n    np.testing.assert_allclose(\n        np.nan_to_num(constant_df[\"mean_torque_jump_ratio\"].to_numpy()),\n        np.zeros(4, dtype=np.float64),\n    )\n\n    jump_torque = np.array(\n        [\n            [1.0, 0.0],\n            [1.0, 0.0],\n            [-1.0, 0.0],\n            [-1.0, 0.0],\n        ],\n        dtype=np.float32,\n    )\n    jump_df = _per_frame_metrics_from_npz(\n        motion_key=\"jump\",\n        data=_make_eval_data(jump_torque),\n        robot_control_dt=0.5,\n    )\n\n    assert jump_df[\"mean_torque_jump_norm\"].iloc[2] > 3.9\n    assert jump_df[\"mean_torque_jump_ratio\"].iloc[2] > 1.9\n\n\ndef test_offline_evaluate_dumped_npzs_exports_torque_jump_summary_metrics(\n    tmp_path: Path,\n):\n    eval_dir = tmp_path / \"eval\"\n    eval_dir.mkdir()\n\n    jump_torque_50hz = np.tile(\n        np.array([[1.0, 0.0]], dtype=np.float32), (4, 1)\n    )\n    jump_torque_low_level = np.array(\n        [\n            [1.0, 0.0],\n            [1.0, 0.0],\n            [1.0, 0.0],\n            [1.0, 0.0],\n            [1.0, 0.0],\n            [1.0, 0.0],\n            [1.0, 0.0],\n            [1.0, 0.0],\n            [-1.0, 0.0],\n            [1.0, 0.0],\n            [-1.0, 0.0],\n            [1.0, 0.0],\n            [-1.0, 0.0],\n            [1.0, 0.0],\n            [-1.0, 0.0],\n            [1.0, 0.0],\n        ],\n        dtype=np.float32,\n    )\n    payload = _make_eval_data(\n        jump_torque_50hz,\n        robot_low_level_dof_torque=jump_torque_low_level,\n    )\n    payload[\"metadata\"] = np.array(\n        json.dumps({\"clip_length\": 4, \"robot_low_level_torque_dt\": 0.005}),\n        dtype=np.str_,\n    )\n\n    np.savez_compressed(eval_dir / \"demo_clip.npz\", **payload)\n\n    output_json_path = eval_dir / \"summary.json\"\n    result = offline_evaluate_dumped_npzs(\n        npz_dir=str(eval_dir),\n        output_json_path=str(output_json_path),\n    )\n\n    per_clip = result[\"per_clip\"][0]\n    for key in (\n        \"mean_torque_jump_norm\",\n        \"p95_torque_jump_norm\",\n        \"mean_torque_jump_ratio\",\n        \"p95_torque_jump_ratio\",\n    ):\n        assert key in per_clip\n        assert key in result[\"dataset\"][\"mean\"]\n\n    assert per_clip[\"mean_dof_torque\"] == 1.0\n    assert per_clip[\"p95_torque_jump_norm\"] > 300.0\n    assert per_clip[\"p95_torque_jump_ratio\"] > 1.0\n\n    with output_json_path.open(\"r\", encoding=\"utf-8\") as handle:\n        written = json.load(handle)\n    assert \"p95_torque_jump_ratio\" in written[\"dataset\"][\"mean\"]\n\n    csv_path = eval_dir / \"per_clip_metrics.csv\"\n    with csv_path.open(\"r\", encoding=\"utf-8\", newline=\"\") as handle:\n        reader = csv.DictReader(handle)\n        row = next(reader)\n    assert \"p95_torque_jump_ratio\" in row\n    assert \"mean_torque_jump_norm\" in row\n\n\ndef test_compute_clip_stability_summary_detects_chatter_and_support_events():\n    num_frames = 50\n    num_low_level = 200\n    policy_dt = 0.02\n    low_level_dt = 0.005\n\n    t_policy = np.arange(num_frames, dtype=np.float32) * policy_dt\n    t_low = np.arange(num_low_level, dtype=np.float32) * low_level_dt\n\n    smooth_ang_vel = np.zeros((num_frames, 2, 3), dtype=np.float32)\n    smooth_ang_vel[:, 0, 0] = 0.2 * np.sin(2.0 * np.pi * 1.0 * t_policy)\n\n    unstable_ang_vel = smooth_ang_vel.copy()\n    unstable_ang_vel[:, 0, 0] += 0.7 * np.sin(\n        2.0 * np.pi * 8.0 * t_policy\n    ).astype(np.float32)\n    unstable_ang_vel[:, 0, 1] += 0.4 * np.sin(\n        2.0 * np.pi * 6.0 * t_policy\n    ).astype(np.float32)\n\n    smooth_low_level_torque = np.zeros((num_low_level, 2), dtype=np.float32)\n    smooth_low_level_torque[:, 0] = np.sin(2.0 * np.pi * 1.0 * t_low)\n\n    unstable_low_level_torque = smooth_low_level_torque.copy()\n    unstable_low_level_torque[:, 0] += 0.8 * np.sin(\n        2.0 * np.pi * 15.0 * t_low\n    ).astype(np.float32)\n    unstable_low_level_torque[80:85, 0] += 2.5\n    unstable_low_level_torque[120:123, 0] -= 2.5\n\n    stable_contact = np.zeros((num_low_level, 2), dtype=np.float32)\n    stable_contact[:100, 0] = 1.0\n    stable_contact[100:, 1] = 1.0\n    stable_normal_force = stable_contact * np.array(\n        [[80.0, 75.0]], dtype=np.float32\n    )\n    stable_tangent_speed = stable_contact * 0.01\n\n    unstable_contact = np.zeros((num_low_level, 2), dtype=np.float32)\n    for start in range(0, num_low_level, 10):\n        unstable_contact[start : start + 5, 0] = 1.0\n        unstable_contact[start + 5 : start + 10, 1] = 1.0\n    unstable_normal_force = unstable_contact * 60.0\n    touchdown_mask = unstable_contact.copy()\n    touchdown_mask[1:] = np.clip(\n        unstable_contact[1:] - unstable_contact[:-1], a_min=0.0, a_max=None\n    )\n    unstable_normal_force += touchdown_mask * 120.0\n    unstable_tangent_speed = unstable_contact * 0.25\n\n    smooth_metrics = _compute_clip_stability_summary(\n        data=_make_eval_data(\n            np.zeros((num_frames, 2), dtype=np.float32),\n            robot_low_level_dof_torque=smooth_low_level_torque,\n            robot_global_angular_velocity=smooth_ang_vel,\n            robot_low_level_foot_contact=stable_contact,\n            robot_low_level_foot_normal_force=stable_normal_force,\n            robot_low_level_foot_tangent_speed=stable_tangent_speed,\n        ),\n        robot_control_dt=policy_dt,\n        low_level_contact_dt=low_level_dt,\n    )\n    unstable_metrics = _compute_clip_stability_summary(\n        data=_make_eval_data(\n            np.zeros((num_frames, 2), dtype=np.float32),\n            robot_low_level_dof_torque=unstable_low_level_torque,\n            robot_global_angular_velocity=unstable_ang_vel,\n            robot_low_level_foot_contact=unstable_contact,\n            robot_low_level_foot_normal_force=unstable_normal_force,\n            robot_low_level_foot_tangent_speed=unstable_tangent_speed,\n        ),\n        robot_control_dt=policy_dt,\n        low_level_contact_dt=low_level_dt,\n    )\n\n    assert (\n        unstable_metrics[\"torque_chatter_hf_ratio\"]\n        > smooth_metrics[\"torque_chatter_hf_ratio\"]\n    )\n    assert (\n        unstable_metrics[\"torque_jump_burst_max\"]\n        > smooth_metrics[\"torque_jump_burst_max\"]\n    )\n    assert (\n        unstable_metrics[\"torso_rp_hf_ratio\"]\n        > smooth_metrics[\"torso_rp_hf_ratio\"]\n    )\n    assert (\n        unstable_metrics[\"torso_rp_angacc_p95\"]\n        > smooth_metrics[\"torso_rp_angacc_p95\"]\n    )\n    assert (\n        unstable_metrics[\"foot_contact_toggle_rate\"]\n        > smooth_metrics[\"foot_contact_toggle_rate\"]\n    )\n    assert (\n        unstable_metrics[\"foot_impact_force_p95\"]\n        > smooth_metrics[\"foot_impact_force_p95\"]\n    )\n    assert (\n        unstable_metrics[\"stance_slip_speed_p95\"]\n        > smooth_metrics[\"stance_slip_speed_p95\"]\n    )\n\n\ndef test_compute_clip_stability_summary_reports_expert_switching_js_div():\n    num_frames = 8\n    stable_logits = np.tile(\n        np.array(\n            [\n                [8.0, -4.0, -4.0],\n                [-4.0, 8.0, -4.0],\n            ],\n            dtype=np.float32,\n        )[None, :, :],\n        (num_frames, 1, 1),\n    )\n    switching_logits = stable_logits.copy()\n    switching_logits[1::2, 0, :] = np.array(\n        [-4.0, 8.0, -4.0], dtype=np.float32\n    )\n    switching_logits[1::2, 1, :] = np.array(\n        [-4.0, -4.0, 8.0], dtype=np.float32\n    )\n\n    stable_metrics = _compute_clip_stability_summary(\n        data=_make_eval_data(\n            np.zeros((num_frames, 2), dtype=np.float32),\n            robot_moe_expert_logits=stable_logits,\n        ),\n        robot_control_dt=0.02,\n        low_level_contact_dt=0.02,\n    )\n    switching_metrics = _compute_clip_stability_summary(\n        data=_make_eval_data(\n            np.zeros((num_frames, 2), dtype=np.float32),\n            robot_moe_expert_logits=switching_logits,\n        ),\n        robot_control_dt=0.02,\n        low_level_contact_dt=0.02,\n    )\n\n    assert stable_metrics[\"expert_switching_js_div\"] < 1e-6\n    assert (\n        switching_metrics[\"expert_switching_js_div\"]\n        > stable_metrics[\"expert_switching_js_div\"]\n    )\n\n\ndef test_offline_evaluate_dumped_npzs_reports_nan_contact_metrics_for_legacy_npz(\n    tmp_path: Path,\n):\n    eval_dir = tmp_path / \"legacy_eval\"\n    eval_dir.mkdir()\n\n    payload = _make_eval_data(np.ones((8, 2), dtype=np.float32))\n    payload[\"metadata\"] = np.array(\n        json.dumps({\"clip_length\": 8, \"robot_low_level_torque_dt\": 0.005}),\n        dtype=np.str_,\n    )\n    np.savez_compressed(eval_dir / \"legacy_clip.npz\", **payload)\n\n    output_json_path = eval_dir / \"summary.json\"\n    result = offline_evaluate_dumped_npzs(\n        npz_dir=str(eval_dir),\n        output_json_path=str(output_json_path),\n    )\n\n    per_clip = result[\"per_clip\"][0]\n    for key in (\n        \"torque_chatter_hf_ratio\",\n        \"torque_jump_burst_max\",\n        \"torso_rp_hf_ratio\",\n        \"torso_rp_angacc_p95\",\n        \"foot_contact_toggle_rate\",\n        \"foot_impact_force_p95\",\n        \"stance_slip_speed_p95\",\n        \"expert_switching_js_div\",\n    ):\n        assert key in per_clip\n        assert key in result[\"dataset\"][\"mean\"]\n\n    assert np.isnan(per_clip[\"foot_contact_toggle_rate\"])\n    assert np.isnan(per_clip[\"foot_impact_force_p95\"])\n    assert np.isnan(per_clip[\"stance_slip_speed_p95\"])\n    assert np.isnan(per_clip[\"expert_switching_js_div\"])\n"
  },
  {
    "path": "tests/test_isaaclab_termination.py",
    "content": "import importlib.util\nimport sys\nimport types\nfrom pathlib import Path\nfrom types import SimpleNamespace\n\nimport pytest\nimport torch\n\nMODULE_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"holomotion\"\n    / \"src\"\n    / \"env\"\n    / \"isaaclab_components\"\n    / \"isaaclab_termination.py\"\n)\n\nMOTION_COMMAND_MODULE_NAME = (\n    \"holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command\"\n)\nISAACLAB_UTILS_MODULE_NAME = (\n    \"holomotion.src.env.isaaclab_components.isaaclab_utils\"\n)\n\n\nclass _Scene(SimpleNamespace):\n    def __getitem__(self, key):\n        return getattr(self, key)\n\n\ndef _load_isaaclab_termination_module(module_name: str):\n    isaaclab_module = types.ModuleType(\"isaaclab\")\n    isaaclab_envs = types.ModuleType(\"isaaclab.envs\")\n    isaaclab_envs.ManagerBasedRLEnv = object\n\n    isaaclab_terminations = types.SimpleNamespace(\n        time_out=lambda env: torch.zeros(1, dtype=torch.bool),\n        bad_orientation=lambda env, limit_angle: torch.zeros(\n            1, dtype=torch.bool\n        ),\n        root_height_below_minimum=lambda env, minimum_height: torch.zeros(\n            1, dtype=torch.bool\n        ),\n        native_only_term=lambda env, margin: torch.zeros(1, dtype=torch.bool),\n    )\n    isaaclab_envs_mdp = types.ModuleType(\"isaaclab.envs.mdp\")\n    isaaclab_envs_mdp.terminations = isaaclab_terminations\n\n    isaaclab_managers = types.ModuleType(\"isaaclab.managers\")\n\n    class _TerminationTermCfg:\n        def __init__(self, func, params=None, time_out=False):\n            self.func = func\n            self.params = {} if params is None else params\n            self.time_out = time_out\n\n    isaaclab_managers.TerminationTermCfg = _TerminationTermCfg\n    isaaclab_managers.SceneEntityCfg = object\n\n    isaaclab_utils = types.ModuleType(\"isaaclab.utils\")\n    isaaclab_utils.configclass = lambda cls: cls\n    isaaclab_utils_math = types.ModuleType(\"isaaclab.utils.math\")\n    isaaclab_utils_math.quat_apply_inverse = (\n        lambda quat, vec: torch.zeros_like(vec)\n    )\n    isaaclab_utils.math = isaaclab_utils_math\n\n    isaaclab_assets = types.ModuleType(\"isaaclab.assets\")\n    isaaclab_assets.Articulation = object\n\n    isaaclab_components_package = types.ModuleType(\n        \"holomotion.src.env.isaaclab_components\"\n    )\n    motion_command_module = types.ModuleType(MOTION_COMMAND_MODULE_NAME)\n    motion_command_module.RefMotionCommand = object\n\n    isaaclab_utils_module = types.ModuleType(ISAACLAB_UTILS_MODULE_NAME)\n    isaaclab_utils_module._get_body_indices = lambda robot, keybody_names: None\n    isaaclab_utils_module.resolve_holo_config = lambda cfg: cfg\n    isaaclab_components_package.isaaclab_motion_tracking_command = (\n        motion_command_module\n    )\n    isaaclab_components_package.isaaclab_utils = isaaclab_utils_module\n\n    fake_modules = {\n        \"isaaclab\": isaaclab_module,\n        \"isaaclab.envs\": isaaclab_envs,\n        \"isaaclab.envs.mdp\": isaaclab_envs_mdp,\n        \"isaaclab.managers\": isaaclab_managers,\n        \"isaaclab.utils\": isaaclab_utils,\n        \"isaaclab.utils.math\": isaaclab_utils_math,\n        \"isaaclab.assets\": isaaclab_assets,\n        \"holomotion.src.env.isaaclab_components\": isaaclab_components_package,\n        MOTION_COMMAND_MODULE_NAME: motion_command_module,\n        ISAACLAB_UTILS_MODULE_NAME: isaaclab_utils_module,\n    }\n    original_modules = {name: sys.modules.get(name) for name in fake_modules}\n\n    sys.modules.update(fake_modules)\n    try:\n        spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH)\n        module = importlib.util.module_from_spec(spec)\n        assert spec.loader is not None\n        spec.loader.exec_module(module)\n        return module\n    finally:\n        for name, original in original_modules.items():\n            if original is None:\n                sys.modules.pop(name, None)\n            else:\n                sys.modules[name] = original\n\n\ndef test_wholebody_mpjpe_far_flags_envs_above_mean_error_threshold():\n    termination_module = _load_isaaclab_termination_module(\n        \"isaaclab_termination_under_test\"\n    )\n\n    current_dof_pos = torch.tensor(\n        [\n            [0.0, 0.2, 0.6],\n            [0.0, 0.1, 0.2],\n        ]\n    )\n    ref_dof_pos = torch.zeros_like(current_dof_pos)\n    command = SimpleNamespace(\n        robot=SimpleNamespace(data=SimpleNamespace(joint_pos=current_dof_pos)),\n        get_ref_motion_dof_pos_cur=lambda prefix=\"ref_\": ref_dof_pos,\n        get_ref_motion_dof_pos_immediate_next=lambda prefix=\"ref_\": ref_dof_pos,\n    )\n    env = SimpleNamespace(\n        command_manager=SimpleNamespace(get_term=lambda name: command)\n    )\n\n    result = termination_module.wholebody_mpjpe_far(env, threshold=0.2)\n\n    assert result.dtype == torch.bool\n    assert torch.equal(result, torch.tensor([True, False]))\n\n\ndef test_wholebody_mpjpe_far_uses_immediate_next_reference():\n    termination_module = _load_isaaclab_termination_module(\n        \"isaaclab_termination_under_test_next_dof\"\n    )\n\n    current_dof_pos = torch.tensor([[0.0, 0.1, 0.2]])\n    command = SimpleNamespace(\n        robot=SimpleNamespace(data=SimpleNamespace(joint_pos=current_dof_pos)),\n        get_ref_motion_dof_pos_cur=lambda prefix=\"ref_\": (_ for _ in ()).throw(\n            AssertionError(\"current reference should not be used\")\n        ),\n        get_ref_motion_dof_pos_immediate_next=lambda prefix=\"ref_\": current_dof_pos,\n    )\n    env = SimpleNamespace(\n        command_manager=SimpleNamespace(get_term=lambda name: command)\n    )\n\n    result = termination_module.wholebody_mpjpe_far(env, threshold=0.05)\n\n    assert torch.equal(result, torch.tensor([False]))\n\n\ndef test_keybody_ref_pos_far_uses_immediate_next_reference():\n    termination_module = _load_isaaclab_termination_module(\n        \"isaaclab_termination_under_test_next_keybody\"\n    )\n\n    body_pos = torch.tensor(\n        [[[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]]], dtype=torch.float32\n    )\n    robot = SimpleNamespace(\n        body_names=[\"anchor\", \"target\"],\n        data=SimpleNamespace(body_pos_w=body_pos),\n    )\n    command = SimpleNamespace(\n        robot=robot,\n        get_ref_motion_bodylink_global_pos_cur=(\n            lambda prefix=\"ref_\": (_ for _ in ()).throw(\n                AssertionError(\"current reference should not be used\")\n            )\n        ),\n        get_ref_motion_bodylink_global_pos_immediate_next=(\n            lambda prefix=\"ref_\": body_pos\n        ),\n    )\n    env = SimpleNamespace(\n        command_manager=SimpleNamespace(get_term=lambda name: command)\n    )\n\n    result = termination_module.keybody_ref_pos_far(\n        env,\n        threshold=0.1,\n        keybody_names=[\"target\"],\n    )\n\n    assert torch.equal(result, torch.tensor([False]))\n\n\ndef test_ref_gravity_projection_far_uses_immediate_next_reference():\n    termination_module = _load_isaaclab_termination_module(\n        \"isaaclab_termination_under_test_next_gravity\"\n    )\n    gravity = torch.tensor([[0.0, 0.0, -1.0]], dtype=torch.float32)\n    anchor_quat = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)\n    robot = SimpleNamespace(\n        data=SimpleNamespace(\n            GRAVITY_VEC_W=gravity,\n            body_quat_w=anchor_quat[:, None, :],\n        )\n    )\n    command = SimpleNamespace(\n        robot=robot,\n        anchor_bodylink_idx=0,\n        get_ref_motion_anchor_bodylink_global_rot_wxyz_cur=(\n            lambda prefix=\"ref_\": (_ for _ in ()).throw(\n                AssertionError(\"current reference should not be used\")\n            )\n        ),\n        get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next=(\n            lambda prefix=\"ref_\": anchor_quat\n        ),\n    )\n    env = SimpleNamespace(\n        scene=_Scene(robot=robot),\n        command_manager=SimpleNamespace(get_term=lambda name: command),\n    )\n\n    result = termination_module.ref_gravity_projection_far(\n        env,\n        threshold=0.1,\n    )\n\n    assert torch.equal(result, torch.tensor([False]))\n\n\ndef test_build_terminations_config_registers_wholebody_mpjpe_far():\n    termination_module = _load_isaaclab_termination_module(\n        \"isaaclab_termination_under_test_for_cfg\"\n    )\n\n    config = termination_module.build_terminations_config(\n        {\n            \"wholebody_mpjpe_far\": {\n                \"params\": {\"threshold\": 0.3},\n            }\n        }\n    )\n\n    assert (\n        config.wholebody_mpjpe_far.func\n        is termination_module.wholebody_mpjpe_far\n    )\n    assert config.wholebody_mpjpe_far.params == {\"threshold\": 0.3}\n    assert config.wholebody_mpjpe_far.time_out is False\n\n\ndef test_build_terminations_config_resolves_native_isaaclab_termination():\n    termination_module = _load_isaaclab_termination_module(\n        \"isaaclab_termination_under_test_for_native_cfg\"\n    )\n\n    config = termination_module.build_terminations_config(\n        {\n            \"native_only_term\": {\n                \"params\": {\"margin\": 0.3},\n            }\n        }\n    )\n\n    assert (\n        config.native_only_term.func\n        is termination_module.isaaclab_mdp.terminations.native_only_term\n    )\n    assert config.native_only_term.params == {\"margin\": 0.3}\n    assert config.native_only_term.time_out is False\n\n\ndef test_build_terminations_config_raises_on_unknown_termination():\n    termination_module = _load_isaaclab_termination_module(\n        \"isaaclab_termination_under_test_for_unknown_cfg\"\n    )\n\n    with pytest.raises(ValueError, match=\"Unknown termination function\"):\n        termination_module.build_terminations_config({\"missing_term\": {}})\n"
  },
  {
    "path": "tests/test_mean_process_5metrics.py",
    "content": "import json\nimport sys\nfrom pathlib import Path\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.scripts.evaluation import mean_process_5metrics\n\n\nLEGACY_METRICS = [\n    \"mpjpe_g\",\n    \"mpjpe_l\",\n    \"whole_body_joints_dist\",\n    \"root_vel_error\",\n    \"root_r_error\",\n    \"root_p_error\",\n    \"root_y_error\",\n    \"root_height_error\",\n    \"mean_dof_vel\",\n    \"mean_dof_acc\",\n    \"mean_dof_torque\",\n    \"mean_action_rate\",\n    \"success\",\n]\n\nTORQUE_JUMP_METRICS = [\n    \"mean_torque_jump_norm\",\n    \"p95_torque_jump_norm\",\n    \"mean_torque_jump_ratio\",\n    \"p95_torque_jump_ratio\",\n]\n\n\ndef test_macro_report_appends_torque_jump_columns_to_legacy_tables(\n    tmp_path, monkeypatch\n):\n    json_path = tmp_path / \"model_a.json\"\n    payload = {\n        \"per_clip\": [\n            {\n                \"motion_key\": \"clips_AMASS_demo\",\n                \"mpjpe_g\": 1.0,\n                \"mpjpe_l\": 2.0,\n                \"whole_body_joints_dist\": 3.0,\n                \"root_vel_error\": 4.0,\n                \"root_r_error\": 5.0,\n                \"root_p_error\": 6.0,\n                \"root_y_error\": 7.0,\n                \"root_height_error\": 8.0,\n                \"mean_dof_vel\": 9.0,\n                \"mean_dof_acc\": 10.0,\n                \"mean_dof_torque\": 11.0,\n                \"mean_action_rate\": 12.0,\n                \"success\": 1.0,\n                \"mean_torque_jump_norm\": 13.0,\n                \"p95_torque_jump_norm\": 14.0,\n                \"mean_torque_jump_ratio\": 15.0,\n                \"p95_torque_jump_ratio\": 16.0,\n            }\n        ]\n    }\n    json_path.write_text(json.dumps(payload), encoding=\"utf-8\")\n\n    mean_df, _ = mean_process_5metrics.process_data(str(tmp_path))\n\n    assert mean_df.columns.tolist() == [\n        \"Method\",\n        \"Dataset\",\n        *LEGACY_METRICS,\n        *TORQUE_JUMP_METRICS,\n    ]\n\n    captured_headers = {}\n\n    def _fake_tabulate(_rows, headers, **_kwargs):\n        captured_headers[\"headers\"] = headers\n        return \"fake-table\"\n\n    monkeypatch.setattr(mean_process_5metrics, \"tabulate\", _fake_tabulate)\n\n    report_path = (\n        mean_process_5metrics.generate_macro_mean_report_from_json_dir(\n            str(tmp_path)\n        )\n    )\n\n    tsv_path = tmp_path / \"sub_dataset_macro_mean_metrics.tsv\"\n    header = tsv_path.read_text(encoding=\"utf-8\").splitlines()[0].split(\"\\t\")\n    assert header == [\n        \"Dataset\",\n        \"Global Bodylink Pos Err\",\n        \"Local Bodylink Pos Err\",\n        \"Dof Position Err\",\n        \"Root Vel Err\",\n        \"Root Roll Err\",\n        \"Root Pitch Err\",\n        \"Root Yaw Err\",\n        \"Root Height Err\",\n        \"Mean Dof Vel\",\n        \"Mean Dof Acc\",\n        \"Mean Dof Torque\",\n        \"Mean Action Rate\",\n        \"Success Rate\",\n        \"Mean Torque Jump Norm\",\n        \"P95 Torque Jump Norm\",\n        \"Mean Torque Jump Ratio\",\n        \"P95 Torque Jump Ratio\",\n    ]\n    assert captured_headers[\"headers\"] == header\n\n    report_text = Path(report_path).read_text(encoding=\"utf-8\")\n    legacy_index = report_text.index(\"Success Rate\")\n    torque_index = report_text.index(\"Mean Torque Jump Norm\")\n    assert torque_index > legacy_index\n"
  },
  {
    "path": "tests/test_motion_cache_gather_state.py",
    "content": "import sys\nimport unittest\nfrom unittest import mock\nfrom pathlib import Path\n\nimport torch\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.training.h5_dataloader import (\n    ClipBatch,\n    Hdf5RootDofDataset,\n    MotionClipBatchCache,\n    MotionWindow,\n    _CpuFKTransform,\n    _normalize_online_filter_cfg,\n    build_motion_datasets_from_cfg,\n)\n\n\ndef _expected_field(\n    tensor: torch.Tensor,\n    clip_indices: torch.Tensor,\n    frame_indices: torch.Tensor,\n    n_future_frames: int,\n    lengths: torch.Tensor,\n) -> torch.Tensor:\n    temporal_span = 1 + int(n_future_frames)\n    time_offsets = torch.arange(temporal_span, dtype=torch.long)\n    gather_timesteps = frame_indices[:, None] + time_offsets[None, :]\n    max_valid = torch.clamp(lengths.index_select(0, clip_indices) - 1, min=0)\n    gather_timesteps = torch.minimum(gather_timesteps, max_valid[:, None])\n    return tensor[clip_indices[:, None], gather_timesteps]\n\n\nclass MotionCacheGatherStateTests(unittest.TestCase):\n    def test_normalize_online_filter_cfg_includes_velocity_smoothing_sigmas(\n        self,\n    ):\n        default_cfg = _normalize_online_filter_cfg({})\n\n        self.assertEqual(default_cfg[\"ref_vel_smoothing_sigma\"], 2.0)\n        self.assertEqual(default_cfg[\"ft_ref_vel_smoothing_sigma\"], 2.0)\n\n        explicit_cfg = _normalize_online_filter_cfg(\n            {\n                \"enabled\": True,\n                \"butter_cutoff_hz_pool\": [3.0],\n                \"ref_vel_smoothing_sigma\": 0.0,\n                \"ft_ref_vel_smoothing_sigma\": 2.0,\n            },\n            default_vel_smoothing_sigma=0.5,\n        )\n\n        self.assertEqual(explicit_cfg[\"ref_vel_smoothing_sigma\"], 0.0)\n        self.assertEqual(explicit_cfg[\"ft_ref_vel_smoothing_sigma\"], 2.0)\n\n    def test_normalize_online_filter_cfg_uses_fk_sigma_fallback_defaults(self):\n        cfg = _normalize_online_filter_cfg(\n            {},\n            default_vel_smoothing_sigma=0.5,\n        )\n\n        self.assertEqual(cfg[\"ref_vel_smoothing_sigma\"], 0.5)\n        self.assertEqual(cfg[\"ft_ref_vel_smoothing_sigma\"], 0.5)\n\n    def test_build_motion_datasets_from_cfg_passes_fk_sigma_fallback(self):\n        with (\n            mock.patch(\n                \"holomotion.src.training.h5_dataloader.preview_sampling_from_cfg\"\n            ),\n            mock.patch(\n                \"holomotion.src.training.h5_dataloader.Hdf5RootDofDataset\"\n            ) as dataset_cls,\n        ):\n            build_motion_datasets_from_cfg(\n                {\n                    \"backend\": \"hdf5_v2\",\n                    \"hdf5_root\": \"/tmp/train\",\n                    \"fk_robot_file_path\": \"robot.xml\",\n                    \"fk_vel_smoothing_sigma\": 0.5,\n                    \"cache\": {\"allowed_prefixes\": [\"ref_\", \"ft_ref_\"]},\n                    \"online_filter\": {\"enabled\": False},\n                },\n                max_frame_length=16,\n                min_window_length=4,\n            )\n\n        self.assertEqual(dataset_cls.call_count, 1)\n        self.assertEqual(\n            dataset_cls.call_args.kwargs[\"fk_vel_smoothing_sigma\"],\n            0.5,\n        )\n\n    def test_build_motion_datasets_from_cfg_defaults_fk_sigma_fallback(self):\n        with (\n            mock.patch(\n                \"holomotion.src.training.h5_dataloader.preview_sampling_from_cfg\"\n            ),\n            mock.patch(\n                \"holomotion.src.training.h5_dataloader.Hdf5RootDofDataset\"\n            ) as dataset_cls,\n        ):\n            build_motion_datasets_from_cfg(\n                {\n                    \"backend\": \"hdf5_v2\",\n                    \"hdf5_root\": \"/tmp/train\",\n                    \"fk_robot_file_path\": \"robot.xml\",\n                    \"cache\": {\"allowed_prefixes\": [\"ref_\", \"ft_ref_\"]},\n                    \"online_filter\": {\"enabled\": False},\n                },\n                max_frame_length=16,\n                min_window_length=4,\n            )\n\n        self.assertEqual(dataset_cls.call_count, 1)\n        self.assertEqual(\n            dataset_cls.call_args.kwargs[\"fk_vel_smoothing_sigma\"],\n            2.0,\n        )\n\n    def test_gather_tensor_returns_expected_values(self):\n        cache = MotionClipBatchCache.__new__(MotionClipBatchCache)\n\n        ref_dof_pos = torch.arange(2 * 6 * 3, dtype=torch.float32).reshape(\n            2, 6, 3\n        )\n        ref_rg_pos = torch.arange(2 * 6 * 2 * 3, dtype=torch.float32).reshape(\n            2, 6, 2, 3\n        )\n        lengths = torch.tensor([6, 4], dtype=torch.long)\n        window_indices = torch.tensor([10, 11], dtype=torch.long)\n\n        cache._current_batch = ClipBatch(\n            tensors={\n                \"ref_dof_pos\": ref_dof_pos,\n                \"ref_rg_pos\": ref_rg_pos,\n            },\n            lengths=lengths,\n            motion_keys=[\"clip-a\", \"clip-b\"],\n            raw_motion_keys=[\"clip-a\", \"clip-b\"],\n            window_indices=window_indices,\n            max_frame_length=6,\n        )\n\n        clip_indices = torch.tensor([1, 0, 1, 1], dtype=torch.long)\n        frame_indices = torch.tensor([0, 2, 3, 1], dtype=torch.long)\n\n        gathered_dof_pos = cache.gather_tensor(\n            \"ref_dof_pos\",\n            clip_indices=clip_indices,\n            frame_indices=frame_indices,\n            n_future_frames=2,\n        )\n        gathered_rg_pos = cache.gather_tensor(\n            \"ref_rg_pos\",\n            clip_indices=clip_indices,\n            frame_indices=frame_indices,\n            n_future_frames=2,\n        )\n\n        expected_dof_pos = _expected_field(\n            ref_dof_pos,\n            clip_indices,\n            frame_indices,\n            n_future_frames=2,\n            lengths=lengths,\n        )\n        expected_rg_pos = _expected_field(\n            ref_rg_pos,\n            clip_indices,\n            frame_indices,\n            n_future_frames=2,\n            lengths=lengths,\n        )\n\n        torch.testing.assert_close(gathered_dof_pos, expected_dof_pos)\n        torch.testing.assert_close(gathered_rg_pos, expected_rg_pos)\n        self.assertEqual(tuple(gathered_dof_pos.shape), (4, 3, 3))\n        self.assertEqual(tuple(gathered_rg_pos.shape), (4, 3, 2, 3))\n\n    def test_gather_tensor_reflects_updated_indices_without_cached_state(self):\n        cache = MotionClipBatchCache.__new__(MotionClipBatchCache)\n\n        ref_dof_pos = torch.arange(3 * 6 * 3, dtype=torch.float32).reshape(\n            3, 6, 3\n        )\n        lengths = torch.tensor([6, 5, 4], dtype=torch.long)\n        window_indices = torch.tensor([10, 11, 12], dtype=torch.long)\n\n        cache._current_batch = ClipBatch(\n            tensors={\"ref_dof_pos\": ref_dof_pos},\n            lengths=lengths,\n            motion_keys=[\"clip-a\", \"clip-b\", \"clip-c\"],\n            raw_motion_keys=[\"clip-a\", \"clip-b\", \"clip-c\"],\n            window_indices=window_indices,\n            max_frame_length=6,\n        )\n\n        initial_clip_indices = torch.tensor([0, 1, 2, 1], dtype=torch.long)\n        initial_frame_indices = torch.tensor([0, 1, 2, 0], dtype=torch.long)\n        updated_clip_indices = torch.tensor([0, 2, 1, 0], dtype=torch.long)\n        updated_frame_indices = torch.tensor([1, 0, 3, 2], dtype=torch.long)\n\n        initial_gathered = cache.gather_tensor(\n            \"ref_dof_pos\",\n            clip_indices=initial_clip_indices,\n            frame_indices=initial_frame_indices,\n            n_future_frames=2,\n        )\n        updated_gathered = cache.gather_tensor(\n            \"ref_dof_pos\",\n            clip_indices=updated_clip_indices,\n            frame_indices=updated_frame_indices,\n            n_future_frames=2,\n        )\n\n        expected_initial = _expected_field(\n            ref_dof_pos,\n            initial_clip_indices,\n            initial_frame_indices,\n            n_future_frames=2,\n            lengths=lengths,\n        )\n        expected_updated = _expected_field(\n            ref_dof_pos,\n            updated_clip_indices,\n            updated_frame_indices,\n            n_future_frames=2,\n            lengths=lengths,\n        )\n\n        torch.testing.assert_close(initial_gathered, expected_initial)\n        torch.testing.assert_close(updated_gathered, expected_updated)\n\n    def test_cpu_fk_transform_forwards_explicit_vel_smoothing_sigma(self):\n        transform = _CpuFKTransform.__new__(_CpuFKTransform)\n        transform._fk = mock.Mock(\n            return_value={\n                \"global_translation\": torch.zeros(1, 4, 2, 3),\n                \"global_rotation_quat\": torch.zeros(1, 4, 2, 4),\n                \"global_velocity\": torch.zeros(1, 4, 2, 3),\n                \"global_angular_velocity\": torch.zeros(1, 4, 2, 3),\n                \"dof_vel\": torch.zeros(1, 4, 2),\n            }\n        )\n        arrays = {\n            \"ref_root_pos\": torch.zeros(4, 3),\n            \"ref_root_rot\": torch.zeros(4, 4),\n            \"ref_dof_pos\": torch.zeros(4, 2),\n        }\n\n        transform(\n            arrays,\n            fps=60.0,\n            prefix=\"ref_\",\n            vel_smoothing_sigma=0.0,\n        )\n\n        self.assertEqual(\n            transform._fk.call_args.kwargs[\"vel_smoothing_sigma\"],\n            0.0,\n        )\n\n    def test_cpu_fk_transform_defaults_vel_smoothing_sigma_to_two(self):\n        transform = _CpuFKTransform.__new__(_CpuFKTransform)\n        transform._fk = mock.Mock(\n            return_value={\n                \"global_translation\": torch.zeros(1, 4, 2, 3),\n                \"global_rotation_quat\": torch.zeros(1, 4, 2, 4),\n                \"global_velocity\": torch.zeros(1, 4, 2, 3),\n                \"global_angular_velocity\": torch.zeros(1, 4, 2, 3),\n                \"dof_vel\": torch.zeros(1, 4, 2),\n            }\n        )\n        arrays = {\n            \"ref_root_pos\": torch.zeros(4, 3),\n            \"ref_root_rot\": torch.zeros(4, 4),\n            \"ref_dof_pos\": torch.zeros(4, 2),\n        }\n\n        transform(arrays, fps=60.0)\n\n        self.assertEqual(\n            transform._fk.call_args.kwargs[\"vel_smoothing_sigma\"],\n            2.0,\n        )\n\n    def test_hdf5_v2_sample_exposes_zero_cutoff_metadata_when_disabled(self):\n        dataset = self._make_stub_root_dof_dataset()\n\n        sample = dataset[0]\n\n        self.assertIn(\"filter_cutoff_hz\", sample.tensors)\n        torch.testing.assert_close(\n            sample.tensors[\"filter_cutoff_hz\"],\n            torch.zeros(4, 1, dtype=torch.float32),\n        )\n\n    def test_hdf5_v2_sample_exposes_sampled_cutoff_metadata(self):\n        dataset = self._make_stub_root_dof_dataset(\n            cutoff_pool=(3.0,),\n            online_filter_enabled=True,\n        )\n\n        sample = dataset[0]\n\n        self.assertIn(\"filter_cutoff_hz\", sample.tensors)\n        torch.testing.assert_close(\n            sample.tensors[\"filter_cutoff_hz\"],\n            torch.full((4, 1), 3.0, dtype=torch.float32),\n        )\n\n    def test_hdf5_v2_sample_generates_filtered_reference_family(self):\n        dataset = self._make_stub_root_dof_dataset(\n            cutoff_pool=(3.0,),\n            online_filter_enabled=True,\n        )\n\n        sample = dataset[0]\n\n        for tensor_name in (\n            \"ft_ref_root_pos\",\n            \"ft_ref_root_rot\",\n            \"ft_ref_dof_pos\",\n            \"ft_ref_rg_pos\",\n            \"ft_ref_rb_rot\",\n            \"ft_ref_body_vel\",\n            \"ft_ref_body_ang_vel\",\n            \"ft_ref_dof_vel\",\n            \"ft_ref_root_vel\",\n            \"ft_ref_root_ang_vel\",\n        ):\n            self.assertIn(tensor_name, sample.tensors)\n\n    def test_hdf5_v2_sample_uses_split_fk_smoothing_sigmas(self):\n        dataset = self._make_stub_root_dof_dataset(\n            cutoff_pool=(3.0,),\n            online_filter_enabled=True,\n            ref_vel_smoothing_sigma=0.0,\n            ft_ref_vel_smoothing_sigma=2.0,\n        )\n\n        sample = dataset[0]\n\n        self.assertIn(\"ref_root_vel\", sample.tensors)\n        self.assertIn(\"ft_ref_root_vel\", sample.tensors)\n        self.assertEqual(\n            dataset._fk_calls,\n            [(\"ref_\", 0.0), (\"ft_ref_\", 2.0)],\n        )\n\n    def test_hdf5_v2_sample_skips_filtered_reference_family_when_disabled(\n        self,\n    ):\n        dataset = self._make_stub_root_dof_dataset(\n            cutoff_pool=(3.0,),\n            online_filter_enabled=True,\n            allowed_prefixes=(\"ref_\",),\n        )\n\n        sample = dataset[0]\n\n        self.assertNotIn(\"ft_ref_root_pos\", sample.tensors)\n        self.assertNotIn(\"ft_ref_rg_pos\", sample.tensors)\n\n    @staticmethod\n    def _make_stub_root_dof_dataset(\n        *,\n        cutoff_pool=(0.0,),\n        online_filter_enabled=False,\n        allowed_prefixes=(\"ref_\", \"ft_ref_\"),\n        ref_vel_smoothing_sigma=2.0,\n        ft_ref_vel_smoothing_sigma=2.0,\n    ):\n        dataset = Hdf5RootDofDataset.__new__(Hdf5RootDofDataset)\n        dataset.windows = [\n            MotionWindow(\n                motion_key=\"clip-a__start_0_len_4\",\n                shard_index=0,\n                start=0,\n                length=4,\n                raw_motion_key=\"clip-a\",\n                window_index=0,\n            )\n        ]\n        dataset.clips = {\n            \"clip-a\": {\n                \"metadata\": {\n                    \"motion_fps\": 60.0,\n                }\n            }\n        }\n        dataset._progress_counter = None\n        dataset._world_frame_transform = None\n        dataset._file_handles = {}\n        dataset._h5_access_counter = 0\n        dataset._h5_cleanup_interval = int(1e6)\n        dataset._online_filter_enabled = bool(online_filter_enabled)\n        dataset._online_filter_cutoff_hz_pool = tuple(cutoff_pool)\n        dataset._allowed_prefixes = tuple(allowed_prefixes)\n        dataset._ref_vel_smoothing_sigma = float(ref_vel_smoothing_sigma)\n        dataset._ft_ref_vel_smoothing_sigma = float(ft_ref_vel_smoothing_sigma)\n        dataset._fk_calls = []\n\n        shard_handle = {\n            \"ref_root_pos\": torch.arange(12, dtype=torch.float32)\n            .reshape(4, 3)\n            .numpy(),\n            \"ref_root_rot\": torch.tensor(\n                [[0.0, 0.0, 0.0, 1.0]] * 4, dtype=torch.float32\n            ).numpy(),\n            \"ref_dof_pos\": torch.arange(8, dtype=torch.float32)\n            .reshape(4, 2)\n            .numpy(),\n        }\n\n        dataset._online_filter_butter_order = 4\n\n        def fake_fk_transform(\n            arrays,\n            fps,\n            prefix=\"ref_\",\n            vel_smoothing_sigma=2.0,\n        ):\n            del fps\n            dataset._fk_calls.append((prefix, float(vel_smoothing_sigma)))\n            root_pos = arrays[f\"{prefix}root_pos\"]\n            root_rot = arrays[f\"{prefix}root_rot\"]\n            arrays[f\"{prefix}rg_pos\"] = torch.stack(\n                [root_pos, root_pos], dim=1\n            )\n            arrays[f\"{prefix}rb_rot\"] = torch.stack(\n                [root_rot, root_rot], dim=1\n            )\n            arrays[f\"{prefix}body_vel\"] = torch.zeros(\n                4, 2, 3, dtype=torch.float32\n            )\n            arrays[f\"{prefix}body_ang_vel\"] = torch.zeros(\n                4, 2, 3, dtype=torch.float32\n            )\n            arrays[f\"{prefix}dof_vel\"] = torch.zeros(4, 2, dtype=torch.float32)\n\n        dataset._fk_transform = fake_fk_transform\n        dataset._get_shard_handle = lambda shard_index: shard_handle\n        return dataset\n"
  },
  {
    "path": "tests/test_motion_cache_startup.py",
    "content": "from pathlib import Path\nimport sys\nfrom types import SimpleNamespace\n\nimport unittest\nfrom unittest import mock\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.algo.algo_base import BaseOnpolicyRL\nimport holomotion.src.training.h5_dataloader as h5_dataloader_module\nfrom holomotion.src.training.h5_dataloader import MotionClipBatchCache\n\n\nclass _FakeDataset:\n    def __init__(self, length: int = 8) -> None:\n        self._length = int(length)\n        self.max_frame_length = 16\n        self.progress_counter = None\n\n    def __len__(self) -> int:\n        return self._length\n\n    def __getitem__(self, index: int):\n        raise AssertionError(\"__getitem__ should not be called in these tests\")\n\n    def set_progress_counter(self, counter) -> None:\n        self.progress_counter = counter\n\n    def close(self) -> None:\n        return\n\n\nclass MotionCacheStartupTests(unittest.TestCase):\n    def test_motion_cache_uses_explicit_constructor_seed(self):\n        with (\n            mock.patch.object(\n                MotionClipBatchCache, \"_build_dataloader\", lambda self: None\n            ),\n            mock.patch.object(\n                MotionClipBatchCache, \"_prime_buffers\", lambda self: None\n            ),\n        ):\n            cache = MotionClipBatchCache(\n                train_dataset=_FakeDataset(),\n                batch_size=2,\n                num_workers=0,\n                pin_memory=False,\n                persistent_workers=False,\n                seed=1234,\n            )\n\n        self.assertEqual(cache._seed, 1234)\n\n    def test_setup_seeding_does_not_reinitialize_motion_cache(self):\n        algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL)\n        algo.config = {\"seed\": 100}\n        algo.process_rank = 2\n        algo.command_name = \"ref_motion\"\n\n        env_seed_calls = []\n        motion_cache_seed_calls = []\n        algo.env = SimpleNamespace(\n            seed=lambda seed: env_seed_calls.append(seed)\n        )\n        algo.command_term = SimpleNamespace(\n            cfg=SimpleNamespace(seed=102),\n            set_motion_cache_seed=lambda seed,\n            reinitialize: motion_cache_seed_calls.append((seed, reinitialize)),\n        )\n\n        BaseOnpolicyRL._setup_seeding(algo)\n\n        self.assertEqual(algo.base_seed, 100)\n        self.assertEqual(algo.seed, 102)\n        self.assertEqual(env_seed_calls, [102])\n        self.assertEqual(motion_cache_seed_calls, [(102, False)])\n\n    def test_motion_cache_passes_loader_timeout_to_dataloader(self):\n        captured_kwargs = {}\n\n        class _FakeLoader:\n            def __init__(self, *args, **kwargs) -> None:\n                del args\n                captured_kwargs.update(kwargs)\n\n        with (\n            mock.patch.object(h5_dataloader_module, \"DataLoader\", _FakeLoader),\n            mock.patch.object(\n                MotionClipBatchCache, \"_prime_buffers\", lambda self: None\n            ),\n        ):\n            MotionClipBatchCache(\n                train_dataset=_FakeDataset(),\n                batch_size=2,\n                num_workers=0,\n                pin_memory=False,\n                persistent_workers=False,\n                loader_timeout=17,\n            )\n\n        self.assertEqual(captured_kwargs[\"timeout\"], 17)\n\n    def test_motion_cache_disables_progress_bar_in_distributed_runs(self):\n        with (\n            mock.patch.object(\n                MotionClipBatchCache, \"_build_dataloader\", lambda self: None\n            ),\n            mock.patch.object(\n                MotionClipBatchCache, \"_prime_buffers\", lambda self: None\n            ),\n        ):\n            cache = MotionClipBatchCache(\n                train_dataset=_FakeDataset(),\n                batch_size=2,\n                num_workers=0,\n                pin_memory=False,\n                persistent_workers=False,\n                sampler_world_size=8,\n                batch_progress_bar=True,\n            )\n\n        self.assertIs(cache._should_use_batch_progress(), False)\n        self.assertIsNone(cache._batch_progress_counter)\n\n    def test_motion_cache_keeps_progress_bar_for_local_runs(self):\n        with (\n            mock.patch.object(\n                MotionClipBatchCache, \"_build_dataloader\", lambda self: None\n            ),\n            mock.patch.object(\n                MotionClipBatchCache, \"_prime_buffers\", lambda self: None\n            ),\n        ):\n            cache = MotionClipBatchCache(\n                train_dataset=_FakeDataset(),\n                batch_size=2,\n                num_workers=0,\n                pin_memory=False,\n                persistent_workers=False,\n                sampler_world_size=1,\n                batch_progress_bar=True,\n            )\n\n        self.assertIs(cache._should_use_batch_progress(), True)\n        self.assertIsNotNone(cache._batch_progress_counter)\n\n    def test_motion_cache_requires_positive_loader_timeout(self):\n        with (\n            mock.patch.object(\n                MotionClipBatchCache, \"_build_dataloader\", lambda self: None\n            ),\n            mock.patch.object(\n                MotionClipBatchCache, \"_prime_buffers\", lambda self: None\n            ),\n        ):\n            with self.assertRaisesRegex(\n                ValueError, \"loader_timeout must be >= 0\"\n            ):\n                MotionClipBatchCache(\n                    train_dataset=_FakeDataset(),\n                    batch_size=2,\n                    num_workers=0,\n                    pin_memory=False,\n                    persistent_workers=False,\n                    loader_timeout=-1,\n                )\n"
  },
  {
    "path": "tests/test_motion_tracking_command_reference_prefix.py",
    "content": "import sys\nimport unittest\nfrom pathlib import Path\nfrom types import SimpleNamespace\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.utils.reference_prefix import resolve_reference_tensor_key\n\n\nclass MotionTrackingCommandReferencePrefixTests(unittest.TestCase):\n    def test_ft_ref_prefix_uses_filtered_tensor_when_present(self):\n        resolved = resolve_reference_tensor_key(\n            batch_tensors={\"ft_ref_root_pos\": SimpleNamespace()},\n            base_key=\"root_pos\",\n            prefix=\"ft_ref_\",\n        )\n\n        self.assertEqual(resolved, \"ft_ref_root_pos\")\n\n    def test_ft_ref_prefix_requires_filtered_tensor(self):\n        with self.assertRaises(KeyError):\n            resolve_reference_tensor_key(\n                batch_tensors={\"root_pos\": SimpleNamespace()},\n                base_key=\"root_pos\",\n                prefix=\"ft_ref_\",\n            )\n\n    def test_ref_prefix_falls_back_to_unprefixed_tensor(self):\n        resolved = resolve_reference_tensor_key(\n            batch_tensors={\"root_pos\": SimpleNamespace()},\n            base_key=\"root_pos\",\n            prefix=\"ref_\",\n        )\n\n        self.assertEqual(resolved, \"root_pos\")\n\n    def test_ref_prefix_prefers_prefixed_tensor_when_present(self):\n        resolved = resolve_reference_tensor_key(\n            batch_tensors={\n                \"root_pos\": SimpleNamespace(),\n                \"ref_root_pos\": SimpleNamespace(),\n            },\n            base_key=\"root_pos\",\n            prefix=\"ref_\",\n        )\n\n        self.assertEqual(resolved, \"ref_root_pos\")\n"
  },
  {
    "path": "tests/test_motion_tracking_timing.py",
    "content": "import importlib.util\nimport sys\nfrom pathlib import Path\nfrom types import ModuleType, SimpleNamespace\n\nimport pytest\nimport torch\n\nMODULE_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"holomotion\"\n    / \"src\"\n    / \"env\"\n    / \"isaaclab_components\"\n    / \"isaaclab_motion_tracking_command.py\"\n)\n\n\nclass _DummyConfig:\n    def __init__(self, *args, **kwargs):\n        self.args = args\n        self.kwargs = kwargs\n\n\ndef _install_fake_motion_command_deps(monkeypatch):\n    isaaclab_mdp = ModuleType(\"isaaclab.envs.mdp\")\n    isaaclab_sim = ModuleType(\"isaaclab.sim\")\n    isaaclab_sim.PreviewSurfaceCfg = _DummyConfig\n    isaaclab_sim.PhysxCfg = _DummyConfig\n    isaaclab_sim.SimulationCfg = _DummyConfig\n    isaaclab_math = ModuleType(\"isaaclab.utils.math\")\n    isaaclab_math.quat_apply_inverse = lambda quat, vec: vec\n    isaaclab_math.quat_apply = lambda quat, vec: vec\n    isaaclab_math.yaw_quat = lambda quat: quat\n    isaaclab_math.quat_inv = lambda quat: quat\n    isaaclab_math.quat_mul = lambda lhs, rhs: lhs\n    isaaclab_math.sample_uniform = (\n        lambda low, high, shape, device=None: torch.zeros(\n            *shape, device=device\n        )\n    )\n\n    isaaclab_actuators = ModuleType(\"isaaclab.actuators\")\n    isaaclab_actuators.ImplicitActuatorCfg = _DummyConfig\n\n    isaaclab_assets = ModuleType(\"isaaclab.assets\")\n    isaaclab_assets.Articulation = object\n    isaaclab_assets.ArticulationCfg = _DummyConfig\n    isaaclab_assets.AssetBaseCfg = _DummyConfig\n\n    isaaclab_envs = ModuleType(\"isaaclab.envs\")\n    isaaclab_envs.ManagerBasedRLEnv = object\n    isaaclab_envs.ManagerBasedRLEnvCfg = _DummyConfig\n    isaaclab_envs.ViewerCfg = _DummyConfig\n\n    isaaclab_envs_mdp_actions = ModuleType(\"isaaclab.envs.mdp.actions\")\n    isaaclab_envs_mdp_actions.JointEffortActionCfg = _DummyConfig\n\n    isaaclab_managers = ModuleType(\"isaaclab.managers\")\n    isaaclab_managers.ActionTermCfg = _DummyConfig\n    isaaclab_managers.CommandTerm = object\n    isaaclab_managers.CommandTermCfg = _DummyConfig\n    isaaclab_managers.EventTermCfg = _DummyConfig\n    isaaclab_managers.ObservationGroupCfg = _DummyConfig\n    isaaclab_managers.ObservationTermCfg = _DummyConfig\n    isaaclab_managers.RewardTermCfg = _DummyConfig\n    isaaclab_managers.TerminationTermCfg = _DummyConfig\n\n    isaaclab_markers = ModuleType(\"isaaclab.markers\")\n    isaaclab_markers.VisualizationMarkers = _DummyConfig\n    isaaclab_markers.VisualizationMarkersCfg = _DummyConfig\n\n    isaaclab_markers_config = ModuleType(\"isaaclab.markers.config\")\n    isaaclab_markers_config.SPHERE_MARKER_CFG = SimpleNamespace(\n        replace=lambda **kwargs: SimpleNamespace(\n            markers={\"sphere\": SimpleNamespace(radius=None)},\n            **kwargs,\n        )\n    )\n\n    isaaclab_scene = ModuleType(\"isaaclab.scene\")\n    isaaclab_scene.InteractiveSceneCfg = _DummyConfig\n\n    isaaclab_sensors = ModuleType(\"isaaclab.sensors\")\n    isaaclab_sensors.ContactSensorCfg = _DummyConfig\n    isaaclab_sensors.RayCasterCfg = _DummyConfig\n    isaaclab_sensors.patterns = _DummyConfig\n\n    isaaclab_terrains = ModuleType(\"isaaclab.terrains\")\n    isaaclab_terrains.TerrainImporterCfg = _DummyConfig\n\n    isaaclab_utils = ModuleType(\"isaaclab.utils\")\n    isaaclab_utils.configclass = lambda cls: cls\n\n    isaaclab_noise = ModuleType(\"isaaclab.utils.noise\")\n    isaaclab_noise.AdditiveUniformNoiseCfg = _DummyConfig\n\n    h5_dataloader = ModuleType(\"holomotion.src.training.h5_dataloader\")\n    h5_dataloader.Hdf5MotionDataset = object\n    h5_dataloader.Hdf5RootDofDataset = object\n    h5_dataloader.MotionClipBatchCache = object\n    h5_dataloader.build_motion_datasets_from_cfg = lambda *args, **kwargs: None\n\n    rotations = ModuleType(\"holomotion.src.utils.isaac_utils.rotations\")\n    rotations.calc_heading_quat_inv = lambda *args, **kwargs: None\n    rotations.get_euler_xyz = lambda *args, **kwargs: None\n    rotations.my_quat_rotate = lambda *args, **kwargs: None\n    rotations.quat_inverse = lambda *args, **kwargs: None\n    rotations.quat_mul = lambda *args, **kwargs: None\n    rotations.quat_rotate = lambda *args, **kwargs: None\n    rotations.quat_rotate_inverse = lambda *args, **kwargs: None\n    rotations.quaternion_to_matrix = lambda *args, **kwargs: None\n    rotations.wrap_to_pi = lambda *args, **kwargs: None\n    rotations.wxyz_to_xyzw = lambda x: x\n    rotations.xyzw_to_wxyz = lambda x: x\n\n    reference_prefix = ModuleType(\"holomotion.src.utils.reference_prefix\")\n    reference_prefix.resolve_reference_tensor_key = (\n        lambda batch_tensors, base_key, prefix=\"ref_\": f\"{prefix}{base_key}\"\n    )\n\n    omegaconf = ModuleType(\"omegaconf\")\n    omegaconf.OmegaConf = SimpleNamespace(\n        to_container=lambda value, resolve=True: value\n    )\n\n    loguru = ModuleType(\"loguru\")\n    loguru.logger = SimpleNamespace(info=lambda *args, **kwargs: None)\n\n    tqdm = ModuleType(\"tqdm\")\n    tqdm.tqdm = lambda iterable, *args, **kwargs: iterable\n\n    scipy = ModuleType(\"scipy\")\n    scipy_spatial = ModuleType(\"scipy.spatial\")\n    scipy_transform = ModuleType(\"scipy.spatial.transform\")\n    scipy_transform.Rotation = object\n\n    for name, module in {\n        \"isaaclab.envs.mdp\": isaaclab_mdp,\n        \"isaaclab.sim\": isaaclab_sim,\n        \"isaaclab.utils.math\": isaaclab_math,\n        \"isaaclab.actuators\": isaaclab_actuators,\n        \"isaaclab.assets\": isaaclab_assets,\n        \"isaaclab.envs\": isaaclab_envs,\n        \"isaaclab.envs.mdp.actions\": isaaclab_envs_mdp_actions,\n        \"isaaclab.managers\": isaaclab_managers,\n        \"isaaclab.markers\": isaaclab_markers,\n        \"isaaclab.markers.config\": isaaclab_markers_config,\n        \"isaaclab.scene\": isaaclab_scene,\n        \"isaaclab.sensors\": isaaclab_sensors,\n        \"isaaclab.terrains\": isaaclab_terrains,\n        \"isaaclab.utils\": isaaclab_utils,\n        \"isaaclab.utils.noise\": isaaclab_noise,\n        \"holomotion.src.training.h5_dataloader\": h5_dataloader,\n        \"holomotion.src.utils.isaac_utils.rotations\": rotations,\n        \"holomotion.src.utils.reference_prefix\": reference_prefix,\n        \"omegaconf\": omegaconf,\n        \"loguru\": loguru,\n        \"tqdm\": tqdm,\n        \"scipy\": scipy,\n        \"scipy.spatial\": scipy_spatial,\n        \"scipy.spatial.transform\": scipy_transform,\n    }.items():\n        monkeypatch.setitem(sys.modules, name, module)\n\n\ndef _load_motion_command_module(monkeypatch):\n    _install_fake_motion_command_deps(monkeypatch)\n    module_name = \"_test_motion_tracking_timing\"\n    spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH)\n    module = importlib.util.module_from_spec(spec)\n    assert spec is not None\n    assert spec.loader is not None\n    sys.modules[module_name] = module\n    spec.loader.exec_module(module)\n    return module\n\n\ndef test_immediate_next_reference_getters_use_slot_one(monkeypatch):\n    module = _load_motion_command_module(monkeypatch)\n    command = module.RefMotionCommand.__new__(module.RefMotionCommand)\n    command.urdf2sim_dof_idx = torch.tensor([1, 0], dtype=torch.long)\n    command.urdf2sim_body_idx = torch.tensor([1, 0], dtype=torch.long)\n    command._env_origins = torch.tensor(\n        [[10.0, 20.0, 30.0]], dtype=torch.float32\n    )\n    base_tensors = {\n        \"ref_dof_pos\": torch.tensor(\n            [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]], dtype=torch.float32\n        ),\n        \"ref_root_pos\": torch.tensor(\n            [[[0.0, 1.0, 2.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],\n            dtype=torch.float32,\n        ),\n        \"ref_body_vel\": torch.tensor(\n            [\n                [\n                    [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]],\n                    [[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],\n                    [[4.0, 4.0, 4.0], [5.0, 5.0, 5.0]],\n                ]\n            ],\n            dtype=torch.float32,\n        ),\n    }\n    command._get_ref_state_array = (\n        lambda base_key, prefix=\"ref_\": base_tensors[f\"{prefix}{base_key}\"]\n    )\n\n    dof_pos = command.get_ref_motion_dof_pos_immediate_next()\n    root_pos = command.get_ref_motion_root_global_pos_immediate_next()\n    body_lin_vel = (\n        command.get_ref_motion_bodylink_global_lin_vel_immediate_next()\n    )\n\n    assert torch.allclose(dof_pos, torch.tensor([[4.0, 3.0]]))\n    assert torch.allclose(root_pos, torch.tensor([[17.0, 28.0, 39.0]]))\n    assert torch.allclose(\n        body_lin_vel,\n        torch.tensor([[[3.0, 3.0, 3.0], [2.0, 2.0, 2.0]]]),\n    )\n\n\ndef test_update_command_skips_just_reset_envs(monkeypatch):\n    module = _load_motion_command_module(monkeypatch)\n    command = module.RefMotionCommand.__new__(module.RefMotionCommand)\n    command.device = torch.device(\"cpu\")\n    command.num_envs = 3\n    command._frame_indices = torch.tensor([10, 20, 30], dtype=torch.long)\n    command._swap_step_counter = 0\n    command._swap_pending = False\n    command._motion_cache = SimpleNamespace(swap_interval_steps=100)\n    command._env = SimpleNamespace(\n        episode_length_buf=torch.tensor([5, 0, 2], dtype=torch.long)\n    )\n    command._filter_env_ids_for_motion_task = lambda env_ids: env_ids\n    command._resample_when_motion_end_cache = lambda: None\n    command._update_ref_motion_state_from_cache = lambda env_ids=None: None\n\n    command._update_command()\n\n    assert torch.equal(command._frame_indices, torch.tensor([11, 20, 31]))\n    assert command._swap_step_counter == 1\n\n\ndef test_update_command_resumes_advancing_after_reset_step(monkeypatch):\n    module = _load_motion_command_module(monkeypatch)\n    command = module.RefMotionCommand.__new__(module.RefMotionCommand)\n    command.device = torch.device(\"cpu\")\n    command.num_envs = 1\n    command._frame_indices = torch.tensor([20], dtype=torch.long)\n    command._swap_step_counter = 0\n    command._swap_pending = False\n    command._motion_cache = SimpleNamespace(swap_interval_steps=100)\n    command._env = SimpleNamespace(episode_length_buf=torch.tensor([0]))\n    command._filter_env_ids_for_motion_task = lambda env_ids: env_ids\n    command._resample_when_motion_end_cache = lambda: None\n    command._update_ref_motion_state_from_cache = lambda env_ids=None: None\n\n    command._update_command()\n    assert torch.equal(command._frame_indices, torch.tensor([20]))\n\n    command._env.episode_length_buf = torch.tensor([1])\n    command._update_command()\n    assert torch.equal(command._frame_indices, torch.tensor([21]))\n\n\ndef test_mpjpe_metrics_use_immediate_next_reference(monkeypatch):\n    module = _load_motion_command_module(monkeypatch)\n    command = module.RefMotionCommand.__new__(module.RefMotionCommand)\n    command.device = torch.device(\"cpu\")\n    command.num_envs = 1\n    command.metrics = {}\n    command.arm_dof_indices = [0]\n    command.torso_dof_indices = [1]\n    command.leg_dof_indices = [2]\n    command.robot = SimpleNamespace(\n        data=SimpleNamespace(\n            joint_pos=torch.tensor([[0.1, 0.2, 0.3]], dtype=torch.float32)\n        )\n    )\n    command.get_ref_motion_dof_pos_cur = lambda prefix=\"ref_\": (\n        _ for _ in ()\n    ).throw(AssertionError(\"current reference should not be used\"))\n    command.get_ref_motion_dof_pos_immediate_next = (\n        lambda prefix=\"ref_\": torch.tensor(\n            [[0.1, 0.2, 0.3]], dtype=torch.float32\n        )\n    )\n\n    command._update_mpjpe_metrics()\n\n    assert torch.allclose(\n        command.metrics[\"Task/MPJPE_WholeBody\"], torch.zeros(1)\n    )\n\n\ndef test_mpkpe_metrics_use_immediate_next_reference(monkeypatch):\n    module = _load_motion_command_module(monkeypatch)\n    command = module.RefMotionCommand.__new__(module.RefMotionCommand)\n    command.device = torch.device(\"cpu\")\n    command.num_envs = 1\n    command.metrics = {}\n    command.arm_body_indices = [0]\n    command.torso_body_indices = [1]\n    command.leg_body_indices = [2]\n    command.robot = SimpleNamespace(\n        data=SimpleNamespace(\n            body_pos_w=torch.tensor(\n                [[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]],\n                dtype=torch.float32,\n            )\n        )\n    )\n    command.get_ref_motion_bodylink_global_pos_cur = lambda prefix=\"ref_\": (\n        _ for _ in ()\n    ).throw(AssertionError(\"current reference should not be used\"))\n    command.get_ref_motion_bodylink_global_pos_immediate_next = (\n        lambda prefix=\"ref_\": torch.tensor(\n            [[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]],\n            dtype=torch.float32,\n        )\n    )\n\n    command._update_mpkpe_metrics()\n\n    assert torch.allclose(\n        command.metrics[\"Task/MPKPE_WholeBody\"], torch.zeros(1)\n    )\n"
  },
  {
    "path": "tests/test_mujoco_filtered_ref_compat.py",
    "content": "import tempfile\nfrom pathlib import Path\n\nimport numpy as np\nfrom omegaconf import OmegaConf\n\nfrom holomotion.src.evaluation.eval_mujoco_sim2sim import MujocoEvaluator\nfrom holomotion.src.evaluation.obs.obs_builder import PolicyObsBuilder\n\n\nPROJECT_ROOT = Path(__file__).resolve().parents[1]\nOBS_CONFIG_PATH = (\n    PROJECT_ROOT\n    / \"holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3_with_freq.yaml\"\n)\nMODULE_CONFIG_PATH = (\n    PROJECT_ROOT\n    / \"holomotion/config/modules/motion_tracking/tf_motrack_v3_with_ft.yaml\"\n)\nOBS_CONFIG_PATH_V2 = (\n    PROJECT_ROOT\n    / \"holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3_sonic_router_v2.yaml\"\n)\nMODULE_CONFIG_PATH_V2 = (\n    PROJECT_ROOT\n    / \"holomotion/config/modules/motion_tracking/tf_motrack_v3_wo_eepos_ref_route_v2.yaml\"\n)\nDOMAIN_RAND_CONFIG_PATH = (\n    PROJECT_ROOT\n    / \"holomotion/config/env/domain_randomization/domain_rand_strong.yaml\"\n)\n\n\ndef _make_minimal_motion_npz(path: Path, *, include_cutoff: bool) -> None:\n    payload = {\n        \"ref_global_translation\": np.zeros((2, 1, 3), dtype=np.float32),\n        \"ref_global_rotation_quat\": np.tile(\n            np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32),\n            (2, 1, 1),\n        ),\n        \"ref_global_velocity\": np.zeros((2, 1, 3), dtype=np.float32),\n        \"ref_global_angular_velocity\": np.zeros((2, 1, 3), dtype=np.float32),\n        \"ref_dof_pos\": np.zeros((2, 2), dtype=np.float32),\n        \"ref_dof_vel\": np.zeros((2, 2), dtype=np.float32),\n    }\n    if include_cutoff:\n        payload[\"filter_cutoff_hz\"] = np.array(\n            [[2.0], [3.0]], dtype=np.float32\n        )\n    np.savez(path, **payload)\n\n\ndef test_policy_obs_list_accepts_shared_cutoff_term():\n    config = OmegaConf.merge(\n        OmegaConf.load(OBS_CONFIG_PATH),\n        OmegaConf.load(MODULE_CONFIG_PATH),\n    )\n    evaluator = MujocoEvaluator.__new__(MujocoEvaluator)\n    evaluator.config = config\n\n    atomic_obs_list = evaluator._get_policy_atomic_obs_list()\n\n    term_names = [str(list(item.keys())[0]) for item in atomic_obs_list]\n    assert term_names[0] == \"ref_motion_filter_cutoff_hz\"\n    assert \"actor_ref_gravity_projection_cur\" in term_names\n\n\ndef test_cutoff_obs_getters_use_current_frame_and_default_zero():\n    evaluator = MujocoEvaluator.__new__(MujocoEvaluator)\n    evaluator.motion_frame_idx = 1\n    evaluator.filter_cutoff_hz = np.array([[2.0], [3.0]], dtype=np.float32)\n\n    assert evaluator._get_obs_ref_motion_filter_cutoff_hz() == np.float32(3.0)\n    assert (\n        evaluator._get_obs_actor_ref_motion_filter_cutoff_hz()\n        == np.float32(3.0)\n    )\n\n    missing = MujocoEvaluator.__new__(MujocoEvaluator)\n    missing.motion_frame_idx = 0\n    assert missing._get_obs_ref_motion_filter_cutoff_hz() == 0.0\n\n\ndef test_policy_obs_list_v2_uses_only_actor_schema_terms():\n    config = OmegaConf.merge(\n        OmegaConf.load(OBS_CONFIG_PATH_V2),\n        OmegaConf.load(MODULE_CONFIG_PATH_V2),\n        OmegaConf.load(DOMAIN_RAND_CONFIG_PATH),\n    )\n    evaluator = MujocoEvaluator.__new__(MujocoEvaluator)\n    evaluator.config = config\n\n    atomic_obs_list = evaluator._get_policy_atomic_obs_list()\n    term_names = [str(list(item.keys())[0]) for item in atomic_obs_list]\n\n    assert not any(name.startswith(\"actor_moe_router_\") for name in term_names)\n    assert \"actor_ref_gravity_projection_cur\" in term_names\n    assert \"actor_ref_base_linvel_fut\" in term_names\n\n\ndef test_load_specific_motion_loads_cutoff_metadata_with_zero_fallback():\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        with_cutoff = Path(tmp_dir) / \"with_cutoff.npz\"\n        without_cutoff = Path(tmp_dir) / \"without_cutoff.npz\"\n        _make_minimal_motion_npz(with_cutoff, include_cutoff=True)\n        _make_minimal_motion_npz(without_cutoff, include_cutoff=False)\n\n        evaluator = MujocoEvaluator.__new__(MujocoEvaluator)\n        evaluator.load_specific_motion(with_cutoff)\n        np.testing.assert_allclose(\n            evaluator.filter_cutoff_hz,\n            np.array([[2.0], [3.0]], dtype=np.float32),\n        )\n\n        evaluator.load_specific_motion(without_cutoff)\n        np.testing.assert_allclose(\n            evaluator.filter_cutoff_hz,\n            np.zeros((2, 1), dtype=np.float32),\n        )\n"
  },
  {
    "path": "tests/test_obs_norm_compile.py",
    "content": "import torch\nimport torch.nn as nn\nfrom holomotion.src.modules.agent_modules import PPOTFActor\nfrom holomotion.src.modules.network_modules import EmpiricalNormalization\n\n\ndef _make_actor_with_obs_norm(obs_dim: int = 16) -> PPOTFActor:\n    actor = PPOTFActor.__new__(PPOTFActor)\n    nn.Module.__init__(actor)\n    actor.obs_norm_enabled = True\n    actor.obs_norm_clip = 10.0\n    actor.obs_normalizer = EmpiricalNormalization(shape=(obs_dim,))\n    return actor\n\n\ndef test_obs_norm_update_is_not_captured_by_dynamo():\n    actor = _make_actor_with_obs_norm()\n    obs = torch.randn(8, 16)\n\n    def normalize_with_update(x: torch.Tensor) -> torch.Tensor:\n        return actor._normalize_actor_obs(x, True)\n\n    explanation = torch._dynamo.explain(normalize_with_update)(obs)\n    graph_code = \"\\n\".join(graph.code for graph in explanation.graphs)\n\n    assert \"torch.var\" not in graph_code\n    assert \"torch.mean\" not in graph_code\n\n    count_before_compile = actor.obs_normalizer.count.item()\n    compiled = torch.compile(normalize_with_update, backend=\"eager\")\n    normalized = compiled(obs)\n\n    assert normalized.shape == obs.shape\n    assert (\n        actor.obs_normalizer.count.item() - count_before_compile\n        == obs.shape[0]\n    )\n"
  },
  {
    "path": "tests/test_observation_frames.py",
    "content": "import importlib.util\nimport sys\nfrom pathlib import Path\nfrom types import ModuleType, SimpleNamespace\n\nimport pytest\nimport torch\n\nOBSERVATION_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"holomotion\"\n    / \"src\"\n    / \"env\"\n    / \"isaaclab_components\"\n    / \"isaaclab_observation.py\"\n)\n\n\nclass _DummyConfig:\n    def __init__(self, *args, **kwargs):\n        self.args = args\n        self.kwargs = kwargs\n\n\nclass _Scene(SimpleNamespace):\n    def __getitem__(self, key):\n        return getattr(self, key)\n\n\ndef _identity_quat(*shape: int) -> torch.Tensor:\n    quat = torch.zeros(*shape, 4, dtype=torch.float32)\n    quat[..., 0] = 1.0\n    return quat\n\n\ndef _load_observation_module(monkeypatch):\n    isaaclab = ModuleType(\"isaaclab\")\n    isaaclab_mdp = ModuleType(\"isaaclab.envs.mdp\")\n    isaaclab_math = ModuleType(\"isaaclab.utils.math\")\n    isaaclab_math.quat_apply = lambda quat, vec: vec\n    isaaclab_math.quat_apply_inverse = lambda quat, vec: vec\n    isaaclab_math.quat_inv = lambda quat: quat\n    isaaclab_math.matrix_from_quat = lambda quat: torch.zeros(\n        *quat.shape[:-1], 3, 3, dtype=quat.dtype, device=quat.device\n    )\n    isaaclab_math.subtract_frame_transforms = lambda t01, q01, t02, q02: (\n        t02 - t01,\n        q02,\n    )\n    isaaclab_math.__getattr__ = lambda name: (lambda *args, **kwargs: None)\n    isaaclab_noise = ModuleType(\"isaaclab.utils.noise\")\n    isaaclab_noise.__getattr__ = lambda name: _DummyConfig\n\n    isaaclab_envs = ModuleType(\"isaaclab.envs\")\n    isaaclab_envs.ManagerBasedRLEnv = object\n    isaaclab_envs.ManagerBasedRLEnvCfg = _DummyConfig\n    isaaclab_envs.ViewerCfg = _DummyConfig\n    isaaclab_sim = ModuleType(\"isaaclab.sim\")\n    isaaclab_sim.__getattr__ = lambda name: _DummyConfig\n    isaaclab_actuators = ModuleType(\"isaaclab.actuators\")\n    isaaclab_actuators.ImplicitActuatorCfg = _DummyConfig\n    isaaclab_assets = ModuleType(\"isaaclab.assets\")\n    isaaclab_assets.Articulation = object\n    isaaclab_assets.ArticulationCfg = _DummyConfig\n    isaaclab_assets.AssetBaseCfg = _DummyConfig\n    isaaclab_managers = ModuleType(\"isaaclab.managers\")\n    isaaclab_managers.__getattr__ = lambda name: _DummyConfig\n    isaaclab_markers = ModuleType(\"isaaclab.markers\")\n    isaaclab_markers.VisualizationMarkers = _DummyConfig\n    isaaclab_markers.VisualizationMarkersCfg = _DummyConfig\n    isaaclab_markers_config = ModuleType(\"isaaclab.markers.config\")\n    isaaclab_markers_config.FRAME_MARKER_CFG = _DummyConfig\n    isaaclab_scene = ModuleType(\"isaaclab.scene\")\n    isaaclab_scene.InteractiveSceneCfg = _DummyConfig\n    isaaclab_sensors = ModuleType(\"isaaclab.sensors\")\n    isaaclab_sensors.ContactSensorCfg = _DummyConfig\n    isaaclab_sensors.RayCasterCfg = _DummyConfig\n    isaaclab_sensors.patterns = _DummyConfig\n    isaaclab_terrains = ModuleType(\"isaaclab.terrains\")\n    isaaclab_terrains.TerrainImporterCfg = _DummyConfig\n    isaaclab_utils = ModuleType(\"isaaclab.utils\")\n    isaaclab_utils.configclass = lambda cls: cls\n\n    omegaconf = ModuleType(\"omegaconf\")\n    omegaconf.DictConfig = dict\n    omegaconf.ListConfig = list\n    omegaconf.OmegaConf = SimpleNamespace(\n        to_container=lambda value, resolve=True: value\n    )\n\n    fake_utils_module = ModuleType(\n        \"holomotion.src.env.isaaclab_components.isaaclab_utils\"\n    )\n    fake_utils_module.resolve_holo_config = lambda value: value\n\n    isaaclab.envs = isaaclab_envs\n    isaaclab.sim = isaaclab_sim\n    isaaclab.actuators = isaaclab_actuators\n    isaaclab.assets = isaaclab_assets\n    isaaclab.managers = isaaclab_managers\n    isaaclab.markers = isaaclab_markers\n    isaaclab.scene = isaaclab_scene\n    isaaclab.sensors = isaaclab_sensors\n    isaaclab.terrains = isaaclab_terrains\n    isaaclab.utils = isaaclab_utils\n    isaaclab_envs.mdp = isaaclab_mdp\n    isaaclab_utils.math = isaaclab_math\n    isaaclab_utils.noise = isaaclab_noise\n\n    for name, module in {\n        \"isaaclab\": isaaclab,\n        \"isaaclab.envs.mdp\": isaaclab_mdp,\n        \"isaaclab.utils.math\": isaaclab_math,\n        \"isaaclab.utils.noise\": isaaclab_noise,\n        \"isaaclab.envs\": isaaclab_envs,\n        \"isaaclab.sim\": isaaclab_sim,\n        \"isaaclab.actuators\": isaaclab_actuators,\n        \"isaaclab.assets\": isaaclab_assets,\n        \"isaaclab.managers\": isaaclab_managers,\n        \"isaaclab.markers\": isaaclab_markers,\n        \"isaaclab.markers.config\": isaaclab_markers_config,\n        \"isaaclab.scene\": isaaclab_scene,\n        \"isaaclab.sensors\": isaaclab_sensors,\n        \"isaaclab.terrains\": isaaclab_terrains,\n        \"isaaclab.utils\": isaaclab_utils,\n        \"omegaconf\": omegaconf,\n        (\n            \"holomotion.src.env.isaaclab_components.isaaclab_utils\"\n        ): fake_utils_module,\n    }.items():\n        monkeypatch.setitem(sys.modules, name, module)\n\n    module_name = \"_test_observation_frames\"\n    spec = importlib.util.spec_from_file_location(\n        module_name, OBSERVATION_PATH\n    )\n    module = importlib.util.module_from_spec(spec)\n    assert spec is not None\n    assert spec.loader is not None\n    sys.modules[module_name] = module\n    spec.loader.exec_module(module)\n    return module\n\n\ndef test_ref_future_observations_can_limit_num_frames(monkeypatch):\n    observation = _load_observation_module(monkeypatch)\n\n    class _Command:\n        def get_ref_motion_dof_pos_fut(self, prefix=\"ref_\"):\n            return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)\n\n        def get_ref_motion_dof_vel_fut(self, prefix=\"ref_\"):\n            return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)\n\n        def get_ref_motion_gravity_projection_fut(self, prefix=\"ref_\"):\n            return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)\n\n        def get_ref_motion_base_linvel_fut(self, prefix=\"ref_\"):\n            return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)\n\n        def get_ref_motion_base_angvel_fut(self, prefix=\"ref_\"):\n            return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)\n\n        def get_ref_motion_root_global_pos_fut(self, prefix=\"ref_\"):\n            pos = torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)\n            pos[..., 2] = torch.tensor(\n                [[1.0, 2.0, 3.0, 4.0], [0.5, 1.5, 2.5, 3.5]],\n                dtype=torch.float32,\n            )\n            return pos\n\n    class _CommandManager:\n        def get_term(self, name):\n            return _Command()\n\n    env = SimpleNamespace(\n        command_manager=_CommandManager(),\n        scene=SimpleNamespace(env_origins=torch.zeros(2, 3)),\n    )\n\n    dof_pos = observation.ObservationFunctions._get_obs_ref_dof_pos_fut(\n        env, num_frames=2\n    )\n    dof_vel = observation.ObservationFunctions._get_obs_ref_dof_vel_fut(\n        env, num_frames=2\n    )\n    gravity = (\n        observation.ObservationFunctions._get_obs_ref_gravity_projection_fut(\n            env, num_frames=2\n        )\n    )\n    base_linvel = (\n        observation.ObservationFunctions._get_obs_ref_base_linvel_fut(\n            env, num_frames=2\n        )\n    )\n    base_angvel = (\n        observation.ObservationFunctions._get_obs_ref_base_angvel_fut(\n            env, num_frames=2\n        )\n    )\n    root_height = (\n        observation.ObservationFunctions._get_obs_ref_root_height_fut(\n            env, num_frames=2\n        )\n    )\n\n    assert dof_pos.shape == (2, 2, 3)\n    assert dof_vel.shape == (2, 6)\n    assert gravity.shape == (2, 2, 3)\n    assert base_linvel.shape == (2, 2, 3)\n    assert base_angvel.shape == (2, 2, 3)\n    assert root_height.shape == (2, 2, 1)\n    torch.testing.assert_close(\n        dof_pos,\n        torch.tensor(\n            [\n                [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n                [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n    torch.testing.assert_close(\n        dof_vel,\n        torch.tensor(\n            [\n                [0.0, 1.0, 2.0, 3.0, 4.0, 5.0],\n                [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n    torch.testing.assert_close(\n        root_height[..., 0], torch.tensor([[1.0, 2.0], [0.5, 1.5]])\n    )\n\n\ndef _make_env():\n    env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32)\n    robot_data = SimpleNamespace(\n        body_pos_w=torch.tensor(\n            [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32\n        ),\n        body_quat_w=_identity_quat(1, 2),\n    )\n    robot = SimpleNamespace(body_names=[\"anchor\", \"target\"], data=robot_data)\n    command = SimpleNamespace(\n        anchor_bodylink_name=\"anchor\",\n        get_ref_motion_anchor_bodylink_global_pos_cur=(\n            lambda prefix=\"ref_\": torch.tensor([[10.0, 0.0, 0.0]])\n        ),\n        get_ref_motion_anchor_bodylink_global_rot_wxyz_cur=(\n            lambda prefix=\"ref_\": _identity_quat(1)\n        ),\n    )\n    return SimpleNamespace(\n        num_envs=1,\n        scene=_Scene(env_origins=env_origins, robot=robot),\n        command_manager=SimpleNamespace(get_term=lambda name: command),\n    )\n\n\ndef test_global_robot_bodylink_pos_is_in_environment_frame(monkeypatch):\n    observation = _load_observation_module(monkeypatch)\n    env = _make_env()\n\n    pos = observation.ObservationFunctions._get_obs_global_robot_bodylink_pos(\n        env,\n        keybody_names=[\"target\"],\n    )\n\n    assert torch.allclose(pos, torch.tensor([[[1.0, 0.0, 0.0]]]))\n\n\ndef test_root_rel_robot_bodylink_pos_uses_consistent_env_frame(monkeypatch):\n    observation = _load_observation_module(monkeypatch)\n    env = _make_env()\n    observation.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros(\n        1, 3, dtype=torch.float32\n    )\n    observation.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1)\n\n    pos = (\n        observation.ObservationFunctions._get_obs_root_rel_robot_bodylink_pos(\n            env,\n            keybody_names=[\"target\"],\n        )\n    )\n\n    assert torch.allclose(pos, torch.tensor([[[1.0, 0.0, 0.0]]]))\n\n\ndef test_global_anchor_pos_diff_uses_environment_frame_consistently(\n    monkeypatch,\n):\n    observation = _load_observation_module(monkeypatch)\n    env = _make_env()\n\n    pos_diff = (\n        observation.ObservationFunctions._get_obs_global_anchor_pos_diff(env)\n    )\n\n    assert torch.allclose(pos_diff, torch.zeros(1, 3))\n\n\ndef test_build_additive_uniform_noise_cfg_supports_optional_z_override(\n    monkeypatch,\n):\n    observation = _load_observation_module(monkeypatch)\n\n    noise = observation._build_noise_cfg(\n        {\n            \"type\": \"AdditiveUniformNoiseCfg\",\n            \"params\": {\n                \"n_min\": -0.1,\n                \"n_max\": 0.1,\n                \"n_min_z\": -0.02,\n                \"n_max_z\": 0.03,\n            },\n        }\n    )\n\n    assert torch.equal(\n        noise.kwargs[\"n_min\"], torch.tensor([-0.1, -0.1, -0.02])\n    )\n    assert torch.equal(noise.kwargs[\"n_max\"], torch.tensor([0.1, 0.1, 0.03]))\n\n\ndef test_build_additive_uniform_noise_cfg_keeps_scalar_bounds_without_z_override(\n    monkeypatch,\n):\n    observation = _load_observation_module(monkeypatch)\n\n    noise = observation._build_noise_cfg(\n        {\n            \"type\": \"AdditiveUniformNoiseCfg\",\n            \"params\": {\n                \"n_min\": -0.1,\n                \"n_max\": 0.1,\n            },\n        }\n    )\n\n    assert noise.kwargs[\"n_min\"] == pytest.approx(-0.1)\n    assert noise.kwargs[\"n_max\"] == pytest.approx(0.1)\n"
  },
  {
    "path": "tests/test_onnx_attention_export.py",
    "content": "import sys\nimport tempfile\nfrom pathlib import Path\n\nimport numpy as np\nimport onnx\nimport onnxruntime\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.modules.network_modules import (\n    export_safe_scaled_dot_product_attention,\n)\n\n\nclass _ExportAttentionModule(nn.Module):\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        mask: torch.Tensor,\n    ) -> torch.Tensor:\n        return export_safe_scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attn_mask=mask,\n            dropout_p=0.0,\n            is_causal=False,\n        )\n\n\nclass _ExportCausalAttentionModule(nn.Module):\n    def forward(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n    ) -> torch.Tensor:\n        return export_safe_scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attn_mask=None,\n            dropout_p=0.0,\n            is_causal=True,\n        )\n\n\ndef _export_model(\n    export_path: Path,\n    module: nn.Module,\n    inputs: tuple[torch.Tensor, ...],\n    input_names: list[str],\n) -> None:\n    torch.onnx.export(\n        module.eval(),\n        inputs,\n        str(export_path),\n        opset_version=17,\n        input_names=input_names,\n        output_names=[\"out\"],\n        dynamo=False,\n        verbose=False,\n    )\n\n\ndef _export_op_types(\n    module: nn.Module,\n    *inputs: torch.Tensor,\n    input_names: list[str],\n) -> list[str]:\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        export_path = Path(tmp_dir) / \"attention.onnx\"\n        _export_model(export_path, module, inputs, input_names)\n        model = onnx.load(str(export_path))\n    return [node.op_type for node in model.graph.node]\n\n\ndef _run_onnx(\n    module: nn.Module,\n    *inputs: torch.Tensor,\n    input_names: list[str],\n) -> np.ndarray:\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        export_path = Path(tmp_dir) / \"attention.onnx\"\n        _export_model(export_path, module, inputs, input_names)\n        session = onnxruntime.InferenceSession(\n            str(export_path),\n            providers=[\"CPUExecutionProvider\"],\n        )\n        feed = {\n            name: tensor.detach().cpu().numpy()\n            for name, tensor in zip(input_names, inputs, strict=True)\n        }\n        outputs = session.run([\"out\"], feed)\n    return outputs[0]\n\n\ndef test_export_safe_attention_uses_native_bool_mask_outside_export(\n    monkeypatch,\n):\n    captured = {}\n    original_sdpa = F.scaled_dot_product_attention\n\n    def _spy_sdpa(\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        *,\n        attn_mask: torch.Tensor | None = None,\n        dropout_p: float = 0.0,\n        is_causal: bool = False,\n        enable_gqa: bool = False,\n    ) -> torch.Tensor:\n        captured[\"mask_dtype\"] = None if attn_mask is None else attn_mask.dtype\n        return original_sdpa(\n            q,\n            k,\n            v,\n            attn_mask=attn_mask,\n            dropout_p=dropout_p,\n            is_causal=is_causal,\n            enable_gqa=enable_gqa,\n        )\n\n    monkeypatch.setattr(torch.onnx, \"is_in_onnx_export\", lambda: False)\n    monkeypatch.setattr(F, \"scaled_dot_product_attention\", _spy_sdpa)\n\n    q = torch.randn(1, 2, 3, 4)\n    k = torch.randn(1, 2, 5, 4)\n    v = torch.randn(1, 2, 5, 4)\n    mask = torch.ones(1, 1, 3, 5, dtype=torch.bool)\n\n    export_safe_scaled_dot_product_attention(\n        q,\n        k,\n        v,\n        attn_mask=mask,\n        dropout_p=0.0,\n        is_causal=False,\n    )\n\n    assert captured[\"mask_dtype\"] == torch.bool\n\n\ndef test_export_safe_attention_matches_sdpa_for_valid_masks():\n    torch.manual_seed(0)\n    q = torch.randn(2, 4, 3, 8)\n    k = torch.randn(2, 4, 5, 8)\n    v = torch.randn(2, 4, 5, 8)\n    mask = torch.tensor(\n        [\n            [[[True, True, False, False, False]]],\n            [[[True, False, False, False, False]]],\n        ],\n        dtype=torch.bool,\n    ).expand(2, 1, 3, 5)\n\n    expected = F.scaled_dot_product_attention(\n        q,\n        k,\n        v,\n        attn_mask=mask,\n        dropout_p=0.0,\n        is_causal=False,\n    )\n    actual = export_safe_scaled_dot_product_attention(\n        q,\n        k,\n        v,\n        attn_mask=mask,\n        dropout_p=0.0,\n        is_causal=False,\n    )\n\n    torch.testing.assert_close(actual, expected, atol=1.0e-6, rtol=1.0e-5)\n\n\ndef test_legacy_attention_export_avoids_isnan():\n    torch.manual_seed(2)\n    q = torch.randn(1, 4, 1, 8)\n    k = torch.randn(1, 4, 16, 8)\n    v = torch.randn(1, 4, 16, 8)\n    mask = torch.ones(1, 1, 1, 16, dtype=torch.bool)\n\n    op_types = _export_op_types(\n        _ExportAttentionModule(),\n        q,\n        k,\n        v,\n        mask,\n        input_names=[\"q\", \"k\", \"v\", \"mask\"],\n    )\n\n    assert \"IsNaN\" not in op_types\n\n\ndef test_legacy_attention_export_ort_matches_pytorch_for_future_mask():\n    torch.manual_seed(3)\n    q = torch.randn(2, 4, 3, 8)\n    k = torch.randn(2, 4, 5, 8)\n    v = torch.randn(2, 4, 5, 8)\n    mask = torch.tensor(\n        [\n            [[[True, True, True, False, False]]],\n            [[[True, False, False, False, False]]],\n        ],\n        dtype=torch.bool,\n    ).expand(2, 1, 3, 5)\n\n    expected = export_safe_scaled_dot_product_attention(\n        q,\n        k,\n        v,\n        attn_mask=mask,\n        dropout_p=0.0,\n        is_causal=False,\n    )\n    actual = _run_onnx(\n        _ExportAttentionModule(),\n        q,\n        k,\n        v,\n        mask,\n        input_names=[\"q\", \"k\", \"v\", \"mask\"],\n    )\n\n    np.testing.assert_allclose(\n        actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5\n    )\n\n\ndef test_legacy_attention_export_ort_matches_pytorch_for_causal_path():\n    torch.manual_seed(4)\n    q = torch.randn(2, 4, 6, 8)\n    k = torch.randn(2, 4, 6, 8)\n    v = torch.randn(2, 4, 6, 8)\n\n    expected = export_safe_scaled_dot_product_attention(\n        q,\n        k,\n        v,\n        attn_mask=None,\n        dropout_p=0.0,\n        is_causal=True,\n    )\n    actual = _run_onnx(\n        _ExportCausalAttentionModule(),\n        q,\n        k,\n        v,\n        input_names=[\"q\", \"k\", \"v\"],\n    )\n\n    np.testing.assert_allclose(\n        actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5\n    )\n\n\ndef test_legacy_attention_export_ort_matches_pytorch_for_kv_mask():\n    torch.manual_seed(5)\n    q = torch.randn(2, 4, 1, 8)\n    k = torch.randn(2, 4, 16, 8)\n    v = torch.randn(2, 4, 16, 8)\n    valid_lengths = torch.tensor([16, 5], dtype=torch.int64)\n    mask = (\n        torch.arange(16, dtype=torch.int64)[None, :] < valid_lengths[:, None]\n    )\n    mask = mask[:, None, None, :]\n\n    expected = export_safe_scaled_dot_product_attention(\n        q,\n        k,\n        v,\n        attn_mask=mask,\n        dropout_p=0.0,\n        is_causal=False,\n    )\n    actual = _run_onnx(\n        _ExportAttentionModule(),\n        q,\n        k,\n        v,\n        mask,\n        input_names=[\"q\", \"k\", \"v\", \"mask\"],\n    )\n\n    np.testing.assert_allclose(\n        actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5\n    )\n"
  },
  {
    "path": "tests/test_onnx_export.py",
    "content": "import sys\nfrom types import SimpleNamespace\n\nfrom holomotion.src.utils.onnx_export import attach_onnx_metadata_holomotion\n\n\nclass _FakeEntry:\n    def __init__(self):\n        self.key = \"\"\n        self.value = \"\"\n\n\nclass _FakeTensor:\n    def __init__(self, values):\n        self._values = values\n\n    def __getitem__(self, index):\n        return _FakeTensor(self._values[index])\n\n    def cpu(self):\n        return self\n\n    def tolist(self):\n        return self._values\n\n\ndef test_attach_onnx_metadata_uses_default_joint_gains(monkeypatch):\n    model = SimpleNamespace(metadata_props=[])\n    fake_onnx = SimpleNamespace(\n        load=lambda path: model,\n        save=lambda loaded_model, path: None,\n        StringStringEntryProto=_FakeEntry,\n    )\n    monkeypatch.setitem(sys.modules, \"onnx\", fake_onnx)\n\n    robot_data = SimpleNamespace(\n        joint_names=[\"joint_a\", \"joint_b\"],\n        joint_stiffness=_FakeTensor([[0.0, 0.0]]),\n        joint_damping=_FakeTensor([[0.0, 0.0]]),\n        default_joint_stiffness=_FakeTensor([[10.0, 20.0]]),\n        default_joint_damping=_FakeTensor([[1.0, 2.0]]),\n        default_joint_pos=_FakeTensor([[0.1, -0.2]]),\n    )\n    action_term = SimpleNamespace(_scale=_FakeTensor([[0.5, 0.25]]))\n    env = SimpleNamespace(\n        scene={\"robot\": SimpleNamespace(data=robot_data)},\n        action_manager=SimpleNamespace(\n            get_term=lambda name: action_term,\n        ),\n    )\n\n    attach_onnx_metadata_holomotion(env, \"dummy.onnx\")\n\n    metadata = {entry.key: entry.value for entry in model.metadata_props}\n    assert metadata[\"joint_stiffness\"] == \"10.000,20.000\"\n    assert metadata[\"joint_damping\"] == \"1.000,2.000\"\n"
  },
  {
    "path": "tests/test_plot_moe_expert_heatmap.py",
    "content": "import importlib.util\nfrom pathlib import Path\nfrom unittest.mock import MagicMock\n\nimport numpy as np\n\nSCRIPT_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"not_for_commit\"\n    / \"plot_moe_expert_heatmap.py\"\n)\n\n\ndef _load_plot_moe_expert_heatmap_module():\n    spec = importlib.util.spec_from_file_location(\n        \"plot_moe_expert_heatmap\", SCRIPT_PATH\n    )\n    module = importlib.util.module_from_spec(spec)\n    assert spec.loader is not None\n    spec.loader.exec_module(module)\n    return module\n\n\ndef _write_eval_npz(path: Path) -> None:\n    np.savez(\n        path,\n        robot_moe_expert_logits=np.array(\n            [\n                [[0.0, 1.0, 2.0, 3.0], [1.5, 0.5, -0.5, -1.5]],\n                [[0.1, 1.1, 2.1, 3.1], [1.0, 0.0, -1.0, -2.0]],\n                [[0.2, 1.2, 2.2, 3.2], [0.5, -0.5, -1.5, -2.5]],\n                [[0.3, 1.3, 2.3, 3.3], [0.0, -1.0, -2.0, -3.0]],\n                [[0.4, 1.4, 2.4, 3.4], [-0.5, -1.5, -2.5, -3.5]],\n            ],\n            dtype=np.float32,\n        ),\n        robot_moe_expert_indices=np.array(\n            [\n                [[3, 2], [0, 1]],\n                [[3, 2], [0, 1]],\n                [[3, 2], [0, 1]],\n                [[3, 2], [0, 1]],\n                [[3, 2], [0, 1]],\n            ],\n            dtype=np.int64,\n        ),\n        robot_dof_torque=np.linspace(-1.0, 1.0, 15, dtype=np.float32).reshape(\n            5, 3\n        ),\n        robot_actions=np.linspace(-0.5, 0.5, 15, dtype=np.float32).reshape(\n            5, 3\n        ),\n        robot_low_level_dof_torque=np.zeros((20, 3), dtype=np.float32),\n        robot_low_level_torque_dt=np.array(0.01, dtype=np.float32),\n    )\n\n\ndef test_plot_dump_exports_moe_heatmap_pdf(tmp_path):\n    module = _load_plot_moe_expert_heatmap_module()\n\n    npz_path = tmp_path / \"demo_eval.npz\"\n    _write_eval_npz(npz_path)\n\n    output_path = module.plot_dump(npz_path)\n\n    assert output_path == (\n        tmp_path / \"demo_eval_moe_expert_probability_heatmap.pdf\"\n    )\n    assert output_path.is_file()\n    assert (tmp_path / \"demo_eval_robot_dof_torque_line_plot.pdf\").is_file()\n    assert (tmp_path / \"demo_eval_robot_actions_line_plot.pdf\").is_file()\n\n\ndef test_selected_expert_weights_are_renormalized_within_selected_ids():\n    module = _load_plot_moe_expert_heatmap_module()\n\n    probabilities = np.array(\n        [\n            [[0.1, 0.2, 0.3, 0.4], [0.7, 0.1, 0.1, 0.1]],\n            [[0.25, 0.25, 0.25, 0.25], [0.05, 0.15, 0.3, 0.5]],\n        ],\n        dtype=np.float32,\n    )\n    expert_indices = np.array(\n        [\n            [[1, 3], [0, 2]],\n            [[0, 2], [1, 3]],\n        ],\n        dtype=np.int64,\n    )\n\n    selected_weights = module.compute_selected_expert_weights(\n        probabilities, expert_indices\n    )\n\n    np.testing.assert_allclose(\n        selected_weights,\n        np.array(\n            [\n                [[1.0 / 3.0, 2.0 / 3.0], [0.875, 0.125]],\n                [[0.5, 0.5], [0.23076923, 0.7692308]],\n            ],\n            dtype=np.float32,\n        ),\n    )\n\n\ndef test_selected_expert_heatmap_only_colors_activated_experts():\n    module = _load_plot_moe_expert_heatmap_module()\n\n    probabilities = np.array(\n        [\n            [[0.1, 0.2, 0.3, 0.4], [0.7, 0.1, 0.1, 0.1]],\n            [[0.25, 0.25, 0.25, 0.25], [0.05, 0.15, 0.3, 0.5]],\n        ],\n        dtype=np.float32,\n    )\n    expert_indices = np.array(\n        [\n            [[1, 3], [0, 2]],\n            [[0, 2], [1, 3]],\n        ],\n        dtype=np.int64,\n    )\n\n    selected_heatmap = module.build_selected_expert_heatmap(\n        probabilities, expert_indices\n    )\n\n    np.testing.assert_allclose(\n        selected_heatmap,\n        np.array(\n            [\n                [\n                    [0.0, 1.0 / 3.0, 0.0, 2.0 / 3.0],\n                    [0.875, 0.0, 0.125, 0.0],\n                ],\n                [\n                    [0.5, 0.0, 0.5, 0.0],\n                    [0.0, 0.23076923, 0.0, 0.7692308],\n                ],\n            ],\n            dtype=np.float32,\n        ),\n    )\n\n\ndef test_collect_npz_paths_recursively_sorts_directory_entries(tmp_path):\n    module = _load_plot_moe_expert_heatmap_module()\n\n    input_dir = tmp_path / \"evals\"\n    first_npz = input_dir / \"z_branch\" / \"clip_z.npz\"\n    second_npz = input_dir / \"a_branch\" / \"nested\" / \"clip_a.npz\"\n    second_npz.parent.mkdir(parents=True)\n    first_npz.parent.mkdir(parents=True)\n    _write_eval_npz(first_npz)\n    _write_eval_npz(second_npz)\n    (input_dir / \"ignore.txt\").write_text(\"ignore\", encoding=\"utf-8\")\n\n    assert module.collect_npz_paths(input_dir) == [second_npz, first_npz]\n\n\ndef test_plot_input_path_directory_generates_all_heatmaps_with_tqdm(\n    tmp_path,\n):\n    module = _load_plot_moe_expert_heatmap_module()\n\n    input_dir = tmp_path / \"evals\"\n    npz_paths = [\n        input_dir / \"z_branch\" / \"clip_z.npz\",\n        input_dir / \"a_branch\" / \"nested\" / \"clip_a.npz\",\n    ]\n    for npz_path in npz_paths:\n        npz_path.parent.mkdir(parents=True, exist_ok=True)\n        _write_eval_npz(npz_path)\n\n    expected_output_paths = [\n        input_dir\n        / \"a_branch\"\n        / \"nested\"\n        / \"clip_a_moe_expert_probability_heatmap.pdf\",\n        input_dir / \"z_branch\" / \"clip_z_moe_expert_probability_heatmap.pdf\",\n    ]\n\n    fake_tqdm = MagicMock(side_effect=lambda iterable, **_: iterable)\n    original_tqdm = module.tqdm\n    module.tqdm = fake_tqdm\n    try:\n        output_paths = module.plot_input_path(input_dir)\n    finally:\n        module.tqdm = original_tqdm\n\n    assert output_paths == expected_output_paths\n    assert all(path.is_file() for path in expected_output_paths)\n    expected_torque_paths = [\n        input_dir\n        / \"a_branch\"\n        / \"nested\"\n        / \"clip_a_robot_dof_torque_line_plot.pdf\",\n        input_dir / \"z_branch\" / \"clip_z_robot_dof_torque_line_plot.pdf\",\n    ]\n    assert all(path.is_file() for path in expected_torque_paths)\n    expected_action_paths = [\n        input_dir\n        / \"a_branch\"\n        / \"nested\"\n        / \"clip_a_robot_actions_line_plot.pdf\",\n        input_dir / \"z_branch\" / \"clip_z_robot_actions_line_plot.pdf\",\n    ]\n    assert all(path.is_file() for path in expected_action_paths)\n    assert list(fake_tqdm.call_args.args[0]) == sorted(npz_paths)\n    assert fake_tqdm.call_args.kwargs == {\n        \"desc\": \"Generating plot PDFs\",\n        \"unit\": \"file\",\n        \"dynamic_ncols\": True,\n    }\n\n\ndef test_plot_dump_requires_2d_robot_dof_torque(tmp_path):\n    module = _load_plot_moe_expert_heatmap_module()\n\n    npz_path = tmp_path / \"bad_eval.npz\"\n    np.savez(\n        npz_path,\n        robot_moe_expert_logits=np.zeros((2, 1, 3), dtype=np.float32),\n        robot_dof_torque=np.zeros((2,), dtype=np.float32),\n    )\n\n    try:\n        module.plot_dump(npz_path)\n    except ValueError as exc:\n        assert \"robot_dof_torque must have shape [frames, dofs]\" in str(exc)\n    else:\n        raise AssertionError(\n            \"Expected plot_dump to reject 1-D robot_dof_torque\"\n        )\n"
  },
  {
    "path": "tests/test_plot_state_series.py",
    "content": "import importlib.util\nfrom pathlib import Path\n\nimport numpy as np\n\nSCRIPT_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"not_for_commit\"\n    / \"plot_state_series.py\"\n)\n\n\ndef _load_plot_state_series_module():\n    spec = importlib.util.spec_from_file_location(\n        \"plot_state_series\", SCRIPT_PATH\n    )\n    module = importlib.util.module_from_spec(spec)\n    assert spec.loader is not None\n    spec.loader.exec_module(module)\n    return module\n\n\ndef test_plot_dump_exports_time_matched_scalar_series(tmp_path):\n    module = _load_plot_state_series_module()\n\n    robot_config_path = tmp_path / \"robot.yaml\"\n    robot_config_path.write_text(\n        \"robot:\\n  dof_names:\\n    - joint_a\\n    - joint_b\\n\",\n        encoding=\"utf-8\",\n    )\n\n    npz_path = tmp_path / \"demo_eval.npz\"\n    np.savez(\n        npz_path,\n        robot_dof_torque=np.arange(10, dtype=np.float32).reshape(5, 2),\n        robot_dof_acc=np.arange(10, 20, dtype=np.float32).reshape(5, 2),\n        robot_action_rate=np.linspace(0.0, 1.0, 5, dtype=np.float32),\n        reward=np.linspace(1.0, 2.0, 5, dtype=np.float32),\n        bad_scalar=np.array([1.0, 2.0], dtype=np.float32),\n        metadata=np.array(\"demo\", dtype=\"<U4\"),\n    )\n\n    module.plot_dump(npz_path, robot_config_path)\n\n    output_dir = tmp_path / \"demo_eval\"\n    assert (output_dir / \"torque.pdf\").is_file()\n    assert (output_dir / \"dof_acc.pdf\").is_file()\n    assert (output_dir / \"robot_action_rate.pdf\").is_file()\n    assert (output_dir / \"reward.pdf\").is_file()\n    assert not (output_dir / \"bad_scalar.pdf\").exists()\n    assert not (output_dir / \"metadata.pdf\").exists()\n"
  },
  {
    "path": "tests/test_ppo_checkpoint_sigma_override.py",
    "content": "from types import SimpleNamespace\nfrom unittest import mock\n\nimport torch\nimport torch.nn as nn\nfrom holomotion.src.algo.ppo import PPO, _checkpoint_state_to_cpu\nfrom holomotion.src.modules.agent_modules import PPOActor\n\n\nclass _DummyActor(nn.Module):\n    def __init__(self, events: list[str]):\n        super().__init__()\n        self.events = events\n        self.noise_std_type = \"log\"\n        self.log_std = nn.Parameter(torch.zeros(3, dtype=torch.float32))\n\n    def override_sigma(self, sigma_override):\n        self.events.append(\"override_sigma\")\n        PPOActor.override_sigma(self, sigma_override)\n\n\ndef test_ppo_load_reapplies_sigma_override_after_checkpoint_restore():\n    events: list[str] = []\n    actor = _DummyActor(events)\n    algo = PPO.__new__(PPO)\n    algo.is_main_process = False\n    algo.device = torch.device(\"cpu\")\n    algo.actor = actor\n    algo.critic = nn.Linear(1, 1)\n    algo.accelerator = SimpleNamespace(unwrap_model=lambda model: model)\n    algo.config = {\"override_sigma\": True, \"sigma_override\": 0.1}\n    algo._load_extra_checkpoint_state = mock.Mock()\n    algo._resolve_model_file_path = (\n        lambda ckpt_path, model_name: f\"{ckpt_path}:{model_name}\"\n    )\n\n    algo.actor_optimizer = mock.Mock()\n    algo.actor_optimizer.load_state_dict.side_effect = (\n        lambda state_dict: events.append(\"actor_optimizer\")\n    )\n    algo.critic_optimizer = mock.Mock()\n    algo.critic_optimizer.load_state_dict.side_effect = (\n        lambda state_dict: events.append(\"critic_optimizer\")\n    )\n\n    loaded_sigma = torch.tensor([0.7, 0.8, 0.9], dtype=torch.float32)\n\n    def _fake_load_accelerate_model(model, model_path, *, strict):\n        events.append(f\"load:{model_path}\")\n        if model is actor:\n            with torch.no_grad():\n                model.log_std.copy_(loaded_sigma.log())\n\n    algo._load_accelerate_model = mock.Mock(\n        side_effect=_fake_load_accelerate_model\n    )\n\n    loaded_dict = {\n        \"actor_optimizer_state_dict\": {\"state\": {}},\n        \"critic_optimizer_state_dict\": {\"state\": {}},\n        \"iter\": 123,\n        \"infos\": {\"source\": \"unit-test\"},\n    }\n\n    with mock.patch(\n        \"holomotion.src.algo.ppo.torch.load\", return_value=loaded_dict\n    ):\n        infos = algo.load(\"checkpoint.pt\")\n\n    assert infos == {\"source\": \"unit-test\"}\n    assert torch.allclose(\n        actor.log_std.exp(),\n        torch.full((3,), 0.1, dtype=torch.float32),\n    )\n    assert events.index(\"override_sigma\") > events.index(\"actor_optimizer\")\n    assert events.index(\"override_sigma\") > events.index(\"critic_optimizer\")\n    algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict)\n\n\ndef test_ppo_load_skips_optimizer_restore_during_offline_eval():\n    algo = PPO.__new__(PPO)\n    algo.is_main_process = False\n    algo.is_offline_eval = True\n    algo.device = torch.device(\"cpu\")\n    algo.actor = nn.Linear(1, 1)\n    algo.critic = nn.Linear(1, 1)\n    algo.accelerator = SimpleNamespace(unwrap_model=lambda model: model)\n    algo.config = {}\n    algo._load_extra_checkpoint_state = mock.Mock()\n    algo._resolve_model_file_path = (\n        lambda ckpt_path, model_name: f\"{ckpt_path}:{model_name}\"\n    )\n    algo._load_accelerate_model = mock.Mock()\n    algo._maybe_override_loaded_actor_sigma = mock.Mock()\n\n    algo.actor_optimizer = mock.Mock()\n    algo.critic_optimizer = mock.Mock()\n\n    loaded_dict = {\n        \"actor_optimizer_state_dict\": {\"state\": {\"stale\": {}}},\n        \"critic_optimizer_state_dict\": {\"state\": {\"stale\": {}}},\n        \"iter\": 321,\n        \"infos\": {\"source\": \"offline-eval\"},\n    }\n\n    with mock.patch(\n        \"holomotion.src.algo.ppo.torch.load\", return_value=loaded_dict\n    ):\n        infos = algo.load(\"checkpoint.pt\")\n\n    assert infos == {\"source\": \"offline-eval\"}\n    assert algo.current_learning_iteration == 321\n    algo.actor_optimizer.load_state_dict.assert_not_called()\n    algo.critic_optimizer.load_state_dict.assert_not_called()\n    algo._maybe_override_loaded_actor_sigma.assert_called_once_with()\n    algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict)\n\n\ndef test_ppo_load_skips_incompatible_optimizer_state_restore():\n    algo = PPO.__new__(PPO)\n    algo.is_main_process = False\n    algo.is_offline_eval = False\n    algo.device = torch.device(\"cpu\")\n    algo.actor = nn.Linear(1, 1)\n    algo.critic = nn.Linear(1, 1)\n    algo.accelerator = SimpleNamespace(unwrap_model=lambda model: model)\n    algo.config = {}\n    algo._load_extra_checkpoint_state = mock.Mock()\n    algo._resolve_model_file_path = (\n        lambda ckpt_path, model_name: f\"{ckpt_path}:{model_name}\"\n    )\n    algo._load_accelerate_model = mock.Mock()\n    algo._maybe_override_loaded_actor_sigma = mock.Mock()\n\n    algo.actor_optimizer = mock.Mock()\n    algo.actor_optimizer.state_dict.return_value = {\n        \"state\": {},\n        \"param_groups\": [{\"params\": [0]}],\n    }\n    algo.actor_optimizer.load_state_dict.side_effect = AssertionError(\n        \"incompatible actor optimizer state should be skipped\"\n    )\n    algo.critic_optimizer = mock.Mock()\n    algo.critic_optimizer.state_dict.return_value = {\n        \"state\": {},\n        \"param_groups\": [{\"params\": [0]}],\n    }\n\n    loaded_dict = {\n        \"actor_optimizer_state_dict\": {\n            \"state\": {0: {\"step\": torch.tensor(1)}},\n            \"param_groups\": [{\"params\": [0, 1]}],\n        },\n        \"critic_optimizer_state_dict\": {\n            \"state\": {0: {\"step\": torch.tensor(2)}},\n            \"param_groups\": [{\"params\": [0]}],\n        },\n        \"iter\": 77,\n        \"infos\": {\"source\": \"resume-training\"},\n    }\n\n    with mock.patch(\n        \"holomotion.src.algo.ppo.torch.load\", return_value=loaded_dict\n    ):\n        infos = algo.load(\"checkpoint.pt\")\n\n    assert infos == {\"source\": \"resume-training\"}\n    assert algo.current_learning_iteration == 77\n    algo.actor_optimizer.load_state_dict.assert_not_called()\n    algo.critic_optimizer.load_state_dict.assert_called_once_with(\n        loaded_dict[\"critic_optimizer_state_dict\"]\n    )\n    algo._maybe_override_loaded_actor_sigma.assert_called_once_with()\n    algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict)\n\n\ndef test_checkpoint_state_to_cpu_moves_nested_tensors():\n    source = {\n        \"state\": {\n            0: {\n                \"exp_avg\": torch.tensor([1.0, 2.0], requires_grad=True),\n                \"exp_avg_sq\": torch.tensor([3.0, 4.0]),\n            }\n        },\n        \"param_groups\": [{\"lr\": 1.0e-3}],\n        \"step_tensor\": torch.tensor([5]),\n    }\n\n    converted = _checkpoint_state_to_cpu(source)\n\n    assert converted is not source\n    assert converted[\"state\"] is not source[\"state\"]\n    assert converted[\"state\"][0][\"exp_avg\"].device.type == \"cpu\"\n    assert converted[\"state\"][0][\"exp_avg_sq\"].device.type == \"cpu\"\n    assert converted[\"step_tensor\"].device.type == \"cpu\"\n    assert converted[\"state\"][0][\"exp_avg\"].requires_grad is False\n    torch.testing.assert_close(\n        converted[\"state\"][0][\"exp_avg\"],\n        source[\"state\"][0][\"exp_avg\"].detach(),\n    )\n"
  },
  {
    "path": "tests/test_ppo_entropy_annealing.py",
    "content": "from pathlib import Path\nimport sys\nfrom types import SimpleNamespace\n\nimport pytest\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.algo.algo_base import BaseOnpolicyRL\nfrom holomotion.src.algo.ppo import PPO\n\n\ndef _build_entropy_algo(\n    *,\n    initial_entropy_coef: float,\n    anneal_entropy: bool,\n    zero_entropy_point: float,\n    current_learning_iteration: int,\n    total_learning_iterations: int,\n    num_learning_iterations: int = 0,\n):\n    algo = PPO.__new__(PPO)\n    algo.initial_entropy_coef = float(initial_entropy_coef)\n    algo.anneal_entropy = bool(anneal_entropy)\n    algo.zero_entropy_point = float(zero_entropy_point)\n    algo.current_learning_iteration = int(current_learning_iteration)\n    algo.total_learning_iterations = int(total_learning_iterations)\n    algo.num_learning_iterations = int(num_learning_iterations)\n    return algo\n\n\ndef test_entropy_coef_is_constant_when_annealing_disabled():\n    algo = _build_entropy_algo(\n        initial_entropy_coef=5.0e-3,\n        anneal_entropy=False,\n        zero_entropy_point=1.0,\n        current_learning_iteration=50,\n        total_learning_iterations=100,\n    )\n\n    assert algo._get_effective_entropy_coef() == pytest.approx(5.0e-3)\n\n\ndef test_entropy_coef_decays_and_respects_resumed_total_iterations():\n    algo = _build_entropy_algo(\n        initial_entropy_coef=5.0e-3,\n        anneal_entropy=True,\n        zero_entropy_point=1.0,\n        current_learning_iteration=123,\n        total_learning_iterations=133,\n        num_learning_iterations=10,\n    )\n\n    expected = 5.0e-3 * max(0.0, 1.0 - 123.0 / 133.0)\n    assert algo._get_effective_entropy_coef() == pytest.approx(expected)\n\n\ndef test_entropy_coef_clamps_to_zero_at_and_after_zero_point():\n    algo = _build_entropy_algo(\n        initial_entropy_coef=5.0e-3,\n        anneal_entropy=True,\n        zero_entropy_point=0.75,\n        current_learning_iteration=75,\n        total_learning_iterations=100,\n    )\n\n    assert algo._get_effective_entropy_coef() == pytest.approx(0.0)\n\n    algo.current_learning_iteration = 90\n    assert algo._get_effective_entropy_coef() == pytest.approx(0.0)\n\n\n@pytest.mark.parametrize(\n    (\"initial_entropy_coef\", \"anneal_entropy\", \"zero_entropy_point\"),\n    [\n        (-1.0, False, 1.0),\n        (1.0, True, 0.0),\n        (1.0, True, -0.1),\n        (1.0, True, 1.1),\n    ],\n)\ndef test_validate_entropy_schedule_config_rejects_invalid_values(\n    initial_entropy_coef: float,\n    anneal_entropy: bool,\n    zero_entropy_point: float,\n):\n    with pytest.raises(ValueError):\n        PPO._validate_entropy_schedule_config(\n            initial_entropy_coef=initial_entropy_coef,\n            anneal_entropy=anneal_entropy,\n            zero_entropy_point=zero_entropy_point,\n        )\n\n\ndef test_learn_sets_current_iteration_before_each_update():\n    algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL)\n    algo.env = SimpleNamespace(reset_all=lambda: ({},))\n    algo._wrap_obs_dict = lambda obs_dict: obs_dict\n    algo._ensure_storage = lambda obs_td: None\n    algo.train_mode = lambda: None\n    algo.rollout_policy = lambda obs_td: obs_td\n    algo.log_dir = \"/tmp/holomotion-test\"\n    algo.num_learning_iterations = 3\n    algo.current_learning_iteration = 5\n    algo.total_learning_iterations = 0\n    algo.log_interval = 100\n    algo.save_interval = 100\n    algo.is_main_process = False\n    algo.ep_infos = []\n    algo._post_iteration_hook = lambda it: None\n    algo._post_training_hook = lambda: None\n    algo._release_cuda_cache = lambda: None\n    algo.save = lambda *args, **kwargs: None\n    algo.accelerator = SimpleNamespace(\n        wait_for_everyone=lambda: None,\n        end_training=lambda: None,\n    )\n\n    observed_iterations = []\n    observed_totals = []\n\n    def _update():\n        observed_iterations.append(algo.current_learning_iteration)\n        observed_totals.append(algo.total_learning_iterations)\n        return {}\n\n    algo.update = _update\n\n    BaseOnpolicyRL.learn(algo)\n\n    assert observed_iterations == [5, 6, 7]\n    assert observed_totals == [8, 8, 8]\n"
  },
  {
    "path": "tests/test_ppo_symmetry_loss.py",
    "content": "from contextlib import nullcontext\nfrom types import ModuleType, SimpleNamespace\nimport sys\n\nimport pytest\nimport torch\nimport torch.nn as nn\nfrom omegaconf import OmegaConf\nfrom tensordict import TensorDict\n\nfrom holomotion.src.algo.ppo import PPO\n\n\nclass _DummyAccelerator:\n    def autocast(self):\n        return nullcontext()\n\n    def backward(self, loss):\n        loss.backward()\n\n    def clip_grad_norm_(self, parameters, max_norm):\n        torch.nn.utils.clip_grad_norm_(list(parameters), max_norm)\n\n    def reduce(self, tensor, reduction=\"mean\"):\n        return tensor\n\n\nclass _DummyActor(nn.Module):\n    def __init__(self, num_actions: int, mirror_offset: float):\n        super().__init__()\n        self.mu_param = nn.Parameter(torch.full((num_actions,), 0.25))\n        self.log_std = nn.Parameter(torch.zeros(num_actions))\n        self.mirror_offset = float(mirror_offset)\n\n    def forward(\n        self,\n        obs_td: TensorDict,\n        actions: torch.Tensor | None = None,\n        mode: str = \"sampling\",\n        *,\n        update_obs_norm: bool = True,\n    ) -> TensorDict:\n        del obs_td, update_obs_norm\n        batch_size = int(actions.shape[0]) if actions is not None else 2\n        mu = self.mu_param.unsqueeze(0).expand(batch_size, -1)\n        sigma = torch.exp(self.log_std).unsqueeze(0).expand(batch_size, -1)\n        out = TensorDict({}, batch_size=[batch_size])\n        out.set(\"mu\", mu)\n        out.set(\"sigma\", sigma)\n        if mode == \"inference\":\n            out.set(\"actions\", mu + self.mirror_offset)\n            return out\n        if actions is None:\n            actions = mu\n        out.set(\"actions\", actions)\n        zero_with_grad = mu.sum(dim=-1) * 0.0\n        out.set(\"actions_log_prob\", zero_with_grad)\n        out.set(\"entropy\", zero_with_grad)\n        return out\n\n\nclass _DummyCritic(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.value = nn.Parameter(torch.tensor([0.1], dtype=torch.float32))\n\n    def forward(self, obs_td: TensorDict, *, update_obs_norm: bool = True):\n        del obs_td, update_obs_norm\n        batch_size = 2\n        out = TensorDict({}, batch_size=[batch_size])\n        out.set(\"values\", self.value.view(1, 1).expand(batch_size, 1))\n        return out\n\n\nclass _SingleBatchStorage:\n    def __init__(self, batch):\n        self._batch = batch\n        self.data = {\n            \"returns\": torch.zeros(2, 1, 1, dtype=torch.float32),\n            \"values\": torch.zeros(2, 1, 1, dtype=torch.float32),\n        }\n        self.num_envs = 1\n        self.num_transitions_per_env = 2\n        self.cleared = False\n\n    def iter_minibatches(self, num_mini_batches: int, num_epochs: int):\n        del num_mini_batches, num_epochs\n        yield self._batch\n\n    def clear(self):\n        self.cleared = True\n\n\ndef _install_mirror_stub():\n    module = ModuleType(\n        \"holomotion.src.env.isaaclab_components.isaaclab_observation\"\n    )\n\n    class MirrorFunctions:\n        @staticmethod\n        def mirror_dof(\n            x: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor\n        ):\n            perm = perm.to(device=x.device, dtype=torch.long)\n            sign = sign.to(device=x.device, dtype=x.dtype)\n            mirrored = torch.index_select(x, dim=x.ndim - 1, index=perm)\n            view_shape = [1] * (mirrored.ndim - 1) + [int(sign.numel())]\n            return mirrored * sign.view(*view_shape)\n\n        @staticmethod\n        def mirror_action(\n            actions: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor\n        ):\n            return MirrorFunctions.mirror_dof(actions, perm=perm, sign=sign)\n\n        @staticmethod\n        def mirror_vec3(x: torch.Tensor):\n            sign = torch.tensor(\n                [1.0, -1.0, 1.0], device=x.device, dtype=x.dtype\n            )\n            view_shape = [1] * (x.ndim - 1) + [3]\n            return x * sign.view(*view_shape)\n\n        @staticmethod\n        def mirror_axial_vec3(x: torch.Tensor):\n            sign = torch.tensor(\n                [-1.0, 1.0, -1.0], device=x.device, dtype=x.dtype\n            )\n            view_shape = [1] * (x.ndim - 1) + [3]\n            return x * sign.view(*view_shape)\n\n        @staticmethod\n        def mirror_velocity_command(x: torch.Tensor):\n            if x.shape[-1] == 3:\n                sign = torch.tensor(\n                    [1.0, -1.0, -1.0], device=x.device, dtype=x.dtype\n                )\n            else:\n                sign = torch.tensor(\n                    [1.0, 1.0, -1.0, -1.0], device=x.device, dtype=x.dtype\n                )\n            view_shape = [1] * (x.ndim - 1) + [int(sign.numel())]\n            return x * sign.view(*view_shape)\n\n    module.MirrorFunctions = MirrorFunctions\n    sys.modules[module.__name__] = module\n\n\ndef test_setup_symmetry_builds_expected_dof_permutation_and_signs():\n    _install_mirror_stub()\n    algo = PPO.__new__(PPO)\n    algo.device = torch.device(\"cpu\")\n    algo.num_actions = 5\n    algo.command_name = \"base_velocity\"\n    algo.symmetry_loss_enabled = True\n    algo.is_main_process = False\n    algo.config = OmegaConf.create(\n        {\n            \"module_dict\": {\n                \"actor\": {\n                    \"obs_schema\": {\n                        \"flattened_obs\": {\n                            \"seq_len\": 2,\n                            \"terms\": [\"unified/actor_dof_pos\"],\n                        }\n                    }\n                }\n            },\n            \"symmetry_loss\": {\n                \"enabled\": True,\n                \"coef\": 0.1,\n                \"dof_sign_by_name\": {\n                    \"left_hip_pitch_joint\": 1.0,\n                    \"right_hip_pitch_joint\": 1.0,\n                    \"waist_yaw_joint\": -1.0,\n                    \"left_knee_joint\": 1.0,\n                    \"right_knee_joint\": 1.0,\n                },\n            },\n        }\n    )\n    algo.env = SimpleNamespace(\n        _env=SimpleNamespace(\n            scene={\n                \"robot\": SimpleNamespace(\n                    joint_names=[\n                        \"left_hip_pitch_joint\",\n                        \"right_hip_pitch_joint\",\n                        \"waist_yaw_joint\",\n                        \"left_knee_joint\",\n                        \"right_knee_joint\",\n                    ]\n                )\n            }\n        )\n    )\n    algo.env_config = OmegaConf.create(\n        {\n            \"config\": {\n                \"robot\": {\n                    \"dof_sign_by_name\": {\n                        \"left_hip_pitch_joint\": 1.0,\n                        \"right_hip_pitch_joint\": 1.0,\n                        \"waist_yaw_joint\": -1.0,\n                        \"left_knee_joint\": 1.0,\n                        \"right_knee_joint\": 1.0,\n                    }\n                },\n                \"obs\": {\n                    \"obs_groups\": {\n                        \"unified\": {\n                            \"atomic_obs_list\": [\n                                {\n                                    \"actor_dof_pos\": {\n                                        \"mirror_func\": \"mirror_dof\",\n                                    }\n                                }\n                            ]\n                        }\n                    }\n                },\n            }\n        }\n    )\n\n    algo._setup_symmetry()\n\n    assert algo._sym_dof_perm.tolist() == [1, 0, 2, 4, 3]\n    assert algo._sym_dof_sign.tolist() == [1.0, 1.0, -1.0, 1.0, 1.0]\n\n\ndef test_mirror_actor_obs_uses_slash_qualified_actor_terms_only():\n    _install_mirror_stub()\n    algo = PPO.__new__(PPO)\n    algo.command_name = \"base_velocity\"\n    algo.symmetry_loss_enabled = True\n    algo.symmetry_loss_coef = 0.1\n    algo._obs_mirror_map = {\n        \"unified/actor_velocity_command\": lambda x: x * 2.0,\n        \"unified/actor_dof_pos\": lambda x: x + 1.0,\n    }\n    obs_td = TensorDict.from_dict(\n        {\n            \"unified\": {\n                \"actor_velocity_command\": torch.tensor(\n                    [[[1.0, 2.0, 3.0]]], dtype=torch.float32\n                ),\n                \"actor_dof_pos\": torch.tensor(\n                    [[[0.1, 0.2]]], dtype=torch.float32\n                ),\n                \"critic_dof_pos\": torch.tensor(\n                    [[9.0, 8.0]], dtype=torch.float32\n                ),\n            }\n        },\n        batch_size=[1],\n        device=\"cpu\",\n    )\n\n    mirrored = algo._mirror_actor_obs(obs_td)\n\n    torch.testing.assert_close(\n        mirrored[\"unified\", \"actor_velocity_command\"],\n        torch.tensor([[[2.0, 4.0, 6.0]]], dtype=torch.float32),\n    )\n    torch.testing.assert_close(\n        mirrored[\"unified\", \"actor_dof_pos\"],\n        torch.tensor([[[1.1, 1.2]]], dtype=torch.float32),\n    )\n    torch.testing.assert_close(\n        mirrored[\"unified\", \"critic_dof_pos\"],\n        obs_td[\"unified\", \"critic_dof_pos\"],\n    )\n\n\ndef test_update_reports_symmetry_loss_only_for_velocity_tracking():\n    algo = PPO.__new__(PPO)\n    algo.device = torch.device(\"cpu\")\n    algo.accelerator = _DummyAccelerator()\n    algo.actor = _DummyActor(num_actions=2, mirror_offset=1.0)\n    algo.critic = _DummyCritic()\n    algo.actor_optimizer = torch.optim.SGD(algo.actor.parameters(), lr=0.01)\n    algo.critic_optimizer = torch.optim.SGD(algo.critic.parameters(), lr=0.01)\n    algo.storage = _SingleBatchStorage(\n        SimpleNamespace(\n            obs=TensorDict.from_dict(\n                {\n                    \"unified\": {\n                        \"actor_dof_pos\": torch.zeros(2, 1, 2),\n                        \"critic_dof_pos\": torch.zeros(2, 2),\n                    }\n                },\n                batch_size=[2],\n                device=\"cpu\",\n            ),\n            actions=torch.zeros(2, 2),\n            values=torch.zeros(2, 1),\n            advantages=torch.zeros(2, 1),\n            returns=torch.zeros(2, 1),\n            actions_log_prob=torch.zeros(2, 1),\n            mu=torch.zeros(2, 2),\n            sigma=torch.ones(2, 2),\n        )\n    )\n    algo.value_loss_coef = 1.0\n    algo.clip_param = 0.2\n    algo.max_grad_norm = 1.0\n    algo.schedule = \"fixed\"\n    algo.desired_kl = None\n    algo.distributed_update_mode = \"legacy\"\n    algo.num_mini_batches = 1\n    algo.num_learning_epochs = 1\n    algo.configured_num_mini_batches = 1\n    algo.requested_num_mini_batches = 1\n    algo.distributed_lr_scale_factor = 1.0\n    algo.entropy_coef = 0.0\n    algo.initial_entropy_coef = 0.0\n    algo.anneal_entropy = False\n    algo.use_clipped_value_loss = False\n    algo.actor_learning_rate = 1.0e-3\n    algo.critic_learning_rate = 1.0e-3\n    algo.global_advantage_norm = True\n    algo.is_distributed = False\n    algo.symmetry_loss_enabled = True\n    algo.symmetry_loss_coef = 0.5\n    algo._mirror_actor_obs = lambda obs_td: obs_td\n    algo._mirror_env_action = lambda actions: actions\n    algo._post_update_hook = lambda loss_dict: None\n\n    algo.command_name = \"base_velocity\"\n    velocity_loss = algo.update()\n\n    assert velocity_loss[\"symmetry_loss\"] == pytest.approx(1.0)\n\n    algo.storage = _SingleBatchStorage(algo.storage._batch)\n    algo.command_name = \"ref_motion\"\n    non_velocity_loss = algo.update()\n\n    assert \"symmetry_loss\" not in non_velocity_loss\n"
  },
  {
    "path": "tests/test_ppo_tf_aux_keybody.py",
    "content": "import copy\nimport sys\nfrom types import ModuleType, SimpleNamespace\nfrom unittest import mock\n\nimport pytest\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom holomotion.src.algo.algo_utils import PpoAuxTransition, RolloutStorage\nfrom holomotion.src.algo.ppo import PPO\nfrom holomotion.src.algo.ppo_tf import PPOTF\nfrom holomotion.src.modules.agent_modules import (\n    PPOTFActor,\n    TensorDictAssembler,\n)\nfrom holomotion.src.modules.network_modules import GroupedMoETransformerPolicy\nfrom holomotion.src.modules.network_modules import GroupedMoEBlock\nfrom holomotion.src.modules.network_modules import ModernTransformerBlock\nfrom tensordict import TensorDict\n\n\ndef _make_aux_policy(\n    *,\n    denoise_root_lin_vel_weight: float = 1.0e-2,\n    denoise_root_ang_vel_weight: float = 1.0e-2,\n    denoise_dof_pos_weight: float = 1.0e-2,\n) -> GroupedMoETransformerPolicy:\n    module_config = {\n        \"num_fine_experts\": 2,\n        \"num_shared_experts\": 1,\n        \"top_k\": 1,\n        \"routing_score_fn\": \"softmax\",\n        \"routing_scale\": 1.0,\n        \"use_dynamic_bias\": False,\n        \"bias_update_rate\": 0.001,\n        \"expert_bias_clip\": 0.0,\n        \"obs_embed_mlp_hidden\": 16,\n        \"d_model\": 8,\n        \"n_heads\": 2,\n        \"n_kv_heads\": 1,\n        \"use_gated_attn\": False,\n        \"n_layers\": 1,\n        \"ff_mult\": 1.0,\n        \"ff_mult_dense\": 2,\n        \"attn_dropout\": 0.0,\n        \"mlp_dropout\": 0.0,\n        \"max_ctx_len\": 4,\n        \"aux_state_pred\": {\n            \"enabled\": True,\n            \"w_denoise_ref_root_lin_vel\": denoise_root_lin_vel_weight,\n            \"w_denoise_ref_root_ang_vel\": denoise_root_ang_vel_weight,\n            \"w_denoise_ref_dof_pos\": denoise_dof_pos_weight,\n            \"keybody_contact_names\": [\n                \"left_knee_link\",\n                \"right_knee_link\",\n            ],\n            \"keybody_rel_pos_names\": [\n                \"left_knee_link\",\n                \"right_knee_link\",\n            ],\n        },\n    }\n    return GroupedMoETransformerPolicy(\n        input_dim=6,\n        output_dim=4,\n        module_config_dict=module_config,\n    )\n\n\ndef _make_aux_actor() -> PPOTFActor:\n    actor = PPOTFActor.__new__(PPOTFActor)\n    nn.Module.__init__(actor)\n    actor.actor_module = _make_aux_policy()\n    actor.aux_state_pred_enabled = True\n    actor.aux_router_command_recon_enabled = False\n    actor.aux_router_switch_penalty_enabled = False\n    actor.obs_norm_enabled = False\n    actor.obs_normalizer = nn.Identity()\n    actor.obs_norm_clip = 0.0\n    actor.actor_obs_transforms = []\n    actor.assembler = TensorDictAssembler(\n        {\"flat_obs\": {\"seq_len\": 1, \"terms\": [\"flat_obs\"]}},\n        output_mode=\"flat\",\n    )\n    actor.min_sigma = 0.01\n    actor.max_sigma = 1.0\n    actor.log_std = nn.Parameter(torch.zeros(4, dtype=torch.float32))\n    return actor\n\n\ndef _make_aux_command_policy(\n    *,\n    n_layers: int = 3,\n    dense_layer_at_last: bool = False,\n    enable_aux_router_command_recon: bool = True,\n    freeze_router: bool = False,\n) -> GroupedMoETransformerPolicy:\n    module_config = {\n        \"num_fine_experts\": 2,\n        \"num_shared_experts\": 1,\n        \"top_k\": 1,\n        \"routing_score_fn\": \"softmax\",\n        \"routing_scale\": 1.0,\n        \"use_dynamic_bias\": False,\n        \"bias_update_rate\": 0.001,\n        \"expert_bias_clip\": 0.0,\n        \"obs_embed_mlp_hidden\": 16,\n        \"d_model\": 8,\n        \"n_heads\": 2,\n        \"n_kv_heads\": 1,\n        \"use_gated_attn\": False,\n        \"n_layers\": n_layers,\n        \"ff_mult\": 1.0,\n        \"ff_mult_dense\": 2,\n        \"attn_dropout\": 0.0,\n        \"mlp_dropout\": 0.0,\n        \"max_ctx_len\": 4,\n        \"dense_layer_at_last\": dense_layer_at_last,\n        \"freeze_router\": freeze_router,\n        \"aux_router_command_recon\": {\n            \"enabled\": enable_aux_router_command_recon,\n            \"output_dim\": 5,\n            \"hidden_dim\": 7,\n        },\n    }\n    return GroupedMoETransformerPolicy(\n        input_dim=6,\n        output_dim=4,\n        module_config_dict=module_config,\n    )\n\n\ndef _make_temporal_aux_actor() -> PPOTFActor:\n    actor = PPOTFActor.__new__(PPOTFActor)\n    nn.Module.__init__(actor)\n    actor.actor_module = _make_aux_command_policy()\n    actor.aux_state_pred_enabled = False\n    actor.aux_router_command_recon_enabled = True\n    actor.aux_router_switch_penalty_enabled = True\n    actor.obs_norm_enabled = False\n    actor.obs_normalizer = nn.Identity()\n    actor.obs_norm_clip = 0.0\n    actor.actor_obs_transforms = []\n    actor.assembler = TensorDictAssembler(\n        {\"flat_obs\": {\"seq_len\": 1, \"terms\": [\"flat_obs\"]}},\n        output_mode=\"flat\",\n    )\n    actor.min_sigma = 0.01\n    actor.max_sigma = 1.0\n    actor.log_std = nn.Parameter(torch.zeros(4, dtype=torch.float32))\n    return actor\n\n\ndef _make_temporal_only_aux_actor() -> PPOTFActor:\n    actor = _make_temporal_aux_actor()\n    actor.aux_router_command_recon_enabled = False\n    return actor\n\n\ndef test_rollout_storage_allocates_ref_and_robot_keybody_targets():\n    original_tokens = dict(PpoAuxTransition.SHAPE_TOKENS)\n    PpoAuxTransition.SHAPE_TOKENS[\"C\"] = 2\n    PpoAuxTransition.SHAPE_TOKENS[\"K\"] = 8\n    try:\n        obs_template = TensorDict(\n            {\"flat_obs\": torch.zeros(2, 5)},\n            batch_size=[2],\n        )\n        storage = RolloutStorage(\n            num_envs=2,\n            num_transitions_per_env=3,\n            obs_template=obs_template,\n            actions_shape=(4,),\n            transition_cls=PpoAuxTransition,\n        )\n    finally:\n        PpoAuxTransition.SHAPE_TOKENS = original_tokens\n\n    assert storage.data[\"gt_ref_keybody_rel_pos\"].shape == (3, 2, 8, 3)\n    assert storage.data[\"gt_robot_keybody_rel_pos\"].shape == (3, 2, 8, 3)\n    assert storage.data[\"gt_denoise_ref_root_lin_vel\"].shape == (3, 2, 3)\n    assert storage.data[\"gt_denoise_ref_root_ang_vel\"].shape == (3, 2, 3)\n    assert storage.data[\"gt_denoise_ref_dof_pos\"].shape == (3, 2, 4)\n\n\ndef test_grouped_moe_policy_returns_keybody_position_predictions():\n    policy = _make_aux_policy()\n    pre_moe_hidden = torch.randn(2, 3, policy.d_model)\n\n    outputs = policy.predict_aux_from_pre_moe(pre_moe_hidden)\n\n    assert outputs[\"base_lin_vel_loc\"].shape == (2, 3, 3)\n    assert outputs[\"base_lin_vel_log_std\"].shape == (2, 3, 3)\n    assert outputs[\"root_height_loc\"].shape == (2, 3, 1)\n    assert outputs[\"root_height_log_std\"].shape == (2, 3, 1)\n    assert outputs[\"keybody_contact_logits\"].shape == (2, 3, 2)\n    assert outputs[\"ref_keybody_rel_pos\"].shape == (2, 3, 2, 3)\n    assert outputs[\"robot_keybody_rel_pos\"].shape == (2, 3, 2, 3)\n    assert outputs[\"denoise_ref_root_lin_vel_residual\"].shape == (2, 3, 3)\n    assert outputs[\"denoise_ref_root_ang_vel_residual\"].shape == (2, 3, 3)\n    assert outputs[\"denoise_ref_dof_pos_residual\"].shape == (2, 3, 4)\n\n\ndef test_grouped_moe_policy_omits_denoise_predictions_when_weights_zero():\n    policy = _make_aux_policy(\n        denoise_root_lin_vel_weight=0.0,\n        denoise_root_ang_vel_weight=0.0,\n        denoise_dof_pos_weight=0.0,\n    )\n    pre_moe_hidden = torch.randn(2, 3, policy.d_model)\n\n    outputs = policy.predict_aux_from_pre_moe(pre_moe_hidden)\n\n    assert \"denoise_ref_root_lin_vel_residual\" not in outputs\n    assert \"denoise_ref_root_ang_vel_residual\" not in outputs\n    assert \"denoise_ref_dof_pos_residual\" not in outputs\n\n\ndef test_ppotf_actor_sequence_logp_emits_actor_facing_dof_denoise_keys():\n    actor = _make_aux_actor()\n    obs_td = TensorDict(\n        {\"flat_obs\": torch.randn(2, 3, 6, dtype=torch.float32)},\n        batch_size=[2, 3],\n    )\n    actions = torch.randn(2, 3, 4, dtype=torch.float32)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    outputs = actor(\n        obs_td,\n        actions=actions,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=False,\n    )\n\n    assert outputs[\"aux_denoise_ref_dof_pos_residual\"].shape == (2, 3, 4)\n    assert \"aux_denoise_ref_keybody_rel_pos_loc\" not in outputs.keys()\n    assert \"aux_denoise_ref_keybody_rel_pos_log_std\" not in outputs.keys()\n\n\ndef test_grouped_moe_policy_default_layout_keeps_dense_first_and_moe_tail():\n    policy = _make_aux_command_policy(\n        enable_aux_router_command_recon=False,\n    )\n\n    assert len(policy.layers) == 3\n    assert isinstance(policy.layers[0], ModernTransformerBlock)\n    assert all(\n        isinstance(layer, GroupedMoEBlock) for layer in policy.layers[1:]\n    )\n    assert policy._num_moe_layers == 2\n    assert policy._last_moe_layer_idx == 2\n\n\ndef test_grouped_moe_policy_dense_layer_at_last_keeps_only_middle_layers_moe():\n    policy = _make_aux_command_policy(\n        n_layers=4,\n        dense_layer_at_last=True,\n    )\n    obs_seq = torch.randn(2, 3, 6, dtype=torch.float32)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    _, router_features = policy.sequence_mu(\n        obs_seq,\n        attn_mask=attn_mask,\n        return_router_features=True,\n    )\n\n    assert isinstance(policy.layers[0], ModernTransformerBlock)\n    assert isinstance(policy.layers[1], GroupedMoEBlock)\n    assert isinstance(policy.layers[2], GroupedMoEBlock)\n    assert isinstance(policy.layers[3], ModernTransformerBlock)\n    assert policy._num_moe_layers == 2\n    assert policy._last_moe_layer_idx == 2\n    assert router_features.shape == (2, 3, 4)\n\n\ndef test_grouped_moe_policy_dense_layer_at_last_allows_shallow_fully_dense():\n    policy = _make_aux_command_policy(\n        n_layers=2,\n        dense_layer_at_last=True,\n        enable_aux_router_command_recon=False,\n    )\n\n    assert len(policy.layers) == 2\n    assert all(\n        isinstance(layer, ModernTransformerBlock) for layer in policy.layers\n    )\n    assert policy._num_moe_layers == 0\n    assert policy._last_moe_layer_idx is None\n\n\ndef test_grouped_moe_policy_command_recon_uses_live_router_features():\n    policy = _make_aux_command_policy()\n    obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    _, router_features = policy.sequence_mu(\n        obs_seq,\n        attn_mask=attn_mask,\n        return_router_features=True,\n    )\n    pred = policy.predict_aux_router_command_from_router_features(\n        router_features\n    )\n\n    assert policy._num_moe_layers == 2\n    assert router_features.shape == (2, 3, 4)\n    assert pred.shape == (2, 3, 5)\n    assert router_features.requires_grad\n\n    pred.sum().backward()\n\n    first_moe = next(\n        layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock)\n    )\n    assert first_moe.last_router_distribution is not None\n    assert first_moe.last_router_distribution.requires_grad\n    assert first_moe.router.weight.grad is not None\n\n\ndef test_grouped_moe_policy_freeze_router_detaches_router_features_and_params():\n    policy = _make_aux_command_policy(freeze_router=True)\n    obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    _, router_features, router_temporal_features = policy.sequence_mu(\n        obs_seq,\n        attn_mask=attn_mask,\n        return_router_features=True,\n        return_router_temporal_features=True,\n    )\n    pred = policy.predict_aux_router_command_from_router_features(\n        router_features\n    )\n\n    first_moe = next(\n        layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock)\n    )\n    assert first_moe.freeze_router is True\n    assert first_moe.router.weight.requires_grad is False\n    assert router_features.requires_grad is False\n    assert router_temporal_features.requires_grad is False\n\n    pred.sum().backward()\n\n    assert first_moe.last_router_distribution is not None\n    assert first_moe.last_router_distribution.requires_grad is False\n    assert first_moe.last_router_logits is not None\n    assert first_moe.last_router_logits.requires_grad is False\n    assert first_moe.router.weight.grad is None\n\n\ndef test_grouped_moe_policy_loads_legacy_aux_command_recon_head_keys():\n    policy = _make_aux_command_policy(enable_aux_router_command_recon=True)\n    state_dict = copy.deepcopy(policy.state_dict())\n\n    expected_tensors = {}\n    for key in list(state_dict.keys()):\n        if \"aux_router_command_recon_head.\" not in key:\n            continue\n        legacy_key = key.replace(\n            \"aux_router_command_recon_head.\",\n            \"aux_command_recon_head.\",\n        )\n        legacy_value = torch.randn_like(state_dict[key])\n        expected_tensors[key] = legacy_value\n        state_dict[legacy_key] = legacy_value\n        del state_dict[key]\n\n    result = policy.load_state_dict(state_dict, strict=True)\n\n    assert result.missing_keys == []\n    assert result.unexpected_keys == []\n    for key, expected in expected_tensors.items():\n        actual = policy.state_dict()[key]\n        assert torch.allclose(actual, expected)\n\n\ndef test_grouped_moe_policy_ignores_legacy_aux_command_recon_head_keys_when_disabled():\n    policy = _make_aux_command_policy(enable_aux_router_command_recon=False)\n    legacy_policy = _make_aux_command_policy(\n        enable_aux_router_command_recon=True\n    )\n    state_dict = copy.deepcopy(policy.state_dict())\n\n    for key, value in legacy_policy.state_dict().items():\n        if \"aux_router_command_recon_head.\" not in key:\n            continue\n        legacy_key = key.replace(\n            \"aux_router_command_recon_head.\",\n            \"aux_command_recon_head.\",\n        )\n        state_dict[legacy_key] = value.clone()\n\n    result = policy.load_state_dict(state_dict, strict=True)\n\n    assert result.missing_keys == []\n    assert result.unexpected_keys == []\n    assert policy.aux_router_command_recon_head is None\n\n\ndef test_grouped_moe_policy_clears_router_cache_before_deepcopy():\n    policy = _make_aux_command_policy()\n    obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    _, router_features = policy.sequence_mu(\n        obs_seq,\n        attn_mask=attn_mask,\n        return_router_features=True,\n    )\n    pred = policy.predict_aux_router_command_from_router_features(\n        router_features\n    )\n    pred.sum().backward()\n\n    first_moe = next(\n        layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock)\n    )\n    assert first_moe.last_router_distribution is not None\n\n    policy.clear_router_distribution_cache()\n    copied = copy.deepcopy(policy)\n\n    copied_first_moe = next(\n        layer for layer in copied.layers if isinstance(layer, GroupedMoEBlock)\n    )\n    assert copied_first_moe.last_router_distribution is None\n\n\ndef test_grouped_moe_block_tracks_least_utilized_expert_stats():\n    block = GroupedMoEBlock(\n        d_model=8,\n        n_heads=2,\n        n_kv_heads=1,\n        num_fine_experts=4,\n        num_shared_experts=1,\n        top_k=1,\n        ff_mult=1.0,\n        use_qk_norm=True,\n        use_gated_attn=False,\n        attn_dropout=0.0,\n        mlp_dropout=0.0,\n        use_dynamic_bias=False,\n        routing_score_fn=\"softmax\",\n    )\n\n    block._apply_bias_update_from_counts(torch.tensor([5, 3, 0, 2]))\n    assert block.last_active_expert_ratio.item() == pytest.approx(0.75)\n    assert block.last_max_expert_frac.item() == pytest.approx(0.5)\n    assert block.last_min_expert_frac.item() == pytest.approx(0.0)\n    assert block.last_dead_expert_ratio.item() == pytest.approx(0.25)\n\n    block._apply_bias_update_from_counts(torch.tensor([5, 3, 1, 1]))\n    assert block.last_min_expert_frac.item() == pytest.approx(0.1)\n    assert block.last_dead_expert_ratio.item() == pytest.approx(0.0)\n\n\ndef test_grouped_moe_block_tracks_dead_expert_margin_to_topk_loss():\n    block = GroupedMoEBlock(\n        d_model=4,\n        n_heads=2,\n        n_kv_heads=1,\n        num_fine_experts=3,\n        num_shared_experts=1,\n        top_k=1,\n        ff_mult=1.0,\n        use_qk_norm=True,\n        use_gated_attn=False,\n        attn_dropout=0.0,\n        mlp_dropout=0.0,\n        use_dynamic_bias=False,\n        routing_score_fn=\"softmax\",\n        dead_expert_margin_to_topk_enabled=True,\n    )\n\n    topk_idx = torch.tensor([[[0], [0]]], dtype=torch.long)\n    dense_distribution = torch.tensor(\n        [[[0.8, 0.15, 0.05], [0.7, 0.2, 0.1]]], dtype=torch.float32\n    )\n    choice_scores = torch.log(dense_distribution)\n\n    loss = block._update_routed_expert_stats_and_floor_loss(\n        topk_idx=topk_idx,\n        dense_distribution=dense_distribution,\n        choice_scores=choice_scores,\n    )\n\n    expected = torch.relu(\n        choice_scores.gather(-1, topk_idx)[..., -1:] - choice_scores\n    )\n    expected = expected[..., 1:].sum() / 4.0\n\n    torch.testing.assert_close(loss, expected)\n    torch.testing.assert_close(\n        block.last_dead_expert_margin_to_topk_loss, expected\n    )\n    torch.testing.assert_close(\n        block.last_dead_expert_margin_to_topk_loss_value,\n        expected.detach(),\n    )\n    torch.testing.assert_close(\n        block.last_dead_expert_margin_to_topk_target,\n        choice_scores.gather(-1, topk_idx)[..., -1:].mean(),\n    )\n    torch.testing.assert_close(\n        block.last_dense_expert_usage,\n        dense_distribution.mean(dim=(0, 1)),\n    )\n\n\ndef test_grouped_moe_block_tracks_selected_expert_margin_to_unselected():\n    block = GroupedMoEBlock(\n        d_model=4,\n        n_heads=2,\n        n_kv_heads=1,\n        num_fine_experts=4,\n        num_shared_experts=1,\n        top_k=2,\n        ff_mult=1.0,\n        use_qk_norm=True,\n        use_gated_attn=False,\n        attn_dropout=0.0,\n        mlp_dropout=0.0,\n        use_dynamic_bias=False,\n        routing_score_fn=\"softmax\",\n        selected_expert_margin_to_unselected_enabled=True,\n        selected_expert_margin_to_unselected_target=0.4,\n    )\n\n    topk_idx = torch.tensor([[[0, 2], [1, 0]]], dtype=torch.long)\n    dense_distribution = torch.tensor(\n        [\n            [\n                [0.42, 0.21, 0.28, 0.09],\n                [0.27, 0.36, 0.22, 0.15],\n            ]\n        ],\n        dtype=torch.float32,\n    )\n    choice_scores = torch.tensor(\n        [[[1.0, 0.3, 0.8, 0.1], [0.9, 1.2, 0.7, 0.4]]],\n        dtype=torch.float32,\n    )\n\n    block._update_routed_expert_stats_and_floor_loss(\n        topk_idx=topk_idx,\n        dense_distribution=dense_distribution,\n        choice_scores=choice_scores,\n    )\n\n    expected_margin = torch.tensor((0.5 + 0.2) / 2.0)\n    expected_loss = torch.tensor((0.0 + 0.2) / 2.0)\n\n    torch.testing.assert_close(\n        block.last_selected_expert_margin_to_unselected,\n        expected_margin,\n    )\n    torch.testing.assert_close(\n        block.last_selected_expert_margin_to_unselected_loss,\n        expected_loss,\n    )\n    torch.testing.assert_close(\n        block.last_selected_expert_margin_to_unselected_loss_value,\n        expected_loss,\n    )\n\n\ndef test_ppotf_summarize_moe_layer_stats_includes_least_utilized_metrics():\n    moe_layers = [\n        SimpleNamespace(\n            last_active_expert_ratio=torch.tensor(0.75),\n            last_max_expert_frac=torch.tensor(0.50),\n            last_min_expert_frac=torch.tensor(0.00),\n            last_dead_expert_ratio=torch.tensor(0.25),\n            last_expert_count_cv=torch.tensor(1.20),\n            last_selected_expert_margin_to_unselected=torch.tensor(0.30),\n        ),\n        SimpleNamespace(\n            last_active_expert_ratio=torch.tensor(0.50),\n            last_max_expert_frac=torch.tensor(0.30),\n            last_min_expert_frac=torch.tensor(0.05),\n            last_dead_expert_ratio=torch.tensor(0.50),\n            last_expert_count_cv=torch.tensor(0.80),\n            last_selected_expert_margin_to_unselected=torch.tensor(0.10),\n        ),\n    ]\n\n    metrics = PPOTF._summarize_moe_layer_stats(moe_layers)\n\n    assert metrics[\"moe_active_expert_ratio\"] == pytest.approx(0.625)\n    assert metrics[\"moe_max_expert_frac\"] == pytest.approx(0.40)\n    assert metrics[\"moe_least_expert_frac\"] == pytest.approx(0.025)\n    assert metrics[\"moe_dead_expert_ratio\"] == pytest.approx(0.375)\n    assert metrics[\"moe_expert_count_cv\"] == pytest.approx(1.0)\n    assert metrics[\n        \"moe_selected_expert_margin_to_unselected\"\n    ] == pytest.approx(0.20)\n\n\ndef test_compute_routed_expert_orthogonal_loss_uses_active_experts_only():\n    algo = PPOTF.__new__(PPOTF)\n    algo.router_expert_orthogonal_min_active_usage = 0.1\n    algo.router_expert_orthogonal_eps = 1.0e-8\n\n    moe_layer = SimpleNamespace(\n        last_routed_expert_usage=torch.tensor(\n            [0.2, 0.12, 0.05], dtype=torch.float32\n        ),\n        down_proj=torch.tensor(\n            [\n                [[1.0, 0.0]],\n                [[1.0, 1.0]],\n                [[0.0, 1.0]],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n\n    loss, active_count, mean_offdiag = (\n        algo._compute_routed_expert_orthogonal_loss(\n            moe_layer,\n            dtype=torch.float32,\n            device=torch.device(\"cpu\"),\n        )\n    )\n\n    active_vecs = F.normalize(\n        torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32),\n        p=2.0,\n        dim=-1,\n        eps=1.0e-8,\n    )\n    gram = active_vecs @ active_vecs.transpose(0, 1)\n    offdiag = gram.masked_select(~torch.eye(2, dtype=torch.bool))\n\n    torch.testing.assert_close(active_count, torch.tensor(2.0))\n    torch.testing.assert_close(loss, offdiag.square().sum())\n    torch.testing.assert_close(mean_offdiag, offdiag.abs().mean())\n\n\ndef test_compute_routed_expert_orthogonal_loss_returns_zero_below_two_active():\n    algo = PPOTF.__new__(PPOTF)\n    algo.router_expert_orthogonal_min_active_usage = 0.1\n    algo.router_expert_orthogonal_eps = 1.0e-8\n\n    moe_layer = SimpleNamespace(\n        last_routed_expert_usage=torch.tensor(\n            [0.2, 0.05, 0.01], dtype=torch.float32\n        ),\n        down_proj=torch.randn(3, 1, 2, dtype=torch.float32),\n    )\n\n    loss, active_count, mean_offdiag = (\n        algo._compute_routed_expert_orthogonal_loss(\n            moe_layer,\n            dtype=torch.float32,\n            device=torch.device(\"cpu\"),\n        )\n    )\n\n    torch.testing.assert_close(loss, torch.tensor(0.0))\n    torch.testing.assert_close(active_count, torch.tensor(1.0))\n    torch.testing.assert_close(mean_offdiag, torch.tensor(0.0))\n\n\ndef test_ppotf_actor_sequence_logp_emits_router_features_for_aux_router_losses():\n    actor = _make_temporal_aux_actor()\n    obs_td = TensorDict(\n        {\"flat_obs\": torch.randn(2, 3, 6, dtype=torch.float32)},\n        batch_size=[2, 3],\n    )\n    actions = torch.randn(2, 3, 4, dtype=torch.float32)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    outputs = actor(\n        obs_td,\n        actions=actions,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=False,\n    )\n\n    assert outputs[\"router_features\"].shape == (2, 3, 4)\n    assert outputs[\"router_temporal_features\"].shape == (2, 3, 4)\n    assert outputs[\"aux_router_command_recon\"].shape == (2, 3, 5)\n\n\ndef test_ppotf_actor_sequence_logp_emits_only_router_features_for_temporal_only_aux():\n    actor = _make_temporal_only_aux_actor()\n    obs_td = TensorDict(\n        {\"flat_obs\": torch.randn(2, 3, 6, dtype=torch.float32)},\n        batch_size=[2, 3],\n    )\n    actions = torch.randn(2, 3, 4, dtype=torch.float32)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    outputs = actor(\n        obs_td,\n        actions=actions,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=False,\n    )\n\n    assert outputs[\"router_features\"].shape == (2, 3, 4)\n    assert outputs[\"router_temporal_features\"].shape == (2, 3, 4)\n    assert \"aux_router_command_recon\" not in outputs.keys()\n\n\ndef test_masked_adjacent_router_js_averages_only_valid_adjacent_tokens():\n    router_features = torch.tensor(\n        [\n            [\n                [0.8, 0.2, 0.6, 0.4],\n                [0.6, 0.4, 0.5, 0.5],\n                [0.1, 0.9, 0.4, 0.6],\n            ]\n        ],\n        dtype=torch.float32,\n    )\n    valid_tok = torch.tensor([[1.0, 1.0, 0.0]], dtype=torch.float32)\n\n    loss = PPOTF._masked_adjacent_router_js(\n        router_features=router_features,\n        valid_tok=valid_tok,\n        num_moe_layers=2,\n        num_fine_experts=2,\n    )\n\n    layer0_prev = torch.tensor([0.8, 0.2], dtype=torch.float32)\n    layer0_curr = torch.tensor([0.6, 0.4], dtype=torch.float32)\n    mix0 = 0.5 * (layer0_prev + layer0_curr)\n    js0 = 0.5 * (\n        (layer0_prev * (torch.log(layer0_prev) - torch.log(mix0))).sum()\n        + (layer0_curr * (torch.log(layer0_curr) - torch.log(mix0))).sum()\n    )\n    layer1_prev = torch.tensor([0.6, 0.4], dtype=torch.float32)\n    layer1_curr = torch.tensor([0.5, 0.5], dtype=torch.float32)\n    mix1 = 0.5 * (layer1_prev + layer1_curr)\n    js1 = 0.5 * (\n        (layer1_prev * (torch.log(layer1_prev) - torch.log(mix1))).sum()\n        + (layer1_curr * (torch.log(layer1_curr) - torch.log(mix1))).sum()\n    )\n    expected = 0.5 * (js0 + js1)\n    assert torch.isclose(loss, expected)\n\n\ndef test_masked_adjacent_router_normed_smooth_l1_averages_only_valid_adjacent_tokens():\n    router_temporal_features = torch.tensor(\n        [\n            [\n                [3.0, 1.0, 0.0],\n                [2.0, 0.0, 2.0],\n                [1.0, 1.0, 1.0],\n            ]\n        ],\n        dtype=torch.float32,\n    )\n    valid_tok = torch.tensor([[1.0, 1.0, 0.0]], dtype=torch.float32)\n\n    loss = PPOTF._masked_adjacent_router_normed_smooth_l1(\n        router_temporal_features=router_temporal_features,\n        valid_tok=valid_tok,\n        num_moe_layers=1,\n        num_fine_experts=3,\n    )\n\n    prev_logits = router_temporal_features[:, :1].reshape(1, 1, 1, 3)\n    curr_logits = router_temporal_features[:, 1:2].reshape(1, 1, 1, 3)\n    prev_norm = F.normalize(\n        prev_logits - prev_logits.mean(dim=-1, keepdim=True),\n        p=2.0,\n        dim=-1,\n        eps=1.0e-5,\n    )\n    curr_norm = F.normalize(\n        curr_logits - curr_logits.mean(dim=-1, keepdim=True),\n        p=2.0,\n        dim=-1,\n        eps=1.0e-5,\n    )\n    expected = F.smooth_l1_loss(\n        curr_norm,\n        prev_norm,\n        reduction=\"none\",\n        beta=1.0,\n    ).mean()\n    assert torch.isclose(loss, expected)\n\n\ndef test_masked_aux_keybody_mse_averages_only_valid_tokens():\n    pred = torch.tensor([[[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]])\n    target = torch.zeros_like(pred)\n    valid_tok = torch.tensor([[1.0, 0.0]])\n\n    loss = PPOTF._masked_aux_keybody_mse(pred, target, valid_tok)\n\n    expected = torch.tensor((1.0 + 4.0 + 9.0) / 3.0)\n    assert torch.isclose(loss, expected)\n\n\ndef test_masked_aux_huber_averages_only_valid_tokens():\n    pred = torch.zeros(1, 2, 1, 3)\n    target = torch.tensor([[[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]])\n    valid_tok = torch.tensor([[1.0, 0.0]])\n\n    loss = PPOTF._masked_aux_huber(\n        pred=pred,\n        target=target,\n        valid_tok=valid_tok,\n        beta=1.0,\n    )\n\n    expected = torch.tensor((0.5 + 1.5 + 2.5) / 3.0)\n    assert torch.isclose(loss, expected)\n\n\ndef test_setup_configs_rejects_router_aux_terms_outside_motion_tracking():\n    algo = PPOTF.__new__(PPOTF)\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\n            \"enabled\": False,\n        },\n        \"aux_router_switch_penalty\": {\"enabled\": True, \"weight\": 1.0},\n    }\n    algo.command_name = \"velocity\"\n\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        with pytest.raises(ValueError, match=\"aux_router_switch_penalty\"):\n            algo._setup_configs()\n\n\ndef test_setup_configs_rejects_unknown_router_switch_penalty_metric():\n    algo = PPOTF.__new__(PPOTF)\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\n            \"enabled\": False,\n        },\n        \"aux_router_switch_penalty\": {\n            \"enabled\": True,\n            \"weight\": 1.0,\n            \"metric\": \"not_a_metric\",\n        },\n    }\n    algo.command_name = \"ref_motion\"\n\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        with pytest.raises(\n            ValueError, match=\"aux_router_switch_penalty.metric\"\n        ):\n            algo._setup_configs()\n\n\ndef test_setup_configs_reads_dead_expert_margin_to_topk_only():\n    algo = PPOTF.__new__(PPOTF)\n    algo.command_name = \"ref_motion\"\n\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n        \"dead_expert_margin_to_topk\": {\"enabled\": True, \"weight\": 0.7},\n    }\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        algo._setup_configs()\n    assert algo.use_dead_expert_margin_to_topk is True\n    assert algo.dead_expert_margin_to_topk_weight == pytest.approx(0.7)\n\n    algo = PPOTF.__new__(PPOTF)\n    algo.command_name = \"ref_motion\"\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n    }\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        algo._setup_configs()\n    assert algo.use_dead_expert_margin_to_topk is False\n    assert algo.dead_expert_margin_to_topk_weight == pytest.approx(0.0)\n\n\ndef test_setup_configs_reads_selected_expert_margin_to_unselected():\n    algo = PPOTF.__new__(PPOTF)\n    algo.command_name = \"ref_motion\"\n\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n        \"selected_expert_margin_to_unselected\": {\n            \"enabled\": True,\n            \"weight\": 0.9,\n            \"target\": 0.3,\n        },\n    }\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        algo._setup_configs()\n    assert algo.use_selected_expert_margin_to_unselected is True\n    assert algo.selected_expert_margin_to_unselected_weight == pytest.approx(\n        0.9\n    )\n    assert algo.selected_expert_margin_to_unselected_target == pytest.approx(\n        0.3\n    )\n\n    algo = PPOTF.__new__(PPOTF)\n    algo.command_name = \"ref_motion\"\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n    }\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        algo._setup_configs()\n    assert algo.use_selected_expert_margin_to_unselected is False\n    assert algo.selected_expert_margin_to_unselected_weight == pytest.approx(\n        0.0\n    )\n    assert algo.selected_expert_margin_to_unselected_target == pytest.approx(\n        0.0\n    )\n\n\ndef test_setup_configs_reads_aux_router_future_recon():\n    algo = PPOTF.__new__(PPOTF)\n    algo.command_name = \"ref_motion\"\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n        \"aux_router_future_recon\": {\n            \"enabled\": True,\n            \"weight\": 0.7,\n            \"hidden_dim\": 13,\n            \"huber_beta\": 0.3,\n        },\n        \"module_dict\": {\n            \"actor\": {\n                \"type\": \"ReferenceRoutedGroupedMoETransformerPolicyV3\",\n            }\n        },\n    }\n\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        algo._setup_configs()\n\n    assert algo.use_aux_router_future_recon is True\n    assert algo.aux_router_future_recon_weight == pytest.approx(0.7)\n    assert algo.aux_router_future_recon_hidden_dim == 13\n    assert algo.aux_router_future_recon_huber_beta == pytest.approx(0.3)\n\n\ndef test_setup_configs_reads_router_expert_orthogonal():\n    algo = PPOTF.__new__(PPOTF)\n    algo.command_name = \"ref_motion\"\n\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n        \"dead_expert_margin_to_topk\": {\"enabled\": True, \"weight\": 0.7},\n        \"router_expert_orthogonal\": {\n            \"enabled\": True,\n            \"weight\": 0.9,\n            \"min_active_usage\": 0.2,\n            \"eps\": 1.0e-6,\n        },\n    }\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        algo._setup_configs()\n    assert algo.use_router_expert_orthogonal is True\n    assert algo.router_expert_orthogonal_weight == pytest.approx(0.9)\n    assert algo.router_expert_orthogonal_min_active_usage == pytest.approx(0.2)\n    assert algo.router_expert_orthogonal_eps == pytest.approx(1.0e-6)\n\n\ndef test_setup_configs_rejects_router_expert_orthogonal_without_dead_margin():\n    algo = PPOTF.__new__(PPOTF)\n    algo.command_name = \"ref_motion\"\n    algo.config = {\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n        \"router_expert_orthogonal\": {\n            \"enabled\": True,\n            \"weight\": 0.9,\n        },\n    }\n\n    with mock.patch.object(PPO, \"_setup_configs\", return_value=None):\n        with pytest.raises(ValueError, match=\"requires.*dead_expert\"):\n            algo._setup_configs()\n\n\ndef test_build_transition_uses_filtered_residual_targets_for_denoise_outputs():\n    algo = PPOTF.__new__(PPOTF)\n    algo.use_aux_state_pred = True\n    algo.use_aux_root_height = False\n    algo.use_aux_denoise_ref_root_lin_vel = True\n    algo.use_aux_denoise_ref_root_ang_vel = True\n    algo.use_aux_denoise_ref_dof_pos = True\n    algo.aux_state_pred_num_contact_bodies = 0\n    algo.aux_state_pred_num_keybody_bodies = 0\n    algo.command_name = \"ref_motion\"\n    algo.num_envs = 2\n    algo.device = torch.device(\"cpu\")\n    algo.transition_cls = PpoAuxTransition\n\n    world_lin_vel = torch.tensor(\n        [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], dtype=torch.float32\n    )\n    world_ang_vel = torch.tensor(\n        [[-1.0, -2.0, -3.0], [-4.0, -5.0, -6.0]], dtype=torch.float32\n    )\n    base_lin_vel = torch.tensor(\n        [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32\n    )\n    base_ang_vel = torch.tensor(\n        [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32\n    )\n    dof_pos = torch.tensor(\n        [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32\n    )\n    command = SimpleNamespace(\n        get_ref_motion_root_global_lin_vel_cur=lambda prefix=\"ref_\": (\n            world_lin_vel if prefix == \"ft_ref_\" else world_lin_vel + 100.0\n        ),\n        get_ref_motion_root_global_ang_vel_cur=lambda prefix=\"ref_\": (\n            world_ang_vel if prefix == \"ft_ref_\" else world_ang_vel + 100.0\n        ),\n        get_ref_motion_base_linvel_cur=lambda prefix=\"ref_\": (\n            base_lin_vel if prefix == \"ft_ref_\" else base_lin_vel + 100.0\n        ),\n        get_ref_motion_base_angvel_cur=lambda prefix=\"ref_\": (\n            base_ang_vel if prefix == \"ft_ref_\" else base_ang_vel + 100.0\n        ),\n        get_ref_motion_dof_pos_cur=lambda prefix=\"ref_\": (\n            dof_pos if prefix == \"ft_ref_\" else dof_pos + 100.0\n        ),\n    )\n    algo.env = SimpleNamespace(\n        _env=SimpleNamespace(\n            command_manager=SimpleNamespace(get_term=lambda name: command)\n        )\n    )\n\n    obs_td = TensorDict(\n        {\"flat_obs\": torch.zeros(2, 5, dtype=torch.float32)},\n        batch_size=[2],\n    )\n    actor_out = TensorDict(\n        {\n            \"actions\": torch.zeros(2, 4, dtype=torch.float32),\n            \"actions_log_prob\": torch.zeros(2, dtype=torch.float32),\n            \"mu\": torch.zeros(2, 4, dtype=torch.float32),\n            \"sigma\": torch.ones(2, 4, dtype=torch.float32),\n        },\n        batch_size=[2],\n    )\n    critic_out = TensorDict(\n        {\"values\": torch.zeros(2, 1, dtype=torch.float32)},\n        batch_size=[2],\n    )\n\n    isaaclab_pkg = ModuleType(\"isaaclab\")\n    isaaclab_envs = ModuleType(\"isaaclab.envs\")\n    isaaclab_mdp = ModuleType(\"isaaclab.envs.mdp\")\n    isaaclab_mdp.base_lin_vel = lambda env: torch.zeros(\n        2, 3, dtype=torch.float32\n    )\n    isaaclab_envs.mdp = isaaclab_mdp\n    isaaclab_pkg.envs = isaaclab_envs\n\n    with mock.patch.dict(\n        sys.modules,\n        {\n            \"isaaclab\": isaaclab_pkg,\n            \"isaaclab.envs\": isaaclab_envs,\n            \"isaaclab.envs.mdp\": isaaclab_mdp,\n        },\n    ):\n        transition = algo._build_transition(obs_td, actor_out, critic_out)\n\n    torch.testing.assert_close(\n        transition.gt_denoise_ref_root_lin_vel,\n        torch.full_like(base_lin_vel, -100.0),\n    )\n    torch.testing.assert_close(\n        transition.gt_denoise_ref_root_ang_vel,\n        torch.full_like(base_ang_vel, -100.0),\n    )\n    torch.testing.assert_close(\n        transition.gt_denoise_ref_dof_pos, torch.full_like(dof_pos, -100.0)\n    )\n\n\ndef test_build_transition_rejects_mismatched_denoise_dof_target_shape():\n    algo = PPOTF.__new__(PPOTF)\n    algo.use_aux_state_pred = True\n    algo.use_aux_root_height = False\n    algo.use_aux_denoise_ref_root_lin_vel = False\n    algo.use_aux_denoise_ref_root_ang_vel = False\n    algo.use_aux_denoise_ref_dof_pos = True\n    algo.aux_state_pred_num_contact_bodies = 0\n    algo.aux_state_pred_num_keybody_bodies = 0\n    algo.command_name = \"ref_motion\"\n    algo.num_envs = 2\n    algo.device = torch.device(\"cpu\")\n    algo.transition_cls = PpoAuxTransition\n    command = SimpleNamespace(\n        get_ref_motion_dof_pos_cur=lambda prefix=\"ref_\": torch.zeros(\n            2, 5, dtype=torch.float32\n        )\n    )\n    algo.env = SimpleNamespace(\n        _env=SimpleNamespace(\n            command_manager=SimpleNamespace(get_term=lambda name: command)\n        )\n    )\n\n    obs_td = TensorDict(\n        {\"flat_obs\": torch.zeros(2, 5, dtype=torch.float32)},\n        batch_size=[2],\n    )\n    actor_out = TensorDict(\n        {\n            \"actions\": torch.zeros(2, 4, dtype=torch.float32),\n            \"actions_log_prob\": torch.zeros(2, dtype=torch.float32),\n            \"mu\": torch.zeros(2, 4, dtype=torch.float32),\n            \"sigma\": torch.ones(2, 4, dtype=torch.float32),\n        },\n        batch_size=[2],\n    )\n    critic_out = TensorDict(\n        {\"values\": torch.zeros(2, 1, dtype=torch.float32)},\n        batch_size=[2],\n    )\n\n    isaaclab_pkg = ModuleType(\"isaaclab\")\n    isaaclab_envs = ModuleType(\"isaaclab.envs\")\n    isaaclab_mdp = ModuleType(\"isaaclab.envs.mdp\")\n    isaaclab_mdp.base_lin_vel = lambda env: torch.zeros(\n        2, 3, dtype=torch.float32\n    )\n    isaaclab_envs.mdp = isaaclab_mdp\n    isaaclab_pkg.envs = isaaclab_envs\n\n    with mock.patch.dict(\n        sys.modules,\n        {\n            \"isaaclab\": isaaclab_pkg,\n            \"isaaclab.envs\": isaaclab_envs,\n            \"isaaclab.envs.mdp\": isaaclab_mdp,\n        },\n    ):\n        with pytest.raises(ValueError, match=\"gt_denoise_ref_dof_pos\"):\n            algo._build_transition(obs_td, actor_out, critic_out)\n\n\ndef test_compute_aux_router_future_recon_loss_uses_normalized_future_targets():\n    algo = PPOTF.__new__(PPOTF)\n    algo.aux_router_future_recon_huber_beta = 0.5\n\n    obs_schema = {\n        \"flattened_obs_fut\": {\n            \"seq_len\": 2,\n            \"terms\": [\n                \"unified/actor_ref_base_linvel_fut\",\n                \"unified/actor_ref_dof_pos_fut\",\n            ],\n        }\n    }\n    obs_b = TensorDict(\n        {\n            \"unified\": TensorDict(\n                {\n                    \"actor_ref_base_linvel_fut\": torch.tensor(\n                        [\n                            [\n                                [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],\n                                [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],\n                                [[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]],\n                            ]\n                        ],\n                        dtype=torch.float32,\n                    ),\n                    \"actor_ref_dof_pos_fut\": torch.tensor(\n                        [\n                            [\n                                [[0.1, 0.2], [0.3, 0.4]],\n                                [[0.5, 0.6], [0.7, 0.8]],\n                                [[0.9, 1.0], [1.1, 1.2]],\n                            ]\n                        ],\n                        dtype=torch.float32,\n                    ),\n                },\n                batch_size=[1, 3],\n            )\n        },\n        batch_size=[1, 3],\n    )\n    assembler = TensorDictAssembler(obs_schema, output_mode=\"flat\")\n\n    class _DummyPolicy(nn.Module):\n        def normalize_aux_router_future_recon_target(\n            self, future_target: torch.Tensor\n        ) -> torch.Tensor:\n            return future_target * 0.25\n\n    actor_wrapper = SimpleNamespace(\n        aux_router_future_recon_assembler=assembler,\n        actor_module=_DummyPolicy(),\n    )\n    raw_target = assembler(obs_b.flatten(0, 1)).reshape(1, 3, -1)\n    normalized_target = raw_target * 0.25\n    pred = normalized_target + torch.tensor(\n        [\n            [\n                [0.0] * raw_target.shape[-1],\n                [0.5] * raw_target.shape[-1],\n                [1.0] * raw_target.shape[-1],\n            ]\n        ],\n        dtype=torch.float32,\n    )\n    actor_out = TensorDict(\n        {\"aux_router_future_recon\": pred},\n        batch_size=[1, 3],\n    )\n    valid_tok = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)\n\n    loss = algo._compute_aux_router_future_recon_loss(\n        actor_wrapper=actor_wrapper,\n        actor_out=actor_out,\n        obs_b=obs_b,\n        valid_tok=valid_tok,\n    )\n\n    expected = PPOTF._masked_aux_huber(\n        pred=pred,\n        target=normalized_target,\n        valid_tok=valid_tok,\n        beta=0.5,\n    )\n    assert torch.isclose(loss, expected)\n\n\ndef test_root_relative_body_pos_uses_consistent_environment_frame():\n    body_pos_w = torch.tensor(\n        [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32\n    )\n    root_pos_env = torch.zeros(1, 3, dtype=torch.float32)\n    root_quat_w = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)\n    env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32)\n\n    rel = PPOTF._root_relative_body_pos_from_mixed_position_frames(\n        body_pos_w=body_pos_w,\n        root_pos_env=root_pos_env,\n        root_quat_w=root_quat_w,\n        env_origins=env_origins,\n    )\n\n    expected = torch.tensor(\n        [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], dtype=torch.float32\n    )\n    assert torch.allclose(rel, expected)\n"
  },
  {
    "path": "tests/test_ref_router_actor.py",
    "content": "from __future__ import annotations\n\nimport pytest\nimport torch\nfrom holomotion.src.algo.ppo_tf import PPOTF\nfrom holomotion.src.modules.agent_modules import PPOTFRefRouterActor\nfrom tensordict import TensorDict\n\n\ndef _make_ref_router_obs_schema() -> dict:\n    return {\n        \"flattened_obs\": {\n            \"seq_len\": 1,\n            \"terms\": [\n                \"unified/actor_ref_dof_pos_cur\",\n                \"unified/actor_dof_pos\",\n                \"unified/actor_ref_root_height_cur\",\n                \"unified/actor_last_action\",\n            ],\n        },\n        \"flattened_obs_fut\": {\n            \"seq_len\": 2,\n            \"terms\": [\n                \"unified/actor_ref_dof_pos_fut\",\n                \"unified/actor_ref_root_height_fut\",\n            ],\n        },\n    }\n\n\ndef _make_ref_router_obs(batch_size: list[int]) -> TensorDict:\n    shape = list(batch_size)\n    fut_shape = shape + [2]\n    unified = TensorDict(\n        {\n            \"actor_ref_dof_pos_cur\": torch.randn(*shape, 2),\n            \"actor_dof_pos\": torch.randn(*shape, 3),\n            \"actor_ref_root_height_cur\": torch.randn(*shape, 1),\n            \"actor_last_action\": torch.randn(*shape, 2),\n            \"actor_ref_dof_pos_fut\": torch.randn(*fut_shape, 2),\n            \"actor_ref_root_height_fut\": torch.randn(*fut_shape, 1),\n        },\n        batch_size=shape,\n    )\n    return TensorDict({\"unified\": unified}, batch_size=shape)\n\n\ndef _make_ref_router_actor(\n    *,\n    num_actions: int = 4,\n    freeze_router: bool = False,\n    aux_router_future_recon: dict | None = None,\n) -> PPOTFRefRouterActor:\n    obs_schema = _make_ref_router_obs_schema()\n    obs_example = _make_ref_router_obs([2])\n    module_config = {\n        \"type\": \"ReferenceRoutedGroupedMoETransformerPolicy\",\n        \"num_fine_experts\": 3,\n        \"num_shared_experts\": 1,\n        \"top_k\": 1,\n        \"routing_score_fn\": \"softmax\",\n        \"routing_scale\": 1.0,\n        \"use_dynamic_bias\": False,\n        \"bias_update_rate\": 0.001,\n        \"expert_bias_clip\": 0.0,\n        \"obs_embed_mlp_hidden\": 16,\n        \"d_model\": 8,\n        \"n_layers\": 2,\n        \"n_heads\": 2,\n        \"n_kv_heads\": 1,\n        \"use_gated_attn\": False,\n        \"use_qk_norm\": True,\n        \"ff_mult\": 1.0,\n        \"ff_mult_dense\": 2,\n        \"attn_dropout\": 0.0,\n        \"mlp_dropout\": 0.0,\n        \"max_ctx_len\": 4,\n        \"freeze_router\": freeze_router,\n        \"obs_norm\": {\"enabled\": False},\n        \"output_dim\": num_actions,\n        \"aux_router_future_recon\": aux_router_future_recon\n        or {\"enabled\": False},\n    }\n    return PPOTFRefRouterActor(\n        obs_schema=obs_schema,\n        module_config_dict=module_config,\n        num_actions=num_actions,\n        init_noise_std=0.2,\n        obs_example=obs_example,\n    )\n\n\ndef test_ref_router_actor_infers_only_actor_ref_feature_indices():\n    obs_schema = _make_ref_router_obs_schema()\n    obs_example = _make_ref_router_obs([2])\n\n    indices = PPOTFRefRouterActor.infer_router_feature_indices(\n        obs_schema, obs_example\n    )\n\n    assert indices == [0, 1, 5, 8, 9, 10, 11, 12, 13]\n\n\ndef test_ref_router_actor_single_step_and_sequence_logp_match_contract():\n    actor = _make_ref_router_actor()\n    obs_td = _make_ref_router_obs([2])\n\n    inference_out = actor(\n        obs_td,\n        mode=\"inference\",\n        update_obs_norm=False,\n    )\n    assert inference_out[\"actions\"].shape == (2, 4)\n    assert inference_out[\"mu\"].shape == (2, 4)\n    assert inference_out[\"sigma\"].shape == (2, 4)\n\n    cache_shape = actor.onnx_past_key_values_shape(batch_size=2)\n    past_key_values = torch.zeros(*cache_shape, dtype=torch.float32)\n    step_idx = torch.zeros(2, dtype=torch.long)\n    with torch.no_grad():\n        actions, present = actor(\n            obs_td,\n            past_key_values=past_key_values,\n            current_pos=step_idx,\n        )\n    assert actions.shape == (2, 4)\n    assert present.shape == cache_shape\n\n    obs_seq = _make_ref_router_obs([2, 3])\n    actions_seq = torch.randn(2, 3, 4)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n    seq_out = actor(\n        obs_seq,\n        actions=actions_seq,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=False,\n    )\n\n    assert seq_out[\"mu\"].shape == (2, 3, 4)\n    assert seq_out[\"sigma\"].shape == (2, 3, 4)\n    assert seq_out[\"actions_log_prob\"].shape == (2, 3, 1)\n    assert seq_out[\"entropy\"].shape == (2, 3, 1)\n\n\ndef test_ref_router_actor_rejects_aux_router_future_recon():\n    with pytest.raises(\n        ValueError,\n        match=\"does not support aux_router_future_recon\",\n    ):\n        _make_ref_router_actor(\n            aux_router_future_recon={\"enabled\": True, \"weight\": 1.0}\n        )\n\n\ndef test_ppotf_select_actor_wrapper_rejects_ref_router_cross_attn():\n    with pytest.raises(\n        ValueError,\n        match=\"ReferenceRoutedGroupedMoETransformerPolicy\",\n    ):\n        PPOTF._select_actor_wrapper_cls(\n            {\n                \"type\": \"ReferenceRoutedGroupedMoETransformerPolicy\",\n                \"use_future_cross_attn\": True,\n            }\n        )\n\n\ndef test_ref_router_actor_freeze_router_freezes_router_obs_embed():\n    actor = _make_ref_router_actor(freeze_router=True)\n    module = actor.actor_module\n\n    assert module.freeze_router is True\n    assert module.router_obs_embed[0].weight.requires_grad is False\n    assert module.router_obs_embed[0].bias.requires_grad is False\n    assert module.router_obs_embed[2].weight.requires_grad is False\n    assert module.router_obs_embed[2].bias.requires_grad is False\n\n\ndef test_ref_router_actor_freeze_router_reapplies_after_load_state_dict():\n    actor = _make_ref_router_actor(freeze_router=True)\n    module = actor.actor_module\n    state_dict = module.state_dict()\n\n    module.router_obs_embed.requires_grad_(True)\n    for layer in module.layers:\n        if hasattr(layer, \"router\"):\n            layer.router.requires_grad_(True)\n\n    result = module.load_state_dict(state_dict, strict=True)\n\n    assert result.missing_keys == []\n    assert result.unexpected_keys == []\n    assert module.router_obs_embed[0].weight.requires_grad is False\n    assert module.router_obs_embed[0].bias.requires_grad is False\n    assert module.router_obs_embed[2].weight.requires_grad is False\n    assert module.router_obs_embed[2].bias.requires_grad is False\n    for layer in module.layers:\n        if hasattr(layer, \"router\"):\n            assert layer.router.weight.requires_grad is False\n"
  },
  {
    "path": "tests/test_ref_router_seq_actor.py",
    "content": "from __future__ import annotations\n\nimport pytest\nimport torch\nfrom holomotion.src.algo.ppo_tf import PPOTF\nfrom holomotion.src.modules.agent_modules import (\n    PPOTFRefRouterSeqActor,\n    PPOTFRefRouterV3Actor,\n)\nfrom tensordict import TensorDict\n\n\nREF_CUR_TERM_DIMS = {\n    \"actor_ref_gravity_projection_cur\": 3,\n    \"actor_ref_base_linvel_cur\": 3,\n    \"actor_ref_base_angvel_cur\": 3,\n    \"actor_ref_dof_pos_cur\": 2,\n    \"actor_ref_root_height_cur\": 1,\n}\n\nREF_FUT_TERM_DIMS = {\n    \"actor_ref_gravity_projection_fut\": 3,\n    \"actor_ref_base_linvel_fut\": 3,\n    \"actor_ref_base_angvel_fut\": 3,\n    \"actor_ref_dof_pos_fut\": 2,\n    \"actor_ref_root_height_fut\": 1,\n}\n\n\ndef _make_ref_router_v2_obs_schema(\n    *,\n    include_ref_cur: bool = True,\n    include_ref_fut: bool = True,\n) -> dict:\n    flat_terms = []\n    if include_ref_cur:\n        flat_terms.extend(\n            [\n                \"unified/actor_ref_gravity_projection_cur\",\n                \"unified/actor_ref_base_linvel_cur\",\n                \"unified/actor_ref_base_angvel_cur\",\n                \"unified/actor_ref_dof_pos_cur\",\n                \"unified/actor_ref_root_height_cur\",\n            ]\n        )\n    flat_terms.extend(\n        [\n            \"unified/actor_projected_gravity\",\n            \"unified/actor_rel_robot_root_ang_vel\",\n            \"unified/actor_dof_pos\",\n            \"unified/actor_dof_vel\",\n            \"unified/actor_last_action\",\n        ]\n    )\n    schema = {\n        \"flattened_obs\": {\"seq_len\": 1, \"terms\": flat_terms},\n    }\n    if include_ref_fut:\n        schema[\"flattened_obs_fut\"] = {\n            \"seq_len\": 5,\n            \"terms\": [\n                \"unified/actor_ref_gravity_projection_fut\",\n                \"unified/actor_ref_base_linvel_fut\",\n                \"unified/actor_ref_base_angvel_fut\",\n                \"unified/actor_ref_dof_pos_fut\",\n                \"unified/actor_ref_root_height_fut\",\n            ],\n        }\n    return schema\n\n\ndef _make_ref_router_v2_obs(batch_size: list[int]) -> TensorDict:\n    shape = list(batch_size)\n    fut_shape = shape + [5]\n    unified = TensorDict(\n        {\n            \"actor_ref_gravity_projection_cur\": torch.randn(*shape, 3),\n            \"actor_ref_base_linvel_cur\": torch.randn(*shape, 3),\n            \"actor_ref_base_angvel_cur\": torch.randn(*shape, 3),\n            \"actor_ref_dof_pos_cur\": torch.randn(*shape, 2),\n            \"actor_ref_root_height_cur\": torch.randn(*shape, 1),\n            \"actor_projected_gravity\": torch.randn(*shape, 3),\n            \"actor_rel_robot_root_ang_vel\": torch.randn(*shape, 3),\n            \"actor_dof_pos\": torch.randn(*shape, 4),\n            \"actor_dof_vel\": torch.randn(*shape, 4),\n            \"actor_last_action\": torch.randn(*shape, 2),\n            \"actor_ref_gravity_projection_fut\": torch.randn(*fut_shape, 3),\n            \"actor_ref_base_linvel_fut\": torch.randn(*fut_shape, 3),\n            \"actor_ref_base_angvel_fut\": torch.randn(*fut_shape, 3),\n            \"actor_ref_dof_pos_fut\": torch.randn(*fut_shape, 2),\n            \"actor_ref_root_height_fut\": torch.randn(*fut_shape, 1),\n        },\n        batch_size=shape,\n    )\n    return TensorDict({\"unified\": unified}, batch_size=shape)\n\n\ndef _make_ref_router_v2_actor(\n    *,\n    obs_schema: dict | None = None,\n    num_actions: int = 4,\n    aux_state_pred: dict | None = None,\n    aux_router_command_recon: dict | None = None,\n    freeze_router: bool = False,\n) -> PPOTFRefRouterSeqActor:\n    obs_schema = (\n        _make_ref_router_v2_obs_schema() if obs_schema is None else obs_schema\n    )\n    obs_example = _make_ref_router_v2_obs([2])\n    module_config = {\n        \"type\": \"ReferenceRoutedGroupedMoETransformerPolicyV2\",\n        \"num_fine_experts\": 3,\n        \"num_shared_experts\": 1,\n        \"top_k\": 1,\n        \"routing_score_fn\": \"softmax\",\n        \"routing_scale\": 1.0,\n        \"use_dynamic_bias\": False,\n        \"bias_update_rate\": 0.001,\n        \"expert_bias_clip\": 0.0,\n        \"obs_embed_mlp_hidden\": 16,\n        \"d_model\": 8,\n        \"n_layers\": 2,\n        \"n_heads\": 2,\n        \"n_kv_heads\": 1,\n        \"use_gated_attn\": False,\n        \"use_qk_norm\": True,\n        \"ff_mult\": 1.0,\n        \"ff_mult_dense\": 2,\n        \"attn_dropout\": 0.0,\n        \"mlp_dropout\": 0.0,\n        \"max_ctx_len\": 4,\n        \"freeze_router\": freeze_router,\n        \"ref_hist_n_layers\": 1,\n        \"ref_future_conv_channels\": 8,\n        \"ref_future_conv_layers\": 2,\n        \"ref_future_conv_kernel_size\": 3,\n        \"ref_future_conv_stride\": 2,\n        \"obs_norm\": {\"enabled\": False},\n        \"output_dim\": num_actions,\n        \"aux_state_pred\": aux_state_pred or {\"enabled\": False},\n        \"aux_router_command_recon\": aux_router_command_recon\n        or {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n    }\n    return PPOTFRefRouterSeqActor(\n        obs_schema=obs_schema,\n        module_config_dict=module_config,\n        num_actions=num_actions,\n        init_noise_std=0.2,\n        obs_example=obs_example,\n    )\n\n\ndef _make_ref_router_v3_actor(\n    *,\n    obs_schema: dict | None = None,\n    num_actions: int = 4,\n    freeze_router: bool = False,\n    aux_router_future_recon: dict | None = None,\n) -> PPOTFRefRouterV3Actor:\n    obs_schema = (\n        _make_ref_router_v2_obs_schema() if obs_schema is None else obs_schema\n    )\n    obs_example = _make_ref_router_v2_obs([2])\n    module_config = {\n        \"type\": \"ReferenceRoutedGroupedMoETransformerPolicyV3\",\n        \"num_fine_experts\": 3,\n        \"num_shared_experts\": 1,\n        \"top_k\": 1,\n        \"routing_score_fn\": \"softmax\",\n        \"routing_scale\": 1.0,\n        \"use_dynamic_bias\": False,\n        \"bias_update_rate\": 0.001,\n        \"expert_bias_clip\": 0.0,\n        \"obs_embed_mlp_hidden\": 16,\n        \"d_model\": 8,\n        \"n_layers\": 2,\n        \"n_heads\": 2,\n        \"n_kv_heads\": 1,\n        \"use_gated_attn\": False,\n        \"use_qk_norm\": True,\n        \"ff_mult\": 1.0,\n        \"ff_mult_dense\": 2,\n        \"attn_dropout\": 0.0,\n        \"mlp_dropout\": 0.0,\n        \"max_ctx_len\": 4,\n        \"freeze_router\": freeze_router,\n        \"ref_hist_n_layers\": 1,\n        \"router_future_hidden_dim\": 12,\n        \"router_layer_proj_hidden_dim\": 10,\n        \"obs_norm\": {\"enabled\": False},\n        \"output_dim\": num_actions,\n        \"aux_state_pred\": {\"enabled\": False},\n        \"aux_router_command_recon\": {\"enabled\": False},\n        \"aux_router_future_recon\": aux_router_future_recon\n        or {\"enabled\": False},\n        \"aux_router_switch_penalty\": {\"enabled\": False},\n    }\n    return PPOTFRefRouterV3Actor(\n        obs_schema=obs_schema,\n        module_config_dict=module_config,\n        num_actions=num_actions,\n        init_noise_std=0.2,\n        obs_example=obs_example,\n    )\n\n\ndef test_ppotf_select_actor_wrapper_uses_ref_router_seq_actor():\n    actor_cls = PPOTF._select_actor_wrapper_cls(\n        {\"type\": \"ReferenceRoutedGroupedMoETransformerPolicyV2\"}\n    )\n\n    assert actor_cls is PPOTFRefRouterSeqActor\n\n\ndef test_ppotf_select_actor_wrapper_uses_ref_router_v3_actor():\n    actor_cls = PPOTF._select_actor_wrapper_cls(\n        {\"type\": \"ReferenceRoutedGroupedMoETransformerPolicyV3\"}\n    )\n\n    assert actor_cls is PPOTFRefRouterV3Actor\n\n\ndef test_ref_router_seq_actor_infers_shared_ref_partitions_without_router_schemas():\n    actor = _make_ref_router_v2_actor()\n\n    assert actor.state_obs_input_dim > 0\n    assert actor.ref_cur_token_dim == sum(REF_CUR_TERM_DIMS.values())\n    assert actor.ref_fut_token_dim == sum(REF_FUT_TERM_DIMS.values())\n    assert actor.ref_fut_seq_len == 5\n\n    cache_shape = actor.onnx_past_key_values_shape(batch_size=2)\n    assert len(cache_shape) == 6\n    assert cache_shape[0] == actor.actor_module.onnx_kv_layers\n    assert cache_shape[1] == 2\n    assert cache_shape[2] == 2\n    assert cache_shape[-1] == 4\n\n\ndef test_ref_router_v3_actor_keeps_full_obs_backbone_and_layer_router_adapters():\n    actor = _make_ref_router_v3_actor()\n    module = actor.actor_module\n\n    assert actor.full_obs_input_dim > actor.state_obs_input_dim\n    assert module.full_obs_input_dim == actor.full_obs_input_dim\n    assert module.obs_embed[0].in_features == actor.full_obs_input_dim\n    assert len(module.router_layer_projections) == sum(\n        isinstance(layer, type(module.layers[1])) for layer in module.layers\n    )\n\n\ndef test_ref_router_v3_history_backbone_consumes_flat_ref_motion():\n    actor = _make_ref_router_v3_actor()\n    module = actor.actor_module\n    x = torch.randn(2, 3, module.full_obs_input_dim)\n\n    _, ref_cur_x, ref_fut_x = module._split_actor_ref_inputs(x)\n\n    assert hasattr(module, \"_build_router_ref_motion\")\n    ref_motion_x = module._build_router_ref_motion(ref_cur_x, ref_fut_x)\n    assert ref_motion_x.shape == (\n        2,\n        3,\n        actor.ref_cur_token_dim\n        + actor.ref_fut_seq_len * actor.ref_fut_token_dim,\n    )\n    assert module.ref_frame_embed[0].in_features == ref_motion_x.shape[-1]\n    assert not hasattr(module, \"router_future_obs_embed\")\n    assert not hasattr(module, \"router_future_pool\")\n    assert not hasattr(module, \"router_summary_fusion\")\n    assert not hasattr(module, \"router_summary_norm\")\n\n\ndef test_ref_router_v3_actor_sequence_logp_emits_aux_router_future_recon():\n    actor = _make_ref_router_v3_actor(\n        aux_router_future_recon={\n            \"enabled\": True,\n            \"hidden_dim\": 9,\n            \"weight\": 1.0,\n        }\n    )\n    obs_td = _make_ref_router_v2_obs([2, 3])\n    actions_seq = torch.randn(2, 3, 4)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    seq_out = actor(\n        obs_td,\n        actions=actions_seq,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=False,\n    )\n\n    assert seq_out[\"aux_router_future_recon\"].shape == (\n        2,\n        3,\n        actor.ref_fut_seq_len * actor.ref_fut_token_dim,\n    )\n\n\ndef test_ref_router_v3_actor_updates_future_recon_empirical_normalizer():\n    actor = _make_ref_router_v3_actor(\n        aux_router_future_recon={\n            \"enabled\": True,\n            \"hidden_dim\": 9,\n            \"weight\": 1.0,\n        }\n    )\n    obs_td = _make_ref_router_v2_obs([2, 3])\n    actions_seq = torch.randn(2, 3, 4)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    assert actor.aux_router_future_recon_assembler is not None\n    assert (\n        int(actor.actor_module.aux_router_future_recon_normalizer.count) == 0\n    )\n\n    actor(\n        obs_td,\n        actions=actions_seq,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=True,\n    )\n\n    assert (\n        int(actor.actor_module.aux_router_future_recon_normalizer.count) == 6\n    )\n\n\ndef test_ref_router_v2_freeze_router_freezes_reference_router_path():\n    actor = _make_ref_router_v2_actor(freeze_router=True)\n    module = actor.actor_module\n    x = torch.randn(2, 3, module.full_obs_input_dim, requires_grad=True)\n\n    mu, router_h, router_temporal_features = module.sequence_mu(\n        x,\n        return_ref_aux_hidden=True,\n        return_router_temporal_features=True,\n    )\n    mu.sum().backward()\n\n    first_moe = next(\n        layer for layer in module.layers if hasattr(layer, \"router\")\n    )\n    assert module.freeze_router is True\n    assert module.ref_frame_embed[0].weight.requires_grad is False\n    assert module.ref_hist_attn.q_proj.weight.requires_grad is False\n    assert module.ref_future_conv[0].weight.requires_grad is False\n    assert module.router_ref_pool.q_proj.weight.requires_grad is False\n    assert module.router_query.requires_grad is False\n    assert first_moe.router.weight.requires_grad is False\n    assert router_h.requires_grad is False\n    assert router_temporal_features.requires_grad is False\n    assert module.ref_frame_embed[0].weight.grad is None\n    assert module.ref_hist_attn.q_proj.weight.grad is None\n    assert module.ref_future_conv[0].weight.grad is None\n    assert module.router_ref_pool.q_proj.weight.grad is None\n    assert module.router_query.grad is None\n    assert first_moe.router.weight.grad is None\n\n\ndef test_ref_router_v3_freeze_router_reapplies_after_load_state_dict():\n    actor = _make_ref_router_v3_actor(freeze_router=True)\n    module = actor.actor_module\n    state_dict = module.state_dict()\n\n    module.ref_frame_embed.requires_grad_(True)\n    module.ref_hist_attn.requires_grad_(True)\n    module.router_layer_projections.requires_grad_(True)\n    for layer in module.layers:\n        if hasattr(layer, \"router\"):\n            layer.router.requires_grad_(True)\n\n    result = module.load_state_dict(state_dict, strict=True)\n\n    assert result.missing_keys == []\n    assert result.unexpected_keys == []\n    assert module.ref_frame_embed[0].weight.requires_grad is False\n    assert module.ref_hist_attn.q_proj.weight.requires_grad is False\n    assert module.router_layer_projections[0][1].weight.requires_grad is False\n    assert not hasattr(module, \"router_future_obs_embed\")\n    assert not hasattr(module, \"router_future_pool\")\n    assert not hasattr(module, \"router_summary_fusion\")\n    assert not hasattr(module, \"router_summary_norm\")\n    for layer in module.layers:\n        if hasattr(layer, \"router\"):\n            assert layer.router.weight.requires_grad is False\n\n\ndef test_ref_router_v2_film_head_starts_near_identity():\n    actor = _make_ref_router_v2_actor()\n    module = actor.actor_module\n\n    assert hasattr(module, \"_actor_film_gain\")\n    gains = module._actor_film_gain().detach()\n    assert gains.shape == (module.d_model,)\n    assert torch.allclose(gains, torch.full_like(gains, 0.05), atol=1.0e-5)\n    last_linear = module.actor_ref_film[-1]\n    assert torch.count_nonzero(last_linear.weight.detach()) == 0\n    assert torch.count_nonzero(last_linear.bias.detach()) == 0\n\n    hidden = torch.randn(2, 1, module.d_model)\n    actor_ref_ctx = torch.randn(2, 1, module.d_model)\n    conditioned = module._apply_actor_ref_film(hidden, actor_ref_ctx)\n    assert torch.allclose(conditioned, hidden)\n\n\ndef test_ref_router_v2_pre_moe_hidden_precedes_film_modulation():\n    actor = _make_ref_router_v2_actor()\n    module = actor.actor_module\n\n    with torch.no_grad():\n        module.actor_film_gain_raw.fill_(100.0)\n        module.actor_ref_film[-1].weight.zero_()\n        module.actor_ref_film[-1].bias.fill_(2.0)\n\n    x = torch.randn(2, module.full_obs_input_dim)\n    x_seq = x[:, None, :]\n    state_x, ref_cur_x, ref_fut_x = module._split_actor_ref_inputs(x_seq)\n    state_h = module.obs_embed(state_x)\n    ref_cur_h = module.ref_frame_embed(ref_cur_x)\n    ref_hist_attn = module.ref_hist_attn(\n        module.ref_hist_norm(ref_cur_h),\n        *module.get_cos_sin(ref_cur_h, torch.zeros(2, 1, dtype=torch.long)),\n        mask=None,\n    )\n    ref_hist_h = module.ref_hist_out_norm(ref_cur_h + ref_hist_attn)\n    ref_fut_tokens = module._encode_future_tokens(ref_fut_x)\n    shared_ref_tokens = torch.cat(\n        [ref_hist_h.unsqueeze(2), ref_fut_tokens], dim=2\n    )\n    router_h = module._pool_router_context(shared_ref_tokens)\n    cos, sin = module.get_cos_sin(state_h, torch.zeros(2, 1, dtype=torch.long))\n    block0_hidden = module._forward_layers_range(\n        state_h,\n        cos=cos,\n        sin=sin,\n        mask=None,\n        router_h=router_h,\n        start_layer=0,\n        end_layer=1,\n    )\n\n    _, pre_moe_hidden = module.sequence_mu(\n        x_seq,\n        return_pre_moe_hidden=True,\n    )\n\n    assert torch.allclose(pre_moe_hidden, block0_hidden)\n\n\ndef test_ref_router_v2_film_gain_is_bounded_per_channel():\n    actor = _make_ref_router_v2_actor()\n    module = actor.actor_module\n\n    assert hasattr(module, \"actor_film_gain_raw\")\n    with torch.no_grad():\n        module.actor_film_gain_raw.copy_(\n            torch.linspace(-100.0, 100.0, module.d_model)\n        )\n\n    gains = module._actor_film_gain()\n\n    assert gains.shape == (module.d_model,)\n    assert torch.all(gains >= 0.0)\n    assert torch.all(gains <= module.actor_film_gain_max + 1.0e-6)\n    assert torch.unique(gains).numel() > 1\n\n\ndef test_ref_router_v2_film_perturbation_rms_stays_bounded():\n    actor = _make_ref_router_v2_actor()\n    module = actor.actor_module\n\n    assert hasattr(module, \"actor_film_gain_raw\")\n    with torch.no_grad():\n        module.actor_ref_film[-1].weight.zero_()\n        module.actor_ref_film[-1].bias.fill_(100.0)\n        module.actor_film_gain_raw.fill_(100.0)\n\n    hidden = torch.randn(4, 3, module.d_model)\n    actor_ref_ctx = torch.randn(4, 3, module.d_model)\n    conditioned = module._apply_actor_ref_film(hidden, actor_ref_ctx)\n    delta = conditioned - hidden\n    delta_rms = delta.pow(2).mean(dim=-1).sqrt()\n\n    assert torch.all(delta_rms <= module.actor_film_gain_max + 1.0e-5)\n\n\ndef test_ref_router_v2_aux_prediction_stays_bound_to_returned_pre_moe_hidden():\n    actor = _make_ref_router_v2_actor(\n        aux_state_pred={\n            \"enabled\": True,\n            \"w_base_lin_vel\": 1.0,\n            \"w_keybody_contact\": 1.0,\n            \"w_ref_keybody_rel_pos\": 1.0,\n            \"w_robot_keybody_rel_pos\": 1.0,\n            \"keybody_contact_names\": [\"knee\"],\n            \"keybody_rel_pos_names\": [\"knee\"],\n        }\n    )\n    module = actor.actor_module\n    module.eval()\n\n    x_a = torch.randn(1, 1, module.full_obs_input_dim)\n    x_b = x_a + 0.5\n\n    with torch.no_grad():\n        _, pre_a, ref_aux_a = module.sequence_mu(\n            x_a,\n            return_pre_moe_hidden=True,\n            return_ref_aux_hidden=True,\n        )\n        aux_a = module.predict_aux_from_pre_moe(\n            pre_a, ref_aux_hidden=ref_aux_a\n        )\n        _, pre_b, ref_aux_b = module.sequence_mu(\n            x_b,\n            return_pre_moe_hidden=True,\n            return_ref_aux_hidden=True,\n        )\n        aux_a_late = module.predict_aux_from_pre_moe(\n            pre_a, ref_aux_hidden=ref_aux_a\n        )\n        aux_b = module.predict_aux_from_pre_moe(\n            pre_b, ref_aux_hidden=ref_aux_b\n        )\n\n    assert torch.allclose(\n        aux_a_late[\"ref_keybody_rel_pos\"], aux_a[\"ref_keybody_rel_pos\"]\n    )\n    assert torch.allclose(\n        aux_a_late[\"base_lin_vel_loc\"], aux_a[\"base_lin_vel_loc\"]\n    )\n    assert not torch.allclose(\n        aux_a[\"ref_keybody_rel_pos\"], aux_b[\"ref_keybody_rel_pos\"]\n    )\n    assert not hasattr(pre_a, \"_ref_aux_hidden\")\n\n\ndef test_ref_router_v2_sequence_single_step_and_cached_onnx_agree():\n    actor = _make_ref_router_v2_actor()\n    module = actor.actor_module\n    module.eval()\n\n    x_seq = torch.randn(1, 2, module.full_obs_input_dim)\n    attn_mask = torch.tril(torch.ones(2, 2, dtype=torch.bool)).unsqueeze(0)\n\n    with torch.no_grad():\n        mu_seq = module.sequence_mu(x_seq, attn_mask=attn_mask)\n\n        module.reset_kv_cache(num_envs=1, device=x_seq.device)\n        mu_step_0 = module.single_step_mu(x_seq[:, 0, :])\n        mu_step_1 = module.single_step_mu(x_seq[:, 1, :])\n        mu_single_step = torch.stack([mu_step_0, mu_step_1], dim=1)\n\n        cache_shape = actor.onnx_past_key_values_shape(batch_size=1)\n        past_key_values = torch.zeros(*cache_shape, dtype=x_seq.dtype)\n        step_0 = torch.zeros(1, dtype=torch.long)\n        step_1 = torch.ones(1, dtype=torch.long)\n        mu_onnx_0, present_0 = module.forward(\n            x_seq[:, 0, :],\n            past_key_values=past_key_values,\n            current_pos=step_0,\n        )\n        mu_onnx_1, present_1 = module.forward(\n            x_seq[:, 1, :],\n            past_key_values=present_0,\n            current_pos=step_1,\n        )\n        mu_onnx = torch.stack([mu_onnx_0, mu_onnx_1], dim=1)\n\n    assert torch.allclose(mu_single_step, mu_seq, atol=1.0e-5, rtol=1.0e-4)\n    assert torch.allclose(mu_onnx, mu_seq, atol=1.0e-5, rtol=1.0e-4)\n    assert present_0.shape == cache_shape\n    assert present_1.shape == cache_shape\n\n\ndef test_ref_router_v3_sequence_single_step_and_cached_onnx_agree():\n    actor = _make_ref_router_v3_actor()\n    module = actor.actor_module\n    module.eval()\n\n    x_seq = torch.randn(1, 2, module.full_obs_input_dim)\n    attn_mask = torch.tril(torch.ones(2, 2, dtype=torch.bool)).unsqueeze(0)\n\n    with torch.no_grad():\n        mu_seq = module.sequence_mu(x_seq, attn_mask=attn_mask)\n\n        module.reset_kv_cache(num_envs=1, device=x_seq.device)\n        mu_step_0 = module.single_step_mu(x_seq[:, 0, :])\n        mu_step_1 = module.single_step_mu(x_seq[:, 1, :])\n        mu_single_step = torch.stack([mu_step_0, mu_step_1], dim=1)\n\n        cache_shape = actor.onnx_past_key_values_shape(batch_size=1)\n        past_key_values = torch.zeros(*cache_shape, dtype=x_seq.dtype)\n        step_0 = torch.zeros(1, dtype=torch.long)\n        step_1 = torch.ones(1, dtype=torch.long)\n        mu_onnx_0, present_0 = module.forward(\n            x_seq[:, 0, :],\n            past_key_values=past_key_values,\n            current_pos=step_0,\n        )\n        mu_onnx_1, present_1 = module.forward(\n            x_seq[:, 1, :],\n            past_key_values=present_0,\n            current_pos=step_1,\n        )\n        mu_onnx = torch.stack([mu_onnx_0, mu_onnx_1], dim=1)\n\n    assert torch.allclose(mu_single_step, mu_seq, atol=1.0e-5, rtol=1.0e-4)\n    assert torch.allclose(mu_onnx, mu_seq, atol=1.0e-5, rtol=1.0e-4)\n    assert present_0.shape == cache_shape\n    assert present_1.shape == cache_shape\n\n\ndef test_ref_router_seq_actor_single_step_and_sequence_logp_match_contract():\n    actor = _make_ref_router_v2_actor()\n    obs_td = _make_ref_router_v2_obs([2])\n\n    inference_out = actor(\n        obs_td,\n        mode=\"inference\",\n        update_obs_norm=False,\n    )\n    assert inference_out[\"actions\"].shape == (2, 4)\n    assert inference_out[\"mu\"].shape == (2, 4)\n    assert inference_out[\"sigma\"].shape == (2, 4)\n\n    cache_shape = actor.onnx_past_key_values_shape(batch_size=2)\n    past_key_values = torch.zeros(*cache_shape, dtype=torch.float32)\n    step_idx = torch.zeros(2, dtype=torch.long)\n    with torch.no_grad():\n        actions, present = actor(\n            obs_td,\n            past_key_values=past_key_values,\n            current_pos=step_idx,\n        )\n    assert actions.shape == (2, 4)\n    assert present.shape == cache_shape\n\n    obs_seq = _make_ref_router_v2_obs([2, 3])\n    actions_seq = torch.randn(2, 3, 4)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n    seq_out = actor(\n        obs_seq,\n        actions=actions_seq,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=False,\n    )\n\n    assert seq_out[\"mu\"].shape == (2, 3, 4)\n    assert seq_out[\"sigma\"].shape == (2, 3, 4)\n    assert seq_out[\"actions_log_prob\"].shape == (2, 3, 1)\n    assert seq_out[\"entropy\"].shape == (2, 3, 1)\n\n\ndef test_ref_router_v3_actor_single_step_and_sequence_logp_match_contract():\n    actor = _make_ref_router_v3_actor()\n    obs_td = _make_ref_router_v2_obs([2])\n\n    inference_out = actor(\n        obs_td,\n        mode=\"inference\",\n        update_obs_norm=False,\n    )\n    assert inference_out[\"actions\"].shape == (2, 4)\n    assert inference_out[\"mu\"].shape == (2, 4)\n    assert inference_out[\"sigma\"].shape == (2, 4)\n\n    cache_shape = actor.onnx_past_key_values_shape(batch_size=2)\n    past_key_values = torch.zeros(*cache_shape, dtype=torch.float32)\n    step_idx = torch.zeros(2, dtype=torch.long)\n    with torch.no_grad():\n        actions, present = actor(\n            obs_td,\n            past_key_values=past_key_values,\n            current_pos=step_idx,\n        )\n    assert actions.shape == (2, 4)\n    assert present.shape == cache_shape\n\n    obs_seq = _make_ref_router_v2_obs([2, 3])\n    actions_seq = torch.randn(2, 3, 4)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n    seq_out = actor(\n        obs_seq,\n        actions=actions_seq,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=False,\n    )\n\n    assert seq_out[\"mu\"].shape == (2, 3, 4)\n    assert seq_out[\"sigma\"].shape == (2, 3, 4)\n    assert seq_out[\"actions_log_prob\"].shape == (2, 3, 1)\n    assert seq_out[\"entropy\"].shape == (2, 3, 1)\n\n\ndef test_ref_router_seq_actor_sequence_logp_emits_aux_preds_without_metadata():\n    actor = _make_ref_router_v2_actor(\n        aux_state_pred={\n            \"enabled\": True,\n            \"w_base_lin_vel\": 1.0,\n            \"w_keybody_contact\": 1.0,\n            \"w_ref_keybody_rel_pos\": 1.0,\n            \"w_robot_keybody_rel_pos\": 1.0,\n            \"keybody_contact_names\": [\"knee\"],\n            \"keybody_rel_pos_names\": [\"knee\"],\n        }\n    )\n\n    obs_seq = _make_ref_router_v2_obs([2, 3])\n    actions_seq = torch.randn(2, 3, 4)\n    attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(\n        2, -1, -1\n    )\n\n    seq_out = actor(\n        obs_seq,\n        actions=actions_seq,\n        mode=\"sequence_logp\",\n        attn_mask=attn_mask,\n        update_obs_norm=False,\n    )\n\n    assert \"aux_ref_keybody_rel_pos\" in seq_out.keys()\n    assert \"aux_robot_keybody_rel_pos\" in seq_out.keys()\n    assert \"aux_base_lin_vel_loc\" in seq_out.keys()\n    assert seq_out[\"aux_ref_keybody_rel_pos\"].shape == (2, 3, 1, 3)\n    assert seq_out[\"aux_robot_keybody_rel_pos\"].shape == (2, 3, 1, 3)\n\n\ndef test_ref_router_seq_actor_requires_all_shared_ref_terms():\n    obs_schema = _make_ref_router_v2_obs_schema(include_ref_cur=False)\n\n    with pytest.raises(ValueError, match=\"missing required current ref term\"):\n        _make_ref_router_v2_actor(obs_schema=obs_schema)\n\n\ndef test_ref_router_seq_actor_rejects_aux_router_command_recon():\n    with pytest.raises(ValueError, match=\"aux_router_command_recon\"):\n        _make_ref_router_v2_actor(\n            aux_router_command_recon={\"enabled\": True, \"hidden_dim\": 8}\n        )\n\n\ndef test_ref_router_seq_actor_rejects_unsupported_aux_state_pred_weights():\n    with pytest.raises(ValueError, match=\"root_height\"):\n        _make_ref_router_v2_actor(\n            aux_state_pred={\n                \"enabled\": True,\n                \"w_base_lin_vel\": 0.0,\n                \"w_keybody_contact\": 0.0,\n                \"w_ref_keybody_rel_pos\": 0.0,\n                \"w_robot_keybody_rel_pos\": 0.0,\n                \"w_root_height\": 1.0,\n                \"keybody_contact_names\": [],\n                \"keybody_rel_pos_names\": [],\n            }\n        )\n"
  },
  {
    "path": "tests/test_reference_filter_export.py",
    "content": "import json\nimport sys\nimport tempfile\nimport unittest\nfrom unittest import mock\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom omegaconf import OmegaConf\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.training.h5_dataloader import MotionClipSample\nfrom holomotion.src.training.reference_filter_export import (\n    export_reference_filter_artifacts_from_config,\n    export_reference_filter_debug_artifacts,\n)\n\n\ndef _quat_xyzw_from_rpy(roll: float, pitch: float, yaw: float) -> torch.Tensor:\n    cr = np.cos(roll * 0.5)\n    sr = np.sin(roll * 0.5)\n    cp = np.cos(pitch * 0.5)\n    sp = np.sin(pitch * 0.5)\n    cy = np.cos(yaw * 0.5)\n    sy = np.sin(yaw * 0.5)\n    return torch.tensor(\n        [\n            sr * cp * cy - cr * sp * sy,\n            cr * sp * cy + sr * cp * sy,\n            cr * cp * sy - sr * sp * cy,\n            cr * cp * cy + sr * sp * sy,\n        ],\n        dtype=torch.float32,\n    )\n\n\ndef _make_sample(*, include_filtered: bool = True) -> MotionClipSample:\n    timesteps = 4\n    num_bodies = 5\n    num_dofs = 3\n\n    ref_rg_pos = torch.arange(\n        timesteps * num_bodies * 3, dtype=torch.float32\n    ).reshape(timesteps, num_bodies, 3)\n    ref_body_vel = ref_rg_pos + 100.0\n    ref_body_ang_vel = ref_rg_pos + 200.0\n    ref_dof_pos = torch.arange(\n        timesteps * num_dofs, dtype=torch.float32\n    ).reshape(timesteps, num_dofs)\n    ref_dof_vel = ref_dof_pos + 50.0\n\n    ref_rb_rot = torch.stack(\n        [\n            _quat_xyzw_from_rpy(0.0, 0.0, 0.0),\n            _quat_xyzw_from_rpy(0.1, -0.2, 0.3),\n            _quat_xyzw_from_rpy(0.2, -0.1, 0.4),\n            _quat_xyzw_from_rpy(0.3, 0.0, 0.5),\n        ],\n        dim=0,\n    )[:, None, :].repeat(1, num_bodies, 1)\n\n    tensors = {\n        \"ref_rg_pos\": ref_rg_pos,\n        \"ref_rb_rot\": ref_rb_rot,\n        \"ref_body_vel\": ref_body_vel,\n        \"ref_body_ang_vel\": ref_body_ang_vel,\n        \"ref_root_pos\": ref_rg_pos[:, 0, :],\n        \"ref_root_rot\": ref_rb_rot[:, 0, :],\n        \"ref_root_vel\": ref_body_vel[:, 0, :],\n        \"ref_root_ang_vel\": ref_body_ang_vel[:, 0, :],\n        \"ref_dof_pos\": ref_dof_pos,\n        \"ref_dof_vel\": ref_dof_vel,\n        \"filter_cutoff_hz\": torch.full(\n            (timesteps, 1), 2.0, dtype=torch.float32\n        ),\n    }\n    if include_filtered:\n        tensors.update(\n            {\n                \"ft_ref_rg_pos\": ref_rg_pos + 0.5,\n                \"ft_ref_rb_rot\": ref_rb_rot.clone(),\n                \"ft_ref_body_vel\": ref_body_vel + 0.25,\n                \"ft_ref_body_ang_vel\": ref_body_ang_vel + 0.25,\n                \"ft_ref_root_pos\": ref_rg_pos[:, 0, :] + 0.5,\n                \"ft_ref_root_rot\": ref_rb_rot[:, 0, :].clone(),\n                \"ft_ref_root_vel\": ref_body_vel[:, 0, :] + 0.25,\n                \"ft_ref_root_ang_vel\": ref_body_ang_vel[:, 0, :] + 0.25,\n                \"ft_ref_dof_pos\": ref_dof_pos + 0.75,\n                \"ft_ref_dof_vel\": ref_dof_vel + 0.75,\n            }\n        )\n\n    return MotionClipSample(\n        motion_key=\"clip-a__start_0_len_4\",\n        raw_motion_key=\"clip-a\",\n        window_index=0,\n        tensors=tensors,\n        length=timesteps,\n    )\n\n\nclass ReferenceFilterExportTests(unittest.TestCase):\n    def test_export_reference_filter_artifacts_from_config_builds_dataset(\n        self,\n    ):\n        sample = _make_sample()\n        body_names = [\n            \"root_link\",\n            \"torso_link\",\n            \"left_wrist_yaw_link\",\n            \"right_wrist_yaw_link\",\n            \"left_ankle_roll_link\",\n        ]\n        dof_names = [\n            \"waist_yaw_joint\",\n            \"left_wrist_yaw_joint\",\n            \"left_ankle_roll_joint\",\n        ]\n\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            config = OmegaConf.create(\n                {\n                    \"robot\": {\n                        \"body_names\": body_names,\n                        \"dof_names\": dof_names,\n                        \"motion\": {\n                            \"online_filter\": {\"enabled\": True},\n                            \"max_frame_length\": 4,\n                            \"min_frame_length\": 1,\n                            \"world_frame_normalization\": True,\n                        },\n                    },\n                    \"debug_reference_filter_export\": {\n                        \"enabled\": True,\n                        \"output_dir\": tmp_dir,\n                        \"selected_body_links\": [\n                            \"left_wrist_yaw_link\",\n                            \"left_ankle_roll_link\",\n                        ],\n                    },\n                }\n            )\n\n            with mock.patch(\n                \"holomotion.src.training.reference_filter_export.\"\n                \"build_motion_datasets_from_cfg\",\n                return_value=([sample], None, {}),\n            ) as build_mock:\n                output_dir = export_reference_filter_artifacts_from_config(\n                    config\n                )\n\n            self.assertEqual(output_dir, Path(tmp_dir))\n            build_mock.assert_called_once()\n            self.assertTrue((Path(tmp_dir) / \"metadata.json\").is_file())\n\n    def test_export_reference_filter_debug_artifacts_writes_outputs(self):\n        sample = _make_sample()\n        body_names = [\n            \"root_link\",\n            \"torso_link\",\n            \"left_wrist_yaw_link\",\n            \"right_wrist_yaw_link\",\n            \"left_ankle_roll_link\",\n        ]\n        dof_names = [\n            \"waist_yaw_joint\",\n            \"left_wrist_yaw_joint\",\n            \"left_ankle_roll_joint\",\n        ]\n        selected_body_links = [\n            \"left_wrist_yaw_link\",\n            \"right_wrist_yaw_link\",\n            \"left_ankle_roll_link\",\n        ]\n\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            output_dir = Path(tmp_dir)\n            export_reference_filter_debug_artifacts(\n                sample=sample,\n                output_dir=output_dir,\n                body_names=body_names,\n                dof_names=dof_names,\n                selected_body_links=selected_body_links,\n            )\n\n            self.assertTrue((output_dir / \"metadata.json\").is_file())\n            self.assertTrue((output_dir / \"root_signals.npz\").is_file())\n            self.assertTrue((output_dir / \"bodylink_signals.npz\").is_file())\n            self.assertTrue((output_dir / \"dof_signals.npz\").is_file())\n            self.assertTrue((output_dir / \"root_comparison.png\").is_file())\n            self.assertTrue(\n                (output_dir / \"left_wrist_yaw_link_comparison.png\").is_file()\n            )\n            self.assertTrue((output_dir / \"dof_pos_comparison.png\").is_file())\n            self.assertTrue((output_dir / \"dof_vel_comparison.png\").is_file())\n\n            metadata = json.loads(\n                (output_dir / \"metadata.json\").read_text(encoding=\"utf-8\")\n            )\n            self.assertEqual(metadata[\"filter_cutoff_hz\"], 2.0)\n            self.assertEqual(\n                metadata[\"selected_body_links\"], selected_body_links\n            )\n            self.assertEqual(metadata[\"dof_names\"], dof_names)\n\n            root_payload = np.load(output_dir / \"root_signals.npz\")\n            self.assertIn(\"ref_global_pos\", root_payload.files)\n            self.assertIn(\"ft_ref_rpy\", root_payload.files)\n            self.assertEqual(root_payload[\"ref_global_pos\"].shape, (4, 3))\n            self.assertEqual(root_payload[\"ft_ref_rpy\"].shape, (4, 3))\n\n            dof_payload = np.load(output_dir / \"dof_signals.npz\")\n            self.assertEqual(dof_payload[\"ref_dof_pos\"].shape, (4, 3))\n            self.assertEqual(dof_payload[\"ft_ref_dof_vel\"].shape, (4, 3))\n\n    def test_export_reference_filter_debug_artifacts_requires_filtered_tensors(\n        self,\n    ):\n        sample = _make_sample(include_filtered=False)\n\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            with self.assertRaisesRegex(\n                ValueError, \"Filtered reference tensors are unavailable\"\n            ):\n                export_reference_filter_debug_artifacts(\n                    sample=sample,\n                    output_dir=Path(tmp_dir),\n                    body_names=[\"root_link\", \"left_wrist_yaw_link\"],\n                    dof_names=[\"waist_yaw_joint\"],\n                    selected_body_links=[\"left_wrist_yaw_link\"],\n                )\n"
  },
  {
    "path": "tests/test_reference_motion_config_wiring.py",
    "content": "import sys\nimport unittest\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\n\nPROJECT_ROOT = Path(__file__).resolve().parents[1]\n\nOBS_CONFIG_PATHS = [\n    PROJECT_ROOT\n    / \"holomotion/config/env/observations/motion_tracking/obs_motion_tracking.yaml\",\n    PROJECT_ROOT\n    / \"holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3.yaml\",\n    PROJECT_ROOT\n    / \"holomotion/config/env/observations/motion_tracking/obs_motrack_tf_more_info.yaml\",\n    PROJECT_ROOT\n    / \"holomotion/config/env/observations/motion_tracking/obs_motrack_mlp_20260210.yaml\",\n    PROJECT_ROOT\n    / \"holomotion/config/env/observations/motion_tracking/obs_motrack_tf_20260210.yaml\",\n    PROJECT_ROOT\n    / \"holomotion/config/env/observations/motion_tracking/obs_motrack_teacher.yaml\",\n]\n\nTERMINATION_CONFIG_PATHS = [\n    PROJECT_ROOT\n    / \"holomotion/config/env/terminations/termination_motion_tracking.yaml\",\n    PROJECT_ROOT\n    / \"holomotion/config/env/terminations/termination_motion_tracking_simple.yaml\",\n    PROJECT_ROOT\n    / \"holomotion/config/env/terminations/termination_motrack_with_kpe.yaml\",\n    PROJECT_ROOT\n    / \"holomotion/config/env/terminations/termination_motrack_with_kpe_jpe.yaml\",\n]\n\nROBOT_TRAINING_CONFIG_PATH = (\n    PROJECT_ROOT\n    / \"holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml\"\n)\nREWARD_CONFIG_PATH = (\n    PROJECT_ROOT\n    / \"holomotion/config/env/rewards/motion_tracking/rew_motrack_robust.yaml\"\n)\n\n\nclass ReferenceMotionConfigWiringTests(unittest.TestCase):\n    def test_motion_tracking_observation_configs_expose_cutoff_term(self):\n        for config_path in OBS_CONFIG_PATHS:\n            with self.subTest(config_path=str(config_path)):\n                config = OmegaConf.load(config_path)\n                self.assertTrue(\n                    self._config_has_obs_term(\n                        config,\n                        term_name=\"ref_motion_filter_cutoff_hz\",\n                    )\n                )\n\n    def test_motion_tracking_termination_configs_forward_ref_prefix(self):\n        for config_path in TERMINATION_CONFIG_PATHS:\n            with self.subTest(config_path=str(config_path)):\n                config = OmegaConf.load(config_path)\n                for term_name, term_cfg in config.terminations.items():\n                    if term_name == \"time_out\":\n                        continue\n                    self.assertIn(\"params\", term_cfg)\n                    self.assertIn(\"ref_prefix\", term_cfg.params)\n\n    def test_robot_training_config_uses_hdf5_v2_backend(self):\n        config = OmegaConf.load(ROBOT_TRAINING_CONFIG_PATH)\n        self.assertEqual(config.robot.motion.backend, \"hdf5_v2\")\n\n    def test_termination_ref_prefix_resolves_with_reward_config(self):\n        reward_config = OmegaConf.load(REWARD_CONFIG_PATH)\n        for config_path in TERMINATION_CONFIG_PATHS:\n            with self.subTest(config_path=str(config_path)):\n                termination_config = OmegaConf.load(config_path)\n                merged = OmegaConf.merge(reward_config, termination_config)\n                for term_name, term_cfg in merged.terminations.items():\n                    if term_name == \"time_out\":\n                        continue\n                    resolved_ref_prefix = OmegaConf.select(\n                        merged,\n                        f\"terminations.{term_name}.params.ref_prefix\",\n                    )\n                    self.assertIsInstance(resolved_ref_prefix, str)\n                    self.assertTrue(resolved_ref_prefix.endswith(\"ref_\"))\n\n    @staticmethod\n    def _config_has_obs_term(config, term_name: str) -> bool:\n        for group_cfg in config.obs.obs_groups.values():\n            for term_cfg in group_cfg.atomic_obs_list:\n                for obs_name, obs_value in term_cfg.items():\n                    if obs_name == term_name:\n                        return True\n                    if obs_value.get(\"func\") == term_name:\n                        return True\n        return False\n"
  },
  {
    "path": "tests/test_root_rel_rewards.py",
    "content": "import importlib.util\nimport sys\nfrom pathlib import Path\nfrom types import ModuleType, SimpleNamespace\n\nimport torch\n\nREWARDS_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"holomotion\"\n    / \"src\"\n    / \"env\"\n    / \"isaaclab_components\"\n    / \"isaaclab_rewards.py\"\n)\nMOTION_TRACKING_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"holomotion\"\n    / \"src\"\n    / \"env\"\n    / \"motion_tracking.py\"\n)\n\n\nclass _DummyConfig:\n    def __init__(self, *args, **kwargs):\n        self.args = args\n        self.kwargs = kwargs\n        if args:\n            self.name = args[0]\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n        if not hasattr(self, \"params\"):\n            self.params = {}\n\n\nclass _DummyManagerTermBase:\n    def __init__(self, cfg, env):\n        self.cfg = cfg\n        self._env = env\n\n    @property\n    def num_envs(self):\n        return self._env.num_envs\n\n    @property\n    def device(self):\n        return self._env.device\n\n    def reset(self, env_ids=None):\n        pass\n\n\ndef _identity_quat(*shape: int) -> torch.Tensor:\n    quat = torch.zeros(*shape, 4, dtype=torch.float32)\n    quat[..., 0] = 1.0\n    return quat\n\n\ndef _load_rewards_module(monkeypatch):\n    isaaclab_root = ModuleType(\"isaaclab\")\n    isaaclab_assets = ModuleType(\"isaaclab.assets\")\n    isaaclab_assets.Articulation = object\n    isaaclab_envs = ModuleType(\"isaaclab.envs\")\n    isaaclab_envs.ManagerBasedRLEnv = object\n    isaaclab_mdp = ModuleType(\"isaaclab.envs.mdp\")\n    isaaclab_mdp.__getattr__ = lambda name: (lambda *args, **kwargs: None)\n    isaaclab_managers = ModuleType(\"isaaclab.managers\")\n    isaaclab_managers.ManagerTermBase = _DummyManagerTermBase\n    isaaclab_managers.RewardTermCfg = _DummyConfig\n    isaaclab_managers.SceneEntityCfg = _DummyConfig\n    isaaclab_sensors = ModuleType(\"isaaclab.sensors\")\n    isaaclab_sensors.ContactSensor = object\n    isaaclab_utils = ModuleType(\"isaaclab.utils\")\n    isaaclab_utils.configclass = lambda cls: cls\n    isaaclab_math = ModuleType(\"isaaclab.utils.math\")\n    isaaclab_math.quat_apply = lambda quat, vec: vec\n    isaaclab_math.quat_apply_inverse = lambda quat, vec: vec\n    isaaclab_math.quat_inv = lambda quat: quat\n    isaaclab_math.quat_mul = lambda lhs, rhs: lhs\n    isaaclab_math.yaw_quat = lambda quat: quat\n    isaaclab_math.quat_error_magnitude = lambda lhs, rhs: torch.linalg.norm(\n        lhs - rhs, dim=-1\n    )\n    isaaclab_math.__getattr__ = lambda name: (lambda *args, **kwargs: None)\n\n    hydra_utils = ModuleType(\"hydra.utils\")\n    hydra_utils.instantiate = lambda value, *args, **kwargs: value\n\n    omegaconf = ModuleType(\"omegaconf\")\n    omegaconf.DictConfig = dict\n    omegaconf.ListConfig = list\n    omegaconf.OmegaConf = SimpleNamespace(\n        to_container=lambda value, resolve=True: value\n    )\n\n    loguru = ModuleType(\"loguru\")\n    loguru.logger = SimpleNamespace(\n        info=lambda *args, **kwargs: None,\n        warning=lambda *args, **kwargs: None,\n    )\n\n    fake_command_module = ModuleType(\n        \"holomotion.src.env.isaaclab_components.\"\n        \"isaaclab_motion_tracking_command\"\n    )\n    fake_command_module.RefMotionCommand = object\n\n    fake_utils_module = ModuleType(\n        \"holomotion.src.env.isaaclab_components.isaaclab_utils\"\n    )\n    fake_utils_module._get_body_indices = lambda robot, keybody_names: [\n        robot.body_names.index(name) for name in keybody_names\n    ]\n    fake_utils_module._get_dof_indices = lambda robot, key_dofs: []\n    fake_utils_module.resolve_holo_config = lambda value: value\n\n    for name, module in {\n        \"isaaclab\": isaaclab_root,\n        \"isaaclab.assets\": isaaclab_assets,\n        \"isaaclab.envs\": isaaclab_envs,\n        \"isaaclab.envs.mdp\": isaaclab_mdp,\n        \"isaaclab.managers\": isaaclab_managers,\n        \"isaaclab.sensors\": isaaclab_sensors,\n        \"isaaclab.utils\": isaaclab_utils,\n        \"isaaclab.utils.math\": isaaclab_math,\n        \"hydra.utils\": hydra_utils,\n        \"omegaconf\": omegaconf,\n        \"loguru\": loguru,\n        (\n            \"holomotion.src.env.isaaclab_components.\"\n            \"isaaclab_motion_tracking_command\"\n        ): fake_command_module,\n        (\n            \"holomotion.src.env.isaaclab_components.isaaclab_utils\"\n        ): fake_utils_module,\n    }.items():\n        monkeypatch.setitem(sys.modules, name, module)\n\n    isaaclab_root.assets = isaaclab_assets\n    isaaclab_root.envs = isaaclab_envs\n    isaaclab_root.managers = isaaclab_managers\n    isaaclab_root.sensors = isaaclab_sensors\n    isaaclab_root.utils = isaaclab_utils\n    isaaclab_envs.mdp = isaaclab_mdp\n    isaaclab_utils.math = isaaclab_math\n\n    module_name = \"_test_root_rel_rewards\"\n    spec = importlib.util.spec_from_file_location(module_name, REWARDS_PATH)\n    module = importlib.util.module_from_spec(spec)\n    assert spec is not None\n    assert spec.loader is not None\n    sys.modules[module_name] = module\n    spec.loader.exec_module(module)\n    return module\n\n\ndef _load_motion_tracking_module(monkeypatch):\n    class _DummyConfigClass:\n        def __init__(self, *args, **kwargs):\n            self.args = args\n            self.kwargs = kwargs\n\n    isaaclab_root = ModuleType(\"isaaclab\")\n    isaaclab_actuators = ModuleType(\"isaaclab.actuators\")\n    isaaclab_actuators.ImplicitActuatorCfg = _DummyConfigClass\n\n    isaaclab_assets = ModuleType(\"isaaclab.assets\")\n    isaaclab_assets.Articulation = object\n\n    isaaclab_envs = ModuleType(\"isaaclab.envs\")\n    isaaclab_envs.ManagerBasedEnv = object\n    isaaclab_envs.ManagerBasedRLEnv = object\n    isaaclab_envs.ManagerBasedRLEnvCfg = object\n    isaaclab_envs.ViewerCfg = _DummyConfigClass\n\n    isaaclab_envs_mdp = ModuleType(\"isaaclab.envs.mdp\")\n    isaaclab_envs_mdp.__getattr__ = lambda name: (lambda *args, **kwargs: None)\n\n    isaaclab_envs_mdp_events = ModuleType(\"isaaclab.envs.mdp.events\")\n    isaaclab_envs_mdp_events._randomize_prop_by_op = (\n        lambda *args, **kwargs: None\n    )\n\n    isaaclab_managers = ModuleType(\"isaaclab.managers\")\n    isaaclab_managers.EventTermCfg = _DummyConfigClass\n    isaaclab_managers.SceneEntityCfg = _DummyConfig\n\n    isaaclab_sim = ModuleType(\"isaaclab.sim\")\n    isaaclab_sim.PhysxCfg = _DummyConfigClass\n    isaaclab_sim.SimulationCfg = _DummyConfigClass\n\n    isaaclab_utils = ModuleType(\"isaaclab.utils\")\n    isaaclab_utils.configclass = lambda cls: cls\n\n    isaaclab_utils_io = ModuleType(\"isaaclab.utils.io\")\n    isaaclab_utils_io.dump_yaml = lambda *args, **kwargs: None\n\n    isaaclab_utils_math = ModuleType(\"isaaclab.utils.math\")\n    isaaclab_utils_math.__getattr__ = lambda name: (\n        lambda *args, **kwargs: None\n    )\n\n    easydict = ModuleType(\"easydict\")\n    easydict.EasyDict = lambda value=None: value if value is not None else {}\n\n    omegaconf = ModuleType(\"omegaconf\")\n    omegaconf.OmegaConf = SimpleNamespace(\n        to_container=lambda value, resolve=True: value\n    )\n\n    loguru = ModuleType(\"loguru\")\n    loguru.logger = SimpleNamespace(\n        info=lambda *args, **kwargs: None,\n        warning=lambda *args, **kwargs: None,\n    )\n\n    isaaclab_components = ModuleType(\"holomotion.src.env.isaaclab_components\")\n    for name in [\n        \"ActionsCfg\",\n        \"VelTrack_CommandsCfg\",\n        \"MoTrack_CommandsCfg\",\n        \"EventsCfg\",\n        \"MotionTrackingSceneCfg\",\n        \"ObservationsCfg\",\n        \"RewardsCfg\",\n        \"TerminationsCfg\",\n        \"CurriculumCfg\",\n    ]:\n        setattr(isaaclab_components, name, _DummyConfigClass)\n    for name in [\n        \"build_actions_config\",\n        \"build_motion_tracking_commands_config\",\n        \"build_velocity_commands_config\",\n        \"build_domain_rand_config\",\n        \"build_curriculum_config\",\n        \"build_observations_config\",\n        \"build_rewards_config\",\n        \"build_scene_config\",\n        \"build_terminations_config\",\n    ]:\n        setattr(isaaclab_components, name, lambda *args, **kwargs: None)\n\n    fake_observation_module = ModuleType(\n        \"holomotion.src.env.isaaclab_components.isaaclab_observation\"\n    )\n    fake_observation_module.ObservationFunctions = object\n\n    fake_utils_module = ModuleType(\n        \"holomotion.src.env.isaaclab_components.isaaclab_utils\"\n    )\n    fake_utils_module.resolve_holo_config = lambda value: value\n\n    for name, module in {\n        \"isaaclab\": isaaclab_root,\n        \"isaaclab.actuators\": isaaclab_actuators,\n        \"isaaclab.assets\": isaaclab_assets,\n        \"isaaclab.envs\": isaaclab_envs,\n        \"isaaclab.envs.mdp\": isaaclab_envs_mdp,\n        \"isaaclab.envs.mdp.events\": isaaclab_envs_mdp_events,\n        \"isaaclab.managers\": isaaclab_managers,\n        \"isaaclab.sim\": isaaclab_sim,\n        \"isaaclab.utils\": isaaclab_utils,\n        \"isaaclab.utils.io\": isaaclab_utils_io,\n        \"isaaclab.utils.math\": isaaclab_utils_math,\n        \"easydict\": easydict,\n        \"omegaconf\": omegaconf,\n        \"loguru\": loguru,\n        \"holomotion.src.env.isaaclab_components\": isaaclab_components,\n        (\n            \"holomotion.src.env.isaaclab_components.isaaclab_observation\"\n        ): fake_observation_module,\n        (\n            \"holomotion.src.env.isaaclab_components.isaaclab_utils\"\n        ): fake_utils_module,\n    }.items():\n        monkeypatch.setitem(sys.modules, name, module)\n\n    isaaclab_root.actuators = isaaclab_actuators\n    isaaclab_root.assets = isaaclab_assets\n    isaaclab_root.envs = isaaclab_envs\n    isaaclab_root.managers = isaaclab_managers\n    isaaclab_root.sim = isaaclab_sim\n    isaaclab_root.utils = isaaclab_utils\n    isaaclab_envs.mdp = isaaclab_envs_mdp\n    isaaclab_utils.io = isaaclab_utils_io\n    isaaclab_utils.math = isaaclab_utils_math\n\n    module_name = \"_test_motion_tracking\"\n    spec = importlib.util.spec_from_file_location(\n        module_name, MOTION_TRACKING_PATH\n    )\n    module = importlib.util.module_from_spec(spec)\n    assert spec is not None\n    assert spec.loader is not None\n    sys.modules[module_name] = module\n    spec.loader.exec_module(module)\n    return module\n\n\ndef _make_env():\n    env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32)\n    robot_data = SimpleNamespace(\n        body_pos_w=torch.tensor(\n            [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32\n        ),\n        body_quat_w=_identity_quat(1, 2),\n        body_lin_vel_w=torch.tensor(\n            [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32\n        ),\n        body_ang_vel_w=torch.tensor(\n            [[[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]], dtype=torch.float32\n        ),\n    )\n    robot = SimpleNamespace(body_names=[\"anchor\", \"target\"], data=robot_data)\n    command = SimpleNamespace(\n        robot=robot,\n        anchor_bodylink_idx=0,\n        get_ref_motion_root_global_pos_cur=lambda prefix=\"ref_\": torch.tensor(\n            [[10.0, 0.0, 0.0]], dtype=torch.float32\n        ),\n        get_ref_motion_root_global_pos_immediate_next=(\n            lambda prefix=\"ref_\": torch.tensor(\n                [[10.0, 0.0, 0.0]], dtype=torch.float32\n            )\n        ),\n        get_ref_motion_root_global_rot_quat_wxyz_cur=(\n            lambda prefix=\"ref_\": _identity_quat(1)\n        ),\n        get_ref_motion_root_global_rot_quat_wxyz_immediate_next=(\n            lambda prefix=\"ref_\": _identity_quat(1)\n        ),\n        get_ref_motion_root_global_lin_vel_cur=(\n            lambda prefix=\"ref_\": torch.zeros(1, 3, dtype=torch.float32)\n        ),\n        get_ref_motion_root_global_lin_vel_immediate_next=(\n            lambda prefix=\"ref_\": torch.zeros(1, 3, dtype=torch.float32)\n        ),\n        get_ref_motion_root_global_ang_vel_cur=(\n            lambda prefix=\"ref_\": torch.tensor([[0.0, 0.0, 1.0]])\n        ),\n        get_ref_motion_root_global_ang_vel_immediate_next=(\n            lambda prefix=\"ref_\": torch.tensor([[0.0, 0.0, 1.0]])\n        ),\n        get_ref_motion_bodylink_global_pos_cur=(\n            lambda prefix=\"ref_\": torch.tensor(\n                [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32\n            )\n        ),\n        get_ref_motion_bodylink_global_pos_immediate_next=(\n            lambda prefix=\"ref_\": torch.tensor(\n                [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32\n            )\n        ),\n        get_ref_motion_bodylink_global_lin_vel_cur=(\n            lambda prefix=\"ref_\": torch.tensor(\n                [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32\n            )\n        ),\n        get_ref_motion_bodylink_global_lin_vel_immediate_next=(\n            lambda prefix=\"ref_\": torch.tensor(\n                [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32\n            )\n        ),\n        get_ref_motion_bodylink_global_rot_wxyz_immediate_next=(\n            lambda prefix=\"ref_\": _identity_quat(1, 2)\n        ),\n        get_ref_motion_bodylink_global_ang_vel_immediate_next=(\n            lambda prefix=\"ref_\": torch.tensor(\n                [[[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]], dtype=torch.float32\n            )\n        ),\n    )\n    return SimpleNamespace(\n        command_manager=SimpleNamespace(get_term=lambda name: command),\n        scene=SimpleNamespace(env_origins=env_origins),\n    )\n\n\ndef _make_torque_rate_env(\n    applied_torque: torch.Tensor,\n    actuators: dict,\n    joint_vel: torch.Tensor | None = None,\n    joint_vel_limits: torch.Tensor | None = None,\n):\n    class _Scene(dict):\n        pass\n\n    if joint_vel is None:\n        joint_vel = torch.zeros_like(applied_torque)\n    if joint_vel_limits is None:\n        joint_vel_limits = torch.ones_like(applied_torque)\n\n    asset = SimpleNamespace(\n        data=SimpleNamespace(\n            applied_torque=applied_torque.clone(),\n            joint_vel=joint_vel.clone(),\n            joint_vel_limits=joint_vel_limits.clone(),\n        ),\n        actuators=actuators,\n    )\n    scene = _Scene(robot=asset)\n    return SimpleNamespace(\n        scene=scene,\n        num_envs=applied_torque.shape[0],\n        device=applied_torque.device,\n        episode_length_buf=torch.zeros(\n            applied_torque.shape[0],\n            dtype=torch.long,\n            device=applied_torque.device,\n        ),\n    )\n\n\ndef _make_action_acc_env(action: torch.Tensor):\n    return SimpleNamespace(\n        action_manager=SimpleNamespace(action=action.clone()),\n        num_envs=action.shape[0],\n        device=action.device,\n        episode_length_buf=torch.zeros(\n            action.shape[0], dtype=torch.long, device=action.device\n        ),\n    )\n\n\ndef test_root_rel_keybody_pos_reward_uses_true_root_frame(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_env()\n    rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros(\n        1, 3, dtype=torch.float32\n    )\n    rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1)\n\n    reward = rewards.root_rel_keybodylink_pos_tracking_l2_exp(\n        env,\n        std=1.0,\n        keybody_names=[\"target\"],\n    )\n\n    assert torch.allclose(reward, torch.ones(1))\n\n\ndef test_root_rel_keybody_pos_bydmmc_reward_uses_true_root_frame(\n    monkeypatch,\n):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_env()\n    rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros(\n        1, 3, dtype=torch.float32\n    )\n    rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1)\n\n    reward = rewards.root_rel_keybodylink_pos_tracking_l2_exp_bydmmc_style(\n        env,\n        std=1.0,\n        keybody_names=[\"target\"],\n    )\n\n    assert torch.allclose(reward, torch.ones(1))\n\n\ndef test_root_rel_keybody_lin_vel_reward_uses_true_root_frame(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_env()\n    rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros(\n        1, 3, dtype=torch.float32\n    )\n    rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1)\n    rewards.isaaclab_mdp.root_lin_vel_w = lambda _env: torch.zeros(\n        1, 3, dtype=torch.float32\n    )\n    rewards.isaaclab_mdp.root_ang_vel_w = lambda _env: torch.tensor(\n        [[0.0, 0.0, 1.0]], dtype=torch.float32\n    )\n\n    reward = rewards.root_rel_keybodylink_lin_vel_tracking_l2_exp(\n        env,\n        std=1.0,\n        keybody_names=[\"target\"],\n    )\n\n    assert torch.allclose(reward, torch.ones(1))\n\n\ndef test_root_pos_xy_tracking_uses_immediate_next_reference(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    robot_data = SimpleNamespace(\n        root_pos_w=torch.tensor([[1.0, 2.0, 0.0]], dtype=torch.float32)\n    )\n    robot = SimpleNamespace(data=robot_data)\n    command = SimpleNamespace(\n        robot=robot,\n        get_ref_motion_root_global_pos_cur=(\n            lambda prefix=\"ref_\": (_ for _ in ()).throw(\n                AssertionError(\"current reference should not be used\")\n            )\n        ),\n        get_ref_motion_root_global_pos_immediate_next=(\n            lambda prefix=\"ref_\": torch.tensor(\n                [[1.0, 2.0, 3.0]], dtype=torch.float32\n            )\n        ),\n    )\n    env = SimpleNamespace(\n        command_manager=SimpleNamespace(get_term=lambda name: command)\n    )\n\n    reward = rewards.root_pos_xy_tracking_exp(env, std=1.0)\n\n    assert torch.allclose(reward, torch.ones(1))\n\n\ndef test_global_keybody_lin_vel_tracking_uses_immediate_next_reference(\n    monkeypatch,\n):\n    rewards = _load_rewards_module(monkeypatch)\n    robot_data = SimpleNamespace(\n        body_lin_vel_w=torch.tensor(\n            [[[0.0, 0.0, 0.0], [3.0, 4.0, 0.0]]], dtype=torch.float32\n        )\n    )\n    robot = SimpleNamespace(\n        body_names=[\"anchor\", \"target\"],\n        data=robot_data,\n    )\n    command = SimpleNamespace(\n        robot=robot,\n        get_ref_motion_bodylink_global_lin_vel_cur=(\n            lambda prefix=\"ref_\": (_ for _ in ()).throw(\n                AssertionError(\"current reference should not be used\")\n            )\n        ),\n        get_ref_motion_bodylink_global_lin_vel_immediate_next=(\n            lambda prefix=\"ref_\": torch.tensor(\n                [[[0.0, 0.0, 0.0], [3.0, 4.0, 0.0]]], dtype=torch.float32\n            )\n        ),\n    )\n    env = SimpleNamespace(\n        command_manager=SimpleNamespace(get_term=lambda name: command)\n    )\n\n    reward = rewards.global_keybodylink_lin_vel_tracking_l2_exp(\n        env,\n        std=1.0,\n        keybody_names=[\"target\"],\n    )\n\n    assert torch.allclose(reward, torch.ones(1))\n\n\ndef test_normed_torque_rate_matches_selected_joint_math(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_torque_rate_env(\n        applied_torque=torch.zeros(2, 3, dtype=torch.float32),\n        actuators={\n            \"all_joints\": SimpleNamespace(\n                joint_indices=slice(None),\n                effort_limit=torch.tensor(\n                    [[10.0, 20.0, 40.0], [10.0, 20.0, 40.0]],\n                    dtype=torch.float32,\n                ),\n            )\n        },\n    )\n    term = rewards.normed_torque_rate(_DummyConfig(params={}), env)\n    asset_cfg = SimpleNamespace(\n        name=\"robot\", joint_ids=torch.tensor([0, 2], dtype=torch.long)\n    )\n\n    first = term(env, asset_cfg=asset_cfg)\n    assert torch.allclose(first, torch.zeros(2))\n\n    env.episode_length_buf[:] = 1\n    env.scene[\"robot\"].data.applied_torque = torch.tensor(\n        [[1.0, 9.0, 4.0], [2.0, 7.0, 8.0]],\n        dtype=torch.float32,\n    )\n    reward = term(env, asset_cfg=asset_cfg)\n\n    expected = torch.tensor(\n        [\n            (1.0 / 10.0) ** 2 + (4.0 / 40.0) ** 2,\n            (2.0 / 10.0) ** 2 + (8.0 / 40.0) ** 2,\n        ],\n        dtype=torch.float32,\n    )\n    assert torch.allclose(reward, expected)\n\n\ndef test_normed_torque_rate_assembles_limits_across_actuator_groups(\n    monkeypatch,\n):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_torque_rate_env(\n        applied_torque=torch.zeros(1, 3, dtype=torch.float32),\n        actuators={\n            \"implicit_group\": SimpleNamespace(\n                joint_indices=torch.tensor([0, 2], dtype=torch.long),\n                effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32),\n            ),\n            \"unitree_group\": SimpleNamespace(\n                joint_indices=torch.tensor([1], dtype=torch.long),\n                effort_limit=torch.tensor([[5.0]], dtype=torch.float32),\n            ),\n        },\n    )\n    term = rewards.normed_torque_rate(_DummyConfig(params={}), env)\n    asset_cfg = SimpleNamespace(\n        name=\"robot\", joint_ids=torch.tensor([0, 1, 2], dtype=torch.long)\n    )\n\n    _ = term(env, asset_cfg=asset_cfg)\n    env.episode_length_buf[:] = 1\n    env.scene[\"robot\"].data.applied_torque = torch.tensor(\n        [[1.0, 1.0, 2.0]], dtype=torch.float32\n    )\n    reward = term(env, asset_cfg=asset_cfg)\n\n    expected = torch.tensor(\n        [(1.0 / 10.0) ** 2 + (1.0 / 5.0) ** 2 + (2.0 / 20.0) ** 2],\n        dtype=torch.float32,\n    )\n    assert torch.allclose(reward, expected)\n\n\ndef test_normed_torque_rate_resets_first_step_history(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_torque_rate_env(\n        applied_torque=torch.tensor(\n            [[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32\n        ),\n        actuators={\n            \"all_joints\": SimpleNamespace(\n                joint_indices=slice(None),\n                effort_limit=torch.tensor(\n                    [[10.0, 10.0], [10.0, 10.0]], dtype=torch.float32\n                ),\n            )\n        },\n    )\n    term = rewards.normed_torque_rate(_DummyConfig(params={}), env)\n    asset_cfg = SimpleNamespace(\n        name=\"robot\", joint_ids=torch.tensor([0, 1], dtype=torch.long)\n    )\n\n    first = term(env, asset_cfg=asset_cfg)\n    assert torch.allclose(first, torch.zeros(2))\n\n    env.episode_length_buf[:] = 1\n    env.scene[\"robot\"].data.applied_torque = torch.tensor(\n        [[2.0, 4.0], [5.0, 8.0]], dtype=torch.float32\n    )\n    second = term(env, asset_cfg=asset_cfg)\n    assert torch.all(second > 0.0)\n\n    term.reset(env_ids=[0])\n    env.scene[\"robot\"].data.applied_torque = torch.tensor(\n        [[7.0, 9.0], [6.0, 10.0]], dtype=torch.float32\n    )\n    after_reset = term(env, asset_cfg=asset_cfg)\n\n    assert torch.isclose(after_reset[0], torch.tensor(0.0))\n    assert after_reset[1] > 0.0\n\n\ndef test_normed_torque_rate_reuses_cached_normalization(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    actuator = SimpleNamespace(\n        joint_indices=slice(None),\n        effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32),\n    )\n    env = _make_torque_rate_env(\n        applied_torque=torch.zeros(1, 2, dtype=torch.float32),\n        actuators={\"all_joints\": actuator},\n    )\n    term = rewards.normed_torque_rate(_DummyConfig(params={}), env)\n    asset_cfg = SimpleNamespace(\n        name=\"robot\", joint_ids=torch.tensor([0, 1], dtype=torch.long)\n    )\n\n    _ = term(env, asset_cfg=asset_cfg)\n\n    actuator.effort_limit = torch.tensor(\n        [[1000.0, 1000.0]], dtype=torch.float32\n    )\n    env.episode_length_buf[:] = 1\n    env.scene[\"robot\"].data.applied_torque = torch.tensor(\n        [[2.0, 4.0]], dtype=torch.float32\n    )\n    reward = term(env, asset_cfg=asset_cfg)\n\n    expected = torch.tensor(\n        [(2.0 / 10.0) ** 2 + (4.0 / 20.0) ** 2], dtype=torch.float32\n    )\n    assert torch.allclose(reward, expected)\n\n\ndef test_normed_positive_work_matches_selected_joint_math(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_torque_rate_env(\n        applied_torque=torch.tensor(\n            [[2.0, 5.0, 8.0], [3.0, -4.0, 6.0]], dtype=torch.float32\n        ),\n        joint_vel=torch.tensor(\n            [[1.0, -5.0, 2.0], [2.0, 3.0, -2.0]], dtype=torch.float32\n        ),\n        joint_vel_limits=torch.tensor(\n            [[4.0, 10.0, 8.0], [4.0, 10.0, 8.0]], dtype=torch.float32\n        ),\n        actuators={\n            \"all_joints\": SimpleNamespace(\n                joint_indices=slice(None),\n                effort_limit=torch.tensor(\n                    [[4.0, 10.0, 16.0], [4.0, 10.0, 16.0]],\n                    dtype=torch.float32,\n                ),\n            )\n        },\n    )\n    term = rewards.normed_positive_work(_DummyConfig(params={}), env)\n    asset_cfg = SimpleNamespace(\n        name=\"robot\", joint_ids=torch.tensor([0, 2], dtype=torch.long)\n    )\n\n    reward = term(env, asset_cfg=asset_cfg)\n\n    expected = torch.tensor(\n        [\n            (2.0 / 4.0) * (1.0 / 4.0) + (8.0 / 16.0) * (2.0 / 8.0),\n            (3.0 / 4.0) * (2.0 / 4.0),\n        ],\n        dtype=torch.float32,\n    )\n    assert torch.allclose(reward, expected)\n\n\ndef test_normed_positive_work_assembles_effort_limits_across_actuators(\n    monkeypatch,\n):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_torque_rate_env(\n        applied_torque=torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float32),\n        joint_vel=torch.tensor([[5.0, 2.0, -1.0]], dtype=torch.float32),\n        joint_vel_limits=torch.tensor([[10.0, 4.0, 8.0]], dtype=torch.float32),\n        actuators={\n            \"implicit_group\": SimpleNamespace(\n                joint_indices=torch.tensor([0, 2], dtype=torch.long),\n                effort_limit=torch.tensor([[4.0, 20.0]], dtype=torch.float32),\n            ),\n            \"unitree_group\": SimpleNamespace(\n                joint_indices=torch.tensor([1], dtype=torch.long),\n                effort_limit=torch.tensor([[6.0]], dtype=torch.float32),\n            ),\n        },\n    )\n    term = rewards.normed_positive_work(_DummyConfig(params={}), env)\n    asset_cfg = SimpleNamespace(\n        name=\"robot\", joint_ids=torch.tensor([0, 1, 2], dtype=torch.long)\n    )\n\n    reward = term(env, asset_cfg=asset_cfg)\n\n    expected = torch.tensor(\n        [\n            (2.0 / 4.0) * (5.0 / 10.0) + (3.0 / 6.0) * (2.0 / 4.0),\n        ],\n        dtype=torch.float32,\n    )\n    assert torch.allclose(reward, expected)\n\n\ndef test_normed_positive_work_reuses_cached_effort_limits(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    actuator = SimpleNamespace(\n        joint_indices=slice(None),\n        effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32),\n    )\n    env = _make_torque_rate_env(\n        applied_torque=torch.tensor([[2.0, 4.0]], dtype=torch.float32),\n        joint_vel=torch.tensor([[5.0, 10.0]], dtype=torch.float32),\n        joint_vel_limits=torch.tensor([[10.0, 20.0]], dtype=torch.float32),\n        actuators={\"all_joints\": actuator},\n    )\n    term = rewards.normed_positive_work(_DummyConfig(params={}), env)\n    asset_cfg = SimpleNamespace(\n        name=\"robot\", joint_ids=torch.tensor([0, 1], dtype=torch.long)\n    )\n\n    first = term(env, asset_cfg=asset_cfg)\n    assert torch.allclose(\n        first,\n        torch.tensor(\n            [(2.0 / 10.0) * (5.0 / 10.0) + (4.0 / 20.0) * (10.0 / 20.0)],\n            dtype=torch.float32,\n        ),\n    )\n\n    actuator.effort_limit = torch.tensor(\n        [[1000.0, 1000.0]], dtype=torch.float32\n    )\n    reward = term(env, asset_cfg=asset_cfg)\n\n    expected = torch.tensor(\n        [(2.0 / 10.0) * (5.0 / 10.0) + (4.0 / 20.0) * (10.0 / 20.0)],\n        dtype=torch.float32,\n    )\n    assert torch.allclose(reward, expected)\n\n\ndef test_normed_positive_work_requires_positive_finite_velocity_limits(\n    monkeypatch,\n):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_torque_rate_env(\n        applied_torque=torch.tensor([[2.0, 4.0]], dtype=torch.float32),\n        joint_vel=torch.tensor([[1.0, 1.0]], dtype=torch.float32),\n        joint_vel_limits=torch.tensor([[0.0, torch.inf]], dtype=torch.float32),\n        actuators={\n            \"all_joints\": SimpleNamespace(\n                joint_indices=slice(None),\n                effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32),\n            )\n        },\n    )\n    term = rewards.normed_positive_work(_DummyConfig(params={}), env)\n\n    try:\n        term(\n            env,\n            asset_cfg=SimpleNamespace(\n                name=\"robot\", joint_ids=torch.tensor([0, 1], dtype=torch.long)\n            ),\n        )\n    except ValueError as exc:\n        assert (\n            \"normed_positive_work requires finite, strictly positive\"\n            in str(exc)\n        )\n    else:\n        raise AssertionError(\n            \"expected normed_positive_work to reject invalid limits\"\n        )\n\n\ndef test_action_acc_matches_second_order_action_change(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_action_acc_env(torch.zeros(2, 2, dtype=torch.float32))\n    term = rewards.action_acc(_DummyConfig(params={}), env)\n\n    first = term(env)\n    assert torch.allclose(first, torch.zeros(2))\n\n    env.episode_length_buf[:] = 1\n    env.action_manager.action = torch.tensor(\n        [[1.0, 2.0], [2.0, 1.0]], dtype=torch.float32\n    )\n    second = term(env)\n    assert torch.allclose(second, torch.zeros(2))\n\n    env.episode_length_buf[:] = 2\n    env.action_manager.action = torch.tensor(\n        [[3.0, 1.0], [5.0, 1.0]], dtype=torch.float32\n    )\n    third = term(env)\n\n    expected = torch.tensor([10.0, 2.0], dtype=torch.float32)\n    assert torch.allclose(third, expected)\n\n\ndef test_action_acc_reset_clears_history(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n    env = _make_action_acc_env(torch.zeros(1, 2, dtype=torch.float32))\n    term = rewards.action_acc(_DummyConfig(params={}), env)\n\n    assert torch.allclose(term(env), torch.zeros(1))\n\n    env.episode_length_buf[:] = 1\n    env.action_manager.action = torch.tensor([[1.0, 1.0]], dtype=torch.float32)\n    assert torch.allclose(term(env), torch.zeros(1))\n\n    env.episode_length_buf[:] = 2\n    env.action_manager.action = torch.tensor([[2.0, 0.0]], dtype=torch.float32)\n    assert torch.allclose(term(env), torch.tensor([4.0]))\n\n    term.reset(env_ids=[0])\n    env.action_manager.action = torch.tensor([[7.0, 7.0]], dtype=torch.float32)\n\n    assert torch.allclose(term(env), torch.zeros(1))\n\n\ndef test_build_rewards_config_exposes_action_acc_term(monkeypatch):\n    rewards = _load_rewards_module(monkeypatch)\n\n    rewards_cfg = rewards.build_rewards_config(\n        {\n            \"action_acc\": {\n                \"weight\": -2.5,\n                \"params\": {},\n            }\n        }\n    )\n\n    assert rewards_cfg.action_acc.func is rewards.action_acc\n    assert rewards_cfg.action_acc.weight == -2.5\n    assert rewards_cfg.action_acc.params == {}\n\n\ndef test_motion_tracking_logs_normed_torque_rate_metric(monkeypatch):\n    motion_tracking = _load_motion_tracking_module(monkeypatch)\n    env = motion_tracking.MotionTrackingEnv.__new__(\n        motion_tracking.MotionTrackingEnv\n    )\n    env.metrics = {}\n    env._robot_prev_joint_vel = None\n    env._robot_prev_applied_torque = None\n    env._robot_torque_rate_inv_effort_limit = None\n    env._robot_torque_rate_needs_reseed = None\n\n    robot = SimpleNamespace(\n        data=SimpleNamespace(\n            joint_vel=torch.zeros(2, 2, dtype=torch.float32),\n            applied_torque=torch.zeros(2, 2, dtype=torch.float32),\n        ),\n        actuators={\n            \"all_joints\": SimpleNamespace(\n                joint_indices=slice(None),\n                effort_limit=torch.tensor(\n                    [[10.0, 20.0], [10.0, 20.0]], dtype=torch.float32\n                ),\n            )\n        },\n    )\n    env._env = SimpleNamespace(\n        step_dt=0.5,\n        action_manager=SimpleNamespace(\n            action=torch.zeros(2, 2, dtype=torch.float32),\n            prev_action=torch.zeros(2, 2, dtype=torch.float32),\n        ),\n        scene={\"robot\": robot},\n        episode_length_buf=torch.zeros(2, dtype=torch.long),\n        num_envs=2,\n        device=torch.device(\"cpu\"),\n    )\n\n    infos = {\"log\": {}}\n    env._update_robot_metrics(infos)\n    assert torch.isclose(\n        infos[\"log\"][\"Metrics/Robot/Normed_Torque_Rate\"],\n        torch.tensor(0.0),\n    )\n\n    env._env.episode_length_buf[:] = 1\n    robot.data.applied_torque = torch.tensor(\n        [[1.0, 4.0], [2.0, 8.0]], dtype=torch.float32\n    )\n    env._update_robot_metrics(infos)\n\n    expected = torch.tensor(\n        [\n            (1.0 / 10.0) ** 2 + (4.0 / 20.0) ** 2,\n            (2.0 / 10.0) ** 2 + (8.0 / 20.0) ** 2,\n        ],\n        dtype=torch.float32,\n    ).mean()\n    assert torch.allclose(\n        infos[\"log\"][\"Metrics/Robot/Normed_Torque_Rate\"], expected\n    )\n"
  },
  {
    "path": "tests/test_unitree_actuators.py",
    "content": "import importlib.util\nimport json\nimport sys\nfrom pathlib import Path\nfrom types import ModuleType, SimpleNamespace\n\nimport pytest\nimport torch\n\n\nACTUATOR_MODULE_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"holomotion\"\n    / \"src\"\n    / \"env\"\n    / \"isaaclab_components\"\n    / \"unitree_actuators.py\"\n)\nSCENE_MODULE_PATH = (\n    Path(__file__).resolve().parents[1]\n    / \"holomotion\"\n    / \"src\"\n    / \"env\"\n    / \"isaaclab_components\"\n    / \"isaaclab_scene.py\"\n)\n\n\nclass _DummyArticulationActions:\n    def __init__(\n        self,\n        joint_positions=None,\n        joint_velocities=None,\n        joint_efforts=None,\n        joint_indices=None,\n    ):\n        self.joint_positions = joint_positions\n        self.joint_velocities = joint_velocities\n        self.joint_efforts = joint_efforts\n        self.joint_indices = joint_indices\n\n\nclass _DummyDelayedPDActuatorCfg:\n    min_delay = 0\n    max_delay = 0\n\n    def __init__(self, **kwargs):\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n\nclass _DummyDelayedPDActuator:\n    def __init__(self, cfg, *args, **kwargs):\n        self.cfg = cfg\n        self._num_envs = kwargs.get(\"num_envs\", 4)\n        self._device = kwargs.get(\"device\", \"cpu\")\n        self.num_joints = len(\n            kwargs.get(\"joint_names\", [\"joint_a\", \"joint_b\"])\n        )\n        self.computed_effort = torch.zeros(\n            self._num_envs, self.num_joints, device=self._device\n        )\n        self.applied_effort = torch.zeros_like(self.computed_effort)\n        effort_limit = kwargs.get(\"effort_limit\", 100.0)\n        if isinstance(effort_limit, torch.Tensor):\n            self.effort_limit = effort_limit.clone().to(device=self._device)\n        else:\n            self.effort_limit = torch.full_like(\n                self.computed_effort, float(effort_limit)\n            )\n        self.super_compute_inputs = []\n        self.super_compute_joint_positions = []\n        self.reset_calls = []\n\n    def _parse_joint_parameter(self, value, default):\n        if value is None:\n            value = default\n        if isinstance(value, torch.Tensor):\n            return value.clone().to(device=self._device)\n        if isinstance(value, dict):\n            values = list(value.values())\n            tensor = torch.tensor(\n                values, dtype=torch.float32, device=self._device\n            )\n            return tensor.unsqueeze(0).repeat(self._num_envs, 1)\n        if isinstance(value, (float, int)):\n            return torch.full_like(self.computed_effort, float(value))\n        raise TypeError(f\"Unsupported parameter type: {type(value)}\")\n\n    def reset(self, env_ids):\n        self.reset_calls.append(env_ids)\n\n    def compute(self, control_action, joint_pos, joint_vel):\n        if control_action.joint_efforts is None:\n            self.super_compute_inputs.append(None)\n        else:\n            self.super_compute_inputs.append(\n                control_action.joint_efforts.clone()\n            )\n        if control_action.joint_positions is None:\n            self.super_compute_joint_positions.append(None)\n        else:\n            self.super_compute_joint_positions.append(\n                control_action.joint_positions.clone()\n            )\n        self.computed_effort = control_action.joint_efforts.clone()\n        self.applied_effort = control_action.joint_efforts.clone()\n        return control_action\n\n\ndef _configclass(cls):\n    annotations = getattr(cls, \"__annotations__\", {})\n    defaults = {\n        name: getattr(cls, name) for name in annotations if hasattr(cls, name)\n    }\n\n    def __init__(self, **kwargs):\n        for name, value in defaults.items():\n            setattr(self, name, value)\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n    cls.__init__ = __init__\n    return cls\n\n\ndef _load_unitree_actuator_module(monkeypatch):\n    isaaclab_root = ModuleType(\"isaaclab\")\n    isaaclab_actuators = ModuleType(\"isaaclab.actuators\")\n    isaaclab_actuators.DelayedPDActuator = _DummyDelayedPDActuator\n    isaaclab_actuators.DelayedPDActuatorCfg = _DummyDelayedPDActuatorCfg\n    isaaclab_utils = ModuleType(\"isaaclab.utils\")\n    isaaclab_utils.configclass = _configclass\n    isaaclab_utils_types = ModuleType(\"isaaclab.utils.types\")\n    isaaclab_utils_types.ArticulationActions = _DummyArticulationActions\n\n    for name, module in {\n        \"isaaclab\": isaaclab_root,\n        \"isaaclab.actuators\": isaaclab_actuators,\n        \"isaaclab.utils\": isaaclab_utils,\n        \"isaaclab.utils.types\": isaaclab_utils_types,\n    }.items():\n        monkeypatch.setitem(sys.modules, name, module)\n\n    isaaclab_root.actuators = isaaclab_actuators\n    isaaclab_root.utils = isaaclab_utils\n    isaaclab_utils.types = isaaclab_utils_types\n\n    module_name = \"_test_unitree_actuators\"\n    spec = importlib.util.spec_from_file_location(\n        module_name, ACTUATOR_MODULE_PATH\n    )\n    module = importlib.util.module_from_spec(spec)\n    assert spec is not None\n    assert spec.loader is not None\n    sys.modules[module_name] = module\n    spec.loader.exec_module(module)\n    return module\n\n\ndef _make_erfi_actuator(module, *, cfg_kwargs=None, num_envs=4, num_joints=3):\n    if cfg_kwargs is None:\n        cfg_kwargs = {}\n    cfg_defaults = {\n        \"Y1\": 100.0,\n        \"Y2\": 120.0,\n        \"erfi_enabled\": True,\n        \"ema_filter_enabled\": False,\n        \"ema_filter_alpha\": 1.0,\n        \"ema_filter_debug_dump_path\": None,\n        \"ema_filter_debug_stop_after_dump\": False,\n        \"rfi_probability\": 0.5,\n        \"rfi_lim\": 0.1,\n        \"randomize_rfi_lim\": True,\n        \"rfi_lim_range\": (0.5, 1.5),\n        \"rao_lim\": 0.1,\n    }\n    cfg_defaults.update(cfg_kwargs)\n    cfg = module.UnitreeErfiActuatorCfg(**cfg_defaults)\n    actuator = module.UnitreeErfiActuator(\n        cfg,\n        joint_names=[f\"joint_{idx}\" for idx in range(num_joints)],\n        joint_ids=torch.arange(num_joints),\n        num_envs=num_envs,\n        device=\"cpu\",\n        stiffness=0.0,\n        damping=0.0,\n        armature=0.0,\n        friction=0.0,\n        dynamic_friction=0.0,\n        viscous_friction=0.0,\n        effort_limit=100.0,\n        velocity_limit=100.0,\n    )\n    return actuator\n\n\ndef _make_action(actuator):\n    return _DummyArticulationActions(\n        joint_positions=torch.zeros_like(actuator.computed_effort),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n\n\ndef test_unitree_erfi_reset_samples_all_rfi(monkeypatch):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\"rfi_probability\": 1.0},\n    )\n\n    actuator.reset(torch.tensor([0, 1, 2, 3], dtype=torch.long))\n\n    assert torch.all(actuator._mode_is_rfi)\n    assert torch.allclose(\n        actuator._rao_scale, torch.zeros_like(actuator._rao_scale)\n    )\n\n\ndef test_unitree_erfi_reset_samples_all_rao(monkeypatch):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\"rfi_probability\": 0.0},\n    )\n\n    actuator.reset(torch.tensor([0, 1, 2, 3], dtype=torch.long))\n\n    assert not torch.any(actuator._mode_is_rfi)\n    assert torch.any(actuator._rao_scale != 0.0)\n\n\ndef test_unitree_erfi_rfi_without_randomized_limit_uses_effort_limit_ratio(\n    monkeypatch,\n):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\n            \"rfi_probability\": 1.0,\n            \"randomize_rfi_lim\": False,\n            \"rfi_lim\": 0.1,\n        },\n        num_envs=2,\n        num_joints=2,\n    )\n    actuator.reset(torch.tensor([0, 1], dtype=torch.long))\n\n    torch.manual_seed(0)\n    actuator.compute(\n        _make_action(actuator),\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n\n    injected = actuator.super_compute_inputs[-1]\n    assert torch.all(injected.abs() <= 10.0 + 1.0e-6)\n\n\ndef test_unitree_erfi_reset_randomizes_rfi_scale_within_range(monkeypatch):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\n            \"rfi_probability\": 1.0,\n            \"rfi_lim_range\": (0.5, 1.5),\n        },\n        num_envs=2,\n        num_joints=2,\n    )\n\n    actuator.reset(torch.tensor([0, 1], dtype=torch.long))\n\n    assert torch.all(actuator._rfi_lim_scale >= 0.5)\n    assert torch.all(actuator._rfi_lim_scale <= 1.5)\n\n\ndef test_unitree_erfi_rao_bias_stays_constant_between_resets(monkeypatch):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\"rfi_probability\": 0.0, \"rao_lim\": 0.1},\n        num_envs=2,\n        num_joints=2,\n    )\n    actuator.reset(torch.tensor([0, 1], dtype=torch.long))\n    action = _make_action(actuator)\n\n    actuator.compute(\n        action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    first = actuator.super_compute_inputs[-1].clone()\n    actuator.compute(\n        action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    second = actuator.super_compute_inputs[-1].clone()\n\n    assert torch.allclose(first, second)\n\n\ndef test_unitree_erfi_rfi_changes_each_compute(monkeypatch):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\"rfi_probability\": 1.0, \"randomize_rfi_lim\": False},\n        num_envs=2,\n        num_joints=2,\n    )\n    actuator.reset(torch.tensor([0, 1], dtype=torch.long))\n    action = _make_action(actuator)\n\n    torch.manual_seed(0)\n    actuator.compute(\n        action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    first = actuator.super_compute_inputs[-1].clone()\n    torch.manual_seed(1)\n    actuator.compute(\n        action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    second = actuator.super_compute_inputs[-1].clone()\n\n    assert not torch.allclose(first, second)\n\n\ndef test_unitree_erfi_disabled_matches_plain_unitree(monkeypatch):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\"erfi_enabled\": False},\n        num_envs=2,\n        num_joints=2,\n    )\n    action = _make_action(actuator)\n\n    actuator.reset(torch.tensor([0, 1], dtype=torch.long))\n    actuator.compute(\n        action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n\n    assert torch.allclose(\n        actuator.super_compute_inputs[-1],\n        torch.zeros_like(actuator.super_compute_inputs[-1]),\n    )\n\n\ndef test_unitree_erfi_ema_filters_joint_positions(monkeypatch):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\n            \"erfi_enabled\": False,\n            \"ema_filter_enabled\": True,\n            \"ema_filter_alpha\": 0.25,\n        },\n        num_envs=2,\n        num_joints=2,\n    )\n    first_action = _DummyArticulationActions(\n        joint_positions=torch.tensor(\n            [[1.0, -1.0], [0.5, -0.5]], dtype=torch.float32\n        ),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n    second_action = _DummyArticulationActions(\n        joint_positions=torch.tensor(\n            [[3.0, 1.0], [1.5, 0.5]], dtype=torch.float32\n        ),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n\n    actuator.compute(\n        first_action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    actuator.compute(\n        second_action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n\n    assert torch.allclose(\n        actuator.super_compute_joint_positions[0],\n        first_action.joint_positions,\n    )\n    expected_second = (\n        0.25 * second_action.joint_positions\n        + 0.75 * first_action.joint_positions\n    )\n    assert torch.allclose(\n        actuator.super_compute_joint_positions[1], expected_second\n    )\n\n\ndef test_unitree_erfi_ema_reset_clears_only_selected_envs(monkeypatch):\n    module = _load_unitree_actuator_module(monkeypatch)\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\n            \"erfi_enabled\": False,\n            \"ema_filter_enabled\": True,\n            \"ema_filter_alpha\": 0.5,\n        },\n        num_envs=2,\n        num_joints=1,\n    )\n    one_action = _DummyArticulationActions(\n        joint_positions=torch.tensor([[1.0], [1.0]], dtype=torch.float32),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n    two_action = _DummyArticulationActions(\n        joint_positions=torch.tensor([[2.0], [2.0]], dtype=torch.float32),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n    zero_action = _DummyArticulationActions(\n        joint_positions=torch.zeros_like(actuator.computed_effort),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n\n    actuator.compute(\n        one_action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    actuator.compute(\n        two_action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    actuator.reset(torch.tensor([1], dtype=torch.long))\n    actuator.compute(\n        zero_action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n\n    assert torch.allclose(\n        actuator.super_compute_joint_positions[1],\n        torch.tensor([[1.5], [1.5]], dtype=torch.float32),\n    )\n    assert torch.allclose(\n        actuator.super_compute_joint_positions[2],\n        torch.tensor([[0.75], [0.0]], dtype=torch.float32),\n    )\n\n\ndef test_unitree_erfi_ema_debug_dump_records_formula(monkeypatch, tmp_path):\n    module = _load_unitree_actuator_module(monkeypatch)\n    dump_path = tmp_path / \"ema_verify.json\"\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\n            \"erfi_enabled\": False,\n            \"ema_filter_enabled\": True,\n            \"ema_filter_alpha\": 0.25,\n            \"ema_filter_debug_dump_path\": str(dump_path),\n        },\n        num_envs=2,\n        num_joints=2,\n    )\n    first_action = _DummyArticulationActions(\n        joint_positions=torch.tensor(\n            [[1.0, -1.0], [0.5, -0.5]], dtype=torch.float32\n        ),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n    second_action = _DummyArticulationActions(\n        joint_positions=torch.tensor(\n            [[3.0, 1.0], [1.5, 0.5]], dtype=torch.float32\n        ),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n\n    actuator.compute(\n        first_action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    actuator.compute(\n        second_action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n\n    assert dump_path.is_file()\n    payload = json.loads(dump_path.read_text())\n    expected_second = (\n        0.25 * second_action.joint_positions[0]\n        + 0.75 * first_action.joint_positions[0]\n    )\n    assert payload[\"alpha\"] == 0.25\n    assert payload[\"matched\"] is True\n    assert payload[\"env_index\"] == 0\n    assert payload[\"raw_joint_positions\"] == [3.0, 1.0]\n    assert payload[\"previous_filtered_joint_positions\"] == [1.0, -1.0]\n    assert payload[\"expected_filtered_joint_positions\"] == pytest.approx(\n        expected_second.tolist()\n    )\n    assert payload[\"actual_filtered_joint_positions\"] == pytest.approx(\n        expected_second.tolist()\n    )\n\n\ndef test_unitree_erfi_ema_debug_stop_after_dump(monkeypatch, tmp_path):\n    module = _load_unitree_actuator_module(monkeypatch)\n    dump_path = tmp_path / \"ema_verify.json\"\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\n            \"erfi_enabled\": False,\n            \"ema_filter_enabled\": True,\n            \"ema_filter_alpha\": 0.5,\n            \"ema_filter_debug_dump_path\": str(dump_path),\n            \"ema_filter_debug_stop_after_dump\": True,\n        },\n        num_envs=1,\n        num_joints=1,\n    )\n    first_action = _DummyArticulationActions(\n        joint_positions=torch.tensor([[1.0]], dtype=torch.float32),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n    second_action = _DummyArticulationActions(\n        joint_positions=torch.tensor([[3.0]], dtype=torch.float32),\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n\n    actuator.compute(\n        first_action,\n        joint_pos=torch.zeros_like(actuator.computed_effort),\n        joint_vel=torch.zeros_like(actuator.computed_effort),\n    )\n    with pytest.raises(RuntimeError, match=\"EMA verification dump written\"):\n        actuator.compute(\n            second_action,\n            joint_pos=torch.zeros_like(actuator.computed_effort),\n            joint_vel=torch.zeros_like(actuator.computed_effort),\n        )\n\n    assert dump_path.is_file()\n\n\ndef test_unitree_erfi_ema_debug_dump_records_skip_reason(\n    monkeypatch, tmp_path\n):\n    module = _load_unitree_actuator_module(monkeypatch)\n    dump_path = tmp_path / \"ema_verify_skip.json\"\n    actuator = _make_erfi_actuator(\n        module,\n        cfg_kwargs={\n            \"erfi_enabled\": False,\n            \"ema_filter_enabled\": True,\n            \"ema_filter_debug_dump_path\": str(dump_path),\n            \"ema_filter_debug_stop_after_dump\": True,\n        },\n        num_envs=1,\n        num_joints=1,\n    )\n    action = _DummyArticulationActions(\n        joint_positions=None,\n        joint_velocities=torch.zeros_like(actuator.computed_effort),\n        joint_efforts=torch.zeros_like(actuator.computed_effort),\n    )\n\n    with pytest.raises(RuntimeError, match=\"EMA verification dump written\"):\n        actuator.compute(\n            action,\n            joint_pos=torch.zeros_like(actuator.computed_effort),\n            joint_vel=torch.zeros_like(actuator.computed_effort),\n        )\n\n    payload = json.loads(dump_path.read_text())\n    assert payload[\"applied\"] is False\n    assert payload[\"reason\"] == \"joint_positions_none\"\n\n\ndef _load_scene_module(monkeypatch):\n    actuator_module = _load_unitree_actuator_module(monkeypatch)\n\n    isaaclab_root = ModuleType(\"isaaclab\")\n    isaaclab_sim = ModuleType(\"isaaclab.sim\")\n    isaaclab_sim.UrdfFileCfg = lambda **kwargs: SimpleNamespace(**kwargs)\n    isaaclab_sim.RigidBodyPropertiesCfg = lambda **kwargs: SimpleNamespace(\n        **kwargs\n    )\n    isaaclab_sim.ArticulationRootPropertiesCfg = (\n        lambda **kwargs: SimpleNamespace(**kwargs)\n    )\n    isaaclab_sim.UrdfConverterCfg = SimpleNamespace(\n        JointDriveCfg=SimpleNamespace(\n            PDGainsCfg=lambda **kwargs: SimpleNamespace(**kwargs)\n        )\n    )\n    isaaclab_actuators = ModuleType(\"isaaclab.actuators\")\n    isaaclab_actuators.ImplicitActuatorCfg = lambda **kwargs: SimpleNamespace(\n        **kwargs\n    )\n    isaaclab_assets = ModuleType(\"isaaclab.assets\")\n    isaaclab_assets.ArticulationCfg = SimpleNamespace(\n        InitialStateCfg=lambda **kwargs: SimpleNamespace(**kwargs)\n    )\n    isaaclab_assets.ArticulationCfg = lambda **kwargs: SimpleNamespace(\n        **kwargs\n    )\n    isaaclab_assets.AssetBaseCfg = lambda **kwargs: SimpleNamespace(**kwargs)\n    isaaclab_scene = ModuleType(\"isaaclab.scene\")\n    isaaclab_scene.InteractiveSceneCfg = object\n    isaaclab_sensors = ModuleType(\"isaaclab.sensors\")\n    isaaclab_sensors.ContactSensorCfg = lambda **kwargs: SimpleNamespace(\n        **kwargs\n    )\n    isaaclab_sensors.RayCasterCfg = SimpleNamespace(\n        OffsetCfg=lambda **kwargs: SimpleNamespace(**kwargs)\n    )\n    isaaclab_sensors.patterns = SimpleNamespace(\n        GridPatternCfg=lambda **kwargs: SimpleNamespace(**kwargs)\n    )\n    isaaclab_terrains = ModuleType(\"isaaclab.terrains\")\n    isaaclab_terrains.TerrainImporterCfg = object\n    isaaclab_utils = ModuleType(\"isaaclab.utils\")\n    isaaclab_utils.configclass = _configclass\n    loguru = ModuleType(\"loguru\")\n    loguru.logger = SimpleNamespace(info=lambda *args, **kwargs: None)\n\n    fake_terrain = ModuleType(\n        \"holomotion.src.env.isaaclab_components.isaaclab_terrain\"\n    )\n    fake_terrain.build_terrain_config = lambda *args, **kwargs: None\n\n    fake_unitree = ModuleType(\n        \"holomotion.src.env.isaaclab_components.unitree_actuators\"\n    )\n    fake_unitree.UnitreeActuator = actuator_module.UnitreeActuator\n    fake_unitree.UnitreeActuatorCfg = actuator_module.UnitreeActuatorCfg\n    fake_unitree.UnitreeErfiActuator = actuator_module.UnitreeErfiActuator\n    fake_unitree.UnitreeErfiActuatorCfg = (\n        actuator_module.UnitreeErfiActuatorCfg\n    )\n\n    for name, module in {\n        \"isaaclab\": isaaclab_root,\n        \"isaaclab.sim\": isaaclab_sim,\n        \"isaaclab.actuators\": isaaclab_actuators,\n        \"isaaclab.assets\": isaaclab_assets,\n        \"isaaclab.scene\": isaaclab_scene,\n        \"isaaclab.sensors\": isaaclab_sensors,\n        \"isaaclab.terrains\": isaaclab_terrains,\n        \"isaaclab.utils\": isaaclab_utils,\n        \"loguru\": loguru,\n        (\n            \"holomotion.src.env.isaaclab_components.isaaclab_terrain\"\n        ): fake_terrain,\n        (\n            \"holomotion.src.env.isaaclab_components.unitree_actuators\"\n        ): fake_unitree,\n    }.items():\n        monkeypatch.setitem(sys.modules, name, module)\n\n    module_name = \"_test_isaaclab_scene\"\n    spec = importlib.util.spec_from_file_location(\n        module_name, SCENE_MODULE_PATH\n    )\n    module = importlib.util.module_from_spec(spec)\n    assert spec is not None\n    assert spec.loader is not None\n    sys.modules[module_name] = module\n    spec.loader.exec_module(module)\n    return module\n\n\ndef test_scene_builder_selects_unitree_erfi_cfg(monkeypatch):\n    module = _load_scene_module(monkeypatch)\n\n    actuators = module._build_unitree_actuator_cfg(\n        {\"actuator_type\": \"unitree_erfi\"},\n        {\"erfi\": {\"enabled\": True, \"rfi_lim\": 0.2}},\n    )\n\n    assert isinstance(actuators[\"all_joints\"], module.UnitreeErfiActuatorCfg)\n    assert actuators[\"all_joints\"].erfi_enabled is True\n    assert actuators[\"all_joints\"].rfi_lim == 0.2\n\n\ndef test_scene_builder_keeps_plain_unitree_cfg(monkeypatch):\n    module = _load_scene_module(monkeypatch)\n\n    actuators = module._build_unitree_actuator_cfg(\n        {\"actuator_type\": \"unitree\"}, {}\n    )\n\n    assert isinstance(actuators[\"all_joints\"], module.UnitreeActuatorCfg)\n    assert not hasattr(actuators[\"all_joints\"], \"rfi_lim\")\n\n\ndef test_scene_builder_disables_erfi_when_domain_rand_missing(monkeypatch):\n    module = _load_scene_module(monkeypatch)\n\n    actuators = module._build_unitree_actuator_cfg(\n        {\"actuator_type\": \"unitree_erfi\"}, {}\n    )\n\n    assert isinstance(actuators[\"all_joints\"], module.UnitreeErfiActuatorCfg)\n    assert actuators[\"all_joints\"].erfi_enabled is False\n\n\ndef test_scene_builder_applies_domain_rand_action_delay_to_unitree(\n    monkeypatch,\n):\n    module = _load_scene_module(monkeypatch)\n\n    actuators = module._build_unitree_actuator_cfg(\n        {\"actuator_type\": \"unitree\"},\n        {\"action_delay\": {\"enabled\": True, \"min_delay\": 1, \"max_delay\": 3}},\n    )\n\n    assert isinstance(actuators[\"all_joints\"], module.UnitreeActuatorCfg)\n    assert actuators[\"all_joints\"].min_delay == 1\n    assert actuators[\"all_joints\"].max_delay == 3\n\n\ndef test_scene_builder_applies_domain_rand_action_delay_to_unitree_erfi(\n    monkeypatch,\n):\n    module = _load_scene_module(monkeypatch)\n\n    actuators = module._build_unitree_actuator_cfg(\n        {\"actuator_type\": \"unitree_erfi\"},\n        {\n            \"erfi\": {\"enabled\": True},\n            \"action_delay\": {\n                \"enabled\": True,\n                \"min_delay\": 2,\n                \"max_delay\": 4,\n            },\n        },\n    )\n\n    assert isinstance(actuators[\"all_joints\"], module.UnitreeErfiActuatorCfg)\n    assert actuators[\"all_joints\"].min_delay == 2\n    assert actuators[\"all_joints\"].max_delay == 4\n\n\ndef test_scene_builder_applies_erfi_ema_filter_config(monkeypatch):\n    module = _load_scene_module(monkeypatch)\n\n    actuators = module._build_unitree_actuator_cfg(\n        {\n            \"actuator_type\": \"unitree_erfi\",\n            \"ema_filter_enabled\": True,\n            \"ema_filter_alpha\": 0.37,\n        },\n        {\"erfi\": {\"enabled\": True}},\n    )\n\n    assert isinstance(actuators[\"all_joints\"], module.UnitreeErfiActuatorCfg)\n    assert actuators[\"all_joints\"].class_type.__name__ == \"UnitreeErfiActuator\"\n    assert actuators[\"all_joints\"].ema_filter_enabled is True\n    assert actuators[\"all_joints\"].ema_filter_alpha == 0.37\n\n\ndef test_scene_builder_disables_action_delay_when_domain_rand_missing(\n    monkeypatch,\n):\n    module = _load_scene_module(monkeypatch)\n\n    actuators = module._build_unitree_actuator_cfg(\n        {\"actuator_type\": \"unitree\"}, {}\n    )\n\n    assert isinstance(actuators[\"all_joints\"], module.UnitreeActuatorCfg)\n    assert actuators[\"all_joints\"].min_delay == 0\n    assert actuators[\"all_joints\"].max_delay == 0\n"
  },
  {
    "path": "tests/test_visualize_with_mujoco.py",
    "content": "import sys\nfrom pathlib import Path\n\nimport numpy as np\n\nsys.path.insert(0, str(Path(__file__).resolve().parents[1]))\n\nfrom holomotion.src.motion_retargeting.utils.visualize_with_mujoco import (\n    _resolve_visualization_arrays,\n)\n\n\ndef test_resolve_visualization_arrays_uses_robot_for_pose_and_ref_for_overlay():\n    arrays = {\n        \"robot_dof_pos\": np.array([[1.0, 2.0]], dtype=np.float32),\n        \"robot_global_translation\": np.array(\n            [[[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]], dtype=np.float32\n        ),\n        \"robot_global_rotation_quat\": np.array(\n            [[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]],\n            dtype=np.float32,\n        ),\n        \"ref_global_translation\": np.array(\n            [[[20.0, 21.0, 22.0], [23.0, 24.0, 25.0]]], dtype=np.float32\n        ),\n    }\n\n    resolved = _resolve_visualization_arrays(\n        arrays=arrays,\n        key_prefix_order=[\"robot_\"],\n        draw_ref_body_spheres=True,\n        ref_key_prefix_order=[\"ref_\"],\n    )\n\n    np.testing.assert_allclose(resolved[\"dof_pos\"], arrays[\"robot_dof_pos\"])\n    np.testing.assert_allclose(\n        resolved[\"global_translation\"], arrays[\"robot_global_translation\"]\n    )\n    np.testing.assert_allclose(\n        resolved[\"global_rotation_quat\"],\n        arrays[\"robot_global_rotation_quat\"],\n    )\n    np.testing.assert_allclose(\n        resolved[\"ref_body_positions\"],\n        arrays[\"ref_global_translation\"],\n    )\n"
  },
  {
    "path": "train.env",
    "content": "# This is the environment file for running HoloMotion scripts.\n\nexport CONDA_BASE=$(conda info --base)\nexport Train_CONDA_PREFIX=\"$CONDA_BASE/envs/holomotion_train\"\n\n# export CUDA_HOME=$Train_CONDA_PREFIX\nexport CUDA_HOME=/usr/local/cuda\n# export LD_LIBRARY_PATH=\"$LD_LIBRARY_PATH:$Train_CONDA_PREFIX/lib/:$Train_CONDA_PREFIX/lib/stubs\"\n# export LIBRARY_PATH=\"$Train_CONDA_PREFIX/lib/stubs:$Train_CONDA_PREFIX/lib:$LIBRARY_PATH\"\nexport HYDRA_FULL_ERROR=1\n\nexport OMNI_KIT_ACCEPT_EULA=\"YES\"\nexport ACCEPT_EULA=\"YES\"\n# export CUDA_LAUNCH_BLOCKING=1\nexport USE_NVRTC=1\nexport HDF5_USE_FILE_LOCKING=FALSE\nexport HOLOMOTION_ISAAC_STAGGER_SEC=1\nexport HOLOMOTION_HDF5_RDCC_NBYTES=$((4 * 1024 * 1024)) # 4MB\nexport HOLOMOTION_HDF5_MAX_OPEN_SHARDS=16               # 16 shards\nexport TORCH_DISTRIBUTED_DEBUG=INFO\n\n# export TORCHDYNAMO_VERBOSE=0\n\necho \"--------------------------------\"\necho \"Train_CONDA_PREFIX: $Train_CONDA_PREFIX\"\necho \"CUDA_HOME: $CUDA_HOME\"\necho \"LD_LIBRARY_PATH: $LD_LIBRARY_PATH\"\necho \"LIBRARY_PATH: $LIBRARY_PATH\"\necho \"HYDRA_FULL_ERROR: $HYDRA_FULL_ERROR\"\necho \"OMNI_KIT_ACCEPT_EULA: $OMNI_KIT_ACCEPT_EULA\"\necho \"HDF5_USE_FILE_LOCKING: $HDF5_USE_FILE_LOCKING\"\necho \"HOLOMOTION_ISAAC_STAGGER_SEC: $HOLOMOTION_ISAAC_STAGGER_SEC\"\necho \"HOLOMOTION_HDF5_RDCC_NBYTES: $HOLOMOTION_HDF5_RDCC_NBYTES\"\necho \"HOLOMOTION_HDF5_MAX_OPEN_SHARDS: $HOLOMOTION_HDF5_MAX_OPEN_SHARDS\"\necho \"--------------------------------\"\n\n# Graceful shutdown function for training scripts\n# Note: Scripts must set TRAIN_PID variable and call: trap cleanup SIGINT SIGTERM\ncleanup() {\n    echo \"\"\n    echo \"🛑 Cleanup triggered - shutting down training process ${TRAIN_PID}...\"\n    exec 2>/dev/null # Suppress error messages during cleanup\n    [[ -n \"${TRAIN_PID}\" ]] && kill -TERM \"${TRAIN_PID}\" 2>/dev/null && echo \"  ✓ Sent TERM signal to process ${TRAIN_PID}\"\n    sleep 2\n    [[ -n \"${TRAIN_PID}\" ]] && pkill -P \"${TRAIN_PID}\" 2>/dev/null && echo \"  ✓ Killed child processes\"\n    exec 2>&1\n    echo \"  ✓ Cleanup complete\"\n}\n"
  }
]