Repository: HorizonRobotics/HoloMotion Branch: master Commit: 75ca12750541 Files: 249 Total size: 2.8 MB Directory structure: gitextract_3fyqn965/ ├── .gitattributes ├── .githooks/ │ ├── README.md │ └── pre-commit ├── .gitignore ├── .gitlab-ci.yml ├── .gitmodules ├── LICENSE ├── Makefile ├── NOTICE ├── README.md ├── assets/ │ ├── robots/ │ │ └── unitree/ │ │ └── G1/ │ │ └── 29dof/ │ │ ├── g1_29dof.xml │ │ ├── g1_29dof_rev_1_0.urdf │ │ ├── g1_29dof_rev_1_0.xml │ │ ├── g1_29dof_rev_1_0_s100.urdf │ │ └── scene_29dof.xml │ └── test_data/ │ └── motion_retargeting/ │ └── ACCAD/ │ └── Male1Walking_c3d/ │ └── Walk_B10_-_Walk_turn_left_45_stageii.npz ├── deploy.env ├── deployment/ │ ├── deploy_environment.sh │ ├── holomotion_teleop/ │ │ ├── holomotion_teleop_node.py │ │ ├── holomotion_teleop_setup.md │ │ └── setup_holomotion_teleop_x86_ubuntu2204.sh │ └── unitree_g1_ros2_29dof/ │ ├── launch_holomotion_29dof.sh │ ├── launch_holomotion_29dof_docker.sh │ ├── src/ │ │ ├── CMakeLists.txt │ │ ├── config/ │ │ │ └── g1_29dof_holomotion.yaml │ │ ├── humanoid_policy/ │ │ │ ├── __init__.py │ │ │ ├── holomotion_fk_root_only.py │ │ │ ├── obs_builder/ │ │ │ │ ├── __init__.py │ │ │ │ └── obs_builder.py │ │ │ ├── policy_node_29dof.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── command_helper.py │ │ │ ├── maths.py │ │ │ ├── motor_crc.py │ │ │ ├── remote_controller_filter.py │ │ │ ├── rotation_helper.py │ │ │ └── rotations.py │ │ ├── include/ │ │ │ └── common/ │ │ │ ├── motor_crc.h │ │ │ ├── motor_crc_hg.h │ │ │ ├── ros2_sport_client.h │ │ │ └── wireless_controller.h │ │ ├── launch/ │ │ │ └── holomotion_29dof_launch.py │ │ ├── models/ │ │ │ └── .gitkeep │ │ ├── motion_data/ │ │ │ └── .gitkeep │ │ ├── package.xml │ │ ├── resource/ │ │ │ └── humanoid_control │ │ ├── setup.cfg │ │ ├── setup.py │ │ └── src/ │ │ ├── common/ │ │ │ ├── motor_crc.cpp │ │ │ ├── motor_crc_hg.cpp │ │ │ ├── ros2_sport_client.cpp │ │ │ └── wireless_controller.cpp │ │ └── main_node.cpp │ └── start_container.sh ├── docs/ │ ├── environment_setup.md │ ├── evaluate_motion_tracking.md │ ├── holomotion_motion_file_spec.md │ ├── motion_retargeting.md │ ├── mujoco_sim2sim.md │ ├── realworld_deployment.md │ ├── smpl_data_curation.md │ └── train_motion_tracking.md ├── environments/ │ ├── environment_deploy.yaml │ ├── environment_train_isaaclab_cu118.yaml │ ├── environment_train_isaaclab_cu128.yaml │ ├── requirements_base.txt │ ├── requirements_deploy.txt │ ├── requirements_torch_cu118.txt │ ├── requirements_torch_cu128.txt │ └── requirements_torch_cu130.txt ├── holomotion/ │ ├── config/ │ │ ├── algo/ │ │ │ ├── ppo.yaml │ │ │ └── ppo_tf.yaml │ │ ├── data_curation/ │ │ │ ├── joints2smpl.yaml │ │ │ └── smplify_base.yaml │ │ ├── env/ │ │ │ ├── domain_randomization/ │ │ │ │ ├── NO_domain_rand.yaml │ │ │ │ ├── domain_rand_medium.yaml │ │ │ │ ├── domain_rand_small.yaml │ │ │ │ └── domain_rand_strong.yaml │ │ │ ├── motion_tracking.yaml │ │ │ ├── observations/ │ │ │ │ ├── motion_tracking/ │ │ │ │ │ ├── obs_motion_tracking_mlp.yaml │ │ │ │ │ └── obs_motion_tracking_tf-moe.yaml │ │ │ │ └── velocity_tracking/ │ │ │ │ └── obs_velocity_tracking.yaml │ │ │ ├── rewards/ │ │ │ │ ├── motion_tracking/ │ │ │ │ │ └── rew_motion_tracking.yaml │ │ │ │ └── velocity_tracking/ │ │ │ │ └── rew_velocity_tracking.yaml │ │ │ ├── terminations/ │ │ │ │ ├── NO_termination.yaml │ │ │ │ ├── termination_motion_tracking.yaml │ │ │ │ └── termination_velocity_tracking.yaml │ │ │ ├── terrain/ │ │ │ │ ├── isaaclab_plane.yaml │ │ │ │ └── isaaclab_rough.yaml │ │ │ └── velocity_tracking.yaml │ │ ├── evaluation/ │ │ │ ├── eval_isaaclab.yaml │ │ │ ├── eval_mujoco_sim2sim.yaml │ │ │ └── eval_velocity_tracking.yaml │ │ ├── modules/ │ │ │ ├── motion_tracking/ │ │ │ │ ├── motion_tracking_mlp.yaml │ │ │ │ └── motion_tracking_tf-moe.yaml │ │ │ └── velocity_tracking/ │ │ │ └── velocity_tracking_mlp.yaml │ │ ├── motion_retargeting/ │ │ │ ├── gmr_to_holomotion.yaml │ │ │ ├── holomotion_preprocess.yaml │ │ │ ├── kinematic_filter.yaml │ │ │ ├── pack_hdf5_database.yaml │ │ │ ├── pack_hdf5_v2.yaml │ │ │ └── unitree_G1_29dof_retargeting.yaml │ │ ├── mujoco_eval/ │ │ │ └── sim2sim.yaml │ │ ├── robot/ │ │ │ └── unitree/ │ │ │ └── G1/ │ │ │ └── 29dof/ │ │ │ ├── 29dof_training_isaaclab.yaml │ │ │ └── 29dof_training_isaaclab_s100.yaml │ │ └── training/ │ │ ├── motion_tracking/ │ │ │ ├── train_g1_29dof_motion_tracking_mlp.yaml │ │ │ └── train_g1_29dof_motion_tracking_tf-moe.yaml │ │ ├── train_base.yaml │ │ └── velocity_tracking/ │ │ └── train_g1_29dof_velocity_tracking_mlp.yaml │ ├── scripts/ │ │ ├── data_curation/ │ │ │ ├── convert_to_amass.sh │ │ │ ├── filter_smpl_data.sh │ │ │ ├── video_to_smpl_gvhmr.sh │ │ │ └── visualize_smpl_npz.sh │ │ ├── evaluation/ │ │ │ ├── calc_offline_eval_metrics.sh │ │ │ ├── eval_motion_tracking.sh │ │ │ ├── eval_mujoco_sim2sim.sh │ │ │ ├── eval_velocity_tracking.sh │ │ │ ├── mean_process_5metrics.py │ │ │ └── multi_model_metrics_analysis.sh │ │ ├── motion_retargeting/ │ │ │ ├── apply_gmr_motion_retarget_patch.sh │ │ │ ├── pack_hdf5_v2.sh │ │ │ ├── run_holomotion_preprocessing.sh │ │ │ ├── run_kinematic_filter.sh │ │ │ ├── run_motion_retargeting_gmr_bvh.sh │ │ │ ├── run_motion_retargeting_gmr_smplx.sh │ │ │ ├── run_motion_retargeting_gmr_to_holomotion.sh │ │ │ └── run_motion_viz_mujoco.sh │ │ └── training/ │ │ ├── train_motion_tracking.sh │ │ └── train_velocity_tracking.sh │ ├── src/ │ │ ├── algo/ │ │ │ ├── __init__.py │ │ │ ├── algo_base.py │ │ │ ├── algo_utils.py │ │ │ ├── ppo.py │ │ │ └── ppo_tf.py │ │ ├── data_curation/ │ │ │ ├── .gitignore │ │ │ ├── __init__.py │ │ │ ├── data_smplify.py │ │ │ ├── filter/ │ │ │ │ ├── filter.py │ │ │ │ └── label_data.py │ │ │ ├── smpl_npz_to_html.py │ │ │ ├── smplify/ │ │ │ │ ├── __init__.py │ │ │ │ ├── smplify_humanact12.py │ │ │ │ ├── smplify_motionx.py │ │ │ │ ├── smplify_omomo.py │ │ │ │ └── smplify_zjumocap.py │ │ │ ├── templates/ │ │ │ │ └── index_wooden_static.html │ │ │ ├── video_to_smpl_gvhmr.py │ │ │ ├── vison_mocap/ │ │ │ │ └── joints2smpl.py │ │ │ └── visualize_smpl_npz.py │ │ ├── env/ │ │ │ ├── __init__.py │ │ │ ├── isaaclab_components/ │ │ │ │ ├── __init__.py │ │ │ │ ├── isaaclab_actions.py │ │ │ │ ├── isaaclab_curriculum.py │ │ │ │ ├── isaaclab_domain_rand.py │ │ │ │ ├── isaaclab_motion_tracking_command.py │ │ │ │ ├── isaaclab_observation.py │ │ │ │ ├── isaaclab_rewards.py │ │ │ │ ├── isaaclab_scene.py │ │ │ │ ├── isaaclab_simulator.py │ │ │ │ ├── isaaclab_termination.py │ │ │ │ ├── isaaclab_terrain.py │ │ │ │ ├── isaaclab_utils.py │ │ │ │ ├── isaaclab_velocity_tracking_command.py │ │ │ │ └── unitree_actuators.py │ │ │ ├── motion_tracking.py │ │ │ └── velocity_tracking.py │ │ ├── evaluation/ │ │ │ ├── __init__.py │ │ │ ├── eval_motion_tracking.py │ │ │ ├── eval_motion_tracking_single.py │ │ │ ├── eval_mujoco_sim2sim.py │ │ │ ├── eval_velocity_tracking.py │ │ │ ├── find_worst_clips.py │ │ │ ├── metrics.py │ │ │ ├── multi_model_metrics_report.py │ │ │ ├── obs/ │ │ │ │ ├── __init__.py │ │ │ │ └── obs_builder.py │ │ │ ├── ray_evaluator_actor.py │ │ │ └── ray_metrics_postprocess.py │ │ ├── modules/ │ │ │ ├── __init__.py │ │ │ ├── agent_modules.py │ │ │ └── network_modules.py │ │ ├── motion_retargeting/ │ │ │ ├── __init__.py │ │ │ ├── gmr_to_holomotion.py │ │ │ ├── holomotion_fk.py │ │ │ ├── holomotion_preprocess.py │ │ │ ├── kinematic_filter.py │ │ │ ├── pack_hdf5_v2.py │ │ │ ├── reference_filtering.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── _schema.json │ │ │ ├── rotation_conversions.py │ │ │ ├── torch_humanoid_batch.py │ │ │ └── visualize_with_mujoco.py │ │ ├── training/ │ │ │ ├── __init__.py │ │ │ ├── h5_dataloader.py │ │ │ ├── reference_filter_export.py │ │ │ └── train.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── config.py │ │ ├── frame_utils.py │ │ ├── isaac_utils/ │ │ │ ├── __init__.py │ │ │ ├── maths.py │ │ │ ├── rotations.py │ │ │ └── setup.py │ │ ├── onnx_export.py │ │ ├── reference_prefix.py │ │ ├── torch_utils.py │ │ └── unitree_g1_actuator_calculator.py │ └── tests/ │ └── __init__.py ├── pyproject.toml ├── tests/ │ ├── benchmark_legacy_onnx_attention.py │ ├── benchmark_moe_router_export.py │ ├── test_actor_export_config.py │ ├── test_algo_base_iteration_logging.py │ ├── test_build_quantization_dataset.py │ ├── test_cache_curriculum_sampler.py │ ├── test_domain_rand_config_builder.py │ ├── test_eval_mujoco_action_delay.py │ ├── test_eval_mujoco_action_ema.py │ ├── test_eval_mujoco_contact_export.py │ ├── test_eval_mujoco_s100_horizon_ptq.py │ ├── test_eval_mujoco_use_gpu.py │ ├── test_eval_onnx_io_dump.py │ ├── test_evaluation_metrics.py │ ├── test_isaaclab_termination.py │ ├── test_mean_process_5metrics.py │ ├── test_motion_cache_gather_state.py │ ├── test_motion_cache_startup.py │ ├── test_motion_tracking_command_reference_prefix.py │ ├── test_motion_tracking_timing.py │ ├── test_mujoco_filtered_ref_compat.py │ ├── test_obs_norm_compile.py │ ├── test_observation_frames.py │ ├── test_onnx_attention_export.py │ ├── test_onnx_export.py │ ├── test_plot_moe_expert_heatmap.py │ ├── test_plot_state_series.py │ ├── test_ppo_checkpoint_sigma_override.py │ ├── test_ppo_entropy_annealing.py │ ├── test_ppo_symmetry_loss.py │ ├── test_ppo_tf_aux_keybody.py │ ├── test_ref_router_actor.py │ ├── test_ref_router_seq_actor.py │ ├── test_reference_filter_export.py │ ├── test_reference_motion_config_wiring.py │ ├── test_root_rel_rewards.py │ ├── test_unitree_actuators.py │ └── test_visualize_with_mujoco.py └── train.env ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ *.ipynb filter=lfs diff=lfs merge=lfs -text *.pt filter=lfs diff=lfs merge=lfs -text *.obj filter=lfs diff=lfs merge=lfs -text *.dae filter=lfs diff=lfs merge=lfs -text *.onnx filter=lfs diff=lfs merge=lfs -text *.pkl filter=lfs diff=lfs merge=lfs -text *.npz filter=lfs diff=lfs merge=lfs -text *.npy filter=lfs diff=lfs merge=lfs -text assets/smplx filter=lfs diff=lfs merge=lfs -text assets/smpl filter=lfs diff=lfs merge=lfs -text assets/test_data filter=lfs diff=lfs merge=lfs -text assets/media/*.png filter=lfs diff=lfs merge=lfs -text assets/media/*.jpg filter=lfs diff=lfs merge=lfs -text assets/media/*.jpeg filter=lfs diff=lfs merge=lfs -text assets/videos/*.mp4 filter=lfs diff=lfs merge=lfs -text *.gif filter=lfs diff=lfs merge=lfs -text *.svg filter=lfs diff=lfs merge=lfs -text *.png filter=lfs diff=lfs merge=lfs -text ================================================ FILE: .githooks/README.md ================================================ # Git hooks Pre-commit runs [ruff](https://docs.astral.sh/ruff/) format on staged Python files (using `train.env` for the correct environment). **Install ruff** in the holomotion train environment if it is absent: ```bash conda activate holomotion_train pip install ruff ``` Ruff is also listed in `environments/requirements_base.txt` if you install deps from there. **Enable hooks** (run once from repo root): ```bash git config core.hooksPath .githooks ``` Ensure the hook is executable: `chmod +x .githooks/pre-commit` **Requirement:** Run `git commit` from a shell where conda is available so `train.env` can set `Train_CONDA_PREFIX`. **Skip hook for one commit:** `git commit --no-verify` ================================================ FILE: .githooks/pre-commit ================================================ #!/bin/bash # Run ruff format on staged Python files before commit. # Use from holomotion repo root (standalone clone or submodule). # Requires: run git commit from a shell where conda is available (so train.env can set Train_CONDA_PREFIX). set -e cd "$(git rev-parse --show-toplevel)" mapfile -t staged_py < <(git diff --cached --name-only --diff-filter=ACM | grep '\.py$' || true) if [ ${#staged_py[@]} -eq 0 ]; then exit 0 fi # train.env sets Train_CONDA_PREFIX; conda must be on PATH when you run git commit source train.env "$Train_CONDA_PREFIX/bin/ruff" format --config pyproject.toml "${staged_py[@]}" git add "${staged_py[@]}" exit 0 ================================================ FILE: .gitignore ================================================ # ignore logs and cache logs/ logs_eval/ data/ outputs/ .archive/ tmp/ # ignore deployment bag_record, install, log deployment/unitree_g1_ros2/bag_record/ deployment/unitree_g1_ros2/install/ deployment/unitree_g1_ros2/log/ # ignore data and outputs data data/ outputs/ build/ install/ log/ .DS_Store/ **.egg-info/ **.log **.LOG # ignore large files *.log *.pkl *.pt *.onnx *.npy *.npz *.zip *.tar.gz *.obj *.dae *.STL # ignore video, image, etc. *.mp4 *.avi *.mov *.png *.pdf __pycache__/ *.pyc *.egg-info .agents/ .cursor/ .cursorignore .vscode/ .*_cache/ **/usd/ assets/isaac/ not_for_commit/ thirdparties/smpl_models/ MUJOCO_LOG.TXT # allow certain files !deployment/unitree_g1_ros2/src/models/*.onnx !deployment/unitree_g1_ros2/src/motion_data/*.pkl !assets/smpl/*.pkl !assets/smpl/*.npz !assets/smplx/*.pkl !assets/smplx/*.npz !assets/test_data/** !assets/media/** # macOS system files .DS_Store **/.DS_Store ================================================ FILE: .gitlab-ci.yml ================================================ workflow: rules: - if: $CI_PIPELINE_SOURCE == 'merge_request_event' job1: script: - echo "This job runs in merge request pipelines" ================================================ FILE: .gitmodules ================================================ [submodule "thirdparties/SMPLSim"] path = thirdparties/SMPLSim url = https://github.com/ZhengyiLuo/SMPLSim branch = master [submodule "thirdparties/joints2smpl"] path = thirdparties/joints2smpl url = https://github.com/wangsen1312/joints2smpl.git branch = main [submodule "thirdparties/omomo_release"] path = thirdparties/omomo_release url = https://github.com/lijiaman/omomo_release.git branch = main [submodule "thirdparties/unitree_ros"] path = thirdparties/unitree_ros url = https://github.com/unitreerobotics/unitree_ros branch = master [submodule "thirdparties/unitree_ros2"] path = thirdparties/unitree_ros2 url = https://github.com/unitreerobotics/unitree_ros2 branch = master [submodule "thirdparties/unitree_sdk2_python"] path = thirdparties/unitree_sdk2_python url = https://github.com/unitreerobotics/unitree_sdk2_python.git [submodule "thirdparties/cyclonedds"] path = thirdparties/cyclonedds url = https://github.com/eclipse-cyclonedds/cyclonedds branch = 0.10.2 [submodule "thirdparties/unitree_sdk2"] path = thirdparties/unitree_sdk2 url = https://github.com/unitreerobotics/unitree_sdk2.git [submodule "thirdparties/GMR"] path = thirdparties/GMR url = https://github.com/YanjieZe/GMR.git branch = master [submodule "thirdparties/HoloMotion_assets"] path = thirdparties/HoloMotion_assets url = https://huggingface.co/datasets/HorizonRobotics/HoloMotion_assets [submodule "thirdparties/smplx"] path = thirdparties/smplx url = https://github.com/vchoutas/smplx [submodule "thirdparties/GVHMR"] path = thirdparties/GVHMR url = https://github.com/zju3dv/GVHMR.git ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2025 maiyue01.chen Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ # Variables PY_SRC := holomotion/ # Your Python code directory RUFF := ruff # Assumes Ruff is installed locally PYTEST := pytest -v TESTS := holomotion/tests/ COV := --cov=holomotion/src --cov-report=term-missing # Directory to lint/format - can be overridden with DIR=path DIR ?= holomotion/src .PHONY: lint format check lint-dir format-dir # Run Ruff linter on default directory lint: @echo "Linting with Ruff..." @$(RUFF) check $(PY_SRC) # Format code in default directory format: @echo "Formatting with Ruff..." @$(RUFF) format $(PY_SRC) @$(RUFF) check --fix $(PY_SRC) # Auto-fix lint errors # Run Ruff linter on specific directory (with fallback) lint-dir: @echo "Linting directory: $(DIR)" @$(RUFF) check $(DIR) # Format code in specific directory (with fallback) format-dir: @echo "Formatting directory: $(DIR)" @$(RUFF) format $(DIR) @$(RUFF) check --fix $(DIR) # Auto-fix lint errors # Strict check (for CI) check: @$(RUFF) check $(PY_SRC) --exit-non-zero-on-fix # Run all tests test: $(PYTEST) $(TESTS) ================================================ FILE: NOTICE ================================================ ======================================================================= ASAP's MIT License ======================================================================= Code derived from implementations in ASAP should mention its derivation and reference the following license: MIT License Copyright (c) 2025 ASAP Team Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ======================================================================= omomo_release's MIT License ======================================================================= Code derived from implementations in omomo_release should mention its derivation and reference the following license: MIT License Copyright (c) 2023 Jiaman Li Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ======================================================================= NVIDIA License ======================================================================= Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. NVIDIA CORPORATION and its licensors retain all intellectual property and proprietary rights in and to this software, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this software and related documentation without an express license agreement from NVIDIA CORPORATION is strictly prohibited. ================================================ FILE: README.md ================================================
HoloMotion Logo --- [![Python](https://img.shields.io/badge/Python3.11-3776AB?logo=python&logoColor=fff)](#) [![Ubuntu](https://img.shields.io/badge/Ubuntu22.04-E95420?logo=ubuntu&logoColor=white)](#) [![License](https://img.shields.io/badge/License-Apache_2.0-green?logo=apache&logoColor=white)](./LICENSE) [![Safari](https://img.shields.io/badge/Website-006CFF?logo=safari&logoColor=fff)](https://horizonrobotics.github.io/robot_lab/holomotion/) [![HuggingFace](https://img.shields.io/badge/-HuggingFace-3B4252?style=flat&logo=huggingface&logoColor=)](https://huggingface.co/collections/HorizonRobotics/holomotion) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/HorizonRobotics/HoloMotion) [![WeChat](https://img.shields.io/badge/Wechat-7BB32E?logo=wechat&logoColor=white)](https://horizonrobotics.feishu.cn/docx/Xs3cdEI8bo1EZuxUfzjckTgKn2c)
# HoloMotion: A Foundation Model for Whole-Body Humanoid Control ## NEWS - [2026.04.04] The v1.2 version of HoloMotion has been released, we provide pre-trained motion tracking and velocity tracking models for the community to deploy directly. - [2026.01.06] The v1.1 version of HoloMotion has been released, representing a major step forward toward a fully engineered, stable, and reproducible humanoid motion intelligence system. - [2025.11.05] The v1.0 version of HoloMotion has been released, and the WeChat user group is now open! Please scan the [QR Code](https://horizonrobotics.feishu.cn/docx/Xs3cdEI8bo1EZuxUfzjckTgKn2c) to join. ## Pre-trained Models - Motion Tracking Model: [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_motion_tracking_model) - Velocity Tracking Model: [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_velocity_tracking_model) Please read the doc for [real-world deployment](docs/realworld_deployment.md) for more details on how to use the models. ## Introduction HoloMotion is a foundation model for humanoid robotics, designed to fullfill robust, real-time, and generalizable whole-body control. Our framework provides an end-to-end solution, encompassing the entire workflow from data curation and motion retargeting to distributed model training, evaluation, and seamless deployment on physical hardware via ROS2. HoloMotion's modular architecture allows for flexible adaptation and extension, enabling researchers and developers to build and benchmark agents that can imitate, generalize, and master complex whole-body motions. For those at the forefront of creating the next generation of humanoid robots, HoloMotion serves as a powerful, extensible, and open-source foundation for achieving whole-body control. --- ### 🛠️ Roadmap: Progress Toward Any Humanoid Control We envision HoloMotion as a general-purpose foundation for humanoid motion and control. Its development is structured around four core generalization goals: Any Pose, Any Command, Any Terrain, and Any Embodiment. Each goal corresponds to a major version milestone. | Version | Target Capability | Description | | -------- | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------- | | **v1.0** | 🔄 Any Pose | Achieve robust tracking and imitation of diverse, whole-body human motions, forming the core of the imitation learning capability. | | **v2.0** | ⏳ Any Command | Enable language- and task-conditioned motion generation, allowing for goal-directed and interactive behaviors. | | **v3.0** | ⏳ Any Terrain | Master adaptation to uneven, dynamic, and complex terrains, enhancing real-world operational robustness. | | **v4.0** | ⏳ Any Embodiment | Generalize control policies across humanoids with varying morphologies and kinematics, achieving true embodiment-level abstraction. | > Each stage builds on the previous one, moving from motion imitation to instruction following, terrain adaptation, and embodiment-level generalization. ## Pipeline Overview ```mermaid flowchart LR A["🔧 1. Environment Setup
Dependencies & conda"] subgraph dataFrame ["DATA"] B["📊 2. Dataset Preparation
Download & curate"] C["🔄 3. Motion Retargeting
Human to robot motion"] B --> C end subgraph modelFrame ["TRAIN & EVAL"] D["🧠 4. Model Training
Train with HoloMotion"] E["📈 5. Evaluation
Test & export"] D --> E end F["🚀 6. Deployment
Deploy to robots"] A --> dataFrame dataFrame --> modelFrame modelFrame --> F classDef subgraphStyle fill:#f9f9f9,stroke:#333,stroke-width:2px,stroke-dasharray:5 5,rx:10,ry:10,font-size:16px,font-weight:bold classDef nodeStyle fill:#e1f5fe,stroke:#0277bd,stroke-width:2px,rx:10,ry:10 class dataFrame,modelFrame subgraphStyle class A,B,C,D,E,F nodeStyle ``` ## Quick Start ### 🔧 1. Environment Setup [[Doc](docs/environment_setup.md)] Set up your development and deployment environments using Conda. This initial step ensures all dependencies are correctly configured for both training and real-world execution. If you only intend to use our pretrained models, you can skip the training environment setup and proceed directly to configure the deployment environment. See the [real-world deployment documentation](docs/realworld_deployment.md) for details. ### 📊 2. Dataset Preparation [[Doc](docs/smpl_data_curation.md)] Acquire and process large-scale motion datasets. Our tools help you curate high-quality data by converting it to the AMASS-compatible smpl format and filtering out anomalies using kinematic metrics. ### 🔄 3. Motion Retargeting [[Doc](docs/motion_retargeting.md)] Translate human motion data into robot-specific kinematic data. Our pipeline leverages [GMR](https://github.com/YanjieZe/GMR) to map human movements onto your robot's morphology, producing optimized HDF5 datasets ready for high-speed, distributed training. ### 🧠 4. Model Training [[Doc](docs/train_motion_tracking.md)] Train your foundation model using our reinforcement learning framework. HoloMotion supports versatile training tasks, including motion tracking and velocity tracking. ### 📈 5. Evaluation [[Doc](docs/evaluate_motion_tracking.md)] Evaluate your trained policies in IsaacLab. Visualize performance, and export trained models in ONNX format for seamless deployment. ### 🚀 6. Real-world Deployment [[Doc](docs/realworld_deployment.md)] Our ROS2 package facilitates the deployment of the exported ONNX models, enabling real-time control on hardware like the Unitree G1. ## Join Us We are hiring full-time engineers, new graduates, and interns who are excited about humanoid robots, motion control, and embodied intelligence. Send your resume by scanning the **WeChat** QR code below to get in touch with us.

## Citation ``` @software{HoloMotion, author = {Maiyue Chen, Kaihui Wang, Bo Zhang, Yi Ren, Zihao Zhu, Xihan Ma, Qijun Huang, Zhiyuan Yang, Yucheng Wang, Zhizhong Su}, title = {HoloMotion: A Foundation Model for Whole-Body Humanoid Control}, year = {2026}, month = April, version = {1.2.0}, url = {https://github.com/HorizonRobotics/HoloMotion}, license = {Apache-2.0} } ``` ## License This project is released under the **[Apache 2.0](https://img.shields.io/badge/license-Apache--2.0-blue.svg)** license. ## Acknowledgements This project is built upon and inspired by several outstanding open source projects: - [GMR](https://github.com/YanjieZe/GMR) - [BeyondMimic](https://github.com/HybridRobotics/whole_body_tracking/tree/dcecabd8c24c68f59d143fdf8e3a670f420c972d) - [ASAP](https://github.com/LeCAR-Lab/ASAP) - [Humanoidverse](https://github.com/LeCAR-Lab/HumanoidVerse) - [PHC](https://github.com/ZhengyiLuo/PHC?tab=readme-ov-file) - [ProtoMotion](https://github.com/NVlabs/ProtoMotions/tree/main/protomotions) - [Mink](https://github.com/kevinzakka/mink) - [PBHC](https://github.com/TeleHuman/PBHC) ================================================ FILE: assets/robots/unitree/G1/29dof/g1_29dof.xml ================================================ ================================================ FILE: assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0.urdf ================================================ ================================================ FILE: assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0.xml ================================================ ================================================ FILE: assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0_s100.urdf ================================================ ================================================ FILE: assets/robots/unitree/G1/29dof/scene_29dof.xml ================================================ ================================================ FILE: assets/test_data/motion_retargeting/ACCAD/Male1Walking_c3d/Walk_B10_-_Walk_turn_left_45_stageii.npz ================================================ version https://git-lfs.github.com/spec/v1 oid sha256:738f96eb1767e281d78631ca697079adefb5f171d581acd622a27740f2503b4e size 5876184 ================================================ FILE: deploy.env ================================================ export CONDA_BASE=$(conda info --base) export Deploy_CONDA_PREFIX="$CONDA_BASE/envs/holomotion_deploy" export CUDA_HOME=$Deploy_CONDA_PREFIX export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$Deploy_CONDA_PREFIX/lib/ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$Deploy_CONDA_PREFIX/lib/stubs export LIBRARY_PATH=$Deploy_CONDA_PREFIX/lib:$LIBRARY_PATH export LIBRARY_PATH=$Deploy_CONDA_PREFIX/lib/stubs:$LIBRARY_PATH export HYDRA_FULL_ERROR=1 echo "--------------------------------" echo "Deploy_CONDA_PREFIX: $Deploy_CONDA_PREFIX" echo "CUDA_HOME: $CUDA_HOME" echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" echo "LIBRARY_PATH: $LIBRARY_PATH" echo "HYDRA_FULL_ERROR: $HYDRA_FULL_ERROR" echo "--------------------------------" ================================================ FILE: deployment/deploy_environment.sh ================================================ #!/bin/bash ############################################################################## # HoloMotion Environment Deployment Script # # This script sets up the complete environment for HoloMotion humanoid robot # system deployment. It handles: # 1. Conda environment creation with all dependencies (CUDA, PyTorch, etc.) # 2. Special dependencies (unitree_sdk2_python) # 3. ROS2 workspace compilation # # Prerequisites: # - Anaconda/Miniconda installed # - ROS2 Humble installed at /opt/ros/humble/ # - Unitree ROS2 SDK at ~/unitree_ros2/ # # Usage: # chmod +x deploy_environment.sh # ./deploy_environment.sh [environment_name] # # Arguments: # environment_name: Optional. Name for the conda environment (default: holomotion_deploy) # # Examples: # ./deploy_environment.sh # Uses default name 'holomotion_deploy' # ./deploy_environment.sh my_robot_env # Uses custom name 'my_robot_env' # # Author: HoloMotion Team ############################################################################## set -e # Exit on any error # Parse command line arguments ENV_NAME="${1:-holomotion_deploy}" echo "🚀 Starting HoloMotion Environment Deployment..." echo "📝 Environment name: $ENV_NAME" # Get script directory SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")" echo "📁 Project root: $PROJECT_ROOT" echo "📁 Script directory: $SCRIPT_DIR" # Step 1: Create conda environment with all dependencies echo "" echo "📦 Step 1: Creating conda environment with all dependencies..." if conda env list | grep -q "^$ENV_NAME "; then echo "⚠️ Environment '$ENV_NAME' already exists. Removing it..." conda env remove -n "$ENV_NAME" -y fi echo "🔧 Creating new environment from environment_deploy.yaml..." echo " This will install: PyTorch (CUDA), NumPy, SciPy, ONNX Runtime, and all other dependencies..." cd "$PROJECT_ROOT" conda env create -f holomotion/environment_deploy.yaml -n "$ENV_NAME" echo "✅ Conda environment with all dependencies created successfully!" # Step 2: Install unitree_sdk2_python echo "" echo "📦 Step 2: Installing unitree_sdk2_python..." # Function to run commands in conda environment run_in_env() { conda run -n "$ENV_NAME" "$@" } echo "🔧 Installing unitree_sdk2_python..." if [ ! -d "$HOME/unitree_sdk2_python" ]; then echo "📥 Cloning unitree_sdk2_python repository..." git clone https://github.com/unitreerobotics/unitree_sdk2_python.git "$HOME/unitree_sdk2_python" fi echo "🔧 Installing unitree_sdk2_python in development mode..." cd "$HOME/unitree_sdk2_python" run_in_env pip install -e . echo "✅ unitree_sdk2_python installed successfully!" # Step 3: Setup ROS2 workspace echo "" echo "📦 Step 3: Setting up ROS2 workspace..." # Ensure conda environment is completely deactivated for ROS2 compilation echo "🔧 Ensuring conda environment is completely deactivated..." # Initialize conda for this script eval "$(conda shell.bash hook)" # Deactivate any active conda environments while [[ "$CONDA_DEFAULT_ENV" != "" && "$CONDA_DEFAULT_ENV" != "base" ]]; do echo " Deactivating conda environment: $CONDA_DEFAULT_ENV" conda deactivate done # If we're still in base environment, deactivate it too if [[ "$CONDA_DEFAULT_ENV" == "base" ]]; then echo " Deactivating base conda environment" conda deactivate fi echo " ✅ Conda environment fully deactivated" # Check ROS2 installation if [ ! -f "/opt/ros/humble/setup.bash" ]; then echo "❌ ROS2 Humble not found at /opt/ros/humble/" echo " Please install ROS2 Humble first: https://docs.ros.org/en/humble/Installation.html" exit 1 fi # Check Unitree ROS2 SDK if [ ! -f "$HOME/unitree_ros2/setup.sh" ]; then echo "❌ Unitree ROS2 SDK not found at ~/unitree_ros2/" echo " Please install Unitree ROS2 SDK first" exit 1 fi echo "🔧 Compiling ROS2 workspace..." cd "$PROJECT_ROOT/holomotion/deployment/unitree_g1_ros2_29dof" # Create necessary directories echo "📁 Creating required directories..." mkdir -p src/models mkdir -p src/motion_data # Clean previous build rm -rf build install log # Source ROS2 and Unitree setup source /opt/ros/humble/setup.bash source ~/unitree_ros2/setup.sh # Build workspace echo "🏗️ Building workspace with colcon..." colcon build echo "✅ ROS2 workspace compiled successfully!" echo "" echo "🎉 Deployment completed successfully!" echo "" echo "📋 Summary of installed packages:" echo " ✅ PyTorch 2.3.1 with CUDA 12.1 support" echo " ✅ ONNX Runtime for neural network inference" echo " ✅ SMPLX for humanoid motion processing" echo " ✅ Scientific computing packages (NumPy, SciPy, etc.)" echo " ✅ Unitree SDK2 Python bindings" echo " ✅ ROS2 workspace compiled" echo "" echo "📋 To run the system:" echo "1. Activate the conda environment:" echo " conda activate $ENV_NAME" echo "" echo "2. Launch the system:" echo " cd $PROJECT_ROOT/holomotion/deployment/unitree_g1_ros2_29dof" echo " bash launch_holomotion.sh" echo "" echo "✅ Environment '$ENV_NAME' setup complete!" echo "🚀 Ready for robot deployment!" ================================================ FILE: deployment/holomotion_teleop/holomotion_teleop_node.py ================================================ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Single-process teleoperation pipeline. This node reads raw Pico body tracking data from XRoboToolkit, converts it to SMPL, applies GMR retargeting, and publishes a 65D observation vector to the robot over ZMQ. Data flow: xrobotoolkit_sdk (body_poses 24x7) -> body_poses_to_smpl_pose_trans -> SMPL_Parser / humanoid_fk -> GMR -> latest_obs(65) -> ZMQ PUB Message format: [topic_bytes][1280-byte JSON header][binary payload] Default payload fields: - latest_obs: (65,) float32 - frame_index: (1,) int64 - timestamp_realtime: (1,) float64 - timestamp_monotonic: (1,) float64 - timestamp_ns: (1,) int64 - pico_dt: (1,) float32 - pico_fps: (1,) float32 """ from __future__ import annotations import argparse from collections import defaultdict from dataclasses import dataclass import json import logging import os import subprocess import sys import threading import time import traceback from types import SimpleNamespace from typing import Any, Dict, Optional, Tuple import numpy as np import torch import torch.nn.functional as F import zmq from scipy.spatial.transform import Rotation as R FILE_DIR = os.path.dirname(os.path.abspath(__file__)) HOLOMOTION_ROOT_DIR = os.path.abspath(os.path.join(FILE_DIR, "..", "..")) SMPL_ASSET_DIR = os.path.join(HOLOMOTION_ROOT_DIR, "assets", "smpl") for extra_path in ( FILE_DIR, os.path.join(FILE_DIR, "GMR"), os.path.join(FILE_DIR, "SMPLSim"), ): if extra_path not in sys.path: sys.path.insert(0, extra_path) try: import xrobotoolkit_sdk as xrt except ImportError: xrt = None from third_party.GMR.general_motion_retargeting.motion_retarget import GeneralMotionRetargeting as GMR from smpl_sim.smpllib.smpl_parser import SMPL_Parser MIRROR_POSE = False MIRROR_AXIS = "x" HEADER_SIZE = 1280 OUT_TOPIC = b"obs65" GMR_LR_SWAP_PAIRS = [ ("left_hip", "right_hip"), ("left_knee", "right_knee"), ("left_foot", "right_foot"), ("left_shoulder", "right_shoulder"), ("left_elbow", "right_elbow"), ("left_wrist", "right_wrist"), ] SMPL_PARENTS_24 = np.array( [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21], dtype=np.int32, ) def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotation matrices to quaternions in wxyz format. """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) quat_by_rijk = torch.stack( [ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) floor = torch.tensor(0.1, dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(floor)) return quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :, ].reshape(batch_dim + (4,)) def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: """ Convert axis-angle vectors to rotation matrices. Input shape: (..., 3) Output shape: (..., 3, 3) """ orig_shape = axis_angle.shape[:-1] aa = axis_angle.reshape(-1, 3) theta = torch.linalg.norm(aa, dim=-1, keepdim=True) axis = aa / torch.clamp(theta, min=1e-8) x = axis[:, 0] y = axis[:, 1] z = axis[:, 2] zeros = torch.zeros_like(x) K = torch.stack( [ zeros, -z, y, z, zeros, -x, -y, x, zeros, ], dim=-1, ).reshape(-1, 3, 3) eye = torch.eye(3, dtype=aa.dtype, device=aa.device).unsqueeze(0).expand(aa.shape[0], -1, -1) sin_theta = torch.sin(theta).unsqueeze(-1) cos_theta = torch.cos(theta).unsqueeze(-1) axis_outer = axis.unsqueeze(-1) @ axis.unsqueeze(-2) small = (theta.squeeze(-1) < 1e-8).unsqueeze(-1).unsqueeze(-1) rot = cos_theta * eye + (1.0 - cos_theta) * axis_outer + sin_theta * K rot = torch.where(small, eye, rot) return rot.reshape(orig_shape + (3, 3)) class Humanoid_Batch_V2: """ Minimal per-frame SMPL kinematics helper used by this script only. Keeping it local avoids importing the much larger training/visualization module. """ def __init__(self, device: torch.device = torch.device("cpu")): self.device = device self.smpl_24_parents = [ -1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21, ] @staticmethod def _relative_link_position(joints_world: torch.Tensor, root_pos: torch.Tensor) -> torch.Tensor: return joints_world - root_pos.unsqueeze(0) def _relative_link_pose(self, full_pose_aa: torch.Tensor) -> torch.Tensor: joint_count = full_pose_aa.shape[0] assert joint_count == len(self.smpl_24_parents), ( f"Joint count mismatch: {joint_count} vs {len(self.smpl_24_parents)}" ) rotation_local = axis_angle_to_matrix(full_pose_aa) rotation_global = torch.empty_like(rotation_local) for joint_idx in range(joint_count): parent = self.smpl_24_parents[joint_idx] if parent == -1: rotation_global[joint_idx] = rotation_local[joint_idx] else: rotation_global[joint_idx] = rotation_global[parent] @ rotation_local[joint_idx] return rotation_global def step_per_frame( self, full_pose_aa: torch.Tensor, root_pos: torch.Tensor, joints: torch.Tensor, ) -> SimpleNamespace: global_joints_position = joints global_joints2root_pos = self._relative_link_position(joints[1:, :], root_pos) global_joints_rotation_mat = self._relative_link_pose(full_pose_aa) return SimpleNamespace( global_joints2root_pos=global_joints2root_pos, global_joints_rotation_mat=global_joints_rotation_mat, global_joints_position=global_joints_position, ) humanoid_fk = Humanoid_Batch_V2() @dataclass class PicoToSmplConfig: quat_scalar_first: bool = False apply_global_y_180: bool = True apply_root_rx90: bool = True root_align_degrees: float = 90.0 root_align_axis: str = "x" def body_poses_to_smpl_pose_trans( body_poses: np.ndarray, parents: np.ndarray = SMPL_PARENTS_24, cfg: Optional[PicoToSmplConfig] = None, ) -> Tuple[np.ndarray, np.ndarray]: if cfg is None: cfg = PicoToSmplConfig() body_poses = np.asarray(body_poses, dtype=np.float32) if body_poses.shape != (24, 7): raise ValueError(f"body_poses shape must be (24,7), got {body_poses.shape}") positions = body_poses[:, 0:3].astype(np.float32) qx, qy, qz, qw = body_poses[:, 3], body_poses[:, 4], body_poses[:, 5], body_poses[:, 6] global_quats_sfirst = np.stack([qw, qx, qy, qz], axis=1).astype(np.float32) global_rots = R.from_quat(global_quats_sfirst, scalar_first=True) if cfg.apply_global_y_180: global_rots = global_rots * R.from_euler("y", 180.0, degrees=True) local_rots = [] for i in range(24): parent = int(parents[i]) if parent == -1: local_rots.append(global_rots[i]) else: local_rots.append(global_rots[parent].inv() * global_rots[i]) pose_aa_24x3 = np.stack([rot.as_rotvec() for rot in local_rots], axis=0).astype(np.float32) trans = positions[0].astype(np.float32) if cfg.apply_root_rx90: rot_align = R.from_euler(cfg.root_align_axis, cfg.root_align_degrees, degrees=True).as_matrix().astype( np.float32 ) root_matrix = R.from_rotvec(pose_aa_24x3[0]).as_matrix().astype(np.float32) pose_aa_24x3[0] = R.from_matrix(rot_align @ root_matrix).as_rotvec().astype(np.float32) trans = (rot_align @ trans.reshape(3, 1)).reshape(3).astype(np.float32) return pose_aa_24x3, trans def _mirror_matrix(mirror_axis: str) -> np.ndarray: if mirror_axis == "x": return np.diag([-1.0, 1.0, 1.0]).astype(np.float32) if mirror_axis == "y": return np.diag([1.0, -1.0, 1.0]).astype(np.float32) if mirror_axis == "z": return np.diag([1.0, 1.0, -1.0]).astype(np.float32) raise ValueError(f"mirror_axis must be one of x/y/z, got {mirror_axis}") def safe_normalize_quat_wxyz(q: np.ndarray, eps: float = 1e-8) -> np.ndarray: q = np.asarray(q, dtype=np.float32).reshape(4,) n = float(np.linalg.norm(q)) if n < eps: return np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) return (q / n).astype(np.float32) def mirror_pos_and_quat_wxyz(pos: np.ndarray, quat_wxyz: np.ndarray, mirror_axis: str) -> Tuple[np.ndarray, np.ndarray]: pos = np.asarray(pos, dtype=np.float32).reshape(3,) q = safe_normalize_quat_wxyz(quat_wxyz) M = _mirror_matrix(mirror_axis) pos_m = (M @ pos).astype(np.float32) q_xyzw = np.array([q[1], q[2], q[3], q[0]], dtype=np.float32) rot_m = R.from_quat(q_xyzw).as_matrix().astype(np.float32) rot_m = (M @ rot_m @ M).astype(np.float32) q_m_xyzw = R.from_matrix(rot_m).as_quat().astype(np.float32) quat_m_wxyz = np.array([q_m_xyzw[3], q_m_xyzw[0], q_m_xyzw[1], q_m_xyzw[2]], dtype=np.float32) return pos_m, safe_normalize_quat_wxyz(quat_m_wxyz) def mirror_and_swap_gmr_input(gmr_input: Dict[str, Any], mirror_axis: str = "x") -> Dict[str, Any]: mirrored: Dict[str, Any] = {} for key, (pos, quat) in gmr_input.items(): mirrored[key] = mirror_pos_and_quat_wxyz(pos, quat, mirror_axis) out = dict(mirrored) for a, b in GMR_LR_SWAP_PAIRS: if a in out and b in out: out[a], out[b] = out[b], out[a] return out def pack_numpy_message(payload: dict, topic: bytes = OUT_TOPIC, version: int = 1) -> bytes: fields = [] binary_data = [] for key, value in payload.items(): if not isinstance(value, np.ndarray): continue if value.dtype == np.float32: dtype_str = "f32" elif value.dtype == np.float64: dtype_str = "f64" elif value.dtype == np.int32: dtype_str = "i32" elif value.dtype == np.int64: dtype_str = "i64" elif value.dtype == np.uint8: dtype_str = "u8" elif value.dtype == bool: dtype_str = "bool" else: dtype_str = "f32" value = value.astype(np.float32) if not value.flags["C_CONTIGUOUS"]: value = np.ascontiguousarray(value) if value.dtype.byteorder == ">": value = value.astype(value.dtype.newbyteorder("<")) fields.append({"name": key, "dtype": dtype_str, "shape": list(value.shape)}) binary_data.append(value.tobytes()) header = {"v": version, "endian": "le", "count": 1, "fields": fields} header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") if len(header_bytes) > HEADER_SIZE: raise ValueError(f"Header too large: {len(header_bytes)} > {HEADER_SIZE}") header_bytes = header_bytes.ljust(HEADER_SIZE, b"\x00") return topic + header_bytes + b"".join(binary_data) class PicoReader: def __init__(self): self._stop = threading.Event() self._thread = threading.Thread(target=self._run, daemon=True) self._last_stamp_ns = None self._fps_ema = 0.0 self._latest = None self._lock = threading.Lock() def start(self): self._thread.start() def stop(self): self._stop.set() self._thread.join(timeout=1.0) def get_latest(self): with self._lock: return self._latest def _run(self): last_report = time.time() while not self._stop.is_set(): if not xrt.is_body_data_available(): time.sleep(0.001) continue stamp_ns = xrt.get_time_stamp_ns() prev_stamp_ns = self._last_stamp_ns if prev_stamp_ns is not None and stamp_ns == prev_stamp_ns: time.sleep(0.000001) continue device_dt = ((stamp_ns - prev_stamp_ns) * 1e-9) if prev_stamp_ns is not None else 0.0 if device_dt > 0.0: inst_fps = 1.0 / device_dt self._fps_ema = inst_fps if self._fps_ema == 0.0 else (0.9 * self._fps_ema + 0.1 * inst_fps) self._last_stamp_ns = stamp_ns t_realtime = time.time() t_monotonic = time.monotonic() try: body_poses = xrt.get_body_joints_pose() body_poses_np = np.asarray(body_poses, dtype=np.float32) if body_poses_np.shape != (24, 7): print(f"[PicoReader] WARNING: unexpected body_poses shape: {body_poses_np.shape}") sample = { "body_poses_np": body_poses_np, "timestamp_realtime": t_realtime, "timestamp_monotonic": t_monotonic, "timestamp_ns": int(stamp_ns), "dt": float(device_dt), "fps": float(self._fps_ema), } with self._lock: self._latest = sample now = time.time() if now - last_report >= 5.0: print( f"[PicoReader] shape={body_poses_np.shape}, " f"dt_ts={device_dt * 1000.0:.2f} ms, fps={self._fps_ema:.2f}" ) last_report = now except Exception as exc: print(f"[PicoReader] read error: {exc}") class ZmqObsSender: def __init__(self, uri: str, logger, topic: bytes = OUT_TOPIC, mode: str = "bind", conflate: bool = True): self.logger = logger self.topic = topic self._context = zmq.Context() self._socket = self._context.socket(zmq.PUB) self._socket.setsockopt(zmq.SNDHWM, 1) if conflate and hasattr(zmq, "CONFLATE"): self._socket.setsockopt(zmq.CONFLATE, 1) if mode == "bind": self._socket.bind(uri) elif mode == "connect": self._socket.connect(uri) else: raise ValueError("mode must be 'bind' or 'connect'") self._last_send_time = None self._send_freq_log = [] self._frame_count = 0 self.logger.info(f"[ZMQOut] sender ready: mode={mode}, uri={uri}, topic={topic.decode('utf-8')}") def send(self, latest_obs: np.ndarray, frame_index: int, sample_meta: dict): payload = { "latest_obs": np.asarray(latest_obs, dtype=np.float32), "frame_index": np.array([frame_index], dtype=np.int64), "timestamp_realtime": np.array([sample_meta["timestamp_realtime"]], dtype=np.float64), "timestamp_monotonic": np.array([sample_meta["timestamp_monotonic"]], dtype=np.float64), "timestamp_ns": np.array([sample_meta["timestamp_ns"]], dtype=np.int64), "pico_dt": np.array([sample_meta["dt"]], dtype=np.float32), "pico_fps": np.array([sample_meta["fps"]], dtype=np.float32), } packet = pack_numpy_message(payload, topic=self.topic) self._socket.send(packet) now = time.time() if self._last_send_time is not None: dt = now - self._last_send_time if dt > 0: self._send_freq_log.append(1.0 / dt) self._frame_count += 1 if self._frame_count >= 50: avg_freq = sum(self._send_freq_log) / len(self._send_freq_log) self.logger.info(f"Average ZMQ send rate: {avg_freq:.2f} Hz") self._send_freq_log.clear() self._frame_count = 0 self._last_send_time = now def stop(self): self._socket.close(0) self._context.term() self.logger.info("🛑 ZMQ obs sender stopped") class VRNodeXRTPicoGMRZmqOut: def __init__( self, robot_zmq_uri: str, robot_zmq_mode: str = "bind", loop_hz: float = 55.0, timing_log_every: int = 100, save_obs_path: str = "", ): self.device = "cpu" logging.getLogger("websockets").setLevel(logging.WARNING) self.info(f"✅ VRNodeXRTPicoGMRZmqOut running on device={self.device}") self.info("starting xrt pico -> gmr -> robot zmq node") self.gmr = GMR(src_human="smplx", tgt_robot="unitree_g1") self.smpl_parser = SMPL_Parser(model_path=SMPL_ASSET_DIR, gender="neutral") if hasattr(self.smpl_parser, "to"): self.smpl_parser = self.smpl_parser.to(self.device) self.betas = torch.zeros(1, 10, device=self.device) self.gmr_input_data: Dict[str, Any] = {} self.prev_dof_pos = None self.lasttime = None self.timing_log_every = max(1, timing_log_every) self.save_obs_path = save_obs_path self.mirror_pose = MIRROR_POSE self.mirror_axis = MIRROR_AXIS self.tick_count = 0 self.frame_index = 0 self.timing_sums_ms = defaultdict(float) self.saved_obs = [] self.latest_sample = None self.reader = PicoReader() self.reader.start() self.sender = ZmqObsSender(uri=robot_zmq_uri, logger=self, mode=robot_zmq_mode) self.start_loop(loop_hz) def info(self, msg): print(f"[INFO] {msg}") def error(self, msg): print(f"[ERROR] {msg}") def warning(self, msg): print(f"[WARN] {msg}") def debug(self, msg): print(f"[DEBUG] {msg}") def _accumulate_timing(self, name: str, start_time: float) -> float: elapsed_ms = (time.perf_counter() - start_time) * 1000.0 self.timing_sums_ms[name] += elapsed_ms return elapsed_ms def _maybe_log_timing(self): if self.tick_count <= 0 or self.tick_count % self.timing_log_every != 0: return avg_parts = [] for key in ("body_poses_to_smpl", "smpl_to_joints", "gmr_retarget", "postprocess_send", "tick_total"): if key in self.timing_sums_ms: avg_ms = self.timing_sums_ms[key] / self.timing_log_every avg_parts.append(f"{key}={avg_ms:.2f}ms") if avg_parts: self.info("[Timing] " + ", ".join(avg_parts)) self.timing_sums_ms.clear() def process_smpl_pose_trans_to_gmr_input(self, smpl_pose_aa, smpl_trans) -> Dict[str, Any]: stage_start = time.perf_counter() if not isinstance(smpl_pose_aa, torch.Tensor): smpl_pose_aa = torch.tensor(smpl_pose_aa, dtype=torch.float32) if not isinstance(smpl_trans, torch.Tensor): smpl_trans = torch.tensor(smpl_trans, dtype=torch.float32) pose = smpl_pose_aa.to(self.device, dtype=torch.float32) trans = smpl_trans.to(self.device, dtype=torch.float32) if pose.ndim == 2: pose = pose.unsqueeze(0) if trans.ndim == 1: trans = trans.unsqueeze(0) verts, joints = self.smpl_parser.get_joints_verts(pose, self.betas, trans) # joints[..., 2] -= verts[0, :, 2].min().item() pose = pose.squeeze(0) trans = trans.squeeze(0) joints = joints.squeeze(0) motion_state = humanoid_fk.step_per_frame(pose, trans, joints) global_joints_position = motion_state.global_joints_position global_joints_rotation_mat = motion_state.global_joints_rotation_mat global_joints_qua_wxyz = matrix_to_quaternion(global_joints_rotation_mat) smpl_to_gmr = { "pelvis": 0, "spine3": 9, "left_hip": 1, "right_hip": 2, "left_knee": 4, "right_knee": 5, "left_foot": 10, "right_foot": 11, "left_shoulder": 16, "right_shoulder": 17, "left_elbow": 18, "right_elbow": 19, "left_wrist": 20, "right_wrist": 21, } gmr_input_data: Dict[str, Any] = {} for name, idx in smpl_to_gmr.items(): pos = global_joints_position[idx].detach().cpu().numpy() quat = global_joints_qua_wxyz[idx].detach().cpu().numpy() gmr_input_data[name] = (pos, quat) if self.mirror_pose: gmr_input_data = mirror_and_swap_gmr_input(gmr_input_data, mirror_axis=self.mirror_axis) self._accumulate_timing("smpl_to_joints", stage_start) return gmr_input_data def process_xrt_frame_to_gmr_input(self, sample: dict): body_poses = np.asarray(sample["body_poses_np"], dtype=np.float32) if body_poses.shape != (24, 7): raise ValueError(f"[XRT] body_poses_np must have shape (24,7), got {body_poses.shape}") stage_start = time.perf_counter() pose_aa, trans = body_poses_to_smpl_pose_trans( body_poses, cfg=PicoToSmplConfig( apply_global_y_180=True, apply_root_rx90=True, root_align_axis="x", root_align_degrees=90.0, ), ) self._accumulate_timing("body_poses_to_smpl", stage_start) self.gmr_input_data = self.process_smpl_pose_trans_to_gmr_input(pose_aa, trans) def process_gmr_output(self): stage_start = time.perf_counter() qpos = self.gmr.retarget(self.gmr_input_data) self._accumulate_timing("gmr_retarget", stage_start) stage_start = time.perf_counter() root_pos = qpos[:3] root_rot = qpos[3:7] dof_pos = qpos[7:] now = time.time() delta_time = 1 / 50 if self.lasttime is None else (now - self.lasttime) self.lasttime = now if self.prev_dof_pos is None: dof_vel = np.zeros_like(dof_pos, dtype=np.float32) else: dof_vel = (dof_pos - self.prev_dof_pos) / max(delta_time, 1e-6) self.prev_dof_pos = dof_pos.copy() latest_obs = np.concatenate([dof_pos, dof_vel, root_pos, root_rot], axis=0).astype(np.float32) self.publish_data(latest_obs) self.sender.send(latest_obs, self.frame_index, self.latest_sample) self.saved_obs.append(latest_obs.copy()) self.frame_index += 1 self._accumulate_timing("postprocess_send", stage_start) return latest_obs def publish_data(self, motion_state: np.ndarray): if motion_state.size != 65: self.error(f"Output dim {motion_state.size} != expected 65") return if np.isnan(motion_state).any(): self.error("NaN detected") return def save_observations(self): if not self.save_obs_path: return if len(self.saved_obs) == 0: self.warning(f"[SaveObs] no observations to save: {self.save_obs_path}") return obs_array = np.stack(self.saved_obs, axis=0).astype(np.float32) save_dir = os.path.dirname(self.save_obs_path) if save_dir: os.makedirs(save_dir, exist_ok=True) if self.save_obs_path.endswith(".npy"): np.save(self.save_obs_path, obs_array) else: np.savez_compressed( self.save_obs_path, latest_obs=obs_array, columns=np.array(["dof_pos(29)", "dof_vel(29)", "root_pos(3)", "root_rot_wxyz(4)"], dtype=object), ) self.info(f"[SaveObs] saved {obs_array.shape[0]} frames to {self.save_obs_path}") def start_loop(self, hz=50): self.info(f"Starting main loop at {hz} Hz") interval = 1.0 / hz def loop(): next_time = time.time() while True: self._tick() next_time += interval sleep_time = next_time - time.time() if sleep_time > 0: time.sleep(sleep_time) else: next_time = time.time() threading.Thread(target=loop, daemon=True).start() def _tick(self): tick_start = time.perf_counter() sample = self.reader.get_latest() if sample is not None: try: self.latest_sample = sample self.process_xrt_frame_to_gmr_input(sample) self.process_gmr_output() except Exception as exc: self.error(f"[tick_error] {exc}") self.error(traceback.format_exc()) return elif self.prev_dof_pos is not None: try: self.process_gmr_output() except Exception as exc: self.error(f"[tick_error] {exc}") self.error(traceback.format_exc()) return self.tick_count += 1 self._accumulate_timing("tick_total", tick_start) self._maybe_log_timing() def stop(self): self.reader.stop() self.sender.stop() self.save_observations() try: if xrt is not None and hasattr(xrt, "close"): xrt.close() except Exception: pass def init_xrt(start_service: bool = True): if xrt is None: raise ImportError("XRoboToolkit SDK not available. Install xrobotoolkit_sdk first.") if start_service: subprocess.Popen(["bash", "/opt/apps/roboticsservice/runService.sh"]) xrt.init() print("Waiting for body tracking data...") while not xrt.is_body_data_available(): print("waiting for body data...") time.sleep(1) def main(): parser = argparse.ArgumentParser(description="XRT Pico -> GMR -> robot ZMQ(65D)") parser.add_argument("--robot-zmq-uri", default="tcp://*:6001", help="Robot-side ZMQ uri for 65D obs output") parser.add_argument("--robot-zmq-mode", default="bind", choices=["bind", "connect"]) parser.add_argument("--hz", type=float, default=55.0, help="Main loop frequency / publish cap") parser.add_argument("--timing-log-every", type=int, default=200, help="Print average stage timing every N ticks") parser.add_argument("--save-obs-path", type=str, default="", help="Optional path to save emitted 65D observations") parser.add_argument("--skip-start-service", action="store_true", help="Do not auto-run /opt/apps/roboticsservice/runService.sh") args = parser.parse_args() init_xrt(start_service=not args.skip_start_service) node = VRNodeXRTPicoGMRZmqOut( robot_zmq_uri=args.robot_zmq_uri, robot_zmq_mode=args.robot_zmq_mode, loop_hz=args.hz, timing_log_every=args.timing_log_every, save_obs_path=args.save_obs_path, ) try: while True: time.sleep(1) except KeyboardInterrupt: node.stop() print("🛑 Program terminated by user.") if __name__ == "__main__": main() ================================================ FILE: deployment/holomotion_teleop/holomotion_teleop_setup.md ================================================ # Holomotion Teleop Single-process pipeline for: `PICO / XRoboToolkit -> SMPL conversion -> GMR retargeting -> robot ZMQ` ## Prerequisites Before setting up the Python environment, install XRoboToolkit PC Service manually. 1. On Ubuntu 22.04, download the XRoboToolkit PC Service `.deb` package, or build it from source. ```bash sudo dpkg -i XRoboToolkit_PC_Service_1.0.0_ubuntu_22.04_amd64.deb ``` ## Environment Setup ```bash cd /path/to/holomotion_teleop bash setup_holomotion_teleop_x86_ubuntu2204.sh ``` This script will: - create the Conda environment `holomotion_teleop` - automatically clone and install `GMR` and `SMPLSim` - install runtime dependencies such as `numpy==1.23.5`, `torch`, and `pyzmq` - build and install `xrobotoolkit_sdk` from source Optional environment variables: ```bash ENV_NAME=holomotion_teleop PYTHON_VERSION=3.10 INSTALL_APT_DEPS=auto THIRD_PARTY_DIR=/path/to/third_party GMR_SOURCE_DIR=/path/to/GMR SMPLSIM_SOURCE_DIR=/path/to/SMPLSim XRT_PYBIND_REPO_DIR=/path/to/XRoboToolkit-PC-Service-Pybind ``` - `INSTALL_APT_DEPS=auto`: only runs apt installation if required build tools are missing - `INSTALL_APT_DEPS=0`: skip apt installation entirely if your machine already has the tools or apt is unusable - `INSTALL_APT_DEPS=1`: force the apt installation step - `THIRD_PARTY_DIR`: default directory used for auto-cloned third-party repositories - `GMR_SOURCE_DIR` / `SMPLSIM_SOURCE_DIR`: point to external source checkouts; if omitted, the script auto-clones them ## Input and Output ### Input The script reads raw body tracking data directly from `xrobotoolkit_sdk.get_body_joints_pose()`: - shape: `(24, 7)` - row format: `[x, y, z, qx, qy, qz, qw]` ### Output The robot-side ZMQ payload contains `latest_obs` as `float32[65]`: 1. `dof_pos[29]` 2. `dof_vel[29]` 3. `root_pos[3]` 4. `root_rot_wxyz[4]` Additional metadata is included in the same payload: - `frame_index` - `timestamp_realtime` - `timestamp_monotonic` - `timestamp_ns` - `pico_dt` - `pico_fps` ## Next Steps Before running teleoperation on the real robot, make sure the operators are already familiar with the offline `.npz` motion-performance workflow and the robot's basic mode-switching behavior. Teleoperation should not be the first time the team tests motion-mode entry on hardware. ### Real Robot Workflow Use the following checklist when running the teleoperation stack on the real robot. #### 1. Hardware and Network Required hardware: - PICO 4 / PICO 4 Pro headset - 2 PICO controllers - 2 PICO motion trackers attached to the ankles - One workstation running `holomotion_teleop_node.py` - One robot computer running the policy / control stack - A low-latency Wi-Fi network shared by the PICO headset and the workstation Make sure the robot, the workstation and the PICO headset are on the same Wi-Fi network. Low network latency is important for stable teleoperation. The PICO-side setup steps below follow the XRoboToolkit / PICO workflow described in the [GR00T VR Teleop Setup (PICO)](https://nvlabs.github.io/GR00T-WholeBodyControl/getting_started/vr_teleop_setup.html). #### 2. Install and Configure PICO 1. Install the XRoboToolkit PICO app on the headset. - Enable Developer Mode on the headset. - Open the browser on PICO and download the XRoboToolkit PICO APK. - Install the APK from the downloads page and confirm it appears in the app library. 2. Pair the two PICO motion trackers. - Attach one tracker to each ankle. - Open the motion tracker settings on the headset. - Unpair any old trackers first, then pair both trackers again. 3. Calibrate the motion trackers on the headset. - Follow the standing calibration step. - Then look down at the foot trackers so the headset cameras can detect them. 4. Connect the headset to the workstation. - Confirm the headset and workstation are on the same Wi-Fi network. - Open the XRoboToolkit app on the headset. - Enter the workstation IP address into the PC Service field. - Verify the status shows a successful connection. 5. In XRoboToolkit, enable the required streaming options. - Enable `Head` and `Controller` tracking. - Set `Pico Motion Tracker` to `Full body`. - Enable the `Send` option for data/control streaming. #### 3. Configure the Robot-Side Policy Before starting the robot-side policy, update the robot config file: `HoloMotion/deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml` Recommended settings: - `enable_teleop_reference: true` - `require_vr_data_for_motion: true` - `latest_obs_zmq_uri: "tcp://:6001"` Replace `` with the actual IP address of the workstation that runs `holomotion_teleop_node.py`. This ensures the robot waits for live VR data before switching into motion mode and connects to the correct ZMQ publisher endpoint. #### 4. Launch Order Start the system in the following order: 1. Start the robot control / policy stack on the robot computer. 2. Wait until the control policy is fully initialized, then press `Start` to move the robot into the default pose. 3. Start XRoboToolkit on the PICO headset and confirm that body-tracking data is being streamed. 4. Start the teleoperation node on the workstation: ```bash conda activate holomotion_teleop cd /path/to/holomotion_teleop python holomotion_teleop_node.py ``` If needed, pass explicit ZMQ arguments such as: ```bash python holomotion_teleop_node.py \ --robot-zmq-uri tcp://*:6001 \ --robot-zmq-mode bind \ --hz 50 ``` 5. After the robot-side policy is receiving live teleoperation data, perform the runtime mode sequence: - press `A` to enter walking / velocity mode - press `B` to enter teleoperation motion mode - press `Y` whenever you want to leave teleoperation and return to walking mode ## Optional Arguments - `--robot-zmq-uri`: robot-side ZMQ endpoint for the 65D output - `--robot-zmq-mode`: `bind` or `connect` - `--hz`: main loop frequency / processing cap - `--timing-log-every`: print average stage timing every N ticks - `--save-obs-path`: save emitted 65D observations on exit as `.npy` or `.npz` #### 5. Runtime Check Before enabling motion on the robot: - confirm XRoboToolkit PC Service is running - confirm the PICO headset is connected to the workstation - confirm `holomotion_teleop_node.py` is publishing ZMQ data - confirm the robot-side policy is using the correct workstation IP in `latest_obs_zmq_uri` - confirm the robot-side config keeps `enable_teleop_reference: true` - confirm the robot-side config keeps `require_vr_data_for_motion: true` - confirm the team has already validated the offline `.npz` motion-performance pipeline before attempting live teleoperation Once the ZMQ stream is stable, enable the robot policy and switch into motion mode. ================================================ FILE: deployment/holomotion_teleop/setup_holomotion_teleop_x86_ubuntu2204.sh ================================================ #!/usr/bin/env bash set -euo pipefail # One-click setup script for the holomotion teleoperation environment. # # This script automates the manually verified workflow: # 1. create/activate conda env # 2. clone/install GMR # 3. clone/build/install XRoboToolkit pybind SDK # 4. clone/install SMPLSim # 5. install runtime Python dependencies # # Usage: # bash setup_gmr_holomotion_teleop_ubuntu2204.sh # # Optional env vars: # ENV_NAME=holomotion_teleop # PYTHON_VERSION=3.10 # INSTALL_APT_DEPS=0 # default disabled; set to 1 only if you need apt # THIRD_PARTY_DIR=/path/to/third_party # GMR_SOURCE_DIR=/path/to/GMR # SMPLSIM_SOURCE_DIR=/path/to/SMPLSim # XRT_PYBIND_REPO_DIR=/path/to/XRoboToolkit-PC-Service-Pybind ENV_NAME="${ENV_NAME:-holomotion_teleop}" PYTHON_VERSION="${PYTHON_VERSION:-3.10}" INSTALL_APT_DEPS="${INSTALL_APT_DEPS:-0}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="${SCRIPT_DIR}" THIRD_PARTY_DIR="${THIRD_PARTY_DIR:-$PROJECT_ROOT/third_party}" GMR_REPO_URL="${GMR_REPO_URL:-https://github.com/YanjieZe/GMR.git}" SMPLSIM_REPO_URL="${SMPLSIM_REPO_URL:-https://github.com/ZhengyiLuo/SMPLSim.git}" XRT_PYBIND_REPO_URL="${XRT_PYBIND_REPO_URL:-https://github.com/YanjieZe/XRoboToolkit-PC-Service-Pybind.git}" XRT_PC_SERVICE_REPO_URL="${XRT_PC_SERVICE_REPO_URL:-https://github.com/XR-Robotics/XRoboToolkit-PC-Service.git}" GMR_SOURCE_DIR="${GMR_SOURCE_DIR:-$THIRD_PARTY_DIR/GMR}" SMPLSIM_SOURCE_DIR="${SMPLSIM_SOURCE_DIR:-$THIRD_PARTY_DIR/SMPLSim}" XRT_PYBIND_REPO_DIR="${XRT_PYBIND_REPO_DIR:-$THIRD_PARTY_DIR/XRoboToolkit-PC-Service-Pybind}" info() { echo "[INFO] $*" } warn() { echo "[WARN] $*" >&2 } error() { echo "[ERROR] $*" >&2 exit 1 } require_command() { local cmd="$1" local hint="${2:-}" if ! command -v "$cmd" >/dev/null 2>&1; then if [[ -n "$hint" ]]; then error "$cmd not found. $hint" else error "$cmd not found." fi fi } run_conda_relaxed() { # Some conda activation/deactivation hooks are not compatible with `set -u` # and may reference unset variables such as SETVARS_CALL. set +u "$@" local status=$? set -u return $status } show_env_summary() { info "project root: $PROJECT_ROOT" info "env name: $ENV_NAME" info "python version: $PYTHON_VERSION" info "install apt deps: $INSTALL_APT_DEPS" info "third party dir: $THIRD_PARTY_DIR" info "gmr source dir: $GMR_SOURCE_DIR" info "smplsim source dir: $SMPLSIM_SOURCE_DIR" info "xrt pybind dir: $XRT_PYBIND_REPO_DIR" } check_platform() { if [[ "$(uname -s)" != "Linux" ]]; then error "This setup script currently supports Linux only." fi if [[ -f /etc/os-release ]]; then # shellcheck disable=SC1091 source /etc/os-release info "detected OS: ${PRETTY_NAME:-unknown}" if [[ "${ID:-}" != "ubuntu" || "${VERSION_ID:-}" != "22.04" ]]; then warn "This script is primarily tested on Ubuntu 22.04. Continuing anyway." fi fi } apt_deps_missing() { local missing=0 command -v gcc >/dev/null 2>&1 || missing=1 command -v g++ >/dev/null 2>&1 || missing=1 command -v make >/dev/null 2>&1 || missing=1 command -v git >/dev/null 2>&1 || missing=1 command -v cmake >/dev/null 2>&1 || missing=1 return "$missing" } install_apt_deps_if_needed() { case "$INSTALL_APT_DEPS" in 0|false|False|FALSE|no|NO) info "Skipping apt dependency installation because INSTALL_APT_DEPS=$INSTALL_APT_DEPS" info "This matches the manually verified workflow and avoids unrelated apt source failures" return ;; 1|true|True|TRUE|yes|YES) ;; *) error "Unsupported INSTALL_APT_DEPS value: $INSTALL_APT_DEPS (expected 1 or 0)" ;; esac require_command sudo "Install sudo or run the equivalent apt commands manually." require_command apt-get "This script needs apt-get to install build tools." info "Installing apt packages needed for build" if ! sudo apt-get update; then error "apt-get update failed. Common causes: broken apt sources, third-party repository timeouts, or proxy/network issues." fi if ! sudo apt-get install -y build-essential git cmake; then cat >&2 <<'EOF' [ERROR] apt package installation failed. Try one of the following: 1. sudo apt --fix-broken install 2. disable broken third-party apt repositories temporarily 3. rerun with INSTALL_APT_DEPS=0 if gcc/g++/make/git/cmake already exist EOF exit 1 fi } setup_conda_env() { require_command conda "Please install Miniconda or Anaconda first." # shellcheck disable=SC1091 source "$(conda info --base)/etc/profile.d/conda.sh" if ! conda env list | awk '{print $1}' | grep -Fx "$ENV_NAME" >/dev/null 2>&1; then info "Creating conda env: $ENV_NAME" run_conda_relaxed conda create -n "$ENV_NAME" "python=$PYTHON_VERSION" -y else info "Conda env already exists: $ENV_NAME" fi run_conda_relaxed conda activate "$ENV_NAME" } clone_repo_if_missing() { local repo_dir="$1" local repo_url="$2" local repo_name="$3" if [[ ! -d "$repo_dir/.git" ]]; then info "Cloning $repo_name" mkdir -p "$(dirname "$repo_dir")" git clone "$repo_url" "$repo_dir" else info "Using existing $repo_name checkout at $repo_dir" fi } install_gmr() { clone_repo_if_missing "$GMR_SOURCE_DIR" "$GMR_REPO_URL" "GMR" info "Installing GMR in editable mode" python -m pip install -e "$GMR_SOURCE_DIR" } build_xrt_python_sdk() { clone_repo_if_missing "$XRT_PYBIND_REPO_DIR" "$XRT_PYBIND_REPO_URL" "XRoboToolkit pybind repository" pushd "$XRT_PYBIND_REPO_DIR" >/dev/null mkdir -p tmp clone_repo_if_missing "tmp/XRoboToolkit-PC-Service" "$XRT_PC_SERVICE_REPO_URL" "XRoboToolkit PC Service source" info "Building PXREARobotSDK" pushd tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK >/dev/null bash build.sh popd >/dev/null mkdir -p lib include cp tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/PXREARobotSDK.h include/ cp -r tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/nlohmann include/nlohmann/ cp tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/build/libPXREARobotSDK.so lib/ info "Installing pybind11 into conda env" run_conda_relaxed conda install -y -c conda-forge pybind11 info "Reinstalling xrobotoolkit_sdk" python -m pip uninstall -y xrobotoolkit_sdk || true python setup.py install popd >/dev/null } install_smplsim() { clone_repo_if_missing "$SMPLSIM_SOURCE_DIR" "$SMPLSIM_REPO_URL" "SMPLSim" info "Installing SMPLSim in editable mode" python -m pip install -e "$SMPLSIM_SOURCE_DIR" } install_runtime_python_deps() { info "Upgrading pip toolchain" python -m pip install --upgrade pip setuptools wheel info "Installing runtime Python packages" python -m pip install pyzmq python -m pip install open3d } install_compat_python_deps() { info "Installing compatibility packages" python -m pip install chumpy info "Pinning numpy for chumpy compatibility" python -m pip install --upgrade "numpy==1.23.5" } print_next_steps() { echo info "Environment setup complete" echo info "Manual prerequisite:" echo " Install XRoboToolkit PC Service manually from the Ubuntu 22.04 .deb package." echo " Launch xrobotoolkit-pc-service before teleoperation." echo info "Activate with:" echo " conda activate $ENV_NAME" echo info "Example command:" echo " python \"$PROJECT_ROOT/holomotion_teleop_node.py\" \\" echo " --robot-zmq-uri tcp://*:6001 \\" echo " --robot-zmq-mode bind \\" echo " --hz 50 \\" echo " --timing-log-every 250" } main() { check_platform show_env_summary install_apt_deps_if_needed setup_conda_env install_gmr build_xrt_python_sdk install_runtime_python_deps install_smplsim install_compat_python_deps print_next_steps } main "$@" ================================================ FILE: deployment/unitree_g1_ros2_29dof/launch_holomotion_29dof.sh ================================================ #!/bin/bash ############################################################################## # HoloMotion Deployment Launch Script # # This script sets up the complete environment and launches the HoloMotion # humanoid robot control system for the Unitree G1 robot. It handles: # 1. ROS2 environment setup and workspace building # 2. Conda environment configuration for GPU/CUDA support # 3. Library path configuration for proper linking # 4. Launch of the complete HoloMotion control pipeline # # Prerequisites: # - Unitree ROS2 SDK properly installed at ~/unitree_ros2/ # - Conda environment 'holomotion_deploy' with required packages # - Network interface configured for robot communication # - Proper permissions for robot hardware access # # Usage: # ./launch_holomotion_29dof.sh [--record] # --record: Enable topic recording (optional, disabled by default) # # Author: HoloMotion Team # License: See project LICENSE file ############################################################################## # Default values ENABLE_RECORDING=false # Parse command line arguments while [[ $# -gt 0 ]]; do case $1 in --record) ENABLE_RECORDING=true shift ;; -h|--help) echo "Usage: $0 [--record]" echo " --record: Enable topic recording (optional, disabled by default)" exit 0 ;; *) echo "Unknown option $1" echo "Usage: $0 [--record]" echo " --record: Enable topic recording (optional, disabled by default)" exit 1 ;; esac done echo "Starting HoloMotion 29DOF..." echo "Recording enabled: $ENABLE_RECORDING" rm -rf build/ install/ log/ 2>/dev/null || sudo rm -rf build/ install/ log/ source ~/miniconda3/bin/activate while [[ ${CONDA_SHLVL:-0} -gt 0 ]]; do conda deactivate done source /opt/ros/humble/setup.sh source ~/unitree_ros2/setup.sh colcon build source install/setup.bash source ../../deploy.env # Launch with recording parameter ros2 launch humanoid_control holomotion_29dof_launch.py enable_recording:=$ENABLE_RECORDING ================================================ FILE: deployment/unitree_g1_ros2_29dof/launch_holomotion_29dof_docker.sh ================================================ #!/bin/bash ############################################################################## # HoloMotion Deployment Launch Script # # This script sets up the complete environment and launches the HoloMotion # humanoid robot control system for the Unitree G1 robot. It handles: # 1. ROS2 environment setup and workspace building # 2. Conda environment configuration for GPU/CUDA support # 3. Library path configuration for proper linking # 4. Launch of the complete HoloMotion control pipeline # # Prerequisites: # - Unitree ROS2 SDK properly installed at ~/unitree_ros2/ # - Conda environment 'holomotion_deploy' with required packages # - Network interface configured for robot communication # - Proper permissions for robot hardware access # # Usage: # ./launch_holomotion_29dof_docker.sh [--record] # --record: Enable topic recording (optional, disabled by default) # # Author: HoloMotion Team # License: See project LICENSE file ############################################################################## # Default values ENABLE_RECORDING=false # Parse command line arguments while [[ $# -gt 0 ]]; do case $1 in --record) ENABLE_RECORDING=true shift ;; -h|--help) echo "Usage: $0 [--record]" echo " --record: Enable topic recording (optional, disabled by default)" exit 0 ;; *) echo "Unknown option $1" echo "Usage: $0 [--record]" echo " --record: Enable topic recording (optional, disabled by default)" exit 1 ;; esac done echo "Starting HoloMotion 29DOF Docker..." echo "Recording enabled: $ENABLE_RECORDING" source /root/miniconda3/etc/profile.d/conda.sh while [[ ${CONDA_SHLVL:-0} -gt 0 ]]; do conda deactivate done rm -rf build/ install/ log/ source /opt/ros/humble/setup.sh source /root/unitree_ros2/setup.sh colcon build source install/setup.bash # Configure conda environment paths for CUDA and library linking # NOTE: Update this path to match your actual conda environment location export CYCLONEDDS_HOME=/root/cyclonedds/install export CMAKE_PREFIX_PATH=$CYCLONEDDS_HOME:$CMAKE_PREFIX_PATH source ../../deploy.env export LD_LIBRARY_PATH=/host_gpu:/cuda_base:/usr/lib/aarch64-linux-gnu/tegra:/usr/lib/aarch64-linux-gnu:/usr/local/cuda/lib64:/lib/aarch64-linux-gnu/:$LD_LIBRARY_PATH # Launch with recording parameter ros2 launch humanoid_control holomotion_29dof_launch.py enable_recording:=$ENABLE_RECORDING ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.8) project(humanoid_control) # Default to C99 if(NOT CMAKE_C_STANDARD) set(CMAKE_C_STANDARD 99) endif() # Default to C++14 if(NOT CMAKE_CXX_STANDARD) set(CMAKE_CXX_STANDARD 17) endif() if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") add_compile_options(-Wall -Wextra -Wpedantic) endif() include_directories(include include/common include/nlohmann) link_directories(src) set( DEPENDENCY_LIST unitree_go unitree_hg unitree_api rclcpp std_msgs rosbag2_cpp yaml-cpp ) # find dependencies find_package(ament_cmake REQUIRED) find_package(ament_cmake_python REQUIRED) find_package(unitree_go REQUIRED) find_package(unitree_hg REQUIRED) find_package(unitree_api REQUIRED) find_package(rclcpp REQUIRED) find_package(std_msgs REQUIRED) find_package(rosbag2_cpp REQUIRED) find_package(yaml-cpp REQUIRED) # Main control executable add_executable( humanoid_control src/main_node.cpp src/common/motor_crc_hg.cpp src/common/wireless_controller.cpp ) ament_target_dependencies(humanoid_control ${DEPENDENCY_LIST}) # Install Python modules ament_python_install_package(humanoid_policy) # Install Python scripts as executables install(PROGRAMS humanoid_policy/policy_node_29dof.py DESTINATION lib/${PROJECT_NAME} RENAME policy_node_29dof ) # Install your models directory install(DIRECTORY models/ DESTINATION share/${PROJECT_NAME}/models ) install(TARGETS humanoid_control DESTINATION lib/${PROJECT_NAME}) # motion folder install(DIRECTORY config/ DESTINATION share/${PROJECT_NAME}/config ) install(DIRECTORY motion_data/ DESTINATION share/${PROJECT_NAME}/motion_data ) # Install launch files install( DIRECTORY launch/ DESTINATION share/${PROJECT_NAME} ) if(BUILD_TESTING) find_package(ament_lint_auto REQUIRED) ament_lint_auto_find_test_dependencies() endif() ament_package() ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml ================================================ device: "cuda" policy_freq: 50 # Hz control_freq: 500 # Hz lowstate_topic: "/lowstate" action_topic: "/humanoid/action" # walking policy velocity_tracking_model_folder: "velocity_tracking_model" # motion policy motion_tracking_model_folder: "motion_tracking_model" # motion data motion_clip_dir: "motion_data" cpu_affinity_main: "" cpu_affinity_zmq_sub: "" # VR / ZMQ: the robot acts as a SUB socket and receives latest_obs from the sender. vr: enable_teleop_reference: false latest_obs_zmq_uri: "tcp://192.168.124.29:6001" latest_obs_zmq_topic: "obs65" latest_obs_zmq_mode: "connect" latest_obs_zmq_conflate: true zmq_jitter_delay_frames: 5 max_data_age: 0.6 require_vr_data_for_motion: false timing_debug_enabled: false timing_debug_log_interval_sec: 5.0 timing_debug_log_per_loop: false complete_dof_order: - left_hip_pitch_joint - left_hip_roll_joint - left_hip_yaw_joint - left_knee_joint - left_ankle_pitch_joint - left_ankle_roll_joint - right_hip_pitch_joint - right_hip_roll_joint - right_hip_yaw_joint - right_knee_joint - right_ankle_pitch_joint - right_ankle_roll_joint - waist_yaw_joint - waist_roll_joint - waist_pitch_joint - left_shoulder_pitch_joint - left_shoulder_roll_joint - left_shoulder_yaw_joint - left_elbow_joint - left_wrist_roll_joint - left_wrist_pitch_joint - left_wrist_yaw_joint - right_shoulder_pitch_joint - right_shoulder_roll_joint - right_shoulder_yaw_joint - right_elbow_joint - right_wrist_roll_joint - right_wrist_pitch_joint - right_wrist_yaw_joint policy_dof_order: - left_hip_pitch_joint - left_hip_roll_joint - left_hip_yaw_joint - left_knee_joint - left_ankle_pitch_joint - left_ankle_roll_joint - right_hip_pitch_joint - right_hip_roll_joint - right_hip_yaw_joint - right_knee_joint - right_ankle_pitch_joint - right_ankle_roll_joint - waist_yaw_joint - waist_roll_joint - waist_pitch_joint - left_shoulder_pitch_joint - left_shoulder_roll_joint - left_shoulder_yaw_joint - left_elbow_joint - left_wrist_roll_joint - left_wrist_pitch_joint - left_wrist_yaw_joint - right_shoulder_pitch_joint - right_shoulder_roll_joint - right_shoulder_yaw_joint - right_elbow_joint - right_wrist_roll_joint - right_wrist_pitch_joint - right_wrist_yaw_joint dof2motor_idx_mapping: # https://support.unitree.com/home/zh/G1_developer/about_G1 left_hip_pitch_joint: 0 left_hip_roll_joint: 1 left_hip_yaw_joint: 2 left_knee_joint: 3 left_ankle_pitch_joint: 4 left_ankle_roll_joint: 5 right_hip_pitch_joint: 6 right_hip_roll_joint: 7 right_hip_yaw_joint: 8 right_knee_joint: 9 right_ankle_pitch_joint: 10 right_ankle_roll_joint: 11 waist_yaw_joint: 12 waist_roll_joint: 13 waist_pitch_joint: 14 left_shoulder_pitch_joint: 15 left_shoulder_roll_joint: 16 left_shoulder_yaw_joint: 17 left_elbow_joint: 18 left_wrist_roll_joint: 19 left_wrist_pitch_joint: 20 left_wrist_yaw_joint: 21 right_shoulder_pitch_joint: 22 right_shoulder_roll_joint: 23 right_shoulder_yaw_joint: 24 right_elbow_joint: 25 right_wrist_roll_joint: 26 right_wrist_pitch_joint: 27 right_wrist_yaw_joint: 28 default_joint_angles: # Left leg joints (indices 0-5) left_hip_pitch_joint: -0.312 left_hip_roll_joint: 0.0 left_hip_yaw_joint: 0.0 left_knee_joint: 0.669 left_ankle_pitch_joint: -0.33 left_ankle_roll_joint: 0.0 # Right leg joints (indices 6-11) right_hip_pitch_joint: -0.312 right_hip_roll_joint: 0.0 right_hip_yaw_joint: 0.0 right_knee_joint: 0.669 right_ankle_pitch_joint: -0.33 right_ankle_roll_joint: 0.0 # Waist joints (indices 12-14) waist_yaw_joint: 0.0 waist_roll_joint: 0.0 waist_pitch_joint: 0.2 # Left arm joints (indices 15-21) left_shoulder_pitch_joint: 0.2 left_shoulder_roll_joint: 0.2 left_shoulder_yaw_joint: 0.0 left_elbow_joint: 0.6 left_wrist_roll_joint: 0.0 left_wrist_pitch_joint: 0.0 left_wrist_yaw_joint: 0.0 # Right arm joints (indices 22-28) right_shoulder_pitch_joint: 0.2 right_shoulder_roll_joint: -0.2 right_shoulder_yaw_joint: 0.0 right_elbow_joint: 0.6 right_wrist_roll_joint: 0.0 right_wrist_pitch_joint: 0.0 right_wrist_yaw_joint: 0.0 # Joint limits joint_limits: position: # Left leg joints left_hip_pitch_joint: [-2.5307, 2.8798] left_hip_roll_joint: [-0.5236, 2.9671] left_hip_yaw_joint: [-2.7576, 2.7576] left_knee_joint: [-0.087267, 2.8798] left_ankle_pitch_joint: [-0.87267, 0.5236] left_ankle_roll_joint: [-0.2618, 0.2618] # Right leg joints right_hip_pitch_joint: [-2.5307, 2.8798] right_hip_roll_joint: [-2.9671, 0.5236] right_hip_yaw_joint: [-2.7576, 2.7576] right_knee_joint: [-0.087267, 2.8798] right_ankle_pitch_joint: [-0.87267, 0.5236] right_ankle_roll_joint: [-0.2618, 0.2618] # Waist joints waist_yaw_joint: [-2.618, 2.618] waist_roll_joint: [-0.52, 0.52] waist_pitch_joint: [-0.52, 0.52] # Left arm joints left_shoulder_pitch_joint: [-3.0892, 2.6704] left_shoulder_roll_joint: [-1.5882, 2.2515] left_shoulder_yaw_joint: [-2.618, 2.618] left_elbow_joint: [-1.0472, 2.0944] left_wrist_roll_joint: [-1.972222054, 1.972222054] left_wrist_pitch_joint: [-1.614429558, 1.614429558] left_wrist_yaw_joint: [-1.614429558, 1.614429558] # Right arm joints right_shoulder_pitch_joint: [-3.0892, 2.6704] right_shoulder_roll_joint: [-2.2515, 1.5882] right_shoulder_yaw_joint: [-2.618, 2.618] right_elbow_joint: [-1.0472, 2.0944] right_wrist_roll_joint: [-1.972222054, 1.972222054] right_wrist_pitch_joint: [-1.614429558, 1.614429558] right_wrist_yaw_joint: [-1.614429558, 1.614429558] velocity: # Left leg joints left_hip_pitch_joint: 32 left_hip_roll_joint: 20 left_hip_yaw_joint: 32 left_knee_joint: 20 left_ankle_pitch_joint: 30 left_ankle_roll_joint: 30 # Right leg joints right_hip_pitch_joint: 32 right_hip_roll_joint: 20 right_hip_yaw_joint: 32 right_knee_joint: 20 right_ankle_pitch_joint: 30 right_ankle_roll_joint: 30 # Waist joints waist_yaw_joint: 32 waist_roll_joint: 30 waist_pitch_joint: 30 # Left arm joints left_shoulder_pitch_joint: 37 left_shoulder_roll_joint: 37 left_shoulder_yaw_joint: 37 left_elbow_joint: 37 left_wrist_roll_joint: 37 left_wrist_pitch_joint: 22 left_wrist_yaw_joint: 22 # Right arm joints right_shoulder_pitch_joint: 37 right_shoulder_roll_joint: 37 right_shoulder_yaw_joint: 37 right_elbow_joint: 37 right_wrist_roll_joint: 37 right_wrist_pitch_joint: 22 right_wrist_yaw_joint: 22 effort: # Left leg joints left_hip_pitch_joint: 88 left_hip_roll_joint: 139 left_hip_yaw_joint: 88 left_knee_joint: 139 left_ankle_pitch_joint: 35 left_ankle_roll_joint: 35 # Right leg joints right_hip_pitch_joint: 88 right_hip_roll_joint: 139 right_hip_yaw_joint: 88 right_knee_joint: 139 right_ankle_pitch_joint: 35 right_ankle_roll_joint: 35 # Waist joints waist_yaw_joint: 88 waist_roll_joint: 35 waist_pitch_joint: 35 # Left arm joints left_shoulder_pitch_joint: 25 left_shoulder_roll_joint: 25 left_shoulder_yaw_joint: 25 left_elbow_joint: 25 left_wrist_roll_joint: 25 left_wrist_pitch_joint: 5 left_wrist_yaw_joint: 5 # Right arm joints right_shoulder_pitch_joint: 25 right_shoulder_roll_joint: 25 right_shoulder_yaw_joint: 25 right_elbow_joint: 25 right_wrist_roll_joint: 25 right_wrist_pitch_joint: 5 right_wrist_yaw_joint: 5 limit_scales: position: 2.0 # Allows 50% more range of motion velocity: 2.0 effort: 2.0 # move to default position # joint_names and default_position are auto-generated from complete_dof_order and default_joint_angles # Only kp and kd arrays need to be specified here kp: [ 350.0, 200.0, 200.0, 300.0, 300.0, 150.0, 350.0, 200.0, 200.0, 300.0, 300.0, 150.0, 200.0, 200.0, 200.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0 ] kd: [ 5.0, 5.0, 5.0, 10.0, 5.0, 5.0, 5.0, 5.0, 5.0, 10.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0 ] ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/__init__.py ================================================ ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/holomotion_fk_root_only.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from __future__ import annotations import logging import time from typing import Callable, Dict, Sequence import numpy as np import torch def _xyzw_to_wxyz(q: np.ndarray) -> np.ndarray: return np.concatenate([q[..., 3:4], q[..., 0:3]], axis=-1) def _wxyz_to_xyzw(q: np.ndarray) -> np.ndarray: return np.concatenate([q[..., 1:4], q[..., 0:1]], axis=-1) def _quat_conjugate_wxyz(q: np.ndarray) -> np.ndarray: out = np.array(q, copy=True) out[..., 1:4] *= -1.0 return out def _quat_mul_wxyz(q1: np.ndarray, q2: np.ndarray) -> np.ndarray: w1 = q1[..., 0] x1 = q1[..., 1] y1 = q1[..., 2] z1 = q1[..., 3] w2 = q2[..., 0] x2 = q2[..., 1] y2 = q2[..., 2] z2 = q2[..., 3] return np.stack( [ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, ], axis=-1, ) def _standardize_quaternion_wxyz(q: np.ndarray) -> np.ndarray: return np.where(q[..., 0:1] < 0.0, -q, q) def _axis_angle_from_wxyz(q: np.ndarray) -> np.ndarray: q = _standardize_quaternion_wxyz(q) q = q / np.linalg.norm(q, axis=-1, keepdims=True).clip(min=1.0e-9) quat_w = q[..., 0] quat_xyz = q[..., 1:4] mag = np.linalg.norm(quat_xyz, axis=-1) half_angle = np.arctan2(mag, quat_w) angle = 2.0 * half_angle use_taylor = np.abs(angle) <= 1.0e-6 angle_safe = np.where(use_taylor, 1.0, angle) sin_half_over_angle = np.where( use_taylor, 0.5 - angle * angle / 48.0, np.sin(half_angle) / angle_safe, ) return quat_xyz / sin_half_over_angle[..., None] def _grad_t(x: np.ndarray, dt: float) -> np.ndarray: if dt <= 0.0: raise ValueError(f"Invalid dt: {dt}") if x.shape[1] < 2: return np.zeros_like(x) grad = np.empty_like(x) inv_dt = 1.0 / dt grad[:, 0] = (x[:, 1] - x[:, 0]) * inv_dt grad[:, -1] = (x[:, -1] - x[:, -2]) * inv_dt if x.shape[1] > 2: grad[:, 1:-1] = (x[:, 2:] - x[:, :-2]) * (0.5 * inv_dt) return grad class HoloMotionFKRootOnly(torch.nn.Module): """Root-only online FK. This lightweight variant is intended for policy-time VR reference building when only the root body pose/velocity are consumed by observation terms. """ def __init__( self, dof_names: Sequence[str], device: torch.device | str = "cpu", dtype: torch.dtype = torch.float32, timing_logger_enabled: bool = False, timing_log_interval_sec: float = 5.0, timing_log_per_call: bool = False, timing_name: str = "HoloMotionFKRootOnly", timing_log_fn: Callable[[str], None] | None = None, ) -> None: super().__init__() self.body_names = ["root"] self.dof_names = list(dof_names) self.num_bodies = 1 self.num_dof = len(self.dof_names) if self.num_dof <= 0: raise ValueError("dof_names must not be empty") self._device = torch.device(device) self._dtype = dtype if self._dtype == torch.float64: self._np_dtype = np.float64 else: self._np_dtype = np.float32 self._timing_logger_enabled = bool(timing_logger_enabled) self._timing_log_interval_sec = float(timing_log_interval_sec) self._timing_log_per_call = bool(timing_log_per_call) self._timing_name = str(timing_name) self._timing_logger = logging.getLogger(__name__) self._timing_log_fn = timing_log_fn self._timing_last_log_time = None self._timing_count = 0 self._timing_sum_ms = {} self._timing_max_ms = {} self.last_timing_ms = {} self._gaussian_kernel_cache: Dict[tuple[float, str], np.ndarray] = {} def set_timing_logger( self, enabled: bool, interval_sec: float | None = None, per_call: bool | None = None, log_fn: Callable[[str], None] | None = None, ) -> None: self._timing_logger_enabled = bool(enabled) if interval_sec is not None: self._timing_log_interval_sec = float(interval_sec) if per_call is not None: self._timing_log_per_call = bool(per_call) if log_fn is not None: self._timing_log_fn = log_fn def _timing_ms(self, t0: float) -> float: return (time.perf_counter() - t0) * 1000.0 def _to_numpy(self, x: torch.Tensor | np.ndarray) -> np.ndarray: if isinstance(x, np.ndarray): return np.asarray(x, dtype=self._np_dtype) if not isinstance(x, torch.Tensor): return np.asarray(x, dtype=self._np_dtype) if x.device.type != "cpu" or x.dtype != self._dtype: x = x.detach().to(device="cpu", dtype=self._dtype) else: x = x.detach() return x.numpy() def _to_output_tensor(self, x: np.ndarray) -> torch.Tensor: tensor = torch.from_numpy(np.ascontiguousarray(x)) if self._device.type != "cpu" or tensor.dtype != self._dtype: tensor = tensor.to(device=self._device, dtype=self._dtype) return tensor def _get_gaussian_kernel(self, sigma: float) -> np.ndarray | None: if sigma <= 0.0: return None key = (float(sigma), np.dtype(self._np_dtype).str) kernel = self._gaussian_kernel_cache.get(key, None) if kernel is not None: return kernel radius = int(4.0 * sigma + 0.5) kernel_x = np.arange(-radius, radius + 1, dtype=self._np_dtype) kernel = np.exp(-0.5 * np.square(kernel_x / sigma)).astype( self._np_dtype, copy=False ) kernel /= kernel.sum(dtype=self._np_dtype) self._gaussian_kernel_cache[key] = kernel return kernel def _gaussian_filter_time(self, x: np.ndarray, kernel: np.ndarray | None) -> np.ndarray: if kernel is None or x.shape[1] < 2: return x radius = kernel.shape[0] // 2 padded = np.pad(x, ((0, 0), (radius, radius), (0, 0)), mode="edge") windows = np.lib.stride_tricks.sliding_window_view( padded, window_shape=kernel.shape[0], axis=1 ) return np.tensordot(windows, kernel, axes=([-1], [0])).astype( x.dtype, copy=False ) def _log_timing_message(self, message: str) -> None: if self._timing_log_fn is not None: self._timing_log_fn(message) else: self._timing_logger.info(message) def _record_timing(self, sample: Dict[str, float]) -> None: self.last_timing_ms = dict(sample) if not self._timing_logger_enabled: return self._timing_count += 1 for key, value in sample.items(): v = float(value) self._timing_sum_ms[key] = self._timing_sum_ms.get(key, 0.0) + v self._timing_max_ms[key] = max(self._timing_max_ms.get(key, v), v) if self._timing_log_per_call: self._log_timing_message( ( f"[{self._timing_name}][Timing] " f"total={sample['total_ms']:.3f}ms " f"input={sample['input_ms']:.3f}ms " f"quat={sample['quat_ms']:.3f}ms " f"linvel={sample['linvel_ms']:.3f}ms " f"angvel={sample['angvel_ms']:.3f}ms " f"smooth={sample['smooth_ms']:.3f}ms " f"output={sample['output_ms']:.3f}ms" ) ) now = time.time() if self._timing_last_log_time is None: self._timing_last_log_time = now return if now - self._timing_last_log_time < self._timing_log_interval_sec: return if self._timing_count == 0: self._timing_last_log_time = now return keys = [ "total_ms", "input_ms", "quat_ms", "linvel_ms", "angvel_ms", "smooth_ms", "output_ms", ] self._log_timing_message( f"[{self._timing_name}][Timing-Agg] " + " ".join( f"{key}=mean:{self._timing_sum_ms.get(key, 0.0) / self._timing_count:.3f}ms/" f"max:{self._timing_max_ms.get(key, 0.0):.3f}ms" for key in keys ) + f" n={self._timing_count}" ) self._timing_count = 0 self._timing_sum_ms.clear() self._timing_max_ms.clear() self._timing_last_log_time = now @torch.inference_mode() def forward( self, root_pos: torch.Tensor, root_quat: torch.Tensor, dof_pos: torch.Tensor, fps: float, quat_format: str = "xyzw", sub_batch_size: int = 64, vel_smoothing_sigma: float = 2.0, compute_velocity: bool = True, ) -> Dict[str, torch.Tensor]: t_total = time.perf_counter() del sub_batch_size del compute_velocity # kept for call-site compatibility if fps <= 0.0: raise ValueError(f"Invalid fps: {fps}") if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3: raise ValueError("Inputs must be (B, T, ...)") if ( root_pos.shape[:2] != root_quat.shape[:2] or root_pos.shape[:2] != dof_pos.shape[:2] ): raise ValueError("Mismatched batch/time shapes among inputs") if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4: raise ValueError( "root_pos must be (B,T,3) and root_quat must be (B,T,4)" ) if dof_pos.shape[-1] != self.num_dof: raise ValueError( f"dof_pos last dim {dof_pos.shape[-1]} does not match {self.num_dof}" ) t_input = time.perf_counter() root_pos_np = self._to_numpy(root_pos) root_quat_np = self._to_numpy(root_quat) dof_pos_np = self._to_numpy(dof_pos) input_ms = self._timing_ms(t_input) t_quat = time.perf_counter() if quat_format == "xyzw": root_quat_xyzw_np = root_quat_np root_quat_wxyz_np = _xyzw_to_wxyz(root_quat_np) elif quat_format == "wxyz": root_quat_wxyz_np = root_quat_np root_quat_xyzw_np = _wxyz_to_xyzw(root_quat_np) else: raise ValueError(f"Unsupported quat_format: {quat_format}") quat_ms = self._timing_ms(t_quat) dt = 1.0 / fps kernel = self._get_gaussian_kernel(float(vel_smoothing_sigma)) t_linvel = time.perf_counter() root_vel_np = _grad_t(root_pos_np, dt) linvel_ms = self._timing_ms(t_linvel) t_angvel = time.perf_counter() root_angvel_np = np.zeros_like(root_pos_np) if root_quat_wxyz_np.shape[1] >= 2: q1 = root_quat_wxyz_np[:, 1:] q0_inv = _quat_conjugate_wxyz(root_quat_wxyz_np[:, :-1]) q_rel = _quat_mul_wxyz(q1, q0_inv) root_angvel_np[:, :-1] = _axis_angle_from_wxyz(q_rel) / dt angvel_ms = self._timing_ms(t_angvel) t_smooth = time.perf_counter() if kernel is not None and root_pos_np.shape[1] >= 2: vel_and_ang_np = np.concatenate([root_vel_np, root_angvel_np], axis=-1) vel_and_ang_np = self._gaussian_filter_time(vel_and_ang_np, kernel) root_vel_np = vel_and_ang_np[..., :3] root_angvel_np = vel_and_ang_np[..., 3:6] smooth_ms = self._timing_ms(t_smooth) t_output = time.perf_counter() out = { "global_translation": self._to_output_tensor(root_pos_np[:, :, None, :]), "global_rotation_quat": self._to_output_tensor( root_quat_xyzw_np[:, :, None, :] ), "global_velocity": self._to_output_tensor(root_vel_np[:, :, None, :]), "global_angular_velocity": self._to_output_tensor( root_angvel_np[:, :, None, :] ), "dof_pos": self._to_output_tensor(dof_pos_np), "dof_vel": self._to_output_tensor(np.zeros_like(dof_pos_np)), } output_ms = self._timing_ms(t_output) self._record_timing( { "total_ms": self._timing_ms(t_total), "input_ms": input_ms, "quat_ms": quat_ms, "linvel_ms": linvel_ms, "angvel_ms": angvel_ms, "smooth_ms": smooth_ms, "output_ms": output_ms, } ) return out ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/obs_builder/__init__.py ================================================ from .obs_builder import PolicyObsBuilder, get_gravity_orientation __all__ = [ "PolicyObsBuilder", "get_gravity_orientation", ] ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/obs_builder/obs_builder.py ================================================ import numpy as np import torch from typing import Dict, List, Sequence, Any, Optional def get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray: """Calculate gravity orientation from quaternion. Args: quaternion: Array-like [w, x, y, z] Returns: np.ndarray of shape (3,) representing gravity projection. """ qw = float(quaternion[0]) qx = float(quaternion[1]) qy = float(quaternion[2]) qz = float(quaternion[3]) gravity_orientation = np.zeros(3, dtype=np.float32) gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy) gravity_orientation[1] = -2.0 * (qz * qy + qw * qx) gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz) return gravity_orientation class _CircularBuffer: """History buffer for batched tensor data (batch==1 in our eval/deploy). Stores history in oldest->newest order when accessed via .buffer. """ def __init__(self, max_len: int, feat_dim: int): if max_len < 1: raise ValueError(f"max_len must be >= 1, got {max_len}") self._max_len = int(max_len) self._feat_dim = int(feat_dim) self._pointer = -1 self._num_pushes = 0 self._buffer: torch.Tensor = torch.zeros( (self._max_len, 1, self._feat_dim), dtype=torch.float32, device="cpu", ) @property def buffer(self) -> torch.Tensor: """Tensor of shape [1, max_len, feat_dim], oldest->newest along dim=1.""" if self._num_pushes == 0: raise RuntimeError( "Attempting to read from an empty history buffer." ) # roll such that oldest is at index=0 along the history axis rolled = torch.roll( self._buffer, shifts=self._max_len - self._pointer - 1, dims=0 ) return torch.transpose(rolled, 0, 1) # [1, max_len, feat] def append(self, data: torch.Tensor) -> None: """Append one step: data shape [1, feat_dim] on the configured device.""" if ( data.ndim != 2 or data.shape[0] != 1 or data.shape[1] != self._feat_dim ): raise ValueError( f"Expected data with shape [1, {self._feat_dim}], got {tuple(data.shape)}" ) self._pointer = (self._pointer + 1) % self._max_len self._buffer[self._pointer] = data if self._num_pushes == 0: # duplicate first push across entire history for warm start self._buffer[:] = data self._num_pushes += 1 class PolicyObsBuilder: """Builds policy observations from Unitree lowstate with temporal history. Designed to be shared between MuJoCo sim2sim evaluation and ROS2 deployment. History management is internal and produces a flattened vector of size sum_i(context_length * feat_i) across the configured observation items. Supports two command modes: - "motion_tracking": uses reference motion states - "velocity_tracking": uses velocity commands [vx, vy, vyaw] """ def __init__( self, dof_names_onnx: Sequence[str], default_angles_onnx: np.ndarray, evaluator: Optional[Any] = None, obs_policy_cfg: Optional[Dict[str, Any]] = None, ) -> None: self.dof_names_onnx: List[str] = list(dof_names_onnx) self.num_actions: int = len(self.dof_names_onnx) self.evaluator = evaluator self.obs_policy_cfg = obs_policy_cfg if default_angles_onnx.shape[0] != self.num_actions: raise ValueError( "default_angles_onnx length must match num actions" ) self.default_angles_onnx = default_angles_onnx.astype(np.float32) self.default_angles_dict: Dict[str, float] = { name: float(self.default_angles_onnx[idx]) for idx, name in enumerate(self.dof_names_onnx) } # Build observation schema from config if provided self.term_specs: List[Dict[str, Any]] = [] for term_dict in self.obs_policy_cfg["atomic_obs_list"]: for name, cfg in term_dict.items(): term_dict = {**cfg} term_dict["name"] = name self.term_specs.append(term_dict) # Buffers are created lazily after first dimension inference self._buffers: Dict[str, _CircularBuffer] = {} def reset(self) -> None: for buf in self._buffers.values(): buf._pointer = -1 buf._num_pushes = 0 buf._buffer.zero_() def _compute_term( self, name: str, ) -> np.ndarray: # Prefer evaluator-provided methods; no legacy fallbacks if self.evaluator is not None: meth = getattr(self.evaluator, f"_get_obs_{name}", None) if callable(meth): out = meth() return np.asarray(out, dtype=np.float32).reshape(-1) raise ValueError( f"Unknown observation term '{name}' or evaluator method missing." ) def build_policy_obs(self) -> np.ndarray: """Append one step using evaluator-provided observation terms and return flattened obs.""" # Compute per-term outputs values: Dict[str, np.ndarray] = {} for spec in self.term_specs: name = spec["name"] scale = float(spec.get("scale", 1.0)) values[name] = self._compute_term(name) * scale # Lazily initialize buffers with inferred feature dims if len(self._buffers) == 0: for spec in self.term_specs: name = spec["name"] hist_len = int(spec.get("history_length", 0)) if hist_len <= 0: continue feat_dim = int(values[name].reshape(-1).shape[0]) self._buffers[name] = _CircularBuffer( hist_len, feat_dim, ) # Append current step to buffers (skip terms without history) for spec in self.term_specs: name = spec["name"] if name in self._buffers: item = torch.as_tensor( values[name], dtype=torch.float32, device="cpu", )[None, :] self._buffers[name].append(item) # Assemble flat list according to term ordering and history flatten rules flat_list: List[np.ndarray] = [] for spec in self.term_specs: name = spec["name"] if name in self._buffers: buf = self._buffers[name].buffer[0] # [hist, feat] arr = buf.reshape(-1).detach().cpu().numpy() flat_list.append(arr.astype(np.float32)) else: # no history -> use computed value directly flat_list.append(values[name].reshape(-1).astype(np.float32)) if len(flat_list) == 0: return np.zeros(0, dtype=np.float32) return np.concatenate(flat_list, axis=0).astype(np.float32) ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/policy_node_29dof.py ================================================ #! /your_dir/miniconda3/envs/holomotion_deploy/bin/python """ HoloMotion Policy Node This module implements the main policy execution node for the HoloMotion humanoid robot system using ZMQ latest_obs transport. It handles neural network policy inference, motion sequence management, remote controller input, and robot state coordination for humanoid behaviors including velocity tracking and motion tracking. The policy node serves as the high-level decision maker that: - Processes sensor observations and builds state representations - Executes trained neural network policies for motion generation (velocity tracking and motion tracking) - Manages multiple motion sequences (motion clips) loaded from offline files - Handles remote controller input for motion selection - Coordinates with the main control node for safe operation Key Features: - Dual policy support: velocity tracking and motion tracking - Offline motion file loading (.npz format) - Runtime policy switching with button controls - Separate hyperparameters (kps, kds, action_scale, default_angles) for each model Author: HoloMotion Team License: See project LICENSE file """ import os import torch import time import json import threading from collections import deque import easydict import numpy as np import onnx import onnxruntime import rclpy import zmq import yaml from ament_index_python.packages import get_package_share_directory from omegaconf import OmegaConf from rclpy.node import Node from rclpy.qos import QoSProfile from std_msgs.msg import Float32MultiArray, String from unitree_hg.msg import LowState from humanoid_policy.obs_builder import PolicyObsBuilder from humanoid_policy.utils.remote_controller_filter import KeyMap, RemoteController from humanoid_policy.holomotion_fk_root_only import HoloMotionFKRootOnly def _parse_cpu_affinity_str(s): """Parse '0,1' or '2' -> [0,1] or [2]. Empty/invalid -> [].""" s = str(s).strip() if not s: return [] out = [] for x in s.split(","): x = x.strip() if x.isdigit(): out.append(int(x)) return out def set_thread_cpu_affinity(cpu_ids): """Pin current thread to given CPU core IDs (Linux only). cpu_ids: list of int, e.g. [0], [0,1]. Returns True if set successfully.""" if not cpu_ids: return False try: import ctypes libc = ctypes.CDLL("libc.so.6") CPU_SETSIZE = 1024 ncpubits = 8 * ctypes.sizeof(ctypes.c_ulong) nlongs = (CPU_SETSIZE + ncpubits - 1) // ncpubits class CpuSetT(ctypes.Structure): _fields_ = [("__bits", ctypes.c_ulong * nlongs)] libc.pthread_self.restype = ctypes.c_ulong libc.pthread_setaffinity_np.argtypes = [ ctypes.c_ulong, ctypes.c_size_t, ctypes.POINTER(CpuSetT) ] cs = CpuSetT() for i in range(nlongs): cs.__bits[i] = 0 for c in cpu_ids: if 0 <= c < CPU_SETSIZE: idx = c // ncpubits bit = c % ncpubits cs.__bits[idx] |= 1 << bit tid = libc.pthread_self() sz = ctypes.sizeof(CpuSetT) ret = libc.pthread_setaffinity_np(tid, sz, ctypes.byref(cs)) return ret == 0 except Exception: return False HEADER_SIZE = 1280 DEFAULT_ZMQ_TOPIC = b"obs65" _DTYPE_BY_NAME = { "f32": np.float32, "f64": np.float64, "i32": np.int32, "i64": np.int64, "u8": np.uint8, "bool": np.bool_, } def _decode_zmq_topic(topic_value) -> bytes: if isinstance(topic_value, bytes): return topic_value return str(topic_value).encode("utf-8") def _coerce_config_bool(value, default: bool = False) -> bool: if value is None: return default if isinstance(value, (bool, np.bool_)): return bool(value) if isinstance(value, str): value = value.strip().lower() if value in {"1", "true", "yes", "y", "on"}: return True if value in {"0", "false", "no", "n", "off", ""}: return False return bool(value) def _infer_onnx_dim(dim, default: int = 1) -> int: if isinstance(dim, int) and dim > 0: return dim return int(default) def _infer_numpy_dtype_from_onnx_type(type_str: str): type_str = str(type_str).lower() if "float16" in type_str: return np.float16 if "float64" in type_str or "double" in type_str: return np.float64 if "int64" in type_str: return np.int64 if "int32" in type_str: return np.int32 if "bool" in type_str: return np.bool_ return np.float32 def unpack_numpy_message(packet: bytes, expected_topic: bytes | None = None) -> dict: if expected_topic is not None: if not packet.startswith(expected_topic): raise ValueError("ZMQ packet topic prefix mismatch") packet = packet[len(expected_topic) :] if len(packet) < HEADER_SIZE: raise ValueError(f"ZMQ packet too short: {len(packet)} < {HEADER_SIZE}") header_bytes = packet[:HEADER_SIZE].rstrip(b"\x00") if not header_bytes: raise ValueError("ZMQ packet has empty header") header = json.loads(header_bytes.decode("utf-8")) payload = memoryview(packet)[HEADER_SIZE:] result = {} offset = 0 for field in header.get("fields", []): name = str(field["name"]) dtype_name = str(field["dtype"]) shape = tuple(int(x) for x in field.get("shape", [])) if dtype_name not in _DTYPE_BY_NAME: raise ValueError(f"Unsupported dtype in ZMQ packet: {dtype_name}") dtype = np.dtype(_DTYPE_BY_NAME[dtype_name]).newbyteorder("<") count = int(np.prod(shape, dtype=np.int64)) if len(shape) > 0 else 1 nbytes = count * dtype.itemsize end = offset + nbytes if end > len(payload): raise ValueError( f"ZMQ packet field '{name}' exceeds payload size: end={end}, payload={len(payload)}" ) arr = np.frombuffer(payload[offset:end], dtype=dtype, count=count) if len(shape) > 0: arr = arr.reshape(shape) else: arr = arr.reshape(()) result[name] = np.array(arr, copy=True) offset = end return result class LatestObsBuffer: """Thread-safe buffer for delayed latest_obs access.""" def __init__(self, max_queue_size: int = 20): self._lock = threading.Lock() self._data = None self._timestamp = None self._sender_timestamp = None self._frame_index = None self._data_queue = deque(maxlen=max_queue_size) self._timestamp_queue = deque(maxlen=max_queue_size) self._sender_timestamp_queue = deque(maxlen=max_queue_size) self._frame_index_queue = deque(maxlen=max_queue_size) def set( self, arr: np.ndarray, sender_timestamp: float | None = None, frame_index: int | None = None, ): with self._lock: current_time = time.time() arr_copy = np.asarray(arr, dtype=np.float32).copy() self._data = arr_copy self._timestamp = current_time self._sender_timestamp = sender_timestamp self._frame_index = frame_index self._data_queue.append(arr_copy) self._timestamp_queue.append(current_time) self._sender_timestamp_queue.append(sender_timestamp) self._frame_index_queue.append(frame_index) def get_with_age_and_delay(self, max_age: float = 0.1, delay_steps: int = 0): """Return a delayed frame and report whether it is stale.""" with self._lock: if len(self._data_queue) == 0: if self._data is None or self._timestamp is None: return None, None, True, None, None current_time = time.time() age = current_time - self._timestamp return ( self._data, self._timestamp, age > max_age, self._frame_index, self._sender_timestamp, ) if delay_steps < 0: delay_steps = 0 idx = len(self._data_queue) - 1 - delay_steps if idx < 0: idx = 0 data = self._data_queue[idx] ts = self._timestamp_queue[idx] frame_index = self._frame_index_queue[idx] sender_timestamp = self._sender_timestamp_queue[idx] current_time = time.time() age = current_time - ts is_stale = age > max_age return data, ts, is_stale, frame_index, sender_timestamp def get_queue_stats(self): with self._lock: if len(self._data_queue) < 2: return {"queue_size": len(self._data_queue), "avg_interval": None} intervals = [] for i in range(1, len(self._timestamp_queue)): interval = self._timestamp_queue[i] - self._timestamp_queue[i - 1] intervals.append(interval) avg_interval = float(np.mean(intervals)) if intervals else None return { "queue_size": len(self._data_queue), "avg_interval": avg_interval, "expected_freq": 1.0 / avg_interval if avg_interval and avg_interval > 0 else None, } class ZmqLatestObsSubscriber: """Background ZMQ SUB receiver for latest_obs packets.""" def __init__( self, uri: str, topic: bytes, buffer: LatestObsBuffer, logger, mode: str = "connect", cpu_affinity=None, conflate: bool = True, ): self.uri = uri self.topic = topic self.buffer = buffer self.logger = logger self.mode = str(mode).strip().lower() self.cpu_affinity = cpu_affinity or [] self.conflate = bool(conflate) self._thread = None self._stop_event = threading.Event() self._context = None self._socket = None self._poller = None self._recv_count = 0 def _process_packet(self, packet: bytes): payload = unpack_numpy_message(packet, expected_topic=self.topic) latest_obs = payload.get("latest_obs", None) if latest_obs is None: raise ValueError("ZMQ packet missing latest_obs field") frame_index = payload.get("frame_index", None) if frame_index is not None: frame_index = int(np.asarray(frame_index).reshape(-1)[0]) sender_timestamp = payload.get("timestamp_realtime", None) if sender_timestamp is not None: sender_timestamp = float(np.asarray(sender_timestamp).reshape(-1)[0]) self.buffer.set( np.asarray(latest_obs, dtype=np.float32), sender_timestamp=sender_timestamp, frame_index=frame_index, ) self._recv_count += 1 if self._recv_count == 1: self.logger.info( f"[ZMQ] first latest_obs packet received from {self.uri}, " f"topic={self.topic.decode('utf-8', errors='ignore')}" ) def _run(self): if self.cpu_affinity and set_thread_cpu_affinity(self.cpu_affinity): self.logger.info(f"[ZMQ] subscriber thread pinned to CPUs {self.cpu_affinity}") self._context = zmq.Context() self._socket = self._context.socket(zmq.SUB) self._socket.setsockopt(zmq.RCVHWM, 1) self._socket.setsockopt(zmq.SUBSCRIBE, self.topic) if self.conflate and hasattr(zmq, "CONFLATE"): self._socket.setsockopt(zmq.CONFLATE, 1) if self.mode == "bind": self._socket.bind(self.uri) elif self.mode == "connect": self._socket.connect(self.uri) else: raise ValueError("latest_obs_zmq_mode must be 'bind' or 'connect'") self._poller = zmq.Poller() self._poller.register(self._socket, zmq.POLLIN) self.logger.info( f"[ZMQ] latest_obs subscriber ready: mode={self.mode}, uri={self.uri}, " f"topic={self.topic.decode('utf-8', errors='ignore')}, conflate={self.conflate}" ) try: while not self._stop_event.is_set(): events = dict(self._poller.poll(50)) if self._socket not in events: continue try: packet = self._socket.recv(flags=zmq.NOBLOCK) except zmq.Again: continue self._process_packet(packet) except Exception as exc: if not self._stop_event.is_set(): self.logger.error(f"[ZMQ] subscriber loop failed: {exc}") finally: try: if self._poller is not None and self._socket is not None: self._poller.unregister(self._socket) except Exception: pass try: if self._socket is not None: self._socket.close(0) except Exception: pass try: if self._context is not None: self._context.term() except Exception: pass self._socket = None self._context = None self._poller = None def start(self): if self._thread is not None: return self._stop_event.clear() self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() self.logger.info("[ZMQ] subscriber thread started") def stop(self): self._stop_event.set() if self._thread: self._thread.join(timeout=2.0) self._thread = None self.logger.info("[ZMQ] subscriber thread stopped") class HoloMotionPolicyNode(Node): """Main policy execution node for HoloMotion humanoid robot control with dual policy support. This node implements the high-level control logic for a humanoid robot capable of performing both velocity tracking and motion sequence execution. It supports two neural network policies and allows runtime switching between them. Key Features: - Dual neural network policy inference (velocity + motion) using ONNX Runtime - Runtime policy switching with A/B/Y button controls - Velocity tracking mode with joystick control - Motion tracking mode with motion clip sequence selection - Safety-aware state machine with motion prerequisites - Real-time observation processing and action generation Policy Control: - A button: Enable policy (defaults to velocity mode) - B button: Switch from velocity to motion mode - Y button: Switch from motion back to velocity mode Input Controls: - Motion mode: B button (for mode switch) - Velocity mode: Y button (for mode switch) + Joystick +UP/DOWN/LEFT/RIGHT (for motion selection) State Machine: - ZERO_TORQUE: Initial safe state, waiting for activation - MOVE_TO_DEFAULT: Ready state, allows policy operations - Policy execution with mode switching - Emergency stop handling """ def __init__(self): """Initialize the policy node with configuration, models, and ROS2 interfaces. Sets up the complete policy execution pipeline including: - Configuration loading from YAML file - Neural network model initialization - Motion data loading for all sequences - ROS2 publishers, subscribers, and timers - State machine initialization The node starts in a safe state and waits for proper robot state before allowing motion execution. """ super().__init__("policy_node") # Get config path from ROS parameter config_path = self.declare_parameter("config_path", "").value with open(config_path, "r", encoding="utf-8") as config_file: self.config_yaml = easydict.EasyDict(yaml.safe_load(config_file)) # Read policy frequency from config, default to 50 Hz if not specified policy_freq = self.config_yaml.get("policy_freq", 50) self.dt = 1.0 / policy_freq self.get_logger().info(f"Policy frequency set to: {policy_freq} Hz (dt = {self.dt:.4f} s)") # Initialize basic parameters - will be updated after config loading self.actions_dim = 29 # Default value, will be updated from config self.real_dof_names = [] # Will be loaded from config self.current_motion_clip_index = 0 # Current motion clip index # Button state tracking for preventing multiple triggers self.last_button_states = { KeyMap.up: 0, KeyMap.down: 0, KeyMap.left: 0, KeyMap.right: 0, KeyMap.A: 0, KeyMap.B: 0, KeyMap.Y: 0, } # Safety check related flags self.policy_enabled = False # Controls whether policy is enabled # Robot state related flags self.robot_state_ready = False # Marks whether MOVE_TO_DEFAULT state is received, allowing key operations self._setup_subscribers() self._setup_publishers() self._setup_timers() # Initialize variables for dual policy self.velocity_policy_session = None self.motion_policy_session = None self.use_kv_cache = False self.motion_kv_cache = None self.motion_kv_input_name = None self.motion_kv_output_name = None self.motion_step_idx_input_name = None self.current_policy_mode = "velocity" self.velocity_config = None self.motion_config = None self.motion_frame_idx = 0 self.ref_dof_pos = None self.ref_dof_vel = None self.ref_raw_bodylink_pos = None self.ref_raw_bodylink_rot = None self.n_motion_frames = 0 self.external_latest_obs = None self.external_obs_received = False self.last_external_obs_time = None self._latest_sender_timestamp = None self.latest_obs_flag = False self.latest_obs_expected_dim = 65 self.external_fut_dof_pos_queue = None self.external_fut_dof_vel_queue = None self.external_fut_root_pos_queue = None self.external_fut_root_rot_queue = None self.external_fut_frame_idx_queue = None self._prev_external_dof_pos = None self._prev_external_dof_vel = None self._prev_external_root_pos = None self._prev_external_root_rot = None self._prev_external_frame_idx = None self.max_data_age = 0.6 self.stale_data_warning_count = 0 self.last_poll_time = None self._last_vr_status_log_time = None self.latest_obs_zmq_uri = self.declare_parameter( "latest_obs_zmq_uri", "tcp://192.168.124.29:6001" ).value self.latest_obs_zmq_topic = self.declare_parameter( "latest_obs_zmq_topic", DEFAULT_ZMQ_TOPIC.decode("utf-8") ).value self.latest_obs_zmq_mode = self.declare_parameter( "latest_obs_zmq_mode", "connect" ).value self.latest_obs_zmq_conflate = self.declare_parameter( "latest_obs_zmq_conflate", True ).value self.zmq_jitter_delay_frames = self.declare_parameter( "zmq_jitter_delay_frames", 5 ).value self.require_vr_data_for_motion = self.declare_parameter( "require_vr_data_for_motion", True ).value self.enable_teleop_reference = self.declare_parameter( "enable_teleop_reference", True ).value self._cpu_affinity_main_str = self.declare_parameter( "cpu_affinity_main", "" ).value self._cpu_affinity_zmq_sub_str = self.declare_parameter( "cpu_affinity_zmq_sub", "" ).value self.timing_debug_enabled = self.declare_parameter( "timing_debug_enabled", True ).value self.timing_debug_log_interval_sec = self.declare_parameter( "timing_debug_log_interval_sec", 5.0 ).value self.timing_debug_log_per_loop = self.declare_parameter( "timing_debug_log_per_loop", False ).value self._timing_debug_last_log_time = None self._timing_debug_samples = deque(maxlen=500) self._root_only_fk_keybody_warned = False _vr = getattr(self.config_yaml, "vr", None) or {} if _vr: self.latest_obs_zmq_uri = str(_vr.get("latest_obs_zmq_uri", self.latest_obs_zmq_uri)) self.latest_obs_zmq_topic = str( _vr.get("latest_obs_zmq_topic", self.latest_obs_zmq_topic) ) self.latest_obs_zmq_mode = str( _vr.get("latest_obs_zmq_mode", self.latest_obs_zmq_mode) ) self.latest_obs_zmq_conflate = bool( _vr.get("latest_obs_zmq_conflate", self.latest_obs_zmq_conflate) ) self.zmq_jitter_delay_frames = int( _vr.get("zmq_jitter_delay_frames", self.zmq_jitter_delay_frames) ) self.max_data_age = float(_vr.get("max_data_age", self.max_data_age)) self.require_vr_data_for_motion = bool( _vr.get("require_vr_data_for_motion", self.require_vr_data_for_motion) ) self.enable_teleop_reference = bool( _vr.get("enable_teleop_reference", self.enable_teleop_reference) ) self.timing_debug_enabled = bool( _vr.get("timing_debug_enabled", self.timing_debug_enabled) ) self.timing_debug_log_interval_sec = float( _vr.get( "timing_debug_log_interval_sec", self.timing_debug_log_interval_sec, ) ) self.timing_debug_log_per_loop = bool( _vr.get("timing_debug_log_per_loop", self.timing_debug_log_per_loop) ) self._cpu_affinity_main_str = str( getattr(self.config_yaml, "cpu_affinity_main", self._cpu_affinity_main_str) ) self._cpu_affinity_zmq_sub_str = str( getattr( self.config_yaml, "cpu_affinity_zmq_sub", self._cpu_affinity_zmq_sub_str, ) ) self._ros_latest_obs_buffer = None self._npz_replay_frame_index = None self._external_seen_frames = 0 self._vr_ready_logged = False self._latest_obs_buffer = LatestObsBuffer() self._latest_obs_zmq_topic_bytes = _decode_zmq_topic(self.latest_obs_zmq_topic) if str(self.latest_obs_zmq_mode).strip().lower() == "connect": uri_str = str(self.latest_obs_zmq_uri) if "*" in uri_str or "0.0.0.0" in uri_str: self.get_logger().warn( "[ZMQ] connect mode requires a concrete peer address. " "Do not use '*' or '0.0.0.0'; use the sender IP instead, " "for example tcp://192.168.124.29:6001." ) zmq_cpu_affinity = _parse_cpu_affinity_str(self._cpu_affinity_zmq_sub_str) self._zmq_subscriber = ZmqLatestObsSubscriber( uri=self.latest_obs_zmq_uri, topic=self._latest_obs_zmq_topic_bytes, mode=self.latest_obs_zmq_mode, conflate=bool(self.latest_obs_zmq_conflate), buffer=self._latest_obs_buffer, logger=self.get_logger(), cpu_affinity=zmq_cpu_affinity if zmq_cpu_affinity else None, ) self._zmq_subscriber.start() self.get_logger().info( f"ZMQ latest_obs subscriber started: mode={self.latest_obs_zmq_mode}, " f"uri={self.latest_obs_zmq_uri}, topic={self.latest_obs_zmq_topic}, " f"jitter_delay={self.zmq_jitter_delay_frames}" ) self.dof_names_ref_motion = [] self.num_actions = 29 self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32) self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32) self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32) self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32) self.target_dof_pos_onnx = self.default_angles_onnx.copy() self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32) self._lowstate_msg = None self.target_dof_pos_real = None self.motion_in_progress = False self._keybody_indices_by_term_name = {} self.fk = None self.fk_initialized = False self.motion_action_ema_filter_enabled = False self.motion_action_ema_filter_alpha = 1.0 self._motion_filtered_actions_onnx = None def _is_vr_ready_for_motion(self) -> bool: """Return whether the ZMQ reference stream is ready for motion mode.""" if not getattr(self, "enable_teleop_reference", True): return False if not ( getattr(self, "external_obs_received", False) and getattr(self, "external_latest_obs", None) is not None ): return False n_fut = int(getattr(self, "n_fut_frames", 0) or 0) if n_fut <= 0: return True delay = int(getattr(self, "zmq_jitter_delay_frames", 0) or 0) needed = n_fut + max(delay, 0) + 1 return int(getattr(self, "_external_seen_frames", 0)) >= needed def _init_keybody_indices_cache(self): if self.motion_config is None: raise ValueError("motion_config is not loaded; cannot init keybody index cache") atomic_list = self._get_policy_atomic_obs_list(self.motion_config)["atomic_obs_list"] body_names = [str(name) for name in self.motion_config.robot.body_names] body_name_to_idx = {body_name: idx for idx, body_name in enumerate(body_names)} cache = {} for term_dict in atomic_list: term_name = str(list(term_dict.keys())[0]) term_cfg = term_dict[term_name] params = {} if isinstance(term_cfg, dict): params = term_cfg.get("params", {}) or {} if not isinstance(params, dict): raise ValueError( f"Observation term '{term_name}' params must be a dict, got {type(params)}" ) needs_keybody = ("keybody" in term_name) or ("keybody_names" in params) if not needs_keybody: continue keybody_names = params.get("keybody_names", None) if keybody_names is None: keybody_idxs = np.arange(len(body_names), dtype=np.int64) else: keybody_names = [str(name) for name in keybody_names] missing_names = [ name for name in keybody_names if name not in body_name_to_idx ] if len(missing_names) > 0: raise ValueError( f"Unknown keybody_names in '{term_name}': {missing_names}. " f"Available body names: {body_names}" ) keybody_idxs = np.asarray( [body_name_to_idx[name] for name in keybody_names], dtype=np.int64, ) cache[term_name] = keybody_idxs self._keybody_indices_by_term_name = cache def _get_policy_atomic_obs_list(self, config): """Resolve the atomic obs list used to build the ONNX policy input. Aligns with MuJoCo sim2sim eval ordering by honoring modules.actor.obs_schema when available, to guarantee the policy input term order matches training/export. """ def _to_plain_obs_cfg(cfg): if OmegaConf.is_config(cfg): plain_cfg = OmegaConf.to_container(cfg, resolve=True) elif cfg is None: plain_cfg = {} else: plain_cfg = dict(cfg) if plain_cfg is None: plain_cfg = {} if not isinstance(plain_cfg, dict): raise ValueError( f"Observation term config must be a mapping, got {type(plain_cfg)}" ) return plain_cfg def _get_actor_atomic_obs_entries(): obs_cfg = config.get("obs", None) if obs_cfg is None: raise ValueError("Missing config.obs for policy obs") obs_groups = obs_cfg.get("obs_groups", None) if obs_groups is None: raise ValueError("Missing config.obs.obs_groups for policy obs") if obs_groups.get("policy", None) is not None: entries = [] for term_dict in obs_groups.policy.atomic_obs_list: term_name = str(list(term_dict.keys())[0]) entries.append( ( "policy", term_name, _to_plain_obs_cfg(term_dict[term_name]), ) ) return entries if obs_groups.get("unified", None) is not None: entries = [] for term_dict in obs_groups.unified.atomic_obs_list: term_name = str(list(term_dict.keys())[0]) if term_name.startswith("actor_"): entries.append( ( "unified", term_name, _to_plain_obs_cfg(term_dict[term_name]), ) ) if not entries: raise ValueError( "obs_groups.unified found but contains no actor_* terms." ) return entries raise ValueError( "Unsupported obs config : expected obs_groups.policy or obs_groups.unified." ) def _get_actor_obs_schema_terms(): modules_cfg = config.get("modules", None) if modules_cfg is None: return [] actor_cfg = modules_cfg.get("actor", None) if actor_cfg is None: return [] obs_schema = actor_cfg.get("obs_schema", None) if obs_schema is None: return [] if OmegaConf.is_config(obs_schema): obs_schema_plain = OmegaConf.to_container(obs_schema, resolve=True) else: obs_schema_plain = obs_schema if not isinstance(obs_schema_plain, dict): return [] ordered_terms = [] def _collect_terms(node): if node is None: return if isinstance(node, dict): if "terms" in node and isinstance(node["terms"], list): ordered_terms.extend(str(term) for term in node["terms"]) return for v in node.values(): _collect_terms(v) return if isinstance(node, list): for v in node: _collect_terms(v) return _collect_terms(obs_schema_plain) return ordered_terms actor_atomic_entries = _get_actor_atomic_obs_entries() schema_terms = _get_actor_obs_schema_terms() if len(schema_terms) == 0: return { "atomic_obs_list": [ {term_name: term_cfg} for _, term_name, term_cfg in actor_atomic_entries ] } by_full_key = {} by_leaf_key = {} ambiguous_leaf_keys = set() for group_name, term_name, term_cfg in actor_atomic_entries: full_key = f"{group_name}/{term_name}" by_full_key[full_key] = (term_name, term_cfg) if term_name in by_leaf_key: ambiguous_leaf_keys.add(term_name) else: by_leaf_key[term_name] = (term_name, term_cfg) ordered_atomic_list = [] for schema_term in schema_terms: schema_term_key = str(schema_term) if schema_term_key in by_full_key: term_name, term_cfg = by_full_key[schema_term_key] ordered_atomic_list.append({term_name: term_cfg}) continue leaf_key = schema_term_key.split("/")[-1] if leaf_key in ambiguous_leaf_keys: raise ValueError( f"Ambiguous obs_schema term '{schema_term_key}': " f"multiple atomic obs share leaf key '{leaf_key}'." ) if leaf_key not in by_leaf_key: available = sorted(list(by_leaf_key.keys())) raise ValueError( f"obs_schema term '{schema_term_key}' not found in atomic_obs_list. " f"Available terms: {available}" ) term_name, term_cfg = by_leaf_key[leaf_key] ordered_atomic_list.append({term_name: term_cfg}) return {"atomic_obs_list": ordered_atomic_list} def _find_actor_place_holder_ndim(self): n_dim = 0 atomic_list = self._get_policy_atomic_obs_list(self.motion_config)[ "atomic_obs_list" ] for obs_dict in atomic_list: name = str(list(obs_dict.keys())[0]) if name == "place_holder" or name == "actor_place_holder": cfg = obs_dict[name] params = cfg.get("params", {}) if isinstance(cfg, dict) else {} n_dim = int(params.get("n_dim", 0)) return n_dim def _init_obs_buffers(self): """Initialize observation builders for both velocity and motion policies. Each obs_builder uses its own model's dof_names_onnx and default_angles_onnx to ensure correct observation computation for each policy. """ # Use velocity model's parameters for velocity obs_builder self.velocity_obs_builder = PolicyObsBuilder( dof_names_onnx=self.velocity_dof_names_onnx, default_angles_onnx=self.velocity_default_angles_onnx, evaluator=self, obs_policy_cfg=self._get_policy_atomic_obs_list(self.velocity_config), ) # Use motion model's parameters for motion obs_builder self.motion_obs_builder = PolicyObsBuilder( dof_names_onnx=self.motion_dof_names_onnx, default_angles_onnx=self.motion_default_angles_onnx, evaluator=self, obs_policy_cfg=self._get_policy_atomic_obs_list(self.motion_config), ) if hasattr(self, "n_fut_frames") and int(self.n_fut_frames) > 0: n_fut = int(self.n_fut_frames) self.external_fut_dof_pos_queue = np.zeros((n_fut, self.num_actions), dtype=np.float32) self.external_fut_dof_vel_queue = np.zeros((n_fut, self.num_actions), dtype=np.float32) self.external_fut_root_pos_queue = np.zeros((n_fut, 3), dtype=np.float32) self.external_fut_root_rot_queue = np.zeros((n_fut, 4), dtype=np.float32) self._fk_root_pos_seq_np = np.zeros((1, n_fut + 1, 3), dtype=np.float32) self._fk_root_rot_seq_np = np.zeros((1, n_fut + 1, 4), dtype=np.float32) self._fk_dof_pos_seq_np = np.zeros( (1, n_fut + 1, self.num_actions), dtype=np.float32 ) self._fk_root_pos_seq_tensor = torch.from_numpy(self._fk_root_pos_seq_np) self._fk_root_rot_seq_tensor = torch.from_numpy(self._fk_root_rot_seq_np) self._fk_dof_pos_seq_tensor = torch.from_numpy(self._fk_dof_pos_seq_np) self.external_fut_frame_idx_queue = np.full((n_fut,), -1, dtype=np.int32) self.get_logger().info( f"Initialized VR future frame queues: n_fut_frames={n_fut}, num_actions={self.num_actions}" ) else: self.external_fut_dof_pos_queue = None self.external_fut_dof_vel_queue = None self.external_fut_root_pos_queue = None self.external_fut_root_rot_queue = None self.external_fut_frame_idx_queue = None self._fk_root_pos_seq_np = None self._fk_root_rot_seq_np = None self._fk_dof_pos_seq_np = None self._fk_root_pos_seq_tensor = None self._fk_root_rot_seq_tensor = None self._fk_dof_pos_seq_tensor = None # Set default obs_builder to velocity mode self.obs_builder = self.velocity_obs_builder def _reset_counter(self): """Reset motion timing counters to start of sequence.""" self.motion_frame_idx = 0 self.motion_step_idx = 0 if self.use_kv_cache and self.motion_kv_cache is not None: self.motion_kv_cache.fill(0) def _switch_to_velocity_mode(self, reason: str = ""): """Switch to velocity tracking mode and clear action cache. Uses velocity model's default_angles_onnx to ensure correct initialization. Also publishes velocity model's control parameters (kps/kds). """ self.current_policy_mode = "velocity" self.latest_obs_flag = False self.motion_in_progress = False self._fk_vr_out = None self._use_fk_vr = False self._reset_motion_action_ema_filter() self._reset_counter() self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32) # Use velocity model's default angles self.target_dof_pos_onnx = self.velocity_default_angles_onnx.copy() # Publish velocity model's control parameters self._publish_control_params() if reason: self.get_logger().info(f"Switched to velocity tracking mode ({reason})") else: self.get_logger().info("Switched to velocity tracking mode") def _is_button_pressed(self, button_key): """Check if button was just pressed (rising edge detection).""" current_state = self.remote_controller.button[button_key] last_state = self.last_button_states[button_key] # Update the last state self.last_button_states[button_key] = current_state # Return True only on rising edge (0 -> 1) return current_state == 1 and last_state == 0 def load_policy(self): """Load both velocity and motion policy models using ONNX Runtime.""" self.get_logger().info("Loading dual policies...") providers = [ ( "CUDAExecutionProvider", { "device_id": 0, }, ), "CPUExecutionProvider", ] onnx_threads = int(self.config_yaml.get("onnx_intra_op_threads", 2)) sess_options = onnxruntime.SessionOptions() sess_options.intra_op_num_threads = onnx_threads sess_options.inter_op_num_threads = 1 # Load velocity policy from model folder velocity_model_folder = self.config_yaml.velocity_tracking_model_folder velocity_model_path = os.path.join( get_package_share_directory("humanoid_control"), "models", velocity_model_folder, "exported", ) # Find ONNX file in exported folder velocity_onnx_files = [f for f in os.listdir(velocity_model_path) if f.endswith('.onnx')] if not velocity_onnx_files: raise FileNotFoundError(f"No ONNX files found in {velocity_model_path}") velocity_onnx_path = os.path.join(velocity_model_path, velocity_onnx_files[0]) self.get_logger().info(f"Loading velocity policy from {velocity_onnx_path}") self.velocity_policy_session = onnxruntime.InferenceSession( str(velocity_onnx_path), sess_options=sess_options, providers=providers ) self.get_logger().info( f"Velocity policy loaded successfully using: " f"{self.velocity_policy_session.get_providers()}" ) # Load motion policy from model folder motion_model_folder = self.config_yaml.motion_tracking_model_folder motion_model_path = os.path.join( get_package_share_directory("humanoid_control"), "models", motion_model_folder, "exported", ) # Find ONNX file in exported folder motion_onnx_files = [f for f in os.listdir(motion_model_path) if f.endswith('.onnx')] if not motion_onnx_files: raise FileNotFoundError(f"No ONNX files found in {motion_model_path}") motion_onnx_path = os.path.join(motion_model_path, motion_onnx_files[0]) self.get_logger().info(f"Loading motion policy from {motion_onnx_path}") self.motion_policy_session = onnxruntime.InferenceSession( str(motion_onnx_path), sess_options=sess_options, providers=providers ) self.get_logger().info( f"Motion policy loaded successfully using: " f"{self.motion_policy_session.get_providers()}" ) # Set input/output names for both policies self.velocity_input_name = self.velocity_policy_session.get_inputs()[0].name self.velocity_output_name = self.velocity_policy_session.get_outputs()[0].name self.motion_input_name = self.motion_policy_session.get_inputs()[0].name self.motion_output_name = self.motion_policy_session.get_outputs()[0].name self.get_logger().info( f"Velocity policy - Input: {self.velocity_input_name}, " f"Output: {self.velocity_output_name}" ) self.get_logger().info( f"Motion policy - Input: {self.motion_input_name}, " f"Output: {self.motion_output_name}" ) # Store ONNX paths for metadata reading self.velocity_onnx_path = velocity_onnx_path self.motion_onnx_path = motion_onnx_path self.get_logger().info("Initializing KV-Cache for Motion Policy...") self.motion_kv_input_name = None self.motion_kv_output_name = None self.motion_kv_shape = None self.motion_step_idx_input_name = None self.motion_kv_dtype = np.float32 for node in self.motion_policy_session.get_inputs(): name = node.name shape = node.shape node_type = node.type self.get_logger().info(f"Motion policy input: name={name}, shape={shape}, type={node_type}") if "obs" in name: self.motion_input_name = name elif "past_key_values" in name: self.motion_kv_input_name = name self.motion_kv_shape = shape if isinstance(node_type, str) and "float16" in node_type: self.motion_kv_dtype = np.float16 elif "step_idx" in name or name == "step_idx": self.motion_step_idx_input_name = name elif "current_pos" in name or name == "current_pos": self.motion_step_idx_input_name = name elif ( self.motion_step_idx_input_name is None and isinstance(node_type, str) and "int64" in node_type and name not in (self.motion_input_name, self.motion_kv_input_name) ): self.motion_step_idx_input_name = name motion_outputs = self.motion_policy_session.get_outputs() action_output_name = None kv_output_name = None for node in motion_outputs: self.get_logger().info(f"Motion policy output: name={node.name}, shape={node.shape}, type={node.type}") if "present_key_values" in node.name: kv_output_name = node.name elif "actions" in node.name: action_output_name = node.name if action_output_name is None: for node in motion_outputs: if kv_output_name is not None and node.name == kv_output_name: continue action_output_name = node.name break if action_output_name is None: action_output_name = motion_outputs[0].name self.motion_output_name = action_output_name self.motion_kv_output_name = kv_output_name if self.motion_kv_input_name is not None and self.motion_kv_output_name is None: self.get_logger().warn( "Motion policy has past_key_values input but no present_key_values output was found. " "KV cache will not update and transformer performance will degrade." ) if self.motion_kv_input_name and self.motion_kv_shape: shape = [d if isinstance(d, int) else 1 for d in self.motion_kv_shape] self.motion_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype) self.motion_model_context_len = int(shape[3]) if len(shape) > 3 else 0 self.motion_max_context_len = int( self.motion_config.get("algo", {}) .get("config", {}) .get("num_steps_per_env", 0) ) if self.motion_max_context_len > 0 and self.motion_model_context_len > 0: self.motion_effective_context_len = min( self.motion_max_context_len, self.motion_model_context_len ) else: self.motion_effective_context_len = self.motion_model_context_len self.use_kv_cache = True self.get_logger().info( f"KV-Cache initialized with shape {shape} " f"(model_ctx={self.motion_model_context_len}, " f"effective_ctx={self.motion_effective_context_len})" ) else: self.use_kv_cache = False self.motion_kv_cache = None self.motion_model_context_len = 0 self.motion_effective_context_len = 0 self.get_logger().warn("No KV-Cache inputs found in Motion Policy model!") self.get_logger().info("Dual policies loaded successfully") def load_model_config(self): """Load config.yaml from both velocity and motion model folders.""" # Load velocity model config velocity_model_folder = self.config_yaml.velocity_tracking_model_folder velocity_config_dir = os.path.join( get_package_share_directory("humanoid_control"), "models", velocity_model_folder, ) # Try different config file names for velocity model config_names = ["config.yaml"] velocity_config_path = None for config_name in config_names: potential_path = os.path.join(velocity_config_dir, config_name) if os.path.exists(potential_path): velocity_config_path = potential_path break if velocity_config_path is None: raise FileNotFoundError( f"No config file found in {velocity_config_dir}. Tried: {config_names}" ) self.get_logger().info( f"Loading velocity model config from {velocity_config_path}" ) self.velocity_config = OmegaConf.load(velocity_config_path) # Load motion model config motion_model_folder = self.config_yaml.motion_tracking_model_folder motion_config_dir = os.path.join( get_package_share_directory("humanoid_control"), "models", motion_model_folder, ) # Try different config file names for motion model motion_config_path = None for config_name in config_names: potential_path = os.path.join(motion_config_dir, config_name) if os.path.exists(potential_path): motion_config_path = potential_path break if motion_config_path is None: raise FileNotFoundError( f"No config file found in {motion_config_dir}. Tried: {config_names}" ) self.get_logger().info(f"Loading motion model config from {motion_config_path}") self.motion_config = OmegaConf.load(motion_config_path) self._load_motion_action_ema_filter_cfg() self.actor_place_holder_ndim = self._find_actor_place_holder_ndim() self.n_fut_frames = int(self.motion_config.obs.n_fut_frames) self.torso_body_idx = self.motion_config.robot.body_names.index("torso_link") self.get_logger().info("Both model configs loaded successfully") def _load_motion_action_ema_filter_cfg(self) -> None: actuator_cfg = self.motion_config.get("robot", {}).get("actuators", {}) enabled_raw = actuator_cfg.get("ema_filter_enabled", None) alpha_raw = actuator_cfg.get("ema_filter_alpha", None) if enabled_raw is None or alpha_raw is None: self.motion_action_ema_filter_enabled = False self.motion_action_ema_filter_alpha = 1.0 self.get_logger().info( "[Motion EMA] ema_filter_enabled/ema_filter_alpha not found in motion config; EMA disabled." ) return self.motion_action_ema_filter_enabled = _coerce_config_bool( enabled_raw, default=False ) self.motion_action_ema_filter_alpha = float(alpha_raw) if not 0.0 <= self.motion_action_ema_filter_alpha <= 1.0: raise ValueError( "motion_config robot.actuators.ema_filter_alpha must be within [0, 1], " f"got {self.motion_action_ema_filter_alpha}." ) self.get_logger().info( "[Motion EMA] Loaded from motion config: " f"enabled={self.motion_action_ema_filter_enabled}, " f"alpha={self.motion_action_ema_filter_alpha:.4f}" ) def _reset_motion_action_ema_filter(self) -> None: self._motion_filtered_actions_onnx = None def _apply_motion_action_ema_filter( self, raw_actions: np.ndarray ) -> np.ndarray: raw_actions = np.asarray(raw_actions, dtype=np.float32).reshape(-1) if not self.motion_action_ema_filter_enabled: return raw_actions.copy() if self._motion_filtered_actions_onnx is None: self._motion_filtered_actions_onnx = raw_actions.copy() return self._motion_filtered_actions_onnx.copy() alpha = float(self.motion_action_ema_filter_alpha) filtered_actions = ( alpha * raw_actions + (1.0 - alpha) * self._motion_filtered_actions_onnx ).astype(np.float32, copy=False) self._motion_filtered_actions_onnx = filtered_actions.copy() return self._motion_filtered_actions_onnx.copy() def _build_dummy_input_from_onnx_node(self, node, fallback_last_dim: int | None = None): shape = list(getattr(node, "shape", []) or []) if not shape: shape = [1] inferred_shape = [_infer_onnx_dim(dim, default=1) for dim in shape] if fallback_last_dim is not None and len(inferred_shape) >= 2: last_dim = shape[-1] if not isinstance(last_dim, int) or last_dim <= 0: inferred_shape[-1] = int(fallback_last_dim) dtype = _infer_numpy_dtype_from_onnx_type(getattr(node, "type", "tensor(float)")) return np.zeros(inferred_shape, dtype=dtype) def _warmup_motion_policy(self, num_iters: int = 2) -> None: if self.motion_policy_session is None: return try: input_nodes = {node.name: node for node in self.motion_policy_session.get_inputs()} obs_node = input_nodes.get(self.motion_input_name, None) if obs_node is None: raise ValueError( f"Motion warmup failed to find obs input '{self.motion_input_name}'." ) motion_obs_dim = None try: motion_obs_dim = int( self.motion_obs_builder.build_policy_obs().shape[0] ) except Exception: motion_obs_dim = None obs_dummy = self._build_dummy_input_from_onnx_node( obs_node, fallback_last_dim=motion_obs_dim ) output_names = [self.motion_output_name] if self.motion_kv_output_name: output_names.append(self.motion_kv_output_name) local_kv_cache = None if self.use_kv_cache and self.motion_kv_input_name is not None: if self.motion_kv_cache is not None: local_kv_cache = np.zeros_like(self.motion_kv_cache) else: shape = [ _infer_onnx_dim(dim, default=1) for dim in (self.motion_kv_shape or []) ] local_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype) for warmup_step in range(max(1, int(num_iters))): input_feed = {self.motion_input_name: obs_dummy} if self.use_kv_cache and self.motion_kv_input_name is not None: input_feed[self.motion_kv_input_name] = local_kv_cache if self.motion_step_idx_input_name is not None: step_node = input_nodes.get(self.motion_step_idx_input_name, None) step_dtype = np.int64 if step_node is not None: step_dtype = _infer_numpy_dtype_from_onnx_type( getattr(step_node, "type", "tensor(int64)") ) input_feed[self.motion_step_idx_input_name] = np.array( [warmup_step], dtype=step_dtype ) warmup_output = self.motion_policy_session.run(output_names, input_feed) if ( local_kv_cache is not None and self.motion_kv_output_name and len(warmup_output) > 1 ): local_kv_cache = warmup_output[1] if self.motion_kv_cache is not None: self.motion_kv_cache.fill(0) self.motion_step_idx = 0 self.get_logger().info( f"[Warmup] Motion policy warmup completed ({max(1, int(num_iters))} iterations, KV cache kept clean)." ) except Exception as exc: if self.motion_kv_cache is not None: self.motion_kv_cache.fill(0) self.motion_step_idx = 0 self.get_logger().warn(f"[Warmup] Motion policy warmup skipped: {exc}") def update_config_parameters(self): """Update configuration parameters from loaded configs.""" # Check if both models have the same basic parameters velocity_actions_dim = self.velocity_config.get("robot", {}).get("actions_dim", 29) motion_actions_dim = self.motion_config.get("robot", {}).get("actions_dim", 29) velocity_dof_names = self.velocity_config.get("robot", {}).get("dof_names", []) motion_dof_names = self.motion_config.get("robot", {}).get("dof_names", []) # Verify that both models have compatible configurations if velocity_actions_dim != motion_actions_dim: self.get_logger().warn( f"Different actions_dim: velocity={velocity_actions_dim}, " f"motion={motion_actions_dim}" ) if velocity_dof_names != motion_dof_names: self.get_logger().warn(f"Different dof_names between models") self.get_logger().warn(f"Velocity dof_names: {len(velocity_dof_names)} items") self.get_logger().warn(f"Motion dof_names: {len(motion_dof_names)} items") # Use velocity config as the primary source for basic parameters config = self.velocity_config # Update basic parameters self.actions_dim = config.get("robot", {}).get("actions_dim", 29) self.real_dof_names = config.get("robot", {}).get("dof_names", []) self.dof_names_ref_motion = list(config.robot.dof_names) self.num_actions = len(self.dof_names_ref_motion) # Update arrays with correct sizes self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32) self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32) self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32) self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32) self.target_dof_pos_onnx = self.default_angles_onnx.copy() self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32) self.get_logger().info( f"Updated config parameters: actions_dim={self.actions_dim}, " f"dof_names={len(self.real_dof_names)}" ) def load_motion_data(self): """Load motion clip data from .npz files.""" motion_clips_dir = os.path.join( get_package_share_directory("humanoid_control"), self.config_yaml.motion_clip_dir, ) self.get_logger().info(f"Looking for motion clip data in: {motion_clips_dir}") self.get_logger().info(f"Directory exists: {os.path.exists(motion_clips_dir)}") if not os.path.exists(motion_clips_dir): self.get_logger().warn(f"Motion clips directory not found: {motion_clips_dir}") return # Only collect .npz files motion_clip_files = [f for f in os.listdir(motion_clips_dir) if f.endswith(".npz")] motion_clip_files.sort() self.get_logger().info( f"Found {len(motion_clip_files)} motion clip files (.npz): {motion_clip_files}" ) if not motion_clip_files: self.get_logger().warn( f"No motion clip files (.npz) found in directory: {motion_clips_dir}" ) return # Load each .npz file self.all_motion_data = [] self.motion_file_names = [] for motion_clip_file in motion_clip_files: motion_path = os.path.join(motion_clips_dir, motion_clip_file) motion_data_dict = dict(np.load(motion_path, allow_pickle=True)) self.all_motion_data.append( { "dof_pos": motion_data_dict["ref_dof_pos"], "dof_vel": motion_data_dict["ref_dof_vel"], "global_translation": motion_data_dict[ "ref_global_translation" ], "global_rotation_quat": motion_data_dict[ "ref_global_rotation_quat" ], "global_velocity": motion_data_dict["ref_global_velocity"], "global_angular_velocity": motion_data_dict["ref_global_angular_velocity"], "n_frames": motion_data_dict["ref_dof_pos"].shape[0], } ) self.motion_file_names.append(motion_clip_file) if not self.all_motion_data: self.get_logger().error("Failed to load any motion clip files") return # Initialize with the first motion clip self.current_motion_clip_index = 0 self._load_current_motion() self.get_logger().info(f"Loaded {len(self.all_motion_data)} motion clips successfully") self.get_logger().info( f"Current motion clip: {self.motion_file_names[self.current_motion_clip_index]}" ) def _load_current_motion(self): """Load the current selected motion clip data.""" if not self.all_motion_data: return self.motion_frame_idx = 0 current_motion = self.all_motion_data[self.current_motion_clip_index] self.ref_dof_pos = current_motion["dof_pos"] self.ref_dof_vel = current_motion["dof_vel"] self.ref_raw_bodylink_pos = current_motion["global_translation"] self.ref_raw_bodylink_rot = current_motion["global_rotation_quat"] self.ref_global_velocity = current_motion["global_velocity"] self.ref_global_angular_velocity = current_motion["global_angular_velocity"] self.n_motion_frames = current_motion["n_frames"] if self.ref_dof_pos is None or self.ref_dof_vel is None: raise ValueError("Motion clip is missing ref_dof_pos/ref_dof_vel arrays") if self.ref_raw_bodylink_pos is None or self.ref_raw_bodylink_rot is None: raise ValueError( "Motion clip is missing ref_global_translation/ref_global_rotation_quat arrays" ) if int(self.ref_dof_pos.shape[1]) != int(len(self.dof_names_ref_motion)): raise ValueError( "ref_dof_pos DOF dimension mismatch: " f"ref_dof_pos.shape[1]={int(self.ref_dof_pos.shape[1])} " f"but len(dof_names_ref_motion)={int(len(self.dof_names_ref_motion))}" ) if int(self.ref_raw_bodylink_pos.shape[1]) != int( len(self.motion_config.robot.body_names) ): raise ValueError( "ref_global_translation body dimension mismatch: " f"ref_raw_bodylink_pos.shape[1]={int(self.ref_raw_bodylink_pos.shape[1])} " f"but len(motion_config.robot.body_names)={int(len(self.motion_config.robot.body_names))}" ) self.motion_in_progress = True self.get_logger().info( f"Loaded motion clip {self.current_motion_clip_index}: " f"{self.motion_file_names[self.current_motion_clip_index]} ({self.n_motion_frames} frames)" ) def _setup_subscribers(self): """Set up ROS2 subscribers for robot state and remote controller input.""" self.remote_controller = RemoteController() self.low_state_sub = self.create_subscription( LowState, self.config_yaml.lowstate_topic, self._low_state_callback, QoSProfile(depth=10), ) # Add robot_state topic subscription self.robot_state_sub = self.create_subscription( String, "/robot_state", self._robot_state_callback, QoSProfile(depth=10), ) self.latest_obs_ros_sub = self.create_subscription( Float32MultiArray, "latest_obs_ros", self._latest_obs_ros_callback, QoSProfile(depth=10), ) def _latest_obs_ros_callback(self, msg: Float32MultiArray): """Receive replayed latest_obs_ros messages for offline validation.""" data = np.asarray(msg.data, dtype=np.float32) if data.size == 66: frame_idx = int(data[0]) obs = data[1:66] self._ros_latest_obs_buffer = (frame_idx, obs) elif data.size >= 65: self._ros_latest_obs_buffer = (None, data[:65]) def _setup_publishers(self): """Set up ROS2 publishers for action commands and status information.""" self.action_pub = self.create_publisher( Float32MultiArray, self.config_yaml.action_topic, QoSProfile(depth=10), ) # Add publishers for kps and kds parameters self.kps_pub = self.create_publisher( Float32MultiArray, "/humanoid/kps", QoSProfile(depth=10), ) self.kds_pub = self.create_publisher( Float32MultiArray, "/humanoid/kds", QoSProfile(depth=10), ) # Add publisher for policy mode status self.policy_mode_pub = self.create_publisher( String, "policy_mode", QoSProfile(depth=10), ) self.latest_obs_pub = self.create_publisher( Float32MultiArray, "latest_obs", QoSProfile(depth=10), ) def _setup_timers(self): """Set up ROS2 timer for main execution loop.""" # Create a one-time timer to call setup after ROS2 initialization self.create_timer(0.1, self._delayed_setup) self.create_timer(self.dt, self.run) def _delayed_setup(self): """Call setup after ROS2 initialization is complete.""" if not hasattr(self, '_setup_completed'): self.get_logger().info("Starting policy node setup...") try: self.setup() self._setup_completed = True self.get_logger().info("Policy node setup completed successfully") except Exception as e: self.get_logger().error(f"Policy node setup failed: {e}") # Cancel the timer to avoid repeated attempts return def _robot_state_callback(self, msg: String): """Handle robot state messages for safety coordination. Processes robot state updates from the main control node to ensure safe operation. Button operations are only allowed when the robot is in MOVE_TO_DEFAULT state. Args: msg: String message containing robot state information Valid states: ZERO_TORQUE, MOVE_TO_DEFAULT, EMERGENCY_STOP, POLICY """ robot_state = msg.data # Only allow button operations when robot state is MOVE_TO_DEFAULT if robot_state == "MOVE_TO_DEFAULT": self.robot_state_ready = True elif robot_state == "ZERO_TORQUE": self.robot_state_ready = False elif robot_state == "EMERGENCY_STOP": self.robot_state_ready = False # =========== Properties =========== @property def robot_root_rot_quat_wxyz(self): return np.array(self._lowstate_msg.imu_state.quaternion, dtype=np.float32) @property def robot_root_ang_vel(self): return np.array(self._lowstate_msg.imu_state.gyroscope, dtype=np.float32) @property def robot_dof_pos_by_name(self): """Get DOF positions by name.""" if self._lowstate_msg is None: return {} return { self.real_dof_names[i]: float(self._lowstate_msg.motor_state[i].q) for i in range(self.actions_dim) } @property def robot_dof_vel_by_name(self): """Get DOF velocities by name.""" if self._lowstate_msg is None: return {} return { self.real_dof_names[i]: float(self._lowstate_msg.motor_state[i].dq) for i in range(self.actions_dim) } @property def ref_motion_frame_idx(self): return min(self.motion_frame_idx, self.n_motion_frames - 1) @property def ref_dof_pos_raw(self): if not self.latest_obs_flag: return self.ref_dof_pos[self.ref_motion_frame_idx] if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None: if self._prev_external_dof_pos is not None: return self._prev_external_dof_pos return self.external_fut_dof_pos_queue[0] if self.external_latest_obs is None: return self.ref_dof_pos[self.ref_motion_frame_idx] return self.external_latest_obs[0, :29] @property def ref_dof_vel_raw(self): if not self.latest_obs_flag: return self.ref_dof_vel[self.ref_motion_frame_idx] if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None: if self._prev_external_dof_vel is not None: return self._prev_external_dof_vel return self.external_fut_dof_vel_queue[0] if self.external_latest_obs is None: return self.ref_dof_vel[self.ref_motion_frame_idx] return self.external_latest_obs[0, 29:58] @property def ref_dof_pos_onnx_order(self): return self.ref_dof_pos_raw[self.ref_to_onnx] @property def ref_dof_vel_onnx_order(self): return self.ref_dof_vel_raw[self.ref_to_onnx] @property def ref_root_pos_raw(self): if not self.latest_obs_flag: return np.asarray( self.ref_raw_bodylink_pos[self.ref_motion_frame_idx, self.root_body_idx], dtype=np.float32, ) if self.n_fut_frames > 0 and self.external_fut_root_pos_queue is not None: if self._prev_external_root_pos is not None: return self._prev_external_root_pos.astype(np.float32) return self.external_fut_root_pos_queue[0].astype(np.float32) if self.external_latest_obs is None: return np.zeros(3, dtype=np.float32) return self.external_latest_obs[0, 58:61].astype(np.float32) @property def root_body_idx(self): return 0 @property def last_valid_ref_motion_frame_idx(self): return self.n_motion_frames - 1 # =========== Policy Obeservation Methods =========== def _xyzw_to_wxyz(self, q_xyzw: np.ndarray) -> np.ndarray: """Convert quaternions from xyzw to wxyz order.""" q_xyzw = np.asarray(q_xyzw, dtype=np.float32) if q_xyzw.shape[-1] != 4: raise ValueError(f"_xyzw_to_wxyz expects (...,4) but got shape {q_xyzw.shape}") # q_xyzw[..., 0:3] -> xyz, q_xyzw[..., 3:4] -> w w = q_xyzw[..., 3:4] xyz = q_xyzw[..., 0:3] return np.concatenate([w, xyz], axis=-1) def _standardize_quaternion_wxyz(self, q_wxyz: np.ndarray) -> np.ndarray: """Standardize quaternion sign so that w >= 0.""" q_wxyz = np.asarray(q_wxyz, dtype=np.float32) if q_wxyz.shape[-1] != 4: raise ValueError(f"_standardize_quaternion_wxyz expects (...,4) but got shape {q_wxyz.shape}") mask = q_wxyz[..., 0:1] < 0.0 q_wxyz = np.where(mask, -q_wxyz, q_wxyz) return q_wxyz def _quat_rotate_wxyz(self, q_wxyz: np.ndarray, v: np.ndarray) -> np.ndarray: q_wxyz = np.asarray(q_wxyz, dtype=np.float32) v = np.asarray(v, dtype=np.float32) qvec = q_wxyz[..., 1:4] w = q_wxyz[..., 0:1] t = 2.0 * np.cross(qvec, v) return v + w * t + np.cross(qvec, t) def _quat_rotate_inv_wxyz(self, q_wxyz: np.ndarray, v: np.ndarray) -> np.ndarray: q_wxyz = np.asarray(q_wxyz, dtype=np.float32) n = int(np.prod(q_wxyz.shape[:-1])) if q_wxyz.ndim > 1 else 1 q_conj = self._q_conj_buffer[:n].reshape(q_wxyz.shape) q_conj[..., 0] = q_wxyz[..., 0] q_conj[..., 1:4] = -q_wxyz[..., 1:4] return self._quat_rotate_wxyz(q_conj, v) def _quat_rotate_inv_wxyz_single( self, q_wxyz: np.ndarray, v: np.ndarray, out: np.ndarray ) -> np.ndarray: """Rotate one 3D vector by the inverse quaternion into a preallocated output.""" q_conj = self._q_conj_buffer[0] q_conj[0] = q_wxyz[0] q_conj[1] = -q_wxyz[1] q_conj[2] = -q_wxyz[2] q_conj[3] = -q_wxyz[3] qvec = q_conj[1:4] w = q_conj[0] self._cross_t_buffer[:] = np.cross(qvec, v) self._cross_t_buffer *= 2.0 out[:] = v + w * self._cross_t_buffer self._cross_t_buffer[:] = np.cross(qvec, self._cross_t_buffer) out += self._cross_t_buffer return out def _get_future_frame_indices(self) -> np.ndarray: frame_idx = self.ref_motion_frame_idx last_valid = self.last_valid_ref_motion_frame_idx np.minimum( frame_idx + self._future_frame_offsets, last_valid, out=self._future_frame_indices_buffer, ) return self._future_frame_indices_buffer def _cache_fk_vr_for_obs(self): """Cache FK outputs used repeatedly during observation construction.""" fk = getattr(self, "_fk_vr_out", None) if not getattr(self, "latest_obs_flag", False) or fk is None: self._use_fk_vr = False return self._use_fk_vr = True T = self.n_fut_frames_int rb = self.root_body_idx np.copyto(self._fk_vel_0_root, fk["global_velocity"][0, 0, rb]) np.copyto(self._fk_angvel_0_root, fk["global_angular_velocity"][0, 0, rb]) np.copyto(self._fk_quat_0_root, fk["global_rotation_quat"][0, 0, rb]) self._fk_quat_0_root_wxyz[0] = self._fk_quat_0_root[3] self._fk_quat_0_root_wxyz[1:4] = self._fk_quat_0_root[:3] if self._fk_quat_0_root_wxyz[0] < 0.0: self._fk_quat_0_root_wxyz *= -1.0 trans_0 = fk["global_translation"][0, 0] if self._fk_trans_0 is None or self._fk_trans_0.shape != trans_0.shape: self._fk_trans_0 = np.empty_like(trans_0) np.copyto(self._fk_trans_0, trans_0) if T > 0: np.copyto(self._fk_vel_fut[:T], fk["global_velocity"][0, 1 : 1 + T, rb]) np.copyto(self._fk_angvel_fut[:T], fk["global_angular_velocity"][0, 1 : 1 + T, rb]) np.copyto(self._fk_quat_fut[:T], fk["global_rotation_quat"][0, 1 : 1 + T, rb]) self._fk_quat_fut_wxyz[:T, 0] = self._fk_quat_fut[:T, 3] self._fk_quat_fut_wxyz[:T, 1:4] = self._fk_quat_fut[:T, :3] neg = self._fk_quat_fut_wxyz[:T, 0] < 0.0 self._fk_quat_fut_wxyz[:T][neg] *= -1.0 trans_fut = fk["global_translation"][0, 1 : 1 + T] if self._fk_trans_fut is None or self._fk_trans_fut.shape != trans_fut.shape: self._fk_trans_fut = np.empty_like(trans_fut) np.copyto(self._fk_trans_fut, trans_fut) self._fill_vr_base_linvel_angvel_fut() def _fill_vr_base_linvel_angvel_fut(self): """Rotate future linear and angular velocity buffers in one pass.""" T = self.n_fut_frames_int if T <= 0: return vel_T6 = self._vel_fut_T6[:T] vel_T6[:, :3] = self._fk_vel_fut[:T] vel_T6[:, 3:6] = self._fk_angvel_fut[:T] q = self._fk_quat_fut_wxyz[:T] q_conj = self._q_conj_buffer[:T].reshape(T, 4) q_conj[:, 0] = q[:, 0] q_conj[:, 1:4] = -q[:, 1:4] qvec = q_conj[:, 1:4] w = q_conj[:, 0:1] rt = self._rot_t_buffer[:T] rc = self._rot_cross_buffer[:T] rt[:] = np.cross(qvec, vel_T6[:, :3]) rt *= 2.0 rc[:] = np.cross(qvec, rt) self._base_linvel_fut_buffer[:T] = vel_T6[:, :3] + w * rt + rc rt[:] = np.cross(qvec, vel_T6[:, 3:6]) rt *= 2.0 rc[:] = np.cross(qvec, rt) self._base_angvel_fut_buffer[:T] = vel_T6[:, 3:6] + w * rt + rc def _prepare_vr_fk_tensors( self, cur_root_pos: np.ndarray, cur_root_rot: np.ndarray, cur_dof_pos: np.ndarray, n_fut: int, ): """Fill preallocated FK input buffers and return torch views without reallocation.""" if ( n_fut <= 0 or self._fk_root_pos_seq_np is None or self._fk_root_rot_seq_np is None or self._fk_dof_pos_seq_np is None ): raise ValueError("VR FK sequence buffers are not initialized") np.copyto(self._fk_root_pos_seq_np[0, 0], cur_root_pos) np.copyto(self._fk_root_rot_seq_np[0, 0], cur_root_rot) np.copyto(self._fk_dof_pos_seq_np[0, 0], cur_dof_pos) np.copyto( self._fk_root_pos_seq_np[0, 1 : 1 + n_fut], self.external_fut_root_pos_queue[:n_fut], ) np.copyto( self._fk_root_rot_seq_np[0, 1 : 1 + n_fut], self.external_fut_root_rot_queue[:n_fut], ) np.copyto( self._fk_dof_pos_seq_np[0, 1 : 1 + n_fut], self.external_fut_dof_pos_queue[:n_fut], ) return ( self._fk_root_pos_seq_tensor, self._fk_root_rot_seq_tensor, self._fk_dof_pos_seq_tensor, ) def _get_future_root_quat_wxyz(self) -> np.ndarray: if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable; future_root_quat_wxyz will return zeros." ) return self._future_root_quat_wxyz_buffer fut_idx = self._get_future_frame_indices() q_root_xyzw = np.asarray( self.ref_raw_bodylink_rot[fut_idx, self.root_body_idx], dtype=np.float32, ) q_root_wxyz = self._future_root_quat_wxyz_buffer q_root_wxyz[:, 0] = q_root_xyzw[:, 3] q_root_wxyz[:, 1] = q_root_xyzw[:, 0] q_root_wxyz[:, 2] = q_root_xyzw[:, 1] q_root_wxyz[:, 3] = q_root_xyzw[:, 2] neg_mask = q_root_wxyz[:, 0] < 0.0 q_root_wxyz[neg_mask] *= -1.0 return self._future_root_quat_wxyz_buffer def _get_ref_keybody_indices(self, term_name: str) -> np.ndarray: keybody_idxs = self._keybody_indices_by_term_name.get(term_name, None) if keybody_idxs is None: raise ValueError( f"Keybody indices for term '{term_name}' were not cached. " "Ensure the term exists in motion policy obs and cache is initialized." ) return keybody_idxs def _get_obs_actor_velocity_command(self): return self._get_obs_velocity_command() def _get_obs_actor_projected_gravity(self): return self._get_obs_projected_gravity() def _get_obs_actor_rel_robot_root_ang_vel(self): return self._get_obs_rel_robot_root_ang_vel() def _get_obs_actor_dof_pos(self): return self._get_obs_dof_pos() def _get_obs_actor_dof_vel(self): return self._get_obs_dof_vel() def _get_obs_actor_last_action(self): return self._get_obs_last_action() def _get_obs_actor_ref_gravity_projection_cur(self): return self._get_obs_ref_gravity_projection_cur() def _get_obs_actor_ref_gravity_projection_fut(self): return self._get_obs_ref_gravity_projection_fut() def _get_obs_actor_ref_base_linvel_cur(self): return self._get_obs_ref_base_linvel_cur() def _get_obs_actor_ref_base_linvel_fut(self): return self._get_obs_ref_base_linvel_fut() def _get_obs_actor_ref_base_angvel_cur(self): return self._get_obs_ref_base_angvel_cur() def _get_obs_actor_ref_base_angvel_fut(self): return self._get_obs_ref_base_angvel_fut() def _get_obs_actor_ref_dof_pos_cur(self): return self._get_obs_ref_dof_pos_cur() def _get_obs_actor_ref_dof_pos_fut(self): return self._get_obs_ref_dof_pos_fut() def _get_obs_actor_ref_root_height_cur(self): return self._get_obs_ref_root_height_cur() def _get_obs_actor_ref_root_height_fut(self): return self._get_obs_ref_root_height_fut() def _get_obs_actor_ref_keybody_rel_pos_cur(self): return self._get_obs_ref_keybody_rel_pos_cur() def _get_obs_actor_ref_keybody_rel_pos_fut(self): return self._get_obs_ref_keybody_rel_pos_fut() def _get_obs_velocity_command(self): """Get velocity command observation (reuses pre-allocated array).""" self._velocity_cmd_obs[1] = self.vx self._velocity_cmd_obs[2] = self.vy self._velocity_cmd_obs[3] = self.vyaw self._velocity_cmd_obs[0] = float( np.linalg.norm(self._velocity_cmd_obs[1:4]) > 0.1 ) return self._velocity_cmd_obs def _get_obs_projected_gravity(self): return get_gravity_orientation(self.robot_root_rot_quat_wxyz) def _get_obs_rel_robot_root_ang_vel(self): return self.robot_root_ang_vel def _get_obs_dof_pos(self): """Get DOF position observation (pre-allocated buffer + index lookup, no dict/list).""" if self._lowstate_msg is None: return self._dof_pos_obs_buffer[: len(self.motion_dof_names_onnx)] if self.current_policy_mode == "motion": buf = self._dof_pos_obs_buffer ms = self._lowstate_msg.motor_state def_angles = self.motion_default_angles_onnx for i, ri in enumerate(self.motion_dof_real_indices): buf[i] = ms[ri].q - def_angles[i] return buf[: len(self.motion_dof_names_onnx)] def_angles = self.velocity_default_angles_onnx for i, ri in enumerate(self.velocity_dof_real_indices): self._dof_pos_obs_buffer[i] = ( self._lowstate_msg.motor_state[ri].q - def_angles[i] ) return self._dof_pos_obs_buffer[: len(self.velocity_dof_names_onnx)] def _get_obs_dof_vel(self): """Get DOF velocity observation (pre-allocated buffer + index lookup, no dict/list).""" if self._lowstate_msg is None: return self._dof_vel_obs_buffer[: len(self.motion_dof_names_onnx)] if self.current_policy_mode == "motion": buf = self._dof_vel_obs_buffer ms = self._lowstate_msg.motor_state for i, ri in enumerate(self.motion_dof_real_indices): buf[i] = ms[ri].dq return buf[: len(self.motion_dof_names_onnx)] for i, ri in enumerate(self.velocity_dof_real_indices): self._dof_vel_obs_buffer[i] = self._lowstate_msg.motor_state[ri].dq return self._dof_vel_obs_buffer[: len(self.velocity_dof_names_onnx)] def _get_obs_last_action(self): return self.actions_onnx.copy() def _get_obs_ref_motion_states(self): return np.concatenate( [self.ref_dof_pos_onnx_order, self.ref_dof_vel_onnx_order] ) def _get_obs_ref_dof_pos_fut(self): """Get future DOF position observation (reuses pre-allocated buffer).""" T = self.n_fut_frames_int if T <= 0: return np.zeros(0, dtype=np.float32) if getattr(self, "latest_obs_flag", False): if ( getattr(self, "external_fut_dof_pos_queue", None) is not None and self.external_fut_dof_pos_queue.shape[0] >= T ): pos_fut = self._pos_fut_buffer pos_fut[:, :] = self.external_fut_dof_pos_queue[:T].T pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0) # [N, T] return pos_fut_onnx.reshape(-1).astype(np.float32) return np.zeros(self.num_actions * T, dtype=np.float32) if not hasattr(self, "ref_dof_pos") or self.ref_dof_pos is None: self.get_logger().warn( "[VR] ref_dof_pos is unavailable and latest_obs is not active; returning zeros for ref_dof_pos_fut." ) return np.zeros(self.num_actions * T, dtype=np.float32) fut_idx = self._get_future_frame_indices() pos_fut = self._pos_fut_buffer pos_fut[:, :] = self.ref_dof_pos[fut_idx].T # Reorder to ONNX and flatten per training layout pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0) # [N, T] return pos_fut_onnx.reshape(-1).astype(np.float32) def _get_obs_ref_root_height_fut(self): """Get future root height observation (reuses pre-allocated buffer).""" T = self.n_fut_frames_int if T <= 0: return np.zeros(0, dtype=np.float32) if self.latest_obs_flag and self.external_fut_root_pos_queue is not None: root_pos_fut = self.external_fut_root_pos_queue[:, 2].astype(np.float32) return root_pos_fut.reshape(-1) if not hasattr(self, "ref_raw_bodylink_pos") or self.ref_raw_bodylink_pos is None: self.get_logger().warn( "[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_root_height_fut." ) return np.zeros(T, dtype=np.float32) fut_idx = self._get_future_frame_indices() h_fut = self._h_fut_buffer h_fut[0, :] = self.ref_raw_bodylink_pos[fut_idx, self.root_body_idx, 2] return h_fut.reshape(-1).astype(np.float32) def _get_obs_ref_root_pos_fut(self): """Get future root position observation (reuses pre-allocated buffer).""" T = self.n_fut_frames_int if T <= 0: return np.zeros(0, dtype=np.float32) if self.latest_obs_flag and self.external_fut_root_pos_queue is not None: pos_fut = self.external_fut_root_pos_queue.astype(np.float32) return pos_fut.reshape(-1).astype(np.float32) if not hasattr(self, "ref_raw_bodylink_pos") or self.ref_raw_bodylink_pos is None: self.get_logger().warn( "[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_root_pos_fut." ) return np.zeros(3 * T, dtype=np.float32) fut_idx = self._get_future_frame_indices() pos_fut = self._root_pos_fut_buffer pos_fut[:, :] = self.ref_raw_bodylink_pos[fut_idx, self.root_body_idx, :] return pos_fut.reshape(-1).astype(np.float32) def _get_obs_ref_dof_pos_cur(self): return self.ref_dof_pos_onnx_order def _get_obs_ref_dof_vel_cur(self): return self.ref_dof_vel_onnx_order def _get_obs_ref_root_height_cur(self): if not self.latest_obs_flag: return self.ref_raw_bodylink_pos[ self.ref_motion_frame_idx, self.root_body_idx, 2 ] return float(self.ref_root_pos_raw[2]) def _get_obs_ref_root_pos_cur(self): return self.ref_root_pos_raw.astype(np.float32) def _get_obs_ref_gravity_projection_cur(self): if getattr(self, "_use_fk_vr", False): return get_gravity_orientation(self._fk_quat_0_root_wxyz) if getattr(self, "latest_obs_flag", False) and getattr( self, "external_latest_obs", None ) is not None: q_root_wxyz = self.external_latest_obs[0, 61:65].astype(np.float32) q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz) return get_gravity_orientation(q_root_wxyz) if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for gravity_projection_cur." ) return np.zeros(3, dtype=np.float32) q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx] q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw) q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz) return get_gravity_orientation(q_root_wxyz) def _get_obs_ref_gravity_projection_fut(self): T = self.n_fut_frames_int if T <= 0: return np.zeros(0, dtype=np.float32) if getattr(self, "_use_fk_vr", False): q_root_wxyz = self._fk_quat_fut_wxyz[:T] gravity_fut = self._gravity_fut_buffer qw = q_root_wxyz[:, 0] qx = q_root_wxyz[:, 1] qy = q_root_wxyz[:, 2] qz = q_root_wxyz[:, 3] gravity_fut[:, 0] = 2.0 * (-qz * qx + qw * qy) gravity_fut[:, 1] = -2.0 * (qz * qy + qw * qx) gravity_fut[:, 2] = 1.0 - 2.0 * (qw * qw + qz * qz) return gravity_fut.reshape(-1).astype(np.float32) if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_gravity_projection_fut." ) return np.zeros(3 * T, dtype=np.float32) q_root_wxyz = self._get_future_root_quat_wxyz() gravity_fut = self._gravity_fut_buffer qw = q_root_wxyz[:, 0] qx = q_root_wxyz[:, 1] qy = q_root_wxyz[:, 2] qz = q_root_wxyz[:, 3] gravity_fut[:, 0] = 2.0 * (-qz * qx + qw * qy) gravity_fut[:, 1] = -2.0 * (qz * qy + qw * qx) gravity_fut[:, 2] = 1.0 - 2.0 * (qw * qw + qz * qz) return gravity_fut.reshape(-1).astype(np.float32) def _get_obs_ref_base_linvel_cur(self): if getattr(self, "_use_fk_vr", False): self._quat_rotate_inv_wxyz_single( self._fk_quat_0_root_wxyz, self._fk_vel_0_root, self._rotated_3vec_buffer ) return self._rotated_3vec_buffer if getattr(self, "latest_obs_flag", False) and getattr( self, "external_latest_obs", None ) is not None: return np.zeros(3, dtype=np.float32) if not hasattr(self, "ref_global_velocity") or self.ref_global_velocity is None: self.get_logger().warn( "[VR] ref_global_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_cur." ) return np.zeros(3, dtype=np.float32) if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_cur." ) return np.zeros(3, dtype=np.float32) q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx] q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw) q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz) v_root_w = np.asarray( self.ref_global_velocity[self.ref_motion_frame_idx, self.root_body_idx], dtype=np.float32, ) v_root = self._quat_rotate_inv_wxyz(q_root_wxyz, v_root_w) return np.asarray(v_root, dtype=np.float32).reshape(3) def _get_obs_ref_base_linvel_fut(self): T = self.n_fut_frames_int if T <= 0: return np.zeros(0, dtype=np.float32) if getattr(self, "_use_fk_vr", False): return self._base_linvel_fut_buffer[:T].reshape(-1).astype(np.float32) if not hasattr(self, "ref_global_velocity") or self.ref_global_velocity is None: self.get_logger().warn( "[VR] ref_global_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_fut." ) return np.zeros(3 * T, dtype=np.float32) if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_fut." ) return np.zeros(3 * T, dtype=np.float32) fut_idx = self._get_future_frame_indices() q_root_wxyz = self._get_future_root_quat_wxyz() v_root_w = np.asarray( self.ref_global_velocity[fut_idx, self.root_body_idx], dtype=np.float32, ) base_linvel_fut = self._base_linvel_fut_buffer base_linvel_fut[:, :] = self._quat_rotate_inv_wxyz(q_root_wxyz, v_root_w) return base_linvel_fut.reshape(-1).astype(np.float32) def _get_obs_ref_base_angvel_cur(self): if getattr(self, "_use_fk_vr", False): self._quat_rotate_inv_wxyz_single( self._fk_quat_0_root_wxyz, self._fk_angvel_0_root, self._rotated_angvel_cur_buffer, ) return self._rotated_angvel_cur_buffer if getattr(self, "latest_obs_flag", False) and getattr( self, "external_latest_obs", None ) is not None: return np.zeros(3, dtype=np.float32) if not hasattr(self, "ref_global_angular_velocity") or self.ref_global_angular_velocity is None: self.get_logger().warn( "[VR] ref_global_angular_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_cur." ) return np.zeros(3, dtype=np.float32) if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_cur." ) return np.zeros(3, dtype=np.float32) q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx] q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw) q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz) w_root_w = np.asarray( self.ref_global_angular_velocity[self.ref_motion_frame_idx, self.root_body_idx], dtype=np.float32, ) w_root = self._quat_rotate_inv_wxyz(q_root_wxyz, w_root_w) return np.asarray(w_root, dtype=np.float32).reshape(3) def _get_obs_ref_base_angvel_fut(self): T = self.n_fut_frames_int if T <= 0: return np.zeros(0, dtype=np.float32) if getattr(self, "_use_fk_vr", False): return self._base_angvel_fut_buffer[:T].reshape(-1).astype(np.float32) if not hasattr(self, "ref_global_angular_velocity") or self.ref_global_angular_velocity is None: self.get_logger().warn( "[VR] ref_global_angular_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_fut." ) return np.zeros(3 * T, dtype=np.float32) if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_fut." ) return np.zeros(3 * T, dtype=np.float32) fut_idx = self._get_future_frame_indices() q_root_wxyz = self._get_future_root_quat_wxyz() w_root_w = np.asarray( self.ref_global_angular_velocity[fut_idx, self.root_body_idx], dtype=np.float32, ) base_angvel_fut = self._base_angvel_fut_buffer base_angvel_fut[:, :] = self._quat_rotate_inv_wxyz(q_root_wxyz, w_root_w) return base_angvel_fut.reshape(-1).astype(np.float32) def _get_obs_ref_keybody_rel_pos_cur(self): if getattr(self, "_use_fk_vr", False) and self._fk_trans_0 is not None: keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur") n_keybodies = int(keybody_idxs.shape[0]) if n_keybodies == 0: return np.zeros(0, dtype=np.float32) if not self._root_only_fk_has_required_keybodies(keybody_idxs): return np.zeros(3 * n_keybodies, dtype=np.float32) root_pos = self._fk_trans_0[self.root_body_idx] keybody_pos = self._fk_trans_0[keybody_idxs] rel_pos_w = keybody_pos - root_pos[None, :] rel_pos_root = self._quat_rotate_inv_wxyz(self._fk_quat_0_root_wxyz, rel_pos_w) return np.asarray(rel_pos_root, dtype=np.float32).reshape(-1) if getattr(self, "latest_obs_flag", False) and getattr( self, "external_latest_obs", None ) is not None: keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur") n_keybodies = int(keybody_idxs.shape[0]) if n_keybodies == 0: return np.zeros(0, dtype=np.float32) return np.zeros(3 * n_keybodies, dtype=np.float32) if not hasattr(self, "ref_raw_bodylink_pos") or self.ref_raw_bodylink_pos is None: self.get_logger().warn( "[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_cur." ) keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur") n_keybodies = int(keybody_idxs.shape[0]) if n_keybodies == 0: return np.zeros(0, dtype=np.float32) return np.zeros(3 * n_keybodies, dtype=np.float32) if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_cur." ) keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur") n_keybodies = int(keybody_idxs.shape[0]) if n_keybodies == 0: return np.zeros(0, dtype=np.float32) return np.zeros(3 * n_keybodies, dtype=np.float32) keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur") n_keybodies = int(keybody_idxs.shape[0]) if n_keybodies == 0: return np.zeros(0, dtype=np.float32) frame_idx = self.ref_motion_frame_idx ref_body_global_pos = np.asarray(self.ref_raw_bodylink_pos[frame_idx], dtype=np.float32) ref_root_global_pos = ref_body_global_pos[self.root_body_idx] q_root_xyzw = self.ref_raw_bodylink_rot[frame_idx, self.root_body_idx] q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw) q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz) rel_pos_w = ref_body_global_pos[keybody_idxs] - ref_root_global_pos[None, :] rel_pos_root = self._quat_rotate_inv_wxyz(q_root_wxyz, rel_pos_w) return np.asarray(rel_pos_root, dtype=np.float32).reshape(-1) def _get_obs_ref_keybody_rel_pos_fut(self): T = self.n_fut_frames_int if T <= 0: return np.zeros(0, dtype=np.float32) if getattr(self, "_use_fk_vr", False) and self._fk_trans_fut is not None: keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_fut") n_keybodies = int(keybody_idxs.shape[0]) if n_keybodies == 0: return np.zeros((T, 0), dtype=np.float32).reshape(-1) if not self._root_only_fk_has_required_keybodies(keybody_idxs): return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1) ref_body = self._fk_trans_fut[:T] # (T, num_bodies, 3) ref_root = ref_body[:, self.root_body_idx, :] # (T, 3) if self._keybody_rel_pos_fut_buffer.shape[1] != n_keybodies: self._keybody_rel_pos_fut_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32) self._keybody_rel_pos_w_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32) elif ( self._keybody_rel_pos_w_buffer is None or self._keybody_rel_pos_w_buffer.shape[0] < T or self._keybody_rel_pos_w_buffer.shape[1] != n_keybodies ): self._keybody_rel_pos_w_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32) rel_pos_fut = self._keybody_rel_pos_fut_buffer np.subtract( ref_body[:, keybody_idxs, :], ref_root[:, None, :], out=self._keybody_rel_pos_w_buffer[:T, :n_keybodies, :], ) rel_pos_fut[:, :, :] = self._quat_rotate_inv_wxyz( self._fk_quat_fut_wxyz[:T, None, :], self._keybody_rel_pos_w_buffer[:T, :n_keybodies, :], ) return rel_pos_fut.reshape(-1).astype(np.float32) keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_fut") n_keybodies = int(keybody_idxs.shape[0]) if not hasattr(self, "ref_raw_bodylink_pos") or self.ref_raw_bodylink_pos is None: self.get_logger().warn( "[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_fut." ) if n_keybodies == 0: return np.zeros((T, 0), dtype=np.float32).reshape(-1) return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1) if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None: self.get_logger().warn( "[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_fut." ) if n_keybodies == 0: return np.zeros((T, 0), dtype=np.float32).reshape(-1) return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1) if n_keybodies == 0: return np.zeros((T, 0), dtype=np.float32).reshape(-1) fut_idx = self._get_future_frame_indices() q_root_wxyz = self._get_future_root_quat_wxyz() ref_body_global_pos = np.asarray(self.ref_raw_bodylink_pos[fut_idx], dtype=np.float32) ref_root_global_pos = ref_body_global_pos[:, self.root_body_idx, :] rel_pos_w = ref_body_global_pos[:, keybody_idxs, :] - ref_root_global_pos[:, None, :] if self._keybody_rel_pos_fut_buffer.shape[1] != n_keybodies: self._keybody_rel_pos_fut_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32) rel_pos_fut = self._keybody_rel_pos_fut_buffer rel_pos_fut[:, :, :] = self._quat_rotate_inv_wxyz(q_root_wxyz[:, None, :], rel_pos_w) return rel_pos_fut.reshape(-1).astype(np.float32) def _get_obs_place_holder(self): return np.zeros(self.actor_place_holder_ndim, dtype=np.float32) # =========== Policy Obeservation Methods =========== def _warmup_fk_for_vr(self): """Run one FK warmup step when entering VR motion mode.""" try: if ( getattr(self, "fk", None) is None or not getattr(self, "fk_initialized", False) ): return if getattr(self, "external_latest_obs", None) is None: return if getattr(self, "external_fut_dof_pos_queue", None) is None: return n_fut = int(getattr(self, "n_fut_frames", 0)) if ( n_fut <= 0 or self.external_fut_root_pos_queue is None or self.external_fut_root_rot_queue is None ): return latest = self.external_latest_obs[0] cur_root_pos = latest[58:61] cur_root_rot = latest[61:65] cur_dof_pos = latest[0:29] root_pos_tensor, root_rot_tensor, dof_pos_tensor = ( self._prepare_vr_fk_tensors( cur_root_pos=cur_root_pos, cur_root_rot=cur_root_rot, cur_dof_pos=cur_dof_pos, n_fut=n_fut, ) ) fk_out = self.fk( root_pos=root_pos_tensor, root_quat=root_rot_tensor, dof_pos=dof_pos_tensor, fps=float(1.0 / self.dt), quat_format="wxyz", vel_smoothing_sigma=0.0, compute_velocity=False, ) self._fk_vr_out = { k: v.detach().cpu().numpy() for k, v in fk_out.items() } except Exception as e: self.get_logger().warn(f"[VR] FK warmup failed, fallback to zeros: {e}") def _low_state_callback(self, ls_msg: LowState): """Process low-level robot state and remote controller input. Main callback that handles: - Remote controller input processing - Motion selection based on button presses - Safety state checking - Velocity command extraction Motion Button Mapping: - A button: Enable policy (defaults to velocity mode) - B button: Switch from velocity to motion mode - Y button: Switch from motion back to velocity mode - UP/DOWN/LEFT/RIGHT: Motion clip selection (only in velocity tracking mode) Args: ls_msg: LowState message containing robot sensor data and remote controller input """ self._lowstate_msg = ls_msg self.remote_controller.set(ls_msg.wireless_remote) # A button: Toggle policy enable state (default to velocity mode) if ( self._is_button_pressed(KeyMap.A) and self.robot_state_ready ): self.policy_enabled = True self.current_policy_mode = "velocity" # Default to velocity mode self.latest_obs_flag = False self._reset_motion_action_ema_filter() self._reset_counter() if hasattr(self, "use_kv_cache") and self.use_kv_cache: self.motion_kv_cache.fill(0) self.motion_step_idx = 0 # Initialize with velocity model's default angles self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32) self.target_dof_pos_onnx = self.velocity_default_angles_onnx.copy() # Publish velocity model's control parameters (kps/kds) self._publish_control_params() self.get_logger().info( f"Policy enabled in {self.current_policy_mode} tracking mode" ) # B button: Switch to motion tracking mode (only when policy is enabled) if ( self._is_button_pressed(KeyMap.B) and self.robot_state_ready and self.policy_enabled and self.current_policy_mode == "velocity" # Only allow switch from velocity mode ): vr_data_available = bool( getattr(self, "enable_teleop_reference", True) and getattr(self, "external_obs_received", False) and getattr(self, "external_latest_obs", None) is not None ) vr_ready = self._is_vr_ready_for_motion() if ( self.enable_teleop_reference and self.require_vr_data_for_motion and not vr_ready ): self.get_logger().warn( "require_vr_data_for_motion=True but the VR queue is not ready yet; staying in velocity mode." ) else: # Don't automatically switch to next motion clip - keep current selection if hasattr(self, "all_motion_data") and self.all_motion_data: # Load the current motion clip data (don't change current_motion_clip_index) self._load_current_motion() self.current_policy_mode = "motion" self._reset_motion_action_ema_filter() self._reset_counter() if hasattr(self, "use_kv_cache") and self.use_kv_cache: self.motion_kv_cache.fill(0) self.get_logger().info("Motion KV-Cache reset.") self.motion_step_idx = 0 self.get_logger().info("Motion Step Index reset to 0.") # Clear any pending actions to prevent conflicts between policies # Use motion model's default angles self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32) self.target_dof_pos_onnx = self.motion_default_angles_onnx.copy() # Publish motion model's control parameters (kps/kds) self._publish_control_params() self.latest_obs_flag = bool(vr_data_available) source_mode = "ZMQ latest_obs" if self.latest_obs_flag else "offline motion" self.get_logger().info( f"Switched to motion tracking mode ({source_mode}) - motion clip index: {self.current_motion_clip_index}" ) if self.latest_obs_flag: self.get_logger().info("[VR] Reference trajectory source: ZMQ latest_obs") self._warmup_fk_for_vr() self.motion_in_progress = True if ( self._is_button_pressed(KeyMap.Y) and self.robot_state_ready and self.policy_enabled and self.current_policy_mode == "motion" # Only allow switch from motion mode ): self._switch_to_velocity_mode() # Get velocity commands only in velocity tracking mode if self.current_policy_mode == "velocity": self.vx, self.vy, self.vyaw = self.remote_controller.get_velocity_commands() else: # In motion tracking mode, ignore joystick input self.vx, self.vy, self.vyaw = 0.0, 0.0, 0.0 # Handle motion clip selection in velocity tracking mode (UP/DOWN/LEFT/RIGHT) if ( self.current_policy_mode == "velocity" and self.policy_enabled and self.robot_state_ready ): # Handle motion clip selection with UP/DOWN/LEFT/RIGHT buttons if self._is_button_pressed(KeyMap.up): # Switch to previous motion clip if hasattr(self, "all_motion_data") and self.all_motion_data: self.current_motion_clip_index = ( self.current_motion_clip_index - 1 ) % len(self.all_motion_data) self.get_logger().info( f"Selected previous motion clip: " f"{self.motion_file_names[self.current_motion_clip_index]}" ) elif self._is_button_pressed(KeyMap.down): # Switch to next motion clip if hasattr(self, "all_motion_data") and self.all_motion_data: self.current_motion_clip_index = ( self.current_motion_clip_index + 1 ) % len(self.all_motion_data) self.get_logger().info( f"Selected next motion clip: " f"{self.motion_file_names[self.current_motion_clip_index]}" ) elif self._is_button_pressed(KeyMap.left): # Select first motion clip if hasattr(self, "all_motion_data") and self.all_motion_data: self.current_motion_clip_index = 0 self.get_logger().info( f"Selected first motion clip: " f"{self.motion_file_names[self.current_motion_clip_index]}" ) elif self._is_button_pressed(KeyMap.right): # Select last motion clip if hasattr(self, "all_motion_data") and self.all_motion_data: self.current_motion_clip_index = len(self.all_motion_data) - 1 self.get_logger().info( f"Selected last motion clip: " f"{self.motion_file_names[self.current_motion_clip_index]}" ) def run(self): """Main execution loop for policy inference and action publication.""" # Only run if setup is completed if not hasattr(self, '_setup_completed') or not self._setup_completed: return t_loop_start = time.perf_counter() now = time.time() t_io = time.perf_counter() buf = getattr(self, "_ros_latest_obs_buffer", None) if buf is not None: self._ros_latest_obs_buffer = None frame_idx, obs_arr = buf if frame_idx is not None: self._npz_replay_frame_index = frame_idx self._store_external_latest_obs(obs_arr[None, :]) self._poll_zmq_latest_obs() if getattr(self, "current_policy_mode", None) == "motion": if self._last_vr_status_log_time is None: self._last_vr_status_log_time = now elif now - self._last_vr_status_log_time >= 5.0: vr_available = bool( getattr(self, "external_obs_received", False) and getattr(self, "external_latest_obs", None) is not None ) queue_stats = self._latest_obs_buffer.get_queue_stats() freq = queue_stats.get("expected_freq") if vr_available: self.get_logger().info( "[VR-STATUS] ZMQ latest_obs streaming | " f"buffer_size={queue_stats['queue_size']} " f"expected_freq={freq:.1f}Hz" if freq else f"buffer_size={queue_stats['queue_size']} expected_freq=unknown" ) else: self.get_logger().warn( "[VR-STATUS] No new ZMQ latest_obs received in the last 5 seconds; using offline reference or the last buffered VR state." ) self._last_vr_status_log_time = now if ( getattr(self, "require_vr_data_for_motion", False) and getattr(self, "policy_enabled", False) and not getattr(self, "_vr_ready_logged", False) and self._is_vr_ready_for_motion() ): self.get_logger().info( f"[VR] VR queue is ready for motion mode (seen_frames={int(getattr(self, '_external_seen_frames', 0))}, " f"n_fut={int(getattr(self, 'n_fut_frames', 0) or 0)}, " f"delay={int(getattr(self, 'zmq_jitter_delay_frames', 0) or 0)})" ) self._vr_ready_logged = True self._publish_latest_obs() io_ms = self._timing_ms(t_io) policy_timing = self._run_without_profiling() _run_elapsed = 0.0 if policy_timing is not None: _run_elapsed = float(policy_timing.get("policy_total_ms", 0.0)) / 1000.0 if ( getattr(self, "current_policy_mode", None) == "motion" and getattr(self, "latest_obs_flag", False) and _run_elapsed > 0.5 and not getattr(self, "_vr_cold_start_logged", False) ): self._vr_cold_start_logged = True self.get_logger().info( "[VR] The first motion step is a cold start (FK/ONNX initialization) and may take about 1 second." ) if ( getattr(self, "current_policy_mode", None) == "motion" and getattr(self, "latest_obs_flag", False) and _run_elapsed > 1.15 * self.dt and _run_elapsed <= 0.5 ): self._policy_slow_count = getattr(self, "_policy_slow_count", 0) + 1 if self._policy_slow_count == 1 or self._policy_slow_count % 50 == 0: self.get_logger().warn( f"[VR] Policy step latency {_run_elapsed*1000:.1f} ms exceeds the target {self.dt*1000:.1f} ms. " f"Estimated /humanoid/action rate: {1.0/_run_elapsed:.1f} Hz (target {1.0/self.dt:.0f} Hz). " "The main bottleneck is usually FK or ONNX inference; if the system settles near 30 Hz, consider setting policy_freq to 30." ) if policy_timing is not None: sample = dict(policy_timing) sample["io_ms"] = io_ms sample["loop_total_ms"] = self._timing_ms(t_loop_start) self._record_timing_sample(sample) def _read_onnx_metadata(self, onnx_model_path: str) -> dict: """Read model metadata from ONNX file and parse into Python types.""" model = onnx.load(str(onnx_model_path)) meta = {p.key: p.value for p in model.metadata_props} def _parse_floats(csv_str: str): return np.array( [float(x) for x in csv_str.split(",") if x != ""], dtype=np.float32, ) result = {} result["action_scale"] = _parse_floats(meta["action_scale"]) result["kps"] = _parse_floats(meta["joint_stiffness"]) result["kds"] = _parse_floats(meta["joint_damping"]) result["default_joint_pos"] = _parse_floats(meta["default_joint_pos"]) result["joint_names"] = [x for x in meta["joint_names"].split(",") if x != ""] return result def _store_external_latest_obs(self, arr: np.ndarray): """Store latest_obs and maintain the current/future frame queues.""" if arr.ndim == 1: arr = arr[None, :] if arr.shape[1] < self.latest_obs_expected_dim: self.get_logger().warn( f"Received latest_obs dim={arr.shape[1]}, expected >= {self.latest_obs_expected_dim}" ) return clipped = arr[:, : self.latest_obs_expected_dim].astype(np.float32, copy=False) current_time = time.time() self.external_latest_obs = clipped self.external_obs_received = True self.last_external_obs_time = current_time self._external_seen_frames = int(getattr(self, "_external_seen_frames", 0)) + 1 latest_root_pos = clipped[0, 58:61] latest_root_rot = clipped[0, 61:65] latest_dof_pos = clipped[0, :29] latest_dof_vel = clipped[0, 29:58] if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None: raw_idx = getattr(self, "_npz_replay_frame_index", None) try: latest_frame_idx = int(raw_idx) if raw_idx is not None else -1 except Exception: latest_frame_idx = -1 if self._prev_external_dof_pos is None: self._prev_external_dof_pos = np.empty_like(self.external_fut_dof_pos_queue[0]) self._prev_external_dof_vel = np.empty_like(self.external_fut_dof_vel_queue[0]) self._prev_external_root_pos = np.empty_like(self.external_fut_root_pos_queue[0]) if self.external_fut_root_rot_queue is not None: self._prev_external_root_rot = np.empty_like( self.external_fut_root_rot_queue[0] ) np.copyto(self._prev_external_dof_pos, self.external_fut_dof_pos_queue[0]) np.copyto(self._prev_external_dof_vel, self.external_fut_dof_vel_queue[0]) np.copyto(self._prev_external_root_pos, self.external_fut_root_pos_queue[0]) if self.external_fut_root_rot_queue is not None: np.copyto(self._prev_external_root_rot, self.external_fut_root_rot_queue[0]) if self.external_fut_frame_idx_queue is not None: try: self._prev_external_frame_idx = int(self.external_fut_frame_idx_queue[0]) except Exception: self._prev_external_frame_idx = -1 self.external_fut_dof_pos_queue[:-1] = self.external_fut_dof_pos_queue[1:] self.external_fut_dof_pos_queue[-1] = latest_dof_pos self.external_fut_dof_vel_queue[:-1] = self.external_fut_dof_vel_queue[1:] self.external_fut_dof_vel_queue[-1] = latest_dof_vel self.external_fut_root_pos_queue[:-1] = self.external_fut_root_pos_queue[1:] self.external_fut_root_pos_queue[-1] = latest_root_pos if self.external_fut_root_rot_queue is not None: self.external_fut_root_rot_queue[:-1] = self.external_fut_root_rot_queue[1:] self.external_fut_root_rot_queue[-1] = latest_root_rot if self.external_fut_frame_idx_queue is not None: self.external_fut_frame_idx_queue[:-1] = self.external_fut_frame_idx_queue[1:] self.external_fut_frame_idx_queue[-1] = latest_frame_idx def _poll_zmq_latest_obs(self): """Poll the ZMQ latest_obs buffer with stale-data checks and delay.""" current_time = time.time() data, timestamp, is_stale, frame_index, sender_timestamp = self._latest_obs_buffer.get_with_age_and_delay( max_age=self.max_data_age, delay_steps=int(getattr(self, "zmq_jitter_delay_frames", 0)), ) if data is None: return if frame_index is not None: self._npz_replay_frame_index = int(frame_index) self._latest_sender_timestamp = sender_timestamp if is_stale: self.stale_data_warning_count += 1 if self.stale_data_warning_count % 50 == 0: age_ms = (current_time - timestamp) * 1000.0 self.get_logger().warn( f"ZMQ latest_obs is stale: age={age_ms:.1f}ms " f"(threshold={self.max_data_age*1000:.1f}ms), " f"stale_count={self.stale_data_warning_count}" ) queue_stats = self._latest_obs_buffer.get_queue_stats() if queue_stats.get("expected_freq"): self.get_logger().warn( f"latest_obs buffer: size={queue_stats['queue_size']}, " f"avg_interval={queue_stats['avg_interval']*1000:.1f}ms, " f"expected_freq={queue_stats['expected_freq']:.1f}Hz" ) else: if self.stale_data_warning_count > 0: self.stale_data_warning_count = 0 if self.last_poll_time is not None: poll_interval = current_time - self.last_poll_time if poll_interval > 0.03: self.get_logger().debug( f"Policy poll interval {poll_interval*1000:.1f}ms (>30ms)" ) self.last_poll_time = current_time self._store_external_latest_obs(np.asarray(data, dtype=np.float32)) if ( getattr(self, "enable_teleop_reference", True) and getattr(self, "require_vr_data_for_motion", False) and not getattr(self, "latest_obs_flag", False) and self._is_vr_ready_for_motion() ): self.latest_obs_flag = True if not getattr(self, "_vr_fk_started_logged", False): self.get_logger().info( "[VR] ZMQ data is ready; the main thread will build the reference trajectory from live ZMQ input." ) self._vr_fk_started_logged = True def _publish_latest_obs(self): """Publish the latest_obs topic for debugging or reuse.""" if self.external_latest_obs is None: return try: msg = Float32MultiArray() msg.data = self.external_latest_obs[0].tolist() self.latest_obs_pub.publish(msg) except Exception as e: self.get_logger().error(f"Failed to publish latest_obs: {e}") def _apply_onnx_metadata(self): """Apply PD/scale/defaults from ONNX metadata as authoritative values. Load separate metadata for velocity and motion models.""" # Load velocity model metadata velocity_meta = self._read_onnx_metadata(self.velocity_onnx_path) self.velocity_dof_names_onnx = velocity_meta["joint_names"] self.velocity_action_scale_onnx = velocity_meta["action_scale"].astype(np.float32) self.velocity_kps_onnx = velocity_meta["kps"].astype(np.float32) self.velocity_kds_onnx = velocity_meta["kds"].astype(np.float32) self.velocity_default_angles_onnx = velocity_meta["default_joint_pos"].astype(np.float32) # Load motion model metadata motion_meta = self._read_onnx_metadata(self.motion_onnx_path) self.motion_dof_names_onnx = motion_meta["joint_names"] self.motion_action_scale_onnx = motion_meta["action_scale"].astype(np.float32) self.motion_kps_onnx = motion_meta["kps"].astype(np.float32) self.motion_kds_onnx = motion_meta["kds"].astype(np.float32) self.motion_default_angles_onnx = motion_meta["default_joint_pos"].astype(np.float32) # Use velocity model metadata as default (for backward compatibility) self.dof_names_onnx = self.velocity_dof_names_onnx self.action_scale_onnx = self.velocity_action_scale_onnx self.kps_onnx = self.velocity_kps_onnx self.kds_onnx = self.velocity_kds_onnx self.default_angles_onnx = self.velocity_default_angles_onnx self.default_angles_dict = { name: float(self.default_angles_onnx[idx]) for idx, name in enumerate(self.dof_names_onnx) } def _build_dof_mappings(self): # Map ONNX <-> MJCF for control # Check if all ONNX names exist in real_dof_names (use velocity as reference) missing_names = [name for name in self.velocity_dof_names_onnx if name not in self.real_dof_names] if missing_names: self.get_logger().warn(f"Missing names in real_dof_names: {missing_names}") # Build mappings for velocity model self.velocity_onnx_to_real = [ self.velocity_dof_names_onnx.index(name) for name in self.real_dof_names ] self.velocity_kps_real = self.velocity_kps_onnx[self.velocity_onnx_to_real].astype(np.float32) self.velocity_kds_real = self.velocity_kds_onnx[self.velocity_onnx_to_real].astype(np.float32) # Build mappings for motion model self.motion_onnx_to_real = [ self.motion_dof_names_onnx.index(name) for name in self.real_dof_names ] self.motion_kps_real = self.motion_kps_onnx[self.motion_onnx_to_real].astype(np.float32) self.motion_kds_real = self.motion_kds_onnx[self.motion_onnx_to_real].astype(np.float32) # Use velocity model mappings as default (for backward compatibility) self.onnx_to_real = self.velocity_onnx_to_real self.kps_real = self.velocity_kps_real self.kds_real = self.velocity_kds_real self.default_angles_mu = self.velocity_default_angles_onnx[self.velocity_onnx_to_real].astype(np.float32) self.action_scale_mu = self.velocity_action_scale_onnx[self.velocity_onnx_to_real].astype(np.float32) # Build ref_to_onnx mapping (for motion model) self.ref_to_onnx = [ self.dof_names_ref_motion.index(name) for name in self.motion_dof_names_onnx ] # Pre-compute default angles dictionaries for efficient observation building self.velocity_default_angles_dict = { name: float(self.velocity_default_angles_onnx[idx]) for idx, name in enumerate(self.velocity_dof_names_onnx) } self.motion_default_angles_dict = { name: float(self.motion_default_angles_onnx[idx]) for idx, name in enumerate(self.motion_dof_names_onnx) } # Pre-compute dof_names_onnx arrays for each mode (avoid repeated selection) self.velocity_dof_names_onnx_array = np.array(self.velocity_dof_names_onnx) self.motion_dof_names_onnx_array = np.array(self.motion_dof_names_onnx) self.motion_dof_real_indices = [ self.real_dof_names.index(n) for n in self.motion_dof_names_onnx ] self.velocity_dof_real_indices = [ self.real_dof_names.index(n) for n in self.velocity_dof_names_onnx ] n_dof = max(len(self.motion_dof_names_onnx), len(self.velocity_dof_names_onnx)) self._dof_pos_obs_buffer = np.zeros(n_dof, dtype=np.float32) self._dof_vel_obs_buffer = np.zeros(n_dof, dtype=np.float32) # Pre-allocate arrays for future frame observations if hasattr(self, "n_fut_frames") and self.n_fut_frames is not None: self.n_fut_frames_int = int(self.n_fut_frames) if self.n_fut_frames_int > 0: self._pos_fut_buffer = np.zeros( (len(self.dof_names_ref_motion), self.n_fut_frames_int), dtype=np.float32 ) self._h_fut_buffer = np.zeros((1, self.n_fut_frames_int), dtype=np.float32) self._root_pos_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32) else: self.n_fut_frames_int = 0 else: self.n_fut_frames_int = 0 self._future_frame_offsets = np.arange(1, self.n_fut_frames_int + 1, dtype=np.int64) self._future_frame_indices_buffer = np.zeros(self.n_fut_frames_int, dtype=np.int64) self._future_root_quat_wxyz_buffer = np.zeros((self.n_fut_frames_int, 4), dtype=np.float32) self._gravity_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32) self._base_linvel_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32) self._base_angvel_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32) self._keybody_rel_pos_fut_buffer = np.zeros((self.n_fut_frames_int, 0, 3), dtype=np.float32) self._keybody_rel_pos_w_buffer = None max_t = max(1, self.n_fut_frames_int) self._vel_fut_T6 = np.zeros((max_t, 6), dtype=np.float32) self._rot_t_buffer = np.zeros((max_t, 3), dtype=np.float32) self._rot_cross_buffer = np.zeros((max_t, 3), dtype=np.float32) self._use_fk_vr = False self._fk_vel_0_root = np.zeros(3, dtype=np.float32) self._fk_angvel_0_root = np.zeros(3, dtype=np.float32) self._fk_quat_0_root = np.zeros(4, dtype=np.float32) self._fk_trans_0 = None max_t = max(1, self.n_fut_frames_int) self._fk_vel_fut = np.zeros((max_t, 3), dtype=np.float32) self._fk_angvel_fut = np.zeros((max_t, 3), dtype=np.float32) self._fk_quat_fut = np.zeros((max_t, 4), dtype=np.float32) self._fk_trans_fut = None self._q_conj_buffer = np.zeros((max_t + 1, 4), dtype=np.float32) self._rotated_3vec_buffer = np.zeros(3, dtype=np.float32) self._rotated_angvel_cur_buffer = np.zeros(3, dtype=np.float32) self._cross_t_buffer = np.zeros(3, dtype=np.float32) self._fk_quat_0_root_wxyz = np.zeros(4, dtype=np.float32) self._fk_quat_fut_wxyz = np.zeros((max_t, 4), dtype=np.float32) # Pre-allocate velocity command observation array self._velocity_cmd_obs = np.zeros(4, dtype=np.float32) # Publish kps and kds parameters (use velocity as default) self._publish_control_params() def _publish_control_params(self): """Publish kps and kds control parameters based on current policy mode. Called during initialization and mode switching to ensure control node receives the correct parameters for the current policy mode. """ try: # Use appropriate parameters based on current policy mode if self.current_policy_mode == "motion": current_kps = self.motion_kps_real current_kds = self.motion_kds_real else: # velocity mode current_kps = self.velocity_kps_real current_kds = self.velocity_kds_real # Publish kps kps_msg = Float32MultiArray() kps_msg.data = current_kps.tolist() self.kps_pub.publish(kps_msg) # Publish kds kds_msg = Float32MultiArray() kds_msg.data = current_kds.tolist() self.kds_pub.publish(kds_msg) self.get_logger().info( f"Published control parameters ({self.current_policy_mode} mode): " f"kps={len(current_kps)}, kds={len(current_kds)}" ) except Exception as e: self.get_logger().error(f"Failed to publish control parameters: {e}") def _publish_policy_mode(self): """Publish current policy mode status.""" try: mode_msg = String() mode_msg.data = f"{self.current_policy_mode}_{'enabled' if self.policy_enabled else 'disabled'}" self.policy_mode_pub.publish(mode_msg) except Exception as e: self.get_logger().error(f"Failed to publish policy mode: {e}") def _timing_ms(self, t0: float) -> float: return (time.perf_counter() - t0) * 1000.0 def _record_timing_sample(self, sample: dict): if not getattr(self, "timing_debug_enabled", False): return self._timing_debug_samples.append(sample) if getattr(self, "timing_debug_log_per_loop", False): self.get_logger().info( "[Timing] " f"loop_total={sample['loop_total_ms']:.2f}ms " f"io={sample['io_ms']:.2f}ms " f"policy_total={sample['policy_total_ms']:.2f}ms " f"fk={sample['fk_ms']:.2f}ms " f"obs={sample['obs_ms']:.2f}ms " f"onnx={sample['onnx_ms']:.2f}ms " f"post={sample['post_ms']:.2f}ms" ) now = time.time() last = getattr(self, "_timing_debug_last_log_time", None) interval = float(getattr(self, "timing_debug_log_interval_sec", 5.0)) if last is None: self._timing_debug_last_log_time = now return if now - last < interval: return if len(self._timing_debug_samples) == 0: self._timing_debug_last_log_time = now return keys = [ "loop_total_ms", "io_ms", "policy_total_ms", "fk_ms", "obs_ms", "onnx_ms", "post_ms", ] stats = {} for key in keys: vals = np.array( [float(s.get(key, 0.0)) for s in self._timing_debug_samples], dtype=np.float64, ) stats[key] = (float(np.mean(vals)), float(np.max(vals))) self.get_logger().info( "[Timing-Agg] " + " ".join( f"{key}=mean:{stats[key][0]:.2f}ms/max:{stats[key][1]:.2f}ms" for key in keys ) + f" n={len(self._timing_debug_samples)}" ) self._timing_debug_samples.clear() self._timing_debug_last_log_time = now def _root_only_fk_has_required_keybodies(self, keybody_idxs: np.ndarray) -> bool: if keybody_idxs.size == 0: return True available_bodies = 0 if self._fk_trans_0 is None else int(self._fk_trans_0.shape[0]) if available_bodies <= int(np.max(keybody_idxs)): if not self._root_only_fk_keybody_warned: self.get_logger().warn( "[RootOnlyFK] FK output only contains root body, but obs schema still " "requests non-root keybody positions. Returning zeros for keybody obs." ) self._root_only_fk_keybody_warned = True return False return True def _run_without_profiling(self): """Run the main loop without performance profiling.""" if self._lowstate_msg is None or not self.policy_enabled: return None timing_info = { "policy_total_ms": 0.0, "fk_ms": 0.0, "obs_ms": 0.0, "onnx_ms": 0.0, "post_ms": 0.0, } _t_policy_start = time.perf_counter() if self.current_policy_mode == "motion": if self.latest_obs_flag: current_time = time.time() if self.last_external_obs_time is None: data_age = float("inf") else: data_age = current_time - self.last_external_obs_time if data_age > self.max_data_age: self.get_logger().warn( f"ZMQ latest_obs is stale: age={data_age*1000:.1f}ms > {self.max_data_age*1000:.1f}ms; " "switching to velocity tracking mode for safety." ) self._switch_to_velocity_mode(reason="VR latest_obs stale") return None if not self.latest_obs_flag and ( not hasattr(self, "n_motion_frames") or not hasattr(self, "ref_dof_pos") ): self.get_logger().warn("Motion data not loaded, skipping policy execution") return None if ( self.latest_obs_flag and self.fk is not None and self.external_fut_dof_pos_queue is not None ): try: n_fut = int(getattr(self, "n_fut_frames", 0)) if ( n_fut > 0 and self.external_fut_root_pos_queue is not None and self.external_fut_root_rot_queue is not None ): t_fk = time.perf_counter() cur_root_pos = self.ref_root_pos_raw.astype(np.float32) cur_root_rot = ( self._prev_external_root_rot if self._prev_external_root_rot is not None else self.external_fut_root_rot_queue[0].astype(np.float32) ) cur_dof_pos = self.ref_dof_pos_raw.astype(np.float32) root_pos_tensor, root_rot_tensor, dof_pos_tensor = ( self._prepare_vr_fk_tensors( cur_root_pos=cur_root_pos, cur_root_rot=cur_root_rot, cur_dof_pos=cur_dof_pos, n_fut=n_fut, ) ) fk_out = self.fk( root_pos=root_pos_tensor, root_quat=root_rot_tensor, dof_pos=dof_pos_tensor, fps=float(1.0 / self.dt), quat_format="wxyz", vel_smoothing_sigma=0.0, compute_velocity=False, ) self._fk_vr_out = { k: v.detach().cpu().numpy() for k, v in fk_out.items() } timing_info["fk_ms"] = self._timing_ms(t_fk) else: self._fk_vr_out = None except Exception as e: self.get_logger().error( f"VR FK computation failed; falling back to offline reference: {e}" ) self._fk_vr_out = None self.obs_builder = self.motion_obs_builder # Use motion model metadata current_action_scale = self.motion_action_scale_onnx current_default_angles = self.motion_default_angles_onnx current_onnx_to_real = self.motion_onnx_to_real else: # velocity mode self.obs_builder = self.velocity_obs_builder # Use velocity model metadata current_action_scale = self.velocity_action_scale_onnx current_default_angles = self.velocity_default_angles_onnx current_onnx_to_real = self.velocity_onnx_to_real t_obs = time.perf_counter() if self.current_policy_mode == "motion": self._cache_fk_vr_for_obs() policy_obs_np = self.obs_builder.build_policy_obs()[None, :].astype( np.float32, copy=False ) timing_info["obs_ms"] = self._timing_ms(t_obs) # Run ONNX inference with the appropriate policy session and correct input/output names t_onnx = time.perf_counter() if self.current_policy_mode == "velocity": input_feed = {self.velocity_input_name: policy_obs_np} onnx_output = self.velocity_policy_session.run([self.velocity_output_name], input_feed) else: # motion mode if self.use_kv_cache: if self.motion_kv_cache is None: shape = [ d if isinstance(d, int) else 1 for d in self.motion_kv_shape ] self.motion_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype) # if ( # self.motion_effective_context_len > 0 # and self.motion_step_idx > 0 # and self.motion_step_idx % self.motion_effective_context_len == 0 # ): # self.motion_kv_cache.fill(0.0) input_feed = { self.motion_input_name: policy_obs_np, self.motion_kv_input_name: self.motion_kv_cache, } if self.motion_step_idx_input_name is not None: step_idx = self.motion_step_idx # if self.motion_effective_context_len > 0: # step_idx = ( # self.motion_step_idx # % self.motion_effective_context_len # ) input_feed[self.motion_step_idx_input_name] = np.array( [step_idx], dtype=np.int64 ) output_names = [self.motion_output_name] if self.motion_kv_output_name: output_names.append(self.motion_kv_output_name) onnx_output = self.motion_policy_session.run( output_names, input_feed ) if len(onnx_output) > 1: self.motion_kv_cache = onnx_output[1] self.motion_step_idx += 1 else: input_feed = {self.motion_input_name: policy_obs_np} onnx_output = self.motion_policy_session.run( [self.motion_output_name], input_feed ) timing_info["onnx_ms"] = self._timing_ms(t_onnx) t_post = time.perf_counter() raw_actions_onnx = np.asarray(onnx_output[0], dtype=np.float32).reshape(-1) if self.current_policy_mode == "motion": self.actions_onnx = self._apply_motion_action_ema_filter(raw_actions_onnx) else: self.actions_onnx = raw_actions_onnx.copy() # Use the appropriate metadata based on current policy mode self.target_dof_pos_onnx = ( self.actions_onnx * current_action_scale + current_default_angles ) self.target_dof_pos_real = self.target_dof_pos_onnx[current_onnx_to_real] # Action processing and publishing self._process_and_publish_actions() if self.current_policy_mode == "motion": if ( not getattr(self, "latest_obs_flag", False) and self.motion_frame_idx >= self.n_motion_frames and self.motion_in_progress ): self.get_logger().info("Motion action completed (offline reference)") self.motion_in_progress = False # Publish policy mode status self._publish_policy_mode() timing_info["post_ms"] = self._timing_ms(t_post) timing_info["policy_total_ms"] = self._timing_ms(_t_policy_start) return timing_info def _process_and_publish_actions(self): """Process and publish action commands.""" if self.target_dof_pos_real is not None: action_msg = Float32MultiArray() action_msg.data = self.target_dof_pos_real.tolist() # Check for NaN values if np.isnan(self.target_dof_pos_real).any(): self.get_logger().error("Action contains NaN values") self.action_pub.publish(action_msg) self.motion_frame_idx += 1 def setup(self): """Set up the evaluator by loading all required components.""" main_affinity = _parse_cpu_affinity_str( getattr(self, "_cpu_affinity_main_str", "") or "" ) if main_affinity and set_thread_cpu_affinity(main_affinity): self.get_logger().info(f"[Policy] main thread pinned to CPUs {main_affinity}") self.load_model_config() # Load config first self.update_config_parameters() # Update parameters from config # Initialize FK for online VR reference reconstruction self._init_fk() self.load_policy() # Then load policies self._apply_onnx_metadata() self._init_obs_buffers() self._build_dof_mappings() self._warmup_motion_policy() self._init_keybody_indices_cache() # Always load motion data since we support both modes self.load_motion_data() self.get_logger().info("Synchronous root-only policy setup completed") def _init_fk(self): """Initialize lightweight root-only FK for synchronous VR reference updates.""" try: self.get_logger().info( "Initializing root-only FK (no URDF, sync main-thread mode)" ) self.fk = HoloMotionFKRootOnly( dof_names=self.dof_names_ref_motion, device="cpu", timing_logger_enabled=True, timing_log_interval_sec=5.0, timing_log_per_call=False, timing_name="FKRootOnlyVR", timing_log_fn=self.get_logger().info, ) try: ndof = len(self.fk.dof_names) root_pos_dummy = torch.zeros((1, 4, 3), dtype=torch.float32) root_quat_dummy = torch.zeros((1, 4, 4), dtype=torch.float32) root_quat_dummy[..., 0] = 1.0 dof_pos_dummy = torch.zeros((1, 4, ndof), dtype=torch.float32) _ = self.fk( root_pos=root_pos_dummy, root_quat=root_quat_dummy, dof_pos=dof_pos_dummy, fps=float(1.0 / self.dt), quat_format="wxyz", vel_smoothing_sigma=0.0, compute_velocity=False, ) self.get_logger().info("[FK] Root-only warmup completed (B=1,T=4)") except Exception as e_dummy: self.get_logger().warn(f"[FK] Root-only warmup failed (ignored): {e_dummy}") self.fk_initialized = True self.get_logger().info( f"Root-only FK initialized successfully with {len(self.fk.dof_names)} dofs" ) except Exception as e: self.get_logger().error(f"Failed to initialize root-only FK: {e}") self.fk = None self.fk_initialized = False def destroy_node(self): try: if getattr(self, "_zmq_subscriber", None) is not None: self._zmq_subscriber.stop() except Exception: pass super().destroy_node() def get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray: """Calculate gravity orientation from quaternion. Args: quaternion: Array-like [w, x, y, z] Returns: np.ndarray of shape (3,) representing gravity projection. """ qw = float(quaternion[0]) qx = float(quaternion[1]) qy = float(quaternion[2]) qz = float(quaternion[3]) gravity_orientation = np.zeros(3, dtype=np.float32) gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy) gravity_orientation[1] = -2.0 * (qz * qy + qw * qx) gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz) return gravity_orientation def main(): """Main entry point for the policy node.""" rclpy.init() policy_node = HoloMotionPolicyNode() rclpy.spin(policy_node) if __name__ == "__main__": main() ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/__init__.py ================================================ ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/command_helper.py ================================================ from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as LowCmdGo from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as LowCmdHG from typing import Union class MotorMode: PR = 0 # Series Control for Pitch/Roll Joints AB = 1 # Parallel Control for A/B Joints def create_damping_cmd(cmd: Union[LowCmdGo, LowCmdHG]): size = len(cmd.motor_cmd) for i in range(size): cmd.motor_cmd[i].q = 0 cmd.motor_cmd[i].qd = 0 cmd.motor_cmd[i].kp = 0 cmd.motor_cmd[i].kd = 8 cmd.motor_cmd[i].tau = 0 def create_zero_cmd(cmd: Union[LowCmdGo, LowCmdHG]): size = len(cmd.motor_cmd) for i in range(size): cmd.motor_cmd[i].q = 0 cmd.motor_cmd[i].qd = 0 cmd.motor_cmd[i].kp = 0 cmd.motor_cmd[i].kd = 0 cmd.motor_cmd[i].tau = 0 def init_cmd_hg(cmd: LowCmdHG, mode_machine: int, mode_pr: int): cmd.mode_machine = mode_machine cmd.mode_pr = mode_pr size = len(cmd.motor_cmd) for i in range(size): cmd.motor_cmd[i].mode = 1 cmd.motor_cmd[i].q = 0 cmd.motor_cmd[i].qd = 0 cmd.motor_cmd[i].kp = 0 cmd.motor_cmd[i].kd = 0 cmd.motor_cmd[i].tau = 0 def init_cmd_go(cmd: LowCmdGo, weak_motor: list): cmd.head[0] = 0xFE cmd.head[1] = 0xEF cmd.level_flag = 0xFF cmd.gpio = 0 PosStopF = 2.146e9 VelStopF = 16000.0 size = len(cmd.motor_cmd) for i in range(size): if i in weak_motor: cmd.motor_cmd[i].mode = 1 else: cmd.motor_cmd[i].mode = 0x0A cmd.motor_cmd[i].q = PosStopF cmd.motor_cmd[i].qd = VelStopF cmd.motor_cmd[i].kp = 0 cmd.motor_cmd[i].kd = 0 cmd.motor_cmd[i].tau = 0 ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/maths.py ================================================ import torch import numpy as np import random import os @torch.jit.script def normalize(x, eps: float = 1e-9): return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) @torch.jit.script def torch_rand_float(lower, upper, shape, device): # type: (float, float, Tuple[int, int], str) -> Tensor return (upper - lower) * torch.rand(*shape, device=device) + lower @torch.jit.script def copysign(a, b): # type: (float, Tensor) -> Tensor a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0]) return torch.abs(a) * torch.sign(b) def set_seed(seed, torch_deterministic=False): """set seed across modules""" if seed == -1 and torch_deterministic: seed = 42 elif seed == -1: seed = np.random.randint(0, 10000) print("Setting seed: {}".format(seed)) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if torch_deterministic: # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True) else: torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False return seed def to_torch(x, dtype=torch.float, device="cuda:0", requires_grad=False): return torch.tensor( x, dtype=dtype, device=device, requires_grad=requires_grad ) @torch.compile def quat_mul_legacy( a: torch.Tensor, b: torch.Tensor, w_last: bool = True ) -> torch.Tensor: """Multiply two quaternions. Args: a (torch.Tensor): (..., 4) quaternion. b (torch.Tensor): (..., 4) quaternion. w_last (bool): Whether the scalar part w is the last element. If True, format is [x, y, z, w]; if False, format is [w, x, y, z]. Returns: torch.Tensor: (..., 4) quaternion result of a * b. """ assert a.shape == b.shape shape = a.shape a = a.reshape(-1, 4) b = b.reshape(-1, 4) if w_last: # Format: [x, y, z, w] x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3] x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3] else: # Format: [w, x, y, z] w1, x1, y1, z1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3] w2, x2, y2, z2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3] ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) xx = ww + yy + zz qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) w = qq - ww + (z1 - y1) * (y2 - z2) x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) if w_last: quat = torch.stack([x, y, z, w], dim=-1).view(shape) else: quat = torch.stack([w, x, y, z], dim=-1).view(shape) return quat @torch.jit.script def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: """Multiply two quaternions together. Args: q1: The first quaternion in (w, x, y, z). Shape is (..., 4). q2: The second quaternion in (w, x, y, z). Shape is (..., 4). Returns: The product of the two quaternions in (w, x, y, z). Shape is (..., 4). Raises: ValueError: Input shapes of ``q1`` and ``q2`` are not matching. """ # check input is correct if q1.shape != q2.shape: msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}." raise ValueError(msg) # reshape to (N, 4) for multiplication shape = q1.shape q1 = q1.reshape(-1, 4) q2 = q2.reshape(-1, 4) # extract components from quaternions w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] # perform multiplication ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) xx = ww + yy + zz qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) w = qq - ww + (z1 - y1) * (y2 - z2) x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) return torch.stack([w, x, y, z], dim=-1).view(shape) @torch.jit.script def normalize(x, eps: float = 1e-9): return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) @torch.jit.script def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """Apply a quaternion rotation to a vector. Args: quat: The quaternion in (w, x, y, z). Shape is (..., 4). vec: The vector in (x, y, z). Shape is (..., 3). Returns: The rotated vector in (x, y, z). Shape is (..., 3). """ # store shape shape = vec.shape # reshape to (N, 3) for multiplication quat = quat.reshape(-1, 4) vec = vec.reshape(-1, 3) # extract components from quaternions xyz = quat[:, 1:] t = xyz.cross(vec, dim=-1) * 2 return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_apply_inverse(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """Apply an inverse quaternion rotation to a vector. Args: quat: The quaternion in (w, x, y, z). Shape is (..., 4). vec: The vector in (x, y, z). Shape is (..., 3). Returns: The rotated vector in (x, y, z). Shape is (..., 3). """ # store shape shape = vec.shape # reshape to (N, 3) for multiplication quat = quat.reshape(-1, 4) vec = vec.reshape(-1, 3) # extract components from quaternions xyz = quat[:, 1:] t = xyz.cross(vec, dim=-1) * 2 return (vec - quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_rotate(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a + b + c # @torch.jit.script def quat_rotate_inverse(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a - b + c @torch.jit.script def quat_conjugate(a): shape = a.shape a = a.reshape(-1, 4) # return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape) return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape) @torch.jit.script def quat_unit(a): return normalize(a) @torch.jit.script def quat_from_angle_axis(angle, axis): theta = (angle / 2).unsqueeze(-1) xyz = normalize(axis) * theta.sin() w = theta.cos() return quat_unit(torch.cat([xyz, w], dim=-1)) @torch.jit.script def normalize_angle(x): return torch.atan2(torch.sin(x), torch.cos(x)) @torch.jit.script def get_basis_vector(q, v): return quat_rotate(q, v) def get_axis_params(value, axis_idx, x_value=0.0, dtype=np.float64, n_dims=3): """Construct arguments to `Vec` according to axis index.""" zs = np.zeros((n_dims,)) assert axis_idx < n_dims, ( "the axis dim should be within the vector dimensions" ) zs[axis_idx] = 1.0 params = np.where(zs == 1.0, value, zs) params[0] = x_value return list(params.astype(dtype)) # @torch.jit.script # def copysign(a, b): # a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0]) # return torch.abs(a) * torch.sign(b) @torch.jit.script def get_euler_xyz(q: torch.Tensor) -> tuple: qx, qy, qz, qw = 0, 1, 2, 3 # roll (x-axis rotation) sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) cosr_cosp = ( q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz] ) roll = torch.atan2(sinr_cosp, cosr_cosp) # pitch (y-axis rotation) sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) pitch = torch.where( torch.abs(sinp) >= 1, copysign(torch.tensor(np.pi / 2.0, device=sinp.device), sinp), torch.asin(sinp), ) # yaw (z-axis rotation) siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) cosy_cosp = ( q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] ) yaw = torch.atan2(siny_cosp, cosy_cosp) return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) @torch.jit.script def quat_from_euler_xyz(roll, pitch, yaw): cy = torch.cos(yaw * 0.5) sy = torch.sin(yaw * 0.5) cr = torch.cos(roll * 0.5) sr = torch.sin(roll * 0.5) cp = torch.cos(pitch * 0.5) sp = torch.sin(pitch * 0.5) qw = cy * cr * cp + sy * sr * sp qx = cy * sr * cp - sy * cr * sp qy = cy * cr * sp + sy * sr * cp qz = sy * cr * cp - cy * sr * sp return torch.stack([qx, qy, qz, qw], dim=-1) def torch_rand_float(lower, upper, shape, device): return (upper - lower) * torch.rand(*shape, device=device) + lower # @torch.jit.script @torch.compile def torch_random_dir_2(shape, device): angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1) return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1) @torch.jit.script def tensor_clamp(t, min_t, max_t): return torch.max(torch.min(t, max_t), min_t) @torch.jit.script def scale(x, lower, upper): return 0.5 * (x + 1.0) * (upper - lower) + lower @torch.jit.script def unscale(x, lower, upper): return (2.0 * x - upper - lower) / (upper - lower) def unscale_np(x, lower, upper): return (2.0 * x - upper - lower) / (upper - lower) @torch.jit.script def quat_to_angle_axis(q): # computes axis-angle representation from quaternion q # q must be normalized min_theta = 1e-5 qx, _, _, qw = 0, 1, 2, 3 sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) angle = 2 * torch.acos(q[..., qw]) angle = normalize_angle(angle) sin_theta_expand = sin_theta.unsqueeze(-1) axis = q[..., qx:qw] / sin_theta_expand mask = torch.abs(sin_theta) > min_theta default_axis = torch.zeros_like(axis) default_axis[..., -1] = 1 angle = torch.where(mask, angle, torch.zeros_like(angle)) mask_expand = mask.unsqueeze(-1) axis = torch.where(mask_expand, axis, default_axis) return angle, axis @torch.jit.script def angle_axis_to_exp_map(angle, axis): # compute exponential map from axis-angle angle_expand = angle.unsqueeze(-1) exp_map = angle_expand * axis return exp_map @torch.jit.script def quat_to_exp_map(q): # compute exponential map from quaternion # q must be normalized angle, axis = quat_to_angle_axis(q) exp_map = angle_axis_to_exp_map(angle, axis) return exp_map @torch.jit.script def slerp(q0, q1, t): cos_half_theta = torch.sum(q0 * q1, dim=-1) neg_mask = cos_half_theta < 0 q1 = q1.clone() q1[neg_mask] = -q1[neg_mask] cos_half_theta = torch.abs(cos_half_theta) cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) half_theta = torch.acos(cos_half_theta) sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) ratio_a = torch.sin((1 - t) * half_theta) / sin_half_theta ratio_b = torch.sin(t * half_theta) / sin_half_theta new_q = ratio_a * q0 + ratio_b * q1 new_q = torch.where( torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q ) new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) return new_q @torch.jit.script def my_quat_rotate(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a + b + c @torch.jit.script def calc_heading(q): # calculate heading direction from quaternion # the heading is the direction on the xy plane # q must be normalized # this is the x axis heading ref_dir = torch.zeros_like(q[..., 0:3]) ref_dir[..., 0] = 1 rot_dir = my_quat_rotate(q, ref_dir) heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) return heading @torch.jit.script def calc_heading_quat(q): # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(heading, axis) return heading_q @torch.jit.script def calc_heading_quat_inv(q): # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(-heading, axis) return heading_q @torch.compile def axis_angle_from_quat( quat: torch.Tensor, w_last: bool = True, ) -> torch.Tensor: """Compute axis-angle (log map) vector from a quaternion. Args: quat (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z]. w_last (bool): Whether the scalar part w is the last element. Returns: torch.Tensor: (..., 3) axis-angle vector (axis * angle), with angle in radians in [0, pi]. Notes: - The quaternion is sign-adjusted to ensure w >= 0 and normalized to unit length for numerical stability. - Uses a stable small-angle handling to avoid NaNs and gradient issues. """ # Handle different quaternion formats if w_last: # Quaternion is [q_x, q_y, q_z, q_w] quat_w_orig = quat[..., -1:] else: # Quaternion is [q_w, q_x, q_y, q_z] quat_w_orig = quat[..., 0:1] # Normalize quaternion to have w > 0 quat = quat * (1.0 - 2.0 * (quat_w_orig < 0.0)) # Ensure unit quaternion for stability quat = quat / torch.linalg.norm(quat, dim=-1, keepdim=True).clamp_min( 1.0e-9 ) # Recompute quat_xyz and quat_w after potential sign flip if w_last: quat_w = quat[..., -1:] quat_xyz = quat[..., :3] else: quat_w = quat[..., 0:1] quat_xyz = quat[..., 1:4] mag = torch.linalg.norm(quat_xyz, dim=-1) half_angle = torch.atan2(mag, quat_w.squeeze(-1)) angle = 2.0 * half_angle # check whether to apply Taylor approximation use_taylor = angle.abs() <= 1.0e-6 # To prevent NaN gradients with torch.where, we compute both branches and blend # based on the condition. # See: https://pytorch.org/docs/1.9.0/generated/torch.where.html#torch-where # "However, if you need the gradients to flow through the branches, please use torch.lerp" # Although we are not using lerp, the principle of avoiding sharp branches is the same. sin_half_angles_over_angles_approx = 0.5 - angle * angle / 48 # Clamp angle to avoid division by zero in the non-taylor branch when angle is exactly 0. angle_safe = torch.where(use_taylor, torch.ones_like(angle), angle) sin_half_angles_over_angles_exact = torch.sin(half_angle) / angle_safe sin_half_angles_over_angles = torch.where( use_taylor, sin_half_angles_over_angles_approx, sin_half_angles_over_angles_exact, ) return quat_xyz / sin_half_angles_over_angles[..., None] @torch.jit.script def quat_inv(q: torch.Tensor, eps: float = 1e-9) -> torch.Tensor: """Computes the inverse of a quaternion. Args: q: The quaternion orientation in (w, x, y, z). Shape is (N, 4). eps: A small value to avoid division by zero. Defaults to 1e-9. Returns: The inverse quaternion in (w, x, y, z). Shape is (N, 4). """ return quat_conjugate(q) / q.pow(2).sum(dim=-1, keepdim=True).clamp( min=eps ) # --------------------- WXYZ helpers (torch) --------------------- def xyzw_to_wxyz(q: torch.Tensor) -> torch.Tensor: """ Convert quaternion from XYZW to WXYZ. Args: q (torch.Tensor): (..., 4) quaternion in XYZW. Returns: torch.Tensor: (..., 4) quaternion in WXYZ. """ return torch.cat([q[..., 3:4], q[..., 0:3]], dim=-1) def wxyz_to_xyzw(q: torch.Tensor) -> torch.Tensor: """ Convert quaternion from WXYZ to XYZW. Args: q (torch.Tensor): (..., 4) quaternion in WXYZ. Returns: torch.Tensor: (..., 4) quaternion in XYZW. """ return torch.cat([q[..., 1:4], q[..., 0:1]], dim=-1) @torch.compile def quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: """ Hamilton product in WXYZ layout using fused implementation. Args: q1 (torch.Tensor): (..., 4) WXYZ. q2 (torch.Tensor): (..., 4) WXYZ. Returns: torch.Tensor: (..., 4) WXYZ. """ return quat_mul(q1, q2, w_last=False) def subtract_frame_transforms( t01: torch.Tensor, q01: torch.Tensor, t02: torch.Tensor = None, q02: torch.Tensor = None, ): r"""Subtract transformations between two reference frames into a stationary frame. It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \times T_{02}`, where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B. Args: t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3). q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3). Defaults to None, in which case the position is assumed to be zero. q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). Defaults to None, in which case the orientation is assumed to be identity. Returns: A tuple containing the position and orientation of frame 2 w.r.t. frame 1. Shape of the tensors are (N, 3) and (N, 4) respectively. """ # compute orientation q10 = quat_inv(q01) if q02 is not None: q12 = quat_mul(q10, q02) else: q12 = q10 # compute translation if t02 is not None: t12 = quat_apply(q10, t02 - t01) else: t12 = quat_apply(q10, -t01) return t12, q12 @torch.compile def quat_normalize_wxyz(q_wxyz: torch.Tensor) -> torch.Tensor: """ Normalize quaternion in WXYZ layout. Args: q_wxyz (torch.Tensor): (..., 4) WXYZ. Returns: torch.Tensor: (..., 4) normalized WXYZ. """ return q_wxyz / torch.linalg.norm(q_wxyz, dim=-1, keepdim=True).clamp_min( 1.0e-9 ) @torch.jit.script def matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor: """Convert rotations given as quaternions to rotation matrices. Args: quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4). Returns: Rotation matrices. The shape is (..., 3, 3). Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70 """ r, i, j, k = torch.unbind(quaternions, -1) # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/motor_crc.py ================================================ import struct import numpy as np from ctypes import Structure, c_uint8, c_float, c_uint32, Array def crc32_core(data_array, length): CRC32 = 0xFFFFFFFF dwPolynomial = 0x04C11DB7 for i in range(length): data = data_array[i] for bit in range(32): # Process all 32 bits if (CRC32 >> 31) & 1: # Check MSB before shift CRC32 = ((CRC32 << 1) & 0xFFFFFFFF) ^ dwPolynomial else: CRC32 = (CRC32 << 1) & 0xFFFFFFFF if (data >> (31 - bit)) & 1: # Match C++ bit processing order CRC32 ^= dwPolynomial return CRC32 def calc_crc(cmd) -> int: """Calculate CRC for LowCmd message""" buffer = bytearray() # Pack header (mode_pr, mode_machine + 2 padding) buffer.extend(struct.pack("> i # 读取原始值 lx_raw = struct.unpack("f", data[4:8])[0] rx_raw = struct.unpack("f", data[8:12])[0] ry_raw = struct.unpack("f", data[12:16])[0] ly_raw = struct.unpack("f", data[20:24])[0] # 应用滤波和死区 self.lx = self.apply_filter_and_deadzone(lx_raw, self.lx_prev) self.ly = self.apply_filter_and_deadzone(ly_raw, self.ly_prev) self.rx = self.apply_filter_and_deadzone(rx_raw, self.rx_prev) self.ry = self.apply_filter_and_deadzone(ry_raw, self.ry_prev) # 更新前一次的值 self.lx_prev = self.lx self.ly_prev = self.ly self.rx_prev = self.rx self.ry_prev = self.ry def get_velocity_commands(self): """ 将摇杆值转换为速度命令 Returns: tuple: (vx, vy, vyaw) - vx: 前进/后退速度 (m/s),由左摇杆前后(ly)控制 - vy: 左右平移速度 (m/s),由左摇杆左右(lx)控制 - vyaw: 转向角速度 (rad/s),由右摇杆左右(rx)控制 """ # 前进/后退速度,使用左摇杆的y轴 vx = ( self.ly * self.max_linear_speed_x ) # 注意:通常需要取反,因为向前推摇杆时ly为负 # 限制x速度最小值为-0.5 if vx < -0.5: vx = -0.5 # 左右平移速度,使用左摇杆的x轴 vy = ( -self.lx * self.max_linear_speed_y ) # 注意:可能需要取反,取决于坐标系定义 # 转向角速度,使用右摇杆的x轴 vyaw = ( -self.rx * self.max_angular_speed ) # 注意:可能需要取反,取决于坐标系定义 # 应用速度阈值 - 当速度小于阈值时设为0 if abs(vx) < self.velocity_threshold_x: vx = 0.0 if abs(vy) < self.velocity_threshold_y: vy = 0.0 if abs(vyaw) < self.velocity_threshold_yaw: vyaw = 0.0 return vx, vy, vyaw ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/rotation_helper.py ================================================ import numpy as np from scipy.spatial.transform import Rotation as R def get_gravity_orientation(quaternion): qw = quaternion[0] qx = quaternion[1] qy = quaternion[2] qz = quaternion[3] gravity_orientation = np.zeros(3) gravity_orientation[0] = 2 * (-qz * qx + qw * qy) gravity_orientation[1] = -2 * (qz * qy + qw * qx) gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz) return gravity_orientation def transform_imu_data(waist_yaw, waist_yaw_omega, imu_quat, imu_omega): RzWaist = R.from_euler("z", waist_yaw).as_matrix() R_torso = R.from_quat( [imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]] ).as_matrix() R_pelvis = np.dot(R_torso, RzWaist.T) w = np.dot(RzWaist, imu_omega[0]) - np.array([0, 0, waist_yaw_omega]) return R.from_matrix(R_pelvis).as_quat()[[3, 0, 1, 2]], w ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/rotations.py ================================================ import torch from torch import Tensor import torch.nn.functional as F from humanoid_policy.utils.maths import ( normalize, copysign, ) from typing import Tuple import numpy as np from typing import List, Optional @torch.jit.script def quat_unit(a): return normalize(a) @torch.jit.script def quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor: shape = b.shape a = a.reshape(-1, 4) b = b.reshape(-1, 3) if w_last: xyz = a[:, :3] w = a[:, 3:] else: xyz = a[:, 1:] w = a[:, :1] t = xyz.cross(b, dim=-1) * 2 return (b + w * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_apply_yaw(quat: Tensor, vec: Tensor, w_last: bool) -> Tensor: quat_yaw = quat.clone().view(-1, 4) quat_yaw[:, :2] = 0.0 quat_yaw = normalize(quat_yaw) return quat_apply(quat_yaw, vec, w_last) @torch.jit.script def wrap_to_pi(angles): angles %= 2 * np.pi angles -= 2 * np.pi * (angles > np.pi) return angles @torch.jit.script def quat_conjugate(a: Tensor, w_last: bool) -> Tensor: shape = a.shape a = a.reshape(-1, 4) if w_last: return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape) else: return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape) @torch.jit.script def quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor: shape = b.shape a = a.reshape(-1, 4) b = b.reshape(-1, 3) if w_last: xyz = a[:, :3] w = a[:, 3:] else: xyz = a[:, 1:] w = a[:, :1] t = xyz.cross(b, dim=-1) * 2 return (b + w * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_rotate(q: Tensor, v: Tensor, w_last: bool) -> Tensor: shape = q.shape if w_last: q_w = q[:, -1] q_vec = q[:, :3] else: q_w = q[:, 0] q_vec = q[:, 1:] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a + b + c @torch.jit.script def quat_rotate_inverse(q: Tensor, v: Tensor, w_last: bool) -> Tensor: shape = q.shape if w_last: q_w = q[:, -1] q_vec = q[:, :3] else: q_w = q[:, 0] q_vec = q[:, 1:] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a - b + c @torch.jit.script def quat_angle_axis(x: Tensor, w_last: bool) -> Tuple[Tensor, Tensor]: """ The (angle, axis) representation of the rotation. The axis is normalized to unit length. The angle is guaranteed to be between [0, pi]. """ if w_last: w = x[..., -1] axis = x[..., :3] else: w = x[..., 0] axis = x[..., 1:] s = 2 * (w**2) - 1 angle = s.clamp(-1, 1).arccos() # just to be safe axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9) return angle, axis @torch.jit.script def quat_from_angle_axis(angle: Tensor, axis: Tensor, w_last: bool) -> Tensor: theta = (angle / 2).unsqueeze(-1) xyz = normalize(axis) * theta.sin() w = theta.cos() if w_last: return quat_unit(torch.cat([xyz, w], dim=-1)) else: return quat_unit(torch.cat([w, xyz], dim=-1)) @torch.jit.script def vec_to_heading(h_vec): h_theta = torch.atan2(h_vec[..., 1], h_vec[..., 0]) return h_theta @torch.jit.script def heading_to_quat(h_theta, w_last: bool): axis = torch.zeros( h_theta.shape + [ 3, ], device=h_theta.device, ) axis[..., 2] = 1 heading_q = quat_from_angle_axis(h_theta, axis, w_last=w_last) return heading_q @torch.jit.script def quat_axis(q: Tensor, axis: int, w_last: bool) -> Tensor: basis_vec = torch.zeros(q.shape[0], 3, device=q.device) basis_vec[:, axis] = 1 return quat_rotate(q, basis_vec, w_last) @torch.jit.script def normalize_angle(x): return torch.atan2(torch.sin(x), torch.cos(x)) @torch.jit.script def get_basis_vector(q: Tensor, v: Tensor, w_last: bool) -> Tensor: return quat_rotate(q, v, w_last) @torch.jit.script def quat_to_angle_axis(q): # type: (Tensor) -> Tuple[Tensor, Tensor] # computes axis-angle representation from quaternion q # q must be normalized # ZL: could have issues. min_theta = 1e-5 qx, qy, qz, qw = 0, 1, 2, 3 sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) angle = 2 * torch.acos(q[..., qw]) angle = normalize_angle(angle) sin_theta_expand = sin_theta.unsqueeze(-1) axis = q[..., qx:qw] / sin_theta_expand mask = torch.abs(sin_theta) > min_theta default_axis = torch.zeros_like(axis) default_axis[..., -1] = 1 angle = torch.where(mask, angle, torch.zeros_like(angle)) mask_expand = mask.unsqueeze(-1) axis = torch.where(mask_expand, axis, default_axis) return angle, axis @torch.jit.script def slerp(q0, q1, t): # type: (Tensor, Tensor, Tensor) -> Tensor cos_half_theta = torch.sum(q0 * q1, dim=-1) neg_mask = cos_half_theta < 0 q1 = q1.clone() q1[neg_mask] = -q1[neg_mask] cos_half_theta = torch.abs(cos_half_theta) cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) half_theta = torch.acos(cos_half_theta) sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta ratioB = torch.sin(t * half_theta) / sin_half_theta new_q = ratioA * q0 + ratioB * q1 new_q = torch.where( torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q ) new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) return new_q @torch.jit.script def angle_axis_to_exp_map(angle, axis): # type: (Tensor, Tensor) -> Tensor # compute exponential map from axis-angle angle_expand = angle.unsqueeze(-1) exp_map = angle_expand * axis return exp_map @torch.jit.script def my_quat_rotate(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a + b + c @torch.jit.script def calc_heading(q): # type: (Tensor) -> Tensor # calculate heading direction from quaternion # the heading is the direction on the xy plane # q must be normalized # this is the x axis heading ref_dir = torch.zeros_like(q[..., 0:3]) ref_dir[..., 0] = 1 rot_dir = my_quat_rotate(q, ref_dir) heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) return heading @torch.jit.script def quat_to_exp_map(q): # type: (Tensor) -> Tensor # compute exponential map from quaternion # q must be normalized angle, axis = quat_to_angle_axis(q) exp_map = angle_axis_to_exp_map(angle, axis) return exp_map @torch.jit.script def calc_heading_quat(q, w_last): # type: (Tensor, bool) -> Tensor # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(heading, axis, w_last=w_last) return heading_q @torch.jit.script def calc_heading_quat_inv(q, w_last): # type: (Tensor, bool) -> Tensor # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(-heading, axis, w_last=w_last) return heading_q @torch.jit.script def quat_inverse(x, w_last): # type: (Tensor, bool) -> Tensor """ The inverse of the rotation """ return quat_conjugate(x, w_last=w_last) @torch.jit.script def get_euler_xyz(q: Tensor, w_last: bool) -> Tuple[Tensor, Tensor, Tensor]: if w_last: qx, qy, qz, qw = 0, 1, 2, 3 else: qw, qx, qy, qz = 0, 1, 2, 3 # roll (x-axis rotation) sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) cosr_cosp = ( q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz] ) roll = torch.atan2(sinr_cosp, cosr_cosp) # pitch (y-axis rotation) sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) pitch = torch.where( torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp) ) # yaw (z-axis rotation) siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) cosy_cosp = ( q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] ) yaw = torch.atan2(siny_cosp, cosy_cosp) return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) # @torch.jit.script def get_euler_xyz_in_tensor(q): qx, qy, qz, qw = 0, 1, 2, 3 # roll (x-axis rotation) sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) cosr_cosp = ( q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz] ) roll = torch.atan2(sinr_cosp, cosr_cosp) # pitch (y-axis rotation) sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) pitch = torch.where( torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp) ) # yaw (z-axis rotation) siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) cosy_cosp = ( q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] ) yaw = torch.atan2(siny_cosp, cosy_cosp) return torch.stack((roll, pitch, yaw), dim=-1) @torch.jit.script def quat_pos(x): """ make all the real part of the quaternion positive """ q = x z = (q[..., 3:] < 0).float() q = (1 - 2 * z) * q return q @torch.jit.script def is_valid_quat(q): x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] return (w * w + x * x + y * y + z * z).allclose(torch.ones_like(w)) @torch.jit.script def quat_normalize(q): """ Construct 3D rotation from quaternion (the quaternion needs not to be normalized). """ q = quat_unit(quat_pos(q)) # normalized to positive and unit quaternion return q @torch.jit.script def quat_mul(a, b, w_last: bool): assert a.shape == b.shape shape = a.shape a = a.reshape(-1, 4) b = b.reshape(-1, 4) if w_last: x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] else: w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) xx = ww + yy + zz qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) w = qq - ww + (z1 - y1) * (y2 - z2) x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) if w_last: quat = torch.stack([x, y, z, w], dim=-1).view(shape) else: quat = torch.stack([w, x, y, z], dim=-1).view(shape) return quat @torch.jit.script def quat_mul_norm(x, y, w_last): # type: (Tensor, Tensor, bool) -> Tensor """ Combine two set of 3D rotations together using \**\* operator. The shape needs to be broadcastable """ return quat_normalize(quat_mul(x, y, w_last)) @torch.jit.script def quat_mul_norm(x, y, w_last): # type: (Tensor, Tensor, bool) -> Tensor """ Combine two set of 3D rotations together using \**\* operator. The shape needs to be broadcastable """ return quat_unit(quat_mul(x, y, w_last)) @torch.jit.script def quat_identity(shape: List[int]): """ Construct 3D identity rotation given shape """ w = torch.ones(shape + [1]) xyz = torch.zeros(shape + [3]) q = torch.cat([xyz, w], dim=-1) return quat_normalize(q) @torch.jit.script def quat_identity_like(x): """ Construct identity 3D rotation with the same shape """ return quat_identity(x.shape[:-1]) @torch.jit.script def transform_from_rotation_translation( r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None ): """ Construct a transform from a quaternion and 3D translation. Only one of them can be None. """ assert r is not None or t is not None, ( "rotation and translation can't be all None" ) if r is None: assert t is not None r = quat_identity(list(t.shape)) if t is None: t = torch.zeros(list(r.shape) + [3]) return torch.cat([r, t], dim=-1) @torch.jit.script def transform_rotation(x): """Get rotation from transform""" return x[..., :4] @torch.jit.script def transform_translation(x): """Get translation from transform""" return x[..., 4:] @torch.jit.script def transform_mul(x, y): """ Combine two transformation together """ z = transform_from_rotation_translation( r=quat_mul_norm( transform_rotation(x), transform_rotation(y), w_last=True ), t=quat_rotate( transform_rotation(x), transform_translation(y), w_last=True ) + transform_translation(x), ) return z ##################################### FROM PHC rotation_conversions.py ##################################### @torch.jit.script def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: """ Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ r, i, j, k = torch.unbind(quaternions, -1) two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) @torch.jit.script def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: """ Convert rotations given as axis/angle to quaternions. Args: axis_angle: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. Returns: quaternions with real part first, as tensor of shape (..., 4). """ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) half_angles = angles * 0.5 eps = 1e-6 small_angles = angles.abs() < eps sin_half_angles_over_angles = torch.empty_like(angles) sin_half_angles_over_angles[~small_angles] = ( torch.sin(half_angles[~small_angles]) / angles[~small_angles] ) # for x small, sin(x/2) is about x/2 - (x/2)^3/6 # so sin(x/2)/x is about 1/2 - (x*x)/48 sin_half_angles_over_angles[small_angles] = ( 0.5 - (angles[small_angles] * angles[small_angles]) / 48 ) quaternions = torch.cat( [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1, ) return quaternions # @torch.jit.script def wxyz_to_xyzw(quat): return quat[..., [1, 2, 3, 0]] # @torch.jit.script def xyzw_to_wxyz(quat): return quat[..., [3, 0, 1, 2]] def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: """ w x y z Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part first, as tensor of shape (..., 4). """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ torch.stack( [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1 ), torch.stack( [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1 ), torch.stack( [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1 ), torch.stack( [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1 ), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) return quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :, # pyre-ignore[16] ].reshape(batch_dim + (4,)) def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret def quat_w_first(rot): rot = torch.cat([rot[..., [-1]], rot[..., :-1]], -1) return rot @torch.jit.script def quat_from_euler_xyz(roll, pitch, yaw): cy = torch.cos(yaw * 0.5) sy = torch.sin(yaw * 0.5) cr = torch.cos(roll * 0.5) sr = torch.sin(roll * 0.5) cp = torch.cos(pitch * 0.5) sp = torch.sin(pitch * 0.5) qw = cy * cr * cp + sy * sr * sp qx = cy * sr * cp - sy * cr * sp qy = cy * cr * sp + sy * sr * cp qz = sy * cr * cp - cy * sr * sp return torch.stack([qx, qy, qz, qw], dim=-1) ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/include/common/motor_crc.h ================================================ /***************************************************************** Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. ******************************************************************/ #ifndef _MOTOR_CRC_H_ #define _MOTOR_CRC_H_ #include #include #include "rclcpp/rclcpp.hpp" #include "unitree_go/msg/low_cmd.hpp" #include "unitree_go/msg/motor_cmd.hpp" #include "unitree_go/msg/bms_cmd.hpp" constexpr int HIGHLEVEL = 0xee; constexpr int LOWLEVEL = 0xff; constexpr int TRIGERLEVEL = 0xf0; constexpr double PosStopF = (2.146E+9f); constexpr double VelStopF = (16000.0f); // joint index constexpr int FR_0 = 0; constexpr int FR_1 = 1; constexpr int FR_2 = 2; constexpr int FL_0 = 3; constexpr int FL_1 = 4; constexpr int FL_2 = 5; constexpr int RR_0 = 6; constexpr int RR_1 = 7; constexpr int RR_2 = 8; constexpr int RL_0 = 9; constexpr int RL_1 = 10; constexpr int RL_2 = 11; typedef struct { uint8_t off; // off 0xA5 std::array reserve; } BmsCmd; typedef struct { uint8_t mode; // desired working mode float q; // desired angle (unit: radian) float dq; // desired velocity (unit: radian/second) float tau; // desired output torque (unit: N.m) float Kp; // desired position stiffness (unit: N.m/rad ) float Kd; // desired velocity stiffness (unit: N.m/(rad/s) ) std::array reserve; } MotorCmd; // motor control typedef struct { std::array head; uint8_t levelFlag; uint8_t frameReserve; std::array SN; std::array version; uint16_t bandWidth; std::array motorCmd; BmsCmd bms; std::array wirelessRemote; std::array led; std::array fan; uint8_t gpio; uint32_t reserve; uint32_t crc; } LowCmd; uint32_t crc32_core(uint32_t* ptr, uint32_t len); void get_crc(unitree_go::msg::LowCmd& msg); #endif ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/include/common/motor_crc_hg.h ================================================ /***************************************************************** Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. ******************************************************************/ #ifndef _MOTOR_CRC_H_ #define _MOTOR_CRC_H_ #include #include #include "rclcpp/rclcpp.hpp" #include "unitree_hg/msg/low_cmd.hpp" #include "unitree_hg/msg/motor_cmd.hpp" typedef struct { uint8_t mode; // desired working mode float q; // desired angle (unit: radian) float dq; // desired velocity (unit: radian/second) float tau; // desired output torque (unit: N.m) float Kp; // desired position stiffness (unit: N.m/rad ) float Kd; // desired velocity stiffness (unit: N.m/(rad/s) ) uint32_t reserve = 0; } MotorCmd; // motor control typedef struct { uint8_t modePr; uint8_t modeMachine; std::array motorCmd; std::array reserve; uint32_t crc; } LowCmd; uint32_t crc32_core(uint32_t *ptr, uint32_t len); void get_crc(unitree_hg::msg::LowCmd &msg); #endif ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/include/common/ros2_sport_client.h ================================================ #ifndef _ROS2_SPORT_CLIENT_ #define _ROS2_SPORT_CLIENT_ #include #include "nlohmann/json.hpp" #include "unitree_api/msg/request.hpp" #pragma pack(1) const int32_t ROBOT_SPORT_API_ID_DAMP = 1001; const int32_t ROBOT_SPORT_API_ID_BALANCESTAND = 1002; const int32_t ROBOT_SPORT_API_ID_STOPMOVE = 1003; const int32_t ROBOT_SPORT_API_ID_STANDUP = 1004; const int32_t ROBOT_SPORT_API_ID_STANDDOWN = 1005; const int32_t ROBOT_SPORT_API_ID_RECOVERYSTAND = 1006; const int32_t ROBOT_SPORT_API_ID_EULER = 1007; const int32_t ROBOT_SPORT_API_ID_MOVE = 1008; const int32_t ROBOT_SPORT_API_ID_SIT = 1009; const int32_t ROBOT_SPORT_API_ID_RISESIT = 1010; const int32_t ROBOT_SPORT_API_ID_SWITCHGAIT = 1011; const int32_t ROBOT_SPORT_API_ID_TRIGGER = 1012; const int32_t ROBOT_SPORT_API_ID_BODYHEIGHT = 1013; const int32_t ROBOT_SPORT_API_ID_FOOTRAISEHEIGHT = 1014; const int32_t ROBOT_SPORT_API_ID_SPEEDLEVEL = 1015; const int32_t ROBOT_SPORT_API_ID_HELLO = 1016; const int32_t ROBOT_SPORT_API_ID_STRETCH = 1017; const int32_t ROBOT_SPORT_API_ID_TRAJECTORYFOLLOW = 1018; const int32_t ROBOT_SPORT_API_ID_CONTINUOUSGAIT = 1019; const int32_t ROBOT_SPORT_API_ID_CONTENT = 1020; const int32_t ROBOT_SPORT_API_ID_WALLOW = 1021; const int32_t ROBOT_SPORT_API_ID_DANCE1 = 1022; const int32_t ROBOT_SPORT_API_ID_DANCE2 = 1023; const int32_t ROBOT_SPORT_API_ID_GETBODYHEIGHT = 1024; const int32_t ROBOT_SPORT_API_ID_GETFOOTRAISEHEIGHT = 1025; const int32_t ROBOT_SPORT_API_ID_GETSPEEDLEVEL = 1026; const int32_t ROBOT_SPORT_API_ID_SWITCHJOYSTICK = 1027; const int32_t ROBOT_SPORT_API_ID_POSE = 1028; const int32_t ROBOT_SPORT_API_ID_SCRAPE = 1029; const int32_t ROBOT_SPORT_API_ID_FRONTFLIP = 1030; const int32_t ROBOT_SPORT_API_ID_FRONTJUMP = 1031; const int32_t ROBOT_SPORT_API_ID_FRONTPOUNCE = 1032; typedef struct { float timeFromStart; float x; float y; float yaw; float vx; float vy; float vyaw; } PathPoint; class SportClient { public: /* * @brief Damp * @api: 1001 */ void Damp(unitree_api::msg::Request &req); /* * @brief BalanceStand * @api: 1002 */ void BalanceStand(unitree_api::msg::Request &req); /* * @brief StopMove * @api: 1003 */ void StopMove(unitree_api::msg::Request &req); /* * @brief StandUp * @api: 1004 */ void StandUp(unitree_api::msg::Request &req); /* * @brief StandDown * @api: 1005 */ void StandDown(unitree_api::msg::Request &req); /* * @brief RecoveryStand * @api: 1006 */ void RecoveryStand(unitree_api::msg::Request &req); /* * @brief Euler * @api: 1007 */ void Euler(unitree_api::msg::Request &req, float roll, float pitch, float yaw); /* * @brief Move * @api: 1008 */ void Move(unitree_api::msg::Request &req, float vx, float vy, float vyaw); /* * @brief Sit * @api: 1009 */ void Sit(unitree_api::msg::Request &req); /* * @brief RiseSit * @api: 1010 */ void RiseSit(unitree_api::msg::Request &req); /* * @brief SwitchGait * @api: 1011 */ void SwitchGait(unitree_api::msg::Request &req, int d); /* * @brief Trigger * @api: 1012 */ void Trigger(unitree_api::msg::Request &req); /* * @brief BodyHeight * @api: 1013 */ void BodyHeight(unitree_api::msg::Request &req, float height); /* * @brief FootRaiseHeight * @api: 1014 */ void FootRaiseHeight(unitree_api::msg::Request &req, float height); /* * @brief SpeedLevel * @api: 1015 */ void SpeedLevel(unitree_api::msg::Request &req, int level); /* * @brief Hello * @api: 1016 */ void Hello(unitree_api::msg::Request &req); /* * @brief Stretch * @api: 1017 */ void Stretch(unitree_api::msg::Request &req); /* * @brief TrajectoryFollow * @api: 1018 */ void TrajectoryFollow(unitree_api::msg::Request &req, std::vector &path); /* * @brief SwitchJoystick * @api: 1027 */ void SwitchJoystick(unitree_api::msg::Request &req, bool flag); /* * @brief ContinuousGait * @api: 1019 */ void ContinuousGait(unitree_api::msg::Request &req, bool flag); /* * @brief Wallow * @api: 1021 */ void Wallow(unitree_api::msg::Request &req); /* * @brief Content * @api: 1020 */ void Content(unitree_api::msg::Request &req); /* * @brief Pose * @api: 1028 */ void Pose(unitree_api::msg::Request &req, bool flag); /* * @brief Scrape * @api: 1029 */ void Scrape(unitree_api::msg::Request &req); /* * @brief FrontFlip * @api: 1030 */ void FrontFlip(unitree_api::msg::Request &req); /* * @brief FrontJump * @api: 1031 */ void FrontJump(unitree_api::msg::Request &req); /* * @brief FrontPounce * @api: 1032 */ void FrontPounce(unitree_api::msg::Request &req); /* * @brief Dance1 * @api: 1022 */ void Dance1(unitree_api::msg::Request &req); /* * @brief Dance2 * @api: 1023 */ void Dance2(unitree_api::msg::Request &req); }; #endif ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/include/common/wireless_controller.h ================================================ #pragma once #include "unitree_go/msg/wireless_controller.hpp" #include #include class KeyMap { public: static const int R1; static const int L1; static const int start; static const int select; static const int R2; static const int L2; static const int F1; static const int F2; static const int A; static const int B; static const int X; static const int Y; static const int up; static const int right; static const int down; static const int left; }; class RemoteController { public: // Constructor RemoteController(); // Add overloaded set method for raw data void set(const std::array& data); // Keep original method for compatibility void set(const unitree_go::msg::WirelessController::SharedPtr msg); // Member variables double lx; double ly; double rx; double ry; int button[16]; }; ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py ================================================ """ HoloMotion ROS2 Launch Configuration This module defines the ROS2 launch configuration for the HoloMotion humanoid robot control system. It sets up a complete robotics pipeline including robot control, motion policy execution, and data recording for the Unitree G1 humanoid robot. The launch file coordinates three main components: 1. Main control node (C++) - Handles low-level robot control and communication 2. Policy node (Python) - Executes motion policies and high-level decision making 3. Recording node - Captures sensor data and commands for analysis Key Features: - Configures network interface for robot communication - Sets up CycloneDDS middleware with specific network interface - Launches coordinated multi-node system with shared configuration - Automatically records operational data with timestamped bags Author: HoloMotion Team License: See project LICENSE file """ from datetime import datetime import os from launch import LaunchDescription from launch.actions import SetEnvironmentVariable, DeclareLaunchArgument from launch.substitutions import LaunchConfiguration from launch_ros.actions import Node from ament_index_python.packages import get_package_share_directory from launch.actions import ExecuteProcess from launch.conditions import IfCondition def generate_launch_description(): """ Generate the complete launch description for the HoloMotion humanoid control system. This function creates a comprehensive ROS2 launch configuration that coordinates multiple nodes required for humanoid robot operation. It sets up the necessary environment, launches control nodes, and optionally initiates data recording. Network Configuration: - Uses specific network interface (eth0) for robot communication - Configures CycloneDDS middleware to use designated network interface - Ensures proper isolation and communication with the robot hardware Node Architecture: 1. Main Control Node (C++): - Handles real-time robot control and sensor data processing - Manages low-level motor commands and feedback loops - Interfaces directly with robot hardware via configured network 2. Policy Node (Python): - Executes trained motion policies for humanoid locomotion - Processes high-level commands and translates to robot actions - Handles motion planning and behavior coordination 3. Recording Node (Optional): - Automatically captures all relevant system data when enabled - Records sensor states, commands, and system metrics - Creates timestamped bag files for later analysis Configuration: - Robot: Unitree G1 with 29 DOF configuration - Config file: g1_29dof_holomotion.yaml - Recording format: MCAP for efficient data storage - Recording: Disabled by default, can be enabled with --record parameter Recorded Topics (when recording enabled): - /lowcmd: Low-level motor commands sent to robot - /lowstate: Robot sensor feedback and joint states - /humanoid/action: High-level action commands from policy Parameters: - enable_recording: Boolean flag to enable/disable topic recording (default: false) Returns: LaunchDescription: Complete ROS2 launch configuration with all nodes, environment variables, and optional recording setup Raises: FileNotFoundError: If the configuration file cannot be located PermissionError: If unable to create recording directory Example: Launch without recording (default): $ ros2 launch humanoid_control holomotion_29dof.launch.py Launch with recording enabled: $ ros2 launch humanoid_control holomotion_29dof.launch.py enable_recording:=true Or using the shell script: $ ./launch_holomotion_29dof.sh --record """ # Declare launch arguments enable_recording_arg = DeclareLaunchArgument( "enable_recording", default_value="false", description="Enable topic recording (true/false)", ) network_interface = "eth0" config_name = "g1_29dof_holomotion.yaml" pkg_dir = get_package_share_directory("humanoid_control") config_file = os.path.join(pkg_dir, "config", config_name) # Allow overriding python interpreter via env var (set by the shell script) python_executable = os.environ["Deploy_CONDA_PREFIX"] + "/bin/python" print(f"Using Python executable: {python_executable}") return LaunchDescription( [ # Declare launch arguments enable_recording_arg, # Main control node (C++) SetEnvironmentVariable( name="CYCLONEDDS_URI", value=f"{network_interface}", ), Node( package="humanoid_control", executable="humanoid_control", name="main_node", parameters=[{"config_path": config_file}], output="screen", ), # Policy node (Python) Node( package="humanoid_control", executable="policy_node_29dof", name="policy_node", parameters=[{"config_path": config_file}], output="screen", prefix=python_executable, ), # Recording node (conditional) ExecuteProcess( cmd=[ "ros2", "bag", "record", "--storage", "mcap", "-o", ( "./bag_record/" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_" + config_name.split(".")[0] ), "/lowcmd", "/lowstate", "/humanoid/action", ], output="screen", condition=IfCondition(LaunchConfiguration("enable_recording")), ), ] ) ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/models/.gitkeep ================================================ ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/motion_data/.gitkeep ================================================ ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/package.xml ================================================ humanoid_control 0.0.0 Humanoid locomotion control package from Horizon Robotics unitree TODO: License declaration ament_cmake ament_cmake_python rclcpp sensor_msgs unitree_hg rclpy python3-numpy python3-torch python3-yaml ros2launch ament_copyright ament_flake8 ament_pep257 python3-pytest ament_cmake ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/resource/humanoid_control ================================================ ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/setup.cfg ================================================ [develop] script_dir=$base/lib/humanoid_control [install] install_scripts=$base/lib/humanoid_control ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/setup.py ================================================ from setuptools import setup, find_packages import os package_name = "humanoid_control" data_files = [ ( "share/ament_index/resource_index/packages", ["resource/" + package_name], ), ("share/" + package_name, ["package.xml"]), ] # Add files from config, launch and model directories for dir_name in ["config", "launch", "models"]: if os.path.exists(dir_name): # Only process if directory exists for root, dirs, files in os.walk(dir_name): install_dir = os.path.join("share", package_name, root) list_entry = (install_dir, [os.path.join(root, f) for f in files]) data_files.append(list_entry) setup( name=package_name, version="0.0.1", packages=find_packages(), data_files=data_files, install_requires=["setuptools"], zip_safe=True, maintainer="Horizon Robotics", maintainer_email="maiyue01.chen@horizon.auto", description="Humanoid locomotion control package from Horizon Robotics", license="Apache License 2.0", tests_require=["pytest"], entry_points={ "console_scripts": [ "policy_node_performance = humanoid_policy.policy_node_performance:main", ], }, ) ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/src/common/motor_crc.cpp ================================================ #include "motor_crc.h" void get_crc(unitree_go::msg::LowCmd& msg) { LowCmd raw{}; memcpy(&raw.head[0], &msg.head[0], 2); raw.levelFlag=msg.level_flag; raw.frameReserve=msg.frame_reserve; memcpy(&raw.SN[0],&msg.sn[0], 8); memcpy(&raw.version[0], &msg.version[0], 8); raw.bandWidth=msg.bandwidth; for(int i = 0; i<20; i++) { raw.motorCmd[i].mode=msg.motor_cmd[i].mode; raw.motorCmd[i].q=msg.motor_cmd[i].q; raw.motorCmd[i].dq=msg.motor_cmd[i].dq; raw.motorCmd[i].tau=msg.motor_cmd[i].tau; raw.motorCmd[i].Kp=msg.motor_cmd[i].kp; raw.motorCmd[i].Kd=msg.motor_cmd[i].kd; memcpy(&raw.motorCmd[i].reserve[0], &msg.motor_cmd[i].reserve[0], 12); } raw.bms.off=msg.bms_cmd.off; memcpy(&raw.bms.reserve[0],&msg.bms_cmd.reserve[0], 3); memcpy(&raw.wirelessRemote[0], &msg.wireless_remote[0], 40); memcpy(&raw.led[0], &msg.led[0], 12); // go2 memcpy(&raw.fan[0], &msg.fan[0], 2); raw.gpio=msg.gpio; // go2 raw.reserve=msg.reserve; raw.crc=crc32_core((uint32_t *)&raw, (sizeof(LowCmd)>>2)-1); msg.crc=raw.crc; } uint32_t crc32_core(uint32_t* ptr, uint32_t len) { uint32_t xbit = 0; uint32_t data = 0; uint32_t CRC32 = 0xFFFFFFFF; const uint32_t dwPolynomial = 0x04c11db7; for (uint32_t i = 0; i < len; i++) { xbit = 1 << 31; data = ptr[i]; for (uint32_t bits = 0; bits < 32; bits++) { if (CRC32 & 0x80000000) { CRC32 <<= 1; CRC32 ^= dwPolynomial; } else CRC32 <<= 1; if (data & xbit) CRC32 ^= dwPolynomial; xbit >>= 1; } } return CRC32; } ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/src/common/motor_crc_hg.cpp ================================================ #include "motor_crc_hg.h" void get_crc(unitree_hg::msg::LowCmd &msg) { LowCmd raw{}; raw.modePr = msg.mode_pr; raw.modeMachine = msg.mode_machine; for (int i = 0; i < 35; i++) { raw.motorCmd[i].mode = msg.motor_cmd[i].mode; raw.motorCmd[i].q = msg.motor_cmd[i].q; raw.motorCmd[i].dq = msg.motor_cmd[i].dq; raw.motorCmd[i].tau = msg.motor_cmd[i].tau; raw.motorCmd[i].Kp = msg.motor_cmd[i].kp; raw.motorCmd[i].Kd = msg.motor_cmd[i].kd; raw.motorCmd[i].reserve = msg.motor_cmd[i].reserve; } memcpy(&raw.reserve[0], &msg.reserve[0], 4); raw.crc = crc32_core((uint32_t *)&raw, (sizeof(LowCmd) >> 2) - 1); msg.crc = raw.crc; } uint32_t crc32_core(uint32_t *ptr, uint32_t len) { uint32_t xbit = 0; uint32_t data = 0; uint32_t CRC32 = 0xFFFFFFFF; const uint32_t dwPolynomial = 0x04c11db7; for (uint32_t i = 0; i < len; i++) { xbit = 1 << 31; data = ptr[i]; for (uint32_t bits = 0; bits < 32; bits++) { if (CRC32 & 0x80000000) { CRC32 <<= 1; CRC32 ^= dwPolynomial; } else CRC32 <<= 1; if (data & xbit) CRC32 ^= dwPolynomial; xbit >>= 1; } } return CRC32; } ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/src/common/ros2_sport_client.cpp ================================================ #include "ros2_sport_client.h" void SportClient::Damp(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_DAMP; } void SportClient::BalanceStand(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_BALANCESTAND; } void SportClient::StopMove(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_STOPMOVE; } void SportClient::StandUp(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_STANDUP; } void SportClient::StandDown(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_STANDDOWN; } void SportClient::RecoveryStand(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_RECOVERYSTAND; } void SportClient::Euler(unitree_api::msg::Request &req, float roll, float pitch, float yaw) { nlohmann::json js; js["x"] = roll; js["y"] = pitch; js["z"] = yaw; req.parameter = js.dump(); req.header.identity.api_id = ROBOT_SPORT_API_ID_EULER; } void SportClient::Move(unitree_api::msg::Request &req, float vx, float vy, float vyaw) { nlohmann::json js; js["x"] = vx; js["y"] = vy; js["z"] = vyaw; req.parameter = js.dump(); req.header.identity.api_id = ROBOT_SPORT_API_ID_MOVE; } void SportClient::Sit(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_SIT; } void SportClient::RiseSit(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_RISESIT; } void SportClient::SwitchGait(unitree_api::msg::Request &req, int d) { nlohmann::json js; js["data"] = d; req.header.identity.api_id = ROBOT_SPORT_API_ID_SWITCHGAIT; req.parameter = js.dump(); } void SportClient::Trigger(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_TRIGGER; } void SportClient::BodyHeight(unitree_api::msg::Request &req, float height) { nlohmann::json js; js["data"] = height; req.parameter = js.dump(); req.header.identity.api_id = ROBOT_SPORT_API_ID_BODYHEIGHT; } void SportClient::FootRaiseHeight(unitree_api::msg::Request &req, float height) { nlohmann::json js; js["data"] = height; req.parameter = js.dump(); req.header.identity.api_id = ROBOT_SPORT_API_ID_FOOTRAISEHEIGHT; } void SportClient::SpeedLevel(unitree_api::msg::Request &req, int level) { nlohmann::json js; js["data"] = level; req.parameter = js.dump(); req.header.identity.api_id = ROBOT_SPORT_API_ID_SPEEDLEVEL; } void SportClient::Hello(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_HELLO; } void SportClient::Stretch(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_STRETCH; } void SportClient::TrajectoryFollow(unitree_api::msg::Request &req, std::vector &path) { nlohmann::json js_path; req.header.identity.api_id = ROBOT_SPORT_API_ID_TRAJECTORYFOLLOW; for (int i = 0; i < 30; i++) { nlohmann::json js_point; js_point["t_from_start"] = path[i].timeFromStart; js_point["x"] = path[i].x; js_point["y"] = path[i].y; js_point["yaw"] = path[i].yaw; js_point["vx"] = path[i].vx; js_point["vy"] = path[i].vy; js_point["vyaw"] = path[i].vyaw; js_path.push_back(js_point); } req.parameter =js_path.dump(); } void SportClient::SwitchJoystick(unitree_api::msg::Request &req, bool flag) { nlohmann::json js; js["data"] = flag; req.parameter = js.dump(); req.header.identity.api_id = ROBOT_SPORT_API_ID_SWITCHJOYSTICK; } void SportClient::ContinuousGait(unitree_api::msg::Request &req, bool flag) { nlohmann::json js; js["data"] = flag; req.parameter = js.dump(); req.header.identity.api_id = ROBOT_SPORT_API_ID_CONTINUOUSGAIT; } void SportClient::Wallow(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_WALLOW; } void SportClient::Content(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_CONTENT; } void SportClient::Pose(unitree_api::msg::Request &req, bool flag) { nlohmann::json js; js["data"] = flag; req.parameter = js.dump(); req.header.identity.api_id = ROBOT_SPORT_API_ID_POSE; } void SportClient::Scrape(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_SCRAPE; } void SportClient::FrontFlip(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTFLIP; } void SportClient::FrontJump(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTJUMP; } void SportClient::FrontPounce(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTPOUNCE; } void SportClient::Dance1(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_DANCE1; } void SportClient::Dance2(unitree_api::msg::Request &req) { req.header.identity.api_id = ROBOT_SPORT_API_ID_DANCE2; } ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/src/common/wireless_controller.cpp ================================================ #include "common/wireless_controller.h" #include // Define static constants const int KeyMap::R1 = 0; const int KeyMap::L1 = 1; const int KeyMap::start = 2; const int KeyMap::select = 3; const int KeyMap::R2 = 4; const int KeyMap::L2 = 5; const int KeyMap::F1 = 6; const int KeyMap::F2 = 7; const int KeyMap::A = 8; const int KeyMap::B = 9; const int KeyMap::X = 10; const int KeyMap::Y = 11; const int KeyMap::up = 12; const int KeyMap::right = 13; const int KeyMap::down = 14; const int KeyMap::left = 15; // Implement RemoteController methods RemoteController::RemoteController() { lx = 0; ly = 0; rx = 0; ry = 0; std::fill(button, button + 16, 0); } void RemoteController::set(const std::array &data) { // Debug print raw bytes // printf("Raw data bytes: "); // for (int i = 0; i < 40; i++) { // printf("%02x ", data[i]); // } // printf("\n"); // Extract keys from bytes 2-3 uint16_t keys = (data[3] << 8) | data[2]; // printf("Keys value: 0x%04x\n", keys); for (int i = 0; i < 16; i++) { button[i] = (keys & (1 << i)) >> i; } // Extract and print floats before memcpy float lx_temp, rx_temp, ry_temp, ly_temp; std::memcpy(&lx_temp, &data[4], 4); // bytes 4-7 std::memcpy(&rx_temp, &data[8], 4); // bytes 8-11 std::memcpy(&ry_temp, &data[12], 4); // bytes 12-15 std::memcpy(&ly_temp, &data[20], 4); // bytes 20-23 // printf("Values before assignment: lx=%f, ly=%f, rx=%f, ry=%f\n", lx_temp, // ly_temp, rx_temp, ry_temp); // Assign to class members lx = lx_temp; rx = rx_temp; ry = ry_temp; ly = ly_temp; // printf("Values after assignment: lx=%f, ly=%f, rx=%f, ry=%f\n", lx, ly, // rx, // ry); } void RemoteController::set( const unitree_go::msg::WirelessController::SharedPtr msg) { uint16_t keys = msg->keys; for (int i = 0; i < 16; i++) { button[i] = (keys & (1 << i)) >> i; } lx = msg->lx; rx = msg->rx; ry = msg->ry; ly = msg->ly; } ================================================ FILE: deployment/unitree_g1_ros2_29dof/src/src/main_node.cpp ================================================ /** * This example demonstrates how to use ROS2 to send low-level motor commands of *unitree g1 robot 29 dof **/ #include "common/motor_crc_hg.h" #include "common/wireless_controller.h" #include "rclcpp/rclcpp.hpp" #include "unitree_go/msg/wireless_controller.hpp" #include "unitree_hg/msg/low_cmd.hpp" #include "unitree_hg/msg/low_state.hpp" #include "unitree_hg/msg/motor_cmd.hpp" #include #include #include #include #include #include #include #include #include #define INFO_IMU 0 // Set 1 to info IMU states #define INFO_MOTOR 0 // Set 1 to info motor states enum PRorAB { PR = 0, AB = 1 }; using std::placeholders::_1; const int G1_NUM_MOTOR = 29; enum class RobotState { ZERO_TORQUE, MOVE_TO_DEFAULT, EMERGENCY_STOP, POLICY }; enum class EmergencyStopPhase { DAMPING, DISABLE }; // New enum for emergency stop phases // Create a humanoid_controller class for low state receive class humanoid_controller : public rclcpp::Node { public: humanoid_controller() : Node("humanoid_controller") { RCLCPP_INFO(this->get_logger(), "Using main_node !!!"); // Get config path from ROS parameter std::string config_path = this->declare_parameter("config_path", ""); RCLCPP_INFO(this->get_logger(), "Config file path: %s", config_path.c_str()); // Load configuration loadConfig(config_path); RCLCPP_INFO(this->get_logger(), "Entered ZERO_TORQUE state, press start to switch to " "MOVE_TO_DEFAULT state, press A to switch to POLICY state, " "press select to emergency stop. Waiting for start signal..."); lowstate_subscriber_ = this->create_subscription( "/lowstate", 10, std::bind(&humanoid_controller::LowStateHandler, this, _1)); policy_action_subscriber_ = this->create_subscription( "/humanoid/action", 10, std::bind(&humanoid_controller::PolicyActionHandler, this, _1)); // Add subscribers for kps and kds parameters from policy node kps_subscriber_ = this->create_subscription( "/humanoid/kps", 10, std::bind(&humanoid_controller::KpsHandler, this, _1)); kds_subscriber_ = this->create_subscription( "/humanoid/kds", 10, std::bind(&humanoid_controller::KdsHandler, this, _1)); lowcmd_publisher_ = this->create_publisher("/lowcmd", 10); robot_state_publisher_ = this->create_publisher("/robot_state", 10); timer_ = this->create_wall_timer(std::chrono::milliseconds(timer_dt), std::bind(&humanoid_controller::Control, this)); time_ = 0; duration_ = 3; // 3 s } private: std::map dof2motor_idx; std::map default_dof_pos; std::map kps; std::map kds; std::vector complete_dof_order; std::vector policy_dof_order; // Policy-subscribed control parameters std::vector policy_kps_data; std::vector policy_kds_data; bool kps_received_ = false; bool kds_received_ = false; RemoteController remote_controller; std::map target_dof_pos; std::vector policy_action_data; // Optional arrays from YAML for start (MOVE_TO_DEFAULT) behavior bool has_joint_arrays_ = false; std::vector joint_names_array_; std::vector default_position_array_; std::vector kp_array_; std::vector kd_array_; std::map move_to_default_kps; std::map move_to_default_kds; RobotState current_state_ = RobotState::ZERO_TORQUE; bool should_shutdown_ = false; // Add safety limit parameters using existing structure in YAML std::map> joint_position_limits; // min, max std::map joint_velocity_limits; std::map joint_effort_limits; // Scaling coefficients for limits double position_limit_scale = 1.0; double velocity_limit_scale = 1.0; double effort_limit_scale = 1.0; EmergencyStopPhase emergency_stop_phase_ = EmergencyStopPhase::DAMPING; double emergency_stop_time_ = 0.0; double emergency_damping_duration_ = 2.0; // 1 second of damping before disabling // Add a helper function to calculate expected torque double calculateExpectedTorque(const std::string& dof_name, double q_des, double q, double dq) { double kp = kps[dof_name]; double kd = kds[dof_name]; // dq_des is assumed to be 0 in your control scheme return kp * (q_des - q) + kd * (0.0 - dq); } // Add a helper function to scale kp and kd to limit torque std::pair limitTorque(const std::string& dof_name, double q_des, double q, double dq) { double kp = kps[dof_name]; double kd = kds[dof_name]; // Calculate expected torque double expected_torque = calculateExpectedTorque(dof_name, q_des, q, dq); double abs_expected_torque = std::abs(expected_torque); // Check if torque would exceed limit if (joint_effort_limits.find(dof_name) != joint_effort_limits.end()) { double max_torque = joint_effort_limits[dof_name] * effort_limit_scale; if (abs_expected_torque > max_torque && abs_expected_torque > 1e-6) { // Scale both kp and kd by the same factor to preserve damping characteristics double scale_factor = max_torque / abs_expected_torque; return std::make_pair(kp * scale_factor, kd * scale_factor); } } // If no scaling needed, return original values return std::make_pair(kp, kd); } // Add a helper function to scale custom kp and kd to limit torque std::pair limitTorqueWithCustomGains( const std::string& dof_name, double q_des, double q, double dq, double custom_kp, double custom_kd) { // Calculate expected torque double expected_torque = custom_kp * (q_des - q) + custom_kd * (0.0 - dq); double abs_expected_torque = std::abs(expected_torque); // Check if torque would exceed limit if (joint_effort_limits.find(dof_name) != joint_effort_limits.end()) { double max_torque = joint_effort_limits[dof_name] * effort_limit_scale; if (abs_expected_torque > max_torque && abs_expected_torque > 1e-6) { // Scale both kp and kd by the same factor to preserve damping characteristics double scale_factor = max_torque / abs_expected_torque; return std::make_pair(custom_kp * scale_factor, custom_kd * scale_factor); } } // If no scaling needed, return original values return std::make_pair(custom_kp, custom_kd); } void loadConfig(const std::string &config_path) { try { YAML::Node config = YAML::LoadFile(config_path); // Load motor indices auto indices = config["dof2motor_idx_mapping"]; for (const auto &it : indices) { dof2motor_idx[it.first.as()] = it.second.as(); } // Load default angles auto angles = config["default_joint_angles"]; for (const auto &it : angles) { default_dof_pos[it.first.as()] = it.second.as(); } // Set target dof pos to default dof pos for (const auto &it : default_dof_pos) { target_dof_pos[it.first] = it.second; } // Note: kps and kds are now received from policy node via ROS topics // No longer loading from config file to avoid conflicts // Load dof order for (const auto &it : config["complete_dof_order"]) { complete_dof_order.push_back(it.as()); } for (const auto &it : config["policy_dof_order"]) { policy_dof_order.push_back(it.as()); } // Load control frequency control_freq_ = config["control_freq"].as(); control_dt_ = 1.0 / control_freq_; timer_dt = static_cast(control_dt_ * 1000); RCLCPP_INFO(this->get_logger(), "Control frequency set to: %f Hz", control_freq_); // Load joint limits auto pos_limits = config["joint_limits"]["position"]; for (const auto &it : pos_limits) { std::string dof_name = it.first.as(); auto limits = it.second.as>(); joint_position_limits[dof_name] = std::make_pair(limits[0], limits[1]); } auto vel_limits = config["joint_limits"]["velocity"]; for (const auto &it : vel_limits) { joint_velocity_limits[it.first.as()] = it.second.as(); } auto effort_limits = config["joint_limits"]["effort"]; for (const auto &it : effort_limits) { joint_effort_limits[it.first.as()] = it.second.as(); } // Load joint limits scaling coefficients (optional, default to 1.0) position_limit_scale = config["limit_scales"]["position"].as(1.0); velocity_limit_scale = config["limit_scales"]["velocity"].as(1.0); effort_limit_scale = config["limit_scales"]["effort"].as(1.0); RCLCPP_INFO(this->get_logger(), "Joint limit scales - Position: %f, Velocity: %f, Effort: %f", position_limit_scale, velocity_limit_scale, effort_limit_scale); // Optional: arrays for joint configuration on Start // If kp and kd arrays are provided, use them with joint names and positions // Auto-generate joint_names and default_position from complete_dof_order and default_joint_angles if not provided if (config["kp"] && config["kd"]) { joint_names_array_.clear(); default_position_array_.clear(); kp_array_.clear(); kd_array_.clear(); // Auto-generate joint_names and default_position from existing config if not explicitly provided if (config["joint_names"] && config["default_position"]) { // Use explicitly provided arrays for (const auto &it : config["joint_names"]) { joint_names_array_.push_back(it.as()); } for (const auto &it : config["default_position"]) { default_position_array_.push_back(it.as()); } } else { // Auto-generate from complete_dof_order and default_joint_angles for (const auto &dof_name : complete_dof_order) { joint_names_array_.push_back(dof_name); if (default_dof_pos.find(dof_name) != default_dof_pos.end()) { default_position_array_.push_back(default_dof_pos[dof_name]); } else { RCLCPP_WARN(this->get_logger(), "Default position not found for joint %s, using 0.0", dof_name.c_str()); default_position_array_.push_back(0.0); } } RCLCPP_INFO(this->get_logger(), "Auto-generated joint_names and default_position from complete_dof_order and default_joint_angles"); } // Load kp and kd arrays for (const auto &it : config["kp"]) { kp_array_.push_back(it.as()); } for (const auto &it : config["kd"]) { kd_array_.push_back(it.as()); } // Basic validation if (joint_names_array_.size() == default_position_array_.size() && joint_names_array_.size() == kp_array_.size() && joint_names_array_.size() == kd_array_.size()) { has_joint_arrays_ = true; // Store MoveToDefault-specific kps/kds and default positions for (size_t i = 0; i < joint_names_array_.size(); ++i) { const std::string &name = joint_names_array_[i]; double pos = default_position_array_[i]; double kp_v = kp_array_[i]; double kd_v = kd_array_[i]; default_dof_pos[name] = pos; // Store MoveToDefault kp/kd move_to_default_kps[name] = kp_v; move_to_default_kds[name] = kd_v; } RCLCPP_INFO(this->get_logger(), "Using joint arrays for Start behavior (size: %zu)", joint_names_array_.size()); } else { RCLCPP_WARN(this->get_logger(), "joint_names/default_position/kp/kd size mismatch; ignoring arrays"); } } } catch (const YAML::Exception &e) { RCLCPP_ERROR(this->get_logger(), "Error parsing config file: %s", e.what()); } } void Control() { // First check if we're already in emergency stop if (current_state_ == RobotState::EMERGENCY_STOP) { emergency_stop_time_ += control_dt_; if (emergency_stop_phase_ == EmergencyStopPhase::DAMPING) { SendDampedEmergencyStop(); if (emergency_stop_time_ >= emergency_damping_duration_) { emergency_stop_phase_ = EmergencyStopPhase::DISABLE; RCLCPP_INFO(this->get_logger(), "Damping complete, disabling motors"); } } else { SendFinalEmergencyStop(); if (timer_) { timer_->cancel(); } rclcpp::shutdown(); return; } get_crc(low_command); lowcmd_publisher_->publish(low_command); return; // Exit early, ignore all other commands } // If not in emergency stop, check for emergency stop command first if (remote_controller.button[KeyMap::select] == 1) { current_state_ = RobotState::EMERGENCY_STOP; should_shutdown_ = true; publishRobotState(); return; } // Process other commands only if not in emergency stop if (remote_controller.button[KeyMap::L1] == 1 && current_state_ != RobotState::ZERO_TORQUE) { RCLCPP_INFO(this->get_logger(), "Switching to ZERO_TORQUE state"); current_state_ = RobotState::ZERO_TORQUE; publishRobotState(); } // Start button only works in ZERO_TORQUE state if (remote_controller.button[KeyMap::start] == 1) { if (current_state_ == RobotState::ZERO_TORQUE) { RCLCPP_INFO(this->get_logger(), "Switching to MOVE_TO_DEFAULT state"); current_state_ = RobotState::MOVE_TO_DEFAULT; time_ = 0.0; publishRobotState(); } else { RCLCPP_INFO(this->get_logger(), "Start button only works in ZERO_TORQUE state. Current state: %d", static_cast(current_state_)); } } // A button only works in MOVE_TO_DEFAULT state if (remote_controller.button[KeyMap::A] == 1) { if (current_state_ == RobotState::MOVE_TO_DEFAULT) { // Check if kps and kds parameters have been received from policy node if (!kps_received_ || !kds_received_) { RCLCPP_ERROR(this->get_logger(), "Cannot switch to POLICY state. Control parameters not received from policy node! kps_received: %s, kds_received: %s", kps_received_ ? "true" : "false", kds_received_ ? "true" : "false"); return; } // Check lower body joint positions before allowing transition bool positions_ok = true; std::stringstream deviation_msg; const double position_threshold = 0.4; // List of lower body joints to check std::vector lower_body_joints = { "left_hip_yaw", "left_hip_roll", "left_hip_pitch", "left_knee", "left_ankle_pitch", "left_ankle_roll", "right_hip_yaw", "right_hip_roll", "right_hip_pitch", "right_knee", "right_ankle_pitch", "right_ankle_roll" }; for (int i = 0; i < G1_NUM_MOTOR; ++i) { std::string dof_name = complete_dof_order[i]; // Skip if not a lower body joint if (std::find(lower_body_joints.begin(), lower_body_joints.end(), dof_name) == lower_body_joints.end()) { continue; } double current_pos = motor[i].q; double default_pos = default_dof_pos[dof_name]; double diff = std::abs(current_pos - default_pos); if (diff > position_threshold) { positions_ok = false; deviation_msg << dof_name << "(" << diff << "), "; } } if (positions_ok) { RCLCPP_INFO(this->get_logger(), "Switching to POLICY state"); current_state_ = RobotState::POLICY; time_ = 0.0; publishRobotState(); } else { RCLCPP_WARN(this->get_logger(), "Cannot switch to POLICY state. Lower body joints with large deviations: %s", deviation_msg.str().c_str()); } } else { RCLCPP_INFO(this->get_logger(), "A button only works in MOVE_TO_DEFAULT state. Current state: %d", static_cast(current_state_)); } } // Normal state machine logic switch (current_state_) { case RobotState::ZERO_TORQUE: SendZeroTorqueCommand(); get_crc(low_command); lowcmd_publisher_->publish(low_command); break; case RobotState::MOVE_TO_DEFAULT: SendDefaultPositionCommand(); get_crc(low_command); lowcmd_publisher_->publish(low_command); break; case RobotState::POLICY: SendPolicyCommand(); get_crc(low_command); lowcmd_publisher_->publish(low_command); break; case RobotState::EMERGENCY_STOP: // Emergency stop is handled at the beginning of the function // This case should not be reached due to early return break; } // Publish current robot state publishRobotState(); } void SendZeroTorqueCommand() { low_command.mode_pr = mode_; low_command.mode_machine = mode_machine; for (int i = 0; i < G1_NUM_MOTOR; ++i) { low_command.motor_cmd[i].mode = 1; // Enable low_command.motor_cmd[i].q = 0.0; low_command.motor_cmd[i].dq = 0.0; low_command.motor_cmd[i].kp = 0.0; low_command.motor_cmd[i].kd = 0.0; low_command.motor_cmd[i].tau = 0.0; } } void SendDefaultPositionCommand() { time_ += control_dt_; low_command.mode_pr = mode_; low_command.mode_machine = mode_machine; // Print kp/kd values on first execution static bool first_move_to_default = true; if (first_move_to_default) { RCLCPP_INFO(this->get_logger(), "=== First MOVE_TO_DEFAULT execution ==="); first_move_to_default = false; } if (has_joint_arrays_) { // Use provided arrays and dof2motor mapping to command motors double ratio = clamp(time_ / duration_, 0.0, 1.0); for (size_t j = 0; j < joint_names_array_.size(); ++j) { const std::string &dof_name = joint_names_array_[j]; if (dof2motor_idx.find(dof_name) == dof2motor_idx.end()) { continue; // skip unknown names } int motor_idx = dof2motor_idx[dof_name]; double target_final = default_position_array_[j]; double target_pos = (1. - ratio) * motor[motor_idx].q + ratio * target_final; // Current state double current_pos = motor[motor_idx].q; double current_vel = motor[motor_idx].dq; // Use MoveToDefault specialized kp/kd double kp_to_use = move_to_default_kps[dof_name]; double kd_to_use = move_to_default_kds[dof_name]; // Print kp/kd values on first execution if (time_ <= control_dt_ * 2) { // Print for first few iterations RCLCPP_INFO(this->get_logger(), "MoveToDefault - %s: kp=%.2f, kd=%.2f", dof_name.c_str(), kp_to_use, kd_to_use); } // Apply torque limiting with MoveToDefault gains auto [limited_kp, limited_kd] = limitTorqueWithCustomGains( dof_name, target_pos, current_pos, current_vel, kp_to_use, kd_to_use); low_command.motor_cmd[motor_idx].mode = 1; low_command.motor_cmd[motor_idx].tau = 0.0; low_command.motor_cmd[motor_idx].q = target_pos; low_command.motor_cmd[motor_idx].dq = 0.0; low_command.motor_cmd[motor_idx].kp = limited_kp; low_command.motor_cmd[motor_idx].kd = limited_kd; } } else { // Fall back to map-driven order with default MoveToDefault gains // Use default kp/kd values for MoveToDefault since policy kps/kds are not available yet const double default_move_kp = 50.0; // Default stiffness for MoveToDefault const double default_move_kd = 5.0; // Default damping for MoveToDefault // Print default kp/kd values on first execution if (time_ <= control_dt_ * 2) { // Print for first few iterations RCLCPP_INFO(this->get_logger(), "MoveToDefault (fallback) - Using default kp=%.2f, kd=%.2f", default_move_kp, default_move_kd); } for (int i = 0; i < G1_NUM_MOTOR; ++i) { std::string dof_name = complete_dof_order[i]; double ratio = clamp(time_ / duration_, 0.0, 1.0); double target_pos = (1. - ratio) * motor[i].q + ratio * default_dof_pos[dof_name]; // Current state double current_pos = motor[i].q; double current_vel = motor[i].dq; // Use default MoveToDefault gains with torque limiting auto [limited_kp, limited_kd] = limitTorqueWithCustomGains( dof_name, target_pos, current_pos, current_vel, default_move_kp, default_move_kd); low_command.motor_cmd[i].mode = 1; low_command.motor_cmd[i].tau = 0.0; low_command.motor_cmd[i].q = target_pos; low_command.motor_cmd[i].dq = 0.0; low_command.motor_cmd[i].kp = limited_kp; low_command.motor_cmd[i].kd = limited_kd; } } } void SendPolicyCommand() { time_ += control_dt_; low_command.mode_pr = mode_; low_command.mode_machine = mode_machine; // Print kp/kd values on first execution static bool first_policy_command = true; if (first_policy_command) { RCLCPP_INFO(this->get_logger(), "=== First POLICY command execution ==="); first_policy_command = false; } // Check if kps and kds parameters have been received from policy node if (!kps_received_ || !kds_received_) { RCLCPP_ERROR(this->get_logger(), "Policy control parameters not received! kps_received: %s, kds_received: %s", kps_received_ ? "true" : "false", kds_received_ ? "true" : "false"); RCLCPP_ERROR(this->get_logger(), "Cannot execute POLICY commands without control parameters. Triggering emergency stop."); current_state_ = RobotState::EMERGENCY_STOP; should_shutdown_ = true; publishRobotState(); return; } for (const auto &pair : target_dof_pos) { const std::string &dof_name = pair.first; const double &target_pos = pair.second; int motor_idx = dof2motor_idx[dof_name]; // Get policy kp/kd values double policy_kp = kps[dof_name]; double policy_kd = kds[dof_name]; // Print kp/kd values on first execution if (time_ <= control_dt_ * 2) { // Print for first few iterations RCLCPP_INFO(this->get_logger(), "Policy - %s: kp=%.2f, kd=%.2f", dof_name.c_str(), policy_kp, policy_kd); } // Use policy kp/kd values directly without torque limiting low_command.motor_cmd[motor_idx].mode = 1; low_command.motor_cmd[motor_idx].tau = 0.0; low_command.motor_cmd[motor_idx].q = target_pos; low_command.motor_cmd[motor_idx].dq = 0.0; low_command.motor_cmd[motor_idx].kp = policy_kp; low_command.motor_cmd[motor_idx].kd = policy_kd; } } void SendDampedEmergencyStop() { low_command.mode_pr = mode_; low_command.mode_machine = mode_machine; // Use default damping value for emergency stop since kds may not be available const double default_emergency_kd = 10.0; // Higher damping for faster stopping for (int i = 0; i < G1_NUM_MOTOR; ++i) { std::string dof_name = complete_dof_order[i]; low_command.motor_cmd[i].mode = 1; // Keep enabled low_command.motor_cmd[i].q = motor[i].q; // Current position low_command.motor_cmd[i].dq = 0.0; // Target zero velocity low_command.motor_cmd[i].kp = 0.0; // No position control low_command.motor_cmd[i].kd = default_emergency_kd; // Use default damping low_command.motor_cmd[i].tau = 0.0; } } void SendFinalEmergencyStop() { low_command.mode_pr = mode_; low_command.mode_machine = mode_machine; for (int i = 0; i < G1_NUM_MOTOR; ++i) { low_command.motor_cmd[i].mode = 0; // Disable low_command.motor_cmd[i].q = 0.0; low_command.motor_cmd[i].dq = 0.0; low_command.motor_cmd[i].kp = 0.0; low_command.motor_cmd[i].kd = 0.0; low_command.motor_cmd[i].tau = 0.0; } } void LowStateHandler(unitree_hg::msg::LowState::SharedPtr message) { mode_machine = (int)message->mode_machine; imu = message->imu_state; for (int i = 0; i < G1_NUM_MOTOR; i++) { motor[i] = message->motor_state[i]; } // Check joint limits for all joints bool limits_exceeded = false; std::string exceeded_msg; // Trigger emergency stop if any limits are exceeded if (limits_exceeded) { RCLCPP_ERROR(this->get_logger(), "%s", exceeded_msg.c_str()); RCLCPP_ERROR(this->get_logger(), "Joint limits exceeded! Triggering emergency stop."); // current_state_ = RobotState::EMERGENCY_STOP; // should_shutdown_ = true; // publishRobotState(); } remote_controller.set(message->wireless_remote); } void PolicyActionHandler( const std_msgs::msg::Float32MultiArray::SharedPtr message) { // RCLCPP_INFO(this->get_logger(), "PolicyActionHandler called!"); policy_action_data = message->data; // Check if message size matches expected size if (policy_action_data.size() != policy_dof_order.size()) { RCLCPP_ERROR(this->get_logger(), "Policy action data size mismatch: got %zu, expected %zu", policy_action_data.size(), policy_dof_order.size()); current_state_ = RobotState::EMERGENCY_STOP; should_shutdown_ = true; publishRobotState(); return; } // set target dof pos for (size_t i = 0; i < policy_dof_order.size(); i++) { const auto &dof_name = policy_dof_order[i]; double calculated_pos = policy_action_data[i]; // Check if the target position is within joint limits (with scaling) if (joint_position_limits.find(dof_name) != joint_position_limits.end()) { // Calculate the middle point of the range double mid_pos = (joint_position_limits[dof_name].first + joint_position_limits[dof_name].second) / 2.0; // Calculate the half-range and scale it double half_range = (joint_position_limits[dof_name].second - joint_position_limits[dof_name].first) / 2.0; double scaled_half_range = half_range * position_limit_scale; // Calculate scaled min and max by expanding from midpoint double min_pos = mid_pos - scaled_half_range; double max_pos = mid_pos + scaled_half_range; if (calculated_pos < min_pos || calculated_pos > max_pos) { // RCLCPP_WARN(this->get_logger(), // "Target position would exceed limit for joint %s: %f (scaled limits: %f, %f)", // dof_name.c_str(), calculated_pos, min_pos, max_pos); // Clamp the position to within limits calculated_pos = std::clamp(calculated_pos, min_pos, max_pos); } } // Set the target position (clamped to safe values if needed) target_dof_pos[dof_name] = calculated_pos; } } void KpsHandler(const std_msgs::msg::Float32MultiArray::SharedPtr message) { policy_kps_data = message->data; kps_received_ = true; // Check if message size matches expected size if (policy_kps_data.size() != policy_dof_order.size()) { RCLCPP_ERROR(this->get_logger(), "Policy kps data size mismatch: got %zu, expected %zu", policy_kps_data.size(), policy_dof_order.size()); current_state_ = RobotState::EMERGENCY_STOP; should_shutdown_ = true; publishRobotState(); return; } // Update kps map with policy data for (size_t i = 0; i < policy_dof_order.size(); i++) { const auto &dof_name = policy_dof_order[i]; kps[dof_name] = policy_kps_data[i]; } RCLCPP_INFO(this->get_logger(), "Received kps parameters from policy node (size: %zu)", policy_kps_data.size()); } void KdsHandler(const std_msgs::msg::Float32MultiArray::SharedPtr message) { policy_kds_data = message->data; kds_received_ = true; // Check if message size matches expected size if (policy_kds_data.size() != policy_dof_order.size()) { RCLCPP_ERROR(this->get_logger(), "Policy kds data size mismatch: got %zu, expected %zu", policy_kds_data.size(), policy_dof_order.size()); current_state_ = RobotState::EMERGENCY_STOP; should_shutdown_ = true; publishRobotState(); return; } // Update kds map with policy data for (size_t i = 0; i < policy_dof_order.size(); i++) { const auto &dof_name = policy_dof_order[i]; kds[dof_name] = policy_kds_data[i]; } RCLCPP_INFO(this->get_logger(), "Received kds parameters from policy node (size: %zu)", policy_kds_data.size()); } double clamp(double value, double low, double high) { if (value < low) return low; if (value > high) return high; return value; } std::string robotStateToString(RobotState state) { switch (state) { case RobotState::ZERO_TORQUE: return "ZERO_TORQUE"; case RobotState::MOVE_TO_DEFAULT: return "MOVE_TO_DEFAULT"; case RobotState::EMERGENCY_STOP: return "EMERGENCY_STOP"; case RobotState::POLICY: return "POLICY"; default: return "UNKNOWN"; } } void publishRobotState() { std_msgs::msg::String state_msg; state_msg.data = robotStateToString(current_state_); robot_state_publisher_->publish(state_msg); } rclcpp::TimerBase::SharedPtr timer_; // ROS2 timer rclcpp::Publisher::SharedPtr lowcmd_publisher_; // ROS2 Publisher rclcpp::Subscription::SharedPtr lowstate_subscriber_; // ROS2 Subscriber rclcpp::Subscription::SharedPtr policy_action_subscriber_; rclcpp::Subscription::SharedPtr kps_subscriber_; rclcpp::Subscription::SharedPtr kds_subscriber_; rclcpp::Publisher::SharedPtr robot_state_publisher_; unitree_hg::msg::LowCmd low_command; // Unitree hg lowcmd message unitree_hg::msg::IMUState imu; // Unitree hg IMU message unitree_hg::msg::MotorState motor[G1_NUM_MOTOR]; // Unitree hg motor state message double control_freq_; double control_dt_; int timer_dt; double time_; // Running time count double duration_; PRorAB mode_ = PRorAB::PR; int mode_machine; RemoteController wireless_remote_; }; // End of humanoid_controller class int main(int argc, char **argv) { rclcpp::init(argc, argv); // Initialize rclcpp auto node = std::make_shared(); // Create a ROS2 node rclcpp::spin(node); // Run ROS2 node rclcpp::shutdown(); // Exit return 0; } ================================================ FILE: deployment/unitree_g1_ros2_29dof/start_container.sh ================================================ #!/bin/bash docker kill holomotion_orin_deploy docker rm holomotion_orin_deploy echo "Old holomotion_orin_deploy container removed !" # Initialize variable as empty holomotion_repo_path="" # Loop until the user provides a non-empty string while [[ -z "$holomotion_repo_path" ]]; do read -p "Please enter the holomotion local repository path: " holomotion_repo_path if [[ -z "$holomotion_repo_path" ]]; then echo "Input cannot be empty." fi done # Validate the directory exists before running Docker if [ ! -d "$holomotion_repo_path" ]; then echo "Error: Directory '$holomotion_repo_path' does not exist." exit 1 fi echo "Mounting path: $holomotion_repo_path" sudo docker run -it \ --name holomotion_orin_deploy \ --runtime nvidia \ --gpus all \ --privileged \ --network host \ -e "ACCEPT_EULA=Y" \ -v "$holomotion_repo_path:/home/unitree/holomotion" \ -v "/usr/local/cuda-11.4/targets/aarch64-linux/lib:/cuda_base:ro" \ -v "/usr/lib/aarch64-linux-gnu/libcudnn.so.8.6.0:/host_gpu/libcudnn.so.8.6.0:ro" \ -v "/usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.6.0:/host_gpu/libcudnn_ops_infer.so.8.6.0:ro" \ -v "/usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.6.0:/host_gpu/libcudnn_cnn_infer.so.8.6.0:ro" \ horizonrobotics/holomotion:orin_foxy_jp5.1_humble_deploy_zmq_20260319 \ bash -c "ln -sf /host_gpu/libcudnn.so.8.6.0 /host_gpu/libcudnn.so.8 && \ ln -sf /host_gpu/libcudnn_ops_infer.so.8.6.0 /host_gpu/libcudnn_ops_infer.so.8 && \ ln -sf /host_gpu/libcudnn_cnn_infer.so.8.6.0 /host_gpu/libcudnn_cnn_infer.so.8 && \ source /root/miniconda3/bin/activate && conda activate holomotion_deploy && exec bash" ================================================ FILE: docs/environment_setup.md ================================================ # Environment Setup ## Step 1: Setup Conda This project uses conda to manage Python environments. We recommend using [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install#linux-installer). **For users in China:** Configure the conda mirror following [TUNA](https://mirrors.tuna.tsinghua.edu.cn/help/anaconda/) for faster downloads. ## Step 2: Setup Third-party Dependencies ### 2.1 Download SMPL/SMPLX Models We use SMPL/SMPLX models to retarget mocap data into robot motion data. Register your account and download the models from: - [SMPL](https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.1.0.zip) - [SMPLX](https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip) Place both zip files (`SMPL_python_v.1.1.0.zip` and `models_smplx_v1_1.zip`) in the `thirdparties/` folder, then extract: ```shell mkdir thirdparties/smpl_models unzip thirdparties/SMPL_python_v.1.1.0.zip -d thirdparties/smpl_models/ unzip thirdparties/models_smplx_v1_1.zip -d thirdparties/smpl_models/ ``` The resulting file structure for smpl models would be: ```shell thirdparties/ ├── smpl_models ├── models └── SMPL_python_v.1.1.0 ``` ### 2.2 Pull Submodules After cloning this repository, run the following command to get all submodule dependencies: ```shell git submodule update --init --recursive ``` ### 2.3 Create Asset Symlinks This project uses symbolic links to connect robot and SMPL assets from submodules to the main `assets` directory. Symlinks are created automatically when you clone the repository. ### 2.4 Verify Third-party File Structure After completing the above steps, your file structure should look like this: ```shell thirdparties/ ├── HoloMotion_assets ├── GMR ├── smplx ├── joints2smpl ├── omomo_release ├── smpl_models ├── SMPLSim ├── unitree_ros └── unitree_ros2 ``` ## Step 3: Create the Conda Environment Create the conda environment named `holomotion_train` and `holomotion_deploy`: ```shell conda env create -f environments/environment_train_isaaclab_cu118.yaml # for newer GPUs like RTX 5090, use environment_train_isaaclab_cu128.yaml conda env create -f environments/environment_deploy.yaml ``` Install smplx and GMR into the conda environment: ```shell cd thirdparties conda activate holomotion_train pip install -e ./smplx # use --no-deps to avoid pulling GMR's dependencies pip install -e ./GMR --no-deps ``` ## Step 4: Configure the Environment Variables HoloMotion uses `train.env` and `deploy.env` files to export environment variables in the shell entry scripts. Please make sure the `Train_CONDA_PREFIX` and the `Deploy_CONDA_PREFIX` variables in `train.env` and `deploy.env` are correctly setup. You can manually source these files and check the output in the shell. Take the `train.env` for example: ```shell source train.env ``` These `.env` files will be sourced in the shell scripts (in `holomotion/scripts`) to correctly find and utilize your conda environments. ================================================ FILE: docs/evaluate_motion_tracking.md ================================================ ## Evaluate the Motion Tracking Model After training for a while and saving model checkpoints, it is necessary to run the evaluation pipeline to get to know your model performance both visually and quantitatively. HoloMotion also bakes the model exporting process for later deployment in the evaluation pipeline. **Overall Workflow:** ```mermaid flowchart LR A[Trained Checkpoints] B[HDF5 Database] C[Evaluation Config] D[Evaluation Entry] E[Offline Evaluation] F[Calculate Metrics] G[MuJoCo Visualization] A --> D B --> D C --> D D --> E E --> F E --> G classDef dashed stroke-dasharray: 5 5, rx:10, ry:10, fill:#c9d9f5 classDef normal fill:#c9d9f5, rx:10, ry:10 class A,B dashed class C,D,E,F,G normal ``` ### 1 Offline Evaluation ```bash bash ./holomotion/scripts/evaluation/eval_motion_tracking.sh ``` Update the evaluation script by setting `checkpoint_path` (e.g., `logs/Holomotion/model_1000.pt`) and `eval_h5_dataset_path`. ### 2 Calculate Metrics Process the `.npz` files generated in the previous step and convert them into a final quantitative JSON metrics report: ```bash bash ./holomotion/scripts/evaluation/calc_offline_eval_metrics.sh ``` - `npz_dir`: Path to the folder containing `.npz` result files. - `dataset_suffix`: Evaluation dataset name, set to differentiate different datasets. ### 3 MuJoCo Visualization Generate video outputs to validate the motion tracking quality from the `.npz` result files by setting the `motion_npz_root` to the evaluation npz folder. Note that in order to properly visualize the recorded robot data, you should set the `+key_prefix="robot_"` . ```bash bash ./holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh ``` - `motion_npz_root`: Path to the folder containing `.npz` result files. - `video_rendering/{motion_name}.mp4` files in the corresponding `.npz` result files. ### 4 Export Trained Model to ONNX To deploy our policy to real world robots, we need to convert the pytorch module into ONNX format, which is supported by most inference frameworks. After running the evaluation script, the `.onnx` file will be generated and saved to the checkpoint directory: ``` logs/HoloMotion/your_checkpoint_dir/ ├── config.yaml ├── exported │ └── model_10000.onnx └── model_10000.pt ``` ================================================ FILE: docs/holomotion_motion_file_spec.md ================================================ ## HoloMotion NPZ Format — Keys and Values This document lists the exact keys saved in a HoloMotion NPZ and their value types/shapes. - Prefix policy - ref\_\*: reference motion (source-of-truth produced by preprocessing) - ft*ref*_: filtered reference motion (post-filtering; never overwrites ref\__) - robot\_\*: robot states (only present in offline evaluation exports) - Legacy (no prefix): kept only for backward-compat; new files prefer ref\_\* - metadata - type: JSON string - fields: - motion_key: str - raw_motion_key: str - motion_fps: float - num_frames: int - wallclock_len: float (seconds, approx (num_frames - 1) / motion_fps) - num_dofs: int - num_bodies: int - clip_length: int (original clip length in frames) - valid_prefix_len: int (contiguous valid frames from the start) - ref_dof_pos - dtype: float32 - shape: [T, num_dofs] (URDF joint order; reference motion) - ref_dof_vel - dtype: float32 - shape: [T, num_dofs] (URDF joint order; reference motion) - ref_global_translation - dtype: float32 - shape: [T, num_bodies, 3] (meters, world frame; reference motion) - ref_global_rotation_quat - dtype: float32 - shape: [T, num_bodies, 4] (quaternion XYZW, world frame; reference motion) - ref_global_velocity - dtype: float32 - shape: [T, num_bodies, 3] (m/s, world frame; reference motion) - ref_global_angular_velocity - dtype: float32 - shape: [T, num_bodies, 3] (rad/s, world frame; reference motion) - ft_ref_dof_pos - dtype: float32 - shape: [T, num_dofs] (filtered reference motion) - ft_ref_dof_vel - dtype: float32 - shape: [T, num_dofs] (derived from filtered positions) - ft_ref_global_translation - dtype: float32 - shape: [T, num_bodies, 3] (filtered reference motion) - ft_ref_global_rotation_quat - dtype: float32 - shape: [T, num_bodies, 4] (filtered, normalized XYZW) - ft_ref_global_velocity - dtype: float32 - shape: [T, num_bodies, 3] (derived from filtered positions) - ft_ref_global_angular_velocity - dtype: float32 - shape: [T, num_bodies, 3] (derived from filtered quaternions) - robot_dof_pos - dtype: float32 - shape: [T, num_dofs] (URDF joint order; robot) - robot_dof_vel - dtype: float32 - shape: [T, num_dofs] (URDF joint order; robot) - robot_global_translation - dtype: float32 - shape: [T, num_bodies, 3] (meters, world frame; robot) - robot_global_rotation_quat - dtype: float32 - shape: [T, num_bodies, 4] (quaternion XYZW, world frame; robot) - robot_global_velocity - dtype: float32 - shape: [T, num_bodies, 3] (m/s, world frame; robot) - robot_global_angular_velocity - dtype: float32 - shape: [T, num_bodies, 3] (rad/s, world frame; robot) - dof_pos (deprecated legacy key) - dtype: float32 - shape: [T, num_dofs] (URDF joint order; ref or robot) - deprecated - dof_vels (deprecated legacy key) - dtype: float32 - shape: [T, num_dofs] (URDF joint order; ref or robot) - deprecated - global_translation (deprecated legacy key) - dtype: float32 - shape: [T, num_bodies, 3] (meters, world frame; ref or robot) - deprecated - global_rotation_quat (deprecated legacy key) - dtype: float32 - shape: [T, num_bodies, 4] (quaternion XYZW, world frame; ref or robot) - deprecated - global_velocity (deprecated legacy key) - dtype: float32 - shape: [T, num_bodies, 3] (m/s, world frame; ref or robot) - deprecated - global_angular_velocity (deprecated legacy key) - dtype: float32 - shape: [T, num_bodies, 3] (rad/s, world frame; ref or robot) - deprecated Notes: - T == num_frames from metadata. - All arrays are float32. ================================================ FILE: docs/motion_retargeting.md ================================================ # Motion Retargeting Transform human motion data into robot-compatible joint trajectories for following training. We support GMR for retargeting (https://github.com/YanjieZe/GMR) ## Prerequisites Before running the motion retargeting pipeline, ensure you have: ### 1. Environment Setup Please make sure the smplx and GMR are properly installed according to [[environment setup doc](./environment_setup.md)]. ### 2. Data Preparation Place your AMASS motion data in `/assets/test_data/motion_retargeting/{dataset_name}` or modify 'amass_dir' in 'script/motion_retargeting/\*.sh' !Please check all related path in .sh and .yaml are right! ### 3. Model Preparation Put SMPLX models under following path thirdparties/ └── GMR/ ├── assets/ │ └── body_models/ │ └── smplx/ │ ├── SMPLX_FEMALE.npz │ ├── SMPLX_FEMALE.pkl │ ├── SMPLX_MALE.npz │ ├── SMPLX_MALE.pkl │ ├── SMPLX_NETURAL.npz │ └── SMPLX_NETURAL.pkl ### 4. Path Verification Check data paths in the configuration scripts: - `holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh` Before using GMR, it is recommended to run `bash ./holomotion/scripts/motion_retargeting/apply_gmr_motion_retarget_patch.sh` first, which can help reduce singular solutions to some extent. ## Quick Start ### 1. Motion Retargeting ```bash bash ./holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh ``` > Reminder: set device = "cuda:0" to "cpu" in "smplx_to_robot_dataset.py" if facing cuda error After GMR retargeting, we further need to convert the dataset into a HoloMotion-compatible npz format, please run: ```bash bash ./holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_to_holomotion.sh ``` ### 2. Motion Visualization Generate video outputs to validate retargeting quality of the HoloMotion npz files: ```bash bash ./holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh ``` **Output**: `video_rendering/{motion_name}.mp4` files in the retargeted data directories ### 3. Pack to HDF5 for Training After retargeting, we need to pack the npz files into a compact HDF5 database: ```bash bash ./holomotion/scripts/motion_retargeting/pack_hdf5_dataset.sh ``` ================================================ FILE: docs/mujoco_sim2sim.md ================================================ # Sim2Sim Verification After generating the ONNX file from the evaluation stage, you can verify the performance of your Isaac-trained policy in another simulator, such as Mujoco to test its performance before deploying to the real robot. The entry script is `holomotion/scripts/evaluation/eval_mujoco_sim2sim.sh`, you should set these variables before running: - `robot_xml_path`: The scene mjcf .xml file for the robot - `ONNX_PATH`: The exported ONNX model file - `motion_npz_path`: The npz file containing the reference motion ================================================ FILE: docs/realworld_deployment.md ================================================ # HoloMotion Deployment Guide This guide describes how to set up the deployment environment and run the trained policy on a physical Unitree G1 robot. ## Robot Configuration for 29 DOF The 29 DOF configuration includes: - 12 leg joints (6 per leg) - 3 waist joints (yaw, roll, pitch) - 14 arm joints (7 per arm) --- ## Deployment Options This guide provides two deployment methods: | Deployment Method | Target Platform | | ----------------------------------------------- | ----------------- | | [Laptop Deployment](#laptop-deployment) | Laptop/Desktop PC | | [PC2 Docker Deployment](#pc2-docker-deployment) | G1 Robot's PC2 | Choose the appropriate method based on your setup: - **For laptop/desktop deployment**: Follow the [Laptop Deployment](#laptop-deployment) section - **For PC2 on robot hardware**: Follow the [PC2 Docker Deployment](#pc2-docker-deployment) section ### ⚠️ Important Safety Notice > **For safety reasons, it is strongly recommended to remove the dexterous hands before running the policy.** --- ## Laptop Deployment ### Quick Environment Setup #### Prerequisites Ensure the following are installed before proceeding: - Anaconda or Miniconda - ROS 2 Humble installed at `/opt/ros/humble` - MCAP for efficient ROS 2 data recording - Unitree ROS 2 SDK installed at `~/unitree_ros2/` #### One-Click Deployment ```bash cd /deployment chmod +x deploy_environment.sh ./deploy_environment.sh ``` This script will: - Create a new conda environment (with CUDA support if available) - Install Python packages from `requirements/requirements_deploy.txt` - Install Unitree SDK Python bindings - Build the ROS 2 workspace under `unitree_g1_ros2_29dof/` --- ### Deploy on Physical G1 Robot (Laptop) ### Setup Overview The deployment process consists of two types of steps: | **One-Time Setup** (per computer) | **Every Run** (each time you use the robot) | | --------------------------------- | ------------------------------------------- | | Step 1: Network Configuration | Step 3: Power On & Initialize Robot | | Step 2: Launch Script Setup | Step 4: Launch Policy Controller | > **Note**: Once you complete Steps 1-2, you only need to do Steps 3-4 for each robot session! ### Step 1: Connect and Configure Network #### Prerequisites for Network Setup: 1. **Power on the robot** and wait for it to fully boot 2. **Use an Ethernet cable** to connect your PC to the robot's LAN port 3. **Ensure both devices are powered on** during configuration #### Network Configuration: Configure your PC's network interface with the following static IP settings: - **Static IP**: `192.168.123.222` - **Netmask**: `255.255.255.0` - **Gateway**: (leave empty) #### Automatic Configuration Script: You can use the following script to configure it automatically (use command `nmcli con show` to check your actual connection name):
Click to view set_static_ip.sh ```bash #!/bin/bash # Replace with your actual connection name (use `nmcli con show` to check) CON_NAME="Wired connection 1" IP_ADDRESS="192.168.123.222" NETMASK="24" GATEWAY="" nmcli con modify "$CON_NAME" ipv4.addresses "$IP_ADDRESS/$NETMASK" nmcli con modify "$CON_NAME" ipv4.method manual if [ -n "$GATEWAY" ]; then nmcli con modify "$CON_NAME" ipv4.gateway "$GATEWAY" fi nmcli con modify "$CON_NAME" ipv4.dns "" nmcli con down "$CON_NAME" && nmcli con up "$CON_NAME" ```
--- ### Step 2: Prepare Launch Script #### Configure Network Interface: 1. **Check your network interface name** (while connected to the robot): ```bash ifconfig ``` Look for the interface connected to the robot (e.g., `enxf8e43ba00afd`, `eth0`, `enp0s31f6`) 2. **Update the launch configuration**: ```bash # Edit the launch file nano /deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py ``` Find and update the `network_interface` parameter with your actual interface name. --- ### Step 3: Power On and Initialize the Robot > **Do this every time** you want to run the robot. #### Robot Initialization Sequence for 29 DOF: 1. **Power on the robot** - Start the robot in the **hanging position** 2. **Wait for zero torque mode** - The robot will automatically enter zero torque mode (joints feel loose) 3. **Connect your computer** - Use the same Ethernet cable to connect to the robot's LAN port 4. **Enter debugging mode** - On the remote controller, press `L2 + R2` simultaneously. Note: the new deployment automatically enters this mode on startup, so manual entry is usually not required. --- ### Step 4: Launch the Policy Controller #### Preflight Checklist Before running, ensure the following are ready. - Model folders configured in `g1_29dof_holomotion.yaml` exist - `motion_tracking_model_folder`: under `src/models/` - `velocity_tracking_model_folder`: under `src/models/` - Motion data directory exists and contains .npz files (retargeted results) - `motion_clip_dir`: under `src/motion_data/` - Config file path used by launch is correct #### Motion Reference Source Motion tracking supports two reference sources: - **Offline motion mode**: the robot executes the selected `.npz` motion clip from `motion_clip_dir`. - **Online teleoperation mode**: the robot uses live `latest_obs` data streamed from the teleoperation workstation / VR pipeline. See [Holomotion teleop setup](../deployment/holomotion_teleop/holomotion_teleop_setup.md) for Pico / XRoboToolkit, ZMQ publishing, and launch order on the workstation. The mode is selected by YAML settings in `g1_29dof_holomotion.yaml`: - **Offline motion mode** - `vr.enable_teleop_reference: false` - `vr.require_vr_data_for_motion: false` - Result: even if ZMQ data is still arriving, the robot ignores it and motion tracking uses offline `.npz` clips only. - **Online teleoperation mode** - `vr.enable_teleop_reference: true` - `vr.require_vr_data_for_motion: true` - `vr.latest_obs_zmq_uri: "tcp://:6001"` - Result: motion mode waits for live teleoperation data before entering and uses the incoming `latest_obs` stream as the motion reference. If you want teleoperation to be available but not mandatory, you can also use: - `vr.enable_teleop_reference: true` - `vr.require_vr_data_for_motion: false` In that configuration, motion mode can still start without waiting for VR readiness, but live teleoperation data may be used when available at mode entry. #### One-click start ```bash cd /deployment/unitree_g1_ros2_29dof bash launch_holomotion_29dof.sh ``` > **Success indicator**: On startup, the robot joints should remain in zero torque state and feel free to move. #### Motion Control Modes The 29 DOF robot operates in two main modes: | Mode | How to Enter | Controls | Switch | | ------- | --------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------ | | Velocity tracking | 1) Press Start to stand up, then press A
2) From motion tracking: press Y | Left stick: move (vx, vy)
Right stick: rotate (yaw)
D-Pad: select motion clip (Left=first, Right=last, Up=prev, Down=next) | B: enter motion tracking | | Motion tracking | Press B | Executes selected motion clip or online teleoperation automatically | Y: back to velocity tracking | #### Control Flow Here is the robot control flowchart for 29 DOF: ```mermaid flowchart TD subgraph prepPhase ["Setup Phase"] direction TB A[Set Robot
to Hanging Position] --> B[Power On and
Zero Torque Mode] B --> C[L2+R2: Enter Debug Mode] C --> D[Launch Program] end %% Main flow prepPhase --> E[Start: Stand Up] E --> F[Lower Robot to Ground] F --> G[A: Enter Velocity tracking Mode] %% Velocity tracking mode controls G --> H[Velocity tracking Mode] H --> H1[Left Stick: Move] H --> H2[Right Stick: Rotate] H --> H3[D-Pad: Select Motion Clip] H --> H4[B: Enter Motion tracking Mode] %% Motion tracking mode H4 --> I[Motion tracking Mode] I --> I1[Execute Motion Clip or Teleoperation] I --> I2[Y: Back to Velocity tracking] %% Mode switching I2 --> H %% Emergency stop D --> N[Select: Emergency Stop] E --> N F --> N G --> N H --> N I --> N N --> O[Close Program] classDef startEnd fill:#e1f5fe,stroke:#01579b,stroke-width:2px classDef control fill:#f3e5f5,stroke:#4a148c,stroke-width:2px classDef velocityTracking fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px classDef motionTracking fill:#fff3e0,stroke:#e65100,stroke-width:2px classDef emergency fill:#ffebee,stroke:#b71c1c,stroke-width:2px,stroke-dasharray: 5 5 classDef preparationFrame fill:#f9f9f9,stroke:#666,stroke-width:2px,stroke-dasharray: 5 5 class A,O startEnd class B,C,D,E,F,G control class H,H1,H2,H3,H4 velocityTracking class I,I1,I2 motionTracking class N emergency class prepPhase preparationFrame ``` #### Configuration Files (used by Step 4) **System Configuration** - File: `HoloMotion/deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml` - Key parameters: - `motion_tracking_model_folder`: motion tracking model folder under `models/` - `velocity_tracking_model_folder`: velocity tracking model folder under `models/` - `motion_clip_dir`: motion clip data folder under `src/` - `vr.enable_teleop_reference`: enable or disable live teleoperation reference - `vr.require_vr_data_for_motion`: whether motion mode must wait for live teleoperation data - `vr.latest_obs_zmq_uri`: teleoperation ZMQ endpoint used in online mode **Pre-trained Models** We provide a pre-trained velocity tracking model that you can download and use: - **Motion Tracking Model**: Download from [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_motion_tracking_model) - **Velocity Tracking Model**: Download from [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_velocity_tracking_model) To use this model: 1. Download the `holomotion_v1.2_velocity_tracking_model` folder from the Hugging Face repository 2. Place the downloaded folder under `models/` (e.g., `models/holomotion_v1.2_velocity_tracking_model/`) 4. Update `velocity_tracking_model_folder` in the `g1_29dof_holomotion.yaml` to point to this folder **Adding New Motion Tracking Models** 1. Create a new folder under `models/` based on the following example model folder structure (e.g., `models/your_model_dir_name/`) 2. Update `motion_tracking_model_folder` in the `g1_29dof_holomotion.yaml` 3. Ensure the motion clip data files are in the `motion_clip_dir` Example model folder structure (motion model): ```bash HoloMotion/deployment/unitree_g1_ros2_29dof/src/models/your_model_dir_name ├── config.yaml ├── exported └── your_model_name.onnx ``` --- ### Safety Notice This deployment is intended for demonstration only. It is not a production-grade control system. Do not interfere with the robot during operation. If unexpected behavior occurs, exit control immediately via the controller or keyboard to ensure safety. To stop the control process, press `Select` or use `Ctrl+C` in the terminal. --- ## PC2 Docker Deployment ### Setup Overview The deployment process consists of two types of steps: | **One-Time Setup** (per PC2) | **Every Run** (each time you use the robot) | | ----------------------------- | -------------------------------------------- | | Step 1: Configure Docker | Step 4: Start Docker Container | | Step 2: Load Docker Image | Step 5: Power On & Initialize Robot | | Step 3: Configure Launch File Network Interface | Step 6: Launch Policy Controller | > **Note**: Once you complete Steps 1-3, you only need to do Steps 4-6 for each robot session! ### System Requirements - **Platform**: NVIDIA Jetson Orin - **JetPack**: 5.1 - **Ubuntu**: 20.04 - **ROS 2**: Foxy - **Docker**: Installed with NVIDIA Container Runtime support ### Step 1: Configure Docker for NVIDIA Runtime Modify `/etc/docker/daemon.json`: ```json { "runtimes": { "nvidia": { "path": "nvidia-container-runtime", "runtimeArgs": [] } }, "default-runtime": "nvidia" } ``` Restart Docker and verify: ```bash sudo systemctl restart docker sudo docker info | grep -i runtime ``` ### Step 2: Load Docker Image Pull the image from dockerhub with: ```bash docker pull horizonrobotics/holomotion:orin_foxy_jp5.1_docker_humble_deploy_zmq_20260319 ``` Or if you have the image locally, tag it appropriately: ```bash docker tag holomotion:orin_foxy_jp5.1_docker_humble_deploy_zmq_20260319 ``` ### Step 3: Configure Launch File Network Interface: 1. **Check your network interface name on the robot**: ```bash ifconfig ``` Look for the interface with IP `192.168.123.164`. The interface is typically `eth0`. 2. **Update the launch configuration** if your interface is not `eth0`: ```bash nano /deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py ``` Find line 103 and update the `network_interface` parameter: ```python network_interface = "eth0" # Change to your actual interface name ``` ### Step 4: Start Docker Container > **Important**: Before running Docker commands, ensure your user is added to the docker group. If you encounter permission errors, add your user to the docker group: ```bash sudo usermod -aG docker $USER ``` After adding your user to the docker group, you need to log out and log back in (or restart your session) for the changes to take effect. Verify with: ```bash groups ``` You should see `docker` in the list of groups. > **Important**: You need to run this step **every time** you want to use the robot. The script will automatically remove any existing container and start a fresh one. ```bash cd /deployment/unitree_g1_ros2_29dof bash start_container.sh ``` **When prompted, enter the holomotion repository path:** - The script will ask: `Please enter the holomotion local repository path:` - Enter the full path to your holomotion repository, for example: - `/home/unitree/HoloMotion` (if the repository is at this location) - Or the actual path where your holomotion repository is located ### Step 5: Power On and Initialize Robot > **Do this every time** before launching the policy controller. 1. Put the robot in hanging position 2. Wait for zero torque mode 3. Press `L2 + R2` on remote controller for debug mode ### Step 6: Launch Policy Controller > **Do this every time** you want to run the robot (after Steps 4 and 5). **Preflight Checklist** Before running, ensure the following are ready. - Model folders configured in `g1_29dof_holomotion.yaml` exist - `motion_tracking_model_folder`: under `src/models/` - `velocity_tracking_model_folder`: under `src/models/` - Motion data directory exists and contains .npz files (retargeted results) - `motion_clip_dir`: under `src/motion_data/` - Config file path used by launch is correct - Motion reference source is configured as intended (see [Motion Reference Source](#motion-reference-source)) **Pre-trained Models** You can download and use the pre-trained velocity tracking model. Refer to the [Pre-trained Models](#configuration-files-used-by-step-4) section above for general instructions. > **Note**: The model folder should be placed in your local repository before starting the Docker container, as the repository is mounted into the container. **One-click start in the docker** ```bash cd /home/unitree/holomotion/deployment/unitree_g1_ros2_29dof bash launch_holomotion_29dof_docker.sh ``` > **Note**: The control flow is the same as described in the [Control Flow](#control-flow) section above. --- ### Safety Notice This deployment is intended for demonstration only. It is not a production-grade control system. Do not interfere with the robot during operation. If unexpected behavior occurs, exit control immediately via the controller or keyboard to ensure safety. To stop the control process: - Press `Select` on the remote controller, or - Use `Ctrl+C` in the terminal (inside Docker container) ================================================ FILE: docs/smpl_data_curation.md ================================================ # Dataset Preparation Guide This guide describes the workflow and setup for preparing datasets to train the motion tracking model. We use **AMASS-compatible SMPL-format motion capture data** as the training input. --- ## Overview The dataset preparation pipeline has the following steps: 1. **Download datasets** - To train with diverse and rich motion data, you first need to collect raw motion capture datasets from various sources. - Then place all downloaded datasets under the data/raw_datasets directory in their original structure. 2. **Convert datasets to AMASS format** - To ensure that all motion data is compatible with the AMASS-style .npz format used by the training pipeline, you need to convert the raw datasets. - Then run the conversion script to generate .npz files under data/amass_compatible_datasets/. 3. **Filter datasets** - To improve data quality by removing abnormal, noisy, or unwanted motion samples, you can optionally run the filtering step. - Then run the filtering script to generate filtered .yaml files under holomotion/config/data_curation/. 4. **Visualize Prepared Data** - Use the included visualization utility to preview and inspect the generated AMASS-compatible `.npz` motion files. - Quickly check for anomalies or errors before training. 5. **Generate Motion from Monocular Video** - You can also generate SMPL-format motion capture files **directly from monocular RGB videos** using GVHMR. - This allows you to create training data or test the model with real-world video footage. - Pipeline are given follow. ### Directory Structure After Full Setup ``` data/ ├── raw_datasets/ │ ├── humanact12/ │ ├── OMOMO/ │ ├── MotionX/ │ └── ZJU_Mocap/ ├── amass_compatible_datasets/ │ ├── amass/ │ │ ├── ACCAD/ │ │ ├── BioMotionLab_NTroje/ │ │ ├── ... │ ├── humanact12/ │ ├── OMOMO/ │ ├── MotionX/ │ └── ZJU_Mocap/ ├── dataset_labels/ │ ├── humanact12.jsonl │ ├── OMOMO.jsonl │ ├── MotionX.jsonl │ ├── ZJU_Mocap.jsonl │ ├── amass.jsonl ``` --- ## Step-by-Step Instructions ### 1. Download Datasets Download and extract the datasets into the `data/` folder as follows: - `data/amass_compatible_datasets/amass/` (required) - [AMASS dataset](https://amass.is.tue.mpg.de/download.php) — choose **SMPL-X G** format. - `data/raw_datasets/humanact12/` (optional) - [HumanAct12](https://github.com/EricGuo5513/action-to-motion?tab=readme-ov-file) - `data/raw_datasets/OMOMO/` (optional) - [OMOMO dataset](https://github.com/lijiaman/omomo_release?tab=readme-ov-file) - `data/raw_datasets/MotionX/` (optional) - [MotionX dataset](https://github.com/IDEA-Research/Motion-X) - `data/raw_datasets/ZJU_Mocap/` (optional) - [EasyMocap](https://github.com/zju3dv/EasyMocap) --- ### 2. Convert Optional Datasets to AMASS Format (optional) Skip this step if you only use amass dataset. Step: 1. Initialize Submodules: Some datasets require external repositories or models for proper conversion. These are included as **submodules** under: ``` thirdparties/ ``` Initialize them with: ```bash git submodule update --init --recursive ``` If you need to modify or update these submodules, refer to their individual `README` files. 2. Download SMPL Model: - Download `SMPL_NEUTRAL.npz` from the [SMPL official website](https://smpl.is.tue.mpg.de/download.php). - Place the file into: ``` ./assets/smpl/ ``` 3. Modify `thirdparties/joints2smpl/src/customloss.py`: Before running the pipeline, make sure to modify the `body_fitting_loss_3d` function in `thirdparties/joints2smpl/src/customloss.py` to include the following change: ```python joint3d_loss = (joint_loss_weight ** 2) * joint3d_loss_part.sum(dim=-1) ``` 4. Modify `thirdparties/joints2smpl/src/smplify.py`: Next, ensure the following modification in the `__call__` function of `SMPLify3D` inside `thirdparties/joints2smpl/src/smplify.py`: ```python init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach() ``` 5. Start Conversion: Run the provided script to convert all available datasets to AMASS `.npz` files: ```bash bash holomotion/scripts/data_curation/convert_to_amass.sh ``` This script reads from `data/{dataset}/` and writes to `data/amass_compatible_datasets/{dataset}/`. --- ### 3. Filter Datasets (optional) Skip this step if you prefer to use all available data for training. #### Why filter? The raw datasets may contain motions that are irrelevant, undesirable, or of poor quality for training. This step helps improve the overall quality and consistency of your dataset. #### Filtering criteria The filtering process excludes samples based on the following rules: - **Upstairs/Downstairs motion:** Paths containing keywords like stairs, staircase, upstairs, downstairs, or motions with large upward/downward Z translation and velocity are excluded. - **Sitting motion:** Sitting motion: Paths containing sitting/Sitting keywords or frames that match a reference sitting pose are excluded. - **Known abnormal datasets:** Known abnormal datasets: Samples from datasets like aist is excluded. - **Unrealistic velocity:** Unrealistic velocity: Motions where the mean velocity exceeds a threshold (default: 100.0) are excluded. #### How to run You can use the `-l` option to specify which datasets to filter (space-separated list). Run the filtering script to identify and exclude abnormal or unwanted samples: ```bash bash holomotion/scripts/data_curation/filter_smpl_data.sh -l "amass humanact12 OMOMO MotionX ZJU_Mocap" ``` The output `.yaml` files will be placed in `holomotion/config/data_curation/`. --- ## Notes - Paths are relative to the project root. - The AMASS dataset must be manually requested from their website. - Dataset conversion and filtering may take time depending on your hardware. --- This guide assumes that you only need the basic configuration to run the complete pipeline. For further customization, refer to the relevant scripts in the repository and optional steps in the documentation. ## Video2SMPL Instructions ### 1. Environment setup Create a new conda environment named 'gvhmr' following the official installation guide [[GVHMR setup doc](../thirdparties/GVHMR/docs/INSTALL.md)] > Reminder: If you encounter 'hmr4d' missing module errors during runtime, install the following dependency separately in gvhmr env ```bash pip install hmr4d ``` Rename the SMPL model files as follows basicmodel_{GENDER}_lbs_*.pkl → SMPL_{GENDER}.pkl Place the SMPL and SMPL-X model files into the directory structure below ``` thirdparties/GVHMR/inputs/checkpoints/ ├── body_models/smplx │ └── SMPLX_{GENDER}.npz └── body_models/smpl └── SMPL_{GENDER}.pkl ``` ### 2. Video to SMPL motion data Confirm that all input videos have a frame rate of 30 FPS to avoid motion acceleration or deceleration. ```bash bash ./holomotion/scripts/data_curation/video_to_smpl_gvhmr.sh ``` > Reminder: Set the directory in the .sh file to an absolute path. ### 3. Visualize generated SMPL motion data Visualize the SMPL motion sequences generated by GVHMR for inspection and debugging. ```bash bash ./holomotion/scripts/data_curation/visualize_smpl_gvhmr.sh ``` ### 4. Convert SMPL data to SMPLX Motion data from GVHMR are SMPL format. Use ./thirdparties/GMR/scripts/smpl_to_smplx.py converting format to SMPLX for retargeting. ================================================ FILE: docs/train_motion_tracking.md ================================================ ## Train the Motion Tracking Model After completing motion retargeting, you can train a motion tracking model with HoloMotion using the following process. **Overall Workflow:** ```mermaid flowchart LR A[Motion Retargeting] --> B[HDF5 Database] B --> C[Training Config] C --> D[Training Entry] D --> E[Distributed PPO Training] classDef dashed stroke-dasharray: 5 5, rx:10, ry:10, fill:#c9d9f5 classDef normal fill:#c9d9f5, rx:10, ry:10 class A dashed class B,C,D,E normal ``` ### 1. Train the Motion Tracking Model The training entry point is `holomotion/src/training/train.py`, which uses the training config to start distributed training across multiple GPUs. #### 2.1 Prepare the Training Config Use the demo config at `holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking.yaml` as a template. Key configuration groups to modify (configs are located in the `holomotion/config/` directory): - **`/algo`**: Algorithm settings (PPO) and network configurations - **`/robot`**: Robot-specific config including DOF, body links, and control parameters - **`/env`**: Environment settings including motion sampling and curriculum learning - **`/env/observations`**: Observation dimensions, noise, and scaling for the policy - **`/env/rewards`**: Reward function definitions - **`/env/domain_randomization`**: Domain randomization settings (start with `NO_domain_rand`) - **`/env/terrain`**: Terrain configuration - **`/modules`**: The policy network modules definitions ```yaml # @package _global_ defaults: - /training: train_base - /algo: ppo - /robot: unitree/G1/29dof/29dof_training_isaaclab - /env: motion_tracking - /env/terminations: termination_motion_tracking - /env/observations: motion_tracking/obs_motion_tracking_tf-moe - /env/rewards: motion_tracking/rew_motion_tracking - /env/domain_randomization: domain_rand_medium - /env/terrain: isaaclab_plane - /modules: motion_tracking/motion_tracking_tf-moe project_name: HoloMotion ``` #### 2.2 Train your Policy Review and modify the training script at `holomotion/scripts/training/train_motion_tracking.sh`. Ensure `config_name` match your training config and LMDB database directory. Start training by running: ```shell bash holomotion/scripts/training/train_motion_tracking.sh # or bash holomotion/scripts/training/train_velocity_tracking.sh ``` Note that IsaacLab relies on internet connections to pull assets from Nvidia's cloud storage. If you encountered stuck at scene creation, it is very likely that you can't access the cloud-hosted assets. Turn on your proxy and try again can solve the issue. ### Training Tips #### How to use less GPU ? Training requires significant GPU memory. Reduce `num_envs` if your GPU has limited GRAM. This will reduce both the rollout burden and the PPO training consumption, at the risk of significantly less stable policy optimization process. #### How to start multiple training session ? In cases where you would like to start multiple training sessions, you should explicitly add the `--main_process_port=port_number` option in the training entry bash script to avoid port conflict of the accelerate backend. And this `port_number` **can not** be `0` . If you would like to run training on a specific GPU, just modify the GPU id in the `export CUDA_VISIBLE_DEVICES="X"` statement. #### How to set the save/log intervals ? You may want to have more or less frequent logging and model dumping intervals. You can alter these intervals by adding the following options: - `algo.config.save_interval=X` : The checkpoint will be saved every `X` learning iterations. - `algo.config.log_interval=Y`: The logging information will be displayed every `Y` learning iterations. #### Where is the checkpoint dumped ? By default, the model checkpoint will be dumped into a folder named `logs/HoloMotion`. You can change this path by explictly setting `project_name=X`, which results in dumping the checkpoints into the `logs/X` directory. #### How to resume training from a checkpoint ? To resume training from a pretrained checkpoint, you can find the checkpoint in the log directory, and then add the option like this: `checkpoint=logs/HoloMotion/20250728_214414-train_unitree_g1_21dof_teacher/model_X.pt` ================================================ FILE: environments/environment_deploy.yaml ================================================ name: holomotion_deploy channels: - pytorch - nvidia - conda-forge - defaults dependencies: # Python runtime - python=3.10 # PyTorch with CUDA support - pytorch-cuda=12.1 - cudnn>=9 - cudatoolkit>=11.7,<12 - pytorch==2.3.1 - torchvision==0.18.1 - torchaudio==2.3.1 # Scientific computing packages (via conda for better compatibility) - numpy==1.24.3 - scipy - matplotlib - pandas # System utilities and development tools - sshpass=1.06 - git - curl - wget - pyyaml - easydict - joblib # Basic Python package management - pip - setuptools - wheel # Install additional packages via pip - pip: - -r environments/requirements_deploy.txt ================================================ FILE: environments/environment_train_isaaclab_cu118.yaml ================================================ name: holomotion_train channels: - conda-forge dependencies: - python=3.11 - pip - mesalib - pip: - -r requirements_torch_cu118.txt - -r requirements_base.txt - -e ../thirdparties/SMPLSim - -e ../ ================================================ FILE: environments/environment_train_isaaclab_cu128.yaml ================================================ name: holomotion_train channels: - conda-forge dependencies: - python=3.11 - pip - mesalib - pip: - -r requirements_torch_cu128.txt - -r requirements_base.txt - -e ../thirdparties/SMPLSim - -e ../ ================================================ FILE: environments/requirements_base.txt ================================================ isaacsim[all,extscache]==5.0.0 isaaclab[isaacsim,all]==2.2.0 setuptools wheel numpy==1.26.0 smplx==0.1.28 hydra-core==1.3.2 easydict tqdm open3d lxml ray ipdb joblib scipy jupyter loguru tensorboard mujoco mink dm_control loop_rate_limiters qpsolvers[quadprog,proxqp] accelerate tabulate matplotlib pandas termcolor rich pytorch-tcn einops onnxruntime-gpu onnx pre-commit ruff pytest imageio>=2.9 imageio-ffmpeg opencv-python natsort psutil redis[hiredis] chumpy pyvirtualdisplay pynput xxhash h5py>=3.8 pygame tensordict==0.11.0 pytorch_kinematics onnxscript # human_body_prior 依赖 human-body-prior ================================================ FILE: environments/requirements_deploy.txt ================================================ # Machine Learning Runtime onnxruntime-gpu # SMPL/SMPLX support smplx==0.1.28 # Configuration management hydra-core==1.3.2 omegaconf # Progress and logging tqdm loguru termcolor rich # Data processing lmdb einops # Protobuf (specific version for compatibility) protobuf==3.20.3 onnx # Development tools ipdb mujoco pygame # Note: The following packages are installed via conda in environment_deploy.yaml: # - torch, torchvision, torchaudio (with CUDA support) # - numpy, scipy, matplotlib, pandas # - pyyaml, easydict, joblib # - system utilities (git, curl, wget, sshpass) ================================================ FILE: environments/requirements_torch_cu118.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://pypi.nvidia.com torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 ================================================ FILE: environments/requirements_torch_cu128.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu128 --extra-index-url https://pypi.nvidia.com torch==2.10.0+cu128 torchvision==0.25.0+cu128 torchaudio==2.10.0+cu128 ================================================ FILE: environments/requirements_torch_cu130.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 --extra-index-url https://pypi.nvidia.com torch==2.10.0+cu130 torchvision==0.25.0+cu130 torchaudio==2.10.0+cu130 ================================================ FILE: holomotion/config/algo/ppo.yaml ================================================ # @package _global_ algo: _target_: holomotion.src.algo.ppo.PPO _recursive_: false config: # --- General Settings --- enable_online_eval: false num_learning_iterations: 10001 log_interval: 5 save_interval: 500 export_policy: true onnx_name_suffix: null use_kv_cache: true eval_interval: null load_optimizer: true headless: ${headless} # --- # --- Accelerate Settings --- mixed_precision: null # "fp16", "bf16", or null. Use "bf16" for A100/H100, "fp16" for older GPUs dynamo_backend: "inductor" # "inductor", "aot_eager", "cudagraphs", or null. Enables automatic model compilation during prepare() # --- # --- PPO Related Settings --- init_at_random_ep_len: true num_steps_per_env: 32 num_learning_epochs: 3 num_mini_batches: 4 clip_param: 0.2 gamma: 0.99 lam: 0.95 value_loss_coef: 1.0 entropy_coef: 5.0e-3 anneal_entropy: false zero_entropy_point: 1.0 max_grad_norm: 1.0 use_clipped_value_loss: true desired_kl: 0.01 init_noise_std: 1.0 # --- Optimizer Settings --- optimizer_type: AdamW # Options: "AdamW", "Adam" schedule: adaptive actor_learning_rate: 3.0e-4 critic_learning_rate: 5.0e-4 adaptive_lr: adapt_critic: false lr_scaler: 1.2 kl_high_factor: 2.0 kl_low_factor: 0.5 min_learning_rate: 1.0e-6 max_learning_rate: 1.0 distributed_update: mode: scalable # Options: "legacy", "scalable" lr_scale: mode: sqrt_world_size # Options: "none", "sqrt_world_size", "linear_world_size" reference_world_size: 1 max_scale: null kl_early_stop: enabled: true signal: window_mean # Shared windowed KL control signal window_size: 3 factor: 1.8 min_updates: 1 # Distributed training settings normalize_advantage_per_mini_batch: false # Use global advantage norm for DDP global_advantage_norm: true # Sync advantages across all ranks # --- Sampling Strategy --- sampling_strategy: uniform curriculum: p_a_ratio: 0.5 ema_alpha_signal: 0.2 ema_alpha_rel_improve: 0.2 relative_eps: 1.0e-6 dump_whole_window_scores_json: false dump_whole_window_scores_every_swaps: 10 weighted_bin: bin_regex_patterns: [] dump_sampled_keys: false dump_sampled_keys_interval: 1000 # --- Module Settings --- module_dict: actor: ${modules.actor} critic: ${modules.critic} symmetry_loss: enabled: false coef: 0.1 dof_sign_by_name: ${robot.dof_sign_by_name} ================================================ FILE: holomotion/config/algo/ppo_tf.yaml ================================================ # @package _global_ defaults: - ppo algo: _target_: holomotion.src.algo.ppo_tf.PPOTF config: num_steps_per_env: 32 kl_coef: 0.0 schedule: adaptive actor_learning_rate: 3.0e-5 critic_learning_rate: 5.0e-5 num_learning_epochs: 3 num_mini_batches: 24 clip_param: 0.2 entropy_coef: 5.0e-3 desired_kl: 0.01 noise_std_type: log fix_sigma: false init_noise_std: 1.0 min_sigma: 0.01 max_sigma: 1.2 aux_state_pred: enabled: true w_keybody_contact: 1.0e-2 w_base_lin_vel: 1.0e-2 w_ref_keybody_rel_pos: 1.0e-1 w_robot_keybody_rel_pos: 1.0e-1 min_std: 0.01 max_std: 2.0 keybody_contact_names: - left_hip_pitch_link - right_hip_pitch_link - left_knee_link - right_knee_link - left_ankle_roll_link - right_ankle_roll_link - left_elbow_link - right_elbow_link - left_wrist_yaw_link - right_wrist_yaw_link keybody_rel_pos_names: - left_knee_link - right_knee_link - left_ankle_roll_link - right_ankle_roll_link - left_elbow_link - right_elbow_link - left_wrist_yaw_link - right_wrist_yaw_link dead_expert_margin_to_topk: enabled: true weight: 10.0 aux_router_command_recon: enabled: false weight: 0.0 hidden_dim: 0 term_prefix: actor_ref_ aux_router_switch_penalty: enabled: false weight: 0.0 router_expert_orthogonal: enabled: false weight: 0.0 min_active_usage: 1.0e-3 eps: 1.0e-8 selected_expert_margin_to_unselected: enabled: false weight: 0.0 target: 0.0 moe_router: routing_score_fn: softmax routing_scale: 1.0 use_dynamic_bias: false bias_update_rate: 0.001 expert_bias_clip: 0.0 ================================================ FILE: holomotion/config/data_curation/joints2smpl.yaml ================================================ ================================================ FILE: holomotion/config/data_curation/smplify_base.yaml ================================================ ================================================ FILE: holomotion/config/env/domain_randomization/NO_domain_rand.yaml ================================================ # @package _global_ domain_rand: action_delay: enabled: false erfi: enabled: false motion_init_perturb: root_pose_perturb_range: x: [0.0, 0.0] y: [0.0, 0.0] z: [0.0, 0.0] roll: [0.0, 0.0] pitch: [0.0, 0.0] yaw: [0.0, 0.0] root_vel_perturb_range: x: [0.0, 0.0] y: [0.0, 0.0] z: [0.0, 0.0] roll: [0.0, 0.0] pitch: [0.0, 0.0] yaw: [0.0, 0.0] dof_pos_perturb_range: [0.0, 0.0] dof_vel_perturb_range: [0.0, 0.0] obs_noise: actor_ref_gravity_projection_cur: n_min: 0 n_max: 0 actor_ref_gravity_projection_fut: n_min: 0 n_max: 0 actor_ref_base_linvel_cur: n_min: 0 n_max: 0 n_min_z: 0 n_max_z: 0 actor_ref_base_linvel_fut: n_min: 0 n_max: 0 n_min_z: 0 n_max_z: 0 actor_ref_base_angvel_cur: n_min: 0 n_max: 0 n_min_z: 0 n_max_z: 0 actor_ref_base_angvel_fut: n_min: 0 n_max: 0 n_min_z: 0 n_max_z: 0 actor_ref_dof_pos_cur: n_min: 0 n_max: 0 actor_ref_dof_pos_fut: n_min: 0 n_max: 0 actor_ref_root_height_cur: n_min: 0 n_max: 0 actor_ref_root_height_fut: n_min: 0 n_max: 0 actor_ref_keybody_rel_pos_cur: n_min: 0 n_max: 0 actor_ref_keybody_rel_pos_fut: n_min: 0 n_max: 0 actor_projected_gravity: n_min: 0.0 n_max: 0.0 actor_rel_robot_root_ang_vel: n_min: 0.0 n_max: 0.0 actor_dof_pos: n_min: 0.0 n_max: 0.0 actor_dof_vel: n_min: 0.0 n_max: 0.0 ================================================ FILE: holomotion/config/env/domain_randomization/domain_rand_medium.yaml ================================================ # @package _global_ domain_rand: action_delay: enabled: true min_delay: 0 max_delay: 2 erfi: enabled: false rfi_probability: 0.5 rfi_lim: 0.1 randomize_rfi_lim: true rfi_lim_range: [0.5, 1.5] rao_lim: 0.1 obs_noise: actor_ref_gravity_projection_cur: n_min: -0.1 n_max: 0.1 actor_ref_gravity_projection_fut: n_min: -0.1 n_max: 0.1 actor_ref_base_linvel_cur: n_min: -0.1 n_max: 0.1 n_min_z: -0.05 n_max_z: 0.05 actor_ref_base_linvel_fut: n_min: -0.1 n_max: 0.1 n_min_z: -0.05 n_max_z: 0.05 actor_ref_base_angvel_cur: n_min: -0.1 n_max: 0.1 n_min_z: -0.1 n_max_z: 0.1 actor_ref_base_angvel_fut: n_min: -0.1 n_max: 0.1 n_min_z: -0.1 n_max_z: 0.1 actor_ref_dof_pos_cur: n_min: -0.05 n_max: 0.05 actor_ref_dof_pos_fut: n_min: -0.05 n_max: 0.05 actor_ref_root_height_cur: n_min: -0.1 n_max: 0.1 actor_ref_root_height_fut: n_min: -0.1 n_max: 0.1 actor_ref_keybody_rel_pos_cur: n_min: -0.1 n_max: 0.1 actor_ref_keybody_rel_pos_fut: n_min: -0.1 n_max: 0.1 actor_projected_gravity: n_min: -0.1 n_max: 0.1 actor_rel_robot_root_ang_vel: n_min: -0.2 n_max: 0.2 actor_dof_pos: n_min: -0.01 n_max: 0.01 actor_dof_vel: n_min: -0.5 n_max: 0.5 motion_init_perturb: root_pose_perturb_range: x: [-0.05, 0.05] y: [-0.05, 0.05] z: [-0.01, 0.01] roll: [-0.1, 0.1] pitch: [-0.1, 0.1] yaw: [-0.2, 0.2] root_vel_perturb_range: x: [-0.5, 0.5] y: [-0.5, 0.5] z: [-0.2, 0.2] roll: [-0.5, 0.5] pitch: [-0.5, 0.5] yaw: [-0.2, 0.2] dof_pos_perturb_range: [-0.1, 0.1] dof_vel_perturb_range: [0.0, 0.0] default_dof_pos_bias: mode: startup params: joint_names: [".*"] pos_distribution_params: [-0.01, 0.01] operation: add distribution: uniform rigid_body_com: mode: startup params: body_names: torso_link com_range: x: [-0.075, 0.075] y: [-0.1, 0.1] z: [-0.1, 0.1] randomize_mass: mode: startup params: body_names: - "pelvis" - "torso_link" mass_range: [-1.0, 2.0] rigid_body_material: mode: startup params: body_names: ".*" static_friction_range: [0.3, 1.6] dynamic_friction_range: [0.3, 1.2] restitution_range: [0.0, 0.5] num_buckets: 64 push_by_setting_velocity: mode: interval interval_range_s: [1.0, 3.0] params: velocity_range: x: [-0.5, 0.5] y: [-0.5, 0.5] z: [-0.2, 0.2] roll: [-0.52, 0.52] pitch: [-0.52, 0.52] yaw: [-0.78, 0.78] randomize_actuator_gains: mode: startup params: asset_name: robot body_names: ".*" stiffness_distribution_params: [0.9, 1.1] damping_distribution_params: [0.9, 1.1] operation: scale distribution: uniform ================================================ FILE: holomotion/config/env/domain_randomization/domain_rand_small.yaml ================================================ # @package _global_ domain_rand: action_delay: enabled: false min_delay: 0 max_delay: 0 erfi: enabled: false rfi_probability: 0.5 rfi_lim: 0.1 randomize_rfi_lim: true rfi_lim_range: [0.5, 1.5] rao_lim: 0.1 obs_noise: actor_ref_gravity_projection_cur: n_min: -0.1 n_max: 0.1 actor_ref_gravity_projection_fut: n_min: -0.1 n_max: 0.1 actor_ref_base_linvel_cur: n_min: -0.1 n_max: 0.1 n_min_z: -0.05 n_max_z: 0.05 actor_ref_base_linvel_fut: n_min: -0.1 n_max: 0.1 n_min_z: -0.05 n_max_z: 0.05 actor_ref_base_angvel_cur: n_min: -0.1 n_max: 0.1 n_min_z: -0.1 n_max_z: 0.1 actor_ref_base_angvel_fut: n_min: -0.1 n_max: 0.1 n_min_z: -0.1 n_max_z: 0.1 actor_ref_dof_pos_cur: n_min: -0.05 n_max: 0.05 actor_ref_dof_pos_fut: n_min: -0.05 n_max: 0.05 actor_ref_root_height_cur: n_min: -0.1 n_max: 0.1 actor_ref_root_height_fut: n_min: -0.1 n_max: 0.1 actor_ref_keybody_rel_pos_cur: n_min: -0.1 n_max: 0.1 actor_ref_keybody_rel_pos_fut: n_min: -0.1 n_max: 0.1 actor_projected_gravity: n_min: -0.1 n_max: 0.1 actor_rel_robot_root_ang_vel: n_min: -0.2 n_max: 0.2 actor_dof_pos: n_min: -0.01 n_max: 0.01 actor_dof_vel: n_min: -0.5 n_max: 0.5 motion_init_perturb: root_pose_perturb_range: x: [-0.05, 0.05] y: [-0.05, 0.05] z: [-0.01, 0.01] roll: [-0.1, 0.1] pitch: [-0.1, 0.1] yaw: [-0.2, 0.2] root_vel_perturb_range: x: [-0.3, 0.3] y: [-0.3, 0.3] z: [-0.1, 0.1] roll: [-0.3, 0.3] pitch: [-0.3, 0.3] yaw: [-0.4, 0.4] dof_pos_perturb_range: [-0.1, 0.1] dof_vel_perturb_range: [0.0, 0.0] default_dof_pos_bias: mode: startup params: joint_names: [".*"] pos_distribution_params: [-0.01, 0.01] operation: add distribution: uniform rigid_body_com: mode: startup params: body_names: torso_link com_range: x: [-0.025, 0.025] y: [-0.05, 0.05] z: [-0.05, 0.05] randomize_mass: mode: startup params: body_names: - "pelvis" - "torso_link" mass_range: [-1.0, 2.0] rigid_body_material: mode: startup params: body_names: ".*" static_friction_range: [0.3, 1.6] dynamic_friction_range: [0.3, 1.2] restitution_range: [0.0, 0.5] num_buckets: 64 push_by_setting_velocity: mode: interval interval_range_s: [1.0, 3.0] params: velocity_range: x: [-0.5, 0.5] y: [-0.5, 0.5] z: [-0.2, 0.2] roll: [-0.52, 0.52] pitch: [-0.52, 0.52] yaw: [-0.78, 0.78] randomize_actuator_gains: mode: startup params: asset_name: robot body_names: ".*" stiffness_distribution_params: [0.9, 1.1] damping_distribution_params: [0.9, 1.1] operation: scale distribution: uniform ================================================ FILE: holomotion/config/env/domain_randomization/domain_rand_strong.yaml ================================================ # @package _global_ domain_rand: action_delay: enabled: true min_delay: 0 max_delay: 4 erfi: enabled: true rfi_probability: 0.5 rfi_lim: 0.1 randomize_rfi_lim: true rfi_lim_range: [0.5, 1.5] rao_lim: 0.1 obs_noise: actor_ref_gravity_projection_cur: n_min: -0.1 n_max: 0.1 actor_ref_gravity_projection_fut: n_min: -0.1 n_max: 0.1 actor_ref_base_linvel_cur: n_min: -0.1 n_max: 0.1 n_min_z: -0.05 n_max_z: 0.05 actor_ref_base_linvel_fut: n_min: -0.1 n_max: 0.1 n_min_z: -0.05 n_max_z: 0.05 actor_ref_base_angvel_cur: n_min: -0.1 n_max: 0.1 n_min_z: -0.1 n_max_z: 0.1 actor_ref_base_angvel_fut: n_min: -0.1 n_max: 0.1 n_min_z: -0.1 n_max_z: 0.1 actor_ref_dof_pos_cur: n_min: -0.05 n_max: 0.05 actor_ref_dof_pos_fut: n_min: -0.05 n_max: 0.05 actor_ref_root_height_cur: n_min: -0.1 n_max: 0.1 actor_ref_root_height_fut: n_min: -0.1 n_max: 0.1 actor_ref_keybody_rel_pos_cur: n_min: -0.1 n_max: 0.1 actor_ref_keybody_rel_pos_fut: n_min: -0.1 n_max: 0.1 actor_projected_gravity: n_min: -0.1 n_max: 0.1 actor_rel_robot_root_ang_vel: n_min: -0.2 n_max: 0.2 actor_dof_pos: n_min: -0.01 n_max: 0.01 actor_dof_vel: n_min: -0.5 n_max: 0.5 motion_init_perturb: root_pose_perturb_range: x: [-0.05, 0.05] y: [-0.05, 0.05] z: [-0.01, 0.01] roll: [-0.1, 0.1] pitch: [-0.1, 0.1] yaw: [-0.2, 0.2] root_vel_perturb_range: x: [-0.5, 0.5] y: [-0.5, 0.5] z: [-0.2, 0.2] roll: [-0.5, 0.5] pitch: [-0.5, 0.5] yaw: [-0.2, 0.2] dof_pos_perturb_range: [-0.1, 0.1] dof_vel_perturb_range: [0.0, 0.0] default_dof_pos_bias: mode: startup params: joint_names: [".*"] pos_distribution_params: [-0.01, 0.01] operation: add distribution: uniform rigid_body_com: mode: startup params: body_names: torso_link com_range: x: [-0.075, 0.075] y: [-0.1, 0.1] z: [-0.1, 0.1] randomize_mass: mode: startup params: body_names: - "pelvis" - "torso_link" mass_range: [-1.0, 2.0] rigid_body_material: mode: startup params: body_names: ".*" static_friction_range: [0.3, 1.6] dynamic_friction_range: [0.3, 1.2] restitution_range: [0.0, 0.5] num_buckets: 64 push_by_setting_velocity: mode: interval interval_range_s: [1.0, 3.0] params: velocity_range: x: [-0.5, 0.5] y: [-0.5, 0.5] z: [-0.2, 0.2] roll: [-0.52, 0.52] pitch: [-0.52, 0.52] yaw: [-0.78, 0.78] randomize_actuator_gains: mode: startup params: asset_name: robot body_names: ".*" stiffness_distribution_params: [0.9, 1.1] damping_distribution_params: [0.9, 1.1] operation: scale distribution: uniform ================================================ FILE: holomotion/config/env/motion_tracking.yaml ================================================ # @package _global_ env: _target_: holomotion.src.env.motion_tracking.MotionTrackingEnv _recursive_: False config: experiment_name: ${experiment_name} num_envs: ${num_envs} env_spacing: 2.5 # meters replicate_physics: true headless: ${headless} num_processes: ${num_processes} main_process: ${main_process} process_id: ${process_id} ckpt_dir: null disable_ref_viz: false eval_log_dir: null save_rendering_dir: null robot: ${robot} domain_rand: ${domain_rand} rewards: ${rewards} terrain: ${terrain} obs: ${obs} terminations: ${terminations} simulation: episode_length_s: 10 # Long episodes for fluid motion-based termination sim_freq: 200 control_decimation: 4 physx: bounce_threshold_velocity: 0.5 gpu_max_rigid_patch_count: 327680 # 10 * 2**15 scene: terrain: ${terrain} lighting: distant_light_intensity: 3000.0 dome_light_intensity: 1000.0 contact_sensor: history_length: 3 force_threshold: 10.0 track_air_time: true debug_vis: false actions: dof_pos: type: joint_position params: asset_name: robot joint_names: - ".*" use_default_offset: true scale: ${robot.actuators.action_scale} commands: ref_motion: type: MotionCommandCfg params: command_obs_name: bydmmc_ref_motion motion_lib_cfg: ${robot.motion} urdf_dof_names: ${robot.dof_names} urdf_body_names: ${robot.body_names} arm_dof_names: ${robot.arm_dof_names} waist_dof_names: ${robot.waist_dof_names} leg_dof_names: ${robot.leg_dof_names} arm_body_names: ${robot.arm_body_names} torso_body_names: ${robot.torso_body_names} leg_body_names: ${robot.leg_body_names} anchor_bodylink_name: ${robot.anchor_body} asset_name: robot debug_vis: true root_pose_perturb_range: ${domain_rand.motion_init_perturb.root_pose_perturb_range} root_vel_perturb_range: ${domain_rand.motion_init_perturb.root_vel_perturb_range} dof_pos_perturb_range: ${domain_rand.motion_init_perturb.dof_pos_perturb_range} dof_vel_perturb_range: ${domain_rand.motion_init_perturb.dof_vel_perturb_range} resample_time_interval_s: 100 n_fut_frames: ${obs.n_fut_frames} target_fps: 50 normalization: clip_actions: 100.0 clip_observations: 100.0 resample_motion_when_training: True curriculum: enabled: false robot_friction_completion_rate: enabled: True func: robot_friction_range_by_completion_rate params: num_updates: 5 cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40] static_friction_target: [0.3, 1.6] dynamic_friction_target: [0.3, 1.2] body_names: ".*" restitution_range: [0.0, 0.5] num_buckets: 64 rigid_body_com_completion_rate: enabled: True func: rigid_body_com_by_completion_rate params: num_updates: 5 cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40] state_prefix: "_cr_curr" asset_name: "robot" body_names: "torso_link" com_range_target: x: [-0.025, 0.025] y: [-0.05, 0.05] z: [-0.05, 0.05] default_dof_pos_bias_completion_rate: enabled: True func: default_dof_pos_bias_by_completion_rate params: num_updates: 5 cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40] state_prefix: "_cr_curr" joint_names: - ".*" pos_distribution_params_target: [-0.01, 0.01] operation: add distribution: uniform push_by_setting_velocity_completion_rate: enabled: True func: isaaclab_mdp.modify_term_cfg params: address: "events.push_by_setting_velocity.params" modify_fn: push_by_setting_velocity_range_by_completion_rate modify_params: num_updates: 3 cr_thresholds: [0.20, 0.30, 0.40] velocity_range_target: x: [-0.5, 0.5] y: [-0.5, 0.5] z: [-0.2, 0.2] roll: [-0.52, 0.52] pitch: [-0.52, 0.52] yaw: [-0.78, 0.78] randomize_actuator_gains_completion_rate: enabled: True func: randomize_actuator_gains_by_completion_rate params: num_updates: 3 cr_thresholds: [0.20, 0.30, 0.40] asset_name: "robot" body_names: ".*" stiffness_distribution_params_target: [0.9, 1.1] damping_distribution_params_target: [0.9, 1.1] operation: scale distribution: uniform action_rate_l2_completion_rate: enabled: true func: reward_term_weight_by_completion_rate params: reward_term_name: "action_rate_l2" final_weight: -0.1 start_scale: 0.1 num_updates: 5 cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40] joint_pos_limits_completion_rate: enabled: true func: reward_term_weight_by_completion_rate params: reward_term_name: "joint_pos_limits" final_weight: -10.0 start_scale: 0.1 num_updates: 5 cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40] undesired_contacts_completion_rate: enabled: true func: reward_term_weight_by_completion_rate params: reward_term_name: "undesired_contacts" final_weight: -0.1 start_scale: 0.1 num_updates: 5 cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40] ================================================ FILE: holomotion/config/env/observations/motion_tracking/obs_motion_tracking_mlp.yaml ================================================ # @package _global_ obs: context_length: 32 n_fut_frames: 10 target_fps: 50 actor_obs_prefix: "ref_" critic_obs_prefix: "ref_" obs_groups: unified: atomic_obs_list: - actor_ref_gravity_projection_cur: func: ref_gravity_projection_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_max} - actor_ref_gravity_projection_fut: func: ref_gravity_projection_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_max} # Reference base linear velocity - actor_ref_base_linvel_cur: func: ref_base_linvel_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max} n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min_z} n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max_z} - actor_ref_base_linvel_fut: func: ref_base_linvel_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max} n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min_z} n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max_z} - actor_ref_base_angvel_cur: func: ref_base_angvel_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max} n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min_z} n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max_z} - actor_ref_base_angvel_fut: func: ref_base_angvel_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max} n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min_z} n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max_z} - actor_ref_dof_pos_cur: func: ref_dof_pos_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_max} - actor_ref_dof_pos_fut: func: ref_dof_pos_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_max} - actor_ref_motion_filter_cutoff_hz: func: ref_motion_filter_cutoff_hz - actor_ref_root_height_cur: func: ref_root_height_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_max} - actor_ref_root_height_fut: func: ref_root_height_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_max} - actor_ref_keybody_rel_pos_cur: func: ref_keybody_rel_pos_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} keybody_names: - "left_knee_link" - "right_knee_link" - "left_ankle_roll_link" - "right_ankle_roll_link" - "left_elbow_link" - "right_elbow_link" - "left_wrist_yaw_link" - "right_wrist_yaw_link" noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_max} - actor_ref_keybody_rel_pos_fut: func: ref_keybody_rel_pos_fut params: ref_prefix: ${obs.actor_obs_prefix} keybody_names: - "left_knee_link" - "right_knee_link" - "left_ankle_roll_link" - "right_ankle_roll_link" - "left_elbow_link" - "right_elbow_link" - "left_wrist_yaw_link" - "right_wrist_yaw_link" noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_max} - actor_projected_gravity: func: projected_gravity history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_projected_gravity.n_min} n_max: ${domain_rand.obs_noise.actor_projected_gravity.n_max} - actor_rel_robot_root_ang_vel: func: rel_robot_root_ang_vel history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_min} n_max: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_max} - actor_dof_pos: func: dof_pos history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_dof_pos.n_min} n_max: ${domain_rand.obs_noise.actor_dof_pos.n_max} - actor_dof_vel: func: dof_vel history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_dof_vel.n_min} n_max: ${domain_rand.obs_noise.actor_dof_vel.n_max} - actor_last_action: func: last_action history_length: ${obs.context_length} flatten_history_dim: false - critic_ref_dof_pos_cur: func: ref_dof_pos_cur params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_dof_pos_fut: func: ref_dof_pos_fut params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_root_height_fut: func: ref_root_height_fut params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_root_height_cur: func: ref_root_height_cur params: ref_prefix: ${obs.critic_obs_prefix} - critic_global_anchor_diff: func: global_anchor_diff params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_cur_heading_aligned_root_pos: func: ref_motion_cur_heading_aligned_root_pos params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_fut_heading_aligned_root_pos: func: ref_motion_fut_heading_aligned_root_pos params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_cur_heading_aligned_root_rot6d: func: ref_motion_cur_heading_aligned_root_rot6d params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_fut_heading_aligned_root_rot6d: func: ref_motion_fut_heading_aligned_root_rot6d params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_cur_heading_aligned_root_lin_vel: func: ref_motion_cur_heading_aligned_root_lin_vel params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_fut_heading_aligned_root_lin_vel: func: ref_motion_fut_heading_aligned_root_lin_vel params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_cur_heading_aligned_root_ang_vel: func: ref_motion_cur_heading_aligned_root_ang_vel params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_fut_heading_aligned_root_ang_vel: func: ref_motion_fut_heading_aligned_root_ang_vel params: ref_prefix: ${obs.critic_obs_prefix} - critic_rel_robot_root_lin_vel: func: rel_robot_root_lin_vel - critic_rel_robot_root_ang_vel: func: rel_robot_root_ang_vel - critic_global_robot_bodylink_lin_vel_flat: func: global_robot_bodylink_lin_vel_flat - critic_global_robot_bodylink_ang_vel_flat: func: global_robot_bodylink_ang_vel_flat - critic_root_rel_robot_bodylink_pos_flat: func: root_rel_robot_bodylink_pos_flat - critic_root_rel_robot_bodylink_rot_mat_flat: func: root_rel_robot_bodylink_rot_mat_flat - critic_dof_pos: func: dof_pos - critic_dof_vel: func: dof_vel - critic_last_action: func: last_action enable_corruption: true concatenate_terms: false ================================================ FILE: holomotion/config/env/observations/motion_tracking/obs_motion_tracking_tf-moe.yaml ================================================ # @package _global_ obs: context_length: 1 n_fut_frames: 10 target_fps: 50 actor_obs_prefix: "ref_" critic_obs_prefix: "ref_" obs_groups: unified: atomic_obs_list: - actor_ref_gravity_projection_cur: func: ref_gravity_projection_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_max} - actor_ref_gravity_projection_fut: func: ref_gravity_projection_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_max} # Reference base linear velocity - actor_ref_base_linvel_cur: func: ref_base_linvel_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max} n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min_z} n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max_z} - actor_ref_base_linvel_fut: func: ref_base_linvel_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max} n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min_z} n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max_z} - actor_ref_base_angvel_cur: func: ref_base_angvel_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max} n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min_z} n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max_z} - actor_ref_base_angvel_fut: func: ref_base_angvel_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max} n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min_z} n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max_z} - actor_ref_dof_pos_cur: func: ref_dof_pos_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_max} - actor_ref_dof_pos_fut: func: ref_dof_pos_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_max} - actor_ref_motion_filter_cutoff_hz: func: ref_motion_filter_cutoff_hz - actor_ref_root_height_cur: func: ref_root_height_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_max} - actor_ref_root_height_fut: func: ref_root_height_fut params: ref_prefix: ${obs.actor_obs_prefix} noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_max} - actor_ref_keybody_rel_pos_cur: func: ref_keybody_rel_pos_cur history_length: ${obs.context_length} flatten_history_dim: false params: ref_prefix: ${obs.actor_obs_prefix} keybody_names: - "left_knee_link" - "right_knee_link" - "left_ankle_roll_link" - "right_ankle_roll_link" - "left_elbow_link" - "right_elbow_link" - "left_wrist_yaw_link" - "right_wrist_yaw_link" noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_min} n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_max} - actor_ref_keybody_rel_pos_fut: func: ref_keybody_rel_pos_fut params: ref_prefix: ${obs.actor_obs_prefix} keybody_names: - "left_knee_link" - "right_knee_link" - "left_ankle_roll_link" - "right_ankle_roll_link" - "left_elbow_link" - "right_elbow_link" - "left_wrist_yaw_link" - "right_wrist_yaw_link" noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_min} n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_max} - actor_projected_gravity: func: projected_gravity history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_projected_gravity.n_min} n_max: ${domain_rand.obs_noise.actor_projected_gravity.n_max} - actor_rel_robot_root_ang_vel: func: rel_robot_root_ang_vel history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_min} n_max: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_max} - actor_dof_pos: func: dof_pos history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_dof_pos.n_min} n_max: ${domain_rand.obs_noise.actor_dof_pos.n_max} - actor_dof_vel: func: dof_vel history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: ${domain_rand.obs_noise.actor_dof_vel.n_min} n_max: ${domain_rand.obs_noise.actor_dof_vel.n_max} - actor_last_action: func: last_action history_length: ${obs.context_length} flatten_history_dim: false - critic_ref_dof_pos_cur: func: ref_dof_pos_cur params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_dof_pos_fut: func: ref_dof_pos_fut params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_root_height_fut: func: ref_root_height_fut params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_root_height_cur: func: ref_root_height_cur params: ref_prefix: ${obs.critic_obs_prefix} - critic_global_anchor_diff: func: global_anchor_diff params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_cur_heading_aligned_root_pos: func: ref_motion_cur_heading_aligned_root_pos params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_fut_heading_aligned_root_pos: func: ref_motion_fut_heading_aligned_root_pos params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_cur_heading_aligned_root_rot6d: func: ref_motion_cur_heading_aligned_root_rot6d params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_fut_heading_aligned_root_rot6d: func: ref_motion_fut_heading_aligned_root_rot6d params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_cur_heading_aligned_root_lin_vel: func: ref_motion_cur_heading_aligned_root_lin_vel params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_fut_heading_aligned_root_lin_vel: func: ref_motion_fut_heading_aligned_root_lin_vel params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_cur_heading_aligned_root_ang_vel: func: ref_motion_cur_heading_aligned_root_ang_vel params: ref_prefix: ${obs.critic_obs_prefix} - critic_ref_motion_fut_heading_aligned_root_ang_vel: func: ref_motion_fut_heading_aligned_root_ang_vel params: ref_prefix: ${obs.critic_obs_prefix} - critic_rel_robot_root_lin_vel: func: rel_robot_root_lin_vel - critic_rel_robot_root_ang_vel: func: rel_robot_root_ang_vel - critic_global_robot_bodylink_lin_vel_flat: func: global_robot_bodylink_lin_vel_flat - critic_global_robot_bodylink_ang_vel_flat: func: global_robot_bodylink_ang_vel_flat - critic_root_rel_robot_bodylink_pos_flat: func: root_rel_robot_bodylink_pos_flat - critic_root_rel_robot_bodylink_rot_mat_flat: func: root_rel_robot_bodylink_rot_mat_flat - critic_dof_pos: func: dof_pos - critic_dof_vel: func: dof_vel - critic_last_action: func: last_action enable_corruption: true concatenate_terms: false ================================================ FILE: holomotion/config/env/observations/velocity_tracking/obs_velocity_tracking.yaml ================================================ # @package _global_ obs: context_length: 8 n_fut_frames: 0 target_fps: 50 obs_groups: unified: atomic_obs_list: # Actor terms (sequence-style; flatten at serializer) - actor_velocity_command: func: velocity_command history_length: ${obs.context_length} flatten_history_dim: false mirror_func: mirror_velocity_command mirror_config: {} - actor_projected_gravity: func: projected_gravity history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: -0.1 n_max: 0.1 mirror_func: mirror_vec3 mirror_config: {} - actor_rel_robot_root_ang_vel: func: rel_robot_root_ang_vel history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: -0.2 n_max: 0.2 mirror_func: mirror_axial_vec3 mirror_config: {} - actor_dof_pos: func: dof_pos history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: -0.01 n_max: 0.01 mirror_func: mirror_dof mirror_config: {} - actor_dof_vel: func: dof_vel history_length: ${obs.context_length} flatten_history_dim: false noise: type: AdditiveUniformNoiseCfg params: n_min: -0.5 n_max: 0.5 mirror_func: mirror_dof mirror_config: {} - actor_last_action: func: last_action history_length: ${obs.context_length} flatten_history_dim: false mirror_func: mirror_dof mirror_config: {} # Critic terms - critic_velocity_command: func: velocity_command - critic_rel_robot_root_lin_vel: func: rel_robot_root_lin_vel - critic_rel_robot_root_ang_vel: func: rel_robot_root_ang_vel - critic_root_rel_robot_bodylink_pos_flat: func: root_rel_robot_bodylink_pos_flat params: keybody_names: ${robot.key_bodies} - critic_root_rel_robot_bodylink_rot_mat_flat: func: root_rel_robot_bodylink_rot_mat_flat params: keybody_names: ${robot.key_bodies} - critic_dof_pos: func: dof_pos - critic_dof_vel: func: dof_vel - critic_last_action: func: last_action enable_corruption: true concatenate_terms: false ================================================ FILE: holomotion/config/env/rewards/motion_tracking/rew_motion_tracking.yaml ================================================ # @package _global_ rewards: _config: reward_prefix: "ref_" is_alive: weight: 0.5 params: {} root_pos_xy_tracking_exp: weight: 1.0 params: std: 0.2 ref_prefix: ${rewards._config.reward_prefix} root_rot_tracking_exp: weight: 0.5 params: std: 0.4 ref_prefix: ${rewards._config.reward_prefix} root_rel_keybodylink_pos_tracking_l2_exp: weight: 1.0 params: keybody_names: ${robot.key_bodies} std: 0.3 ref_prefix: ${rewards._config.reward_prefix} root_rel_keybodylink_rot_tracking_l2_exp: weight: 2.0 params: keybody_names: ${robot.key_bodies} std: 0.4 ref_prefix: ${rewards._config.reward_prefix} global_keybodylink_lin_vel_tracking_l2_exp: weight: 1.0 params: keybody_names: ${robot.key_bodies} std: 1.0 ref_prefix: ${rewards._config.reward_prefix} global_keybodylink_ang_vel_tracking_l2_exp: weight: 1.0 params: keybody_names: ${robot.key_bodies} std: 3.14 ref_prefix: ${rewards._config.reward_prefix} action_rate_l2: weight: -0.1 params: {} # joint_acc_l2: # weight: -1.0e-6 # params: {} joint_pos_limits: weight: -10.0 params: asset_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: robot joint_names: - ".*" undesired_contacts: weight: -0.1 params: threshold: 1.0 sensor_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: contact_forces body_names: - ${robot.undesired_contacts_regrex} ================================================ FILE: holomotion/config/env/rewards/velocity_tracking/rew_velocity_tracking.yaml ================================================ # @package _global_ rewards: stand_still_action_rate: weight: -1.0 params: command_name: base_velocity feet_contact_without_cmd: weight: 1.0 params: command_name: base_velocity sensor_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: contact_forces body_names: - ".*ankle_roll.*" track_stand_still_exp: weight: 5.0 params: command_name: base_velocity std: 0.2 track_lin_vel_xy_heading_aligned_frame_exp: weight: 3.0 params: std: 0.5 command_name: base_velocity track_ang_vel_z_heading_aligned_frame_exp: weight: 3.0 params: std: 0.5 command_name: base_velocity is_alive: weight: 0.15 params: {} lin_vel_z_l2: weight: -1.0 params: {} ang_vel_xy_l2: weight: -5.0e-2 params: {} joint_acc_l2: weight: -1.0e-6 params: {} action_rate_l2: weight: -1.0e-1 params: {} joint_pos_limits: weight: -5.0 params: asset_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: robot joint_names: [".*"] feet_air_time_v4: weight: 1.0 params: threshold: 0.5 sensor_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: contact_forces body_names: [".*ankle_roll.*"] command_name: base_velocity fly: weight: -1.0 params: threshold: 1.0 sensor_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: contact_forces body_names: [".*ankle_roll.*"] feet_too_near: weight: -10.0 threshold: 0.2 params: asset_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: robot joint_names: - .*ankle_roll.* joint_deviation_l1_arms: weight: -0.3 params: asset_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: robot joint_names: - .*_hip_roll.* - .*waist_roll.* - .*waist_pitch.* - .*_shoulder_roll.* - .*_shoulder_yaw.* - .*_wrist.* joint_deviation_l1_legs_yaw: weight: -0.15 params: asset_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: robot joint_names: - .*waist_yaw.* - .*_hip_yaw.* - .*_elbow.* - .*_ankle.* joint_deviation_l1_legs: weight: -0.02 params: asset_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: robot joint_names: - .*_shoulder_pitch.* - .*_hip_pitch.* - .*_knee.* flat_orientation_l2: weight: -5.0 params: {} base_height_l2: weight: -10.0 params: target_height: 0.78 feet_slide: weight: -1.0 params: asset_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: robot body_names: [".*ankle_roll.*"] sensor_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: contact_forces body_names: [".*ankle_roll.*"] undesired_contacts: weight: -1.0 params: threshold: 1.0 sensor_cfg: _target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg name: contact_forces body_names: - ${robot.undesired_contacts_regrex} torso_xy_ang_vel_l2_penalty: weight: -1.0 params: {} torso_upright_l2_penalty: weight: -1.0 params: {} ================================================ FILE: holomotion/config/env/terminations/NO_termination.yaml ================================================ # @package _global_ terminations: {} ================================================ FILE: holomotion/config/env/terminations/termination_motion_tracking.yaml ================================================ # @package _global_ terminations: time_out: time_out: true ref_gravity_projection_far: params: threshold: 0.8 ref_prefix: ${rewards._config.reward_prefix} keybody_ref_z_far: params: threshold: 0.25 ref_prefix: ${rewards._config.reward_prefix} keybody_names: - pelvis - left_ankle_roll_link - right_ankle_roll_link - left_wrist_yaw_link - right_wrist_yaw_link # keybody_ref_pos_far: # params: # threshold: 0.5 # ref_prefix: ${rewards._config.reward_prefix} # keybody_names: # - pelvis ================================================ FILE: holomotion/config/env/terminations/termination_velocity_tracking.yaml ================================================ # @package _global_ terminations: time_out: time_out: true root_height_below_minimum: params: minimum_height: 0.2 bad_orientation: params: limit_angle: 0.8 ================================================ FILE: holomotion/config/env/terrain/isaaclab_plane.yaml ================================================ # @package _global_ # Simple flat terrain generated via IsaacLab height-field TerrainGenerator. # Uses random-uniform height field with zero noise as a flat patch. terrain: terrain_type: generator prim_path: /World/ground static_friction: 1.0 dynamic_friction: 1.0 restitution: 0.0 friction_combine_mode: multiply restitution_combine_mode: multiply debug_vis: false max_init_terrain_level: 0 # Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch. # When false, env origins are placed on a regular grid as in the default importer. random_spawn: true # Keep random spawn points away from terrain edges to avoid spawning onto the outer border. random_spawn_margin: 2.0 # TerrainGeneratorCfg parameters. generator: num_rows: 1 num_cols: 1 size: [10.0, 10.0] border_width: 1000.0 horizontal_scale: 0.1 vertical_scale: 0.005 slope_threshold: null difficulty_range: [0.0, 0.0] color_scheme: height sub_terrains: plane: type: plane proportion: 1.0 # Offline visual material configuration (PreviewSurface, no MDL/Nucleus). visual_material: type: color diffuse_color: [0.25, 0.25, 0.25] metallic: 0.0 roughness: 0.5 ================================================ FILE: holomotion/config/env/terrain/isaaclab_rough.yaml ================================================ # @package _global_ # Rough height-field terrain for locomotion training. # Uses random-uniform height field to create continuous noise-like terrain. terrain: terrain_type: generator prim_path: /World/ground static_friction: 1.0 dynamic_friction: 1.0 restitution: 0.0 friction_combine_mode: multiply restitution_combine_mode: multiply debug_vis: false max_init_terrain_level: 4 # Randomize spawn position within each sub-terrain (recommended for locomotion). random_spawn: true random_spawn_margin: 4.0 # TerrainGeneratorCfg parameters. generator: num_rows: 4 # Number of terrain rows (difficulty levels) num_cols: 4 # Number of terrain columns (types) size: [20.0, 20.0] # Size of each sub-terrain in meters [length, width] border_width: 1000.0 # Border around terrain in meters horizontal_scale: 0.1 # Resolution in x-y plane vertical_scale: 0.005 # Height resolution slope_threshold: null # Slopes above this become vertical difficulty_range: [0.0, 1.0] # Min and max difficulty color_scheme: height # Use material shading instead of vertex colors sub_terrains: rough: type: random_uniform proportion: 1.0 noise_range: [0.0, 0.04] noise_step: 0.05 downsampled_scale: 1.0 # Offline visual material configuration (PreviewSurface, no MDL/Nucleus). visual_material: type: color diffuse_color: [0.25, 0.25, 0.25] metallic: 0.0 roughness: 0.5 ================================================ FILE: holomotion/config/env/velocity_tracking.yaml ================================================ # @package _global_ env: _target_: holomotion.src.env.velocity_tracking.VelocityTrackingEnv _recursive_: False config: experiment_name: ${experiment_name} num_envs: ${num_envs} env_spacing: 2.5 replicate_physics: true headless: ${headless} num_processes: ${num_processes} main_process: ${main_process} process_id: ${process_id} ckpt_dir: null disable_ref_viz: false eval_log_dir: null save_rendering_dir: null robot: ${robot} domain_rand: ${domain_rand} rewards: ${rewards} terrain: ${terrain} obs: ${obs} terminations: ${terminations} simulation: episode_length_s: 20 sim_freq: 200 control_decimation: 4 physx: bounce_threshold_velocity: 0.5 gpu_max_rigid_patch_count: 327680 scene: terrain: ${terrain} lighting: distant_light_intensity: 3000.0 dome_light_intensity: 1000.0 contact_sensor: history_length: 3 force_threshold: 10.0 track_air_time: true debug_vis: false actions: dof_pos: type: joint_position params: asset_name: robot joint_names: - ".*" use_default_offset: true scale: ${robot.actuators.action_scale} commands: base_velocity: type: HoloMotionUniformVelocityCommandCfg params: asset_name: robot resampling_time_range: [3, 10.0] rel_standing_envs: 0.20 rel_yaw_envs: 0.30 # actual prob for sampled yaw-only is 0.3 * (1-0.2) = 0.24 rel_heading_envs: 1.0 heading_command: false heading_control_stiffness: 0.5 debug_vis: true ranges: lin_vel_x: [-0.6, 1.0] lin_vel_y: [-0.5, 0.5] ang_vel_z: [-1.0, 1.0] heading: [-3.14, 3.14] # limit_ranges: # lin_vel_x: [-0.5, 1.0] # lin_vel_y: [-0.3, 0.3] # ang_vel_z: [-0.2, 0.2] # heading: [-3.14, 3.14] ================================================ FILE: holomotion/config/evaluation/eval_isaaclab.yaml ================================================ # @package _global_ defaults: - /robot: unitree/G1/29dof/29dof_training_isaaclab - /env: motion_tracking - /env/terrain: isaaclab_plane - /env/terminations: NO_termination - /env/domain_randomization: NO_domain_rand project_name: ??? experiment_name: ??? num_envs: ??? headless: ??? motion_h5_path: null checkpoint: null log_dir: null ckpt_pt_names: null num_processes: ??? main_process: ??? process_id: ??? timestamp: ${now:%Y%m%d_%H%M%S} base_dir: logs experiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name} save_dir: ${experiment_dir}/.hydra output_dir: ${experiment_dir}/output experiment_save_dir: ??? export_policy: false export_only: false dump_npzs: false calc_per_clip_metrics: false generate_report: false dof_mode: "23" obs: critic_obs_prefix: "ref_" rewards: _config: reward_prefix: "ref_" algo: config: dynamo_backend: null sampling_strategy: uniform seed: 114514 env: config: seed: 42 simulation: episode_length_s: 36000 robot: motion: backend: "hdf5_v2" cache_max_num_clips: ${num_envs} train_hdf5_roots: ${robot.motion.val_hdf5_roots} val_hdf5_roots: ${motion_h5_path} max_frame_length: 10000 # 20s min_frame_length: 1 # handpicked_motion_names: ${handpicked_motion_names} world_frame_normalization: false dataloader: num_workers: 2 pin_memory: true persistent_workers: false prefetch_factor: 1 timeout: 600 batch_progress_bar: true terrain: terrain_type: generator prim_path: /World/ground static_friction: 1.0 dynamic_friction: 1.0 restitution: 0.0 friction_combine_mode: multiply restitution_combine_mode: multiply debug_vis: false max_init_terrain_level: 0 # Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch. # When false, env origins are placed on a regular grid as in the default importer. random_spawn: true # Keep random spawn points away from terrain edges to avoid spawning onto the outer border. random_spawn_margin: 2.0 # TerrainGeneratorCfg parameters. generator: num_rows: 1 num_cols: 1 size: [10.0, 10.0] border_width: 1000.0 horizontal_scale: 0.1 vertical_scale: 0.005 slope_threshold: null difficulty_range: [0.0, 0.0] color_scheme: height sub_terrains: plane: type: plane proportion: 1.0 # Offline visual material configuration (PreviewSurface, no MDL/Nucleus). visual_material: type: color diffuse_color: [0.25, 0.25, 0.25] metallic: 0.0 roughness: 0.5 ================================================ FILE: holomotion/config/evaluation/eval_mujoco_sim2sim.yaml ================================================ # @package _global_ defaults: - _self_ # Evaluation toggles headless: false # true to run without GUI record_video: false # true to save MP4 recordings video_width: 1280 video_height: 720 video_fps: 30 camera_tracking: true # true to make camera follow robot root body camera_height_offset: 0.3 # small offset (meters) above robot root for camera lookat point # NOTE: This offsets where camera LOOKS AT, not camera position # Use small values (0.2-0.5m) for proper framing, not large values camera_distance: 4.0 # camera distance from lookat point (meters) # Larger values = camera further away, smaller values = closer camera_azimuth: 150.0 # default viewer/offscreen azimuth (deg), side-ish view camera_elevation: -20.0 # default viewer/offscreen elevation (deg), slight downward angle # Offline evaluation pipeline (dataset mode) motion_npz_dir: null dump_npzs: true dump_onnx_io_npy: false calc_per_clip_metrics: false generate_report: false metric_calculation: "per_clip" # "per_clip" (Macro) or "per_frame" (Micro) dof_mode: "23" # "29" for full DoF, "23" for reduced DoF failure_pos_err_thresh_m: 0.25 ray_actors_per_gpu: 16 # persistent Ray actors per GPU for batch eval ray_multi_ckpt_mode: "split" # "split" or "per_checkpoint" for multi-ONNX eval ckpt_onnx_root_dir: null # optional ONNX root directory for multi-checkpoint eval ckpt_onnx_names: null # optional list of ONNX file names to evaluate ray_parallel_metrics_postprocess: true # parallelize per-checkpoint metrics/report/export with Ray when evaluating multiple ckpts ray_metrics_postprocess_num_cpus: 24 # Ray resource accounting per checkpoint-postprocess task (0 = don't reserve CPUs) metrics_threadpool_max_workers: 24 # per-checkpoint ThreadPoolExecutor workers inside metrics.py (null => auto=min(num_files, 24)) # Termination / scheduling max_policy_steps: 0 # 0 = unlimited; used in headless if no motion policy_action_delay_step: 0 # max random action delay in 50 Hz policy steps; 0 disables delay action_delay_type: "episode" # "episode" samples once per reset, "step" re-samples every policy step unitree_viewer_dt: 0.0167 # ~60 Hz viewer sync unitree_domain_id: 0 unitree_interface: "lo" unitree_use_joystick: false unitree_joystick_type: "xbox" unitree_print_scene_information: false # Debug options debug_anchor_obs: false # when true, dump anchor pose/obs debug CSV for sim2sim debug_anchor_obs_interval: 50 # log every N policy steps (>=1) use_isaac_root_alignment: true # align free root state to IsaacLab reference root at frame 0 isaac_action_playback: false # when true, use per-frame actions recorded from IsaacLab instead of ONNX policy robot_xml_path: ??? use_gpu: true ================================================ FILE: holomotion/config/evaluation/eval_velocity_tracking.yaml ================================================ # @package _global_ defaults: - /robot: unitree/G1/29dof/29dof_training_isaaclab - /env: velocity_tracking - /env/terrain: isaaclab_plane - /env/terminations: NO_termination - /env/domain_randomization: NO_domain_rand project_name: ??? experiment_name: ??? num_envs: ??? headless: ??? motion_h5_path: null checkpoint: null num_processes: ??? main_process: ??? process_id: ??? timestamp: ${now:%Y%m%d_%H%M%S} base_dir: logs experiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name} save_dir: ${experiment_dir}/.hydra output_dir: ${experiment_dir}/output experiment_save_dir: ??? dump_npzs: false export_policy: true algo: config: dynamo_backend: null # Video recording for offline evaluation (env.render() -> MP4 at target_fps) record_video: false # enable MP4 recording during offline evaluation env: config: simulation: episode_length_s: 3600 robot: motion: cache_max_num_clips: ${num_envs} max_frame_length: 10000 min_frame_length: 0 hdf5_root: ${motion_h5_path} val_hdf5_root: ${motion_h5_path} dataloader: num_workers: 0 # terrain: # terrain_type: usd # usd_path: assets/isaac/4.1/Isaac/Environments/Grid/gridroom_black.usd # usd_path: assets/isaac/4.1/Isaac/Environments/Terrains/rough_plane.usd terrain: terrain_type: generator prim_path: /World/ground static_friction: 1.0 dynamic_friction: 1.0 restitution: 0.0 friction_combine_mode: multiply restitution_combine_mode: multiply debug_vis: false max_init_terrain_level: 0 # Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch. # When false, env origins are placed on a regular grid as in the default importer. random_spawn: true # TerrainGeneratorCfg parameters. generator: num_rows: 1 num_cols: 1 size: [8.0, 8.0] border_width: 10.0 horizontal_scale: 0.1 vertical_scale: 0.005 slope_threshold: null difficulty_range: [0.0, 0.0] color_scheme: height sub_terrains: plane: type: random_uniform proportion: 1.0 noise_range: [0.0, 0.0] noise_step: 0.25 downsampled_scale: 0.5 # Offline visual material configuration (PreviewSurface, no MDL/Nucleus). visual_material: type: color diffuse_color: [0.25, 0.25, 0.25] metallic: 0.0 roughness: 0.5 ================================================ FILE: holomotion/config/modules/motion_tracking/motion_tracking_mlp.yaml ================================================ # @package _global_ modules: actor: type: MLP hidden_norm: none layer_config: hidden_dims: - 2048 - 1024 - 512 - 256 activation: SiLU obs_norm: enabled: true epsilon: 1.0e-8 # Reduced for better stability in DDP update_method: ema # ema or cumulative ema_momentum: 1.0e-4 update_at_train: true update_at_eval: false enable_clipping: true # Enable clipping for DDP stability clip_range: 10.0 # Reduced clip range for better stability sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout # Observation schema for motion tracking, from the actor's perspective. obs_schema: flattened_obs: seq_len: ${obs.context_length} terms: - unified/actor_ref_gravity_projection_cur - unified/actor_ref_base_linvel_cur - unified/actor_ref_base_angvel_cur - unified/actor_ref_dof_pos_cur - unified/actor_ref_root_height_cur - unified/actor_projected_gravity - unified/actor_rel_robot_root_ang_vel - unified/actor_dof_pos - unified/actor_dof_vel - unified/actor_last_action flattened_obs_fut: seq_len: ${obs.n_fut_frames} terms: - unified/actor_ref_dof_pos_fut - unified/actor_ref_root_height_fut - unified/actor_ref_gravity_projection_fut - unified/actor_ref_base_linvel_fut - unified/actor_ref_base_angvel_fut output_dim: robot_action_dim critic: type: MLP obs_norm: enabled: true epsilon: 1.0e-8 # Reduced for better stability in DDP update_method: ema # ema or cumulative ema_momentum: 1.0e-4 update_at_train: true update_at_eval: false enable_clipping: true # Enable clipping for DDP stability clip_range: 10.0 # Reduced clip range for better stability sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout hidden_norm: rmsnorm layer_config: hidden_dims: - 2048 - 2048 - 2048 - 2048 activation: SiLU obs_schema: flattened_obs: seq_len: 1 terms: - unified/critic_ref_dof_pos_cur - unified/critic_global_anchor_diff - unified/critic_ref_motion_cur_heading_aligned_root_pos - unified/critic_ref_motion_cur_heading_aligned_root_rot6d - unified/critic_ref_motion_cur_heading_aligned_root_lin_vel - unified/critic_ref_motion_cur_heading_aligned_root_ang_vel - unified/critic_rel_robot_root_lin_vel - unified/critic_rel_robot_root_ang_vel - unified/critic_global_robot_bodylink_lin_vel_flat - unified/critic_global_robot_bodylink_ang_vel_flat - unified/critic_root_rel_robot_bodylink_pos_flat - unified/critic_root_rel_robot_bodylink_rot_mat_flat - unified/critic_dof_pos - unified/critic_dof_vel - unified/critic_last_action flattened_obs_fut: seq_len: ${obs.n_fut_frames} terms: - unified/critic_ref_dof_pos_fut - unified/critic_ref_root_height_fut - unified/critic_ref_motion_fut_heading_aligned_root_pos - unified/critic_ref_motion_fut_heading_aligned_root_rot6d - unified/critic_ref_motion_fut_heading_aligned_root_lin_vel - unified/critic_ref_motion_fut_heading_aligned_root_ang_vel output_dim: 1 ================================================ FILE: holomotion/config/modules/motion_tracking/motion_tracking_tf-moe.yaml ================================================ # @package _global_ modules: actor: type: ReferenceRoutedGroupedMoETransformerPolicy use_checkpointing: false # use gradient checkpointing to save GRAM significantly # MoE-specific hyperparameters num_fine_experts: 16 num_shared_experts: 1 top_k: 2 moe_loss_coef: 0.0 routing_score_fn: ${algo.config.moe_router.routing_score_fn} routing_scale: ${algo.config.moe_router.routing_scale} use_dynamic_bias: ${algo.config.moe_router.use_dynamic_bias} bias_update_rate: ${algo.config.moe_router.bias_update_rate} expert_bias_clip: ${algo.config.moe_router.expert_bias_clip} # Transformer hyperparameters - smaller model for stability obs_embed_mlp_hidden: 2048 router_embed_mlp_hidden: 2048 d_model: 512 n_heads: 8 n_kv_heads: 4 use_gated_attn: true n_layers: 3 ff_mult: 2.0 ff_mult_dense: 4 attn_dropout: 0.0 mlp_dropout: 0.0 max_ctx_len: 32 # Auxiliary dynamics prediction weights (0.0 = disabled) aux_sys_id_weight: 0.0 aux_dynamics_weight: 0.0 obs_norm: enabled: true epsilon: 1.0e-8 # Reduced for better stability in DDP update_method: ema # ema or cumulative ema_momentum: 1.0e-4 update_at_train: true update_at_eval: false enable_clipping: true # Enable clipping for DDP stability clip_range: 10.0 # Reduced clip range for better stability sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout # Observation schema for motion tracking, from the actor's perspective. obs_schema: flattened_obs: seq_len: ${obs.context_length} terms: - unified/actor_ref_gravity_projection_cur - unified/actor_ref_base_linvel_cur - unified/actor_ref_base_angvel_cur - unified/actor_ref_dof_pos_cur - unified/actor_ref_root_height_cur - unified/actor_projected_gravity - unified/actor_rel_robot_root_ang_vel - unified/actor_dof_pos - unified/actor_dof_vel - unified/actor_last_action flattened_obs_fut: seq_len: ${obs.n_fut_frames} terms: - unified/actor_ref_dof_pos_fut - unified/actor_ref_root_height_fut - unified/actor_ref_gravity_projection_fut - unified/actor_ref_base_linvel_fut - unified/actor_ref_base_angvel_fut output_dim: robot_action_dim critic: type: MLP obs_norm: enabled: true epsilon: 1.0e-8 # Reduced for better stability in DDP update_method: ema # ema or cumulative ema_momentum: 1.0e-4 update_at_train: true update_at_eval: false enable_clipping: true # Enable clipping for DDP stability clip_range: 10.0 # Reduced clip range for better stability sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout hidden_norm: rmsnorm layer_config: hidden_dims: - 2048 - 2048 - 2048 - 2048 activation: SiLU obs_schema: flattened_obs: seq_len: 1 terms: - unified/critic_ref_dof_pos_cur - unified/critic_global_anchor_diff - unified/critic_ref_motion_cur_heading_aligned_root_pos - unified/critic_ref_motion_cur_heading_aligned_root_rot6d - unified/critic_ref_motion_cur_heading_aligned_root_lin_vel - unified/critic_ref_motion_cur_heading_aligned_root_ang_vel - unified/critic_rel_robot_root_lin_vel - unified/critic_rel_robot_root_ang_vel - unified/critic_global_robot_bodylink_lin_vel_flat - unified/critic_global_robot_bodylink_ang_vel_flat - unified/critic_root_rel_robot_bodylink_pos_flat - unified/critic_root_rel_robot_bodylink_rot_mat_flat - unified/critic_dof_pos - unified/critic_dof_vel - unified/critic_last_action flattened_obs_fut: seq_len: ${obs.n_fut_frames} terms: - unified/critic_ref_dof_pos_fut - unified/critic_ref_root_height_fut - unified/critic_ref_motion_fut_heading_aligned_root_pos - unified/critic_ref_motion_fut_heading_aligned_root_rot6d - unified/critic_ref_motion_fut_heading_aligned_root_lin_vel - unified/critic_ref_motion_fut_heading_aligned_root_ang_vel output_dim: 1 ================================================ FILE: holomotion/config/modules/velocity_tracking/velocity_tracking_mlp.yaml ================================================ # @package _global_ modules: actor: type: MLP fix_sigma: false noise_std_type: scalar obs_norm: enabled: true epsilon: 1.0e-8 # Reduced for better stability in DDP update_method: ema # ema or cumulative ema_momentum: 1.0e-4 update_at_train: true update_at_eval: false enable_clipping: true # Enable clipping for DDP stability clip_range: 10.0 # Reduced clip range for better stability sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout hidden_norm: rmsnorm layer_config: hidden_dims: - 512 - 512 - 512 activation: SiLU obs_schema: flattened_obs: seq_len: ${obs.context_length} terms: - unified/actor_velocity_command - unified/actor_projected_gravity - unified/actor_rel_robot_root_ang_vel - unified/actor_dof_pos - unified/actor_dof_vel - unified/actor_last_action output_dim: robot_action_dim critic: type: MLP obs_norm: enabled: true epsilon: 1.0e-8 # Reduced for better stability in DDP update_method: ema # ema or cumulative ema_momentum: 1.0e-4 update_at_train: true update_at_eval: false enable_clipping: true # Enable clipping for DDP stability clip_range: 10.0 # Reduced clip range for better stability sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout hidden_norm: rmsnorm layer_config: hidden_dims: - 512 - 512 - 512 activation: SiLU obs_schema: flattened_obs: seq_len: 1 terms: - unified/critic_velocity_command - unified/critic_rel_robot_root_lin_vel - unified/critic_rel_robot_root_ang_vel - unified/critic_root_rel_robot_bodylink_pos_flat - unified/critic_root_rel_robot_bodylink_rot_mat_flat - unified/critic_dof_pos - unified/critic_dof_vel - unified/critic_last_action output_dim: 1 ================================================ FILE: holomotion/config/motion_retargeting/gmr_to_holomotion.yaml ================================================ # @package _global_ defaults: - _self_ hydra: job: chdir: false io: src_dir: ??? robot_config: ??? out_root: ??? ref_dir: holomotion/src/motion_retargeting/utils processing: target_fps: 50 fast_interpolate: true skip_existing: true debug_mode: false ray: num_workers: 0 ray_address: "" naming: emit_prefixed: true emit_legacy: false preprocess: # Available stages: # ['filename_as_motionkey','legacy_to_ref_keys','add_legacy_keys', # 'slicing','apply_butterworth_filter','add_padding','tagging'] # Empty list [] means no preprocessing stages applied pipeline: [] slicing: window_size: 500 overlap: 50 filtering: type: butterworth butter_cutoff_hz: 3.0 butter_order: 4 padding: # Robot config path for FK and default joint angles # If empty, uses io.robot_config robot_config_path: ${io.robot_config} # Duration of stand-still padding before/after motion (seconds) stand_still_time: 1.0 # Duration of transition between default pose and motion (seconds) transition_time: 1.5 tagging: # when empty, write tags to: /kinematic_tags.json output_json_path: "" ================================================ FILE: holomotion/config/motion_retargeting/holomotion_preprocess.yaml ================================================ # @package _global_ defaults: - _self_ hydra: job: chdir: false io: src_root: ??? out_root: ??? preprocess: # Available stages: # ['filename_as_motionkey','legacy_to_ref_keys','add_legacy_keys', # 'slicing','apply_butterworth_filter','add_padding','tagging'] pipeline: [] slicing: window_size: 500 overlap: 50 filtering: type: butterworth butter_cutoff_hz: 3.0 butter_order: 4 padding: # Robot config path for FK and default joint angles robot_config_path: "" # Duration of stand-still padding before/after motion (seconds) stand_still_time: 1.0 # Duration of transition between default pose and motion (seconds) transition_time: 1.0 tagging: # when empty, write tags to: /kinematic_tags.json output_json_path: "" ray: enabled: true num_workers: 2 # 0 = use all available CPUs ray_address: "" # empty = local; otherwise connect to existing Ray cluster ================================================ FILE: holomotion/config/motion_retargeting/kinematic_filter.yaml ================================================ # @package _global_ defaults: - _self_ hydra: job: chdir: false io: dataset_root: "" # absolute path to dataset root containing kinematic_tags.json filtering: output_yaml: "" # optional; defaults to /excluded_kinematic_motion_names.yaml schema: across: union thresholds: kinematic_features.root_linear_speed.max: { op: ">", value: 10.0 } kinematic_features.root_angular_speed.max: { op: ">", value: 20.0 } kinematic_features.root_delta_z.max: { op: ">", value: 2.0 } kinematic_features.jerk.max: { op: ">", value: 2000.0 } ================================================ FILE: holomotion/config/motion_retargeting/pack_hdf5_database.yaml ================================================ # @package _global_ defaults: - _self_ - /robot: unitree/G1/29dof/29dof_training_isaaclab hydra: job: chdir: false # IO precomputed_npz_root: ??? hdf5_root: ??? # Runtime # Optimal parameters for distributed JuiceFS training with millions of clips: # - chunks_t: Larger chunks (2048-4096) reduce metadata overhead and improve sequential read performance on JuiceFS # Balance: larger chunks = better sequential I/O, but too large wastes memory chunks_t: 1024 # - compression: lzf provides fast decompression suitable for training workloads # Alternatives: gzip (better compression, slower), none (fastest but largest files) compression: lzf # lzf|gzip|none # - shard_target_gb: Larger shards (5-10 GB) reduce shard count and metadata overhead for millions of clips # Balance: fewer shards = less metadata overhead, better for distributed access; more shards = better parallelism # For millions of clips, 5-10 GB reduces shard count while maintaining good parallelism shard_target_gb: 1.0 # - num_jobs: Parallel workers for packing process (should match available CPU cores) num_jobs: 16 debug_local_mode: false ================================================ FILE: holomotion/config/motion_retargeting/pack_hdf5_v2.yaml ================================================ # @package _global_ defaults: - _self_ - /robot: unitree/G1/29dof/29dof_training_isaaclab hydra: job: chdir: false # IO holomotion_npz_root: ??? hdf5_root: ??? # Runtime chunks_t: 1024 compression: lzf # lzf|gzip|none shard_target_gb: 1.0 shard_target_mode: h5_filesize # h5_filesize|npz_filesize|uncompressed_nbytes num_jobs: 16 debug_local_mode: false ================================================ FILE: holomotion/config/motion_retargeting/unitree_G1_29dof_retargeting.yaml ================================================ robot: humanoid_type: unitree/G1/29dof asset: smpl_dir: "assets/smpl" assetRoot: "./" assetFileName: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml" training_mjcfName: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml" video_dir: ${motion_npz_root}/video_rendering skip_frames: 3 # when skip frames=1, it means 30hz, when skip frames=2, it means 15hz, etc. show_markers: False max_workers: 12 ================================================ FILE: holomotion/config/mujoco_eval/sim2sim.yaml ================================================ # @package _group_ defaults: - _self_ enabled: false model_type: "holomotion" # Evaluation toggles headless: false record_video: false video_width: 1280 video_height: 720 video_fps: 30 camera_tracking: true camera_height_offset: 0.3 camera_distance: 4.0 camera_azimuth: 150.0 camera_elevation: -20.0 # Input/output robot_xml_path: null motion_npz_dir: null motion_npz_path: null ckpt_onnx_path: null ckpt_onnx_root_dir: null ckpt_onnx_names: null # Offline evaluation pipeline dump_npzs: true calc_per_clip_metrics: false generate_report: false metric_calculation: "per_clip" dof_mode: "23" failure_pos_err_thresh_m: 0.25 ray_actors_per_gpu: 16 ray_multi_ckpt_mode: "split" # Runtime use_gpu: true ================================================ FILE: holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml ================================================ # @package _global_ robot: humanoid_type: unitree/G1/29dof dof_obs_size: 29 actions_dim: 29 num_bodies: 30 num_extend_bodies: 0 undesired_contacts_regrex: "^(?!left_ankle_roll_link$)(?!right_ankle_roll_link$)(?!left_wrist_yaw_link$)(?!right_wrist_yaw_link$).+$" torso_name: "torso_link" anchor_body: "torso_link" key_bodies: - "pelvis" - "left_hip_roll_link" - "left_knee_link" - "left_ankle_pitch_link" - "right_hip_roll_link" - "right_knee_link" - "right_ankle_pitch_link" - "torso_link" - "left_shoulder_roll_link" - "left_elbow_link" - "left_wrist_yaw_link" - "right_shoulder_roll_link" - "right_elbow_link" - "right_wrist_yaw_link" key_dofs: - "left_knee_joint" - "right_knee_joint" - "left_elbow_joint" - "right_elbow_joint" dof_names: - "left_hip_pitch_joint" - "left_hip_roll_joint" - "left_hip_yaw_joint" - "left_knee_joint" - "left_ankle_pitch_joint" - "left_ankle_roll_joint" - "right_hip_pitch_joint" - "right_hip_roll_joint" - "right_hip_yaw_joint" - "right_knee_joint" - "right_ankle_pitch_joint" - "right_ankle_roll_joint" - "waist_yaw_joint" - "waist_roll_joint" - "waist_pitch_joint" - "left_shoulder_pitch_joint" - "left_shoulder_roll_joint" - "left_shoulder_yaw_joint" - "left_elbow_joint" - "left_wrist_roll_joint" - "left_wrist_pitch_joint" - "left_wrist_yaw_joint" - "right_shoulder_pitch_joint" - "right_shoulder_roll_joint" - "right_shoulder_yaw_joint" - "right_elbow_joint" - "right_wrist_roll_joint" - "right_wrist_pitch_joint" - "right_wrist_yaw_joint" # ========== Unified DOF Groupings ========== # Main anatomical groupings for DOF arm_dof_names: - "left_shoulder_pitch_joint" - "left_shoulder_roll_joint" - "left_shoulder_yaw_joint" - "left_elbow_joint" - "left_wrist_roll_joint" - "left_wrist_pitch_joint" - "left_wrist_yaw_joint" - "right_shoulder_pitch_joint" - "right_shoulder_roll_joint" - "right_shoulder_yaw_joint" - "right_elbow_joint" - "right_wrist_roll_joint" - "right_wrist_pitch_joint" - "right_wrist_yaw_joint" waist_dof_names: - "waist_yaw_joint" - "waist_roll_joint" - "waist_pitch_joint" leg_dof_names: - "left_hip_pitch_joint" - "left_hip_roll_joint" - "left_hip_yaw_joint" - "left_knee_joint" - "left_ankle_pitch_joint" - "left_ankle_roll_joint" - "right_hip_pitch_joint" - "right_hip_roll_joint" - "right_hip_yaw_joint" - "right_knee_joint" - "right_ankle_pitch_joint" - "right_ankle_roll_joint" # Side-specific groupings for DOF left_arm_dof_names: - "left_shoulder_pitch_joint" - "left_shoulder_roll_joint" - "left_shoulder_yaw_joint" - "left_elbow_joint" - "left_wrist_roll_joint" - "left_wrist_pitch_joint" - "left_wrist_yaw_joint" right_arm_dof_names: - "right_shoulder_pitch_joint" - "right_shoulder_roll_joint" - "right_shoulder_yaw_joint" - "right_elbow_joint" - "right_wrist_roll_joint" - "right_wrist_pitch_joint" - "right_wrist_yaw_joint" left_leg_dof_names: - "left_hip_pitch_joint" - "left_hip_roll_joint" - "left_hip_yaw_joint" - "left_knee_joint" - "left_ankle_pitch_joint" - "left_ankle_roll_joint" right_leg_dof_names: - "right_hip_pitch_joint" - "right_hip_roll_joint" - "right_hip_yaw_joint" - "right_knee_joint" - "right_ankle_pitch_joint" - "right_ankle_roll_joint" # Combined groupings for DOF (for backward compatibility and common usage) upper_body_dof_names: ${robot.arm_dof_names} # Alias for arm_dof_names lower_body_dof_names: - "left_hip_pitch_joint" - "left_hip_roll_joint" - "left_hip_yaw_joint" - "left_knee_joint" - "left_ankle_pitch_joint" - "left_ankle_roll_joint" - "right_hip_pitch_joint" - "right_hip_roll_joint" - "right_hip_yaw_joint" - "right_knee_joint" - "right_ankle_pitch_joint" - "right_ankle_roll_joint" - "waist_yaw_joint" - "waist_roll_joint" - "waist_pitch_joint" # ========== Unified Body Groupings ========== # Main anatomical groupings for bodies arm_body_names: - "left_shoulder_pitch_link" - "left_shoulder_roll_link" - "left_shoulder_yaw_link" - "left_elbow_link" - "left_wrist_roll_link" - "left_wrist_pitch_link" - "left_wrist_yaw_link" - "right_shoulder_pitch_link" - "right_shoulder_roll_link" - "right_shoulder_yaw_link" - "right_elbow_link" - "right_wrist_roll_link" - "right_wrist_pitch_link" - "right_wrist_yaw_link" head_hand_bodies: - "torso_link" - "left_wrist_yaw_link" - "right_wrist_yaw_link" torso_body_names: - "waist_yaw_link" - "waist_roll_link" - "torso_link" leg_body_names: - "left_hip_pitch_link" - "left_hip_roll_link" - "left_hip_yaw_link" - "left_knee_link" - "left_ankle_pitch_link" - "left_ankle_roll_link" - "right_hip_pitch_link" - "right_hip_roll_link" - "right_hip_yaw_link" - "right_knee_link" - "right_ankle_pitch_link" - "right_ankle_roll_link" # Side-specific groupings for bodies left_arm_body_names: - "left_shoulder_pitch_link" - "left_shoulder_roll_link" - "left_shoulder_yaw_link" - "left_elbow_link" - "left_wrist_roll_link" - "left_wrist_pitch_link" - "left_wrist_yaw_link" right_arm_body_names: - "right_shoulder_pitch_link" - "right_shoulder_roll_link" - "right_shoulder_yaw_link" - "right_elbow_link" - "right_wrist_roll_link" - "right_wrist_pitch_link" - "right_wrist_yaw_link" left_leg_body_names: - "left_hip_pitch_link" - "left_hip_roll_link" - "left_hip_yaw_link" - "left_knee_link" - "left_ankle_pitch_link" - "left_ankle_roll_link" right_leg_body_names: - "right_hip_pitch_link" - "right_hip_roll_link" - "right_hip_yaw_link" - "right_knee_link" - "right_ankle_pitch_link" - "right_ankle_roll_link" body_names: - "pelvis" - "left_hip_pitch_link" - "left_hip_roll_link" - "left_hip_yaw_link" - "left_knee_link" - "left_ankle_pitch_link" - "left_ankle_roll_link" - "right_hip_pitch_link" - "right_hip_roll_link" - "right_hip_yaw_link" - "right_knee_link" - "right_ankle_pitch_link" - "right_ankle_roll_link" - "waist_yaw_link" - "waist_roll_link" - "torso_link" - "left_shoulder_pitch_link" - "left_shoulder_roll_link" - "left_shoulder_yaw_link" - "left_elbow_link" - "left_wrist_roll_link" - "left_wrist_pitch_link" - "left_wrist_yaw_link" - "right_shoulder_pitch_link" - "right_shoulder_roll_link" - "right_shoulder_yaw_link" - "right_elbow_link" - "right_wrist_roll_link" - "right_wrist_pitch_link" - "right_wrist_yaw_link" init_state: pos: [0.0, 0.0, 0.8] # x,y,z [m] rot: [0.0, 0.929, 0.341, 0.298] # x,y,z,w [quat] lin_vel: [0.0, 0.0, 0.0] # x,y,z [m/s] ang_vel: [0.0, 0.0, 0.0] # x,y,z [rad/s] default_joint_angles: # = target angles [rad] when action = 0.0 left_hip_pitch_joint: -0.312 left_hip_roll_joint: 0.0 left_hip_yaw_joint: 0.0 left_knee_joint: 0.669 left_ankle_pitch_joint: -0.363 left_ankle_roll_joint: 0.0 right_hip_pitch_joint: -0.312 right_hip_roll_joint: 0.0 right_hip_yaw_joint: 0.0 right_knee_joint: 0.669 right_ankle_pitch_joint: -0.363 right_ankle_roll_joint: 0.0 waist_yaw_joint: 0. waist_roll_joint: 0. waist_pitch_joint: 0.1 left_shoulder_pitch_joint: 0.2 left_shoulder_roll_joint: 0.2 left_shoulder_yaw_joint: 0.0 left_elbow_joint: 0.6 left_wrist_roll_joint: 0.0 left_wrist_pitch_joint: 0.0 left_wrist_yaw_joint: 0.0 right_shoulder_pitch_joint: 0.2 right_shoulder_roll_joint: -0.2 right_shoulder_yaw_joint: 0.0 right_elbow_joint: 0.6 right_wrist_roll_joint: 0.0 right_wrist_pitch_joint: 0.0 right_wrist_yaw_joint: 0.0 actuators: actuator_type: unitree_erfi # implicit, unitree, or unitree_erfi ema_filter_enabled: false ema_filter_alpha: 1.0 all_joints: joint_names_expr: - ".*_hip_yaw_joint" - ".*_hip_roll_joint" - ".*_hip_pitch_joint" - ".*_knee_joint" - ".*_ankle_pitch_joint" - ".*_ankle_roll_joint" - "waist_yaw_joint" - "waist_roll_joint" - "waist_pitch_joint" - ".*_shoulder_pitch_joint" - ".*_shoulder_roll_joint" - ".*_shoulder_yaw_joint" - ".*_elbow_joint" - ".*_wrist_roll_joint" - ".*_wrist_pitch_joint" - ".*_wrist_yaw_joint" effort_limit_sim: ".*_hip_yaw_joint": 88.0 ".*_hip_roll_joint": 139.0 ".*_hip_pitch_joint": 88.0 ".*_knee_joint": 139.0 ".*_ankle_pitch_joint": 50.0 ".*_ankle_roll_joint": 50.0 "waist_yaw_joint": 88.0 "waist_roll_joint": 50.0 "waist_pitch_joint": 50.0 ".*_shoulder_pitch_joint": 25.0 ".*_shoulder_roll_joint": 25.0 ".*_shoulder_yaw_joint": 25.0 ".*_elbow_joint": 25.0 ".*_wrist_roll_joint": 25.0 ".*_wrist_pitch_joint": 5.0 ".*_wrist_yaw_joint": 5.0 velocity_limit_sim: ".*_hip_yaw_joint": 32.0 ".*_hip_roll_joint": 20.0 ".*_hip_pitch_joint": 32.0 ".*_knee_joint": 20.0 ".*_ankle_pitch_joint": 37.0 ".*_ankle_roll_joint": 37.0 "waist_yaw_joint": 32.0 "waist_roll_joint": 37.0 "waist_pitch_joint": 37.0 ".*_shoulder_pitch_joint": 37.0 ".*_shoulder_roll_joint": 37.0 ".*_shoulder_yaw_joint": 37.0 ".*_elbow_joint": 37.0 ".*_wrist_roll_joint": 37.0 ".*_wrist_pitch_joint": 22.0 ".*_wrist_yaw_joint": 22.0 stiffness: ".*_hip_pitch_joint": 40.17923847 ".*_hip_roll_joint": 99.09842778 ".*_hip_yaw_joint": 40.17923847 ".*_knee_joint": 99.09842778 ".*_ankle_pitch_joint": 28.50124620 ".*_ankle_roll_joint": 28.50124620 "waist_yaw_joint": 40.17923847 "waist_roll_joint": 28.50124620 "waist_pitch_joint": 28.50124620 ".*_shoulder_pitch_joint": 14.25062310 ".*_shoulder_roll_joint": 14.25062310 ".*_shoulder_yaw_joint": 14.25062310 ".*_elbow_joint": 14.25062310 ".*_wrist_roll_joint": 14.25062309787429 ".*_wrist_pitch_joint": 16.77832748089279 ".*_wrist_yaw_joint": 16.77832748089279 damping: ".*_hip_pitch_joint": 2.55788977 ".*_hip_roll_joint": 6.30880185 ".*_hip_yaw_joint": 2.55788977 ".*_knee_joint": 6.30880185 ".*_ankle_pitch_joint": 1.81444569 ".*_ankle_roll_joint": 1.81444569 "waist_yaw_joint": 2.55788977 "waist_roll_joint": 1.81444569 "waist_pitch_joint": 1.81444569 ".*_shoulder_pitch_joint": 0.90722284 ".*_shoulder_roll_joint": 0.90722284 ".*_shoulder_yaw_joint": 0.90722284 ".*_elbow_joint": 0.90722284 ".*_wrist_roll_joint": 0.907222843292423 ".*_wrist_pitch_joint": 1.06814150219 ".*_wrist_yaw_joint": 1.06814150219 armature: ".*_hip_pitch_joint": 0.010177520 ".*_hip_roll_joint": 0.025101925 ".*_hip_yaw_joint": 0.010177520 ".*_knee_joint": 0.025101925 ".*_ankle_pitch_joint": 0.007219450 ".*_ankle_roll_joint": 0.007219450 "waist_yaw_joint": 0.010177520 "waist_roll_joint": 0.007219450 "waist_pitch_joint": 0.007219450 ".*_shoulder_pitch_joint": 0.003609725 ".*_shoulder_roll_joint": 0.003609725 ".*_shoulder_yaw_joint": 0.003609725 ".*_elbow_joint": 0.003609725 ".*_wrist_roll_joint": 0.003609725 ".*_wrist_pitch_joint": 0.00425 ".*_wrist_yaw_joint": 0.00425 action_scale: ".*_hip_pitch_joint": 0.548 ".*_hip_roll_joint": 0.351 ".*_hip_yaw_joint": 0.548 ".*_knee_joint": 0.351 ".*_ankle_pitch_joint": 0.439 ".*_ankle_roll_joint": 0.439 "waist_yaw_joint": 0.548 "waist_roll_joint": 0.439 "waist_pitch_joint": 0.439 ".*_shoulder_pitch_joint": 0.439 ".*_shoulder_roll_joint": 0.439 ".*_shoulder_yaw_joint": 0.439 ".*_elbow_joint": 0.439 ".*_wrist_roll_joint": 0.439 ".*_wrist_pitch_joint": 0.075 ".*_wrist_yaw_joint": 0.075 dof_sign_by_name: left_hip_pitch_joint: 1.0 left_hip_roll_joint: -1.0 left_hip_yaw_joint: -1.0 left_knee_joint: 1.0 left_ankle_pitch_joint: 1.0 left_ankle_roll_joint: -1.0 right_hip_pitch_joint: 1.0 right_hip_roll_joint: -1.0 right_hip_yaw_joint: -1.0 right_knee_joint: 1.0 right_ankle_pitch_joint: 1.0 right_ankle_roll_joint: -1.0 waist_yaw_joint: -1.0 waist_roll_joint: -1.0 waist_pitch_joint: 1.0 left_shoulder_pitch_joint: 1.0 left_shoulder_roll_joint: -1.0 left_shoulder_yaw_joint: -1.0 left_elbow_joint: 1.0 left_wrist_roll_joint: -1.0 left_wrist_pitch_joint: 1.0 left_wrist_yaw_joint: -1.0 right_shoulder_pitch_joint: 1.0 right_shoulder_roll_joint: -1.0 right_shoulder_yaw_joint: -1.0 right_elbow_joint: 1.0 right_wrist_roll_joint: -1.0 right_wrist_pitch_joint: 1.0 right_wrist_yaw_joint: -1.0 asset: collapse_fixed_joints: True replace_cylinder_with_capsule: True flip_visual_attachments: False max_angular_velocity: 1000. max_linear_velocity: 1000. density: 0.001 angular_damping: 0. linear_damping: 0. asset_root: "./" urdf_file: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.urdf" assetFileName: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml" fix_base_link: false force_usd_conversion: true extend_config: [] motion: asset: assetRoot: "./" assetFileName: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml" sampling_strategy: ${algo.config.sampling_strategy} weighted_bin: ${algo.config.weighted_bin} curriculum: ${algo.config.curriculum} dump_sampled_motion_keys: false dump_sampled_motion_keys_interval: 1 dump_sampled_motion_keys_dir: "sampled_motion_cache_keys" max_frame_length: 300 # 6s min_frame_length: 50 # 1s handpicked_motion_names: null excluded_motion_names: null world_frame_normalization: true backend: "hdf5_v2" # hdf5, hdf5_v2 train_hdf5_roots: ${train_hdf5_roots} val_hdf5_roots: ${train_hdf5_roots} dof_names: ${robot.dof_names} body_names: ${robot.body_names} key_bodies: ${robot.key_bodies} extend_config: ${robot.extend_config} dataloader: num_workers: 2 prefetch_factor: 1 pin_memory: true persistent_workers: false timeout: 600 fk_robot_file_path: ${robot.asset.urdf_file} fk_vel_smoothing_sigma: 2.0 online_filter: enabled: false butter_order: 4 butter_cutoff_hz_pool: [] cache: batch_progress_bar: false max_num_clips: ${num_envs} # Batch size for motion clips device: "cuda" # "cuda" or "cpu"; cuda stages on GPU swap_interval_steps: ${robot.motion.max_frame_length} # Swap cache every N steps allowed_prefixes: - "ref_" - "ft_ref_" ================================================ FILE: holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab_s100.yaml ================================================ # @package _global_ defaults: - unitree/G1/29dof/29dof_training_isaaclab robot: asset: urdf_file: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0_s100.urdf" ================================================ FILE: holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking_mlp.yaml ================================================ # @package _global_ defaults: - /training: train_base - /algo: ppo - /robot: unitree/G1/29dof/29dof_training_isaaclab - /env: motion_tracking - /env/terminations: termination_motion_tracking - /env/observations: motion_tracking/obs_motion_tracking_mlp - /env/rewards: motion_tracking/rew_motion_tracking - /env/domain_randomization: domain_rand_medium - /env/terrain: isaaclab_rough - /modules: motion_tracking/motion_tracking_mlp project_name: HoloMotionMotrackV1.2 # checkpoint: ??? train_hdf5_roots: - data/h5v2_datasets/AMASS_test ================================================ FILE: holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking_tf-moe.yaml ================================================ # @package _global_ defaults: - /training: train_base - /algo: ppo_tf - /robot: unitree/G1/29dof/29dof_training_isaaclab - /env: motion_tracking - /env/terminations: termination_motion_tracking - /env/observations: motion_tracking/obs_motion_tracking_tf-moe - /env/rewards: motion_tracking/rew_motion_tracking - /env/domain_randomization: domain_rand_medium - /env/terrain: isaaclab_rough - /modules: motion_tracking/motion_tracking_tf-moe project_name: HoloMotionMotrackV1.2 # checkpoint: ??? train_hdf5_roots: - data/h5v2_datasets/AMASS_test ================================================ FILE: holomotion/config/training/train_base.yaml ================================================ # @package _global_ defaults: - _self_ - /mujoco_eval: sim2sim project_name: ??? experiment_name: ??? num_envs: ??? headless: ??? motion_h5_path: ??? checkpoint: null num_processes: ??? main_process: ??? process_id: ??? timestamp: ${now:%Y%m%d_%H%M%S} base_dir: logs experiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name} save_dir: ${experiment_dir}/.hydra output_dir: ${experiment_dir}/output experiment_save_dir: ??? ================================================ FILE: holomotion/config/training/velocity_tracking/train_g1_29dof_velocity_tracking_mlp.yaml ================================================ # @package _global_ defaults: - /training: train_base - /algo: ppo - /robot: unitree/G1/29dof/29dof_training_isaaclab - /env: velocity_tracking - /env/terminations: termination_velocity_tracking - /env/observations: velocity_tracking/obs_velocity_tracking - /env/rewards: velocity_tracking/rew_velocity_tracking - /env/domain_randomization: domain_rand_medium - /env/terrain: isaaclab_rough - /modules: velocity_tracking/velocity_tracking_mlp project_name: HoloMotionVelocityTrackingG1 # checkpoint: ??? train_hdf5_roots: - /horizon-bucket/robot_lab/users/maiyue01.chen/h5_datasets/h5_unitree_walk_20160119 env: config: commands: base_velocity: params: resampling_time_range: [3.0, 6.0] rel_standing_envs: 0.2 rel_yaw_envs: 0.2 heading_command: false heading_control_stiffness: 0.5 ranges: lin_vel_x: [-1.0, 1.0] lin_vel_y: [-0.5, 0.5] ang_vel_z: [-1.0, 1.0] heading: [-3.14, 3.14] algo: config: symmetry_loss: enabled: true coef: 0.1 robot: init_state: default_joint_angles: waist_pitch_joint: 0.0 ================================================ FILE: holomotion/scripts/data_curation/convert_to_amass.sh ================================================ source train.env # 默认原始数据路径 DATA_ROOT="./data/raw_datasets" # 如果传入参数就覆盖默认 if [ ! -z "$1" ]; then DATA_ROOT="$1" fi ${Train_CONDA_PREFIX}/bin/python \ holomotion/src/data_curation/data_smplify.py \ --data_root "$DATA_ROOT" ================================================ FILE: holomotion/scripts/data_curation/filter_smpl_data.sh ================================================ source train.env # default json lisy default_jsonl_list=("humanact12" "MotionX" "OMOMO" "ZJU_Mocap" "amass") jsonl_list=("${default_jsonl_list[@]}") # extract command line params while getopts "l:" opt; do case $opt in l) # 用户输入的 jsonl_list IFS=' ' read -r -a jsonl_list <<<"$OPTARG" ;; *) echo "Usage: $0 [-l \"file1 file2 ...\"]" exit 1 ;; esac done echo "Running label_data.py first..." ${Train_CONDA_PREFIX}/bin/python \ ./holomotion/src/data_curation/filter/label_data.py \ --jsonl_list "${jsonl_list[@]}" echo "label_data.py finished." echo "==============================" for json in "${jsonl_list[@]}"; do echo "Processing $json" # if [[ "$json" == "amass" ]]; then parent_folder="./data/amass_compatible_datasets/amass" else parent_folder="./data/amass_compatible_datasets" fi # 生成路径 json_path="./data/dataset_labels/${json}.jsonl" yaml_path="./holomotion/config/data_curation/${json}_excluded.yaml" # 调用 python 脚本 ${Train_CONDA_PREFIX}/bin/python \ ./holomotion/src/data_curation/filter/filter.py \ --parent_folder "$parent_folder" \ --json_path "$json_path" \ --yaml_path "$yaml_path" echo "Finished $json" echo "-----------------------" done echo "All done" ================================================ FILE: holomotion/scripts/data_curation/video_to_smpl_gvhmr.sh ================================================ export CONDA_BASE=$(conda info --base) export Train_CONDA_PREFIX="$CONDA_BASE/envs/gvhmr" video_folder_root="holomotion_abs_path/data/video_data" npz_data_root="holomotion_abs_path/data/gvhmr_converted/gvhmr_result" out_dir="holomotion_abs_path/data/gvhmr_converted/collected_smpl" cd thirdparties/GVHMR/ $Train_CONDA_PREFIX/bin/python ../../holomotion/src/data_curation/video_to_smpl_gvhmr.py \ --folder=${video_folder_root} \ --output_root=${npz_data_root} \ -s mkdir -p "${out_dir}" for subdir in "${npz_data_root}"/*; do if [[ ! -d "${subdir}" ]]; then continue fi sub_name=$(basename "${subdir}") src_npz="${subdir}/smpl.npz" if [[ ! -f "${src_npz}" ]]; then echo "[SKIP] ${sub_name}: smpl.npz not found" continue fi dst_npz="${out_dir}/${sub_name}_smpl.npz" cp -f "${src_npz}" "${dst_npz}" echo "[COPY] ${src_npz} -> ${dst_npz}" done ================================================ FILE: holomotion/scripts/data_curation/visualize_smpl_npz.sh ================================================ export CONDA_BASE=$(conda info --base) export Train_CONDA_PREFIX="$CONDA_BASE/envs/gvhmr" $Train_CONDA_PREFIX/bin/python ../../holomotion/src/data_curation/visualize_smpl_npz.py ================================================ FILE: holomotion/scripts/evaluation/calc_offline_eval_metrics.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env npz_dir="your_npz_dir" dataset_suffix="HoloMotion_eval" metric_calculation="per_clip" # Options: "per_clip" or "per_frame" dof_mode="23" # Options: "29" for full DoF, "23" for upper body only ${Train_CONDA_PREFIX}/bin/python \ holomotion/src/evaluation/metrics.py \ --npz_dir=${npz_dir} \ --dataset_suffix=${dataset_suffix} \ --failure_pos_err_thresh_m=0.25 \ --metric_calculation=${metric_calculation} \ --dof_mode=${dof_mode} ================================================ FILE: holomotion/scripts/evaluation/eval_motion_tracking.sh ================================================ #!/bin/bash # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env export CUDA_VISIBLE_DEVICES="0" HEADLESS=true CONFIG_NAME="eval_isaaclab" CKPT_PATH="logs/HoloMotionMotrackV1.2/your_log_dir/model_xxx.pt" eval_h5_dataset_path="['data/h5v2_datasets/lafan1']" num_envs=4 ${Train_CONDA_PREFIX}/bin/accelerate launch \ holomotion/src/evaluation/eval_motion_tracking_single.py \ --config-name=evaluation/${CONFIG_NAME} \ headless=${HEADLESS} \ num_envs=${num_envs} \ export_policy=true \ dump_npzs=true \ calc_per_clip_metrics=true \ generate_report=true \ motion_h5_path=${eval_h5_dataset_path} \ +use_kv_cache=true \ export_only=false \ checkpoint=$CKPT_PATH \ project_name="HoloMotionMoTrack" ================================================ FILE: holomotion/scripts/evaluation/eval_mujoco_sim2sim.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env export CUDA_VISIBLE_DEVICES="0" export HEADLESS=false if $HEADLESS; then export MUJOCO_GL="osmesa" export RECORD_VIDEO=true else export MUJOCO_GL="egl" export RECORD_VIDEO=false fi model_type="${model_type:-holomotion}" robot_xml_path="assets/robots/unitree/G1/29dof/scene_29dof.xml" ONNX_PATH="your_onnx_model.onnx" export motion_npz_path="your_npz.npz" ${Train_CONDA_PREFIX}/bin/python holomotion/src/evaluation/eval_mujoco_sim2sim.py \ record_video=$RECORD_VIDEO \ headless=$HEADLESS \ camera_tracking=true \ camera_distance=7.0 \ +model_type=${model_type} \ use_gpu=true \ dump_npzs=true \ dump_onnx_io_npy=false \ calc_per_clip_metrics=true \ generate_report=true \ ray_actors_per_gpu=12 \ policy_action_delay_step=0 \ action_delay_type=step \ +ckpt_onnx_path="$ONNX_PATH" \ +motion_npz_path='${oc.env:motion_npz_path}' \ robot_xml_path=${robot_xml_path} ================================================ FILE: holomotion/scripts/evaluation/eval_velocity_tracking.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env export CUDA_VISIBLE_DEVICES="0" config_name="eval_velocity_tracking" num_envs=1 ckpt_path="logs/HoloMotionVelocityTracking/xxxxx-train_g1_29dof_velocity_tracking/model_xxx.pt" ${Train_CONDA_PREFIX}/bin/python \ holomotion/src/evaluation/eval_velocity_tracking.py \ --config-name=evaluation/${config_name} \ project_name="HoloMotionVelocityTracking" \ num_envs=${num_envs} \ headless=false \ experiment_name=${config_name} \ checkpoint=${ckpt_path} \ +env.config.commands.base_velocity.params.resampling_time_range=[3,5] ================================================ FILE: holomotion/scripts/evaluation/mean_process_5metrics.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import argparse import csv import glob import json import os import re import numpy as np import pandas as pd from tabulate import tabulate # 需要统计的指标 Key METRICS = [ "mpjpe_g", "mpjpe_l", "whole_body_joints_dist", "root_vel_error", "root_r_error", "root_p_error", "root_y_error", "root_height_error", "mean_dof_vel", "mean_dof_acc", "mean_dof_torque", "mean_action_rate", "success", "mean_torque_jump_norm", "p95_torque_jump_norm", "mean_torque_jump_ratio", "p95_torque_jump_ratio", ] # 表头映射 (Json Key -> 表格显示名称) COLUMN_MAPPING = { "mpjpe_g": "Global Bodylink Pos Err", "mpjpe_l": "Local Bodylink Pos Err", "whole_body_joints_dist": "Dof Position Err", "root_vel_error": "Root Vel Err", "root_r_error": "Root Roll Err", "root_p_error": "Root Pitch Err", "root_y_error": "Root Yaw Err", "root_height_error": "Root Height Err", "mean_dof_vel": "Mean Dof Vel", "mean_dof_acc": "Mean Dof Acc", "mean_dof_torque": "Mean Dof Torque", "mean_action_rate": "Mean Action Rate", "success": "Success Rate", "mean_torque_jump_norm": "Mean Torque Jump Norm", "p95_torque_jump_norm": "P95 Torque Jump Norm", "mean_torque_jump_ratio": "Mean Torque Jump Ratio", "p95_torque_jump_ratio": "P95 Torque Jump Ratio", } def get_dataset_name(motion_key): if not isinstance(motion_key, str): return "Unknown" match_old = re.search(r"clips_([a-zA-Z0-9]+)_", motion_key) if match_old: return match_old.group(1) match_new = re.search(r"v1.1_eval_([a-zA-Z0-9]+)_", motion_key) if match_new: return match_new.group(1) return motion_key.split("_")[0] def process_data(folder_path): folder_path = os.path.expanduser(folder_path) search_pattern = os.path.join(folder_path, "*.json") json_files = glob.glob(search_pattern) json_files = [ file for file in json_files if "batch_" not in os.path.basename(file) ] if not json_files: raise FileNotFoundError( f"No .json files found in directory: {folder_path}" ) all_records = [] for file_path in json_files: model_name = os.path.splitext(os.path.basename(file_path))[0] with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) # 结构兼容性处理 if isinstance(data, dict) and "per_clip" in data: clips_data = data["per_clip"] elif isinstance(data, list): clips_data = data elif isinstance(data, dict) and "motion_key" in data: clips_data = [data] else: continue for entry in clips_data: if "motion_key" not in entry: continue dataset_name = get_dataset_name(entry["motion_key"]) record = {"Method": model_name, "Dataset": dataset_name} for metric in METRICS: val = entry.get(metric, None) if val is not None: record[metric] = val all_records.append(record) if not all_records: raise ValueError( f"No valid per-clip metric records extracted from: {folder_path}" ) df = pd.DataFrame(all_records) df = df.reindex(columns=["Method", "Dataset", *METRICS]) grouped_ds = df.groupby(["Method", "Dataset"])[METRICS] df_mean_ds = grouped_ds.mean().reset_index() df_median_ds = grouped_ds.median().reset_index() # Macro-Mean calculation df_mean_total = ( df_mean_ds.groupby(["Method"])[METRICS].mean().reset_index() ) # Macro-Median calculation df_median_total = ( df_median_ds.groupby(["Method"])[METRICS].mean().reset_index() ) df_mean_total["Dataset"] = "Total (Macro)" df_median_total["Dataset"] = "Total (Macro)" final_mean = pd.concat([df_mean_ds, df_mean_total], ignore_index=True) final_median = pd.concat( [df_median_ds, df_median_total], ignore_index=True ) return final_mean, final_median def highlight_best(val, best_val): """Return a highlighted HTML string when value is best.""" if val is None or pd.isna(val): return str(val) val_float = float(val) best_val_float = float(best_val) formatted_val = f"{val_float:.4f}" if np.isclose(val_float, best_val_float, atol=1e-6): return f"{formatted_val}" return formatted_val def generate_report( df, folder_path, file_name="result_table_mean.md", title="Evaluation Results (Mean)", ): out_md = os.path.join(folder_path, file_name) all_datasets = df["Dataset"].unique().tolist() # 排序:将 Total 放到最后 total_key = "Total (Macro)" if total_key in all_datasets: all_datasets.remove(total_key) all_datasets.sort() all_datasets.append(total_key) else: all_datasets.sort() md_content_accumulator = f"# {title}\n\n" md_content_accumulator += ( "> **Note:** 'Total (Macro)' represents the **Macro-Average**, " "calculated as the arithmetic mean of the scores across all datasets, " "treating each dataset equally regardless of sample size.\n\n" ) for ds_name in all_datasets: sub_df = df[df["Dataset"] == ds_name].copy() for metric in METRICS: if metric in sub_df.columns: if metric == "success": best_val = sub_df[metric].max() else: best_val = sub_df[metric].min() sub_df[metric] = sub_df[metric].apply( lambda x, best_val=best_val: highlight_best(x, best_val) ) sub_df = sub_df.drop(columns=["Dataset"]) sub_df.rename(columns=COLUMN_MAPPING, inplace=True) cols = list(sub_df.columns) if "Method" in cols: cols.insert(0, cols.pop(cols.index("Method"))) sub_df = sub_df[cols] md_content_accumulator += f"### Dataset: {ds_name}\n" # 使用 to_markdown 生成表格 table_str = sub_df.to_markdown(index=False) md_content_accumulator += table_str + "\n\n" with open(out_md, "w", encoding="utf-8") as f: f.write(md_content_accumulator) return os.path.abspath(out_md) def _format_metric_values_for_cli(sub_df: pd.DataFrame) -> pd.DataFrame: cli_df = sub_df.copy() for metric in METRICS: if metric in cli_df.columns: cli_df[metric] = cli_df[metric].apply( lambda x: f"{float(x):.4f}" if pd.notna(x) else "nan" ) return cli_df def _print_cli_tables(df: pd.DataFrame, title: str, folder_path: str) -> None: total_key = "Total (Macro)" all_datasets = df["Dataset"].unique().tolist() dataset_order = sorted([d for d in all_datasets if d != total_key]) if total_key in all_datasets: dataset_order.append(total_key) merged_df = df.copy() merged_df["Dataset"] = pd.Categorical( merged_df["Dataset"], categories=dataset_order, ordered=True ) merged_df = merged_df.sort_values( by=["Dataset", "Method"], kind="stable" ).reset_index(drop=True) merged_df["Dataset"] = merged_df["Dataset"].astype(str) merged_df = _format_metric_values_for_cli(merged_df) merged_df.rename(columns=COLUMN_MAPPING, inplace=True) metric_display_cols = [ COLUMN_MAPPING[m] for m in METRICS if COLUMN_MAPPING[m] in merged_df ] # table_cols = ["Dataset", "Method"] + metric_display_cols table_cols = ["Dataset"] + metric_display_cols merged_df = merged_df[table_cols] output_tsv_path = os.path.join( folder_path, "sub_dataset_macro_mean_metrics.tsv" ) with open(output_tsv_path, "w", encoding="utf-8", newline="") as f: writer = csv.writer(f, delimiter="\t", lineterminator="\n") writer.writerow(merged_df.columns.tolist()) writer.writerows(merged_df.values.tolist()) table_str = tabulate( merged_df.values.tolist(), headers=merged_df.columns.tolist(), tablefmt="simple_outline", colalign=("left",) * len(merged_df.columns), ) block = ( "\n" + "=" * 80 + f"\n{title}\n" + "=" * 80 + f"\n\n{table_str}\n" + "=" * 80 + "\n" ) print(block) metric_log_path = os.path.join(folder_path, "metric.log") with open(metric_log_path, "a", encoding="utf-8") as file: file.write(block) def generate_macro_mean_report_from_json_dir(folder_path: str) -> str: mean_df, _ = process_data(folder_path) report_path = generate_report( df=mean_df, folder_path=folder_path, file_name="result_table_macro_mean.md", title="Evaluation Results (Macro-Averaging Mean)", ) _print_cli_tables( df=mean_df, title="DATASET-WISE METRICS (MACRO-AVERAGING MEAN)", folder_path=folder_path, ) return report_path if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dir", type=str, help="json文件夹路径") args = parser.parse_args() out_md = generate_macro_mean_report_from_json_dir(args.dir) print(f"报告已生成: {out_md}") ================================================ FILE: holomotion/scripts/evaluation/multi_model_metrics_analysis.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env metrics_json_dir="logs/Holomotion/metrics_output_dataset" ${Train_CONDA_PREFIX}/bin/python \ holomotion/src/evaluation/multi_model_metrics_report.py \ --json_dir="$metrics_json_dir" ================================================ FILE: holomotion/scripts/motion_retargeting/apply_gmr_motion_retarget_patch.sh ================================================ #!/usr/bin/env bash # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # This file was originally adapted from the [GMR] repository: # https://github.com/YanjieZe/GMR/blob/master/general_motion_retargeting/motion_retarget.py # set -euo pipefail REPO_ROOT="$(pwd)" TARGET_FILE="${1:-$REPO_ROOT/thirdparties/GMR/general_motion_retargeting/motion_retarget.py}" if [[ ! -f "$TARGET_FILE" ]]; then echo "Target file not found: $TARGET_FILE" >&2 exit 1 fi python - "$TARGET_FILE" <<'PY' from pathlib import Path import ast import sys import textwrap PATCH_MARKERS = ( "self.first_frame_damping = max(float(damping), 2.0)", "self.prev_posture_task = mink.PostureTask(self.model, cost=1e-3)", "def _solve_task_group(", ) PATCHED_INIT = """ def __init__( self, src_human: str, tgt_robot: str, actual_human_height: float = None, solver: str="daqp", # change from "quadprog" to "daqp". damping: float=5e-1, # change from 1e-1 to 1e-2. verbose: bool=True, use_velocity_limit: bool=False, ) -> None: # load the robot model self.xml_file = str(ROBOT_XML_DICT[tgt_robot]) if verbose: print("Use robot model: ", self.xml_file) self.model = mj.MjModel.from_xml_path(self.xml_file) # Print DoF names in order print("[GMR] Robot Degrees of Freedom (DoF) names and their order:") self.robot_dof_names = {} for i in range(self.model.nv): # 'nv' is the number of DoFs dof_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, self.model.dof_jntid[i]) self.robot_dof_names[dof_name] = i if verbose: print(f"DoF {i}: {dof_name}") print("[GMR] Robot Body names and their IDs:") self.robot_body_names = {} for i in range(self.model.nbody): # 'nbody' is the number of bodies body_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, i) self.robot_body_names[body_name] = i if verbose: print(f"Body ID {i}: {body_name}") print("[GMR] Robot Motor (Actuator) names and their IDs:") self.robot_motor_names = {} for i in range(self.model.nu): # 'nu' is the number of actuators (motors) motor_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_ACTUATOR, i) self.robot_motor_names[motor_name] = i if verbose: print(f"Motor ID {i}: {motor_name}") # Load the IK config with open(IK_CONFIG_DICT[src_human][tgt_robot]) as f: ik_config = json.load(f) if verbose: print("Use IK config: ", IK_CONFIG_DICT[src_human][tgt_robot]) # compute the scale ratio based on given human height and the assumption in the IK config if actual_human_height is not None: ratio = actual_human_height / ik_config["human_height_assumption"] else: ratio = 1.0 # adjust the human scale table for key in ik_config["human_scale_table"].keys(): ik_config["human_scale_table"][key] = ik_config["human_scale_table"][key] * ratio # used for retargeting self.ik_match_table1 = ik_config["ik_match_table1"] self.ik_match_table2 = ik_config["ik_match_table2"] self.human_root_name = ik_config["human_root_name"] self.robot_root_name = ik_config["robot_root_name"] self.use_ik_match_table1 = ik_config["use_ik_match_table1"] self.use_ik_match_table2 = ik_config["use_ik_match_table2"] self.human_scale_table = ik_config["human_scale_table"] self.ground = ik_config["ground_height"] * np.array([0, 0, 1]) self.max_iter = 10 self.solver = solver self.damping = damping self.first_frame_damping = max(float(damping), 2.0) self.first_frame_max_iter = max(int(self.max_iter), 10) self._is_first_frame = True self.human_body_to_task1 = {} self.human_body_to_task2 = {} self.pos_offsets1 = {} self.rot_offsets1 = {} self.pos_offsets2 = {} self.rot_offsets2 = {} self._arm_task_original_orientation_costs = {} self._first_frame_arm_orientation_cost = 1.0 self.task_errors1 = {} self.task_errors2 = {} self.ik_limits = [mink.ConfigurationLimit(self.model)] if use_velocity_limit: VELOCITY_LIMITS = {k: 3*np.pi for k in self.robot_motor_names.keys()} self.ik_limits.append(mink.VelocityLimit(self.model, VELOCITY_LIMITS)) self.setup_retarget_configuration() self.ground_offset = 0.0 """ PATCHED_SETUP = """ def setup_retarget_configuration(self): self.configuration = mink.Configuration(self.model) self._default_qpos = self.configuration.data.qpos.copy() self.posture_task = mink.PostureTask(self.model, cost=1e-2) self.posture_task.set_target(self._default_qpos) self.prev_posture_task = mink.PostureTask(self.model, cost=1e-3) self.prev_posture_task.set_target(self._default_qpos) self.tasks1 = [] self.tasks2 = [] for frame_name, entry in self.ik_match_table1.items(): body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry if pos_weight != 0 or rot_weight != 0: task = mink.FrameTask( frame_name=frame_name, frame_type="body", position_cost=pos_weight, orientation_cost=rot_weight, lm_damping=1, ) self.human_body_to_task1[body_name] = task self.pos_offsets1[body_name] = np.array(pos_offset) - self.ground self.rot_offsets1[body_name] = R.from_quat( rot_offset, scalar_first=True ) self.tasks1.append(task) self.task_errors1[task] = [] if self._is_arm_body(body_name): self._arm_task_original_orientation_costs[task] = float( rot_weight ) for frame_name, entry in self.ik_match_table2.items(): body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry if pos_weight != 0 or rot_weight != 0: task = mink.FrameTask( frame_name=frame_name, frame_type="body", position_cost=pos_weight, orientation_cost=rot_weight, lm_damping=1, ) self.human_body_to_task2[body_name] = task self.pos_offsets2[body_name] = np.array(pos_offset) - self.ground self.rot_offsets2[body_name] = R.from_quat( rot_offset, scalar_first=True ) self.tasks2.append(task) self.task_errors2[task] = [] if self._is_arm_body(body_name): self._arm_task_original_orientation_costs[task] = float( rot_weight ) """ PATCHED_RETARGET_BLOCK = """ @staticmethod def _is_arm_body(body_name): return any( token in body_name for token in ( "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist", ) ) def _set_first_frame_arm_task_costs(self, enabled): for task, original_orientation_cost in ( self._arm_task_original_orientation_costs.items() ): orientation_cost = ( self._first_frame_arm_orientation_cost if enabled else original_orientation_cost ) task.set_orientation_cost(orientation_cost) def _solve_task_group( self, tasks, error_fn, *, damping, max_iter, include_posture, include_prev_posture, ): solve_tasks = list(tasks) if include_posture: solve_tasks.append(self.posture_task) if include_prev_posture: solve_tasks.append(self.prev_posture_task) curr_error = error_fn() dt = self.configuration.model.opt.timestep vel = mink.solve_ik( self.configuration, solve_tasks, dt, self.solver, damping, limits=self.ik_limits, ) self.configuration.integrate_inplace(vel, dt) next_error = error_fn() num_iter = 0 while curr_error - next_error > 0.001 and num_iter < max_iter: curr_error = next_error dt = self.configuration.model.opt.timestep vel = mink.solve_ik( self.configuration, solve_tasks, dt, self.solver, damping, limits=self.ik_limits, ) self.configuration.integrate_inplace(vel, dt) next_error = error_fn() num_iter += 1 def retarget(self, human_data, offset_to_ground=False): prev_q = self.configuration.data.qpos.copy() # Update the task targets self.update_targets(human_data, offset_to_ground) include_posture = self._is_first_frame include_prev_posture = True solve_damping = ( self.first_frame_damping if self._is_first_frame else self.damping ) solve_max_iter = ( self.first_frame_max_iter if self._is_first_frame else self.max_iter ) self.prev_posture_task.set_target(prev_q) if self._is_first_frame: self._set_first_frame_arm_task_costs(True) if self.use_ik_match_table1: self._solve_task_group( self.tasks1, self.error1, damping=solve_damping, max_iter=solve_max_iter, include_posture=include_posture, include_prev_posture=include_prev_posture, ) if self.use_ik_match_table2: self._solve_task_group( self.tasks2, self.error2, damping=solve_damping, max_iter=solve_max_iter, include_posture=include_posture, include_prev_posture=include_prev_posture, ) if self._is_first_frame: self._set_first_frame_arm_task_costs(False) self._is_first_frame = False return self.configuration.data.qpos.copy() """ def indent_block(src: str, indent: str = " ") -> str: body = textwrap.dedent(src).strip("\n") return "\n".join(indent + line if line else "" for line in body.splitlines()) + "\n" def find_class(module: ast.Module, class_name: str) -> ast.ClassDef: for node in module.body: if isinstance(node, ast.ClassDef) and node.name == class_name: return node raise SystemExit(f"Class {class_name!r} not found in target file.") def find_method(class_node: ast.ClassDef, method_name: str) -> ast.FunctionDef: for node in class_node.body: if isinstance(node, ast.FunctionDef) and node.name == method_name: return node raise SystemExit(f"Method {method_name!r} not found in class {class_node.name}.") def apply_replacement(lines, node: ast.AST, replacement: str): start = node.lineno - 1 end = node.end_lineno return lines[:start] + [indent_block(replacement)] + lines[end:] path = Path(sys.argv[1]) text = path.read_text(encoding="utf-8") if all(marker in text for marker in PATCH_MARKERS): print(f"Patch already present: {path}") raise SystemExit(0) module = ast.parse(text) class_node = find_class(module, "GeneralMotionRetargeting") init_node = find_method(class_node, "__init__") setup_node = find_method(class_node, "setup_retarget_configuration") retarget_node = find_method(class_node, "retarget") lines = text.splitlines(keepends=True) for node, replacement in sorted( [ (retarget_node, PATCHED_RETARGET_BLOCK), (setup_node, PATCHED_SETUP), (init_node, PATCHED_INIT), ], key=lambda item: item[0].lineno, reverse=True, ): lines = apply_replacement(lines, node, replacement) new_text = "".join(lines) ast.parse(new_text) path.write_text(new_text, encoding="utf-8") print(f"Patch applied to: {path}") PY ================================================ FILE: holomotion/scripts/motion_retargeting/pack_hdf5_v2.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env export CUDA_VISIBLE_DEVICES="" holomotion_npz_root='["data/holomotion_retargeted/AMASS_test"]' hdf5_root="data/h5v2_datasets/AMASS_test" robot_config="unitree/G1/29dof/29dof_training_isaaclab" ${Train_CONDA_PREFIX}/bin/python \ holomotion/src/motion_retargeting/pack_hdf5_v2.py \ robot=$robot_config \ holomotion_npz_root=${holomotion_npz_root} \ hdf5_root=$hdf5_root ================================================ FILE: holomotion/scripts/motion_retargeting/run_holomotion_preprocessing.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env holo_src_dir="src_holomotion_npz_dir" holo_tgt_dir="output_holomotion_npz_dir" pipeline="['filename_as_motionkey','legacy_to_ref_keys','tagging']" robot_config="holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml" ${Train_CONDA_PREFIX}/bin/python \ holomotion/src/motion_retargeting/holomotion_preprocess.py \ padding.robot_config_path=${robot_config} \ io.src_root=${holo_src_dir} \ io.out_root=${holo_tgt_dir} \ preprocess.pipeline=${pipeline} \ ray.enabled=true \ padding.stand_still_time=20.0 \ ray.num_workers=2 ================================================ FILE: holomotion/scripts/motion_retargeting/run_kinematic_filter.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env dataset_root="data/holomotion_retargeted/processed_datasets/AMASS_test" ${Train_CONDA_PREFIX}/bin/python \ holomotion/src/motion_retargeting/kinematic_filter.py \ io.dataset_root=${dataset_root} ================================================ FILE: holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_bvh.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env bvh_src_dir="data/lafan1_bvh" gmr_tgt_dir="data/gmr_retargeted/lafan1/" # Step 1: retargeting to robot dataset from smplx format # create gmr_tgt_dir if not exists if [ ! -d "$gmr_tgt_dir" ]; then mkdir -p $gmr_tgt_dir fi $Train_CONDA_PREFIX/bin/python \ thirdparties/GMR/scripts/bvh_to_robot_dataset.py \ --src_folder ${bvh_src_dir}/ \ --tgt_folder ${gmr_tgt_dir}/ \ --robot unitree_g1 ================================================ FILE: holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env smplx_src_dir="assets/test_data/motion_retargeting/" gmr_tgt_dir="data/gmr_retargeted/AMASS_test/" # create gmr_tgt_dir if not exists if [ ! -d "$gmr_tgt_dir" ]; then mkdir -p $gmr_tgt_dir fi $Train_CONDA_PREFIX/bin/python \ thirdparties/GMR/scripts/smplx_to_robot_dataset.py \ --src_folder=${smplx_src_dir}/ \ --tgt_folder=${gmr_tgt_dir}/ \ --num_cpus=16 \ --robot=unitree_g1 ================================================ FILE: holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_to_holomotion.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env dir_name="AMASS_test" gmr_tgt_dir="data/gmr_retargeted/${dir_name}" holo_retargeted_dir="data/holomotion_retargeted/processed_datasets/${dir_name}" robot_cfg="holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml" preprocess_pipeline="['filename_as_motionkey','legacy_to_ref_keys','slicing','add_padding','tagging']" ${Train_CONDA_PREFIX}/bin/python \ holomotion/src/motion_retargeting/gmr_to_holomotion.py \ io.robot_config=${robot_cfg} \ io.src_dir=${gmr_tgt_dir} \ io.out_root=${holo_retargeted_dir} \ processing.target_fps=50 \ preprocess.pipeline=${preprocess_pipeline} \ ray.num_workers=16 ================================================ FILE: holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env export MUJOCO_GL="osmesa" motion_npz_root="path_to_your_npz_dir" export motion_name="all" $Train_CONDA_PREFIX/bin/python holomotion/src/motion_retargeting/utils/visualize_with_mujoco.py \ +key_prefix="robot_" \ +draw_ref_body_spheres=true \ +ref_key_prefix="ref_" \ +motion_npz_root=${motion_npz_root} \ skip_frames=6 \ max_workers=11 \ +motion_name='${oc.env:motion_name}' ================================================ FILE: holomotion/scripts/training/train_motion_tracking.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env export CUDA_VISIBLE_DEVICES=0 if [[ $(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -eq 1 ]]; then USE_MULTI_GPU=false else USE_MULTI_GPU=true fi config_name="train_g1_29dof_motion_tracking_mlp" # config_name="train_g1_29dof_motion_tracking_tf-moe" num_envs=4096 COMMON_ARGS=( "holomotion/src/training/train.py" "--config-name=training/motion_tracking/${config_name}" "num_envs=${num_envs}" "headless=true" "experiment_name=${config_name}" ) trap cleanup SIGINT SIGTERM if [[ "${USE_MULTI_GPU}" == "true" ]]; then ${Train_CONDA_PREFIX}/bin/accelerate launch \ --multi_gpu \ "${COMMON_ARGS[@]}" else ${Train_CONDA_PREFIX}/bin/accelerate launch \ "${COMMON_ARGS[@]}" fi wait ${TRAIN_PID} trap - SIGINT SIGTERM ================================================ FILE: holomotion/scripts/training/train_velocity_tracking.sh ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. source train.env export CUDA_VISIBLE_DEVICES=0 config_name="train_g1_29dof_velocity_tracking_mlp" num_envs=4096 COMMON_ARGS=( "holomotion/src/training/train.py" "--config-name=training/velocity_tracking/${config_name}" "experiment_name=${config_name}" "num_envs=${num_envs}" "headless=true" ) trap cleanup SIGINT SIGTERM if [[ $(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -eq 1 ]]; then ${Train_CONDA_PREFIX}/bin/accelerate launch \ --multi_gpu \ "${COMMON_ARGS[@]}" else ${Train_CONDA_PREFIX}/bin/accelerate launch \ "${COMMON_ARGS[@]}" fi wait ${TRAIN_PID} trap - SIGINT SIGTERM ================================================ FILE: holomotion/src/algo/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/algo/algo_base.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import random import statistics import sys import time from collections import deque from typing import Any, Dict import numpy as np import torch from accelerate import Accelerator from accelerate.utils import ( ProjectConfiguration, TorchDynamoPlugin, load_checkpoint_in_model, load_state_dict, ) from hydra.utils import get_class from loguru import logger from tensordict import TensorDict from holomotion.src.algo.algo_utils import AlgoLogger class BaseOnpolicyRL: """Base class for on-policy RL algorithms in HoloMotion.""" def __init__( self, env_config, config, log_dir=None, headless: bool = True, is_offline_eval: bool = False, ) -> None: self.config = config self.env_config = env_config self.log_dir = log_dir self.headless = headless self.is_offline_eval = is_offline_eval self._setup_accelerator() self.algo_logger = AlgoLogger( self.accelerator, self.log_dir, is_main_process=self.is_main_process, ) self._setup_environment() self._setup_configs() self._setup_seeding() self._setup_data_buffers() self._setup_algo_components() self._setup_models_and_optimizer() def _setup_accelerator(self) -> None: if not self.is_offline_eval: os.makedirs(self.log_dir, exist_ok=True) accelerator_kwargs = {} mixed_precision = self.config.get("mixed_precision", None) if mixed_precision in ("fp16", "bf16"): accelerator_kwargs["mixed_precision"] = mixed_precision dynamo_backend = self.config.get("dynamo_backend", None) if os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1": dynamo_backend = None if dynamo_backend in ("inductor", "aot_eager", "cudagraphs"): dynamo_dynamic = bool(self.config.get("dynamo_dynamic", True)) dynamo_fullgraph = bool(self.config.get("dynamo_fullgraph", False)) dynamo_mode = self.config.get("dynamo_mode", "default") accelerator_kwargs["dynamo_plugin"] = TorchDynamoPlugin( backend=str(dynamo_backend), mode=str(dynamo_mode), fullgraph=bool(dynamo_fullgraph), dynamic=bool(dynamo_dynamic), ) accelerator_kwargs["log_with"] = "tensorboard" project_config = ProjectConfiguration( project_dir=self.log_dir, logging_dir=self.log_dir, ) accelerator_kwargs["project_config"] = project_config self.accelerator = Accelerator(**accelerator_kwargs) self.local_rank = getattr( self.accelerator, "local_process_index", None ) if self.local_rank is None: self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) self.device = self.accelerator.device if torch.cuda.is_available() and self.device.type == "cuda": dev_index = self.device.index if dev_index is None: dev_index = int(self.local_rank) self.device = torch.device("cuda", dev_index) else: dev_index = int(dev_index) torch.cuda.set_device(dev_index) self.is_main_process = self.accelerator.is_main_process self.accelerator.init_trackers( project_name="holomotion", config={ "precision": mixed_precision if mixed_precision else "fp32", "dynamo_backend": dynamo_backend if dynamo_backend else "none", "dynamo_dynamic": bool(self.config.get("dynamo_dynamic", True)) if dynamo_backend else False, }, ) self._release_cuda_cache() logger.remove() log_level = os.environ.get("LOGURU_LEVEL", "INFO").upper() if self.log_dir: rank_log_file_name = ( "offline_eval_rank" if self.is_offline_eval else "run_rank" ) logger.add( os.path.join( self.log_dir, f"{rank_log_file_name}_{int(self.accelerator.process_index):04d}.log", ), level=log_level, colorize=False, ) if self.is_main_process: logger.add( sys.stdout, level=log_level, colorize=True, ) log_file_name = ( "offline_eval.log" if self.is_offline_eval else "run.log" ) logger.add( os.path.join(self.log_dir, log_file_name), level=log_level, colorize=False, ) used_precision = mixed_precision if mixed_precision else "fp32" logger.info( f"Accelerator initialized with precision: {used_precision}" ) if dynamo_backend: logger.info(f"Accelerator dynamo_backend: {dynamo_backend}") logger.info(f"TensorBoard logging enabled at: {self.log_dir}") self.process_rank = self.accelerator.process_index self.gpu_world_size = self.accelerator.num_processes self.gpu_global_rank = self.accelerator.process_index self.is_distributed = self.gpu_world_size > 1 env_rank = os.environ.get("RANK", "unset") env_local_rank = os.environ.get("LOCAL_RANK", "unset") env_world_size = os.environ.get("WORLD_SIZE", "unset") env_local_world_size = os.environ.get("LOCAL_WORLD_SIZE", "unset") env_node_rank = os.environ.get( "NODE_RANK", os.environ.get("MACHINE_RANK", "unset") ) env_master_addr = os.environ.get("MASTER_ADDR", "unset") env_master_port = os.environ.get("MASTER_PORT", "unset") env_cuda_visible_devices = os.environ.get( "CUDA_VISIBLE_DEVICES", "unset" ) cuda_device_count = ( int(torch.cuda.device_count()) if torch.cuda.is_available() else 0 ) logger.info( "[Accelerate setup] " f"distributed_type={self.accelerator.distributed_type}, " f"num_processes={int(self.accelerator.num_processes)}, " f"process_index={int(self.accelerator.process_index)}, " f"local_process_index={int(self.local_rank)}, " f"is_main_process={bool(self.accelerator.is_main_process)}" ) logger.info( "[Accelerate env] " f"RANK={env_rank}, LOCAL_RANK={env_local_rank}, " f"WORLD_SIZE={env_world_size}, " f"LOCAL_WORLD_SIZE={env_local_world_size}, " f"NODE_RANK={env_node_rank}, MASTER_ADDR={env_master_addr}, " f"MASTER_PORT={env_master_port}" ) logger.info( "[Accelerate cuda] " f"CUDA_VISIBLE_DEVICES={env_cuda_visible_devices}, " f"torch_cuda_device_count={cuda_device_count}, " f"selected_device={self.device}" ) def _setup_environment(self) -> None: """Setup IsaacLab AppLauncher and environment instance.""" # Device string from accelerator (handles distributed training) device_str = str(self.device) # Delayed import to ensure Accelerate is fully initialized before IsaacLab from isaaclab.app import AppLauncher # Stagger IsaacSim AppLauncher initialization across distributed ranks # Use local rank per node to stagger independently on each node if self.is_distributed: self.accelerator.wait_for_everyone() base_delay_s = float( os.environ.get("HOLOMOTION_ISAAC_STAGGER_SEC", "5.0") ) local_rank = int(self.local_rank) delay_s = base_delay_s * float(local_rank) if delay_s > 0.0: logger.info( f"[Global Rank {self.gpu_global_rank}, Local Rank {local_rank}] " f"Sleeping {delay_s:.1f}s before IsaacSim AppLauncher init" ) time.sleep(delay_s) # Create AppLauncher with accelerator device # Enable cameras only when needed: # - headless & recording: True (offscreen rendering) # - headless & not recording: False (maximize performance) # - with GUI: True _record_video = bool(self.config.get("record_video", False)) enable_cameras = _record_video or (not self.headless) # Explicitly disable Omniverse multi-GPU rendering to avoid per-process # MGPU context creation across all visible GPUs. kit_args_str = ( "--/renderer/multiGpu/enabled=false " "--/renderer/multiGpu/autoEnable=false " "--/renderer/multiGpu/maxGpuCount=1" ) app_launcher_flags = { "headless": self.headless, "enable_cameras": enable_cameras, "video": _record_video, "device": device_str, "kit_args": kit_args_str, } self._sim_app_launcher = AppLauncher(**app_launcher_flags) self._sim_app = self._sim_app_launcher.app logger.info( f"AppLauncher initialized with flags: {app_launcher_flags}" ) env_class = get_class(self.env_config._target_) render_mode = ( "rgb_array" if bool(self.config.get("record_video", False)) else None ) self.env = env_class( config=self.env_config.config, device=device_str, headless=self.headless, log_dir=self.log_dir, accelerator=self.accelerator, render_mode=render_mode, ) _ = self.env.reset_all() logger.info(f"Environment initialized with render_mode: {render_mode}") def _setup_configs(self) -> None: self.num_envs: int = self.env.config.num_envs self.num_privileged_obs = 0 self.num_actions = self.env.config.robot.actions_dim self.command_name = list(self.env.config.commands.keys())[0] self.command_term = self.env._env.command_manager.get_term( self.command_name ) if self.command_name == "ref_motion": self.command_term.set_runtime_distributed_context( process_id=int(self.accelerator.process_index), num_processes=int(self.accelerator.num_processes), ) self.command_term.setup_dumping_dir(self.log_dir) self.save_interval = self.config.save_interval self.log_interval = self.config.log_interval self.num_steps_per_env = self.config.num_steps_per_env self.num_learning_iterations = self.config.num_learning_iterations self.total_learning_iterations = int(self.num_learning_iterations) def _setup_seeding(self) -> None: if self.command_name == "ref_motion": self.seed = int(self.command_term.cfg.seed) self.base_seed = int(self.seed - int(self.process_rank)) else: self.base_seed = int(self.config.get("seed", int(time.time()))) self.seed = int(self.base_seed + int(self.process_rank)) random.seed(self.seed) np.random.seed(self.seed) torch.manual_seed(self.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(self.seed) self.env.seed(self.seed) if self.command_name == "ref_motion": self.command_term.set_motion_cache_seed( self.seed, reinitialize=False ) def _setup_data_buffers(self) -> None: self.tot_timesteps = 0 self.tot_time = 0 self.current_learning_iteration = 0 self.start_time = 0 self.stop_time = 0 self.collection_time = 0 self.learn_time = 0 self.ep_infos = [] self.rewbuffer = deque(maxlen=100) self.lenbuffer = deque(maxlen=100) self.cur_reward_sum = torch.zeros( self.env.num_envs, dtype=torch.float, device=self.device, ) self.cur_episode_length = torch.zeros( self.env.num_envs, dtype=torch.float, device=self.device, ) self.storage = None self.transition_td = None self._last_rollout_dones = None self._last_rollout_actions = None def _setup_algo_components(self) -> None: """Hook for algorithm-specific components (AMP, DAgger, PULSE).""" return def _setup_models_and_optimizer(self) -> None: raise NotImplementedError( "Subclasses must implement _setup_models_and_optimizer." ) def _build_storage(self, obs_td: TensorDict) -> Any: """Hook for custom RolloutStorage. Override for specialized storage; default no-op.""" return None def _post_env_step_hook( self, rewards: torch.Tensor, dones: torch.Tensor, time_outs: torch.Tensor, infos: Dict[str, Any], ) -> None: """Hook after each env step for auxiliary data collection.""" if self.command_name != "ref_motion": return motion_term = self.env._env.command_manager.get_term("ref_motion") if motion_term is None: return motion_term.update_curriculum_reward_accumulators(rewards) def _post_update_hook(self, loss_dict: Dict[str, Any]) -> None: """Hook after each PPO update for auxiliary losses or logging.""" return def _extra_checkpoint_state(self) -> Dict[str, Any]: """Additional state to save in checkpoints.""" return {} def _load_extra_checkpoint_state( self, loaded_dict: Dict[str, Any] ) -> None: """Load additional checkpoint state if present.""" return def _build_transition( self, obs_td: TensorDict, actor_out: TensorDict, critic_out: TensorDict, ): raise NotImplementedError( "Subclasses must implement _build_transition." ) def _post_iteration_hook(self, it: int) -> None: return def _post_training_hook(self) -> None: return def _release_cuda_cache(self) -> None: if torch.cuda.is_available() and self.device.type == "cuda": torch.cuda.empty_cache() def _get_additional_log_metrics(self) -> Dict[str, Any]: return {} def train_mode(self) -> None: self.actor.train() self.critic.train() def _ensure_storage(self, obs_td: TensorDict) -> None: if self.storage is not None: return self.storage = self._build_storage(obs_td) if self.storage is None: raise RuntimeError( "Storage is not initialized. Override _build_storage() or initialize self.storage in subclass." ) def _reset_rollout_forward_state(self) -> None: """Hook for algorithm-specific rollout state reset.""" return def _rollout_forward( self, obs_td: TensorDict, *, actor_mode: str = "sampling", collect_transition: bool = True, track_episode_stats: bool = True, ) -> TensorDict: update_obs_norm = not self.is_offline_eval with self.accelerator.autocast(): actor_out: TensorDict = self.actor( obs_td, actions=None, mode=actor_mode, update_obs_norm=update_obs_norm, ) critic_out: TensorDict | None = None if collect_transition: critic_out = self.critic( obs_td, update_obs_norm=update_obs_norm ) if collect_transition: self.transition_td = self._build_transition( obs_td, actor_out, critic_out, ) actions = actor_out.get("actions") self._last_rollout_actions = actions obs_dict, rewards, dones, time_outs, infos = self.env.step(actions) next_obs_td = self._wrap_obs_dict(obs_dict) dones = dones.to(self.device) self._last_rollout_dones = dones if collect_transition: rewards = rewards.to(self.device) time_outs = time_outs.to(self.device) self.process_env_step(rewards, dones, time_outs, infos) if track_episode_stats: rewards_for_stats = rewards.to(self.device) self._track_episode_stats(rewards_for_stats, dones, infos) return next_obs_td def _track_episode_stats( self, rewards: torch.Tensor, dones: torch.Tensor, infos: Dict[str, Any], ) -> None: log_info = infos.get("log") if self.is_main_process and isinstance(log_info, dict): cpu_log_info: Dict[str, torch.Tensor] = {} for key, value in log_info.items(): cpu_value = self._log_value_to_cpu_tensor(value) if cpu_value is not None and cpu_value.numel() > 0: cpu_log_info[key] = cpu_value if len(cpu_log_info) > 0: self.ep_infos.append(cpu_log_info) self.cur_reward_sum += rewards self.cur_episode_length += 1 done_ids = (dones > 0).nonzero(as_tuple=False) self.rewbuffer.extend( self.cur_reward_sum[done_ids][:, 0].cpu().numpy().tolist() ) self.lenbuffer.extend( self.cur_episode_length[done_ids][:, 0].cpu().numpy().tolist() ) self.cur_reward_sum[done_ids] = 0 self.cur_episode_length[done_ids] = 0 def _compute_returns(self, obs_td: TensorDict) -> None: update_obs_norm = not self.is_offline_eval with self.accelerator.autocast(): last_values = ( self.critic(obs_td, update_obs_norm=update_obs_norm) .get("values") .detach() ) self.storage.compute_returns( last_values, self.gamma, self.lam, normalize_advantage=False, ) if getattr(self, "global_advantage_norm", False): accelerator = self.accelerator if self.is_distributed else None self.storage.normalize_advantages_global_by_command( command_name=self.command_name, accelerator=accelerator, eps=1.0e-8, ) def rollout_policy(self, obs_td: TensorDict) -> TensorDict: """Collect one rollout with current policy and compute returns.""" actor_was_training = self.actor.training critic_was_training = self.critic.training self.actor.eval() self.critic.eval() with torch.no_grad(): self._reset_rollout_forward_state() for _ in range(self.num_steps_per_env): obs_td = self._rollout_forward(obs_td) self._compute_returns(obs_td) if actor_was_training: self.actor.train() if critic_was_training: self.critic.train() return obs_td def learn(self): """Main learning loop with runner logic shared across on-policy algorithms.""" obs_dict = self.env.reset_all()[0] obs_td = self._wrap_obs_dict(obs_dict) self._ensure_storage(obs_td) self.train_mode() start_it = self.current_learning_iteration total_it = start_it + int(self.num_learning_iterations) self.total_learning_iterations = total_it self.accelerator.wait_for_everyone() if self.is_main_process: logger.info( f"Starting training for {self.num_learning_iterations} iterations " f"from iteration {self.current_learning_iteration}" ) for it in range(start_it, total_it): self.current_learning_iteration = it start = time.time() obs_td = self.rollout_policy(obs_td) stop = time.time() collection_time = stop - start start = stop loss_dict = self.update() stop = time.time() learn_time = stop - start if self.is_main_process and it % self.log_interval == 0: self._log_iteration( it=it, loss_dict=loss_dict, collection_time=collection_time, learn_time=learn_time, ) if self.is_main_process and it % self.save_interval == 0: self.save( os.path.join( self.log_dir, f"model_{self.current_learning_iteration}.pt", ) ) self._release_cuda_cache() self._post_iteration_hook(it) self.ep_infos.clear() self.accelerator.wait_for_everyone() final_checkpoint_path = os.path.join( self.log_dir, f"model_{self.current_learning_iteration}.pt" ) if self.is_main_process: self.save(final_checkpoint_path) self._release_cuda_cache() self._post_training_hook() if self.log_dir: self.accelerator.wait_for_everyone() self.accelerator.end_training() if self.is_main_process: logger.info( f"Training completed. Model saved to {self.log_dir}" ) def process_env_step( self, rewards: torch.Tensor, dones: torch.Tensor, time_outs: torch.Tensor, infos: Dict[str, Any], ) -> None: """Process env step results and append to storage. Args: rewards: [N, 1] rewards (env step output). dones: [N, 1] done flags (env step output). time_outs: [N] time out flags (env step output). infos: Environment info dictionary. """ raw_rewards = rewards.clone().view(-1, 1) rewards = raw_rewards.clone() dones = dones.view(-1, 1) # Bootstrapping on time outs rewards += self.gamma * ( self.transition_td.values * time_outs[:, None] ) self.transition_td.rewards = rewards self.transition_td.dones = dones.to(dtype=torch.bool) self.storage.add(self.transition_td) self._post_env_step_hook(raw_rewards, dones, time_outs, infos) self.transition_td = None def _wrap_obs_dict(self, obs_dict: dict) -> TensorDict: """Wrap env obs dict into a native nested TensorDict on device.""" return TensorDict.from_dict( obs_dict, batch_size=[self.env.num_envs], device=self.device, ) @staticmethod def _clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]: """Remove '_orig_mod.' prefix from torch.compile wrapped models. Args: state_dict: State dict that may contain '_orig_mod.' prefixed keys Returns: Cleaned state dict with prefixes removed """ cleaned_dict = {} prefix = "_orig_mod." prefix_len = len(prefix) for k, v in state_dict.items(): new_k = k[prefix_len:] if k.startswith(prefix) else k cleaned_dict[new_k] = v return cleaned_dict def _load_model_state(self, model, state_dict, *, strict: bool = True): """Load a state dict into a (possibly compiled) model safely. - Always unwrap Accelerate wrappers first. - If the model is a compiled OptimizedModule (has ``_orig_mod``), load into the original module and strip any ``_orig_mod.`` prefixes from the incoming state dict for robustness. """ target = self.accelerator.unwrap_model(model) cleaned = self._clean_state_dict(state_dict) if hasattr(target, "_orig_mod"): target._orig_mod.load_state_dict(cleaned, strict=strict) else: target.load_state_dict(cleaned, strict=strict) def _resolve_model_file_path(self, ckpt_path: str, model_name: str) -> str: """Resolve per-model Accelerate checkpoint directory from *.pt path.""" base_path = ckpt_path.replace(".pt", "") model_path = os.path.join(base_path, model_name) if not os.path.isdir(model_path): raise FileNotFoundError( f"Missing accelerate checkpoint directory for {model_name}: " f"{model_path}" ) return model_path def _load_accelerate_model( self, model, model_path: str, *, strict: bool = True ) -> None: """Load model params from Accelerate checkpoint directory/file.""" checkpoint_path = model_path if os.path.isdir(model_path): safetensors_path = os.path.join(model_path, "model.safetensors") pytorch_bin_path = os.path.join(model_path, "pytorch_model.bin") if os.path.isfile(safetensors_path): checkpoint_path = safetensors_path elif os.path.isfile(pytorch_bin_path): checkpoint_path = pytorch_bin_path else: target = self.accelerator.unwrap_model(model) load_checkpoint_in_model(target, model_path, strict=strict) return state_dict = load_state_dict(checkpoint_path) self._load_model_state(model, state_dict, strict=strict) def _aggregate_episode_log_metrics( self, ) -> Dict[str, float]: metrics: Dict[str, float] = {} if len(self.ep_infos) == 0: return metrics metric_sums: Dict[str, float] = {} metric_counts: Dict[str, int] = {} for ep_info in self.ep_infos: for key, value in ep_info.items(): cpu_value = self._log_value_to_cpu_tensor(value) if cpu_value is None or cpu_value.numel() == 0: continue metric_sums[key] = metric_sums.get(key, 0.0) + float( cpu_value.sum().item() ) metric_counts[key] = metric_counts.get(key, 0) + int( cpu_value.numel() ) for key, total in metric_sums.items(): count = metric_counts.get(key, 0) if count <= 0: continue mean_value = total / float(count) metric_key = key if "/" in key else f"Episode/{key}" metrics[metric_key] = mean_value return metrics @staticmethod def _log_value_to_cpu_tensor(value: Any) -> torch.Tensor | None: if isinstance(value, torch.Tensor): tensor = value.detach() if tensor.ndim == 0: tensor = tensor.unsqueeze(0) return tensor.to(device="cpu", dtype=torch.float32).reshape(-1) if isinstance(value, np.ndarray): return torch.as_tensor(value, dtype=torch.float32).reshape(-1) if isinstance(value, (int, float)): return torch.tensor([float(value)], dtype=torch.float32) return None def _log_iteration( self, *, it: int, loss_dict: Dict[str, Any], collection_time: float, learn_time: float, synced_mean_reward: float | None = None, synced_mean_episode_length: float | None = None, ) -> None: if not self.log_dir: return world_size = max(1, int(self.gpu_world_size)) fps = int( self.num_steps_per_env * self.num_envs * world_size / max(collection_time + learn_time, 1.0e-8) ) total_learning_iterations = int( getattr( self, "total_learning_iterations", self.current_learning_iteration + int(self.num_learning_iterations), ) ) iteration_metrics: Dict[str, Any] = { "0-Train/iteration": int(it), "0-Train/iterations_total": total_learning_iterations, } for key, value in loss_dict.items(): if value is None: continue scalar = float(value) iteration_metrics[f"Loss/{key}"] = scalar iteration_metrics.update( { "1-Perf/total_fps": float(fps), "1-Perf/collection_time": float(collection_time), "1-Perf/learning_time": float(learn_time), } ) if ( synced_mean_reward is not None and synced_mean_episode_length is not None ): iteration_metrics["0-Train/mean_reward"] = float( synced_mean_reward ) iteration_metrics["0-Train/mean_episode_length"] = float( synced_mean_episode_length ) elif len(self.rewbuffer) > 0: mean_reward = float(statistics.mean(self.rewbuffer)) mean_episode_length = float(statistics.mean(self.lenbuffer)) iteration_metrics["0-Train/mean_reward"] = mean_reward iteration_metrics["0-Train/mean_episode_length"] = ( mean_episode_length ) iteration_metrics.update(self._aggregate_episode_log_metrics()) iteration_metrics.update(self._get_additional_log_metrics()) self.algo_logger.log_iteration( step=it, total_learning_iterations=total_learning_iterations, metrics=iteration_metrics, ) def load(self, ckpt_path): raise NotImplementedError("Subclasses must implement load().") def save(self, path, infos=None): raise NotImplementedError("Subclasses must implement save().") ================================================ FILE: holomotion/src/algo/algo_utils.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os from collections.abc import Mapping from typing import Any, Generator import torch import torch.nn as nn from loguru import logger from tabulate import tabulate from tensordict import TensorDict, tensorclass class AlgoLogger: def __init__( self, accelerator, log_dir: str | None, *, is_main_process: bool, ) -> None: self.accelerator = accelerator self.log_dir = log_dir self.is_main_process = bool(is_main_process) @staticmethod def _is_scalar_metric(value: Any) -> bool: if isinstance(value, (int, float)): return True if isinstance(value, torch.Tensor): return value.numel() == 1 return False @staticmethod def _to_scalar(value: Any) -> float: if isinstance(value, torch.Tensor): return float(value.item()) return float(value) @staticmethod def _format_console_value(value: Any) -> str: if isinstance(value, (int, float)): value_f = float(value) abs_value = abs(value_f) if abs_value > 0.0 and (abs_value < 1.0e-4 or abs_value >= 1.0e4): return f"{value_f:.4e}" return f"{value_f:.4f}" if isinstance(value, torch.Tensor) and value.numel() == 1: value_f = float(value.item()) abs_value = abs(value_f) if abs_value > 0.0 and (abs_value < 1.0e-4 or abs_value >= 1.0e4): return f"{value_f:.4e}" return f"{value_f:.4f}" return str(value) def _build_console_log( self, *, step: int, total_learning_iterations: int | None, console_metrics: Mapping[str, Any], ) -> str: if total_learning_iterations is None: title = f"TRAINING LOG - Iteration {step}" else: title = ( f"TRAINING LOG - Iteration {step}/{total_learning_iterations}" ) table_data = [ [key, str(console_metrics[key])] for key in sorted(console_metrics.keys()) ] log_lines = [ "\n" + "=" * 80, title, "=" * 80, tabulate( table_data, headers=["Metric", "Value"], tablefmt="simple_outline", colalign=("left", "left"), disable_numparse=True, ), "=" * 80, f"Logging Directory: {os.path.abspath(self.log_dir)}", "=" * 80 + "\n", ] return "\n".join(log_lines) def log_iteration( self, *, step: int, metrics: Mapping[str, Any], total_learning_iterations: int | None = None, ) -> None: if not self.log_dir or not self.is_main_process: return tensorboard_metrics: dict[str, float] = {} for key in sorted(metrics.keys()): value = metrics[key] if value is None or not self._is_scalar_metric(value): continue tensorboard_metrics[key] = self._to_scalar(value) if len(tensorboard_metrics) > 0: self.accelerator.log(tensorboard_metrics, step=int(step)) console_metrics = { key: self._format_console_value(value) for key, value in metrics.items() if value is not None } console_log = self._build_console_log( step=step, total_learning_iterations=total_learning_iterations, console_metrics=console_metrics, ) logger.info(console_log) @tensorclass(shadow=True) class PpoTransition: """PPO rollout transition tensorclass. Batch axes: - N: num_envs (per-step) - B: minibatch_size (for minibatches) Shapes (batch dims = [N] or [B]): - obs: TensorDict with leaf tensors [*, ...] - actions, teacher_actions, mu, sigma: [*, A] - actions_log_prob, values, rewards, returns, advantages, dones: [*, 1] All float tensors are float32. `dones` is bool. """ FIELD_SPECS = { "obs": {"kind": "obs"}, "actions": {"shape": ("A",), "dtype": torch.float32}, "teacher_actions": {"shape": ("A",), "dtype": torch.float32}, "mu": {"shape": ("A",), "dtype": torch.float32}, "sigma": {"shape": ("A",), "dtype": torch.float32}, "actions_log_prob": {"shape": (1,), "dtype": torch.float32}, "values": {"shape": (1,), "dtype": torch.float32}, "rewards": {"shape": (1,), "dtype": torch.float32}, "dones": {"shape": (1,), "dtype": torch.bool}, "returns": {"shape": (1,), "dtype": torch.float32}, "advantages": {"shape": (1,), "dtype": torch.float32}, } obs: TensorDict actions: torch.Tensor teacher_actions: torch.Tensor mu: torch.Tensor sigma: torch.Tensor actions_log_prob: torch.Tensor values: torch.Tensor rewards: torch.Tensor dones: torch.Tensor returns: torch.Tensor advantages: torch.Tensor @tensorclass(shadow=True) class PpoVelocityTransition: """PPO rollout transition tensorclass. Batch axes: - N: num_envs (per-step) - B: minibatch_size (for minibatches) Shapes (batch dims = [N] or [B]): - obs: TensorDict with leaf tensors [*, ...] - actions, teacher_actions, mu, sigma: [*, A] - actions_log_prob, values, rewards, returns, advantages, dones: [*, 1] - velocity_commands: [*, 4] All float tensors are float32. `dones` is bool. """ FIELD_SPECS = { "obs": {"kind": "obs"}, "actions": {"shape": ("A",), "dtype": torch.float32}, "teacher_actions": {"shape": ("A",), "dtype": torch.float32}, "mu": {"shape": ("A",), "dtype": torch.float32}, "sigma": {"shape": ("A",), "dtype": torch.float32}, "actions_log_prob": {"shape": (1,), "dtype": torch.float32}, "values": {"shape": (1,), "dtype": torch.float32}, "rewards": {"shape": (1,), "dtype": torch.float32}, "dones": {"shape": (1,), "dtype": torch.bool}, "returns": {"shape": (1,), "dtype": torch.float32}, "advantages": {"shape": (1,), "dtype": torch.float32}, "velocity_commands": {"shape": (4,), "dtype": torch.float32}, } obs: TensorDict actions: torch.Tensor teacher_actions: torch.Tensor mu: torch.Tensor sigma: torch.Tensor actions_log_prob: torch.Tensor values: torch.Tensor rewards: torch.Tensor dones: torch.Tensor returns: torch.Tensor advantages: torch.Tensor velocity_commands: torch.Tensor @tensorclass(shadow=True) class PpoAuxTransition: """PPO transition with auxiliary state-prediction supervision targets.""" SHAPE_TOKENS = {"C": 0, "K": 0} FIELD_SPECS = { "obs": {"kind": "obs"}, "actions": {"shape": ("A",), "dtype": torch.float32}, "teacher_actions": {"shape": ("A",), "dtype": torch.float32}, "mu": {"shape": ("A",), "dtype": torch.float32}, "sigma": {"shape": ("A",), "dtype": torch.float32}, "actions_log_prob": {"shape": (1,), "dtype": torch.float32}, "values": {"shape": (1,), "dtype": torch.float32}, "rewards": {"shape": (1,), "dtype": torch.float32}, "dones": {"shape": (1,), "dtype": torch.bool}, "returns": {"shape": (1,), "dtype": torch.float32}, "advantages": {"shape": (1,), "dtype": torch.float32}, "gt_base_lin_vel_b": {"shape": (3,), "dtype": torch.float32}, "gt_root_height_rel_terrain": {"shape": (1,), "dtype": torch.float32}, "gt_keybody_contacts": {"shape": ("C",), "dtype": torch.float32}, "gt_ref_keybody_rel_pos": { "shape": ("K", 3), "dtype": torch.float32, }, "gt_robot_keybody_rel_pos": { "shape": ("K", 3), "dtype": torch.float32, }, "gt_denoise_ref_root_lin_vel": { "shape": (3,), "dtype": torch.float32, }, "gt_denoise_ref_root_ang_vel": { "shape": (3,), "dtype": torch.float32, }, "gt_denoise_ref_dof_pos": { "shape": ("A",), "dtype": torch.float32, }, } obs: TensorDict actions: torch.Tensor teacher_actions: torch.Tensor mu: torch.Tensor sigma: torch.Tensor actions_log_prob: torch.Tensor values: torch.Tensor rewards: torch.Tensor dones: torch.Tensor returns: torch.Tensor advantages: torch.Tensor gt_base_lin_vel_b: torch.Tensor gt_root_height_rel_terrain: torch.Tensor gt_keybody_contacts: torch.Tensor gt_ref_keybody_rel_pos: torch.Tensor gt_robot_keybody_rel_pos: torch.Tensor gt_denoise_ref_root_lin_vel: torch.Tensor gt_denoise_ref_root_ang_vel: torch.Tensor gt_denoise_ref_dof_pos: torch.Tensor class RolloutStorage(nn.Module): """Rollout storage as a single TensorDict buffer with batch size [T, N].""" def __init__( self, num_envs, num_transitions_per_env, obs_template: TensorDict, actions_shape, device="cpu", command_name: str | None = None, transition_cls: type[PpoTransition] = PpoTransition, ): super().__init__() self.device = device self.num_transitions_per_env = num_transitions_per_env self.num_envs = num_envs self.command_name = command_name self._float_dtype = torch.float32 self._dones_dtype = torch.bool self._transition_cls = transition_cls obs_template = obs_template.to(self.device) self.data = TensorDict( {}, batch_size=[num_transitions_per_env, num_envs], device=self.device, ) self._allocate_from_transition( obs_template=obs_template, actions_shape=actions_shape, ) self.step = 0 def _resolve_shape(self, spec_shape, actions_shape) -> tuple: if spec_shape is None: return tuple() resolved = [] shape_tokens = getattr(self._transition_cls, "SHAPE_TOKENS", {}) for dim in spec_shape: if dim == "A": resolved.extend(tuple(actions_shape)) elif isinstance(dim, str) and dim in shape_tokens: resolved.append(int(shape_tokens[dim])) else: resolved.append(int(dim)) return tuple(resolved) def _allocate_from_transition( self, *, obs_template: TensorDict, actions_shape, ) -> None: specs = getattr(self._transition_cls, "FIELD_SPECS", None) if not isinstance(specs, dict): raise ValueError( "Transition class must define FIELD_SPECS for allocation." ) for name, spec in specs.items(): if spec.get("kind") == "obs": leaf_keys = obs_template.keys( include_nested=True, leaves_only=True ) for key in leaf_keys: value = obs_template.get(key) if not torch.is_tensor(value): continue dtype = ( self._float_dtype if torch.is_floating_point(value) else value.dtype ) key_tuple = key if isinstance(key, tuple) else (key,) self.data.set( ("obs",) + key_tuple, torch.empty( (self.num_transitions_per_env, self.num_envs) + tuple(value.shape[1:]), device=self.device, dtype=dtype, ), ) continue shape_spec = spec.get("shape") dtype = spec.get("dtype", self._float_dtype) resolved = self._resolve_shape(shape_spec, actions_shape) self.data.set( name, torch.empty( (self.num_transitions_per_env, self.num_envs) + resolved, device=self.device, dtype=dtype, ), ) def _to_storage_tensor(self, tensor: torch.Tensor) -> torch.Tensor: if not torch.is_tensor(tensor): raise TypeError("Expected a tensor for RolloutStorage update.") if tensor.device != self.device: tensor = tensor.to(self.device) if ( torch.is_floating_point(tensor) and tensor.dtype != self._float_dtype ): tensor = tensor.to(dtype=self._float_dtype) return tensor def add(self, transition: PpoTransition) -> None: if self.step >= self.num_transitions_per_env: raise OverflowError("Rollout buffer overflow!") if not isinstance(transition, self._transition_cls): raise TypeError( "Transition must match the RolloutStorage transition class." ) if transition.batch_size is None or len(transition.batch_size) < 1: raise ValueError("Transition must have batch size [N].") if int(transition.batch_size[0]) != int(self.num_envs): raise ValueError( f"Transition batch size {transition.batch_size} " f"does not match num_envs={self.num_envs}." ) td = transition.to_tensordict() td = td.apply(self._to_storage_tensor, inplace=False) if "dones" in td.keys(): dones = td.get("dones") if torch.is_tensor(dones) and dones.dtype != self._dones_dtype: td.set("dones", dones.to(dtype=self._dones_dtype)) self.data[self.step].update_(td) self.step += 1 def clear(self) -> None: self.step = 0 def compute_returns( self, last_values, gamma, lam, normalize_advantage: bool = False, ) -> None: advantage = 0 rewards = self.data["rewards"] values = self.data["values"] dones = self.data["dones"] returns = self.data["returns"] advantages = self.data["advantages"] for step in reversed(range(self.num_transitions_per_env)): if step == self.num_transitions_per_env - 1: next_values = last_values else: next_values = values[step + 1] next_is_not_terminal = 1.0 - dones[step].float() delta = ( rewards[step] + next_is_not_terminal * gamma * next_values - values[step] ) advantage = delta + next_is_not_terminal * gamma * lam * advantage returns[step] = advantage + values[step] advantages.copy_(returns - values) if normalize_advantage: flat = advantages.view(-1) mean = flat.mean() std = flat.std().clamp_min(1.0e-8) advantages.copy_((advantages - mean) / std) @torch.no_grad() def normalize_advantages_global( self, *, accelerator=None, eps: float = 1.0e-8, ) -> None: """Global advantage normalization over the full rollout buffer. This normalizes `self.data["advantages"]` in-place using mean/std over all `[T * N]` samples. If `accelerator` is provided, the moments are aggregated across processes via `accelerator.reduce(..., reduction="sum")`. """ advantages = self.data["advantages"] advantages_flat = advantages.view(-1).float() count = torch.tensor( [advantages_flat.numel()], device=self.device, dtype=torch.float32 ) sum_local = advantages_flat.sum() sqsum_local = (advantages_flat * advantages_flat).sum() if accelerator is not None and int(accelerator.num_processes) > 1: count_g = accelerator.reduce(count, reduction="sum") sum_g = accelerator.reduce(sum_local, reduction="sum") sqsum_g = accelerator.reduce(sqsum_local, reduction="sum") else: count_g = count sum_g = sum_local sqsum_g = sqsum_local mean = sum_g / count_g var = (sqsum_g / count_g) - mean * mean std = torch.sqrt(var.clamp_min(eps)) advantages.copy_((advantages - mean) / std) @torch.no_grad() def normalize_advantages_global_by_move_mask( self, *, accelerator=None, eps: float = 1.0e-8, move_threshold: float = 0.5, ) -> None: """Global advantage normalization split by move vs static (velocity commands). Assumes: - `advantages`: [T, N, 1] - `velocity_commands`: [T, N, 4], where channel 0 is move_mask in {0,1}. """ velocity_commands = self.data.get("velocity_commands", None) if velocity_commands is None: raise ValueError( "velocity_commands is required for global advantage normalization by move mask." ) advantages = self.data["advantages"] advantages_flat = advantages.view(-1).float() vel_flat = velocity_commands.view(-1, int(velocity_commands.shape[-1])) move_mask = vel_flat[:, 0] > float(move_threshold) static_mask = ~move_mask count_all = torch.tensor( [advantages_flat.numel()], device=self.device, dtype=torch.float32 ) sum_all_local = advantages_flat.sum() sqsum_all_local = (advantages_flat * advantages_flat).sum() if accelerator is not None and int(accelerator.num_processes) > 1: count_all_g = accelerator.reduce(count_all, reduction="sum") sum_all_g = accelerator.reduce(sum_all_local, reduction="sum") sqsum_all_g = accelerator.reduce(sqsum_all_local, reduction="sum") else: count_all_g = count_all sum_all_g = sum_all_local sqsum_all_g = sqsum_all_local mean_all = sum_all_g / count_all_g var_all = (sqsum_all_g / count_all_g) - mean_all * mean_all std_all = torch.sqrt(var_all.clamp_min(eps)) def _group_stats( mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: if not bool(mask.any().item()): return mean_all, std_all mask_f = mask.to(dtype=torch.float32) count_local = mask_f.sum() sum_local = (advantages_flat * mask_f).sum() sqsum_local = (advantages_flat * advantages_flat * mask_f).sum() if accelerator is not None and int(accelerator.num_processes) > 1: count_g = accelerator.reduce(count_local, reduction="sum") sum_g = accelerator.reduce(sum_local, reduction="sum") sqsum_g = accelerator.reduce(sqsum_local, reduction="sum") else: count_g = count_local sum_g = sum_local sqsum_g = sqsum_local if float(count_g.item()) <= 0.0: return mean_all, std_all mean = sum_g / count_g var = (sqsum_g / count_g) - mean * mean std = torch.sqrt(var.clamp_min(eps)) return mean, std move_mean, move_std = _group_stats(move_mask) static_mean, static_std = _group_stats(static_mask) advantages_norm = advantages_flat.clone() if bool(move_mask.any().item()): advantages_norm[move_mask] = ( advantages_flat[move_mask] - move_mean ) / move_std if bool(static_mask.any().item()): advantages_norm[static_mask] = ( advantages_flat[static_mask] - static_mean ) / static_std self.data["advantages"].copy_(advantages_norm.view_as(advantages)) @torch.no_grad() def normalize_advantages_global_by_command( self, *, command_name: str | None, accelerator=None, eps: float = 1.0e-8, ) -> None: """Dispatch global advantage normalization based on command type/name.""" if ( command_name == "base_velocity" and self.data.get("velocity_commands", None) is not None ): self.normalize_advantages_global_by_move_mask( accelerator=accelerator, eps=eps ) return self.normalize_advantages_global(accelerator=accelerator, eps=eps) def iter_minibatches( self, num_mini_batches: int, num_epochs: int, ) -> Generator[PpoTransition, None, None]: if self.step != self.num_transitions_per_env: raise RuntimeError( f"RolloutStorage buffer not full: step={self.step}, " f"expected={self.num_transitions_per_env}. " "This would mix stale entries from a previous rollout." ) batch_size = self.num_envs * self.num_transitions_per_env ( effective_num_mini_batches, mini_batch_size, ) = self.resolve_mini_batch_partition(batch_size, num_mini_batches) indices = torch.randperm( batch_size, requires_grad=False, device=self.device, )[: effective_num_mini_batches * mini_batch_size] flat = self.data.flatten(0, 1) # [T * N, ...] for _ in range(num_epochs): for i in range(effective_num_mini_batches): start = i * mini_batch_size end = (i + 1) * mini_batch_size batch_idx = indices[start:end] batch = flat[batch_idx] yield self._transition_cls.from_tensordict(batch) @staticmethod def resolve_mini_batch_partition( batch_size: int, num_mini_batches: int, ) -> tuple[int, int]: if batch_size <= 0: raise RuntimeError( "RolloutStorage minibatch partition requires batch_size > 0." ) effective_num_mini_batches = max( 1, min(int(num_mini_batches), int(batch_size)) ) mini_batch_size = max(1, batch_size // effective_num_mini_batches) return effective_num_mini_batches, mini_batch_size ================================================ FILE: holomotion/src/algo/ppo.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import inspect import json import math import os from concurrent.futures import ThreadPoolExecutor, as_completed import torch import torch.nn.functional as F import torch.optim as optim import numpy as np from loguru import logger from tabulate import tabulate from tqdm import tqdm import imageio from omegaconf import OmegaConf from holomotion.src.algo.algo_base import BaseOnpolicyRL from holomotion.src.algo.algo_utils import ( PpoTransition, PpoVelocityTransition, RolloutStorage, ) from holomotion.src.utils.onnx_export import ( export_policy_to_onnx as export_policy_to_onnx_common, ) from tensordict import TensorDict def _checkpoint_state_to_cpu(value): if isinstance(value, torch.Tensor): return value.detach().cpu() if isinstance(value, dict): return {k: _checkpoint_state_to_cpu(v) for k, v in value.items()} if isinstance(value, list): return [_checkpoint_state_to_cpu(v) for v in value] if isinstance(value, tuple): return tuple(_checkpoint_state_to_cpu(v) for v in value) return value class PPO(BaseOnpolicyRL): def _setup_configs(self): super()._setup_configs() self.desired_kl = self.config.desired_kl self.schedule = self.config.schedule self.actor_learning_rate = self.config.get( "actor_learning_rate", self.config.get("learning_rate", 3e-4) ) self.critic_learning_rate = self.config.get( "critic_learning_rate", self.config.get("learning_rate", 3e-4) ) self.base_actor_learning_rate = float(self.actor_learning_rate) self.base_critic_learning_rate = float(self.critic_learning_rate) self.actor_beta1 = self.config.get("actor_beta1", 0.9) self.actor_beta2 = self.config.get("actor_beta2", 0.999) self.critic_beta1 = self.config.get("critic_beta1", 0.9) self.critic_beta2 = self.config.get("critic_beta2", 0.999) self.optimizer_type = self.config.optimizer_type self.clip_param = self.config.clip_param self.num_learning_epochs = int(self.config.num_learning_epochs) self.configured_num_mini_batches = int(self.config.num_mini_batches) if self.configured_num_mini_batches < 1: raise ValueError("num_mini_batches must be >= 1.") distributed_update_cfg = self.config.get("distributed_update", {}) self.distributed_update_mode = str( distributed_update_cfg.get("mode", "legacy") ).lower() if self.distributed_update_mode not in {"legacy", "scalable"}: raise ValueError( "distributed_update.mode must be one of " "{'legacy', 'scalable'}." ) self.requested_num_mini_batches = self._resolve_num_mini_batches( self.configured_num_mini_batches ) self.num_mini_batches = self.requested_num_mini_batches self.gamma = self.config.gamma self.lam = self.config.lam self.value_loss_coef = self.config.value_loss_coef self.initial_entropy_coef = float(self.config.entropy_coef) self.anneal_entropy = bool(self.config.get("anneal_entropy", False)) self.zero_entropy_point = float( self.config.get("zero_entropy_point", 1.0) ) self._validate_entropy_schedule_config( initial_entropy_coef=self.initial_entropy_coef, anneal_entropy=self.anneal_entropy, zero_entropy_point=self.zero_entropy_point, ) self.entropy_coef = self.initial_entropy_coef self.max_grad_norm = self.config.max_grad_norm self.use_clipped_value_loss = self.config.use_clipped_value_loss adaptive_lr_cfg = self.config.get("adaptive_lr", {}) self.adaptive_lr_adapt_critic = bool( adaptive_lr_cfg.get("adapt_critic", False) ) self.adaptive_lr_factor = float(adaptive_lr_cfg.get("lr_scaler", 1.2)) self.adaptive_lr_kl_high_factor = float( adaptive_lr_cfg.get("kl_high_factor", 2.0) ) self.adaptive_lr_kl_low_factor = float( adaptive_lr_cfg.get("kl_low_factor", 0.5) ) self.adaptive_lr_min = float( adaptive_lr_cfg.get("min_learning_rate", 1.0e-6) ) self.adaptive_lr_max = float( adaptive_lr_cfg.get("max_learning_rate", 1.0) ) if self.adaptive_lr_factor <= 1.0: raise ValueError("adaptive_lr.lr_scaler must be > 1.") if self.adaptive_lr_kl_high_factor <= 0.0: raise ValueError("adaptive_lr.kl_high_factor must be > 0.") if self.adaptive_lr_kl_low_factor <= 0.0: raise ValueError("adaptive_lr.kl_low_factor must be > 0.") if self.adaptive_lr_min <= 0.0: raise ValueError("adaptive_lr.min_learning_rate must be > 0.") if self.adaptive_lr_max < self.adaptive_lr_min: raise ValueError( "adaptive_lr.max_learning_rate must be >= " "adaptive_lr.min_learning_rate." ) kl_early_stop_cfg = distributed_update_cfg.get("kl_early_stop", {}) self.kl_early_stop_enabled = bool( kl_early_stop_cfg.get("enabled", False) ) kl_signal_mode = str( kl_early_stop_cfg.get("signal", "window_mean") ).lower() if kl_signal_mode != "window_mean": raise ValueError( "Only distributed_update.kl_early_stop.signal='window_mean' " "is supported." ) self.kl_early_stop_window_size = int( kl_early_stop_cfg.get("window_size", 3) ) self.kl_early_stop_factor = float(kl_early_stop_cfg.get("factor", 2.0)) self.kl_early_stop_min_updates = int( kl_early_stop_cfg.get("min_updates", 1) ) if self.kl_early_stop_window_size < 1: raise ValueError( "distributed_update.kl_early_stop.window_size must be >= 1." ) if self.kl_early_stop_factor <= 0.0: raise ValueError( "distributed_update.kl_early_stop.factor must be > 0." ) if self.kl_early_stop_min_updates < 1: raise ValueError( "distributed_update.kl_early_stop.min_updates must be >= 1." ) if self.kl_early_stop_enabled and self.desired_kl is None: raise ValueError( "distributed_update.kl_early_stop requires desired_kl to be set." ) self.global_advantage_norm = bool( self.config.get("global_advantage_norm", True) ) self.normalize_advantage_per_mini_batch = bool( self.config.get("normalize_advantage_per_mini_batch", False) ) self.distributed_lr_scale_factor = self._compute_lr_scale_factor( distributed_update_cfg.get("lr_scale", {}) ) self.actor_learning_rate = ( self.base_actor_learning_rate * self.distributed_lr_scale_factor ) self.critic_learning_rate = ( self.base_critic_learning_rate * self.distributed_lr_scale_factor ) self._last_update_metrics = { "0-Train/configured_num_mini_batches": float( self.configured_num_mini_batches ), "0-Train/requested_num_mini_batches": float( self.requested_num_mini_batches ), "0-Train/effective_num_mini_batches": float( self.requested_num_mini_batches ), "0-Train/mini_batch_size_per_rank": 0.0, "0-Train/num_updates_executed": 0.0, "0-Train/lr_scale_factor": float(self.distributed_lr_scale_factor), "0-Train/scalable_distributed_update": float( self.distributed_update_mode == "scalable" ), "0-Train/kl_windowed": 0.0, "0-Train/kl_stop_triggered": 0.0, "0-Train/kl_stop_analytic": 0.0, } self._offline_evaluating: bool = False motion_cfg = self.env_config.config.robot.motion sampling_strategy_cfg = motion_cfg.get("sampling_strategy", None) if sampling_strategy_cfg is None: sampling_strategy = "uniform" else: sampling_strategy = str(sampling_strategy_cfg).lower() valid_strategies = {"uniform", "weighted_bin", "curriculum"} if sampling_strategy not in valid_strategies: raise ValueError( f"Invalid sampling_strategy '{sampling_strategy}'. " f"Expected one of {sorted(valid_strategies)}." ) self.sampling_strategy: str = sampling_strategy self.weighted_bin_cfg = dict(motion_cfg.get("weighted_bin", {})) sym_cfg = self.config.get("symmetry_loss", {}) self.symmetry_loss_enabled = bool(sym_cfg.get("enabled", False)) self.symmetry_loss_coef = float(sym_cfg.get("coef", 0.0)) self._sym_dof_perm: torch.Tensor | None = None self._sym_dof_sign: torch.Tensor | None = None self._obs_mirror_map: dict[str, callable] = {} if self._symmetry_loss_active(): self._setup_symmetry() def _resolve_num_mini_batches(self, base_num_mini_batches: int) -> int: if self.distributed_update_mode == "legacy" and self.is_distributed: return max(1, base_num_mini_batches * int(self.gpu_world_size)) return max(1, base_num_mini_batches) def _compute_lr_scale_factor(self, lr_scale_cfg) -> float: scale_mode = str(lr_scale_cfg.get("mode", "none")).lower() if scale_mode not in { "none", "sqrt_world_size", "linear_world_size", }: raise ValueError( "distributed_update.lr_scale.mode must be one of " "{'none', 'sqrt_world_size', 'linear_world_size'}." ) reference_world_size = float( lr_scale_cfg.get("reference_world_size", 1) ) if reference_world_size <= 0.0: raise ValueError( "distributed_update.lr_scale.reference_world_size must be > 0." ) runtime_world_size = float( self.gpu_world_size if self.is_distributed else 1 ) world_ratio = runtime_world_size / reference_world_size if scale_mode == "none": scale = 1.0 elif scale_mode == "sqrt_world_size": scale = math.sqrt(world_ratio) else: scale = world_ratio max_scale = lr_scale_cfg.get("max_scale", None) if max_scale is not None: max_scale = float(max_scale) if max_scale <= 0.0: raise ValueError( "distributed_update.lr_scale.max_scale must be > 0 when set." ) scale = min(scale, max_scale) return float(scale) def _symmetry_loss_active(self) -> bool: return bool( getattr(self, "command_name", None) == "base_velocity" and getattr(self, "symmetry_loss_enabled", False) and float(getattr(self, "symmetry_loss_coef", 0.0)) > 0.0 ) @staticmethod def _omega_or_obj_to_dict(value): if value is None: return {} if OmegaConf.is_config(value): return OmegaConf.to_container(value, resolve=True) if isinstance(value, dict): return value if hasattr(value, "__dict__"): return vars(value) return {} def _setup_symmetry(self) -> None: robot_asset = self.env._env.scene["robot"] joint_names = list(getattr(robot_asset, "joint_names", [])) if len(joint_names) != int(self.num_actions): raise ValueError( "symmetry_loss requires simulator joint_names to match " f"num_actions, got {len(joint_names)} vs {self.num_actions}." ) name_to_idx = {name: idx for idx, name in enumerate(joint_names)} perm: list[int] = [] for name in joint_names: if name.startswith("left_"): mirror_name = "right_" + name[len("left_") :] elif name.startswith("right_"): mirror_name = "left_" + name[len("right_") :] else: mirror_name = name perm.append(int(name_to_idx.get(mirror_name, name_to_idx[name]))) sym_cfg = self._omega_or_obj_to_dict( self.config.get("symmetry_loss", {}) ) sign_by_name = sym_cfg.get("dof_sign_by_name", None) if not sign_by_name: robot_cfg = self._omega_or_obj_to_dict( getattr( getattr(self.env_config, "config", None), "robot", None ) ) sign_by_name = robot_cfg.get("dof_sign_by_name", None) sign_by_name = self._omega_or_obj_to_dict(sign_by_name) if len(sign_by_name) == 0: raise ValueError( "symmetry_loss requires dof_sign_by_name in algo or robot config." ) sign = [float(sign_by_name.get(name, 1.0)) for name in joint_names] self._sym_dof_perm = torch.tensor( perm, device=self.device, dtype=torch.long ) self._sym_dof_sign = torch.tensor( sign, device=self.device, dtype=torch.float32 ) self._build_obs_mirror_map() def _extract_obs_mirror_metadata(self) -> dict[str, dict]: obs_cfg = getattr( getattr(self.env_config, "config", None), "obs", None ) obs_root = self._omega_or_obj_to_dict(obs_cfg) obs_groups = obs_root.get("obs_groups", {}) metadata: dict[str, dict] = {} for group_name, group_cfg in obs_groups.items(): if not isinstance(group_cfg, dict): continue for term_entry in group_cfg.get("atomic_obs_list", []): if not isinstance(term_entry, dict): continue for term_name, term_cfg in term_entry.items(): term_cfg = self._omega_or_obj_to_dict(term_cfg) mirror_func = term_cfg.get("mirror_func", None) if not mirror_func: continue metadata[f"{group_name}/{term_name}"] = { "mirror_func": str(mirror_func), "mirror_config": self._omega_or_obj_to_dict( term_cfg.get("mirror_config", {}) ), } return metadata def _get_actor_schema_terms(self) -> set[str]: module_dict = self._omega_or_obj_to_dict( self.config.get("module_dict", {}) ) actor_cfg = self._omega_or_obj_to_dict(module_dict.get("actor", {})) actor_schema = self._omega_or_obj_to_dict( actor_cfg.get("obs_schema", {}) ) actor_terms: set[str] = set() for seq_cfg in actor_schema.values(): if not isinstance(seq_cfg, dict): continue for term in seq_cfg.get("terms", []): actor_terms.add(str(term)) return actor_terms def _build_obs_mirror_map(self) -> None: from holomotion.src.env.isaaclab_components.isaaclab_observation import ( MirrorFunctions, ) self._obs_mirror_map = {} if self._sym_dof_perm is None or self._sym_dof_sign is None: return term_meta = self._extract_obs_mirror_metadata() actor_terms = self._get_actor_schema_terms() for term in actor_terms: meta = term_meta.get(term) if meta is None: continue mirror_func = str(meta.get("mirror_func", "")) if mirror_func == "mirror_dof": perm = self._sym_dof_perm sign = self._sym_dof_sign def _fn(x, perm=perm, sign=sign): perm_local = perm.to(device=x.device, dtype=torch.long) sign_local = sign.to(device=x.device, dtype=x.dtype) return MirrorFunctions.mirror_dof( x, perm=perm_local, sign=sign_local ) elif mirror_func == "mirror_vec3": def _fn(x): return MirrorFunctions.mirror_vec3(x) elif mirror_func == "mirror_axial_vec3": def _fn(x): return MirrorFunctions.mirror_axial_vec3(x) elif mirror_func == "mirror_velocity_command": def _fn(x): return MirrorFunctions.mirror_velocity_command(x) else: continue self._obs_mirror_map[term] = _fn @staticmethod def _td_key_to_path(key) -> str: if isinstance(key, tuple): return "/".join(str(part) for part in key) return str(key) def _mirror_actor_obs(self, obs_td: TensorDict) -> TensorDict: if ( not self._symmetry_loss_active() or not isinstance(obs_td, TensorDict) or len(getattr(self, "_obs_mirror_map", {})) == 0 ): return obs_td mirrored = TensorDict( {}, batch_size=list(obs_td.batch_size), device=obs_td.device, ) for key in obs_td.keys(include_nested=True, leaves_only=True): key_tuple = key if isinstance(key, tuple) else (key,) value = obs_td.get(key_tuple) mirror_fn = self._obs_mirror_map.get( self._td_key_to_path(key_tuple) ) mirrored.set( key_tuple, mirror_fn(value) if mirror_fn is not None else value, ) return mirrored def _mirror_env_action(self, actions: torch.Tensor) -> torch.Tensor: from holomotion.src.env.isaaclab_components.isaaclab_observation import ( MirrorFunctions, ) if not self._symmetry_loss_active(): return actions if self._sym_dof_perm is None or self._sym_dof_sign is None: raise RuntimeError( "Symmetry DOF permutation/signs are not initialized." ) return MirrorFunctions.mirror_action( actions, perm=self._sym_dof_perm.to( device=actions.device, dtype=torch.long ), sign=self._sym_dof_sign.to( device=actions.device, dtype=actions.dtype ), ) def _compute_analytic_kl( self, old_mu: torch.Tensor, old_sigma: torch.Tensor, new_mu: torch.Tensor, new_sigma: torch.Tensor, weight: torch.Tensor | None = None, ) -> float: with torch.no_grad(): kl_vec = torch.sum( torch.log((new_sigma + 1.0e-8) / (old_sigma + 1.0e-8)) + (torch.square(old_sigma) + torch.square(old_mu - new_mu)) / (2.0 * torch.square(new_sigma) + 1.0e-8) - 0.5, dim=-1, ) if weight is None: kl_sum = kl_vec.sum() kl_count = torch.tensor( float(kl_vec.numel()), device=self.device, dtype=torch.float32, ) else: kl_weight = weight.to(dtype=torch.float32) kl_sum = (kl_vec * kl_weight).sum() kl_count = kl_weight.sum() if self.is_distributed: kl_sum = self.accelerator.reduce(kl_sum, reduction="sum") kl_count = self.accelerator.reduce(kl_count, reduction="sum") kl_mean = kl_sum / kl_count.clamp_min(1.0) return float(kl_mean.item()) def _compute_clip_fraction( self, ratio: torch.Tensor, weight: torch.Tensor | None = None, ) -> float: with torch.no_grad(): clipped = ( (ratio < (1.0 - self.clip_param)) | (ratio > (1.0 + self.clip_param)) ).to(torch.float32) if weight is None: clip_sum = clipped.sum() clip_count = torch.tensor( float(clipped.numel()), device=self.device, dtype=torch.float32, ) else: clip_weight = weight.to(dtype=torch.float32) clip_sum = (clipped * clip_weight).sum() clip_count = clip_weight.sum() if self.is_distributed: clip_sum = self.accelerator.reduce(clip_sum, reduction="sum") clip_count = self.accelerator.reduce( clip_count, reduction="sum" ) clip_fraction = clip_sum / clip_count.clamp_min(1.0) return float(clip_fraction.item()) def _compute_explained_variance( self, target: torch.Tensor, prediction: torch.Tensor, weight: torch.Tensor | None = None, ) -> float: with torch.no_grad(): target_f = target.float() prediction_f = prediction.float() residual = target_f - prediction_f if weight is None: weight_f = torch.ones_like(target_f, dtype=torch.float32) else: weight_f = weight.to(dtype=torch.float32) count = weight_f.sum() target_sum = (target_f * weight_f).sum() target_sq_sum = (target_f.square() * weight_f).sum() residual_sum = (residual * weight_f).sum() residual_sq_sum = (residual.square() * weight_f).sum() if self.is_distributed: count = self.accelerator.reduce(count, reduction="sum") target_sum = self.accelerator.reduce( target_sum, reduction="sum" ) target_sq_sum = self.accelerator.reduce( target_sq_sum, reduction="sum" ) residual_sum = self.accelerator.reduce( residual_sum, reduction="sum" ) residual_sq_sum = self.accelerator.reduce( residual_sq_sum, reduction="sum" ) denom = count.clamp_min(1.0) target_mean = target_sum / denom residual_mean = residual_sum / denom target_var = target_sq_sum / denom - target_mean.square() residual_var = residual_sq_sum / denom - residual_mean.square() if float(target_var.item()) <= 1.0e-8: return 0.0 explained_variance = 1.0 - residual_var / target_var return float(explained_variance.item()) def _set_optimizer_learning_rates(self) -> None: for param_group in self.actor_optimizer.param_groups: param_group["lr"] = self.actor_learning_rate for param_group in self.critic_optimizer.param_groups: param_group["lr"] = self.critic_learning_rate @staticmethod def _validate_entropy_schedule_config( *, initial_entropy_coef: float, anneal_entropy: bool, zero_entropy_point: float, ) -> None: if float(initial_entropy_coef) < 0.0: raise ValueError("entropy_coef must be >= 0.") if anneal_entropy and not (0.0 < float(zero_entropy_point) <= 1.0): raise ValueError( "zero_entropy_point must be in (0.0, 1.0] when " "anneal_entropy is enabled." ) def _get_effective_entropy_coef(self) -> float: if self.initial_entropy_coef <= 0.0 or not self.anneal_entropy: return float(self.initial_entropy_coef) total_learning_iterations = int( getattr( self, "total_learning_iterations", self.current_learning_iteration + int(self.num_learning_iterations), ) ) total_learning_iterations = max(1, total_learning_iterations) zero_entropy_iteration = float(self.zero_entropy_point) * float( total_learning_iterations ) anneal_scale = max( 0.0, 1.0 - float(self.current_learning_iteration) / zero_entropy_iteration, ) return float(self.initial_entropy_coef * anneal_scale) def _apply_adaptive_lr(self, kl_signal: float | None) -> None: if ( self.desired_kl is None or self.schedule != "adaptive" or kl_signal is None ): return if kl_signal > self.desired_kl * self.adaptive_lr_kl_high_factor: self.actor_learning_rate = max( self.adaptive_lr_min, self.actor_learning_rate / self.adaptive_lr_factor, ) if self.adaptive_lr_adapt_critic: self.critic_learning_rate = max( self.adaptive_lr_min, self.critic_learning_rate / self.adaptive_lr_factor, ) elif ( kl_signal > 0.0 and kl_signal < self.desired_kl * self.adaptive_lr_kl_low_factor ): self.actor_learning_rate = min( self.adaptive_lr_max, self.actor_learning_rate * self.adaptive_lr_factor, ) if self.adaptive_lr_adapt_critic: self.critic_learning_rate = min( self.adaptive_lr_max, self.critic_learning_rate * self.adaptive_lr_factor, ) self._set_optimizer_learning_rates() def _compute_windowed_kl_signal( self, recent_analytic_kls: list[float] ) -> float | None: if len(recent_analytic_kls) < self.kl_early_stop_window_size: return None window = recent_analytic_kls[-self.kl_early_stop_window_size :] return float(sum(window) / len(window)) def _should_early_stop_for_kl( self, kl_signal: float | None, num_kl_measurements: int, ) -> bool: if not self.kl_early_stop_enabled or self.desired_kl is None: return False if kl_signal is None: return False required_measurements = max( self.kl_early_stop_min_updates, self.kl_early_stop_window_size ) if num_kl_measurements < required_measurements: return False return kl_signal > self.desired_kl * self.kl_early_stop_factor def _setup_data_buffers(self): super()._setup_data_buffers() self.use_velocity_transition: bool = ( self.command_name == "base_velocity" ) self.transition_cls = ( PpoVelocityTransition if self.use_velocity_transition else PpoTransition ) self.transition_td: PpoTransition | PpoVelocityTransition | None = None def _build_optimizer_kwargs(self, optimizer_class: type) -> dict: if self.optimizer_type != "AdamW": return {} signature = inspect.signature(optimizer_class.__init__) parameters = signature.parameters use_fused = bool( self.config.get( "adamw_use_fused", bool(self.device.type == "cuda") ) ) use_foreach = bool(self.config.get("adamw_use_foreach", True)) if ( use_fused and ("fused" in parameters) and (self.device.type == "cuda") ): return {"fused": True} if use_foreach and ("foreach" in parameters): return {"foreach": True} return {} def _setup_models_and_optimizer(self): from holomotion.src.modules.agent_modules import PPOActor, PPOCritic # Build sample TensorDict for schema-based assembly sample_obs_dict = self.env.reset_all()[0] sample_td = self._wrap_obs_dict(sample_obs_dict) actor_cfg = OmegaConf.to_container( self.config.module_dict.actor, resolve=True ) critic_cfg = OmegaConf.to_container( self.config.module_dict.critic, resolve=True ) self.actor_type = actor_cfg.get("type", "MLP") self.critic_type = critic_cfg.get("type", "MLP") actor_schema = actor_cfg.get("obs_schema", None) critic_schema = critic_cfg.get("obs_schema", None) self.actor = PPOActor( obs_schema=actor_schema, module_config_dict=actor_cfg, num_actions=self.num_actions, init_noise_std=self.config.init_noise_std, obs_example=sample_td, ).to(self.device) self.critic = PPOCritic( obs_schema=critic_schema, module_config_dict=critic_cfg, obs_example=sample_td, ).to(self.device) if self.is_main_process: actor = self.accelerator.unwrap_model(self.actor) critic = self.accelerator.unwrap_model(self.critic) logger.info("Actor (TensorDict module):\n{!r}", actor) logger.info( "Actor keys: in_keys={} out_keys={}", list(actor.in_keys), list(actor.out_keys), ) logger.info("Actor core nn module:\n{!r}", actor.actor_module) logger.info("Critic (TensorDict module):\n{!r}", critic) logger.info( "Critic keys: in_keys={} out_keys={}", list(critic.in_keys), list(critic.out_keys), ) logger.info("Critic core nn module:\n{!r}", critic.critic_module) # Log actor and critic parameter counts (in millions) actor_params = sum(p.numel() for p in self.actor.parameters()) critic_params = sum(p.numel() for p in self.critic.parameters()) params_table = [ ["Actor", f"{actor_params / 1.0e6:.3f}"], ["Critic", f"{critic_params / 1.0e6:.3f}"], ["Total", f"{(actor_params + critic_params) / 1.0e6:.3f}"], ] logger.info( "Model Summary:\n" + tabulate( params_table, headers=["Model", "Params (M)"], tablefmt="simple_outline", ) ) optimizer_class = getattr(optim, self.optimizer_type) optimizer_kwargs = self._build_optimizer_kwargs(optimizer_class) self.actor_optimizer = optimizer_class( self.actor.parameters(), lr=self.actor_learning_rate, betas=(self.actor_beta1, self.actor_beta2), **optimizer_kwargs, ) self.critic_optimizer = optimizer_class( self.critic.parameters(), lr=self.critic_learning_rate, betas=(self.critic_beta1, self.critic_beta2), **optimizer_kwargs, ) dynamo_backend = self.config.get("dynamo_backend", None) if dynamo_backend and self.is_main_process: logger.info( f"Models will be compiled with dynamo_backend='{dynamo_backend}' " "during accelerator.prepare()" ) ( self.actor, self.critic, self.actor_optimizer, self.critic_optimizer, ) = self.accelerator.prepare( self.actor, self.critic, self.actor_optimizer, self.critic_optimizer, ) def _build_storage(self, obs_td: TensorDict): return RolloutStorage( self.num_envs, self.num_steps_per_env, obs_template=obs_td, actions_shape=[self.num_actions], device=self.device, command_name=self.command_name, transition_cls=self.transition_cls, ) def _build_transition( self, obs_td: TensorDict, actor_out: TensorDict, critic_out: TensorDict, ): actions = actor_out.get("actions") actions_log_prob = actor_out.get("actions_log_prob") mu = actor_out.get("mu") sigma = actor_out.get("sigma") values = critic_out.get("values") zero_scalar = torch.zeros( self.num_envs, 1, device=self.device, dtype=torch.float32, ) zero_scalar_bool = torch.zeros( self.num_envs, 1, device=self.device, dtype=torch.bool, ) transition_kwargs = { "obs": obs_td, "actions": actions.detach(), "teacher_actions": torch.zeros_like(actions), "mu": mu.detach(), "sigma": sigma.detach(), "actions_log_prob": actions_log_prob[..., None].detach(), "values": values.detach(), "rewards": zero_scalar.clone(), "dones": zero_scalar_bool, "returns": zero_scalar.clone(), "advantages": zero_scalar.clone(), "batch_size": [self.num_envs], "device": self.device, } if self.use_velocity_transition: import isaaclab.envs.mdp as isaaclab_mdp velocity_cmd = isaaclab_mdp.generated_commands( self.env._env, command_name="base_velocity" ) if velocity_cmd.shape[-1] > 3: velocity_cmd = velocity_cmd[..., :3] move_mask = (velocity_cmd.norm(dim=-1) > 0.1).to( dtype=velocity_cmd.dtype ) transition_kwargs["velocity_commands"] = torch.cat( [move_mask[..., None], velocity_cmd], dim=-1, ).detach() return self.transition_cls(**transition_kwargs) def _post_iteration_hook(self, it: int) -> None: if self.command_name == "ref_motion": motion_cmd = self.env._env.command_manager.get_term("ref_motion") motion_cmd.apply_cache_swap_if_pending_barrier( accelerator=self.accelerator ) def _post_training_hook(self) -> None: if self.command_name == "ref_motion": motion_cmd = self.env._env.command_manager.get_term("ref_motion") if motion_cmd is not None: motion_cmd.close() def _get_mean_policy_std(self) -> torch.Tensor: base_actor = self.accelerator.unwrap_model(self.actor) if hasattr(base_actor, "std"): return base_actor.std.mean() if hasattr(base_actor, "log_std"): return torch.exp(base_actor.log_std).mean() return torch.tensor(0.0, device=self.device) def _maybe_override_loaded_actor_sigma(self) -> None: if not bool(self.config.get("override_sigma", False)): return sigma_override = self.config.get("sigma_override", None) if sigma_override is None: raise ValueError( "config.override_sigma is enabled but config.sigma_override is not set." ) actor_unwrapped = self.accelerator.unwrap_model(self.actor) orig_mod = getattr(actor_unwrapped, "_orig_mod", None) if orig_mod is not None: actor_unwrapped = orig_mod override_sigma = getattr(actor_unwrapped, "override_sigma", None) if override_sigma is None: raise AttributeError( f"{type(actor_unwrapped).__name__} does not implement override_sigma()." ) override_sigma(sigma_override) if self.is_main_process: logger.info( "Reapplied sigma override after checkpoint load: {}", sigma_override, ) def _get_additional_log_metrics(self) -> dict[str, float]: """Build auxiliary training/cache metrics.""" iteration_metrics = {} if "actor_learning_rate" in self.__dict__: iteration_metrics["0-Train/actor_learning_rate"] = float( self.actor_learning_rate ) if "critic_learning_rate" in self.__dict__: iteration_metrics["0-Train/critic_learning_rate"] = float( self.critic_learning_rate ) if "initial_entropy_coef" in self.__dict__: iteration_metrics["0-Train/entropy_coef_effective"] = float( self._get_effective_entropy_coef() ) if "_last_update_metrics" in self.__dict__: iteration_metrics.update(self._last_update_metrics) mean_std = self._get_mean_policy_std() iteration_metrics["0-Train/mean_noise_std"] = float(mean_std.item()) if self.command_name != "ref_motion": return iteration_metrics motion_cmd = self.env._env.command_manager.get_term("ref_motion") cache = motion_cmd._motion_cache iteration_metrics["1-Perf/Cache/swap_index"] = float(cache.swap_index) pool_stats = cache.cache_curriculum_pool_statistics() if pool_stats is not None: core_cache_metric_keys = { "prioritized_pool_size": "1-Perf/Cache/prioritized_pool_size", "prioritized_pool_mean_score": "1-Perf/Cache/prioritized_pool_mean_score", "uniform_pool_mean_score": "1-Perf/Cache/uniform_pool_mean_score", "entered_prioritized_pool_count": "1-Perf/Cache/entered_prioritized_pool_count", "exited_prioritized_pool_count": "1-Perf/Cache/exited_prioritized_pool_count", } for src_key, dst_key in core_cache_metric_keys.items(): if src_key in pool_stats: iteration_metrics[dst_key] = float(pool_stats[src_key]) return iteration_metrics def update(self): mean_value_loss = 0.0 mean_surrogate_loss = 0.0 mean_entropy = 0.0 mean_kl_analytic = 0.0 mean_symmetry_loss = 0.0 critic_explained_variance = self._compute_explained_variance( target=self.storage.data["returns"], prediction=self.storage.data["values"], ) batch_size = int( self.storage.num_envs * self.storage.num_transitions_per_env ) ( effective_num_mini_batches, mini_batch_size, ) = RolloutStorage.resolve_mini_batch_partition( batch_size, self.num_mini_batches ) self._last_update_metrics = { "0-Train/configured_num_mini_batches": float( self.configured_num_mini_batches ), "0-Train/requested_num_mini_batches": float( self.requested_num_mini_batches ), "0-Train/effective_num_mini_batches": float( effective_num_mini_batches ), "0-Train/mini_batch_size_per_rank": float(mini_batch_size), "0-Train/num_updates_executed": 0.0, "0-Train/lr_scale_factor": float(self.distributed_lr_scale_factor), "0-Train/scalable_distributed_update": float( self.distributed_update_mode == "scalable" ), "0-Train/kl_windowed": 0.0, "0-Train/kl_stop_triggered": 0.0, "0-Train/kl_stop_analytic": 0.0, "0-Train/kl_analytic_batch_last": 0.0, "0-Train/kl_analytic_batch_max": 0.0, "0-Train/clip_fraction_batch_mean": 0.0, "0-Train/clip_fraction_batch_last": 0.0, } entropy_coef = self._get_effective_entropy_coef() generator = self.storage.iter_minibatches( effective_num_mini_batches, self.num_learning_epochs, ) measure_analytic_kl = self.desired_kl is not None num_updates = 0 num_kl_measurements = 0 kl_stop_triggered = False kl_stop_analytic = 0.0 kl_windowed = None recent_analytic_kls: list[float] = [] kl_analytic_batch_last = 0.0 kl_analytic_batch_max = 0.0 clip_fraction_batch_mean = 0.0 clip_fraction_batch_last = 0.0 for batch in generator: obs_batch = batch.obs actions_batch = batch.actions target_values_batch = batch.values advantages_batch = batch.advantages returns_batch = batch.returns old_actions_log_prob_batch = batch.actions_log_prob old_mu_batch = batch.mu old_sigma_batch = batch.sigma with self.accelerator.autocast(): actor_out = self.actor( obs_batch, actions=actions_batch, mode="logp", update_obs_norm=False, ) critic_out = self.critic(obs_batch, update_obs_norm=False) actions_log_prob_batch = actor_out.get("actions_log_prob") mu_batch = actor_out.get("mu") sigma_batch = actor_out.get("sigma") entropy_batch = actor_out.get("entropy") value_pred = critic_out.get("values") symmetry_loss = None if self._symmetry_loss_active(): mirrored_obs_batch = self._mirror_actor_obs(obs_batch) mirrored_actor_out = self.actor( mirrored_obs_batch, actions=None, mode="inference", update_obs_norm=False, ) mirrored_actions = mirrored_actor_out.get("actions") mirrored_actions_back = self._mirror_env_action( mirrored_actions ) symmetry_loss = F.mse_loss( mu_batch.float(), mirrored_actions_back.float(), ) value_batch = value_pred returns_batch_norm = returns_batch target_values_batch_norm = target_values_batch analytic_kl = None if measure_analytic_kl: analytic_kl = self._compute_analytic_kl( old_mu=old_mu_batch.float(), old_sigma=old_sigma_batch.float(), new_mu=mu_batch.float(), new_sigma=sigma_batch.float(), ) mean_kl_analytic += analytic_kl num_kl_measurements += 1 kl_analytic_batch_last = analytic_kl kl_analytic_batch_max = max(kl_analytic_batch_max, analytic_kl) recent_analytic_kls.append(analytic_kl) if len(recent_analytic_kls) > self.kl_early_stop_window_size: recent_analytic_kls.pop(0) kl_windowed = self._compute_windowed_kl_signal( recent_analytic_kls ) if self._should_early_stop_for_kl( kl_windowed, num_kl_measurements ): kl_stop_triggered = True kl_stop_analytic = analytic_kl break ratio = torch.exp( actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch).float() ) clip_fraction = self._compute_clip_fraction(ratio) clip_fraction_batch_mean += clip_fraction clip_fraction_batch_last = clip_fraction surrogate = -torch.squeeze(advantages_batch) * ratio surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp( ratio, 1.0 - self.clip_param, 1.0 + self.clip_param ) surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() if self.use_clipped_value_loss: value_clipped = target_values_batch_norm + ( value_batch - target_values_batch_norm ).clamp(-self.clip_param, self.clip_param) value_losses = (value_batch - returns_batch_norm).pow(2) value_losses_clipped = ( value_clipped - returns_batch_norm ).pow(2) value_loss = torch.max( value_losses, value_losses_clipped ).mean() else: value_loss = (returns_batch_norm - value_batch).pow(2).mean() actor_loss = surrogate_loss critic_loss = self.value_loss_coef * value_loss if entropy_coef > 0.0: entropy_loss = entropy_batch.mean() actor_loss = actor_loss - entropy_coef * entropy_loss if symmetry_loss is not None: actor_loss = ( actor_loss + self.symmetry_loss_coef * symmetry_loss ) self.actor_optimizer.zero_grad() self.critic_optimizer.zero_grad() self.accelerator.backward(actor_loss) self.accelerator.backward(critic_loss) if self.max_grad_norm is not None: self.accelerator.clip_grad_norm_( self.actor.parameters(), self.max_grad_norm, ) self.accelerator.clip_grad_norm_( self.critic.parameters(), self.max_grad_norm, ) self.actor_optimizer.step() self.critic_optimizer.step() num_updates += 1 mean_value_loss += float(value_loss.item()) mean_surrogate_loss += float(surrogate_loss.item()) mean_entropy += float(entropy_batch.mean().item()) if symmetry_loss is not None: mean_symmetry_loss += float(symmetry_loss.item()) denom = max(1, num_updates) mean_value_loss /= denom mean_surrogate_loss /= denom mean_entropy /= denom mean_kl_analytic /= max(1, num_kl_measurements) mean_symmetry_loss /= denom clip_fraction_batch_mean /= denom if self.schedule == "adaptive": self._apply_adaptive_lr(kl_windowed) self._last_update_metrics["0-Train/num_updates_executed"] = float( num_updates ) self._last_update_metrics["0-Train/kl_windowed"] = float( kl_windowed or 0.0 ) self._last_update_metrics["0-Train/kl_stop_triggered"] = float( kl_stop_triggered ) self._last_update_metrics["0-Train/kl_stop_analytic"] = float( kl_stop_analytic ) self._last_update_metrics["0-Train/kl_analytic_batch_last"] = float( kl_analytic_batch_last ) self._last_update_metrics["0-Train/kl_analytic_batch_max"] = float( kl_analytic_batch_max ) self._last_update_metrics["0-Train/clip_fraction_batch_mean"] = float( clip_fraction_batch_mean ) self._last_update_metrics["0-Train/clip_fraction_batch_last"] = float( clip_fraction_batch_last ) self.storage.clear() loss_out = { "value_function": mean_value_loss, "critic_explained_variance": critic_explained_variance, "surrogate": mean_surrogate_loss, "entropy": mean_entropy, "kl_analytic": mean_kl_analytic, } if self._symmetry_loss_active(): loss_out["symmetry_loss"] = mean_symmetry_loss # Reduce losses across processes for consistent logging on rank 0 if self.is_distributed: reduced_out = {} for k, v in loss_out.items(): if v is None: reduced_out[k] = None continue t = torch.tensor(v, device=self.device, dtype=torch.float32) reduced_t = self.accelerator.reduce(t, reduction="mean") reduced_out[k] = float(reduced_t.item()) loss_out = reduced_out self._post_update_hook(loss_out) return loss_out def load(self, ckpt_path): if ckpt_path is None: return None if self.is_main_process: logger.info(f"Loading checkpoint from {ckpt_path}") actor_model_path = self._resolve_model_file_path(ckpt_path, "actor") critic_model_path = self._resolve_model_file_path(ckpt_path, "critic") self._load_accelerate_model(self.actor, actor_model_path, strict=True) self._load_accelerate_model( self.critic, critic_model_path, strict=True ) loaded_dict = torch.load(ckpt_path, map_location=self.device) if not getattr(self, "is_offline_eval", False): self._restore_optimizer_state( self.actor_optimizer, loaded_dict["actor_optimizer_state_dict"], optimizer_name="actor", ) self._restore_optimizer_state( self.critic_optimizer, loaded_dict["critic_optimizer_state_dict"], optimizer_name="critic", ) elif self.is_main_process: logger.info( "Skipping optimizer state restore during offline evaluation." ) self.current_learning_iteration = loaded_dict.get("iter", 0) self._maybe_override_loaded_actor_sigma() self._load_extra_checkpoint_state(loaded_dict) return loaded_dict.get("infos", None) def _restore_optimizer_state( self, optimizer, loaded_state_dict, *, optimizer_name: str, ) -> bool: compatible, reason = self._optimizer_state_is_compatible( optimizer, loaded_state_dict ) if not compatible: if self.is_main_process: logger.warning( "Skipping {} optimizer state restore from checkpoint: {}", optimizer_name, reason, ) return False try: optimizer.load_state_dict(loaded_state_dict) except ValueError as exc: if self.is_main_process: logger.warning( "Skipping {} optimizer state restore from checkpoint: {}", optimizer_name, exc, ) return False return True def _optimizer_state_is_compatible( self, optimizer, loaded_state_dict ) -> tuple[bool, str | None]: current_state_dict = optimizer.state_dict() current_groups = current_state_dict.get("param_groups") loaded_groups = loaded_state_dict.get("param_groups") if not isinstance(current_groups, list) or not isinstance( loaded_groups, list ): return True, None if len(current_groups) != len(loaded_groups): return ( False, "param group count mismatch " f"(current={len(current_groups)}, loaded={len(loaded_groups)})", ) for group_idx, (current_group, loaded_group) in enumerate( zip(current_groups, loaded_groups) ): current_param_count = len(current_group.get("params", [])) loaded_param_count = len(loaded_group.get("params", [])) if current_param_count != loaded_param_count: return ( False, "param group size mismatch for group " f"{group_idx} (current={current_param_count}, " f"loaded={loaded_param_count})", ) return True, None def save(self, path, infos=None): if not self.is_main_process: return logger.info(f"Saving checkpoint to {path}") base_path = path.replace(".pt", "") os.makedirs( os.path.dirname(base_path) if os.path.dirname(base_path) else ".", exist_ok=True, ) self.accelerator.save_model( self.actor, os.path.join(base_path, "actor") ) self.accelerator.save_model( self.critic, os.path.join(base_path, "critic") ) custom_state = { "actor_optimizer_state_dict": self.actor_optimizer.state_dict(), "critic_optimizer_state_dict": self.critic_optimizer.state_dict(), "iter": self.current_learning_iteration, "infos": infos, } custom_state.update(self._extra_checkpoint_state()) torch.save(_checkpoint_state_to_cpu(custom_state), path) if bool(self.config.get("export_policy", False)): export_policy_to_onnx_common( self, path, onnx_name_suffix=self.config.get("onnx_name_suffix", None), use_kv_cache=bool(self.config.get("use_kv_cache", True)), ) def offline_evaluate_policy(self, dump_npzs: bool = False): """Dump NPZs (no metrics) from validation cache using ref_motion command. - Iterates validation batches; env i -> clip i (deterministic) starting at frame 0. - Collect robot and reference sequences each step and save one NPZ per clip. - NPZ conforms to holomotion_retargeted format keys. - Optionally records viewport MP4(s) aligned with target_fps and rollout length. """ ckpt_path = self.config.checkpoint n_fut_frames = self.env.config.commands.ref_motion.params.get( "n_fut_frames", 8 ) # log_dir is already set to checkpoint directory in eval script model_name = os.path.basename(ckpt_path).replace(".pt", "") # Eval modes (freeze normalizers if enabled) self.actor.eval() self.critic.eval() # Require ref_motion command and simple cache backend command_name = list(self.env.config.commands.keys())[0] if command_name != "ref_motion": logger.warning( "Offline evaluation only supported for ref_motion command" ) return {} motion_cmd = self.env._env.command_manager.get_term("ref_motion") cache = getattr(motion_cmd, "_motion_cache", None) if cache is None: logger.error( "Offline evaluation requires hdf5_simple cache backend (no LMDB support)" ) return {} self._offline_evaluating = True # Evaluation flag and cache batch-size adjustment (ensure batch_size == num_envs) motion_cmd._is_evaluating = True num_envs = self.env.num_envs try: if getattr(cache, "_batch_size", None) != num_envs: from holomotion.src.training.h5_dataloader import ( MotionClipBatchCache, ) cache = MotionClipBatchCache( train_dataset=cache._datasets["train"], val_dataset=cache._datasets["val"], batch_size=num_envs, stage_device=getattr(cache, "_stage_device", None), num_workers=getattr(cache, "_num_workers", 0), prefetch_factor=getattr(cache, "_prefetch_factor", None), pin_memory=getattr(cache, "_pin_memory", True), persistent_workers=getattr( cache, "_persistent_workers", False ), sampler_rank=getattr(cache, "_sampler_rank", 0), sampler_world_size=getattr( cache, "_sampler_world_size", 1 ), allowed_prefixes=getattr(cache, "_allowed_prefixes", None), swap_interval_steps=getattr( cache, "swap_interval_steps", None ), force_timeout_on_swap=getattr( cache, "force_timeout_on_swap", True ), seed=getattr(cache, "_seed", None), loader_timeout=getattr(cache, "_loader_timeout", 0.0), ) motion_cmd._motion_cache = cache except Exception as e: logger.warning( f"Offline eval: failed to rebuild cache to batch_size={num_envs}: {e}" ) # Derive HDF5 dataset base name (from validation dataset root) for output naming dataset_suffix = None val_dataset = cache._datasets["val"] dataset_root = None if hasattr(val_dataset, "hdf5_root"): dataset_root = str(val_dataset.hdf5_root).rstrip(os.sep) elif hasattr(val_dataset, "ts_roots"): ts_roots = getattr(val_dataset, "ts_roots") if ts_roots: dataset_root = str(ts_roots[0]).rstrip(os.sep) if dataset_root: dataset_suffix = os.path.basename(dataset_root) # Output directory (respect existing log_dir derived from checkpoint) suffix = f"isaaclab_eval_output_{model_name}" if dataset_suffix is not None: suffix = f"{suffix}_{dataset_suffix}" output_dir = os.path.join(self.log_dir, suffix) os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving evaluation outputs to: {output_dir}") # Switch to validation cache and iterate all batches if hasattr(cache, "set_mode"): cache.set_mode("val") # Determine policy/video FPS from command config (align wallclock time) motion_fps = int(getattr(motion_cmd.cfg, "target_fps", 50)) total_batches = int(getattr(cache, "num_batches", 1)) with torch.no_grad(): for batch_idx in tqdm( range(total_batches), desc="Evaluating batches" ): if batch_idx > 0: cache.advance() # Reset envs first, then apply deterministic mapping on the active cache batch _ = self.env.reset_all() if hasattr(motion_cmd, "setup_offline_eval_deterministic"): motion_cmd.setup_offline_eval_deterministic( apply_pending_swap=False ) self._reset_rollout_forward_state() # Read current batch metadata AFTER reset + setup current = getattr(cache, "current_batch", None) if current is None or not hasattr(current, "motion_keys"): logger.warning( "Current cache batch missing motion_keys; skipping batch" ) continue motion_keys = list(current.motion_keys) raw_motion_keys = list( getattr(current, "raw_motion_keys", current.motion_keys) ) # Determine active env count for this batch clip_count = int(cache.clip_count) active_count = min(num_envs, clip_count) if active_count > 0: active_ids = torch.arange( active_count, dtype=torch.long, device=self.device, ) motion_cmd.force_realign_offline_eval_no_perturb( active_ids ) # Recompute observations after deterministic setup obs_mgr = self.env._env.observation_manager if active_count > 0: obs_mgr.reset(active_ids) obs_dict = obs_mgr.compute(update_history=True) else: obs_dict = obs_mgr.compute(update_history=True) obs = self._wrap_obs_dict(obs_dict) # Map env -> motion_key for active envs env_motion_keys = { int(i): motion_keys[int(i)] for i in range(active_count) } env_raw_motion_keys = { int(i): raw_motion_keys[int(i)] for i in range(active_count) } # Prepare per-env collectors env_has_done = torch.zeros( num_envs, dtype=torch.bool, device=self.device ) episode_lengths = torch.zeros( num_envs, dtype=torch.long, device=self.device ) active_mask = torch.zeros( num_envs, dtype=torch.bool, device=self.device ) if active_count > 0: active_mask[:active_count] = True # Reference collectors (URDF order) ref_dof_pos = [[] for _ in range(active_count)] ref_dof_vel = [[] for _ in range(active_count)] ref_body_pos = [[] for _ in range(active_count)] ref_body_rot_xyzw = [[] for _ in range(active_count)] ref_body_vel = [[] for _ in range(active_count)] ref_body_ang_vel = [[] for _ in range(active_count)] # Robot collectors (URDF order) robot_dof_pos = [[] for _ in range(active_count)] robot_dof_vel = [[] for _ in range(active_count)] robot_body_pos = [[] for _ in range(active_count)] robot_body_rot_xyzw = [[] for _ in range(active_count)] robot_body_vel = [[] for _ in range(active_count)] robot_body_ang_vel = [[] for _ in range(active_count)] robot_dof_acc = [[] for _ in range(active_count)] robot_dof_torque = [[] for _ in range(active_count)] robot_action_rate = [[] for _ in range(active_count)] prev_robot_dof_vel = [None for _ in range(active_count)] prev_robot_actions = [None for _ in range(active_count)] step_dt = float(self.env._env.step_dt) # Per-env bookkeeping clip_lengths_np = ( current.lengths.detach().cpu().numpy() if hasattr(current, "lengths") else np.array( [getattr(cache, "max_frame_length", 1000)] * active_count ) ) # Persist an explicit mapping file for verification try: mapping_records = [] for i in range(active_count): mapping_records.append( { "env_id": int(i), "motion_key": env_motion_keys[int(i)], "raw_motion_key": env_raw_motion_keys[int(i)], "clip_length": int(clip_lengths_np[int(i)]), } ) mapping_path = os.path.join( output_dir, f"batch_{batch_idx:04d}_mapping.json" ) with open(mapping_path, "w") as f: json.dump(mapping_records, f, indent=2) except Exception: pass env_frame_counts = [0 for _ in range(active_count)] encountered_done = [False for _ in range(active_count)] valid_masks = [[] for _ in range(active_count)] def _sanitize_key(key: str) -> str: return ( key.replace("/", "+") .replace(" ", "_") .replace("\\", "+") ) def _get_out_path(idx: int) -> str: out_name = f"{_sanitize_key(env_motion_keys[idx])}.npz" return os.path.join(output_dir, out_name) def _save_env_npz(idx: int): if idx >= active_count: return # Total collected frames total_len = int(min(env_frame_counts[idx], max_steps)) if total_len <= 0: return # Compute contiguous valid prefix length and slice_len vm = valid_masks[idx][:total_len] valid_prefix_len = 0 for b in vm: if b: valid_prefix_len += 1 else: break clip_len = int(clip_lengths_np[idx]) slice_len = int(min(valid_prefix_len, clip_len, total_len)) if slice_len <= 0: return # Reference arrays (sliced) ref_dof_pos_arr = np.stack( ref_dof_pos[idx][:slice_len], axis=0 ).astype(np.float32) ref_dof_vel_arr = np.stack( ref_dof_vel[idx][:slice_len], axis=0 ).astype(np.float32) ref_body_pos_arr = np.stack( ref_body_pos[idx][:slice_len], axis=0 ).astype(np.float32) ref_body_rot_xyzw_arr = np.stack( ref_body_rot_xyzw[idx][:slice_len], axis=0 ).astype(np.float32) ref_body_vel_arr = np.stack( ref_body_vel[idx][:slice_len], axis=0 ).astype(np.float32) ref_body_ang_vel_arr = np.stack( ref_body_ang_vel[idx][:slice_len], axis=0 ).astype(np.float32) # Robot arrays (sliced) robot_dof_pos_arr = np.stack( robot_dof_pos[idx][:slice_len], axis=0 ).astype(np.float32) robot_dof_vel_arr = np.stack( robot_dof_vel[idx][:slice_len], axis=0 ).astype(np.float32) robot_dof_acc_arr = np.stack( robot_dof_acc[idx][:slice_len], axis=0 ).astype(np.float32) robot_dof_torque_arr = np.stack( robot_dof_torque[idx][:slice_len], axis=0 ).astype(np.float32) robot_action_rate_arr = np.asarray( robot_action_rate[idx][:slice_len], dtype=np.float32 ) robot_body_pos_arr = np.stack( robot_body_pos[idx][:slice_len], axis=0 ).astype(np.float32) robot_body_rot_xyzw_arr = np.stack( robot_body_rot_xyzw[idx][:slice_len], axis=0 ).astype(np.float32) robot_body_vel_arr = np.stack( robot_body_vel[idx][:slice_len], axis=0 ).astype(np.float32) robot_body_ang_vel_arr = np.stack( robot_body_ang_vel[idx][:slice_len], axis=0 ).astype(np.float32) # Metadata motion_fps = int(getattr(motion_cmd.cfg, "target_fps", 50)) num_dofs = int(ref_dof_pos_arr.shape[1]) num_bodies = int(ref_body_pos_arr.shape[1]) wallclock_len = ( float(slice_len - 1) / float(motion_fps) if motion_fps > 0 and slice_len > 0 else 0.0 ) meta = { "motion_key": env_motion_keys[idx], "raw_motion_key": env_raw_motion_keys[idx], "motion_fps": float(motion_fps), "num_frames": int(slice_len), "wallclock_len": float(wallclock_len), "num_dofs": int(num_dofs), "num_bodies": int(num_bodies), "clip_length": int(clip_lengths_np[idx]), "valid_prefix_len": int(valid_prefix_len), } # Output filename: flattened motion_key out_path = _get_out_path(idx) np.savez_compressed( out_path, metadata=json.dumps(meta), robot_dof_pos=robot_dof_pos_arr, robot_dof_vel=robot_dof_vel_arr, robot_dof_acc=robot_dof_acc_arr, robot_dof_torque=robot_dof_torque_arr, robot_action_rate=robot_action_rate_arr, robot_global_translation=robot_body_pos_arr, robot_global_rotation_quat=robot_body_rot_xyzw_arr, robot_global_velocity=robot_body_vel_arr, robot_global_angular_velocity=robot_body_ang_vel_arr, ref_dof_pos=ref_dof_pos_arr, ref_dof_vel=ref_dof_vel_arr, ref_global_translation=ref_body_pos_arr, ref_global_rotation_quat=ref_body_rot_xyzw_arr, ref_global_velocity=ref_body_vel_arr, ref_global_angular_velocity=ref_body_ang_vel_arr, ) max_steps = int( getattr(cache, "max_frame_length", 1000) ) # decide the max_length to evaluate for rollout_step in tqdm( range(max_steps), desc="Rollout steps" ): # PRE-STEP: collect states for all active envs active = [i for i in range(active_count)] if len(active) > 0: # Reference step tensors (URDF order) ref_dp = ( motion_cmd.get_ref_motion_dof_pos_cur_urdf_order() .detach() .cpu() .numpy() ) ref_dv = ( motion_cmd.get_ref_motion_dof_vel_cur_urdf_order() .detach() .cpu() .numpy() ) ref_bp = ( motion_cmd.get_ref_motion_bodylink_global_pos_cur_urdf_order() .detach() .cpu() .numpy() ) ref_br = ( motion_cmd.get_ref_motion_bodylink_global_rot_xyzw_cur_urdf_order() .detach() .cpu() .numpy() ) ref_bv = ( motion_cmd.get_ref_motion_bodylink_global_lin_vel_cur_urdf_order() .detach() .cpu() .numpy() ) ref_bav = ( motion_cmd.get_ref_motion_bodylink_global_ang_vel_cur_urdf_order() .detach() .cpu() .numpy() ) # Robot step tensors (URDF order) rob_dp = ( motion_cmd.robot_dof_pos_cur_urdf_order.detach() .cpu() .numpy() ) rob_dv = ( motion_cmd.robot_dof_vel_cur_urdf_order.detach() .cpu() .numpy() ) rob_bp = ( motion_cmd.robot_bodylink_global_pos_cur_urdf_order.detach() .cpu() .numpy() ) rob_br = ( motion_cmd.robot_bodylink_global_rot_xyzw_cur_urdf_order.detach() .cpu() .numpy() ) rob_bv = ( motion_cmd.robot_bodylink_global_lin_vel_cur_urdf_order.detach() .cpu() .numpy() ) rob_bav = ( motion_cmd.robot_bodylink_global_ang_vel_cur_urdf_order.detach() .cpu() .numpy() ) for idx in active: if prev_robot_dof_vel[idx] is None: dof_acc_cur = np.zeros_like( rob_dv[idx], dtype=np.float32 ) else: dof_acc_cur = ( rob_dv[idx] - prev_robot_dof_vel[idx] ) / step_dt prev_robot_dof_vel[idx] = rob_dv[idx].copy() ref_dof_pos[idx].append(ref_dp[idx]) ref_dof_vel[idx].append(ref_dv[idx]) ref_body_pos[idx].append(ref_bp[idx]) ref_body_rot_xyzw[idx].append(ref_br[idx]) ref_body_vel[idx].append(ref_bv[idx]) ref_body_ang_vel[idx].append(ref_bav[idx]) robot_dof_pos[idx].append(rob_dp[idx]) robot_dof_vel[idx].append(rob_dv[idx]) robot_dof_acc[idx].append( dof_acc_cur.astype(np.float32) ) robot_body_pos[idx].append(rob_bp[idx]) robot_body_rot_xyzw[idx].append(rob_br[idx]) robot_body_vel[idx].append(rob_bv[idx]) robot_body_ang_vel[idx].append(rob_bav[idx]) # Record valid mask for current frame (before step) clip_limit = int(clip_lengths_np[idx]) valid_now = ( (idx < active_count) and (not encountered_done[idx]) and ( env_frame_counts[idx] < clip_limit - n_fut_frames ) ) valid_masks[idx].append(bool(valid_now)) # Increment local frame counter env_frame_counts[idx] += 1 # No mid-rollout finalize; we defer to end using valid masks # Inference and step (advance sim) obs = self._rollout_forward( obs, actor_mode="inference", collect_transition=False, track_episode_stats=False, ) dones = self._last_rollout_dones if dones is None: raise RuntimeError( "Rollout forward did not return dones during offline evaluation." ) actions_step = self._last_rollout_actions if actions_step is None: raise RuntimeError( "Rollout forward did not return actions during offline evaluation." ) actions_np = actions_step.detach().cpu().numpy() torque_urdf = ( motion_cmd.robot.data.applied_torque[ ..., motion_cmd.sim2urdf_dof_idx ] .detach() .cpu() .numpy() ) for idx in range(active_count): if prev_robot_actions[idx] is None: action_rate_cur = 0.0 else: action_rate_cur = float( np.linalg.norm( actions_np[idx] - prev_robot_actions[idx] ) / step_dt ) prev_robot_actions[idx] = actions_np[idx].copy() robot_action_rate[idx].append( np.float32(action_rate_cur) ) robot_dof_torque[idx].append( torque_urdf[idx].astype(np.float32) ) # Handle RL dones (first-done policy): mark done for future frames step_dones = ( dones.bool().reshape(-1).detach().cpu().numpy() ) for idx in range(min(active_count, len(step_dones))): if step_dones[idx] and not encountered_done[idx]: encountered_done[idx] = True if rollout_step == max_steps - 1: # End of rollout: save once per env with full rollout arrays + valid_mask if dump_npzs and active_count > 0: out_path_to_last_idx = {} for idx in range(active_count): out_path_to_last_idx[_get_out_path(idx)] = idx save_indices = list(out_path_to_last_idx.values()) max_npz_save_workers = max( 1, min(16, len(save_indices)) ) with ThreadPoolExecutor( max_workers=max_npz_save_workers ) as executor: futures = [ executor.submit(_save_env_npz, idx) for idx in save_indices ] for future in tqdm( as_completed(futures), total=len(futures), desc="Saving NPZs", ): future.result() break logger.info( f"Offline evaluation complete: saved clips to {output_dir}" ) return {"output_dir": output_dir} ================================================ FILE: holomotion/src/algo/ppo_tf.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from typing import Generator import torch import torch.nn.functional as F import torch.optim as optim from holomotion.src.algo.algo_utils import PpoAuxTransition from holomotion.src.algo.ppo import PPO from holomotion.src.modules.agent_modules import ( PPOCondTFActor, PPOCritic, PPOTFActor, PPOTFRefRouterActor, PPOTFRefRouterSeqActor, PPOTFRefRouterV3Actor, TensorDictAssembler, ) from holomotion.src.modules.network_modules import GroupedMoEBlock from loguru import logger from omegaconf import OmegaConf from tabulate import tabulate from tensordict import TensorDict class PPOTF(PPO): """Transformer-policy PPO with TensorDict rollout and sequence update.""" @staticmethod def _select_actor_wrapper_cls(actor_cfg: dict): actor_type = str(actor_cfg.get("type", "")) use_future_cross_attn = bool( actor_cfg.get("use_future_cross_attn", False) ) if actor_type == "ReferenceRoutedGroupedMoETransformerPolicy": if use_future_cross_attn: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicy does not " "support use_future_cross_attn=True." ) return PPOTFRefRouterActor if actor_type == "ReferenceRoutedGroupedMoETransformerPolicyV2": if use_future_cross_attn: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 does not " "support use_future_cross_attn=True." ) return PPOTFRefRouterSeqActor if actor_type == "ReferenceRoutedGroupedMoETransformerPolicyV3": if use_future_cross_attn: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 does not " "support use_future_cross_attn=True." ) return PPOTFRefRouterV3Actor if use_future_cross_attn: return PPOCondTFActor return PPOTFActor @staticmethod def _summarize_moe_layer_stats(moe_layers) -> dict[str, float | None]: if len(moe_layers) == 0: return { "moe_active_expert_ratio": None, "moe_max_expert_frac": None, "moe_least_expert_frac": None, "moe_dead_expert_ratio": None, "moe_expert_count_cv": None, "moe_selected_expert_margin_to_unselected": None, } def _mean_attr(attr_name: str) -> float: values = torch.stack( [ getattr(layer, attr_name).to(torch.float32) for layer in moe_layers ] ) return float(values.mean().item()) return { "moe_active_expert_ratio": _mean_attr("last_active_expert_ratio"), "moe_max_expert_frac": _mean_attr("last_max_expert_frac"), "moe_least_expert_frac": _mean_attr("last_min_expert_frac"), "moe_dead_expert_ratio": _mean_attr("last_dead_expert_ratio"), "moe_expert_count_cv": _mean_attr("last_expert_count_cv"), "moe_selected_expert_margin_to_unselected": _mean_attr( "last_selected_expert_margin_to_unselected" ), } def _setup_configs(self): super()._setup_configs() aux_cfg = self.config.get("aux_state_pred", {}) self.use_aux_state_pred: bool = bool(aux_cfg.get("enabled", False)) self.aux_state_pred_w_base_lin_vel = float( aux_cfg.get("w_base_lin_vel", 0.0) ) self.aux_state_pred_w_root_height = float( aux_cfg.get("w_root_height", 0.0) ) self.aux_state_pred_w_keybody_contact = float( aux_cfg.get("w_keybody_contact", 0.0) ) self.aux_state_pred_w_ref_keybody_rel_pos = float( aux_cfg.get("w_ref_keybody_rel_pos", 0.0) ) self.aux_state_pred_w_robot_keybody_rel_pos = float( aux_cfg.get("w_robot_keybody_rel_pos", 0.0) ) self.aux_state_pred_w_denoise_ref_root_lin_vel = float( aux_cfg.get("w_denoise_ref_root_lin_vel", 0.0) ) self.aux_state_pred_w_denoise_ref_root_ang_vel = float( aux_cfg.get("w_denoise_ref_root_ang_vel", 0.0) ) self.aux_state_pred_w_denoise_ref_dof_pos = float( aux_cfg.get("w_denoise_ref_dof_pos", 0.0) ) self.aux_state_pred_keybody_contact_names = [ str(name) for name in aux_cfg.get("keybody_contact_names", []) ] self.aux_state_pred_keybody_rel_pos_names = [ str(name) for name in aux_cfg.get("keybody_rel_pos_names", []) ] self.aux_state_pred_num_contact_bodies = int( len(self.aux_state_pred_keybody_contact_names) ) self.aux_state_pred_num_keybody_bodies = int( len(self.aux_state_pred_keybody_rel_pos_names) ) self.use_aux_root_height = bool( self.use_aux_state_pred and self.aux_state_pred_w_root_height > 0.0 ) self.use_aux_denoise_ref_root_lin_vel = bool( self.use_aux_state_pred and self.aux_state_pred_w_denoise_ref_root_lin_vel > 0.0 ) self.use_aux_denoise_ref_root_ang_vel = bool( self.use_aux_state_pred and self.aux_state_pred_w_denoise_ref_root_ang_vel > 0.0 ) self.use_aux_denoise_ref_dof_pos = bool( self.use_aux_state_pred and self.aux_state_pred_w_denoise_ref_dof_pos > 0.0 ) self.aux_state_pred_min_std = float(aux_cfg.get("min_std", 1.0e-3)) self.aux_state_pred_max_std = float(aux_cfg.get("max_std", 5.0)) self.aux_denoise_residual_huber_beta = float( aux_cfg.get("denoise_residual_huber_beta", 0.1) ) self.aux_state_pred_raycast_z_offset = float( aux_cfg.get("raycast_z_offset", 1.0) ) self.aux_state_pred_raycast_max_dist = float( aux_cfg.get("raycast_max_dist", 20.0) ) if self.aux_state_pred_min_std <= 0.0: raise ValueError("aux_state_pred.min_std must be > 0.") if self.aux_state_pred_max_std <= self.aux_state_pred_min_std: raise ValueError( "aux_state_pred.max_std must be > aux_state_pred.min_std." ) if self.aux_denoise_residual_huber_beta <= 0.0: raise ValueError( "aux_state_pred.denoise_residual_huber_beta must be > 0." ) if self.aux_state_pred_w_base_lin_vel < 0.0: raise ValueError("aux_state_pred.w_base_lin_vel must be >= 0.") if self.aux_state_pred_w_root_height < 0.0: raise ValueError("aux_state_pred.w_root_height must be >= 0.") if self.aux_state_pred_w_keybody_contact < 0.0: raise ValueError("aux_state_pred.w_keybody_contact must be >= 0.") if self.aux_state_pred_w_ref_keybody_rel_pos < 0.0: raise ValueError( "aux_state_pred.w_ref_keybody_rel_pos must be >= 0." ) if self.aux_state_pred_w_robot_keybody_rel_pos < 0.0: raise ValueError( "aux_state_pred.w_robot_keybody_rel_pos must be >= 0." ) if self.aux_state_pred_w_denoise_ref_root_lin_vel < 0.0: raise ValueError( "aux_state_pred.w_denoise_ref_root_lin_vel must be >= 0." ) if self.aux_state_pred_w_denoise_ref_root_ang_vel < 0.0: raise ValueError( "aux_state_pred.w_denoise_ref_root_ang_vel must be >= 0." ) if self.aux_state_pred_w_denoise_ref_dof_pos < 0.0: raise ValueError( "aux_state_pred.w_denoise_ref_dof_pos must be >= 0." ) if self.use_aux_root_height: if self.aux_state_pred_raycast_max_dist <= 0.0: raise ValueError( "aux_state_pred.raycast_max_dist must be > 0." ) if self.aux_state_pred_raycast_z_offset < 0.0: raise ValueError( "aux_state_pred.raycast_z_offset must be >= 0." ) if ( self.aux_state_pred_w_keybody_contact > 0.0 and self.aux_state_pred_num_contact_bodies == 0 ): raise ValueError( "aux_state_pred.w_keybody_contact > 0 requires " "aux_state_pred.keybody_contact_names to be non-empty." ) if ( self.aux_state_pred_w_ref_keybody_rel_pos > 0.0 or self.aux_state_pred_w_robot_keybody_rel_pos > 0.0 ) and self.aux_state_pred_num_keybody_bodies == 0: raise ValueError( "aux_state_pred keybody position weights > 0 require " "aux_state_pred.keybody_rel_pos_names to be non-empty." ) if self.use_aux_state_pred and self.command_name != "ref_motion": raise ValueError( "aux_state_pred is only supported for PPOTF motion tracking " "(command_name='ref_motion')." ) PpoAuxTransition.SHAPE_TOKENS["C"] = ( self.aux_state_pred_num_contact_bodies ) PpoAuxTransition.SHAPE_TOKENS["K"] = ( self.aux_state_pred_num_keybody_bodies ) aux_cmd_cfg = self.config.get("aux_router_command_recon", {}) self.use_aux_router_command_recon: bool = bool( aux_cmd_cfg.get("enabled", False) ) self.aux_router_command_recon_weight = float( aux_cmd_cfg.get("weight", 0.0) ) self.aux_router_command_recon_hidden_dim = int( aux_cmd_cfg.get("hidden_dim", 0) ) self.aux_router_command_recon_term_prefix = str( aux_cmd_cfg.get("term_prefix", "actor_ref_") ) aux_switch_cfg = self.config.get("aux_router_switch_penalty", {}) self.use_aux_router_switch_penalty = bool( aux_switch_cfg.get("enabled", False) ) self.aux_router_switch_penalty_weight = float( aux_switch_cfg.get("weight", 0.0) ) self.aux_router_switch_penalty_metric = str( aux_switch_cfg.get("metric", "js") ).lower() self.aux_router_switch_penalty_beta = float( aux_switch_cfg.get("beta", 1.0) ) aux_router_future_cfg = self.config.get("aux_router_future_recon", {}) self.use_aux_router_future_recon = bool( aux_router_future_cfg.get("enabled", False) ) self.aux_router_future_recon_weight = float( aux_router_future_cfg.get("weight", 0.0) ) self.aux_router_future_recon_hidden_dim = int( aux_router_future_cfg.get("hidden_dim", 0) ) self.aux_router_future_recon_huber_beta = float( aux_router_future_cfg.get("huber_beta", 1.0) ) dead_margin_cfg = self.config.get("dead_expert_margin_to_topk", {}) self.use_dead_expert_margin_to_topk = bool( dead_margin_cfg.get("enabled", False) ) self.dead_expert_margin_to_topk_weight = float( dead_margin_cfg.get("weight", 0.0) ) orth_cfg = self.config.get("router_expert_orthogonal", {}) self.use_router_expert_orthogonal = bool( orth_cfg.get("enabled", False) ) self.router_expert_orthogonal_weight = float( orth_cfg.get("weight", 0.0) ) self.router_expert_orthogonal_min_active_usage = float( orth_cfg.get("min_active_usage", 1.0e-3) ) self.router_expert_orthogonal_eps = float(orth_cfg.get("eps", 1.0e-8)) selected_margin_cfg = self.config.get( "selected_expert_margin_to_unselected", {} ) self.use_selected_expert_margin_to_unselected = bool( selected_margin_cfg.get("enabled", False) ) self.selected_expert_margin_to_unselected_weight = float( selected_margin_cfg.get("weight", 0.0) ) self.selected_expert_margin_to_unselected_target = float( selected_margin_cfg.get("target", 0.0) ) if self.aux_router_switch_penalty_metric not in { "js", "normed_smooth_l1", }: raise ValueError( "aux_router_switch_penalty.metric must be one of " "{'js', 'normed_smooth_l1'}, got " f"{self.aux_router_switch_penalty_metric!r}." ) if self.aux_router_command_recon_weight < 0.0: raise ValueError("aux_router_command_recon.weight must be >= 0.") if self.aux_router_future_recon_weight < 0.0: raise ValueError("aux_router_future_recon.weight must be >= 0.") if self.aux_router_switch_penalty_weight < 0.0: raise ValueError("aux_router_switch_penalty.weight must be >= 0.") if self.dead_expert_margin_to_topk_weight < 0.0: raise ValueError("dead_expert_margin_to_topk.weight must be >= 0.") if self.router_expert_orthogonal_weight < 0.0: raise ValueError("router_expert_orthogonal.weight must be >= 0.") if self.router_expert_orthogonal_min_active_usage < 0.0: raise ValueError( "router_expert_orthogonal.min_active_usage must be >= 0." ) if self.selected_expert_margin_to_unselected_weight < 0.0: raise ValueError( "selected_expert_margin_to_unselected.weight must be >= 0." ) if self.router_expert_orthogonal_eps <= 0.0: raise ValueError("router_expert_orthogonal.eps must be > 0.") if self.aux_router_switch_penalty_beta <= 0.0: raise ValueError("aux_router_switch_penalty.beta must be > 0.") if self.aux_router_future_recon_huber_beta <= 0.0: raise ValueError("aux_router_future_recon.huber_beta must be > 0.") if self.selected_expert_margin_to_unselected_target < 0.0: raise ValueError( "selected_expert_margin_to_unselected.target must be >= 0." ) if ( self.use_dead_expert_margin_to_topk and self.dead_expert_margin_to_topk_weight == 0.0 ): logger.warning( "dead_expert_margin_to_topk.enabled=True but weight=0.0; " "dead-expert margin loss will have no effect." ) if ( self.use_router_expert_orthogonal and not self.use_dead_expert_margin_to_topk ): raise ValueError( "router_expert_orthogonal.enabled=True requires " "dead_expert_margin_to_topk.enabled=True in sparse top-k MoE." ) if ( self.use_router_expert_orthogonal and self.router_expert_orthogonal_weight == 0.0 ): logger.warning( "router_expert_orthogonal.enabled=True but weight=0.0; " "orthogonal regularization will have no effect." ) if ( self.use_selected_expert_margin_to_unselected and self.selected_expert_margin_to_unselected_weight == 0.0 ): logger.warning( "selected_expert_margin_to_unselected.enabled=True but " "weight=0.0; selected-expert margin loss will have no effect." ) if ( self.use_aux_router_switch_penalty and self.aux_router_switch_penalty_weight == 0.0 ): logger.warning( "aux_router_switch_penalty.enabled=True but weight=0.0; " "router switch penalty will have no effect." ) if ( self.use_aux_router_future_recon and self.aux_router_future_recon_weight == 0.0 ): logger.warning( "aux_router_future_recon.enabled=True but weight=0.0; " "future reconstruction loss will have no effect." ) if ( self.use_aux_router_command_recon or self.use_aux_router_switch_penalty or self.use_aux_router_future_recon ) and self.command_name != "ref_motion": raise ValueError( "aux_router_command_recon, aux_router_future_recon, and " "aux_router_switch_penalty are " "only supported for PPOTF motion tracking " "(command_name='ref_motion')." ) self.aux_command_router_num_moe_layers = 0 self.aux_command_router_num_fine_experts = 0 self.aux_router_command_recon_assembler: TensorDictAssembler | None = ( None ) actor_cfg = self.config.get("module_dict", {}).get("actor", {}) actor_type = str(actor_cfg.get("type", "")) if actor_type in { "ReferenceRoutedGroupedMoETransformerPolicyV2", "ReferenceRoutedGroupedMoETransformerPolicyV3", }: if self.use_aux_router_command_recon: raise ValueError( f"{actor_type} does not support aux_router_command_recon." ) unsupported_aux_weights = { "w_root_height": self.aux_state_pred_w_root_height, "w_denoise_ref_root_lin_vel": self.aux_state_pred_w_denoise_ref_root_lin_vel, "w_denoise_ref_root_ang_vel": self.aux_state_pred_w_denoise_ref_root_ang_vel, "w_denoise_ref_dof_pos": self.aux_state_pred_w_denoise_ref_dof_pos, } enabled_unsupported = [ name for name, value in unsupported_aux_weights.items() if float(value) > 0.0 ] if enabled_unsupported: raise ValueError( f"{actor_type} only supports " "aux_state_pred weights for base_lin_vel, keybody_contact, " "ref_keybody_rel_pos, and robot_keybody_rel_pos. Unsupported " "weights: " + ", ".join(enabled_unsupported) ) elif self.use_aux_router_future_recon: raise ValueError( "aux_router_future_recon requires " "ReferenceRoutedGroupedMoETransformerPolicyV2 or V3." ) @staticmethod def _unwrap_obs_schema(schema: dict | None) -> dict | None: if schema is None: return None has_terms = any( isinstance(v, dict) and ("terms" in v) for v in schema.values() ) if has_terms: return schema if len(schema) == 1: only_value = next(iter(schema.values())) if isinstance(only_value, dict): return only_value return schema @staticmethod def _schema_term_leaf_name(term: str) -> str: return str(term).split("/")[-1] @classmethod def _is_aux_command_term(cls, term: str, term_prefix: str) -> bool: return cls._schema_term_leaf_name(term).startswith(term_prefix) @classmethod def _build_aux_router_command_recon_schema( cls, actor_schema: dict, term_prefix: str ) -> dict: command_schema = {} for group_name, seq_cfg in actor_schema.items(): terms = [ str(term) for term in seq_cfg.get("terms", []) if cls._is_aux_command_term(str(term), term_prefix) ] if len(terms) == 0: continue next_seq_cfg = dict(seq_cfg) next_seq_cfg["terms"] = terms command_schema[group_name] = next_seq_cfg if len(command_schema) == 0: raise ValueError( "aux_router_command_recon could not find any actor command terms in " f"obs_schema with prefix '{term_prefix}'." ) return command_schema @staticmethod def _masked_aux_keybody_mse( pred: torch.Tensor, target: torch.Tensor, valid_tok: torch.Tensor, ) -> torch.Tensor: if pred.shape != target.shape: raise ValueError( "pred and target must have the same shape for keybody MSE, " f"got {tuple(pred.shape)} and {tuple(target.shape)}." ) if pred.ndim != 4: raise ValueError( "Keybody MSE expects [B, T, K, 3] tensors, " f"got pred with shape {tuple(pred.shape)}." ) per_token_mse = torch.square(pred - target).mean(dim=(-1, -2)) valid_tok = valid_tok.to(per_token_mse.dtype) if valid_tok.shape != per_token_mse.shape: raise ValueError( "valid_tok must match per-token keybody MSE shape, " f"got {tuple(valid_tok.shape)} and " f"{tuple(per_token_mse.shape)}." ) valid_count = valid_tok.sum().clamp_min(1.0) return (per_token_mse * valid_tok).sum() / valid_count @staticmethod def _masked_aux_mse( pred: torch.Tensor, target: torch.Tensor, valid_tok: torch.Tensor, ) -> torch.Tensor: if pred.shape != target.shape: raise ValueError( "pred and target must share the same shape for auxiliary MSE, " f"got {tuple(pred.shape)} and {tuple(target.shape)}." ) if pred.ndim < 3: raise ValueError( "Auxiliary MSE expects tensors with shape [B, T, ...], " f"got {tuple(pred.shape)}." ) reduce_dims = tuple(range(2, pred.ndim)) per_token_mse = torch.square(pred - target).mean(dim=reduce_dims) valid_tok = valid_tok.to(per_token_mse.dtype) if valid_tok.shape != per_token_mse.shape: raise ValueError( "valid_tok must match per-token auxiliary MSE shape, got " f"{tuple(valid_tok.shape)} and {tuple(per_token_mse.shape)}." ) valid_count = valid_tok.sum().clamp_min(1.0) return (per_token_mse * valid_tok).sum() / valid_count @staticmethod def _masked_adjacent_router_js( *, router_features: torch.Tensor, valid_tok: torch.Tensor, num_moe_layers: int, num_fine_experts: int, ) -> torch.Tensor: if router_features.ndim != 3: raise ValueError( "router_features must have shape [B, T, L*E], got " f"{tuple(router_features.shape)}." ) if valid_tok.ndim != 2: raise ValueError( "valid_tok must have shape [B, T], got " f"{tuple(valid_tok.shape)}." ) if num_moe_layers <= 0 or num_fine_experts <= 0: raise ValueError( "num_moe_layers and num_fine_experts must be positive, got " f"{num_moe_layers} and {num_fine_experts}." ) bsz, seq_len, feat_dim = router_features.shape expected_dim = num_moe_layers * num_fine_experts if feat_dim != expected_dim: raise ValueError( "router_features last dim must equal num_moe_layers * " "num_fine_experts, got " f"{feat_dim} vs {expected_dim}." ) if valid_tok.shape != (bsz, seq_len): raise ValueError( "valid_tok shape mismatch for router temporal loss: expected " f"{(bsz, seq_len)}, got {tuple(valid_tok.shape)}." ) if seq_len <= 1: return router_features.new_zeros(()) router_p = router_features.reshape( bsz, seq_len, num_moe_layers, num_fine_experts ).to(torch.float32) prev_p = router_p[:, :-1] curr_p = router_p[:, 1:] mix_p = 0.5 * (prev_p + curr_p) eps = 1.0e-20 prev_safe = prev_p.clamp_min(eps) curr_safe = curr_p.clamp_min(eps) mix_safe = mix_p.clamp_min(eps) kl_prev = (prev_p * (torch.log(prev_safe) - torch.log(mix_safe))).sum( dim=-1 ) kl_curr = (curr_p * (torch.log(curr_safe) - torch.log(mix_safe))).sum( dim=-1 ) js = 0.5 * (kl_prev + kl_curr) adjacent_valid = (valid_tok[:, :-1] * valid_tok[:, 1:]).to(js.dtype) valid_count = adjacent_valid.sum().clamp_min(1.0) * float( num_moe_layers ) return (js * adjacent_valid.unsqueeze(-1)).sum() / valid_count @staticmethod def _masked_adjacent_router_normed_smooth_l1( *, router_temporal_features: torch.Tensor, valid_tok: torch.Tensor, num_moe_layers: int, num_fine_experts: int, beta: float = 1.0, ) -> torch.Tensor: if router_temporal_features.ndim != 3: raise ValueError( "router_temporal_features must have shape [B, T, L*E], got " f"{tuple(router_temporal_features.shape)}." ) if valid_tok.ndim != 2: raise ValueError( "valid_tok must have shape [B, T], got " f"{tuple(valid_tok.shape)}." ) if num_moe_layers <= 0 or num_fine_experts <= 0: raise ValueError( "num_moe_layers and num_fine_experts must be positive, got " f"{num_moe_layers} and {num_fine_experts}." ) if beta <= 0.0: raise ValueError( f"beta must be positive for SmoothL1, got {beta}." ) bsz, seq_len, feat_dim = router_temporal_features.shape expected_dim = num_moe_layers * num_fine_experts if feat_dim != expected_dim: raise ValueError( "router_temporal_features last dim must equal " "num_moe_layers * num_fine_experts, got " f"{feat_dim} vs {expected_dim}." ) if valid_tok.shape != (bsz, seq_len): raise ValueError( "valid_tok shape mismatch for router temporal loss: expected " f"{(bsz, seq_len)}, got {tuple(valid_tok.shape)}." ) if seq_len <= 1: return router_temporal_features.new_zeros(()) router_logits = router_temporal_features.reshape( bsz, seq_len, num_moe_layers, num_fine_experts ).to(torch.float32) router_logits = router_logits - router_logits.mean( dim=-1, keepdim=True ) router_logits = F.normalize(router_logits, p=2.0, dim=-1, eps=1.0e-5) prev_logits = router_logits[:, :-1] curr_logits = router_logits[:, 1:] smooth_l1 = F.smooth_l1_loss( curr_logits, prev_logits, reduction="none", beta=beta, ).mean(dim=(-1, -2)) adjacent_valid = (valid_tok[:, :-1] * valid_tok[:, 1:]).to( smooth_l1.dtype ) valid_count = adjacent_valid.sum().clamp_min(1.0) return (smooth_l1 * adjacent_valid).sum() / valid_count @staticmethod def _masked_aux_gaussian_nll( *, loc: torch.Tensor, log_std: torch.Tensor, target: torch.Tensor, valid_tok: torch.Tensor, min_std: float, max_std: float, ) -> tuple[torch.Tensor, torch.Tensor]: if loc.shape != log_std.shape or loc.shape != target.shape: raise ValueError( "loc, log_std, and target must share the same shape for " "Gaussian aux loss, got " f"{tuple(loc.shape)}, {tuple(log_std.shape)}, " f"{tuple(target.shape)}." ) if loc.ndim < 3: raise ValueError( "Gaussian aux loss expects tensors with shape [B, T, ...], " f"got {tuple(loc.shape)}." ) per_elem_std = torch.clamp( torch.exp(log_std), min=float(min_std), max=float(max_std), ) reduce_dims = tuple(range(2, loc.ndim)) per_token_nll = 0.5 * ( torch.square((target - loc) / per_elem_std) + 2.0 * torch.log(per_elem_std + 1.0e-8) ).sum(dim=reduce_dims) valid_tok = valid_tok.to(per_token_nll.dtype) if valid_tok.shape != per_token_nll.shape: raise ValueError( "valid_tok must match per-token Gaussian loss shape, got " f"{tuple(valid_tok.shape)} and {tuple(per_token_nll.shape)}." ) valid_count = valid_tok.sum().clamp_min(1.0) loss = (per_token_nll * valid_tok).sum() / valid_count per_token_std = per_elem_std.reshape( per_elem_std.shape[0], per_elem_std.shape[1], -1 ).mean(dim=-1) mean_std = (per_token_std * valid_tok).sum() / valid_count return loss, mean_std @staticmethod def _masked_aux_huber( *, pred: torch.Tensor, target: torch.Tensor, valid_tok: torch.Tensor, beta: float, ) -> torch.Tensor: if pred.shape != target.shape: raise ValueError( "pred and target must share the same shape for Huber aux loss, " f"got {tuple(pred.shape)} and {tuple(target.shape)}." ) if pred.ndim < 3: raise ValueError( "Huber aux loss expects tensors with shape [B, T, ...], " f"got {tuple(pred.shape)}." ) per_elem = F.smooth_l1_loss(pred, target, reduction="none", beta=beta) reduce_dims = tuple(range(2, pred.ndim)) per_token = per_elem.mean(dim=reduce_dims) valid_tok = valid_tok.to(per_token.dtype) if valid_tok.shape != per_token.shape: raise ValueError( "valid_tok must match per-token Huber loss shape, got " f"{tuple(valid_tok.shape)} and {tuple(per_token.shape)}." ) valid_count = valid_tok.sum().clamp_min(1.0) return (per_token * valid_tok).sum() / valid_count def _compute_aux_router_future_recon_loss( self, *, actor_wrapper: PPOTFActor, actor_out: TensorDict, obs_b: TensorDict, valid_tok: torch.Tensor, ) -> torch.Tensor: future_assembler = actor_wrapper.aux_router_future_recon_assembler if future_assembler is None: raise ValueError( "aux_router_future_recon is enabled but future assembler was " "not initialized on the actor wrapper." ) aux_router_future_recon_pred = actor_out.get("aux_router_future_recon") bsz, seq_len = int(obs_b.batch_size[0]), int(obs_b.batch_size[1]) future_target = future_assembler(obs_b.flatten(0, 1)).reshape( bsz, seq_len, -1 ) normalized_future_target = actor_wrapper.actor_module.normalize_aux_router_future_recon_target( future_target ).to(aux_router_future_recon_pred.dtype) return self._masked_aux_huber( pred=aux_router_future_recon_pred, target=normalized_future_target, valid_tok=valid_tok, beta=self.aux_router_future_recon_huber_beta, ) def _compute_routed_expert_orthogonal_loss( self, moe_layer: GroupedMoEBlock, *, dtype: torch.dtype, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: usage = moe_layer.last_routed_expert_usage.to( device=device, dtype=torch.float32 ) active_mask = usage > float( self.router_expert_orthogonal_min_active_usage ) active_idx = torch.nonzero(active_mask, as_tuple=False).squeeze(-1) active_count = torch.tensor( float(active_idx.numel()), device=device, dtype=torch.float32 ) if active_idx.numel() < 2: zero = torch.zeros((), device=device, dtype=dtype) zero_f = torch.zeros((), device=device, dtype=torch.float32) return zero, active_count, zero_f expert_vecs = moe_layer.down_proj.index_select(0, active_idx) expert_vecs = expert_vecs.reshape(active_idx.numel(), -1).to( device=device, dtype=torch.float32 ) expert_vecs = F.normalize( expert_vecs, p=2.0, dim=-1, eps=float(self.router_expert_orthogonal_eps), ) gram = expert_vecs @ expert_vecs.transpose(0, 1) offdiag_mask = ~torch.eye( gram.shape[0], dtype=torch.bool, device=gram.device ) offdiag = gram.masked_select(offdiag_mask) if offdiag.numel() == 0: zero = torch.zeros((), device=device, dtype=dtype) zero_f = torch.zeros((), device=device, dtype=torch.float32) return zero, active_count, zero_f orth_loss = offdiag.square().sum().to(dtype) mean_offdiag_similarity = offdiag.abs().mean() return orth_loss, active_count, mean_offdiag_similarity @staticmethod def _root_relative_body_pos_from_mixed_position_frames( *, body_pos_w: torch.Tensor, root_pos_env: torch.Tensor, root_quat_w: torch.Tensor, env_origins: torch.Tensor, ) -> torch.Tensor: """Convert world-frame body positions using an env-frame root pose. In IsaacLab, `isaaclab_mdp.root_pos_w(env)` is already in the environment frame (simulator world minus `env.scene.env_origins`), while `robot.data.body_pos_w` stays in simulator-world coordinates. """ if body_pos_w.ndim != 3 or body_pos_w.shape[-1] != 3: raise ValueError( "body_pos_w must have shape [B, N, 3], " f"got {tuple(body_pos_w.shape)}." ) if root_pos_env.ndim != 2 or root_pos_env.shape[-1] != 3: raise ValueError( "root_pos_env must have shape [B, 3], " f"got {tuple(root_pos_env.shape)}." ) if root_quat_w.ndim != 2 or root_quat_w.shape[-1] != 4: raise ValueError( "root_quat_w must have shape [B, 4], " f"got {tuple(root_quat_w.shape)}." ) if env_origins.ndim != 2 or env_origins.shape[-1] != 3: raise ValueError( "env_origins must have shape [B, 3], " f"got {tuple(env_origins.shape)}." ) if body_pos_w.shape[0] != root_pos_env.shape[0]: raise ValueError( "Batch size mismatch between body_pos_w and root_pos_env: " f"{body_pos_w.shape[0]} vs {root_pos_env.shape[0]}." ) if body_pos_w.shape[0] != root_quat_w.shape[0]: raise ValueError( "Batch size mismatch between body_pos_w and root_quat_w: " f"{body_pos_w.shape[0]} vs {root_quat_w.shape[0]}." ) if body_pos_w.shape[0] != env_origins.shape[0]: raise ValueError( "Batch size mismatch between body_pos_w and env_origins: " f"{body_pos_w.shape[0]} vs {env_origins.shape[0]}." ) body_pos_env = body_pos_w - env_origins[:, None, :] rel_pos_env = body_pos_env - root_pos_env[:, None, :] quat_vec = root_quat_w[:, None, 1:].expand_as(rel_pos_env) quat_real = root_quat_w[:, None, :1].expand( -1, rel_pos_env.shape[1], -1 ) t = 2.0 * torch.cross(quat_vec, rel_pos_env, dim=-1) return rel_pos_env - quat_real * t + torch.cross(quat_vec, t, dim=-1) def _setup_models_and_optimizer(self): sample_obs_dict = self.env.reset_all()[0] sample_td = self._wrap_obs_dict(sample_obs_dict) actor_cfg = OmegaConf.to_container( self.config.module_dict.actor, resolve=True ) critic_cfg = OmegaConf.to_container( self.config.module_dict.critic, resolve=True ) actor_cfg["noise_std_type"] = getattr( self.config, "noise_std_type", "log" ) actor_cfg["min_sigma"] = getattr(self.config, "min_sigma", 0.1) actor_cfg["max_sigma"] = getattr(self.config, "max_sigma", 1.5) actor_cfg["fix_sigma"] = getattr(self.config, "fix_sigma", False) self._future_mask_prob = float(actor_cfg.get("future_mask_prob", 0.0)) self._future_mask_mode = str( actor_cfg.get("future_mask_mode", "random_suffix") ).lower() aux_cfg = self.config.get("aux_state_pred", {}) if isinstance(aux_cfg, dict): actor_cfg["aux_state_pred"] = dict(aux_cfg) else: actor_cfg["aux_state_pred"] = OmegaConf.to_container( aux_cfg, resolve=True ) aux_cmd_cfg = self.config.get("aux_router_command_recon", {}) if isinstance(aux_cmd_cfg, dict): actor_aux_cmd_cfg = dict(aux_cmd_cfg) else: actor_aux_cmd_cfg = OmegaConf.to_container( aux_cmd_cfg, resolve=True ) aux_switch_cfg = self.config.get("aux_router_switch_penalty", {}) if isinstance(aux_switch_cfg, dict): actor_aux_switch_cfg = dict(aux_switch_cfg) else: actor_aux_switch_cfg = OmegaConf.to_container( aux_switch_cfg, resolve=True ) aux_router_future_cfg = self.config.get("aux_router_future_recon", {}) if isinstance(aux_router_future_cfg, dict): actor_aux_router_future_cfg = dict(aux_router_future_cfg) else: actor_aux_router_future_cfg = OmegaConf.to_container( aux_router_future_cfg, resolve=True ) dead_margin_cfg = self.config.get("dead_expert_margin_to_topk", {}) if isinstance(dead_margin_cfg, dict): actor_dead_margin_cfg = dict(dead_margin_cfg) else: actor_dead_margin_cfg = OmegaConf.to_container( dead_margin_cfg, resolve=True ) selected_margin_cfg = self.config.get( "selected_expert_margin_to_unselected", {} ) if isinstance(selected_margin_cfg, dict): actor_selected_margin_cfg = dict(selected_margin_cfg) else: actor_selected_margin_cfg = OmegaConf.to_container( selected_margin_cfg, resolve=True ) actor_schema = self._unwrap_obs_schema( actor_cfg.get("obs_schema", None) ) critic_schema = self._unwrap_obs_schema( critic_cfg.get("obs_schema", None) ) if actor_schema is None: raise ValueError( "PPOTF requires actor obs_schema to infer flattened obs dim." ) if self.use_aux_router_command_recon: aux_command_schema = self._build_aux_router_command_recon_schema( actor_schema, self.aux_router_command_recon_term_prefix ) self.aux_router_command_recon_assembler = TensorDictAssembler( aux_command_schema, output_mode="flat" ) actor_aux_cmd_cfg["output_dim"] = int( self.aux_router_command_recon_assembler.infer_output_dim( sample_td ) ) if self.aux_router_command_recon_hidden_dim > 0: actor_aux_cmd_cfg["hidden_dim"] = ( self.aux_router_command_recon_hidden_dim ) actor_cfg["aux_router_command_recon"] = actor_aux_cmd_cfg actor_cfg["aux_router_future_recon"] = actor_aux_router_future_cfg actor_cfg["aux_router_switch_penalty"] = actor_aux_switch_cfg actor_cfg["dead_expert_margin_to_topk"] = actor_dead_margin_cfg actor_cfg["selected_expert_margin_to_unselected"] = ( actor_selected_margin_cfg ) actor_obs_dim = int( TensorDictAssembler( actor_schema, output_mode="flat" ).infer_output_dim(sample_td) ) use_future_cross_attn = bool( actor_cfg.get("use_future_cross_attn", False) ) actor_cls = self._select_actor_wrapper_cls(actor_cfg) if use_future_cross_attn: if "flattened_obs" not in actor_schema: raise ValueError( "use_future_cross_attn=True requires " "actor obs_schema.flattened_obs." ) if "flattened_obs_fut" not in actor_schema: raise ValueError( "use_future_cross_attn=True requires " "actor obs_schema.flattened_obs_fut." ) state_schema = {"flattened_obs": actor_schema["flattened_obs"]} future_schema = { "flattened_obs_fut": actor_schema["flattened_obs_fut"] } state_obs_dim = int( TensorDictAssembler( state_schema, output_mode="flat" ).infer_output_dim(sample_td) ) future_asm = TensorDictAssembler(future_schema, output_mode="seq") future_token_dim = int(future_asm.infer_output_dim(sample_td)) future_seq_len = int(future_asm.seq_len) actor_cfg["state_obs_dim"] = state_obs_dim actor_cfg["future_token_dim"] = future_token_dim actor_cfg["future_seq_len"] = future_seq_len actor_cfg["input_dim_override"] = state_obs_dim else: actor_cfg["input_dim_override"] = actor_obs_dim self.actor = actor_cls( obs_schema=actor_schema, module_config_dict=actor_cfg, num_actions=self.num_actions, init_noise_std=self.config.init_noise_std, obs_example=sample_td, ).to(self.device) actor_module_unwrapped = self.actor.actor_module self.aux_command_router_num_moe_layers = int( getattr(actor_module_unwrapped, "_num_moe_layers", 0) ) self.aux_command_router_num_fine_experts = int( getattr(actor_module_unwrapped, "num_fine_experts", 0) ) if ( self.use_aux_router_switch_penalty and self.aux_command_router_num_moe_layers <= 0 ): raise ValueError( "aux_router_switch_penalty requires at least one " "GroupedMoEBlock." ) self.critic = PPOCritic( obs_schema=critic_schema, module_config_dict=critic_cfg, obs_example=sample_td, ).to(self.device) if self.is_main_process: actor = self.accelerator.unwrap_model(self.actor) critic = self.accelerator.unwrap_model(self.critic) logger.info("Actor (TensorDict module):\n{!r}", actor) logger.info( "Actor keys: in_keys={} out_keys={}", list(actor.in_keys), list(actor.out_keys), ) logger.info("Actor core nn module:\n{!r}", actor.actor_module) logger.info("Critic (TensorDict module):\n{!r}", critic) logger.info( "Critic keys: in_keys={} out_keys={}", list(critic.in_keys), list(critic.out_keys), ) logger.info("Critic core nn module:\n{!r}", critic.critic_module) actor_params = sum(p.numel() for p in self.actor.parameters()) critic_params = sum(p.numel() for p in self.critic.parameters()) params_table = [ ["Actor(Transformer)", f"{actor_params / 1.0e6:.3f}"], ["Critic", f"{critic_params / 1.0e6:.3f}"], ["Total", f"{(actor_params + critic_params) / 1.0e6:.3f}"], ] logger.info( "Model Summary:\n" + tabulate( params_table, headers=["Model", "Params (M)"], tablefmt="simple_outline", ) ) optimizer_class = getattr(optim, self.optimizer_type) optimizer_kwargs = self._build_optimizer_kwargs(optimizer_class) if self.optimizer_type == "AdamW": decay_params = [] non_decay_params = [] for name, p in self.actor.named_parameters(): if not p.requires_grad: continue if ( p.ndim < 2 or ("log_std" in name) or ("bias" in name) or ("norm" in name) ): non_decay_params.append(p) else: decay_params.append(p) self.actor_optimizer = optimizer_class( [ {"params": decay_params, "weight_decay": 0.01}, {"params": non_decay_params, "weight_decay": 0.0}, ], lr=self.actor_learning_rate, betas=(self.actor_beta1, self.actor_beta2), **optimizer_kwargs, ) else: self.actor_optimizer = optimizer_class( self.actor.parameters(), lr=self.actor_learning_rate, betas=(self.actor_beta1, self.actor_beta2), **optimizer_kwargs, ) self.critic_optimizer = optimizer_class( self.critic.parameters(), lr=self.critic_learning_rate, betas=(self.critic_beta1, self.critic_beta2), **optimizer_kwargs, ) ( self.actor, self.critic, self.actor_optimizer, self.critic_optimizer, ) = self.accelerator.prepare( self.actor, self.critic, self.actor_optimizer, self.critic_optimizer, ) actor_for_kv = self.accelerator.unwrap_model(self.actor) if hasattr(actor_for_kv, "reset_kv_cache"): actor_for_kv.reset_kv_cache(self.env.num_envs, self.device) self._kv_reset_pending = torch.zeros( self.env.num_envs, dtype=torch.bool, device=self.device ) self._rollout_future_masks = None self._rollout_step_idx = 0 def _setup_data_buffers(self): super()._setup_data_buffers() self._aux_height_scanner = None self._aux_contact_sensor = None self._aux_contact_body_ids = None self._aux_keybody_body_ids = None if not self.use_aux_state_pred: return if self.use_velocity_transition: raise ValueError( "aux_state_pred is not supported with velocity " "tracking in PPOTF." ) self.transition_cls = PpoAuxTransition if self.use_aux_root_height: if "height_scanner" not in self.env._env.scene.sensors: raise ValueError( "aux_state_pred requires a RayCaster sensor " "named 'height_scanner' " "in env.scene.sensors." ) height_scanner = self.env._env.scene.sensors["height_scanner"] height_scanner.cfg.max_distance = ( self.aux_state_pred_raycast_max_dist ) height_scanner.cfg.ray_alignment = "world" height_scanner.cfg.offset.pos = ( 0.0, 0.0, self.aux_state_pred_raycast_z_offset, ) if height_scanner.is_initialized: height_scanner.ray_starts[..., 2] = ( self.aux_state_pred_raycast_z_offset ) self._aux_height_scanner = height_scanner if self.aux_state_pred_num_contact_bodies > 0: if "contact_forces" not in self.env._env.scene.sensors: raise ValueError( "aux_state_pred.keybody_contact_names requires " "a ContactSensor " "named 'contact_forces' in env.scene.sensors." ) contact_sensor = self.env._env.scene.sensors["contact_forces"] sensor_body_names = list(contact_sensor.body_names) body_ids = [] for body_name in self.aux_state_pred_keybody_contact_names: if body_name not in sensor_body_names: raise ValueError( f"Body '{body_name}' not found in contact " "sensor body_names." ) body_ids.append(sensor_body_names.index(body_name)) self._aux_contact_sensor = contact_sensor self._aux_contact_body_ids = torch.tensor( body_ids, dtype=torch.long, device=self.device ) if self.aux_state_pred_num_keybody_bodies > 0: robot_body_names = list(self.env._env.scene["robot"].body_names) body_ids = [] for body_name in self.aux_state_pred_keybody_rel_pos_names: if body_name not in robot_body_names: raise ValueError( f"Body '{body_name}' not found in robot body_names." ) body_ids.append(robot_body_names.index(body_name)) self._aux_keybody_body_ids = torch.tensor( body_ids, dtype=torch.long, device=self.device ) def _build_transition( self, obs_td: TensorDict, actor_out: TensorDict, critic_out: TensorDict, ): if not self.use_aux_state_pred: return super()._build_transition(obs_td, actor_out, critic_out) import isaaclab.envs.mdp as isaaclab_mdp actions = actor_out.get("actions") actions_log_prob = actor_out.get("actions_log_prob") mu = actor_out.get("mu") sigma = actor_out.get("sigma") values = critic_out.get("values") zero_scalar = torch.zeros( self.num_envs, 1, device=self.device, dtype=torch.float32, ) zero_scalar_bool = torch.zeros( self.num_envs, 1, device=self.device, dtype=torch.bool, ) gt_base_lin_vel_b = isaaclab_mdp.base_lin_vel(self.env._env) if self.use_aux_root_height: root_pos_w = isaaclab_mdp.root_pos_w(self.env._env) if self._aux_height_scanner is None: raise RuntimeError( "Aux state prediction expected " "_aux_height_scanner to be initialized." ) terrain_z = self._aux_height_scanner.data.ray_hits_w[:, 0, 2:3] env_origin_z = self.env._env.scene.env_origins[:, 2:3] terrain_z = torch.where( torch.isfinite(terrain_z), terrain_z, env_origin_z ) gt_root_height_rel_terrain = root_pos_w[:, 2:3] - terrain_z else: gt_root_height_rel_terrain = torch.zeros( self.num_envs, 1, device=self.device, dtype=torch.float32 ) if self.aux_state_pred_num_contact_bodies > 0: if ( self._aux_contact_sensor is None or self._aux_contact_body_ids is None ): raise RuntimeError( "Aux keybody contact prediction expects contact sensor " "and body ids to be initialized." ) contact_time = self._aux_contact_sensor.data.current_contact_time[ :, self._aux_contact_body_ids ] gt_keybody_contacts = (contact_time > 0.0).to(torch.float32) else: gt_keybody_contacts = torch.zeros( self.num_envs, 0, device=self.device, dtype=torch.float32 ) command = self.env._env.command_manager.get_term(self.command_name) if self.aux_state_pred_num_keybody_bodies > 0: if self._aux_keybody_body_ids is None: raise RuntimeError( "Aux keybody position prediction expects body " "ids to be initialized." ) # Both the ref-motion command and robot asset expose bodies in # simulator order, so the cached robot body indices align here. gt_ref_keybody_rel_pos = ( command.get_ref_motion_bodylink_rel_pos_cur()[ :, self._aux_keybody_body_ids, : ] ) robot = self.env._env.scene["robot"] robot_keybody_global_pos = robot.data.body_pos_w[ :, self._aux_keybody_body_ids, : ] env_origins = self.env._env.scene.env_origins root_pos_w = isaaclab_mdp.root_pos_w(self.env._env) root_quat_w = isaaclab_mdp.root_quat_w(self.env._env) gt_robot_keybody_rel_pos = ( self._root_relative_body_pos_from_mixed_position_frames( body_pos_w=robot_keybody_global_pos, root_pos_env=root_pos_w, root_quat_w=root_quat_w, env_origins=env_origins, ) ) else: gt_ref_keybody_rel_pos = torch.zeros( self.num_envs, 0, 3, device=self.device, dtype=torch.float32 ) gt_robot_keybody_rel_pos = torch.zeros( self.num_envs, 0, 3, device=self.device, dtype=torch.float32 ) gt_denoise_ref_root_lin_vel = torch.zeros( self.num_envs, 3, device=self.device, dtype=torch.float32 ) gt_denoise_ref_root_ang_vel = torch.zeros( self.num_envs, 3, device=self.device, dtype=torch.float32 ) gt_denoise_ref_dof_pos = torch.zeros( self.num_envs, actions.shape[-1], device=self.device, dtype=torch.float32, ) if ( self.use_aux_denoise_ref_root_lin_vel or self.use_aux_denoise_ref_root_ang_vel or self.use_aux_denoise_ref_dof_pos ): try: if self.use_aux_denoise_ref_root_lin_vel: gt_denoise_ref_root_lin_vel = ( command.get_ref_motion_base_linvel_cur( prefix="ft_ref_" ) - command.get_ref_motion_base_linvel_cur(prefix="ref_") ) if self.use_aux_denoise_ref_root_ang_vel: gt_denoise_ref_root_ang_vel = ( command.get_ref_motion_base_angvel_cur( prefix="ft_ref_" ) - command.get_ref_motion_base_angvel_cur(prefix="ref_") ) if self.use_aux_denoise_ref_dof_pos: gt_denoise_ref_dof_pos = ( command.get_ref_motion_dof_pos_cur(prefix="ft_ref_") - command.get_ref_motion_dof_pos_cur(prefix="ref_") ) expected_shape = (self.num_envs, actions.shape[-1]) if tuple(gt_denoise_ref_dof_pos.shape) != expected_shape: raise ValueError( "gt_denoise_ref_dof_pos must match the action-aligned " "DoF shape " f"{expected_shape}, got " f"{tuple(gt_denoise_ref_dof_pos.shape)}." ) except KeyError as exc: raise RuntimeError( "Filtered reference tensors are unavailable for " "aux_denoise_* targets. Enable online filtering or " "materialize ft_ref_* tensors in the motion cache." ) from exc return self.transition_cls( obs=obs_td, actions=actions.detach(), teacher_actions=torch.zeros_like(actions), mu=mu.detach(), sigma=sigma.detach(), actions_log_prob=actions_log_prob[..., None].detach(), values=values.detach(), rewards=zero_scalar.clone(), dones=zero_scalar_bool, returns=zero_scalar.clone(), advantages=zero_scalar.clone(), gt_base_lin_vel_b=gt_base_lin_vel_b.detach(), gt_root_height_rel_terrain=gt_root_height_rel_terrain.detach(), gt_keybody_contacts=gt_keybody_contacts.detach(), gt_ref_keybody_rel_pos=gt_ref_keybody_rel_pos.detach(), gt_robot_keybody_rel_pos=gt_robot_keybody_rel_pos.detach(), gt_denoise_ref_root_lin_vel=gt_denoise_ref_root_lin_vel.detach(), gt_denoise_ref_root_ang_vel=gt_denoise_ref_root_ang_vel.detach(), gt_denoise_ref_dof_pos=gt_denoise_ref_dof_pos.detach(), batch_size=[self.num_envs], device=self.device, ) def _build_storage(self, obs_td: TensorDict): actor_for_kv = self.accelerator.unwrap_model(self.actor) actor_policy = actor_for_kv.actor_module if bool(getattr(actor_policy, "use_future_cross_attn", False)): n_fut = int(getattr(actor_policy, "future_seq_len", 0)) if n_fut <= 0: raise ValueError( "future_seq_len must be positive when " "use_future_cross_attn=True" ) obs_td = obs_td.clone(recurse=False) obs_td.set( "future_mask", torch.ones( self.env.num_envs, n_fut, dtype=torch.bool, device=self.device, ), ) return super()._build_storage(obs_td) def _sample_iteration_future_masks(self) -> torch.Tensor | None: actor_for_kv = self.accelerator.unwrap_model(self.actor) actor_policy = actor_for_kv.actor_module if not bool(getattr(actor_policy, "use_future_cross_attn", False)): return None n_fut = int(getattr(actor_policy, "future_seq_len", 0)) if n_fut <= 0: raise ValueError( "future_seq_len must be positive when " "use_future_cross_attn=True" ) if self._future_mask_mode != "random_suffix": raise ValueError( "Unsupported future_mask_mode: " f"{self._future_mask_mode}. " "Expected 'random_suffix'." ) num_steps = int(self.num_steps_per_env) num_envs = int(self.env.num_envs) keep = torch.ones( num_steps, num_envs, n_fut, dtype=torch.bool, device=self.device, ) if bool(getattr(self, "_offline_evaluating", False)): return keep if self._future_mask_prob <= 0.0: return keep apply_mask = ( torch.rand(num_steps, num_envs, device=self.device) < self._future_mask_prob ) keep_len = torch.randint( 1, n_fut + 1, (num_steps, num_envs), device=self.device, ) full_len = torch.full( (num_steps, num_envs), n_fut, dtype=torch.long, device=self.device, ) keep_len = torch.where(apply_mask, keep_len, full_len) token_idx = torch.arange(n_fut, device=self.device, dtype=torch.long)[ None, None, : ] return token_idx < keep_len[:, :, None] def _reset_rollout_forward_state(self) -> None: actor_for_kv = self.accelerator.unwrap_model(self.actor) actor_for_kv.clear_env_cache(None) actor_policy = actor_for_kv.actor_module actor_policy.reset_routing_stats() actor_policy.set_collect_routing_stats(True) self._kv_reset_pending.zero_() self._rollout_future_masks = self._sample_iteration_future_masks() self._rollout_step_idx = 0 def _rollout_forward( self, obs_td: TensorDict, *, actor_mode: str = "sampling", collect_transition: bool = True, track_episode_stats: bool = True, ) -> TensorDict: if collect_transition and self._rollout_future_masks is not None: if self._rollout_step_idx >= int( self._rollout_future_masks.shape[0] ): raise RuntimeError( "Rollout future-mask step index exceeded " "pre-sampled mask length." ) obs_td = obs_td.clone(recurse=False) obs_td.set( "future_mask", self._rollout_future_masks[self._rollout_step_idx], ) actor_for_kv = self.accelerator.unwrap_model(self.actor) if torch.any(self._kv_reset_pending): env_ids = torch.nonzero(self._kv_reset_pending).squeeze(-1) if env_ids.numel() > 0: actor_for_kv.clear_env_cache(env_ids) self._kv_reset_pending[env_ids] = False next_obs_td = super()._rollout_forward( obs_td, actor_mode=actor_mode, collect_transition=collect_transition, track_episode_stats=track_episode_stats, ) if collect_transition and self._rollout_future_masks is not None: self._rollout_step_idx += 1 if not collect_transition: dones = self._last_rollout_dones if dones is not None: self._kv_reset_pending |= ( dones.view(-1).to(torch.bool).to(self.device) ) return next_obs_td def process_env_step( self, rewards: torch.Tensor, dones: torch.Tensor, time_outs: torch.Tensor, infos: dict, ) -> None: super().process_env_step(rewards, dones, time_outs, infos) if getattr(self, "_kv_reset_pending", None) is not None: self._kv_reset_pending |= ( dones.view(-1).to(torch.bool).to(self.device) ) @staticmethod def _build_episode_causal_mask(dones_seq: torch.Tensor) -> torch.Tensor: """Build [N, T, T] mask: causal and within the same episode segment.""" n, t, _ = dones_seq.shape device = dones_seq.device dones = dones_seq.squeeze(-1).to(torch.long) seg = torch.cumsum(dones, dim=1) - dones same = seg[:, :, None] == seg[:, None, :] causal = torch.tril(torch.ones(t, t, dtype=torch.bool, device=device)) return same & causal @staticmethod def _resolve_sequence_batch_partition( num_envs: int, num_mini_batches: int, ) -> tuple[int, int]: if num_envs <= 0: raise RuntimeError( "PPOTF sequence batching requires at least one " "environment on each rank." ) effective_num_mini_batches = max( 1, min(int(num_mini_batches), int(num_envs)) ) mini_batch_envs = max( 1, (num_envs + effective_num_mini_batches - 1) // effective_num_mini_batches, ) return effective_num_mini_batches, mini_batch_envs def _sequence_batches( self, num_mini_batches: int, num_epochs: int ) -> Generator[tuple, None, None]: data = self.storage.data obs_seq = data["obs"].transpose(0, 1) actions_seq = data["actions"].transpose(0, 1) values_seq = data["values"].transpose(0, 1) rewards_seq = data["rewards"].transpose(0, 1) returns_seq = data["returns"].transpose(0, 1) adv_seq = data["advantages"].transpose(0, 1) old_logp_seq = data["actions_log_prob"].transpose(0, 1) old_mu_seq = data["mu"].transpose(0, 1) old_sigma_seq = data["sigma"].transpose(0, 1) dones_seq = data["dones"].transpose(0, 1) gt_base_lin_vel_seq = None gt_root_height_seq = None gt_keybody_contact_seq = None gt_ref_keybody_rel_pos_seq = None gt_robot_keybody_rel_pos_seq = None gt_denoise_ref_root_lin_vel_seq = None gt_denoise_ref_root_ang_vel_seq = None gt_denoise_ref_dof_pos_seq = None if self.use_aux_state_pred: gt_base_lin_vel_seq = data["gt_base_lin_vel_b"].transpose(0, 1) gt_root_height_seq = data["gt_root_height_rel_terrain"].transpose( 0, 1 ) gt_keybody_contact_seq = data["gt_keybody_contacts"].transpose( 0, 1 ) gt_ref_keybody_rel_pos_seq = data[ "gt_ref_keybody_rel_pos" ].transpose(0, 1) gt_robot_keybody_rel_pos_seq = data[ "gt_robot_keybody_rel_pos" ].transpose(0, 1) gt_denoise_ref_root_lin_vel_seq = data[ "gt_denoise_ref_root_lin_vel" ].transpose(0, 1) gt_denoise_ref_root_ang_vel_seq = data[ "gt_denoise_ref_root_ang_vel" ].transpose(0, 1) gt_denoise_ref_dof_pos_seq = data[ "gt_denoise_ref_dof_pos" ].transpose(0, 1) num_envs = int(actions_seq.shape[0]) if num_envs <= 0: raise RuntimeError( "PPOTF sequence batching requires at least one " "environment on each rank, " f"got num_envs={num_envs}." ) num_mini_batches, mb_env = self._resolve_sequence_batch_partition( num_envs, num_mini_batches ) env_indices = torch.randperm(num_envs, device=self.device) for _ in range(num_epochs): for i in range(num_mini_batches): start = i * mb_env if start >= num_envs: break end = min(num_envs, (i + 1) * mb_env) idx = env_indices[start:end] obs_b = obs_seq[idx] actions_b = actions_seq[idx] values_b = values_seq[idx] rewards_b = rewards_seq[idx] returns_b = returns_seq[idx] adv_b = adv_seq[idx] old_logp_b = old_logp_seq[idx] old_mu_b = old_mu_seq[idx] old_sigma_b = old_sigma_seq[idx] dones_b = dones_seq[idx] gt_base_lin_vel_b = ( gt_base_lin_vel_seq[idx] if gt_base_lin_vel_seq is not None else None ) gt_root_height_b = ( gt_root_height_seq[idx] if gt_root_height_seq is not None else None ) gt_keybody_contact_b = ( gt_keybody_contact_seq[idx] if gt_keybody_contact_seq is not None else None ) gt_ref_keybody_rel_pos_b = ( gt_ref_keybody_rel_pos_seq[idx] if gt_ref_keybody_rel_pos_seq is not None else None ) gt_robot_keybody_rel_pos_b = ( gt_robot_keybody_rel_pos_seq[idx] if gt_robot_keybody_rel_pos_seq is not None else None ) gt_denoise_ref_root_lin_vel_b = ( gt_denoise_ref_root_lin_vel_seq[idx] if gt_denoise_ref_root_lin_vel_seq is not None else None ) gt_denoise_ref_root_ang_vel_b = ( gt_denoise_ref_root_ang_vel_seq[idx] if gt_denoise_ref_root_ang_vel_seq is not None else None ) gt_denoise_ref_dof_pos_b = ( gt_denoise_ref_dof_pos_seq[idx] if gt_denoise_ref_dof_pos_seq is not None else None ) attn_mask = self._build_episode_causal_mask(dones_b) yield ( obs_b, actions_b, values_b, adv_b, returns_b, rewards_b, old_logp_b, old_mu_b, old_sigma_b, attn_mask, gt_base_lin_vel_b, gt_root_height_b, gt_keybody_contact_b, gt_ref_keybody_rel_pos_b, gt_robot_keybody_rel_pos_b, gt_denoise_ref_root_lin_vel_b, gt_denoise_ref_root_ang_vel_b, gt_denoise_ref_dof_pos_b, ) def update(self): actor_unwrapped = self.accelerator.unwrap_model(self.actor) actor_policy = actor_unwrapped.actor_module actor_policy.set_collect_routing_stats(False) mean_value_loss = 0.0 mean_surrogate_loss = 0.0 mean_entropy = 0.0 mean_kl_token = 0.0 mean_kl_loss = 0.0 mean_kl_analytic = 0.0 critic_explained_variance = self._compute_explained_variance( target=self.storage.data["returns"], prediction=self.storage.data["values"], ) mean_aux_base_lin_vel_nll = 0.0 mean_aux_root_height_nll = 0.0 mean_aux_base_lin_vel_std = 0.0 mean_aux_root_height_std = 0.0 mean_aux_keybody_contact_bce = 0.0 mean_aux_keybody_contact_acc = 0.0 mean_aux_ref_keybody_rel_pos_mse = 0.0 mean_aux_robot_keybody_rel_pos_mse = 0.0 mean_aux_denoise_ref_root_lin_vel_huber = 0.0 mean_aux_denoise_ref_root_ang_vel_huber = 0.0 mean_aux_denoise_ref_dof_pos_huber = 0.0 mean_aux_router_command_recon_mse = 0.0 mean_aux_router_future_recon_huber = 0.0 mean_aux_router_switch_penalty_js = 0.0 mean_dead_expert_margin_to_topk_loss = 0.0 mean_router_expert_orthogonal_loss = 0.0 mean_selected_expert_margin_to_unselected_loss = 0.0 moe_layers = [ layer for layer in actor_policy.layers if isinstance(layer, GroupedMoEBlock) ] ( effective_num_mini_batches, mini_batch_envs, ) = self._resolve_sequence_batch_partition( self.storage.num_envs, self.num_mini_batches ) self._last_update_metrics = { "0-Train/configured_num_mini_batches": float( self.configured_num_mini_batches ), "0-Train/requested_num_mini_batches": float( self.requested_num_mini_batches ), "0-Train/effective_num_mini_batches": float( effective_num_mini_batches ), "0-Train/mini_batch_size_per_rank": float( mini_batch_envs * self.num_steps_per_env ), "0-Train/mini_batch_num_envs_per_rank": float(mini_batch_envs), "0-Train/num_updates_executed": 0.0, "0-Train/lr_scale_factor": float(self.distributed_lr_scale_factor), "0-Train/scalable_distributed_update": float( self.distributed_update_mode == "scalable" ), "0-Train/kl_windowed": 0.0, "0-Train/kl_stop_triggered": 0.0, "0-Train/kl_stop_analytic": 0.0, "0-Train/kl_analytic_batch_last": 0.0, "0-Train/kl_analytic_batch_max": 0.0, "0-Train/clip_fraction_batch_mean": 0.0, "0-Train/clip_fraction_batch_last": 0.0, } entropy_coef = self._get_effective_entropy_coef() generator = self._sequence_batches( effective_num_mini_batches, self.num_learning_epochs, ) measure_analytic_kl = self.desired_kl is not None normalize_per_mb = bool(self.normalize_advantage_per_mini_batch) num_updates = 0 num_kl_measurements = 0 kl_stop_triggered = False kl_stop_analytic = 0.0 kl_windowed = None recent_analytic_kls: list[float] = [] kl_analytic_batch_last = 0.0 kl_analytic_batch_max = 0.0 clip_fraction_batch_mean = 0.0 clip_fraction_batch_last = 0.0 for ( obs_b, actions_b, target_values_b, advantages_b, returns_b, _rewards_b, old_logp_b, old_mu_b, old_sigma_b, attn_mask_b, gt_base_lin_vel_b, gt_root_height_b, gt_keybody_contact_b, gt_ref_keybody_rel_pos_b, gt_robot_keybody_rel_pos_b, gt_denoise_ref_root_lin_vel_b, gt_denoise_ref_root_ang_vel_b, gt_denoise_ref_dof_pos_b, ) in generator: valid_tok = attn_mask_b.diagonal(dim1=1, dim2=2).to(torch.float32) valid_count = valid_tok.sum().clamp_min(1.0) if normalize_per_mb: with torch.no_grad(): flat = advantages_b.view(-1).float() if self.global_advantage_norm and self.is_distributed: count = torch.tensor( [flat.numel()], device=self.device, dtype=torch.float32, ) sum_g = self.accelerator.reduce( flat.sum(), reduction="sum" ) sqsum_g = self.accelerator.reduce( (flat * flat).sum(), reduction="sum" ) count_g = self.accelerator.reduce( count, reduction="sum" ) mean = sum_g / count_g var = (sqsum_g / count_g) - mean * mean std = torch.sqrt(var.clamp_min(1.0e-8)) else: mean = flat.mean() std = flat.std().clamp_min(1.0e-8) advantages_b = (advantages_b - mean) / std b, t = int(obs_b.batch_size[0]), int(obs_b.batch_size[1]) critic_obs_flat = obs_b.flatten(0, 1) with self.accelerator.autocast(): actor_out = self.actor( obs_b, actions=actions_b, mode="sequence_logp", attn_mask=attn_mask_b, update_obs_norm=False, ) critic_out = self.critic( critic_obs_flat, update_obs_norm=False ) logp_new_b = actor_out.get("actions_log_prob") mu_b = actor_out.get("mu") sigma_b = actor_out.get("sigma") entropy_b = actor_out.get("entropy") v_pred_flat = critic_out.get("values") value_batch = v_pred_flat.reshape(b, t, -1) returns_batch_norm = returns_b target_values_batch_norm = target_values_b analytic_kl = None if measure_analytic_kl: analytic_kl = self._compute_analytic_kl( old_mu=old_mu_b.float(), old_sigma=old_sigma_b.float(), new_mu=mu_b.float(), new_sigma=sigma_b.float(), weight=valid_tok, ) mean_kl_analytic += analytic_kl num_kl_measurements += 1 kl_analytic_batch_last = analytic_kl kl_analytic_batch_max = max(kl_analytic_batch_max, analytic_kl) recent_analytic_kls.append(analytic_kl) if len(recent_analytic_kls) > self.kl_early_stop_window_size: recent_analytic_kls.pop(0) kl_windowed = self._compute_windowed_kl_signal( recent_analytic_kls ) if self._should_early_stop_for_kl( kl_windowed, num_kl_measurements ): kl_stop_triggered = True kl_stop_analytic = analytic_kl break logp_new = logp_new_b.squeeze(-1).float() logp_old = old_logp_b.squeeze(-1).float() ratio = torch.exp(logp_new - logp_old) clip_fraction = self._compute_clip_fraction( ratio, weight=valid_tok ) clip_fraction_batch_mean += clip_fraction clip_fraction_batch_last = clip_fraction adv = advantages_b.squeeze(-1) s1 = ratio * adv s2 = ( torch.clamp( ratio, 1.0 - self.clip_param, 1.0 + self.clip_param ) * adv ) surrogate_loss = ( -torch.min(s1, s2) * valid_tok ).sum() / valid_count if self.use_clipped_value_loss: value_clipped = target_values_batch_norm + ( value_batch - target_values_batch_norm ).clamp(-self.clip_param, self.clip_param) value_losses = (value_batch - returns_batch_norm).pow(2) value_losses_clipped = ( value_clipped - returns_batch_norm ).pow(2) v_max = torch.max(value_losses, value_losses_clipped).squeeze( -1 ) value_loss = (v_max * valid_tok).sum() / valid_count else: v_err = (returns_batch_norm - value_batch).pow(2).squeeze(-1) value_loss = (v_err * valid_tok).sum() / valid_count actor_loss = surrogate_loss critic_loss = self.value_loss_coef * value_loss aux_base_lin_vel_loss = None aux_root_height_loss = None aux_base_lin_vel_std = None aux_root_height_std = None aux_keybody_contact_loss = None aux_keybody_contact_acc = None aux_ref_keybody_rel_pos_loss = None aux_robot_keybody_rel_pos_loss = None aux_denoise_ref_root_lin_vel_loss = None aux_denoise_ref_root_ang_vel_loss = None aux_denoise_ref_dof_pos_loss = None aux_router_command_recon_loss = None aux_router_future_recon_loss = None aux_router_switch_penalty_loss = None dead_expert_margin_to_topk_loss = None router_expert_orthogonal_loss = None selected_expert_margin_to_unselected_loss = None if self.use_aux_state_pred: aux_base_lin_vel_loc = actor_out.get("aux_base_lin_vel_loc") aux_base_lin_vel_log_std = actor_out.get( "aux_base_lin_vel_log_std" ) aux_base_lin_vel_std = torch.clamp( torch.exp(aux_base_lin_vel_log_std), min=self.aux_state_pred_min_std, max=self.aux_state_pred_max_std, ) aux_base_lin_vel_nll = 0.5 * ( torch.square( (gt_base_lin_vel_b - aux_base_lin_vel_loc) / aux_base_lin_vel_std ) + 2.0 * torch.log(aux_base_lin_vel_std + 1.0e-8) ).sum(dim=-1) aux_base_lin_vel_loss = ( aux_base_lin_vel_nll * valid_tok ).sum() / valid_count actor_loss = ( actor_loss + self.aux_state_pred_w_base_lin_vel * aux_base_lin_vel_loss ) aux_root_height_loc = actor_out.get("aux_root_height_loc") aux_root_height_log_std = actor_out.get( "aux_root_height_log_std" ) if self.use_aux_root_height and gt_root_height_b is not None: aux_root_height_std = torch.clamp( torch.exp(aux_root_height_log_std), min=self.aux_state_pred_min_std, max=self.aux_state_pred_max_std, ) aux_root_height_nll = 0.5 * ( torch.square( (gt_root_height_b - aux_root_height_loc) / aux_root_height_std ) + 2.0 * torch.log(aux_root_height_std + 1.0e-8) ).sum(dim=-1) aux_root_height_loss = ( aux_root_height_nll * valid_tok ).sum() / valid_count actor_loss = ( actor_loss + self.aux_state_pred_w_root_height * aux_root_height_loss ) else: actor_loss = actor_loss + 0.0 * ( aux_root_height_loc.sum() + aux_root_height_log_std.sum() ) if ( self.aux_state_pred_num_contact_bodies > 0 and gt_keybody_contact_b is not None ): aux_keybody_contact_logits = actor_out.get( "aux_keybody_contact_logits" ) contact_bce = F.binary_cross_entropy_with_logits( aux_keybody_contact_logits, gt_keybody_contact_b, reduction="none", ).mean(dim=-1) aux_keybody_contact_loss = ( contact_bce * valid_tok ).sum() / valid_count actor_loss = ( actor_loss + self.aux_state_pred_w_keybody_contact * aux_keybody_contact_loss ) contact_pred = (aux_keybody_contact_logits > 0.0).to( gt_keybody_contact_b.dtype ) contact_acc_tok = ( (contact_pred == gt_keybody_contact_b) .to(torch.float32) .mean(dim=-1) ) aux_keybody_contact_acc = ( contact_acc_tok * valid_tok ).sum() / valid_count aux_ref_keybody_rel_pos = actor_out.get( "aux_ref_keybody_rel_pos" ) aux_robot_keybody_rel_pos = actor_out.get( "aux_robot_keybody_rel_pos" ) if ( self.aux_state_pred_num_keybody_bodies > 0 and gt_ref_keybody_rel_pos_b is not None ): aux_ref_keybody_rel_pos_loss = ( self._masked_aux_keybody_mse( aux_ref_keybody_rel_pos, gt_ref_keybody_rel_pos_b, valid_tok, ) ) actor_loss = ( actor_loss + self.aux_state_pred_w_ref_keybody_rel_pos * aux_ref_keybody_rel_pos_loss ) elif aux_ref_keybody_rel_pos.numel() > 0: actor_loss = ( actor_loss + 0.0 * aux_ref_keybody_rel_pos.sum() ) if ( self.aux_state_pred_num_keybody_bodies > 0 and gt_robot_keybody_rel_pos_b is not None ): aux_robot_keybody_rel_pos_loss = ( self._masked_aux_keybody_mse( aux_robot_keybody_rel_pos, gt_robot_keybody_rel_pos_b, valid_tok, ) ) actor_loss = ( actor_loss + self.aux_state_pred_w_robot_keybody_rel_pos * aux_robot_keybody_rel_pos_loss ) elif aux_robot_keybody_rel_pos.numel() > 0: actor_loss = ( actor_loss + 0.0 * aux_robot_keybody_rel_pos.sum() ) if self.use_aux_denoise_ref_root_lin_vel: aux_denoise_ref_root_lin_vel_residual = actor_out.get( "aux_denoise_ref_root_lin_vel_residual" ) aux_denoise_ref_root_lin_vel_loss = self._masked_aux_huber( pred=aux_denoise_ref_root_lin_vel_residual, target=gt_denoise_ref_root_lin_vel_b, valid_tok=valid_tok, beta=self.aux_denoise_residual_huber_beta, ) actor_loss = ( actor_loss + self.aux_state_pred_w_denoise_ref_root_lin_vel * aux_denoise_ref_root_lin_vel_loss ) if self.use_aux_denoise_ref_root_ang_vel: aux_denoise_ref_root_ang_vel_residual = actor_out.get( "aux_denoise_ref_root_ang_vel_residual" ) aux_denoise_ref_root_ang_vel_loss = self._masked_aux_huber( pred=aux_denoise_ref_root_ang_vel_residual, target=gt_denoise_ref_root_ang_vel_b, valid_tok=valid_tok, beta=self.aux_denoise_residual_huber_beta, ) actor_loss = ( actor_loss + self.aux_state_pred_w_denoise_ref_root_ang_vel * aux_denoise_ref_root_ang_vel_loss ) if self.use_aux_denoise_ref_dof_pos: aux_denoise_ref_dof_pos_residual = actor_out.get( "aux_denoise_ref_dof_pos_residual" ) aux_denoise_ref_dof_pos_loss = self._masked_aux_huber( pred=aux_denoise_ref_dof_pos_residual, target=gt_denoise_ref_dof_pos_b, valid_tok=valid_tok, beta=self.aux_denoise_residual_huber_beta, ) actor_loss = ( actor_loss + self.aux_state_pred_w_denoise_ref_dof_pos * aux_denoise_ref_dof_pos_loss ) if self.use_aux_router_command_recon: if self.aux_router_command_recon_assembler is None: raise ValueError( "aux_router_command_recon is enabled but command " "assembler was not initialized." ) aux_router_command_recon_pred = actor_out.get( "aux_router_command_recon" ) gt_aux_router_command_recon_b = ( self.aux_router_command_recon_assembler( obs_b.flatten(0, 1) ).reshape(b, t, -1) ) aux_router_command_recon_loss = self._masked_aux_mse( aux_router_command_recon_pred, gt_aux_router_command_recon_b, valid_tok, ) actor_loss = ( actor_loss + self.aux_router_command_recon_weight * aux_router_command_recon_loss ) if self.use_aux_router_future_recon: aux_router_future_recon_loss = ( self._compute_aux_router_future_recon_loss( actor_wrapper=actor_unwrapped, actor_out=actor_out, obs_b=obs_b, valid_tok=valid_tok, ) ) actor_loss = ( actor_loss + self.aux_router_future_recon_weight * aux_router_future_recon_loss ) if self.use_aux_router_switch_penalty: if self.aux_router_switch_penalty_metric == "js": aux_router_features = actor_out.get("router_features") aux_router_switch_penalty_loss = self._masked_adjacent_router_js( router_features=aux_router_features, valid_tok=valid_tok, num_moe_layers=self.aux_command_router_num_moe_layers, num_fine_experts=self.aux_command_router_num_fine_experts, ) else: aux_router_temporal_features = actor_out.get( "router_temporal_features" ) aux_router_switch_penalty_loss = self._masked_adjacent_router_normed_smooth_l1( router_temporal_features=aux_router_temporal_features, valid_tok=valid_tok, num_moe_layers=self.aux_command_router_num_moe_layers, num_fine_experts=self.aux_command_router_num_fine_experts, beta=self.aux_router_switch_penalty_beta, ) aux_router_switch_penalty_loss = ( aux_router_switch_penalty_loss.to(actor_loss.dtype) ) actor_loss = ( actor_loss + self.aux_router_switch_penalty_weight * aux_router_switch_penalty_loss ) if self.use_dead_expert_margin_to_topk and len(moe_layers) > 0: margin_losses = [ layer.last_dead_expert_margin_to_topk_loss for layer in moe_layers if layer.last_dead_expert_margin_to_topk_loss is not None ] if len(margin_losses) > 0: dead_expert_margin_to_topk_loss = torch.stack( [ loss.to(actor_loss.device, dtype=actor_loss.dtype) for loss in margin_losses ] ).mean() actor_loss = ( actor_loss + self.dead_expert_margin_to_topk_weight * dead_expert_margin_to_topk_loss ) if self.use_router_expert_orthogonal and len(moe_layers) > 0: orth_losses = [] for layer in moe_layers: layer_orth_loss, _, _ = ( self._compute_routed_expert_orthogonal_loss( layer, dtype=actor_loss.dtype, device=actor_loss.device, ) ) orth_losses.append(layer_orth_loss) if len(orth_losses) > 0: router_expert_orthogonal_loss = torch.stack( orth_losses ).mean() actor_loss = ( actor_loss + self.router_expert_orthogonal_weight * router_expert_orthogonal_loss ) if ( self.use_selected_expert_margin_to_unselected and len(moe_layers) > 0 ): selected_margin_losses = [ layer.last_selected_expert_margin_to_unselected_loss for layer in moe_layers if layer.last_selected_expert_margin_to_unselected_loss is not None ] if len(selected_margin_losses) > 0: selected_expert_margin_to_unselected_loss = torch.stack( [ loss.to(actor_loss.device, dtype=actor_loss.dtype) for loss in selected_margin_losses ] ).mean() actor_loss = ( actor_loss + self.selected_expert_margin_to_unselected_weight * selected_expert_margin_to_unselected_loss ) kl_coef = float( getattr(self.config, "kl_coef", self.desired_kl or 0.0) or 0.0 ) if kl_coef > 0.0: delta_logp = logp_new - logp_old kl_token = ( ratio.detach() * delta_logp * valid_tok ).sum() / valid_count kl_loss = kl_coef * kl_token actor_loss = actor_loss + kl_loss mean_kl_token += float(kl_token.item()) mean_kl_loss += float(kl_loss.item()) if entropy_coef > 0.0: ent_tok = entropy_b.squeeze(-1) entropy_loss = (ent_tok * valid_tok).sum() / valid_count actor_loss = actor_loss - entropy_coef * entropy_loss self.actor_optimizer.zero_grad() self.critic_optimizer.zero_grad() self.accelerator.backward(actor_loss) self.accelerator.backward(critic_loss) if self.max_grad_norm is not None: self.accelerator.clip_grad_norm_( self.actor.parameters(), self.max_grad_norm ) self.accelerator.clip_grad_norm_( self.critic.parameters(), self.max_grad_norm ) self.actor_optimizer.step() self.critic_optimizer.step() num_updates += 1 mean_value_loss += float(value_loss.item()) mean_surrogate_loss += float(surrogate_loss.item()) mean_entropy += float(entropy_b.mean().item()) if aux_base_lin_vel_loss is not None: mean_aux_base_lin_vel_nll += float( aux_base_lin_vel_loss.item() ) if aux_root_height_loss is not None: mean_aux_root_height_nll += float(aux_root_height_loss.item()) if aux_base_lin_vel_std is not None: mean_aux_base_lin_vel_std += float( aux_base_lin_vel_std.mean().item() ) if aux_root_height_std is not None: mean_aux_root_height_std += float( aux_root_height_std.mean().item() ) if aux_keybody_contact_loss is not None: mean_aux_keybody_contact_bce += float( aux_keybody_contact_loss.item() ) if aux_keybody_contact_acc is not None: mean_aux_keybody_contact_acc += float( aux_keybody_contact_acc.item() ) if aux_ref_keybody_rel_pos_loss is not None: mean_aux_ref_keybody_rel_pos_mse += float( aux_ref_keybody_rel_pos_loss.item() ) if aux_robot_keybody_rel_pos_loss is not None: mean_aux_robot_keybody_rel_pos_mse += float( aux_robot_keybody_rel_pos_loss.item() ) if aux_denoise_ref_root_lin_vel_loss is not None: mean_aux_denoise_ref_root_lin_vel_huber += float( aux_denoise_ref_root_lin_vel_loss.item() ) if aux_denoise_ref_root_ang_vel_loss is not None: mean_aux_denoise_ref_root_ang_vel_huber += float( aux_denoise_ref_root_ang_vel_loss.item() ) if aux_denoise_ref_dof_pos_loss is not None: mean_aux_denoise_ref_dof_pos_huber += float( aux_denoise_ref_dof_pos_loss.item() ) if aux_router_command_recon_loss is not None: mean_aux_router_command_recon_mse += float( aux_router_command_recon_loss.item() ) if aux_router_future_recon_loss is not None: mean_aux_router_future_recon_huber += float( aux_router_future_recon_loss.item() ) if aux_router_switch_penalty_loss is not None: mean_aux_router_switch_penalty_js += float( aux_router_switch_penalty_loss.item() ) if dead_expert_margin_to_topk_loss is not None: mean_dead_expert_margin_to_topk_loss += float( dead_expert_margin_to_topk_loss.item() ) if router_expert_orthogonal_loss is not None: mean_router_expert_orthogonal_loss += float( router_expert_orthogonal_loss.item() ) if selected_expert_margin_to_unselected_loss is not None: mean_selected_expert_margin_to_unselected_loss += float( selected_expert_margin_to_unselected_loss.item() ) actor_policy.apply_dynamic_bias_update_from_stats() denom = max(1, num_updates) mean_value_loss /= denom mean_surrogate_loss /= denom mean_entropy /= denom mean_kl_token /= denom mean_kl_loss /= denom mean_kl_analytic /= max(1, num_kl_measurements) clip_fraction_batch_mean /= denom if self.schedule == "adaptive": self._apply_adaptive_lr(kl_windowed) mean_aux_base_lin_vel_nll /= denom mean_aux_root_height_nll /= denom mean_aux_base_lin_vel_std /= denom mean_aux_root_height_std /= denom mean_aux_keybody_contact_bce /= denom mean_aux_keybody_contact_acc /= denom mean_aux_ref_keybody_rel_pos_mse /= denom mean_aux_robot_keybody_rel_pos_mse /= denom mean_aux_denoise_ref_root_lin_vel_huber /= denom mean_aux_denoise_ref_root_ang_vel_huber /= denom mean_aux_denoise_ref_dof_pos_huber /= denom mean_aux_router_command_recon_mse /= denom mean_aux_router_future_recon_huber /= denom mean_aux_router_switch_penalty_js /= denom mean_dead_expert_margin_to_topk_loss /= denom mean_router_expert_orthogonal_loss /= denom mean_selected_expert_margin_to_unselected_loss /= denom self._last_update_metrics["0-Train/num_updates_executed"] = float( num_updates ) self._last_update_metrics["0-Train/kl_windowed"] = float( kl_windowed or 0.0 ) self._last_update_metrics["0-Train/kl_stop_triggered"] = float( kl_stop_triggered ) self._last_update_metrics["0-Train/kl_stop_analytic"] = float( kl_stop_analytic ) self._last_update_metrics["0-Train/kl_analytic_batch_last"] = float( kl_analytic_batch_last ) self._last_update_metrics["0-Train/kl_analytic_batch_max"] = float( kl_analytic_batch_max ) self._last_update_metrics["0-Train/clip_fraction_batch_mean"] = float( clip_fraction_batch_mean ) self._last_update_metrics["0-Train/clip_fraction_batch_last"] = float( clip_fraction_batch_last ) moe_layers = [ layer for layer in actor_unwrapped.actor_module.layers if isinstance(layer, GroupedMoEBlock) ] moe_active_expert_ratio = None moe_max_expert_frac = None moe_least_expert_frac = None moe_dead_expert_ratio = None moe_expert_count_cv = None moe_selected_expert_margin_to_unselected = None moe_last_router_js_step = None moe_last_router_top1_switch_rate = None if len(moe_layers) > 0: moe_metrics = self._summarize_moe_layer_stats(moe_layers) moe_active_expert_ratio = moe_metrics["moe_active_expert_ratio"] moe_max_expert_frac = moe_metrics["moe_max_expert_frac"] moe_least_expert_frac = moe_metrics["moe_least_expert_frac"] moe_dead_expert_ratio = moe_metrics["moe_dead_expert_ratio"] moe_expert_count_cv = moe_metrics["moe_expert_count_cv"] moe_selected_expert_margin_to_unselected = moe_metrics[ "moe_selected_expert_margin_to_unselected" ] router_shift_stats = actor_policy.get_last_moe_router_shift_stats() js_sum = router_shift_stats["js_sum"] js_count = router_shift_stats["js_count"] top1_switch_sum = router_shift_stats["top1_switch_sum"] top1_switch_count = router_shift_stats["top1_switch_count"] if ( js_sum is not None and js_count is not None and top1_switch_sum is not None and top1_switch_count is not None ): js_sum = js_sum.detach().to(self.device, dtype=torch.float32) js_count = js_count.detach().to( self.device, dtype=torch.float32 ) top1_switch_sum = top1_switch_sum.detach().to( self.device, dtype=torch.float32 ) top1_switch_count = top1_switch_count.detach().to( self.device, dtype=torch.float32 ) if self.is_distributed: js_sum = self.accelerator.reduce(js_sum, reduction="sum") js_count = self.accelerator.reduce( js_count, reduction="sum" ) top1_switch_sum = self.accelerator.reduce( top1_switch_sum, reduction="sum" ) top1_switch_count = self.accelerator.reduce( top1_switch_count, reduction="sum" ) if float(js_count.item()) > 0.0: moe_last_router_js_step = float((js_sum / js_count).item()) if float(top1_switch_count.item()) > 0.0: moe_last_router_top1_switch_rate = float( (top1_switch_sum / top1_switch_count).item() ) self.storage.clear() loss_out = { "value_function": mean_value_loss, "critic_explained_variance": critic_explained_variance, "surrogate": mean_surrogate_loss, "entropy": mean_entropy, "kl_token": mean_kl_token, "kl_loss": mean_kl_loss, "kl_analytic": mean_kl_analytic, "aux_base_lin_vel_nll": mean_aux_base_lin_vel_nll, "aux_root_height_nll": mean_aux_root_height_nll, "aux_base_lin_vel_std": mean_aux_base_lin_vel_std, "aux_root_height_std": mean_aux_root_height_std, "aux_keybody_contact_bce": mean_aux_keybody_contact_bce, "aux_keybody_contact_acc": mean_aux_keybody_contact_acc, "aux_ref_keybody_rel_pos_mse": mean_aux_ref_keybody_rel_pos_mse, "aux_robot_keybody_rel_pos_mse": ( mean_aux_robot_keybody_rel_pos_mse ), "aux_denoise_ref_root_lin_vel_huber": ( mean_aux_denoise_ref_root_lin_vel_huber ), "aux_denoise_ref_root_ang_vel_huber": ( mean_aux_denoise_ref_root_ang_vel_huber ), "aux_denoise_ref_dof_pos_huber": ( mean_aux_denoise_ref_dof_pos_huber ), "aux_router_command_recon_mse": mean_aux_router_command_recon_mse, "aux_router_future_recon_huber": ( mean_aux_router_future_recon_huber ), "aux_router_switch_penalty_js": ( mean_aux_router_switch_penalty_js ), "dead_expert_margin_to_topk": ( mean_dead_expert_margin_to_topk_loss ), "router_expert_orthogonal": mean_router_expert_orthogonal_loss, "selected_expert_margin_to_unselected": ( mean_selected_expert_margin_to_unselected_loss ), "moe_active_expert_ratio": moe_active_expert_ratio, "moe_max_expert_frac": moe_max_expert_frac, "moe_least_expert_frac": moe_least_expert_frac, "moe_dead_expert_ratio": moe_dead_expert_ratio, "moe_expert_count_cv": moe_expert_count_cv, "moe_selected_expert_margin_to_unselected": ( moe_selected_expert_margin_to_unselected ), "moe_last_router_js_step": moe_last_router_js_step, "moe_last_router_top1_switch_rate": ( moe_last_router_top1_switch_rate ), } if self.is_distributed: reduced_out = {} for k, v in loss_out.items(): if v is None: reduced_out[k] = None continue t = torch.tensor(v, device=self.device, dtype=torch.float32) reduced_t = self.accelerator.reduce(t, reduction="mean") reduced_out[k] = float(reduced_t.item()) loss_out = reduced_out self._post_update_hook(loss_out) return loss_out ================================================ FILE: holomotion/src/data_curation/.gitignore ================================================ _generated/ ================================================ FILE: holomotion/src/data_curation/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/data_curation/data_smplify.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # import argparse import os from smplify.smplify_humanact12 import humanact12_to_amass from smplify.smplify_motionx import motionx_to_amass from smplify.smplify_omomo import omomo_to_amass from holomotion.holomotion.src.data_curation.smplify.smplify_zjumocap import ( zju_to_amass, ) def ensure_dir(path): """Make sure the dir exist. Args: path: The path of the dir. """ if not os.path.exists(path): os.makedirs(path) def main(): """Convert multiple motion capture datasets to AMASS format. This function parses command-line arguments to specify the root directory of raw datasets and an optional save directory. It iterates over the supported datasets (MotionX, ZJU_Mocap, HumanAct12, OMOMO), and if the corresponding data directory exists, converts it to AMASS format and saves it in a unified directory structure. Raises: SystemExit: If required command-line arguments are missing. Side Effects: Creates output directories and writes converted data files. Prints progress and warning messages to stdout. """ parser = argparse.ArgumentParser( description="Convert all datasets to AMASS format" ) parser.add_argument( "--data_root", type=str, required=True, help="Path to the root directory of raw datasets", ) parser.add_argument( "--save_root", type=str, default=None, help="Path to save the unified data (default: data_root/smplx_data)", ) args = parser.parse_args() data_root = os.path.abspath(args.data_root) save_root = args.save_root or "./data/amass_compatible_datasets" print(f"Raw data root: {data_root}") print(f"Unified data will be saved to: {save_root}") ensure_dir(save_root) datasets = [ ("MotionX", motionx_to_amass), ("ZJU_Mocap", zju_to_amass), ("humanact12", humanact12_to_amass), ("OMOMO", omomo_to_amass), ] for name, func in datasets: data_dir = os.path.join(data_root, name) save_dir = os.path.join(save_root, name) ensure_dir(save_dir) if not os.path.exists(data_dir): print(f"Warning: {data_dir} does not exist. Skipping {name}.") continue print(f"Processing {name}...") func(data_dir, save_dir) print(f"{name} done. Saved to {save_dir}.\n") print("All datasets processed.") if __name__ == "__main__": main() ================================================ FILE: holomotion/src/data_curation/filter/filter.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import argparse import json import os import numpy as np def checksitpose( npz_path, ref_pose_path, threshold=0.75, frame_thresh=1 ) -> bool: """Check if the given motion sequence is close to the pose. Args: npz_path (str): Path to the .npz file of the motion sequence. ref_pose_path (str): the reference sitting pose. threshold (float, optional): Euclidean distance threshold. frame_thresh (int, optional): Minimum number of frames. Returns: bool: True if the sequence contains sitting-like frames. """ count = 0 try: sitdata = np.load(ref_pose_path) sitpose = sitdata["poses"][535][:66] # reference sitting pose except Exception: return False sitpose_down = sitpose[3:36] # lower-body joints only bdata = np.load(npz_path) curposes = bdata["poses"] # shape: (N, 165) for pose in curposes: pose_down = pose[3:36] dist = np.linalg.norm(pose_down - sitpose_down) if dist < threshold: count += 1 if count >= frame_thresh: return True return False def process_dataset( parent_folder, json_path, output_path, abnormal_path, sit_pose_reference, stair_keywords=None, sit_keywords=None, sit_threshold=0.75, frame_threshold=20, velocity_threshold=100.0, ): """Label the dataset under parent folder.""" stair_keywords = stair_keywords or [ "stairs", "staircase", "upstairs", "downstairs", ] sit_keywords = sit_keywords or ["sitting", "Sitting"] abnormal_dataset = ["aist"] stairs = sit = untrack = 0 filtered_paths = set() with ( open(json_path) as f_in, open(output_path, "w") as f_out_normal, open(abnormal_path, "w") as f_out_abnormal, ): for line in f_in: line = line.strip() if not line: continue try: content = json.loads(line) path = content.get("path", "") npz_path = os.path.join(parent_folder, path) # skip the path if it is abnormal if path in filtered_paths: f_out_abnormal.write(line + "\n") continue up_z = content.get("max_up_z_velocity", 0) down_z = content.get("max_down_z_velocity", 0) max_z = content.get("max_z_translation", 0) min_z = content.get("min_z_translation", 0) mean_v = content.get("mean_velocity", 0) # filter by keywords if any(kw in path for kw in stair_keywords): f_out_abnormal.write(line + "\n") filtered_paths.clear() filtered_paths.add(path) stairs += 1 continue elif any(kw in path for kw in sit_keywords): f_out_abnormal.write(line + "\n") filtered_paths.clear() filtered_paths.add(path) sit += 1 continue elif any(kw in path for kw in abnormal_dataset): f_out_abnormal.write(line + "\n") filtered_paths.clear() filtered_paths.add(path) continue if mean_v > velocity_threshold: f_out_abnormal.write(line + "\n") filtered_paths.clear() filtered_paths.add(path) untrack += 1 continue if up_z >= 0.6 and max_z > 0.7: f_out_abnormal.write(line + "\n") filtered_paths.clear() filtered_paths.add(path) stairs += 1 continue elif down_z <= -0.7 and min_z < -0.7: f_out_abnormal.write(line + "\n") filtered_paths.clear() filtered_paths.add(path) stairs += 1 continue if checksitpose( npz_path, sit_pose_reference, sit_threshold, frame_threshold, ): f_out_abnormal.write(line + "\n") filtered_paths.add(path) sit += 1 continue # normal motion f_out_normal.write(line + "\n") except Exception as e: print(f"Error processing line: {line}\nException: {e}") print( f"total abnormal data:upstairs {stairs}, sitting {sit}, \ velocity {untrack}" ) def jsonl_to_yaml(jsonl_path, yaml_output_path): """Convert jsonl file into yaml file.""" output_set = set() with open(jsonl_path) as f: for line in f: if not line.strip(): continue try: data = json.loads(line) path = data.get("path", "") if path: clean_path = os.path.splitext(path.strip().lstrip("/"))[0] new_name = "0-" + clean_path.replace("/", "_").replace( "\\", "_" ) output_set.add(f"{new_name}") except json.JSONDecodeError: print(f"skip json line: {line.strip()}") continue with open(yaml_output_path, "w") as out: out.write("[\n") for item in sorted(output_set): out.write(f" {item},\n") out.write("]\n") print(f"done, total {len(output_set)} items -> {yaml_output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Filter AMASS dataset and save results." ) parser.add_argument( "--parent_folder", type=str, default="./data/amass_compatible_datasets", help="Path to the parent folder of AMASS data", ) parser.add_argument( "--json_path", type=str, default="./data/dataset_labels/OMOMO.jsonl", help="Path to the input JSONL file", ) parser.add_argument( "--output_path", type=str, default="./data/dataset_labels/temp.jsonl", help="Path to save the filtered output JSONL", ) parser.add_argument( "--abnormal_path", type=str, default="./data/dataset_labels/temp2.jsonl", help="Path to save abnormal data JSONL", ) parser.add_argument( "--sit_pose_reference", type=str, default="./data/amass_compatible_datasets/amass/BioMotionLab_NTroje/rub062/0016_sitting2_poses.npz", help="Path to the reference sitting pose npz", ) parser.add_argument( "--yaml_path", type=str, default="./holomotion/config/data_curation/base.yaml", help="Path to the excluded yaml file", ) args = parser.parse_args() process_dataset( parent_folder=args.parent_folder, json_path=args.json_path, output_path=args.output_path, abnormal_path=args.abnormal_path, sit_pose_reference=args.sit_pose_reference, ) os.makedirs(os.path.dirname(args.yaml_path), exist_ok=True) jsonl_to_yaml(args.abnormal_path, args.yaml_path) ================================================ FILE: holomotion/src/data_curation/filter/label_data.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # import argparse import json import os import sys import numpy as np sys.path.append( "./holomotion/src/data_curation/omomo_release/human_body_prior/src" ) def calc_max_xy_translation(motion_data: dict): """Calculate max xy translation.""" trans = motion_data["trans"] root_trans_offset = trans max_xy_translation = np.max( np.linalg.norm( root_trans_offset[:, :2] - root_trans_offset[0:1, :2], axis=1, ) ) return max_xy_translation def calc_max_z_translation(motion_data: dict): """Calculate max and min z translation.""" trans = motion_data["trans"] root_trans_offset = trans max_z_translation = np.max( root_trans_offset[:, 2] - root_trans_offset[0:1, 2] ) min_z_translation = np.min( root_trans_offset[:, 2] - root_trans_offset[0:1, 2] ) return max_z_translation, min_z_translation def calc_max_velocity_scale(motion_data: dict, fps: float = 30): """Calculate max velocity scale.""" root_trans_offset = motion_data["trans"] est_root_vel = np.diff(root_trans_offset * fps, axis=0) root_vel_norm = np.linalg.norm(est_root_vel, axis=-1) max_velocity_scale = np.max(root_vel_norm) return max_velocity_scale def calc_mean_velocity_scale(motion_data: dict, fps: float = 30): """Calculate mean velocity scale.""" root_trans_offset = motion_data["trans"] est_root_vel = np.diff(root_trans_offset * fps, axis=0) root_vel_norm = np.linalg.norm(est_root_vel, axis=-1) mean_velocity_scale = np.mean(root_vel_norm) return mean_velocity_scale def calc_std_velocity_scale(motion_data: dict, fps: float = 30): """Calculate std velocity scale.""" root_trans_offset = motion_data["trans"] est_root_vel = np.diff(root_trans_offset * fps, axis=0) root_vel_norm = np.linalg.norm(est_root_vel, axis=-1) std_velocity_scale = np.std(root_vel_norm) return std_velocity_scale def calc_max_vxy_scale(motion_data: dict, fps: float = 30): """Calculate smax vx, vy scale.""" root_trans_offset = motion_data["trans"] est_root_vel = np.diff(root_trans_offset * fps, axis=0) root_vel_norm = np.linalg.norm(est_root_vel[:, :2], axis=-1) max_vxy_scale = np.max(root_vel_norm) mean_vxy_scale = np.mean(root_vel_norm) std_vxy_scale = np.std(root_vel_norm) return max_vxy_scale, mean_vxy_scale, std_vxy_scale def calc_std_accel(motion_data: dict, fps: float = 30.0) -> float: """Calculate the standard deviation of root joint acceleration. This function computes the per-frame acceleration of the root joint in the XY plane from its translation data and returns the standard deviation of those values. Args: motion_data (dict): A dictionary that must contain a 'trans' key representing global translation of the root joint. Shape should be (T, 3), where T is the number of frames. fps (float): Frames per second of the motion sequence. Returns: float: Standard deviation of the acceleration magnitudes on the XY plane. Returns 0.0 if there are fewer than 3 frames. """ trans = motion_data["trans"] # shape: (T, 3) if trans.shape[0] < 3: return 0.0 # At least 3 frames are needed to compute two differences # Compute velocity (frame-to-frame displacement * fps) velocities = np.diff(trans, axis=0) * fps # shape: (T-1, 3) # Compute acceleration (frame-to-frame velocity difference * fps) accelerations = np.diff(velocities, axis=0) * fps # shape: (T-2, 3) # Compute acceleration magnitude in XY plane accel_xy_norm = np.linalg.norm( accelerations[:, :2], axis=1 ) # shape: (T-2,) # Return standard deviation return np.std(accel_xy_norm) def calc_max_vz_scale(motion_data: dict, fps: float = 30): """Calculate max vz scale.""" root_trans_offset = motion_data["trans"] est_root_vel = np.diff(root_trans_offset * fps, axis=0) root_vel_norm = np.abs(est_root_vel[:, 2]) max_vz_scale = np.max(root_vel_norm) mean_vz_scale = np.mean(root_vel_norm) std_vz_scale = np.std(root_vel_norm) return max_vz_scale, mean_vz_scale, std_vz_scale def calc_vz_scale_with_direction(motion_data: dict, fps: float = 30): """Calculate vz scale with direction.""" root_trans_offset = motion_data["trans"] est_root_vel = np.diff(root_trans_offset * fps, axis=0) vz = est_root_vel[:, 2] max_up_vz = np.max(vz[vz > 0]) if np.any(vz > 0) else 0.0 max_down_vz = np.min(vz[vz < 0]) if np.any(vz < 0) else 0.0 mean_vz = np.mean(vz) std_vz = np.std(vz) return max_up_vz, max_down_vz, mean_vz, std_vz def beyond_upper_dof_limits( motion_data: dict, upper_dof_mapping: dict, upper_dof_max_limits: dict, ): """Check whether or not the motion data is beyond upper dof limits.""" for dof_name, dof_idx in upper_dof_mapping.items(): dof_data = motion_data["dof"][:, dof_idx] max_dof_scale = np.max(dof_data) min_dof_scale = np.min(dof_data) if ( max_dof_scale < upper_dof_max_limits[dof_name][0] or max_dof_scale > upper_dof_max_limits[dof_name][1] or min_dof_scale < upper_dof_max_limits[dof_name][0] or min_dof_scale > upper_dof_max_limits[dof_name][1] ): return True return False class HyperParams: max_xy_translation: float = 2.0 max_z_translation: float = 0.3 max_velocity_scale: float = 1.0 max_vxy_scale: float = 1.2 max_vz_scale: float = 0.3 upper_dof_mapping: dict = { "left_shoulder_pitch_joint": 13, "left_shoulder_roll_joint": 14, "left_shoulder_yaw_joint": 15, "left_elbow_joint": 16, "right_shoulder_pitch_joint": 17, "right_shoulder_roll_joint": 18, "right_shoulder_yaw_joint": 19, "right_elbow_joint": 20, } upper_dof_max_limits: dict = { "left_shoulder_pitch_joint": [-1.0, 1.0], "left_shoulder_roll_joint": [0.0, 0.5], "left_shoulder_yaw_joint": [-0.5, 0.5], "left_elbow_joint": [0.5, 1.3], "right_shoulder_pitch_joint": [-1.0, 1.0], "right_shoulder_roll_joint": [-0.5, 0.0], "right_shoulder_yaw_joint": [-0.5, 0.3], "right_elbow_joint": [0.5, 1.5], } def label_data_with_metrics(data_folder, jsonl_path: str, parent_folder: str): """Calculate the metics and load them into a jsonl file.""" assert jsonl_path.endswith(".jsonl") with open(jsonl_path, "w") as f_out: for root, _, files in os.walk(data_folder): for file in files: if file.endswith(".npz"): npz_path = os.path.join(root, file) content = {} content["path"] = os.path.relpath(npz_path, parent_folder) data = np.load(npz_path) fps = data.get("mocap_frame_rate") if fps is None: fps = data.get("mocap_framerate") if fps is None: continue try: content["max_xy_translation"] = round( calc_max_xy_translation(data), 2 ) max_z_translation, min_z_translation = ( calc_max_z_translation(data) ) content["max_z_translation"] = round( max_z_translation, 2 ) content["min_z_translation"] = round( min_z_translation, 2 ) content["max_velocity"] = round( calc_max_velocity_scale(data, fps), 2 ) content["mean_velocity"] = round( calc_mean_velocity_scale(data, fps), 2 ) content["std_velocity"] = round( calc_std_velocity_scale(data, fps), 2 ) content["std_accel"] = round( calc_std_accel(data, fps), 2 ) max_xy_v, mean_xy_v, std_xy_v = calc_max_vxy_scale( data, fps ) content["max_xy_velocity"] = round(max_xy_v, 2) content["mean_xy_velocity"] = round(mean_xy_v, 2) content["std_xy_velocity"] = round(std_xy_v, 2) max_up_z_v, max_down_z_v, mean_z_v, std_z_v = ( calc_vz_scale_with_direction(data, fps) ) content["max_up_z_velocity"] = round(max_up_z_v, 2) content["max_down_z_velocity"] = round(max_down_z_v, 2) content["mean_z_velocity"] = round(mean_z_v, 2) content["std_z_velocity"] = round(std_z_v, 2) except Exception as e: print(f"Error: {e}") def convert_to_builtin_type(obj): if isinstance(obj, dict): return { k: convert_to_builtin_type(v) for k, v in obj.items() } elif isinstance(obj, list): return [convert_to_builtin_type(i) for i in obj] elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, (np.float32, np.float64)): return float(obj) elif isinstance(obj, (np.int32, np.int64)): return int(obj) else: return obj f_out.write( json.dumps(convert_to_builtin_type(content)) + "\n" ) print(f"Annotated file saved to: {jsonl_path}") def main(): parser = argparse.ArgumentParser() parser.add_argument( "--jsonl_list", nargs="+", default=["humanact12", "MotionX", "OMOMO", "ZJU_Mocap", "amass"], help="List of jsonl files to process.", ) args = parser.parse_args() amass_folder = "./data/amass_compatible_datasets/amass" other_folder = "./data/amass_compatible_datasets" caption_folder = "./data/dataset_labels" os.makedirs(caption_folder, exist_ok=True) for name in args.jsonl_list: file = name + ".jsonl" if name != "amass": label_data_with_metrics( os.path.join(other_folder, name), os.path.join(caption_folder, file), other_folder, ) else: label_data_with_metrics( amass_folder, os.path.join(caption_folder, file), amass_folder ) if __name__ == "__main__": main() ================================================ FILE: holomotion/src/data_curation/smpl_npz_to_html.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import argparse import json from dataclasses import dataclass from pathlib import Path from typing import Tuple import numpy as np from scipy.spatial.transform import Rotation as R # ----------------------------- # Defaults # ----------------------------- DEFAULT_TEMPLATE_PATH = Path("index_wooden_static.html") DEFAULT_OUT_HTML = Path("vis.html") POSE_JOINTS = 22 EULER_FIX_DEG = (-90.0, 180.0, 0.0) EULER_ORDER = "xyz" # Empirical vertical offset (in meters) to align wooden_static visualization mesh # with canonical SMPL coordinates (e.g., GVHMR pipelines). WOODEN_SMPL_HEIGHT_OFFSET = 0.2 @dataclass(frozen=True) class SmplSequence: """A minimal SMPL motion sequence loaded from npz.""" poses: np.ndarray # (T, 66) = root(3) + body(63), axis-angle trans: np.ndarray # (T, 3) betas: np.ndarray # (B,) fps: float gender: str def parse_args() -> argparse.Namespace: ap = argparse.ArgumentParser( description="Generate vis.html from a SMPL npz using a HTML template." ) ap.add_argument("--npz", type=Path, help="Path to input .npz") ap.add_argument( "--template", type=Path, default=DEFAULT_TEMPLATE_PATH, help="Path to HTML template", ) ap.add_argument( "--out", type=Path, default=DEFAULT_OUT_HTML, help="Path to output HTML", ) ap.add_argument( "--pose_joints", type=int, default=POSE_JOINTS, help=f"Number of pose joints in poses (default: {POSE_JOINTS}).", ) ap.add_argument( "--height_axis", type=int, default=1, choices=[0, 1, 2], help="Axis index for height in Th (default: 1 for Y-up).", ) ap.add_argument( "--height_offset", type=float, default=WOODEN_SMPL_HEIGHT_OFFSET, help=( "Subtract from Th height axis (Y-up), in meters. " "Default is an empirical offset to align wooden_static mesh " "with canonical SMPL coordinates (e.g., GVHMR)." ), ) return ap.parse_args() def euler_fix_rot(euler_deg=EULER_FIX_DEG, order=EULER_ORDER) -> R: """Rotation for world-frame correction: R_new = R_fix * R_old.""" return R.from_euler(order.lower(), euler_deg, degrees=True) def _require_key(data: np.lib.npyio.NpzFile, key: str) -> np.ndarray: if key not in data: raise KeyError( f"Missing key '{key}' in npz. Available: {list(data.keys())}" ) return data[key] def load_npz(path: Path) -> SmplSequence: if not path.exists(): raise FileNotFoundError(f"Missing {path}") data = np.load(path, allow_pickle=False) poses = _require_key(data, "poses").astype(np.float32) trans = _require_key(data, "trans").astype(np.float32) betas = _require_key(data, "betas").astype(np.float32) fps = float(np.asarray(_require_key(data, "mocap_framerate"))) gender = str(np.asarray(_require_key(data, "gender"))) return SmplSequence( poses=poses, trans=trans, betas=betas, fps=fps, gender=gender ) def validate_sequence(seq: SmplSequence, pose_joints: int) -> int: """Validate shapes and return T.""" if seq.poses.ndim != 2: raise ValueError(f"poses must be 2D, got shape={seq.poses.shape}") if seq.trans.ndim != 2 or seq.trans.shape[1] != 3: raise ValueError(f"trans must be (T,3), got shape={seq.trans.shape}") T = int(seq.poses.shape[0]) exp_dim = int(pose_joints) * 3 if seq.poses.shape[1] != exp_dim: raise ValueError( f"unexpected poses shape: {seq.poses.shape}, expected (T,{exp_dim})" ) if seq.trans.shape[0] != T: raise ValueError( f"poses frames ({T}) != trans frames ({seq.trans.shape[0]})" ) return T def build_smpl_frames( seq: SmplSequence, *, pose_joints: int, height_axis: int, height_offset: float, ) -> Tuple[list, int]: """ Build frames in the format expected by index_wooden_static.html template. Notes: height_offset is a visualization-only correction to compensate for the vertical origin mismatch between wooden_static mesh and canonical SMPL coordinates (e.g., GVHMR). Override via --height_offset if needed. """ T = validate_sequence(seq, pose_joints) rot_fix = euler_fix_rot() root_aa = seq.poses[:, :3] body_aa = seq.poses[:, 3:] # root: left-multiply world rotation Rh = (rot_fix * R.from_rotvec(root_aa)).as_rotvec().astype(np.float32) # trans: rotate in world frame, then apply visualization height offset Th = rot_fix.apply(seq.trans).astype(np.float32) if height_offset != 0.0: Th[:, int(height_axis)] -= float(height_offset) # pad hands (6) -> body(63) + hand(6) = 69 poses_js = np.concatenate([body_aa, np.zeros((T, 6), np.float32)], axis=1) shapes = seq.betas.reshape(-1).tolist() frames = [ [ { "id": 0, "gender": seq.gender, "Rh": [Rh[f].tolist()], "Th": [Th[f].tolist()], "poses": [poses_js[f].tolist()], "shapes": shapes, } ] for f in range(T) ] return frames, T def render_html(template_path: Path, frames: list, T: int, fps: float) -> str: template = template_path.read_text(encoding="utf-8") smpl_data_json = json.dumps(frames, ensure_ascii=False) caption_html = ( "
" f"Frames: {T}    Framerate: {fps:.1f} fps" "
" ) return template.replace("{{ smpl_data_json }}", smpl_data_json).replace( "{{ caption_html }}", caption_html ) def main( npz_path: Path, template_path: Path, out_html: Path, *, pose_joints: int, height_axis: int, height_offset: float, ) -> None: if not template_path.exists(): raise FileNotFoundError(f"Missing {template_path}") seq = load_npz(npz_path) frames, T = build_smpl_frames( seq, pose_joints=pose_joints, height_axis=height_axis, height_offset=height_offset, ) html = render_html(template_path, frames, T, seq.fps) out_html.write_text(html, encoding="utf-8") print(f"[OK] wrote {out_html.resolve()}") if __name__ == "__main__": args = parse_args() main( args.npz, args.template, args.out, pose_joints=args.pose_joints, height_axis=args.height_axis, height_offset=args.height_offset, ) ================================================ FILE: holomotion/src/data_curation/smplify/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/data_curation/smplify/smplify_humanact12.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import random import h5py import numpy as np import smplx import torch from scipy.spatial.transform import Rotation from tqdm import tqdm from thirdparties.joints2smpl.src import config from thirdparties.joints2smpl.src.smplify import SMPLify3D SMPL_MODEL_DIR = "./assets/smpl/" device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") num_joints = 22 joint_category = "AMASS" num_smplify_iters = 150 fix_foot = False def joints2smpl(file_name, data_dir, save_dir): """Convert 3D joint positions to SMPL-X parameters. Args: file_name (str): Name of the input .npy joint file data_dir (str): Directory containing input joint files save_dir (str): Directory to save processed output files """ # print(file_name) input_joints = np.load(os.path.join(data_dir, file_name)) input_joints = input_joints[:, :, [0, 1, 2]] # amass stands on x, y """XY at origin""" input_joints[..., [0, 1]] -= input_joints[0, 0, [0, 1]] """Put on Floor""" floor_height = input_joints[:, :, 2].min() input_joints[:, :, 2] -= floor_height batch_size = input_joints.shape[0] smplmodel = smplx.create( SMPL_MODEL_DIR, model_type="smpl", gender="neutral", ext="npz", batch_size=batch_size, ).to(device) # ## --- load the mean pose as original ---- smpl_mean_file = config.SMPL_MEAN_FILE file = h5py.File(smpl_mean_file, "r") init_mean_pose = ( torch.from_numpy(file["pose"][:]) .unsqueeze(0) .repeat(batch_size, 1) .float() .to(device) ) init_mean_shape = ( torch.from_numpy(file["shape"][:]) .unsqueeze(0) .repeat(batch_size, 1) .float() .to(device) ) cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device) # # #-------------initialize SMPLify smplify = SMPLify3D( smplxmodel=smplmodel, batch_size=batch_size, joints_category=joint_category, num_iters=num_smplify_iters, device=device, ) keypoints_3d = torch.Tensor(input_joints).to(device).float() pred_betas = init_mean_shape pred_pose = init_mean_pose pred_cam_t = cam_trans_zero if joint_category == "AMASS": confidence_input = torch.ones(num_joints) # make sure the foot and ankle if fix_foot: confidence_input[7] = 1.5 confidence_input[8] = 1.5 confidence_input[10] = 1.5 confidence_input[11] = 1.5 else: print("Such category not settle down!") ( new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, new_opt_cam_t, new_opt_joint_loss, ) = smplify( pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), keypoints_3d, conf_3d=confidence_input.to(device), # seq_ind=idx ) poses = new_opt_pose.detach().cpu().numpy() betas = new_opt_betas.mean(axis=0).detach().cpu().numpy() trans = keypoints_3d[:, 0].detach().cpu().numpy() target_dim = 165 current_dim = poses.shape[-1] pad_dim = target_dim - current_dim if pad_dim > 0: pad_array = np.zeros((*poses.shape[:-1], pad_dim), dtype=poses.dtype) poses = np.concatenate([poses, pad_array], axis=-1) root_orient = poses[:, :3] root_mat = Rotation.from_rotvec(root_orient).as_matrix() rx_minus_100 = Rotation.from_euler("x", -100, degrees=True).as_matrix() align_r = rx_minus_100 @ root_mat align_axis_angle = Rotation.from_matrix(align_r).as_rotvec() poses[:, :3] = align_axis_angle trans_rotated = rx_minus_100 @ (trans.T) trans_rotated = trans_rotated.T input_joints = input_joints[:, :, [0, 2, 1]] # jts stands on x, z input_joints[..., 0] *= -1 param = { "poses": poses, "trans": trans_rotated, "betas": betas, "gender": "neutral", "jtr": input_joints, "mocap_frame_rate": 30, } file_name = file_name.split(".")[0] + ".npz" print(file_name) np.savez_compressed(os.path.join(save_dir, file_name), **param) def humanact12_to_amass(data_dir, save_dir): """Convert HumanAct12 dataset to AMASS-compatible format. Args: data_dir (str): Directory containing HumanAct12 .npy joint files save_dir (str): Directory to save processed AMASS .npz files """ os.makedirs(save_dir, exist_ok=True) file_list = os.listdir(data_dir) random.shuffle(file_list) for file_name in tqdm(file_list): if os.path.exists(os.path.join(save_dir, file_name)): print(f"{os.path.join(save_dir, file_name)} already exists") continue joints2smpl(file_name, data_dir, save_dir) if __name__ == "__main__": data_dir = "" save_dir = "" ================================================ FILE: holomotion/src/data_curation/smplify/smplify_motionx.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import numpy as np from scipy.spatial.transform import Rotation def motionx_to_amass(src_root, dst_root): """Convert MotionX format motion data to AMASS format. Args: src_root (str): Source directory containing MotionX .npy files dst_root (str): Destination directory for processed AMASS .npz files Side effects: Creates directory structure mirroring src_root under dst_root Generates compressed .npz files in destination directory Prints file paths of processed files Processed data contains: poses: [T, 165] float array of joint rotations (root first) trans: [T, 3] float array of root translations betas: [10] float array of shape parameters gender: str (always "neutral") mocap_frame_rate: int (always 30) """ os.makedirs(dst_root, exist_ok=True) for root, _, files in os.walk(src_root): # print(files) for file in files: src_file_path = os.path.join(root, file) motion = np.load(src_file_path) poses = motion[:, :156] # 最终 shape: (T, 156) num_frames = poses.shape[0] sl = poses.shape[1] pad = np.zeros((num_frames, 165 - sl), dtype=poses.dtype) # (T, 9) poses = np.concatenate([poses, pad], axis=1) # (T, 165) align_axis_angle = poses[:, :3] root_orient = poses[:, :3] root_mat = Rotation.from_rotvec(root_orient).as_matrix() rotate_matrix = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) align_r = rotate_matrix @ root_mat align_axis_angle = Rotation.from_matrix(align_r).as_rotvec() poses[:, :3] = align_axis_angle trans = motion[:, 309:312] # (T, 3) trans[:, 2] = trans[:, 2] * (-1) trans_matrix = np.array( [[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]] ) trans = np.dot(trans, trans_matrix) trans = rotate_matrix @ (trans.T) trans = trans.T betas = motion[0, 312:] amass_data = { "poses": poses, "trans": trans, "betas": betas, "gender": "neutral", "mocap_frame_rate": 30, } relative_path = src_file_path.replace(src_root, "") file_name = dst_root + relative_path save_dir = file_name.split("/")[-1] save_dir = file_name.replace(save_dir, "") os.makedirs(save_dir, exist_ok=True) file_name = file_name.replace(".npy", ".npz") print(file_name) np.savez_compressed(file_name, **amass_data) if __name__ == "__main__": src_root = "./data/smplx_322" dst_root = "./data/smplx_data/MotionX" motionx_to_amass(src_root, dst_root) ================================================ FILE: holomotion/src/data_curation/smplify/smplify_omomo.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # ----------------------------------------------------------------------------- # Portions of this file are derived from omomo_release (https://github.com/lijiaman/omomo_release). # The original omomo_release code is licensed under the MIT license. # ----------------------------------------------------------------------------- import os import numpy as np import pytorch3d.transforms as transforms import torch from torch.utils import data from thirdparties.omomo_release.manip.data.hand_foot_dataset import ( HandFootManipDataset, quat_ik_torch, ) class MyHandFootManipDataset(HandFootManipDataset): """Modified dataset class for hand-foot manipulation tasks. This class overrides the __getitem__ method. """ def __init__(self, *args, **kwargs): """Initialize the dataset instance by forwarding all arguments. This constructor ensures proper initialization of the HandFootManipDataset parent class. All parameters and keyword arguments are passed through unchanged. Args: *args: Variable length argument list for parent class **kwargs: Arbitrary keyword arguments for parent class """ super().__init__(*args, **kwargs) def __getitem__(self, index): """Retrieve and process a data sample by index. Try not to padding when retrieve motion data. Args: index (int): Index of the sample to retrieve Reference: https://github.com/lijiaman/omomo_release/blob/main/manip/data/hand_foot_dataset.py """ # index = 0 # For debug data_input = self.window_data_dict[index]["motion"] data_input = torch.from_numpy(data_input).float() seq_name = self.window_data_dict[index]["seq_name"] object_name = seq_name.split("_")[1] trans2joint = self.window_data_dict[index]["trans2joint"] if self.use_object_splits: ori_w_idx = self.window_data_dict[index]["ori_w_idx"] obj_bps_npy_path = os.path.join( self.dest_obj_bps_npy_folder, seq_name + "_" + str(ori_w_idx) + ".npy", ) else: obj_bps_npy_path = os.path.join( self.dest_obj_bps_npy_folder, seq_name + "_" + str(index) + ".npy", ) obj_bps_data = np.load(obj_bps_npy_path) # T X N X 3 obj_bps_data = torch.from_numpy(obj_bps_data) num_joints = 24 normalized_jpos = self.normalize_jpos_min_max( data_input[:, : num_joints * 3].reshape(-1, num_joints, 3) ) # T X 22 X 3 global_joint_rot = data_input[:, 2 * num_joints * 3 :] # T X (22*6) new_data_input = torch.cat( (normalized_jpos.reshape(-1, num_joints * 3), global_joint_rot), dim=1, ) ori_data_input = torch.cat( (data_input[:, : num_joints * 3], global_joint_rot), dim=1 ) # Add padding. actual_steps = new_data_input.shape[0] # pass paded_new_data_input = new_data_input paded_ori_data_input = ori_data_input paded_obj_bps = obj_bps_data.reshape(actual_steps, -1) paded_obj_com_pos = torch.from_numpy( self.window_data_dict[index]["window_obj_com_pos"] ).float() paded_obj_rot_mat = torch.from_numpy( self.window_data_dict[index]["obj_rot_mat"] ).float() paded_obj_scale = torch.from_numpy( self.window_data_dict[index]["obj_scale"] ).float() paded_obj_trans = torch.from_numpy( self.window_data_dict[index]["obj_trans"] ).float() if object_name in ["mop", "vacuum"]: paded_obj_bottom_rot_mat = torch.from_numpy( self.window_data_dict[index]["obj_bottom_rot_mat"] ).float() paded_obj_bottom_scale = torch.from_numpy( self.window_data_dict[index]["obj_bottom_scale"] ).float() paded_obj_bottom_trans = ( torch.from_numpy( self.window_data_dict[index]["obj_bottom_trans"] ) .float() .squeeze(-1) ) data_input_dict = {} data_input_dict["motion"] = paded_new_data_input data_input_dict["ori_motion"] = paded_ori_data_input data_input_dict["obj_bps"] = paded_obj_bps data_input_dict["obj_com_pos"] = paded_obj_com_pos data_input_dict["obj_rot_mat"] = paded_obj_rot_mat data_input_dict["obj_scale"] = paded_obj_scale data_input_dict["obj_trans"] = paded_obj_trans if object_name in ["mop", "vacuum"]: data_input_dict["obj_bottom_rot_mat"] = paded_obj_bottom_rot_mat data_input_dict["obj_bottom_scale"] = paded_obj_bottom_scale data_input_dict["obj_bottom_trans"] = paded_obj_bottom_trans else: data_input_dict["obj_bottom_rot_mat"] = paded_obj_rot_mat data_input_dict["obj_bottom_scale"] = paded_obj_scale data_input_dict["obj_bottom_trans"] = paded_obj_trans data_input_dict["betas"] = self.window_data_dict[index]["betas"] data_input_dict["gender"] = str(self.window_data_dict[index]["gender"]) data_input_dict["seq_name"] = seq_name data_input_dict["obj_name"] = seq_name.split("_")[1] data_input_dict["seq_len"] = actual_steps data_input_dict["trans2joint"] = trans2joint return data_input_dict def run_smplx_model(root_trans, aa_rot_rep, betas, gender, fname): """Prepare and save SMPL-X motion data in AMASS npz format. Processes input motion parameters into SMPL-X compatible format and saves as a compressed npz file. Args: root_trans (torch.Tensor): Root translations [BS, T, 3] aa_rot_rep (torch.Tensor): Axis-angle joint rotations [BS, T, num_joints, 3] where num_joints can be either 22 (body only) or 52 (body+hands) betas (torch.Tensor): Shape parameters [BS, 16] gender (list): Gender strings for each sample in batch [BS] fname (str): Output filename/path for saving .npz file Output npz file contains: poses: [BS*T, 165] float array of pose parameters trans: [BS*T, 3] float array of translations betas: [16] float array of shape parameters (from first sample) gender: str (always "neutral") mocap_frame_rate: int (always 30) """ # root_trans: BS X T X 3 # aa_rot_rep: BS X T X 22 X 3 # betas: BS X 16 # gender: BS bs, num_steps, num_joints, _ = aa_rot_rep.shape if num_joints != 52: padding_zeros_hand = torch.zeros(bs, num_steps, 30, 3).to( aa_rot_rep.device ) # BS X T X 30 X 3 aa_rot_rep = torch.cat( (aa_rot_rep, padding_zeros_hand), dim=2 ) # BS X T X 52 X 3 aa_rot_rep = aa_rot_rep.reshape( bs * num_steps, -1, 3 ) # (BS*T) X n_joints X 3 betas = ( betas[:, None, :].repeat(1, num_steps, 1).reshape(bs * num_steps, -1) ) # (BS*T) X 16 gender = np.asarray(gender)[:, np.newaxis].repeat(num_steps, axis=1) gender = gender.reshape(-1).tolist() # (BS*T) smpl_trans = root_trans.reshape(-1, 3) # (BS*T) X 3 smpl_root_orient = aa_rot_rep[:, 0, :] # (BS*T) X 3 # print(smpl_root_orient.shape) smpl_pose_body = aa_rot_rep[:, 1:22, :].reshape(-1, 63) # (BS*T) X 63 smpl_pose_hand = aa_rot_rep[:, 22:, :].reshape(-1, 90) # (BS*T) X 90 poses = torch.cat( [smpl_root_orient, smpl_pose_body, smpl_pose_hand], dim=-1 ) target_dim = 165 current_dim = poses.shape[-1] pad_dim = target_dim - current_dim if pad_dim > 0: pad_tensor = torch.zeros( *poses.shape[:-1], pad_dim, device=poses.device, dtype=poses.dtype ) poses_padded = torch.cat([poses, pad_tensor], dim=-1) else: poses_padded = poses # already 165 or more amass_data = { "poses": poses_padded.detach().cpu().numpy(), "trans": smpl_trans.detach().cpu().numpy(), "betas": betas[0].detach().cpu().numpy(), "gender": "neutral", "mocap_frame_rate": 30, } np.savez_compressed(fname, **amass_data) def process_dataset(dl, dataset, target_folder, split_name: str): """Process a motion dataset batch and convert sequences to SMPL-X format. Args: dl (DataLoader): PyTorch DataLoader providing batched data dataset (Dataset): Source dataset object (for denormalization) target_folder (str): target folder for data saving split_name (str): Name of data split being processed Output files: Saved as: {target_folder}/{split_name}_{object_name}_{index}.npz Where: target_folder: (implied from external context) object_name: Extracted from sequence name index: Incremental sequence counter """ index = 0 for data_dict in dl: val_data = data_dict["motion"].cuda() for_vis_gt_data = val_data[:] all_res_list = for_vis_gt_data num_seq = all_res_list.shape[0] print(f"Processing {split_name}, num_seq: {num_seq}") num_joints = 24 normalized_global_jpos = all_res_list[:, :, : num_joints * 3].reshape( num_seq, -1, num_joints, 3 ) global_jpos = dataset.de_normalize_jpos_min_max( normalized_global_jpos.reshape(-1, num_joints, 3) ) global_jpos = global_jpos.reshape(num_seq, -1, num_joints, 3) global_root_jpos = global_jpos[:, :, 0, :].clone() global_rot_6d = all_res_list[:, :, -22 * 6 :].reshape( num_seq, -1, 22, 6 ) global_rot_mat = transforms.rotation_6d_to_matrix(global_rot_6d) trans2joint = data_dict["trans2joint"].to(all_res_list.device) for idx in range(num_seq): curr_global_rot_mat = global_rot_mat[idx] curr_local_rot_mat = quat_ik_torch(curr_global_rot_mat) curr_local_rot_aa_rep = transforms.matrix_to_axis_angle( curr_local_rot_mat ) curr_global_root_jpos = global_root_jpos[idx] curr_trans2joint = trans2joint[idx : idx + 1].clone() root_trans = curr_global_root_jpos + curr_trans2joint betas = data_dict["betas"][idx] gender = data_dict["gender"][idx] curr_seq_name = data_dict["seq_name"][idx] object_name = curr_seq_name.split("_")[1] fname = os.path.join( target_folder, f"{split_name}_{object_name}_{index}.npz" ) print(fname) run_smplx_model( root_trans[None].cuda(), curr_local_rot_aa_rep[None].cuda(), betas.cuda(), [gender], fname, ) index += 1 def omomo_to_amass(data_root_folder, target_folder): """Convert Omomo dataset to AMASS-compatible SMPL-X format. Args: data_root_folder (str): Path to the root directory of Omomo dataset target_folder (str): Output directory for processed AMASS files """ use_object_split = True window_size = 120 train_dataset = MyHandFootManipDataset( train=True, data_root_folder=data_root_folder, window=window_size, use_object_splits=use_object_split, ) val_dataset = MyHandFootManipDataset( train=False, data_root_folder=data_root_folder, window=window_size, use_object_splits=use_object_split, ) val_ds = val_dataset train_ds = train_dataset val_dl = data.DataLoader( val_ds, batch_size=1, shuffle=False, pin_memory=True, num_workers=0 ) train_dl = data.DataLoader( train_ds, batch_size=1, shuffle=False, pin_memory=True, num_workers=0 ) process_dataset(train_dl, train_dataset, target_folder, "train") process_dataset(val_dl, val_dataset, target_folder, "val") if __name__ == "__main__": data_root_folder = "" target_folder = "" omomo_to_amass(data_root_folder, target_folder) ================================================ FILE: holomotion/src/data_curation/smplify/smplify_zjumocap.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import numpy as np from scipy.spatial.transform import Rotation from tqdm import tqdm def zju_single_to_amass( param_dir, out_path, gender="neutral", fps=30, rotate=False ): """Convert .npy files into a single AMASS-style .npz file. Args: param_dir: Folder containing 0.npy, 1.npy, .... out_path: Output .npz path. gender: Gender to assign ('neutral', 'male', 'female'). fps: Mocap frame rate. rotate: whether or not rotate the body """ pose_list = [] trans_list = [] shape_list = [] # Get sorted list of npy files files = sorted( [f for f in os.listdir(param_dir) if f.endswith(".npy")], key=lambda x: int(os.path.splitext(x)[0]), ) if rotate: ry_minus_180 = Rotation.from_euler("y", -180, degrees=True).as_matrix() else: ry_minus_180 = Rotation.from_euler("y", 0, degrees=True).as_matrix() for fname in tqdm(files, desc="Processing frames"): fpath = os.path.join(param_dir, fname) data = np.load(fpath, allow_pickle=True).item() poses = data["poses"] # (1, 72) global_orient = data["Rh"] root_orient = global_orient root_mat = Rotation.from_rotvec(root_orient).as_matrix() align_r = ry_minus_180 @ root_mat align_axis_angle = Rotation.from_matrix(align_r).as_rotvec() global_orient = align_axis_angle body_pose = poses[:, 3:66] hand_pose = poses[:, 66:72] full_pose = np.concatenate( [global_orient, body_pose, hand_pose], axis=1 ) # (1, 165) pose_list.append(full_pose[0]) # shape: (165,) trans_list.append(data["Th"][0]) # shape: (3,) shape_list.append(data["shapes"][0]) # shape: (10,) poses = np.stack(pose_list, axis=0).astype(np.float32) # (N, 165) trans = np.stack(trans_list, axis=0).astype(np.float32) # (N, 3) trans_rotated = ry_minus_180 @ (trans.T) trans_rotated = trans_rotated.T betas = shape_list[0].astype(np.float32) # (10,) same for all frames # Save as AMASS-style npz np.savez_compressed( out_path, poses=poses, trans=trans_rotated, betas=betas, gender=gender, mocap_frame_rate=fps, ) print(f"Saved AMASS-style file to: {out_path}") print(f"Total frames: {poses.shape[0]}") def zju_to_amass(input_dir, output_dir): """Convert multiple ZJU-formatted folders to AMASS-style .npz files. Args: input_dir: Path to ZJU dataset root folder. output_dir: Path to save AMASS-format .npz files. """ os.makedirs(output_dir, exist_ok=True) subjects = sorted( [ d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d)) ] ) for subject in subjects: subject_dir = os.path.join(input_dir, subject) new_params_dir = os.path.join(subject_dir, "new_params") params_dir = os.path.join(subject_dir, "params") if os.path.isdir(new_params_dir): param_dir = new_params_dir print(f"[{subject}] Using new_params") elif os.path.isdir(params_dir): param_dir = params_dir print(f"[{subject}] Using params") else: print(f"[{subject}] No params found, skipping") continue out_path = os.path.join(output_dir, f"{subject}.npz") zju_single_to_amass(param_dir, out_path) print(f"All subjects processed. Output saved to {output_dir}") # Example usage if __name__ == "__main__": zju_to_amass( param_dir="", out_path="", ) ================================================ FILE: holomotion/src/data_curation/templates/index_wooden_static.html ================================================ Motion Visualization
{{ caption_html }}
0 / 0
Loading...
1.0x
================================================ FILE: holomotion/src/data_curation/video_to_smpl_gvhmr.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # This file was originally copied from the [PBHC] repository: # https://github.com/TeleHuman/PBHC # Modifications have been made to fit the needs of this project. import cv2 import torch import pytorch_lightning as pl import numpy as np import argparse from hmr4d.utils.pylogger import Log import hydra from hydra import initialize_config_module, compose from pathlib import Path from pytorch3d.transforms import quaternion_to_matrix from hmr4d.configs import register_store_gvhmr from hmr4d.utils.video_io_utils import ( get_video_lwh, read_video_np, save_video, merge_videos_horizontal, get_writer, get_video_reader, ) from hmr4d.utils.vis.cv2_utils import ( draw_bbx_xyxy_on_image_batch, draw_coco17_skeleton_batch, ) from hmr4d.utils.preproc import Tracker, Extractor, VitPoseExtractor, SimpleVO from hmr4d.utils.geo.hmr_cam import ( get_bbx_xys_from_xyxy, estimate_K, convert_K_to_K4, create_camera_sensor, ) from hmr4d.utils.geo_transform import compute_cam_angvel from hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL from hmr4d.utils.net_utils import detach_to_cpu, to_cuda from hmr4d.utils.smplx_utils import make_smplx from hmr4d.utils.vis.renderer import ( Renderer, get_global_cameras_static, get_ground_params_from_points, ) from tqdm import tqdm from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay from einops import einsum, rearrange import shutil import subprocess from scipy.spatial.transform import Rotation as sRot CRF = 23 # 17 is lossless, every +6 halves the mp4 size def get_video_fps(video_path: Path) -> float: cap = cv2.VideoCapture(str(video_path)) fps = cap.get(cv2.CAP_PROP_FPS) cap.release() if fps is None or fps <= 1e-6: raise RuntimeError(f"Failed to read FPS from video: {video_path}") return float(fps) def is_close_fps(a: float, b: float, tol: float = 0.02) -> bool: return abs(a - b) <= tol def transcode_to_30fps_cfr(src: Path, dst: Path, crf: int) -> None: dst.parent.mkdir(parents=True, exist_ok=True) cmd = [ "ffmpeg", "-y", "-i", str(src), "-vf", "fps=30", "-vsync", "cfr", "-c:v", "libx264", "-crf", str(crf), "-preset", "medium", "-c:a", "copy", str(dst), ] subprocess.run(cmd, check=True) def parse_args_to_cfg(args=None): # Put all args to cfg if args is None: parser = argparse.ArgumentParser() parser.add_argument( "--video", type=str, default="inputs/demo/dance_3.mp4" ) parser.add_argument( "--output_root", type=str, default=None, help="by default to outputs/demo", ) parser.add_argument( "-s", "--static_cam", action="store_true", help="If true, skip DPVO", ) parser.add_argument( "--use_dpvo", action="store_true", help="If true, use DPVO. By default not using DPVO.", ) parser.add_argument( "--f_mm", type=int, default=None, help="Focal length of fullframe camera in mm. Leave it as None to use default values." "For iPhone 15p, the [0.5x, 1x, 2x, 3x] lens have typical values [13, 24, 48, 77]." "If the camera zoom in a lot, you can try 135, 200 or even larger values.", ) parser.add_argument( "--verbose", action="store_true", help="If true, draw intermediate results", ) args = parser.parse_args() # Input video_path = Path(args.video) assert video_path.exists(), f"Video not found at {video_path}" length, width, height = get_video_lwh(video_path) Log.info(f"[Input]: {video_path}") Log.info(f"(L, W, H) = ({length}, {width}, {height})") # Cfg with initialize_config_module( version_base="1.3", config_module=f"hmr4d.configs" ): overrides = [ f"video_name='{video_path.stem}'", f"static_cam={args.static_cam}", f"verbose={args.verbose}", f"use_dpvo={args.use_dpvo}", ] if args.f_mm is not None: overrides.append(f"f_mm={args.f_mm}") # Allow to change output root if args.output_root is not None: overrides.append(f"output_root='{args.output_root}'") register_store_gvhmr() cfg = compose(config_name="demo", overrides=overrides) # Output Log.info(f"[Output Dir]: {cfg.output_dir}") Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) Path(cfg.preprocess_dir).mkdir(parents=True, exist_ok=True) # Copy raw-input-video to video_path Log.info(f"[Prepare Video] {video_path} -> {cfg.video_path}") src_len = get_video_lwh(video_path)[0] dst_path = Path(cfg.video_path) need_regen = (not dst_path.exists()) or ( get_video_lwh(dst_path)[0] != src_len ) src_fps = get_video_fps(video_path) Log.info(f"[Input FPS]: {src_fps:.4f}") if need_regen: if is_close_fps(src_fps, 30.0): Log.info("[FPS OK] ~30fps, copy without re-encoding.") dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(video_path, dst_path) else: Log.info("[FPS CONVERT] transcoding to 30fps with constant speed.") transcode_to_30fps_cfr(video_path, Path(cfg.video_path), CRF) return cfg @torch.no_grad() def run_preprocess(cfg): Log.info(f"[Preprocess] Start!") tic = Log.time() video_path = cfg.video_path paths = cfg.paths static_cam = cfg.static_cam verbose = cfg.verbose # Get bbx tracking result if not Path(paths.bbx).exists(): tracker = Tracker() bbx_xyxy = tracker.get_one_track(video_path).float() # (L, 4) bbx_xys = get_bbx_xys_from_xyxy( bbx_xyxy, base_enlarge=1.2 ).float() # (L, 3) apply aspect ratio and enlarge torch.save({"bbx_xyxy": bbx_xyxy, "bbx_xys": bbx_xys}, paths.bbx) del tracker else: bbx_xys = torch.load(paths.bbx)["bbx_xys"] Log.info(f"[Preprocess] bbx (xyxy, xys) from {paths.bbx}") if verbose: video = read_video_np(video_path) bbx_xyxy = torch.load(paths.bbx)["bbx_xyxy"] video_overlay = draw_bbx_xyxy_on_image_batch(bbx_xyxy, video) save_video(video_overlay, cfg.paths.bbx_xyxy_video_overlay) # Get VitPose if not Path(paths.vitpose).exists(): vitpose_extractor = VitPoseExtractor() vitpose = vitpose_extractor.extract(video_path, bbx_xys) torch.save(vitpose, paths.vitpose) del vitpose_extractor else: vitpose = torch.load(paths.vitpose) Log.info(f"[Preprocess] vitpose from {paths.vitpose}") if verbose: video = read_video_np(video_path) video_overlay = draw_coco17_skeleton_batch(video, vitpose, 0.5) save_video(video_overlay, paths.vitpose_video_overlay) # Get vit features if not Path(paths.vit_features).exists(): extractor = Extractor() vit_features = extractor.extract_video_features(video_path, bbx_xys) torch.save(vit_features, paths.vit_features) del extractor else: Log.info(f"[Preprocess] vit_features from {paths.vit_features}") # Get visual odometry results if not static_cam: # use slam to get cam rotation if not Path(paths.slam).exists(): if not cfg.use_dpvo: simple_vo = SimpleVO( cfg.video_path, scale=0.5, step=8, method="sift", f_mm=cfg.f_mm, ) vo_results = simple_vo.compute() # (L, 4, 4), numpy torch.save(vo_results, paths.slam) else: # DPVO from hmr4d.utils.preproc.slam import SLAMModel length, width, height = get_video_lwh(cfg.video_path) K_fullimg = estimate_K(width, height) intrinsics = convert_K_to_K4(K_fullimg) slam = SLAMModel( video_path, width, height, intrinsics, buffer=4000, resize=0.5, ) bar = tqdm(total=length, desc="DPVO") while True: ret = slam.track() if ret: bar.update() else: break slam_results = slam.process() # (L, 7), numpy torch.save(slam_results, paths.slam) else: Log.info(f"[Preprocess] slam results from {paths.slam}") Log.info(f"[Preprocess] End. Time elapsed: {Log.time() - tic:.2f}s") def load_data_dict(cfg): paths = cfg.paths length, width, height = get_video_lwh(cfg.video_path) if cfg.static_cam: R_w2c = torch.eye(3).repeat(length, 1, 1) else: traj = torch.load(cfg.paths.slam) if cfg.use_dpvo: # DPVO traj_quat = torch.from_numpy(traj[:, [6, 3, 4, 5]]) R_w2c = quaternion_to_matrix(traj_quat).mT else: # SimpleVO R_w2c = torch.from_numpy(traj[:, :3, :3]) if cfg.f_mm is not None: K_fullimg = create_camera_sensor(width, height, cfg.f_mm)[2].repeat( length, 1, 1 ) else: K_fullimg = estimate_K(width, height).repeat(length, 1, 1) data = { "length": torch.tensor(length), "bbx_xys": torch.load(paths.bbx)["bbx_xys"], "kp2d": torch.load(paths.vitpose), "K_fullimg": K_fullimg, "cam_angvel": compute_cam_angvel(R_w2c), "f_imgseq": torch.load(paths.vit_features), } return data def save_npz(pred, save_path): out_dir = Path(save_path).parent out_dir.mkdir(parents=True, exist_ok=True) trans = pred["transl"].detach().cpu() body_pose = torch.cat( ( pred["global_orient"].detach().cpu(), pred["body_pose"].detach().cpu(), ), dim=1, ) transform1 = sRot.from_euler( "xyz", np.array([np.pi / 2, 0, np.pi]), degrees=False ) new_root = ( transform1 * sRot.from_rotvec(body_pose[:, :3].numpy()) ).as_rotvec() body_pose[:, :3] = torch.from_numpy(new_root) trans = trans @ torch.tensor(transform1.as_matrix().T, dtype=torch.float32) out_path = out_dir / "smpl.npz" Log.info(f"npz_path {out_path}") np.savez( str(out_path), betas=pred["betas"][0].detach().cpu().numpy(), gender="neutral", poses=body_pose.numpy(), trans=trans.numpy(), mocap_framerate=30.0, ) def render_incam(cfg): incam_video_path = Path(cfg.paths.incam_video) if incam_video_path.exists(): Log.info(f"[Render Incam] Video already exists at {incam_video_path}") return pred = torch.load(cfg.paths.hmr4d_results) smplx = make_smplx("supermotion").cuda() smplx2smpl = torch.load( "hmr4d/utils/body_model/smplx2smpl_sparse.pt" ).cuda() faces_smpl = make_smplx("smpl").faces # smpl smplx_out = smplx(**to_cuda(pred["smpl_params_incam"])) pred_c_verts = torch.stack( [torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices] ) # -- rendering code -- # video_path = cfg.video_path length, width, height = get_video_lwh(video_path) K = pred["K_fullimg"][0] # renderer renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K) reader = get_video_reader(video_path) # (F, H, W, 3), uint8, numpy bbx_xys_render = torch.load(cfg.paths.bbx)["bbx_xys"] # -- render mesh -- # verts_incam = pred_c_verts writer = get_writer(incam_video_path, fps=30, crf=CRF) for i, img_raw in tqdm( enumerate(reader), total=get_video_lwh(video_path)[0], desc=f"Rendering Incam", ): img = renderer.render_mesh( verts_incam[i].cuda(), img_raw, [0.8, 0.8, 0.8] ) # # bbx # bbx_xys_ = bbx_xys_render[i].cpu().numpy() # lu_point = (bbx_xys_[:2] - bbx_xys_[2:] / 2).astype(int) # rd_point = (bbx_xys_[:2] + bbx_xys_[2:] / 2).astype(int) # img = cv2.rectangle(img, lu_point, rd_point, (255, 178, 102), 2) writer.write_frame(img) writer.close() reader.close() def render_global(cfg): global_video_path = Path(cfg.paths.global_video) # Always save NPZ regardless of whether the video already exists pred = torch.load(cfg.paths.hmr4d_results) save_npz(pred["smpl_params_global"], save_path=global_video_path) if global_video_path.exists(): Log.info( f"[Render Global] Video already exists at {global_video_path}" ) return debug_cam = False smplx = make_smplx("supermotion").cuda() smplx2smpl = torch.load( "hmr4d/utils/body_model/smplx2smpl_sparse.pt" ).cuda() faces_smpl = make_smplx("smpl").faces J_regressor = torch.load( "hmr4d/utils/body_model/smpl_neutral_J_regressor.pt" ).cuda() # smpl smplx_out = smplx(**to_cuda(pred["smpl_params_global"])) # npz already saved above pred_ay_verts = torch.stack( [torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices] ) def move_to_start_point_face_z(verts): "XZ to origin, Start from the ground, Face-Z" # position verts = verts.clone() # (L, V, 3) offset = einsum(J_regressor, verts[0], "j v, v i -> j i")[0] # (3) offset[1] = verts[:, :, [1]].min() verts = verts - offset # face direction T_ay2ayfz = compute_T_ayfz2ay( einsum(J_regressor, verts[[0]], "j v, l v i -> l j i"), inverse=True, ) verts = apply_T_on_points(verts, T_ay2ayfz) return verts verts_glob = move_to_start_point_face_z(pred_ay_verts) joints_glob = einsum( J_regressor, verts_glob, "j v, l v i -> l j i" ) # (L, J, 3) global_R, global_T, global_lights = get_global_cameras_static( verts_glob.cpu(), beta=2.0, cam_height_degree=20, target_center_height=1.0, ) # -- rendering code -- # video_path = cfg.video_path length, width, height = get_video_lwh(video_path) _, _, K = create_camera_sensor(width, height, 24) # render as 24mm lens # renderer renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K) # renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K, bin_size=0) # -- render mesh -- # scale, cx, cz = get_ground_params_from_points( joints_glob[:, 0], verts_glob ) renderer.set_ground(scale * 1.5, cx, cz) color = torch.ones(3).float().cuda() * 0.8 render_length = length if not debug_cam else 8 writer = get_writer(global_video_path, fps=30, crf=CRF) for i in tqdm(range(render_length), desc=f"Rendering Global"): cameras = renderer.create_camera(global_R[i], global_T[i]) img = renderer.render_with_ground( verts_glob[[i]], color[None], cameras, global_lights ) writer.write_frame(img) writer.close() if __name__ == "__main__": # Top-level parser to support folder batch mode top_parser = argparse.ArgumentParser() top_parser.add_argument("--video", type=str, default=None) top_parser.add_argument("--folder", "-f", type=str, default=None) top_parser.add_argument("--output_root", "-d", type=str, default=None) top_parser.add_argument("-s", "--static_cam", action="store_true") top_parser.add_argument("--use_dpvo", action="store_true") top_parser.add_argument("--f_mm", type=int, default=None) top_parser.add_argument("--verbose", action="store_true") top_args = top_parser.parse_args() # Batch mode if top_args.folder is not None: folder = Path(top_args.folder) mp4_paths = sorted( list(folder.glob("*.mp4")) + list(folder.glob("*.MP4")) ) Log.info(f"Found {len(mp4_paths)} .mp4 files in {folder}") for mp4_path in tqdm(mp4_paths): per_args = argparse.Namespace( video=str(mp4_path), output_root=top_args.output_root, static_cam=top_args.static_cam, use_dpvo=top_args.use_dpvo, f_mm=top_args.f_mm, verbose=top_args.verbose, ) try: cfg = parse_args_to_cfg(per_args) paths = cfg.paths Log.info(f"[GPU]: {torch.cuda.get_device_name()}") Log.info(f"[GPU]: {torch.cuda.get_device_properties('cuda')}") run_preprocess(cfg) data = load_data_dict(cfg) if not Path(paths.hmr4d_results).exists(): Log.info("[HMR4D] Predicting") model: DemoPL = hydra.utils.instantiate( cfg.model, _recursive_=False ) model.load_pretrained_model(cfg.ckpt_path) model = model.eval().cuda() tic = Log.sync_time() pred = model.predict(data, static_cam=cfg.static_cam) pred = detach_to_cpu(pred) data_time = data["length"] / 30 Log.info( f"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s" ) torch.save(pred, paths.hmr4d_results) render_incam(cfg) render_global(cfg) if not Path(paths.incam_global_horiz_video).exists(): Log.info("[Merge Videos]") merge_videos_horizontal( [paths.incam_video, paths.global_video], paths.incam_global_horiz_video, ) except Exception as e: Log.error(f"Failed on {mp4_path}: {e}") raise SystemExit(0) # Single video mode if top_args.video is None: top_parser.error("Must provide --video or --folder") single_args = argparse.Namespace( video=top_args.video, output_root=top_args.output_root, static_cam=top_args.static_cam, use_dpvo=top_args.use_dpvo, f_mm=top_args.f_mm, verbose=top_args.verbose, ) cfg = parse_args_to_cfg(single_args) paths = cfg.paths Log.info(f"[GPU]: {torch.cuda.get_device_name()}") Log.info(f"[GPU]: {torch.cuda.get_device_properties('cuda')}") # ===== Preprocess and save to disk ===== # run_preprocess(cfg) data = load_data_dict(cfg) # ===== HMR4D ===== # if not Path(paths.hmr4d_results).exists(): Log.info("[HMR4D] Predicting") model: DemoPL = hydra.utils.instantiate(cfg.model, _recursive_=False) model.load_pretrained_model(cfg.ckpt_path) model = model.eval().cuda() tic = Log.sync_time() pred = model.predict(data, static_cam=cfg.static_cam) pred = detach_to_cpu(pred) data_time = data["length"] / 30 Log.info( f"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s" ) torch.save(pred, paths.hmr4d_results) # ===== Render ===== # render_incam(cfg) render_global(cfg) if not Path(paths.incam_global_horiz_video).exists(): Log.info("[Merge Videos]") merge_videos_horizontal( [paths.incam_video, paths.global_video], paths.incam_global_horiz_video, ) ================================================ FILE: holomotion/src/data_curation/vison_mocap/joints2smpl.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import sys sys.path.append("../../") sys.path.append("../../thirdparties/joints2smpl/src") import h5py import numpy as np import smplx import torch from scipy.spatial.transform import Rotation from thirdparties.joints2smpl.src import config from thirdparties.joints2smpl.src.smplify import SMPLify3D device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") num_joints = 22 joint_category = "AMASS" num_smplify_iters = 300 fix_foot = False def joints2smpl(input_joints, save_name): """Save the joints as amass-compatible npz file. Note: This function depends on the `joints2smpl` repository. To use this function properly, you need to manually modify parts of the internal `joints2smpl` repository code to ensure compatibility. """ # print(file_name) input_joints = input_joints[:, :, [0, 1, 2]] # amass stands on x, y """XY at origin""" input_joints[..., [0, 1]] -= input_joints[0, 0, [0, 1]] """Put on Floor""" floor_height = input_joints[:, :, 2].min() input_joints[:, :, 2] -= floor_height batch_size = input_joints.shape[0] smplmodel = smplx.create( config.SMPL_MODEL_DIR, model_type="smpl", gender="neutral", ext="npz", batch_size=batch_size, ).to(device) # ## --- load the mean pose as original ---- smpl_mean_file = config.SMPL_MEAN_FILE file = h5py.File(smpl_mean_file, "r") init_mean_pose = ( torch.from_numpy(file["pose"][:]) .unsqueeze(0) .repeat(batch_size, 1) .float() .to(device) ) init_mean_shape = ( torch.from_numpy(file["shape"][:]) .unsqueeze(0) .repeat(batch_size, 1) .float() .to(device) ) cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device) # # #-------------initialize SMPLify smplify = SMPLify3D( smplxmodel=smplmodel, batch_size=batch_size, joints_category=joint_category, num_iters=num_smplify_iters, device=device, ) keypoints_3d = torch.Tensor(input_joints).to(device).float() pred_betas = init_mean_shape pred_pose = init_mean_pose pred_cam_t = cam_trans_zero if joint_category == "AMASS": confidence_input = torch.ones(num_joints) # make sure the foot and ankle if fix_foot: confidence_input[7] = 1.5 confidence_input[8] = 1.5 confidence_input[10] = 1.5 confidence_input[11] = 1.5 else: print("Such category not settle down!") ( new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, new_opt_cam_t, new_opt_joint_loss, ) = smplify( pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), keypoints_3d, conf_3d=confidence_input.to(device), # seq_ind=idx ) poses = new_opt_pose.detach().cpu().numpy() betas = new_opt_betas.mean(axis=0).detach().cpu().numpy() trans = keypoints_3d[:, 0].detach().cpu().numpy() root_orient = poses[:, :3] root_mat = Rotation.from_rotvec(root_orient).as_matrix() rx_minus_90 = Rotation.from_euler("xz", [90, 0], degrees=True).as_matrix() # rotate_matrix = np.array([[1,0,0],[0,0,-1],[0,1,0]]) # Ry_10 = Rotation.from_euler('z',20,degrees=True).as_matrix() align_r = rx_minus_90 @ root_mat # align_r = rotate_matrix@root_mat align_axis_angle = Rotation.from_matrix(align_r).as_rotvec() poses[:, :3] = align_axis_angle input_joints = input_joints[:, :, [0, 2, 1]] # jts stands on x, z input_joints[..., 0] *= -1 trans_rotated = rx_minus_90 @ (trans.T) trans_rotated = trans_rotated.T target_dim = 165 poses_padding = np.zeros((poses.shape[0], target_dim)) if poses.shape[1] < target_dim: poses_padding[:, : poses.shape[1]] = poses else: poses_padding = poses param = { "poses": poses_padding, "trans": trans_rotated, "betas": betas, "gender": "neutral", "jtr": input_joints, "mocap_frame_rate": 30, } np.savez_compressed(save_name, **param) print(f"successfully save file:{save_name}") ================================================ FILE: holomotion/src/data_curation/visualize_smpl_npz.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import argparse import http.server import os import socket import socketserver import subprocess import sys import threading from dataclasses import dataclass from pathlib import Path from typing import Optional import webview # ----------------------------- # UI Shell # ----------------------------- SHELL_HTML = r""" SMPL NPZ Viewer
Select an NPZ file…
""" # ----------------------------- # Config # ----------------------------- @dataclass(frozen=True) class AppConfig: root: Path port: int smpl_npz_to_html: Path template: Path out_html: Path window_title: str width: int height: int auto_pick: bool debug: bool def parse_args() -> argparse.Namespace: ap = argparse.ArgumentParser(description="SMPL NPZ viewer UI.") ap.add_argument( "--port", type=int, default=8000, help="Local HTTP port for serving assets.", ) ap.add_argument( "--smpl_npz_to_html", type=Path, default=Path("smpl_npz_to_html.py"), help="Path to smpl_npz_to_html.py", ) ap.add_argument( "--template", type=Path, default=Path("templates/index_wooden_static.html"), help="HTML template path", ) ap.add_argument( "--out", type=Path, default=Path("_generated/vis.html"), help="Output vis.html path", ) ap.add_argument( "--title", type=str, default="NPZ Viewer", help="Window title" ) ap.add_argument("--width", type=int, default=800, help="Window width") ap.add_argument("--height", type=int, default=600, help="Window height") ap.add_argument( "--no-auto-pick", action="store_false", help="Do not auto-open file picker at startup", ) ap.add_argument( "--debug", action="store_true", help="Enable pywebview debug/devtools" ) return ap.parse_args() # ----------------------------- # Utilities # ----------------------------- def js_escape(s: str) -> str: return s.replace("\\", "\\\\").replace("'", "\\'") def ensure_exists(path: Path, what: str) -> None: if not path.exists(): raise FileNotFoundError(f"Missing {what}: {path}") def is_port_available(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: sock.bind(("127.0.0.1", port)) return True except OSError: return False # ----------------------------- # Core: server + generator + UI API # ----------------------------- class StaticServer: def __init__(self, root: Path, port: int): self.root = root self.port = port self._thread: Optional[threading.Thread] = None def start(self) -> None: def _serve(): os.chdir(self.root) # serve assets from project root handler = http.server.SimpleHTTPRequestHandler with socketserver.TCPServer( ("127.0.0.1", self.port), handler ) as httpd: httpd.serve_forever() self._thread = threading.Thread(target=_serve, daemon=True) self._thread.start() class MakeVisRunner: def __init__( self, root: Path, smpl_npz_to_html: Path, template: Path, out_html: Path, ): self.root = root self.smpl_npz_to_html = smpl_npz_to_html self.template = template self.out_html = out_html def run(self, npz_path: Path) -> None: ensure_exists(self.smpl_npz_to_html, "smpl_npz_to_html.py") ensure_exists(self.template, "template html") ensure_exists(npz_path, "npz file") self.out_html.parent.mkdir(parents=True, exist_ok=True) cmd = [ sys.executable, str(self.smpl_npz_to_html), "--npz", str(npz_path), "--template", str(self.template), "--out", str(self.out_html), ] subprocess.check_call(cmd, cwd=str(self.root)) def pick_npz_dialog(window) -> Optional[Path]: file_types = ("NPZ files (*.npz)", "All files (*.*)") # Prefer new enum if available; fallback to deprecated constant. try: dialog_open = webview.FileDialog.OPEN # type: ignore[attr-defined] paths = window.create_file_dialog( dialog_open, allow_multiple=False, file_types=file_types ) except Exception: paths = window.create_file_dialog( webview.OPEN_DIALOG, allow_multiple=False, file_types=file_types ) return Path(paths[0]) if paths else None class UIAPI: def __init__(self, window, cfg: AppConfig, runner: MakeVisRunner): self.window = window self.cfg = cfg self.runner = runner self._busy = False def pick_and_generate(self) -> None: if self._busy: return npz = pick_npz_dialog(self.window) if npz is None: return safe_name = js_escape(npz.name) self.window.evaluate_js( f"setBusy(true); setStatus('Generating: {safe_name}');" ) def worker(): self._busy = True try: self.runner.run(npz) rel = self.cfg.out_html.relative_to(self.cfg.root).as_posix() self.window.evaluate_js( f"setBusy(false); setStatus('Loaded: {safe_name}'); " f"showViewer('http://127.0.0.1:{self.cfg.port}/{rel}');" ) except Exception as e: msg = js_escape(str(e)) self.window.evaluate_js( f"setBusy(false); setStatus('Failed: {msg}');" ) finally: self._busy = False threading.Thread(target=worker, daemon=True).start() def auto_pick_once(self) -> None: # Called from window.events.loaded; ensure it runs once. if getattr(self, "_auto_done", False): return setattr(self, "_auto_done", True) if self.cfg.auto_pick: self.pick_and_generate() # ----------------------------- # Entrypoint # ----------------------------- def build_config(args: argparse.Namespace) -> AppConfig: root = Path(__file__).resolve().parent smpl_npz_to_html = ( (root / args.smpl_npz_to_html).resolve() if not args.smpl_npz_to_html.is_absolute() else args.smpl_npz_to_html ) template = ( (root / args.template).resolve() if not args.template.is_absolute() else args.template ) out_html = ( (root / args.out).resolve() if not args.out.is_absolute() else args.out ) return AppConfig( root=root, port=int(args.port), smpl_npz_to_html=smpl_npz_to_html, template=template, out_html=out_html, window_title=str(args.title), width=int(args.width), height=int(args.height), auto_pick=not bool(args.no_auto_pick), debug=bool(args.debug), ) def main() -> None: args = parse_args() cfg = build_config(args) if not is_port_available(cfg.port): raise RuntimeError( f"Port {cfg.port} is already in use. Try --port 8001" ) server = StaticServer(cfg.root, cfg.port) server.start() runner = MakeVisRunner( cfg.root, cfg.smpl_npz_to_html, cfg.template, cfg.out_html ) window = webview.create_window( cfg.window_title, html=SHELL_HTML, width=cfg.width, height=cfg.height ) api = UIAPI(window, cfg, runner) window.expose(api.pick_and_generate) # Auto pick once on initial load (optional) window.events.loaded += lambda: threading.Thread( target=api.auto_pick_once, daemon=True ).start() webview.start(debug=cfg.debug) if __name__ == "__main__": main() ================================================ FILE: holomotion/src/env/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/env/isaaclab_components/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from holomotion.src.env.isaaclab_components.isaaclab_actions import ( build_actions_config, ActionsCfg, ) from holomotion.src.env.isaaclab_components.isaaclab_scene import ( build_scene_config, MotionTrackingSceneCfg, ) from holomotion.src.env.isaaclab_components.isaaclab_simulator import ( build_simulator_config, ) from holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import ( build_motion_tracking_commands_config, MoTrack_CommandsCfg, ) from holomotion.src.env.isaaclab_components.isaaclab_rewards import ( build_rewards_config, RewardsCfg, ) from holomotion.src.env.isaaclab_components.isaaclab_observation import ( build_observations_config, ObservationsCfg, ) from holomotion.src.env.isaaclab_components.isaaclab_termination import ( build_terminations_config, TerminationsCfg, ) from holomotion.src.env.isaaclab_components.isaaclab_domain_rand import ( build_domain_rand_config, EventsCfg, ) from holomotion.src.env.isaaclab_components.isaaclab_curriculum import ( build_curriculum_config, CurriculumCfg, ) from holomotion.src.env.isaaclab_components.isaaclab_velocity_tracking_command import ( build_velocity_commands_config, VelTrack_CommandsCfg, ) ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_actions.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from isaaclab.utils import configclass import isaaclab.envs.mdp as mdp class ActionFunctions: """Collection of action function implementations.""" @staticmethod def joint_position_action( asset_name: str = "robot", joint_names: list[str] | None = None, use_default_offset: bool = True, scale: float = 1.0, ) -> mdp.JointPositionActionCfg: """Joint position control action.""" if joint_names is None: joint_names = [".*"] return mdp.JointPositionActionCfg( asset_name=asset_name, joint_names=joint_names, use_default_offset=use_default_offset, scale=scale, ) @staticmethod def joint_velocity_action( asset_name: str = "robot", joint_names: list[str] | None = None, scale: float = 1.0, ) -> mdp.JointVelocityActionCfg: """Joint velocity control action.""" if joint_names is None: joint_names = [".*"] return mdp.JointVelocityActionCfg( asset_name=asset_name, joint_names=joint_names, scale=scale, ) @staticmethod def joint_effort_action( asset_name: str = "robot", joint_names: list[str] | None = None, scale: float = 1.0, ) -> mdp.JointEffortActionCfg: """Joint effort control action.""" if joint_names is None: joint_names = [".*"] return mdp.JointEffortActionCfg( asset_name=asset_name, joint_names=joint_names, scale=scale, ) @configclass class ActionsCfg: """Container for action terms.""" pass def build_actions_config(actions_config_dict: dict) -> ActionsCfg: """Build IsaacLab-compatible ActionsCfg from a config dictionary.""" actions_cfg = ActionsCfg() for action_name, action_config in actions_config_dict.items(): action_type = action_config["type"] params = action_config.get("params", {}) if action_type == "joint_position": action_term = ActionFunctions.joint_position_action(**params) elif action_type == "joint_velocity": action_term = ActionFunctions.joint_velocity_action(**params) elif action_type == "joint_effort": action_term = ActionFunctions.joint_effort_action(**params) else: raise ValueError(f"Unknown action type: {action_type}") setattr(actions_cfg, action_name, action_term) return actions_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_curriculum.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from isaaclab.envs import ManagerBasedRLEnv import torch from typing import Sequence from isaaclab.managers import CurriculumTermCfg from isaaclab.utils import configclass import isaaclab.envs.mdp as isaaclab_mdp from omegaconf import DictConfig, ListConfig, OmegaConf from typing import Any, Callable, Dict from loguru import logger from .isaaclab_domain_rand import DomainRandFunctions def _completion_rate_curriculum_get_level( env, *, term_tag: str = "default", metric_key: str = "Metrics/ref_motion/Task/Completion_Rate", num_updates: int = 5, cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40), min_steps_per_level: int = 300, cooldown_steps: int = 0, apply_on_startup: bool = True, startup_level: int = 0, state_prefix: str = "_cr_curr", ): base_env = getattr(env, "unwrapped", env) level_key = f"{state_prefix}_level" startup_key = f"{state_prefix}_startup_applied" last_up_key = f"{state_prefix}_last_upgrade_step" level_start_step_key = f"{state_prefix}_level_start_step" if not hasattr(base_env, level_key): setattr(base_env, level_key, -1) if not hasattr(base_env, startup_key): setattr(base_env, startup_key, False) if not hasattr(base_env, last_up_key): setattr(base_env, last_up_key, -(10**18)) if not hasattr(base_env, level_start_step_key): setattr(base_env, level_start_step_key, 0) step = int( getattr( base_env, "common_step_counter", getattr(env, "common_step_counter", 0), ) ) def _get_completion_stats(): metrics = getattr(base_env, "metrics", None) if isinstance(metrics, dict) and metric_key in metrics: val = metrics[metric_key] val = float(val.item()) if hasattr(val, "item") else float(val) return val, step return None def _thr_for_next(next_level: int) -> float: if not cr_thresholds: return 1.0 idx = max(0, min(next_level - 1, len(cr_thresholds) - 1)) return float(cr_thresholds[idx]) stats = _get_completion_stats() cur_level = int(getattr(base_env, level_key)) changed = False # -------- startup init -------- if apply_on_startup and not bool(getattr(base_env, startup_key)): init_level = int(max(0, min(int(startup_level), int(num_updates)))) setattr(base_env, level_key, max(cur_level, init_level)) setattr(base_env, startup_key, True) setattr(base_env, last_up_key, step) setattr(base_env, level_start_step_key, step) cur_level = int(getattr(base_env, level_key)) changed = True # -------- level upgrade -------- if cur_level < int(num_updates): if stats is not None: cr_val, _ = stats level_start_step = int(getattr(base_env, level_start_step_key)) stayed_steps = int(step - level_start_step) cooldown_ok = True if int(cooldown_steps) > 0: last_up = int(getattr(base_env, last_up_key)) cooldown_ok = (step - last_up) >= int(cooldown_steps) if cooldown_ok and stayed_steps >= int(min_steps_per_level): next_level = min(cur_level + 1, int(num_updates)) thr = _thr_for_next(next_level) if float(cr_val) >= float(thr): setattr(base_env, level_key, next_level) setattr(base_env, last_up_key, step) setattr(base_env, level_start_step_key, step) cur_level = next_level changed = True applied_key = ( f"{state_prefix}_applied_{str(term_tag)}_level_{int(cur_level)}" ) if not hasattr(base_env, applied_key): setattr(base_env, applied_key, False) already_applied = bool(getattr(base_env, applied_key)) need_apply = bool(changed) or (not already_applied) return int(cur_level), stats, bool(changed), bool(need_apply) def lin_vel_cmd_levels( env: ManagerBasedRLEnv, env_ids: Sequence[int], reward_term_name: str = "track_lin_vel_xy", ) -> torch.Tensor: command_term = env.command_manager.get_term("base_velocity") ranges = command_term.cfg.ranges limit_ranges = command_term.cfg.limit_ranges reward_term = env.reward_manager.get_term_cfg(reward_term_name) reward = ( torch.mean(env.reward_manager._episode_sums[reward_term_name][env_ids]) / env.max_episode_length_s ) if env.common_step_counter % env.max_episode_length == 0: if reward > reward_term.weight * 0.8: delta_command = torch.tensor([-0.1, 0.1], device=env.device) ranges.lin_vel_x = torch.clamp( torch.tensor(ranges.lin_vel_x, device=env.device) + delta_command, limit_ranges.lin_vel_x[0], limit_ranges.lin_vel_x[1], ).tolist() ranges.lin_vel_y = torch.clamp( torch.tensor(ranges.lin_vel_y, device=env.device) + delta_command, limit_ranges.lin_vel_y[0], limit_ranges.lin_vel_y[1], ).tolist() return torch.tensor(ranges.lin_vel_x[1], device=env.device) def ang_vel_cmd_levels( env: ManagerBasedRLEnv, env_ids: Sequence[int], reward_term_name: str = "track_ang_vel_z", ) -> torch.Tensor: command_term = env.command_manager.get_term("base_velocity") ranges = command_term.cfg.ranges limit_ranges = command_term.cfg.limit_ranges reward_term = env.reward_manager.get_term_cfg(reward_term_name) reward = ( torch.mean(env.reward_manager._episode_sums[reward_term_name][env_ids]) / env.max_episode_length_s ) if env.common_step_counter % env.max_episode_length == 0: if reward > reward_term.weight * 0.8: delta_command = torch.tensor([-0.1, 0.1], device=env.device) ranges.ang_vel_z = torch.clamp( torch.tensor(ranges.ang_vel_z, device=env.device) + delta_command, limit_ranges.ang_vel_z[0], limit_ranges.ang_vel_z[1], ).tolist() return torch.tensor(ranges.ang_vel_z[1], device=env.device) def robot_friction_range_by_completion_rate( env: ManagerBasedRLEnv, env_ids: Sequence[int], *, num_updates: int = 5, cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40), min_steps_per_level: int = 300, cooldown_steps: int = 0, state_prefix: str = "_cr_curr", static_friction_target=(0.3, 1.6), dynamic_friction_target=(0.3, 1.2), enforce_dynamic_le_static: bool = True, asset_name: str = "robot", body_names: str = ".*", restitution_range=(0.0, 0.5), num_buckets: int = 64, anchor_quantile: float = 0.5, min_expand_frac: float = 0.0, ): base_env = getattr(env, "unwrapped", env) def _quantile(lo: float, hi: float, q: float) -> float: lo, hi = float(min(lo, hi)), float(max(lo, hi)) q = float(max(0.0, min(1.0, q))) return lo + (hi - lo) * q def _compute_ranges(level: int): level_i = int(max(0, min(level, int(num_updates)))) frac = ( 1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates))) ) s_lo_t, s_hi_t = map(float, static_friction_target) d_lo_t, d_hi_t = map(float, dynamic_friction_target) s_lo_t, s_hi_t = min(s_lo_t, s_hi_t), max(s_lo_t, s_hi_t) d_lo_t, d_hi_t = min(d_lo_t, d_hi_t), max(d_lo_t, d_hi_t) s_anchor = _quantile(s_lo_t, s_hi_t, anchor_quantile) d_anchor = _quantile(d_lo_t, d_hi_t, anchor_quantile) eps = float(min_expand_frac) band = eps + (1.0 - eps) * float(max(frac, 0.0)) s_lo = s_anchor - (s_anchor - s_lo_t) * band s_hi = s_anchor + (s_hi_t - s_anchor) * band d_lo = d_anchor - (d_anchor - d_lo_t) * band d_hi = d_anchor + (d_hi_t - d_anchor) * band s_lo, s_hi = min(s_lo, s_hi), max(s_lo, s_hi) d_lo, d_hi = min(d_lo, d_hi), max(d_lo, d_hi) if enforce_dynamic_le_static: d_hi = min(d_hi, s_hi) d_lo = min(d_lo, d_hi) return ( float(s_lo), float(s_hi), float(d_lo), float(d_hi), float(frac), int(level_i), ) level, stats, changed, need_apply = _completion_rate_curriculum_get_level( env, term_tag="fric", num_updates=num_updates, cr_thresholds=cr_thresholds, min_steps_per_level=min_steps_per_level, cooldown_steps=cooldown_steps, state_prefix=state_prefix, ) if not need_apply: return float(level) s_lo, s_hi, d_lo, d_hi, frac, level_i = _compute_ranges(int(level)) DomainRandFunctions._get_dr_rigid_body_material( env=env, env_ids=None, asset_name=asset_name, body_names=body_names, static_friction_range=(s_lo, s_hi), dynamic_friction_range=(d_lo, d_hi), restitution_range=tuple(restitution_range), num_buckets=int(num_buckets), ) setattr(base_env, f"{state_prefix}_applied_fric_level_{int(level)}", True) return float(level) def rigid_body_com_by_completion_rate( env: ManagerBasedRLEnv, env_ids: Sequence[int], *, num_updates: int = 5, cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40), min_steps_per_level: int = 300, cooldown_steps: int = 0, state_prefix: str = "_cr_curr", asset_name: str = "robot", body_names: str = "torso_link", com_range_target: dict = { "x": (-0.025, 0.025), "y": (-0.05, 0.05), "z": (-0.05, 0.05), }, anchor_quantile: float = 0.5, min_expand_frac: float = 0.0, ): base_env = getattr(env, "unwrapped", env) def _quantile(lo: float, hi: float, q: float) -> float: lo, hi = float(min(lo, hi)), float(max(lo, hi)) q = float(max(0.0, min(1.0, q))) return lo + (hi - lo) * q level, stats, changed, need_apply = _completion_rate_curriculum_get_level( env, term_tag="com", num_updates=num_updates, cr_thresholds=cr_thresholds, min_steps_per_level=min_steps_per_level, cooldown_steps=cooldown_steps, state_prefix=state_prefix, ) if not need_apply: return float(level) level_i = int(max(0, min(int(level), int(num_updates)))) frac = ( 1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates))) ) band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float( max(frac, 0.0) ) com_range = {} for axis, (lo_t, hi_t) in com_range_target.items(): lo_t, hi_t = float(lo_t), float(hi_t) lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t) anchor = _quantile(lo_t, hi_t, anchor_quantile) lo = anchor - (anchor - lo_t) * band hi = anchor + (hi_t - anchor) * band com_range[axis] = (float(min(lo, hi)), float(max(lo, hi))) DomainRandFunctions._get_dr_rigid_body_com( env=env, env_ids=None, com_range=com_range, asset_name=asset_name, body_names=body_names, ) setattr(base_env, f"{state_prefix}_applied_com_level_{int(level)}", True) return float(level) def default_dof_pos_bias_by_completion_rate( env: ManagerBasedRLEnv, env_ids: Sequence[int], *, num_updates: int = 5, cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40), min_steps_per_level: int = 300, cooldown_steps: int = 0, state_prefix: str = "_cr_curr", asset_name: str = "robot", joint_names: list[str] = (".*"), pos_distribution_params_target: tuple[float, float] = (-0.01, 0.01), operation: str = "add", distribution: str = "uniform", anchor_quantile: float = 0.5, min_expand_frac: float = 0.0, ): base_env = getattr(env, "unwrapped", env) level, stats, changed, need_apply = _completion_rate_curriculum_get_level( env, term_tag="dof", num_updates=num_updates, cr_thresholds=cr_thresholds, min_steps_per_level=min_steps_per_level, cooldown_steps=cooldown_steps, state_prefix=state_prefix, ) if not need_apply: return float(level) def _quantile(lo: float, hi: float, q: float) -> float: lo, hi = float(min(lo, hi)), float(max(lo, hi)) q = float(max(0.0, min(1.0, q))) return lo + (hi - lo) * q lo_t, hi_t = map(float, pos_distribution_params_target) lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t) level_i = int(max(0, min(int(level), int(num_updates)))) frac = ( 1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates))) ) band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float( max(frac, 0.0) ) anchor = _quantile(lo_t, hi_t, anchor_quantile) lo = anchor - (anchor - lo_t) * band hi = anchor + (hi_t - anchor) * band lo, hi = float(min(lo, hi)), float(max(lo, hi)) DomainRandFunctions._get_dr_default_dof_pos_bias( env=env, env_ids=None, asset_name=asset_name, joint_names=joint_names, pos_distribution_params=(lo, hi), operation=operation, distribution=distribution, ) setattr(base_env, f"{state_prefix}_applied_dof_level_{int(level)}", True) return float(level) def push_by_setting_velocity_range_by_completion_rate( env: ManagerBasedRLEnv, env_ids: Sequence[int], old_value, *, num_updates: int = 5, cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40), min_steps_per_level: int = 300, cooldown_steps: int = 0, state_prefix: str = "_cr_curr", velocity_range_target: dict = { "x": (-0.5, 0.5), "y": (-0.5, 0.5), "z": (-0.2, 0.2), "roll": (-0.52, 0.52), "pitch": (-0.52, 0.52), "yaw": (-0.78, 0.78), }, anchor_quantile: float = 0.5, min_expand_frac: float = 0.0, ): base_env = getattr(env, "unwrapped", env) def _quantile(lo: float, hi: float, q: float) -> float: lo, hi = float(min(lo, hi)), float(max(lo, hi)) q = float(max(0.0, min(1.0, q))) return lo + (hi - lo) * q level, stats, changed, need_apply = _completion_rate_curriculum_get_level( env, term_tag="push", num_updates=num_updates, cr_thresholds=cr_thresholds, min_steps_per_level=min_steps_per_level, cooldown_steps=cooldown_steps, state_prefix=state_prefix, ) if not need_apply: return isaaclab_mdp.modify_term_cfg.NO_CHANGE level_i = int(max(0, min(int(level), int(num_updates)))) frac = ( 1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates))) ) band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float( max(frac, 0.0) ) new_params = dict(old_value) if isinstance(old_value, dict) else old_value current_velocity_range = {} for axis, (lo_t, hi_t) in velocity_range_target.items(): lo_t, hi_t = float(lo_t), float(hi_t) lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t) anchor = _quantile(lo_t, hi_t, anchor_quantile) lo = anchor - (anchor - lo_t) * band hi = anchor + (hi_t - anchor) * band current_velocity_range[axis] = [float(min(lo, hi)), float(max(lo, hi))] if isinstance(new_params, dict) or hasattr(new_params, "__setitem__"): if isinstance(new_params, dict): new_params = dict(new_params) new_params["velocity_range"] = current_velocity_range else: new_params["velocity_range"] = current_velocity_range else: setattr(new_params, "velocity_range", current_velocity_range) setattr(base_env, f"{state_prefix}_applied_push_level_{int(level)}", True) return new_params def randomize_actuator_gains_by_completion_rate( env: ManagerBasedRLEnv, env_ids: Sequence[int], *, num_updates: int = 5, cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40), min_steps_per_level: int = 300, cooldown_steps: int = 0, state_prefix: str = "_cr_curr", asset_name: str = "robot", body_names: str = ".*", stiffness_distribution_params_target: tuple[float, float] = (0.9, 1.1), damping_distribution_params_target: tuple[float, float] = (0.9, 1.1), operation: str = "scale", distribution: str = "uniform", anchor_quantile: float = 0.5, min_expand_frac: float = 0.0, ): base_env = getattr(env, "unwrapped", env) level, stats, changed, need_apply = _completion_rate_curriculum_get_level( env, term_tag="gains", num_updates=num_updates, cr_thresholds=cr_thresholds, min_steps_per_level=min_steps_per_level, cooldown_steps=cooldown_steps, state_prefix=state_prefix, ) if not need_apply: return float(level) def _quantile(lo: float, hi: float, q: float) -> float: lo, hi = float(min(lo, hi)), float(max(lo, hi)) q = float(max(0.0, min(1.0, q))) return lo + (hi - lo) * q level_i = int(max(0, min(int(level), int(num_updates)))) frac = ( 1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates))) ) band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float( max(frac, 0.0) ) # stiffness ks_lo_t, ks_hi_t = map(float, stiffness_distribution_params_target) ks_lo_t, ks_hi_t = min(ks_lo_t, ks_hi_t), max(ks_lo_t, ks_hi_t) ks_anchor = _quantile(ks_lo_t, ks_hi_t, anchor_quantile) ks_lo = ks_anchor - (ks_anchor - ks_lo_t) * band ks_hi = ks_anchor + (ks_hi_t - ks_anchor) * band ks_lo, ks_hi = float(min(ks_lo, ks_hi)), float(max(ks_lo, ks_hi)) # damping kd_lo_t, kd_hi_t = map(float, damping_distribution_params_target) kd_lo_t, kd_hi_t = min(kd_lo_t, kd_hi_t), max(kd_lo_t, kd_hi_t) kd_anchor = _quantile(kd_lo_t, kd_hi_t, anchor_quantile) kd_lo = kd_anchor - (kd_anchor - kd_lo_t) * band kd_hi = kd_anchor + (kd_hi_t - kd_anchor) * band kd_lo, kd_hi = float(min(kd_lo, kd_hi)), float(max(kd_lo, kd_hi)) DomainRandFunctions._get_dr_randomize_actuator_gains( env=env, env_ids=None, asset_name=asset_name, body_names=body_names, stiffness_distribution_params=(ks_lo, ks_hi), damping_distribution_params=(kd_lo, kd_hi), operation=operation, distribution=distribution, ) setattr(base_env, f"{state_prefix}_applied_gains_level_{int(level)}", True) return float(level) def reward_term_weight_by_completion_rate( env, env_ids, *, reward_term_name: str, final_weight: float, start_scale: float = 0.1, num_updates: int = 5, cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40), min_steps_per_level: int = 300, cooldown_steps: int = 0, state_prefix: str = "_cr_curr", ): base_env = getattr(env, "unwrapped", env) level, stats, changed, need_apply = _completion_rate_curriculum_get_level( env, term_tag=f"reward_{reward_term_name}", num_updates=num_updates, cr_thresholds=cr_thresholds, min_steps_per_level=min_steps_per_level, cooldown_steps=cooldown_steps, state_prefix=state_prefix, ) progress = 1.0 if num_updates <= 0 else float(level) / float(num_updates) start_weight = float(final_weight) * float(start_scale) new_weight = start_weight + progress * (float(final_weight) - start_weight) reward_cfg = env.reward_manager.get_term_cfg(reward_term_name) old_weight = float(reward_cfg.weight) if not need_apply: return float(level) reward_cfg.weight = float(new_weight) env.reward_manager.set_term_cfg(reward_term_name, reward_cfg) setattr( base_env, f"{state_prefix}_reward_weight_{reward_term_name}", float(new_weight), ) setattr( base_env, f"{state_prefix}_applied_reward_{reward_term_name}_level_{int(level)}", True, ) return float(level) @configclass class CurriculumCfg: pass def build_curriculum_config(curriculum_config_dict: dict) -> CurriculumCfg: """ Build IsaacLab-compatible CurriculumCfg from a config dictionary. """ if isinstance(curriculum_config_dict, (DictConfig, ListConfig)): curriculum_config_dict = OmegaConf.to_container( curriculum_config_dict, resolve=True ) curriculum_cfg = CurriculumCfg() cfg_dict: Dict[str, Any] = dict(curriculum_config_dict or {}) def _resolve_callable(name: Any) -> Callable: if callable(name): return name if isinstance(name, str) and name.startswith("isaaclab_mdp."): name = name.split(".", 1)[1] fn = globals().get(name) if callable(fn): return fn fn = getattr(isaaclab_mdp, name, None) if callable(fn): return fn if hasattr(isaaclab_mdp, "curriculums"): fn = getattr(isaaclab_mdp.curriculums, name, None) if callable(fn): return fn raise ValueError(f"Unknown curriculum function: {name}") def _normalize_modify_params(x: Any) -> Any: if isinstance(x, list): # many configs express tuples as YAML lists return tuple(_normalize_modify_params(v) for v in x) if isinstance(x, dict): return {k: _normalize_modify_params(v) for k, v in x.items()} return x def _fix_params(params: Dict[str, Any]) -> Dict[str, Any]: params = dict(params or {}) if "modify_fn" in params and isinstance( params["modify_fn"], (str, Callable) ): params["modify_fn"] = _resolve_callable(params["modify_fn"]) if "modify_params" in params and isinstance( params["modify_params"], dict ): params["modify_params"] = _normalize_modify_params( params["modify_params"] ) return params global_enabled = cfg_dict.pop("enabled", True) if not global_enabled: return curriculum_cfg for term_name, term_cfg in cfg_dict.items(): if term_cfg is None: term_cfg = {} if isinstance(term_cfg, bool): if not term_cfg: continue term_cfg = {} if not isinstance(term_cfg, dict): raise TypeError( f"[build_curriculum_config] term '{term_name}' must be a dict/bool/None, got {type(term_cfg)}" ) if not term_cfg.get("enabled", True): continue func_field = term_cfg.get("func", None) if func_field is None: func = _resolve_callable(term_name) else: func = _resolve_callable(func_field) params = _fix_params(term_cfg.get("params", {}) or {}) setattr( curriculum_cfg, term_name, CurriculumTermCfg(func=func, params=params), ) return curriculum_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_domain_rand.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import torch from typing import Literal import isaaclab.utils.math as math_utils from isaaclab.assets import Articulation import isaaclab.envs.mdp as isaaclab_mdp from isaaclab.envs.mdp.events import _randomize_prop_by_op from isaaclab.managers import SceneEntityCfg, EventTermCfg from isaaclab.utils import configclass from isaaclab.envs import ManagerBasedEnv from isaaclab.managers import EventTermCfg class DomainRandFunctions: @staticmethod def _get_dr_default_dof_pos_bias( env: ManagerBasedEnv, env_ids: torch.Tensor | None, asset_name: str = "robot", joint_names: list[str] = (".*"), pos_distribution_params: tuple[float, float] | None = None, operation: Literal["add", "scale", "abs"] = "abs", distribution: Literal[ "uniform", "log_uniform", "gaussian" ] = "uniform", ): asset_cfg = SceneEntityCfg(asset_name, joint_names=joint_names) asset_cfg.resolve(env.scene) asset: Articulation = env.scene[asset_name] asset.data.default_joint_pos_nominal = torch.clone( asset.data.default_joint_pos[0] ) if env_ids is None: env_ids = torch.arange(env.scene.num_envs, device=asset.device) if asset_cfg.joint_ids == slice(None): joint_ids = slice(None) else: joint_ids = torch.tensor( asset_cfg.joint_ids, dtype=torch.int, device=asset.device, ) if pos_distribution_params is not None: pos = asset.data.default_joint_pos.to(asset.device).clone() pos = _randomize_prop_by_op( pos, pos_distribution_params, env_ids, joint_ids, operation=operation, distribution=distribution, )[env_ids][:, joint_ids] if env_ids != slice(None) and joint_ids != slice(None): env_ids = env_ids[:, None] asset.data.default_joint_pos[env_ids, joint_ids] = pos env.action_manager.get_term("dof_pos")._offset[ env_ids, joint_ids ] = pos @staticmethod def _get_dr_rigid_body_com( env: ManagerBasedEnv, env_ids: torch.Tensor | None, com_range: dict[str, tuple[float, float]], asset_name: str = "robot", body_names: str = "torso_link", ): asset_cfg = SceneEntityCfg(asset_name, body_names=body_names) asset_cfg.resolve(env.scene) return isaaclab_mdp.events.randomize_rigid_body_com( env, env_ids, com_range, asset_cfg, ) @staticmethod def _get_dr_rigid_body_material( env: ManagerBasedEnv, env_ids: torch.Tensor | None, asset_name: str = "robot", body_names: str = ".*", static_friction_range: tuple[float, float] | None = None, dynamic_friction_range: tuple[float, float] | None = None, restitution_range: tuple[float, float] | None = None, num_buckets: int = 64, ): asset_cfg = SceneEntityCfg(asset_name, body_names=body_names) asset_cfg.resolve(env.scene) eveent_cfg = EventTermCfg( func=isaaclab_mdp.events.randomize_rigid_body_material, params={ "asset_cfg": asset_cfg, "static_friction_range": static_friction_range, "dynamic_friction_range": dynamic_friction_range, "restitution_range": restitution_range, "num_buckets": num_buckets, }, ) material_randomizer = ( isaaclab_mdp.events.randomize_rigid_body_material(eveent_cfg, env) ) return material_randomizer(env, env_ids, **eveent_cfg.params) @staticmethod def _get_dr_push_by_setting_velocity( env: ManagerBasedEnv, env_ids: torch.Tensor, velocity_range: dict[str, tuple[float, float]], ): return isaaclab_mdp.events.push_by_setting_velocity( env, env_ids, velocity_range, ) @staticmethod def _get_dr_randomize_actuator_gains( env: ManagerBasedEnv, env_ids: torch.Tensor, asset_name: str = "robot", body_names: str = ".*", stiffness_distribution_params: tuple[float, float] | None = None, damping_distribution_params: tuple[float, float] | None = None, operation: Literal["add", "scale", "abs"] = "abs", distribution: Literal[ "uniform", "log_uniform", "gaussian" ] = "uniform", ): asset_cfg = SceneEntityCfg(asset_name, body_names=body_names) asset_cfg.resolve(env.scene) return isaaclab_mdp.events.randomize_actuator_gains( env, env_ids, asset_cfg, stiffness_distribution_params, damping_distribution_params, operation=operation, distribution=distribution, ) @staticmethod def _get_dr_randomize_mass( env: ManagerBasedEnv, env_ids: torch.Tensor, asset_name: str = "robot", body_names: str = ".*", mass_range: tuple[float, float] | None = None, ): asset_cfg = SceneEntityCfg(asset_name, body_names=body_names) asset_cfg.resolve(env.scene) return isaaclab_mdp.events.randomize_rigid_body_mass( env, env_ids, mass_distribution_params=mass_range, asset_cfg=asset_cfg, operation="add", ) @configclass class EventsCfg: pass def build_domain_rand_config(domain_rand_config_dict: dict) -> EventsCfg: """Build IsaacLab-compatible EventsCfg from a config dictionary.""" events_cfg = EventsCfg() for event_name, cfg in domain_rand_config_dict.items(): # Keep non-event config under `domain_rand` available for Hydra # references without forcing it through the Isaac Lab event builder. if not (isinstance(cfg, dict) and "mode" in cfg): continue try: func = getattr(DomainRandFunctions, f"_get_dr_{event_name}") except AttributeError as exc: raise AttributeError( f"Unknown domain randomization event '{event_name}'" ) from exc term = EventTermCfg( func=func, **cfg, ) setattr(events_cfg, event_name, term) return events_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_motion_tracking_command.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from dataclasses import MISSING from typing import Sequence import time import json from collections import defaultdict from typing import Dict, List, Optional import numpy as np from tqdm import tqdm from scipy.spatial.transform import Rotation as sRot import isaaclab.envs.mdp as mdp import isaaclab.sim as sim_utils import isaaclab.utils.math as isaaclab_math import torch from isaaclab.actuators import ImplicitActuatorCfg from isaaclab.assets import Articulation, ArticulationCfg, AssetBaseCfg from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg from isaaclab.envs.mdp.actions import JointEffortActionCfg from isaaclab.managers import ( ActionTermCfg, CommandTerm, CommandTermCfg, EventTermCfg as EventTerm, ObservationGroupCfg, ObservationGroupCfg as ObsGroup, ObservationTermCfg, ObservationTermCfg as ObsTerm, RewardTermCfg, TerminationTermCfg, ) from isaaclab.markers import ( VisualizationMarkers, VisualizationMarkersCfg, ) from holomotion.src.training.h5_dataloader import ( Hdf5MotionDataset, Hdf5RootDofDataset, MotionClipBatchCache, build_motion_datasets_from_cfg, ) import os from isaaclab.markers.config import SPHERE_MARKER_CFG from isaaclab.sim import PreviewSurfaceCfg from isaaclab.scene import InteractiveSceneCfg from isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns from isaaclab.sim import PhysxCfg, SimulationCfg from isaaclab.terrains import TerrainImporterCfg from isaaclab.utils import configclass from isaaclab.utils.noise import AdditiveUniformNoiseCfg as Unoise from omegaconf import OmegaConf from holomotion.src.utils.isaac_utils.rotations import ( calc_heading_quat_inv, get_euler_xyz, my_quat_rotate, quat_inverse, quat_mul, quat_rotate, quat_rotate_inverse, quaternion_to_matrix, wrap_to_pi, wxyz_to_xyzw, xyzw_to_wxyz, ) from holomotion.src.utils.reference_prefix import ( resolve_reference_tensor_key, ) from loguru import logger class RefMotionCommand(CommandTerm): cfg: CommandTermCfg def __init__( self, cfg, env: ManagerBasedRLEnv, ): # print(cfg) super().__init__(cfg, env) self._env = env self._is_evaluating = self.cfg.is_evaluating self._runtime_process_id = int(self.cfg.process_id) self._runtime_num_processes = max(1, int(self.cfg.num_processes)) self._init_robot_handle() self._init_buffers() self._init_motion_lib() # # self._init_tracking_config() def _init_tracking_config(self, config): self.log_dict_holomotion = {} self.log_dict_nonreduced_holomotion = {} self.log_dict_nonreduced = {} self.log_dict = {} if "head_hand_bodies" in config: self.motion_tracking_id = [ self.robot.body_names.index(link) for link in config.head_hand_bodies ] if "leg_body_names" in config: self.lower_body_id = [ self.robot.body_names.index(link) for link in config.leg_body_names ] if "arm_body_names" in config: self.upper_body_id = [ self.robot.body_names.index(link) for link in config.arm_body_names ] if "leg_dof_names" in config: self.lower_body_joint_ids = [ config.dof_names.index(link) for link in config.leg_dof_names ] if "arm_dof_names" in config: self.upper_body_joint_ids = [ config.dof_names.index(link) for link in config.arm_dof_names ] if "waist_dof_names" in config: self.waist_dof_indices = [ config.dof_names.index(link) for link in config.waist_dof_names ] @staticmethod def _amp_filter_names_by_prefix( names: Sequence[str], prefix: str, keywords: Sequence[str] ) -> list[str]: return [ name for name in names if name.startswith(prefix) and any(key in name for key in keywords) ] @staticmethod def _amp_pick_first_name( names: Sequence[str], patterns: Sequence[str] ) -> str | None: for pattern in patterns: for name in names: if pattern in name: return name return None def _resolve_motion_cache_stage_device( self, cache_cfg: Dict[str, object] ) -> Optional[torch.device]: raw_stage_device = cache_cfg.get("device", "cuda") if isinstance(raw_stage_device, torch.device): if raw_stage_device.type == "cpu": return None if raw_stage_device.type != "cuda": raise ValueError( f"Unsupported motion cache device: {raw_stage_device}" ) if raw_stage_device.index is not None: return raw_stage_device if not torch.cuda.is_available(): return None local_rank_env = os.environ.get("LOCAL_RANK") if local_rank_env is not None: local_rank = int(local_rank_env) device_count = int(torch.cuda.device_count()) if 0 <= local_rank < device_count: return torch.device("cuda", local_rank) return torch.device("cuda", int(torch.cuda.current_device())) stage_device = str(raw_stage_device).strip().lower() if stage_device in ("none", "cpu"): return None if stage_device == "cuda": if isinstance(self.device, torch.device): if self.device.type == "cuda": return self.device return None device_str = str(self.device).strip().lower() if device_str.startswith("cuda"): return torch.device(device_str) if not torch.cuda.is_available(): return None local_rank_env = os.environ.get("LOCAL_RANK") if local_rank_env is not None: local_rank = int(local_rank_env) device_count = int(torch.cuda.device_count()) if 0 <= local_rank < device_count: return torch.device("cuda", local_rank) return torch.device("cuda", int(torch.cuda.current_device())) if stage_device.startswith("cuda:"): return torch.device(stage_device) raise ValueError( f"Unsupported motion cache device config: {raw_stage_device}" ) def _init_motion_lib(self): mcfg = OmegaConf.create(self.cfg.motion_lib_cfg) self.mcfg = mcfg backend = str(mcfg.get("backend", "hdf5")).lower() self._motion_cache = None if backend in ("hdf5", "hdf5_simple"): # Support multi-root configuration while keeping single-root # behavior fully backward compatible. train_hdf5_roots = mcfg.get("train_hdf5_roots", None) val_hdf5_roots = mcfg.get("val_hdf5_roots", None) if train_hdf5_roots: train_roots = [str(r) for r in train_hdf5_roots] else: hdf5_root = mcfg.get("hdf5_root") if hdf5_root is None: raise ValueError("hdf5_root is required") train_roots = [str(hdf5_root)] val_hdf5_root = mcfg.get("val_hdf5_root", None) if val_hdf5_roots: val_roots = [str(r) for r in val_hdf5_roots] elif val_hdf5_root is not None and str(val_hdf5_root) != str( train_roots[0] ): val_roots = [str(val_hdf5_root)] else: val_roots = None train_manifest_paths = [ os.path.join(root, "manifest.json") for root in train_roots ] for mp in train_manifest_paths: if not os.path.exists(mp): raise FileNotFoundError( f"HDF5 manifest not found at {mp}. " "Please set robot.motion.hdf5_root/train_hdf5_roots to " "the correct path!" ) max_frame_length = int(mcfg.get("max_frame_length", 500)) min_frame_length = int(mcfg.get("min_frame_length", 1)) world_frame_norm = bool( mcfg.get("world_frame_normalization", True) ) cache_cfg = mcfg.get("cache", {}) allowed_prefixes = cache_cfg.get( "allowed_prefixes", ["ref_", "ft_ref_"], ) if len(train_manifest_paths) == 1: logger.info( f"Loading HDF5 training dataset from {train_manifest_paths[0]}" ) else: logger.info( f"Loading HDF5 training dataset from manifests: " f"{train_manifest_paths}" ) train_dataset = Hdf5MotionDataset( manifest_path=train_manifest_paths if len(train_manifest_paths) > 1 else train_manifest_paths[0], max_frame_length=max_frame_length, min_window_length=min_frame_length, handpicked_motion_names=mcfg.get( "handpicked_motion_names", None ), excluded_motion_names=mcfg.get("excluded_motion_names", None), world_frame_normalization=world_frame_norm, allowed_prefixes=allowed_prefixes, ) if len(train_dataset) == 0: raise ValueError( "Training dataset is empty. Check that all manifests " "contain valid clips with length " f">= {min_frame_length}" ) logger.info(f"Loaded {len(train_dataset)} training motion windows") train_num_clips = len(train_dataset.clips) train_total_frames = sum( int(meta.get("length", 0)) for meta in train_dataset.clips.values() ) fps_used = int(self.cfg.target_fps) train_duration_s = ( float(train_total_frames) / float(fps_used) if fps_used > 0 else 0.0 ) if len(train_roots) == 1: logger.info( f"Train dataset: root={train_roots[0]}, " f"manifest={train_manifest_paths[0]}" ) else: logger.info( f"Train dataset: roots={train_roots}, " f"manifests={train_manifest_paths}" ) logger.info( f"Train clips={train_num_clips}, frames={train_total_frames}, " f"duration={train_duration_s / 3600:.2f}h @ {fps_used} fps" ) excluded_names = mcfg.get("excluded_motion_names", None) if excluded_names: excluded_set = set(excluded_names) excluded_clip_keys = [ k for k in train_dataset.clips.keys() if k in excluded_set ] excluded_num_clips = len(excluded_clip_keys) excluded_total_frames = sum( int(train_dataset.clips[k].get("length", 0)) for k in excluded_clip_keys ) excluded_duration_s = ( float(excluded_total_frames) / float(fps_used) if fps_used > 0 else 0.0 ) left_num_clips = max(0, train_num_clips - excluded_num_clips) left_total_frames = max( 0, train_total_frames - excluded_total_frames ) left_duration_s = ( float(left_total_frames) / float(fps_used) if fps_used > 0 else 0.0 ) logger.info( f"Excluded (by name): clips={excluded_num_clips}, " f"frames={excluded_total_frames}, " f"duration={excluded_duration_s / 3600:.2f}h" ) logger.info( f"Remaining after exclusion: clips={left_num_clips}, " f"frames={left_total_frames}, " f"duration={left_duration_s / 3600:.2f}h" ) val_dataset = None if val_roots is not None: val_manifest_paths = [ os.path.join(root, "manifest.json") for root in val_roots ] for mp in val_manifest_paths: if not os.path.exists(mp): raise FileNotFoundError( f"HDF5 validation manifest not found at {mp}. " "Please set robot.motion.val_hdf5_root/" "val_hdf5_roots to the correct path!" ) if len(val_manifest_paths) == 1: logger.info( f"Loading HDF5 validation dataset from {val_manifest_paths[0]}" ) else: logger.info( "Loading HDF5 validation dataset from manifests: " f"{val_manifest_paths}" ) val_dataset = Hdf5MotionDataset( manifest_path=val_manifest_paths if len(val_manifest_paths) > 1 else val_manifest_paths[0], max_frame_length=max_frame_length, min_window_length=min_frame_length, handpicked_motion_names=mcfg.get( "handpicked_motion_names", None ), excluded_motion_names=mcfg.get( "excluded_motion_names", None ), world_frame_normalization=world_frame_norm, allowed_prefixes=allowed_prefixes, ) logger.info( f"Loaded {len(val_dataset)} validation motion windows" ) val_num_clips = len(val_dataset.clips) val_total_frames = sum( int(meta.get("length", 0)) for meta in val_dataset.clips.values() ) val_duration_s = ( float(val_total_frames) / float(fps_used) if fps_used > 0 else 0.0 ) if len(val_roots) == 1: logger.info( f"Val dataset: root={val_roots[0]}, " f"manifest={val_manifest_paths[0]}" ) else: logger.info( f"Val dataset: roots={val_roots}, " f"manifests={val_manifest_paths}" ) logger.info( f"Val clips={val_num_clips}, frames={val_total_frames}, " f"duration={val_duration_s / 3600:.1f}h @ {fps_used} fps" ) else: logger.info( "Validation dataset: using training dataset " "(no separate val manifest found)" ) dataloader_cfg = mcfg.get("dataloader", {}) stage_device = self._resolve_motion_cache_stage_device(cache_cfg) self._motion_cache = MotionClipBatchCache( train_dataset=train_dataset, val_dataset=val_dataset, batch_size=int(cache_cfg.get("max_num_clips", 1024)), stage_device=stage_device, num_workers=int(dataloader_cfg.get("num_workers", 4)), prefetch_factor=dataloader_cfg.get("prefetch_factor", None), pin_memory=bool(dataloader_cfg.get("pin_memory", True)), persistent_workers=bool( dataloader_cfg.get("persistent_workers", True) ), batch_progress_bar=bool( cache_cfg.get("batch_progress_bar", False) ), sampler_rank=int(self.cfg.process_id), sampler_world_size=int(self.cfg.num_processes), allowed_prefixes=allowed_prefixes, swap_interval_steps=int( cache_cfg.get("swap_interval_steps", max_frame_length) ), seed=int(self.cfg.seed), loader_timeout=float(dataloader_cfg.get("timeout", 0.0)), ) cache = self._motion_cache logger.info( "DataLoader params: " f"batch_size={cache._batch_size}, " f"num_workers={cache._num_workers}, " f"prefetch_factor={cache._prefetch_factor}, " f"pin_memory={cache._pin_memory}, " f"persistent_workers={cache._persistent_workers}" ) logger.info( "Sampler/Cache params: " f"rank={cache._sampler_rank}/{cache._sampler_world_size}, " f"device={cache._stage_device}, " f"swap_interval_steps={cache.swap_interval_steps}" ) self._motion_lib = None elif backend == "hdf5_v2": max_frame_length = int(mcfg.get("max_frame_length", 500)) min_frame_length = int(mcfg.get("min_frame_length", 1)) world_frame_norm = bool( mcfg.get("world_frame_normalization", True) ) cache_cfg = mcfg.get("cache", {}) allowed_prefixes = cache_cfg.get( "allowed_prefixes", ["ref_", "ft_ref_"], ) train_hdf5_roots = mcfg.get("train_hdf5_roots", None) if train_hdf5_roots: train_roots = [str(r) for r in train_hdf5_roots] else: hdf5_root = mcfg.get("hdf5_root", None) train_roots = [str(hdf5_root)] if hdf5_root is not None else [] train_manifest_paths = [ os.path.join(root, "manifest.json") for root in train_roots ] ( train_dataset, val_dataset, cache_kwargs, ) = build_motion_datasets_from_cfg( motion_cfg=mcfg, max_frame_length=max_frame_length, min_window_length=min_frame_length, world_frame_normalization=world_frame_norm, handpicked_motion_names=mcfg.get( "handpicked_motion_names", None ), excluded_motion_names=mcfg.get("excluded_motion_names", None), allowed_prefixes=allowed_prefixes, ) if len(train_dataset) == 0: raise ValueError( "Training dataset is empty. Check that all HDF5 v2 " "roots contain valid clips with length " f">= {min_frame_length}" ) if len(train_manifest_paths) == 1: logger.info( f"Loading HDF5 v2 training dataset from {train_manifest_paths[0]}" ) else: logger.info( "Loading HDF5 v2 training dataset from manifests: " f"{train_manifest_paths}" ) fps_used = int(self.cfg.target_fps) logger.info(f"Loaded {len(train_dataset)} training motion windows") train_num_clips = len(train_dataset.clips) train_total_frames = sum( int(meta.get("length", 0)) for meta in train_dataset.clips.values() ) train_duration_s = ( float(train_total_frames) / float(fps_used) if fps_used > 0 else 0.0 ) logger.info( f"Train clips={train_num_clips}, frames={train_total_frames}, " f"duration={train_duration_s / 3600:.2f}h @ {fps_used} fps" ) if len(train_roots) == 1: logger.info( f"Train dataset: root={train_roots[0]}, " f"manifest={train_manifest_paths[0]}" ) elif len(train_roots) > 1: logger.info( f"Train dataset: roots={train_roots}, " f"manifests={train_manifest_paths}" ) excluded_names = mcfg.get("excluded_motion_names", None) if excluded_names: excluded_set = set(excluded_names) excluded_clip_keys: List[str] = [] if isinstance(train_dataset, Hdf5RootDofDataset): for key, meta in train_dataset.clips.items(): aliases = train_dataset._build_motion_key_aliases( key, meta ) if any(alias in excluded_set for alias in aliases): excluded_clip_keys.append(key) else: excluded_clip_keys = [ k for k in train_dataset.clips.keys() if k in excluded_set ] excluded_num_clips = len(excluded_clip_keys) excluded_total_frames = sum( int(train_dataset.clips[k].get("length", 0)) for k in excluded_clip_keys ) excluded_duration_s = ( float(excluded_total_frames) / float(fps_used) if fps_used > 0 else 0.0 ) remaining_num_clips = train_num_clips - excluded_num_clips remaining_total_frames = ( train_total_frames - excluded_total_frames ) remaining_duration_s = train_duration_s - excluded_duration_s logger.info( "Excluded (by name): " f"clips={excluded_num_clips}, frames={excluded_total_frames}, " f"duration={excluded_duration_s / 3600:.2f}h" ) logger.info( "Remaining after exclusion: " f"clips={remaining_num_clips}, frames={remaining_total_frames}, " f"duration={remaining_duration_s / 3600:.2f}h" ) if val_dataset is None: logger.info( "Validation dataset: using training dataset " "(no separate val HDF5 v2 roots found)" ) dataloader_cfg = mcfg.get("dataloader", {}) stage_device = self._resolve_motion_cache_stage_device(cache_cfg) self._motion_cache = MotionClipBatchCache( train_dataset=train_dataset, val_dataset=val_dataset, batch_size=int(cache_cfg.get("max_num_clips", 1024)), stage_device=stage_device, num_workers=int(dataloader_cfg.get("num_workers", 4)), prefetch_factor=dataloader_cfg.get("prefetch_factor", None), pin_memory=bool(dataloader_cfg.get("pin_memory", True)), persistent_workers=bool( dataloader_cfg.get("persistent_workers", True) ), batch_progress_bar=bool( cache_cfg.get("batch_progress_bar", False) ), sampler_rank=int(self.cfg.process_id), sampler_world_size=int(self.cfg.num_processes), allowed_prefixes=allowed_prefixes, swap_interval_steps=int( cache_cfg.get("swap_interval_steps", max_frame_length) ), seed=int(self.cfg.seed), loader_timeout=float(dataloader_cfg.get("timeout", 0.0)), **cache_kwargs, ) cache = self._motion_cache logger.info( "DataLoader params: " f"batch_size={cache._batch_size}, " f"num_workers={cache._num_workers}, " f"prefetch_factor={cache._prefetch_factor}, " f"pin_memory={cache._pin_memory}, " f"persistent_workers={cache._persistent_workers}" ) logger.info( "Sampler/Cache params: " f"rank={cache._sampler_rank}/{cache._sampler_world_size}, " f"device={cache._stage_device}, " f"swap_interval_steps={cache.swap_interval_steps}" ) self._motion_lib = None else: raise ValueError(f"Unsupported motion backend: {backend}") sampling_strategy_cfg = mcfg.get("sampling_strategy", None) if sampling_strategy_cfg is None: sampling_strategy = "uniform" else: sampling_strategy = str(sampling_strategy_cfg).lower() if sampling_strategy == "weighted_bin": weighted_bin_cfg = mcfg.get("weighted_bin", {}) self._motion_cache.enable_weighted_bin_sampling( cfg=dict(weighted_bin_cfg or {}) ) elif sampling_strategy == "curriculum": curriculum_cfg = dict(mcfg.get("curriculum", {}) or {}) self._motion_cache.enable_cache_curriculum_sampling( cfg=curriculum_cfg ) elif sampling_strategy not in ("uniform", "curriculum"): raise ValueError( f"Invalid sampling_strategy '{sampling_strategy}'. " "Expected one of ['curriculum', 'uniform', 'weighted_bin']." ) self._sampling_strategy = sampling_strategy self._init_per_env_cache() def setup_dumping_dir(self, log_dir: str): mcfg = self.mcfg base_log_dir = str(log_dir) if self._sampling_strategy == "curriculum": curriculum_dump_dir = os.path.join( base_log_dir, "cache_curriculum_window_scores" ) self._motion_cache.set_cache_curriculum_dump_dir( curriculum_dump_dir ) self._dump_sampled_motion_keys_enabled = bool( mcfg.get("dump_sampled_motion_keys", False) ) if not self._dump_sampled_motion_keys_enabled: return self._dump_sampled_motion_keys_interval = max( 1, int(mcfg.get("dump_sampled_motion_keys_interval", 1)) ) dump_dir_cfg = "sampled_motion_cache_keys" self._dump_sampled_motion_keys_dir = os.path.join( base_log_dir, dump_dir_cfg ) if self._dump_sampled_motion_keys_enabled: os.makedirs(self._dump_sampled_motion_keys_dir, exist_ok=True) logger.info( f"Dumping sampled motion keys to {self._dump_sampled_motion_keys_dir}" ) def set_runtime_distributed_context( self, *, process_id: int, num_processes: int ) -> None: self._runtime_process_id = int(process_id) self._runtime_num_processes = max(1, int(num_processes)) def set_motion_cache_seed( self, seed: int, *, reinitialize: bool = True ) -> None: self._motion_cache.set_seed(int(seed), reinitialize=reinitialize) if reinitialize: self._init_per_env_cache() def close(self) -> None: """Release motion cache resources for this command term.""" if self._motion_cache is not None: self._motion_cache.close() self._motion_cache = None def _init_per_env_cache(self): """Initialize per-env cache for motion tracking.""" self._clip_indices = torch.zeros( self.num_envs, dtype=torch.long, device=self.device ) self._frame_indices = torch.zeros( self.num_envs, dtype=torch.long, device=self.device ) self._swap_pending = False self._swap_step_counter = 0 # Initial assignment clip_idx, frame_idx = self._motion_cache.sample_env_assignments( self.num_envs, self.cfg.n_fut_frames, self.device, deterministic_start=(self._is_evaluating), ) self._clip_indices[:] = clip_idx self._frame_indices[:] = frame_idx self._start_frame_indices[:] = frame_idx self._reward_sum_since_assign[:] = 0.0 self._step_count_since_assign[:] = 0.0 self._update_ref_motion_state_from_cache() def _maybe_dump_sampled_motion_keys(self) -> None: if not self._dump_sampled_motion_keys_enabled: return swap_index = int(self._motion_cache.swap_index) if swap_index <= 0: return if swap_index % self._dump_sampled_motion_keys_interval != 0: return current_batch = self._motion_cache.current_batch window_indices = current_batch.window_indices.detach().cpu().tolist() cache_scores = None cache_selection_counts = None cache_in_prioritized_pool = None curriculum_state_step = None score_bundle = ( self._motion_cache.cache_curriculum_scores_for_window_indices( current_batch.window_indices ) ) if score_bundle is not None: score_tensor, state, version = score_bundle cache_scores = score_tensor.detach().cpu().tolist() cache_selection_counts = ( state["selection_count"].detach().cpu().tolist() ) cache_in_prioritized_pool = ( state["in_prioritized_pool"].detach().cpu().tolist() ) curriculum_state_step = int(version) payload = { "swap_index": swap_index, "sampling_strategy": str(self._sampling_strategy), "num_keys": int(len(current_batch.motion_keys)), "motion_keys": list(current_batch.motion_keys), "raw_motion_keys": list(current_batch.raw_motion_keys), "window_indices": window_indices, "cache_sampling_score": cache_scores, "cache_sampling_count": cache_selection_counts, "cache_in_prioritized_pool": cache_in_prioritized_pool, "curriculum_state_step": curriculum_state_step, } file_name = ( f"sampled_motion_keys_rank_{self._runtime_process_id:04d}_swap_" f"{swap_index:06d}.json" ) output_path = os.path.join( self._dump_sampled_motion_keys_dir, file_name ) with open(output_path, "w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2) handle.write("\n") def _init_robot_handle(self): self.robot: Articulation = self._env.scene[self.cfg.asset_name] self.anchor_bodylink_name = self.cfg.anchor_bodylink_name self.anchor_bodylink_idx = self.robot.body_names.index( self.anchor_bodylink_name ) self.urdf_dof_names = self.cfg.urdf_dof_names self.urdf_body_names = self.cfg.urdf_body_names self.simulator_dof_names = self.robot.joint_names self.simulator_body_names = self.robot.body_names self.urdf2sim_dof_idx = [ self.urdf_dof_names.index(dof) for dof in self.simulator_dof_names ] self.urdf2sim_body_idx = [ self.urdf_body_names.index(body) for body in self.simulator_body_names ] self.sim2urdf_dof_idx = [ self.simulator_dof_names.index(dof) for dof in self.urdf_dof_names ] self.sim2urdf_body_idx = [ self.simulator_body_names.index(body) for body in self.urdf_body_names ] self.arm_dof_indices = [ self.simulator_dof_names.index(dof) for dof in self.cfg.arm_dof_names ] self.torso_dof_indices = [ self.simulator_dof_names.index(dof) for dof in self.cfg.waist_dof_names ] self.leg_dof_indices = [ self.simulator_dof_names.index(dof) for dof in self.cfg.leg_dof_names ] # Body indices for mpkpe metrics using unified naming self.arm_body_indices = [ self.simulator_body_names.index(body) for body in self.cfg.arm_body_names ] self.torso_body_indices = [ self.simulator_body_names.index(body) for body in self.cfg.torso_body_names ] self.leg_body_indices = [ self.simulator_body_names.index(body) for body in self.cfg.leg_body_names ] # Per-env world origins (translation only) # Shape: [num_envs, 3] on the same device as the sim self._env_origins = self._env.scene.env_origins.to(self.device) # AMP-style observation indices (RSL reference alignment) urdf_dof_name_to_idx = { name: idx for idx, name in enumerate(self.urdf_dof_names) } sim_dof_name_to_idx = { name: idx for idx, name in enumerate(self.simulator_dof_names) } urdf_body_name_to_idx = { name: idx for idx, name in enumerate(self.urdf_body_names) } sim_body_name_to_idx = { name: idx for idx, name in enumerate(self.simulator_body_names) } left_arm_dof_names = list( getattr(self.cfg, "left_arm_dof_names", []) or [] ) right_arm_dof_names = list( getattr(self.cfg, "right_arm_dof_names", []) or [] ) left_leg_dof_names = list( getattr(self.cfg, "left_leg_dof_names", []) or [] ) right_leg_dof_names = list( getattr(self.cfg, "right_leg_dof_names", []) or [] ) if not left_arm_dof_names: left_arm_dof_names = self._amp_filter_names_by_prefix( self.urdf_dof_names, "left_", ("shoulder", "elbow", "wrist"), ) if not right_arm_dof_names: right_arm_dof_names = self._amp_filter_names_by_prefix( self.urdf_dof_names, "right_", ("shoulder", "elbow", "wrist"), ) if not left_leg_dof_names: left_leg_dof_names = self._amp_filter_names_by_prefix( self.urdf_dof_names, "left_", ("hip", "knee", "ankle") ) if not right_leg_dof_names: right_leg_dof_names = self._amp_filter_names_by_prefix( self.urdf_dof_names, "right_", ("hip", "knee", "ankle") ) self._amp_left_arm_urdf_dof_idx = [ urdf_dof_name_to_idx[name] for name in left_arm_dof_names ] self._amp_right_arm_urdf_dof_idx = [ urdf_dof_name_to_idx[name] for name in right_arm_dof_names ] self._amp_left_leg_urdf_dof_idx = [ urdf_dof_name_to_idx[name] for name in left_leg_dof_names ] self._amp_right_leg_urdf_dof_idx = [ urdf_dof_name_to_idx[name] for name in right_leg_dof_names ] self._amp_left_arm_sim_dof_idx = [ sim_dof_name_to_idx[name] for name in left_arm_dof_names ] self._amp_right_arm_sim_dof_idx = [ sim_dof_name_to_idx[name] for name in right_arm_dof_names ] self._amp_left_leg_sim_dof_idx = [ sim_dof_name_to_idx[name] for name in left_leg_dof_names ] self._amp_right_leg_sim_dof_idx = [ sim_dof_name_to_idx[name] for name in right_leg_dof_names ] left_arm_body_names = list( getattr(self.cfg, "left_arm_body_names", []) or [] ) right_arm_body_names = list( getattr(self.cfg, "right_arm_body_names", []) or [] ) left_leg_body_names = list( getattr(self.cfg, "left_leg_body_names", []) or [] ) right_leg_body_names = list( getattr(self.cfg, "right_leg_body_names", []) or [] ) if not left_arm_body_names: left_arm_body_names = self._amp_filter_names_by_prefix( self.urdf_body_names, "left_", ("shoulder", "elbow", "wrist") ) if not right_arm_body_names: right_arm_body_names = self._amp_filter_names_by_prefix( self.urdf_body_names, "right_", ("shoulder", "elbow", "wrist") ) if not left_leg_body_names: left_leg_body_names = self._amp_filter_names_by_prefix( self.urdf_body_names, "left_", ("hip", "knee", "ankle") ) if not right_leg_body_names: right_leg_body_names = self._amp_filter_names_by_prefix( self.urdf_body_names, "right_", ("hip", "knee", "ankle") ) left_elbow_name = self._amp_pick_first_name( left_arm_body_names, ("left_elbow", "elbow") ) right_elbow_name = self._amp_pick_first_name( right_arm_body_names, ("right_elbow", "elbow") ) left_foot_name = self._amp_pick_first_name( left_leg_body_names, ("left_ankle_roll", "left_ankle_pitch", "left_ankle"), ) right_foot_name = self._amp_pick_first_name( right_leg_body_names, ("right_ankle_roll", "right_ankle_pitch", "right_ankle"), ) self._amp_left_elbow_urdf_body_idx = ( urdf_body_name_to_idx[left_elbow_name] if left_elbow_name is not None else None ) self._amp_right_elbow_urdf_body_idx = ( urdf_body_name_to_idx[right_elbow_name] if right_elbow_name is not None else None ) self._amp_left_foot_urdf_body_idx = ( urdf_body_name_to_idx[left_foot_name] if left_foot_name is not None else None ) self._amp_right_foot_urdf_body_idx = ( urdf_body_name_to_idx[right_foot_name] if right_foot_name is not None else None ) self._amp_left_elbow_sim_body_idx = ( sim_body_name_to_idx[left_elbow_name] if left_elbow_name is not None else None ) self._amp_right_elbow_sim_body_idx = ( sim_body_name_to_idx[right_elbow_name] if right_elbow_name is not None else None ) self._amp_left_foot_sim_body_idx = ( sim_body_name_to_idx[left_foot_name] if left_foot_name is not None else None ) self._amp_right_foot_sim_body_idx = ( sim_body_name_to_idx[right_foot_name] if right_foot_name is not None else None ) self._amp_left_hand_local_vec = torch.tensor( [0.0, 0.0, -0.3], device=self.device, dtype=torch.float32 ) self._amp_right_hand_local_vec = torch.tensor( [0.0, 0.0, -0.3], device=self.device, dtype=torch.float32 ) def _init_buffers(self): self.metrics = {} self.ref_motion_global_frame_ids = torch.zeros( self.num_envs, dtype=torch.long, device=self.device, ) # mark envs that timed out (frame id exceeded end frame) in current step self._motion_end_mask = torch.zeros( self.num_envs, dtype=torch.bool, device=self.device, ) # counter for number of motion ends per environment self.motion_end_counter = torch.zeros( self.num_envs, dtype=torch.long, device=self.device, ) # per-environment cached motion indices self._cached_motion_ids = torch.zeros( self.num_envs, dtype=torch.long, device=self.device, ) # env -> cache row indirection (starts as identity mapping) self._env_to_cache_row = torch.arange( self.num_envs, dtype=torch.long, device=self.device ) self._start_frame_indices = torch.zeros( self.num_envs, dtype=torch.long, device=self.device, ) self._reward_sum_since_assign = torch.zeros( self.num_envs, dtype=torch.float32, device=self.device, ) self._mpjpe_sum_since_assign = torch.zeros( self.num_envs, dtype=torch.float32, device=self.device, ) self._mpkpe_sum_since_assign = torch.zeros( self.num_envs, dtype=torch.float32, device=self.device, ) self._step_count_since_assign = torch.zeros( self.num_envs, dtype=torch.float32, device=self.device, ) self._completion_rate_sum_by_window: Dict[int, float] = {} self._completion_rate_count_by_window: Dict[int, int] = {} self._mpkpe_signal_sum_by_window: Dict[int, float] = {} self._mpkpe_signal_count_by_window: Dict[int, int] = {} self.pos_history_buffer = None self.rot_history_buffer = None self.ref_pos_history_buffer = None self.current_accel = None self.ref_body_accel = None self.current_ang_accel = None # Placeholder for angular acceleration self.metrics["Task/MPJPE_WholeBody"] = torch.zeros( self.num_envs, device=self.device ) self.metrics["Task/MPKPE_WholeBody"] = torch.zeros( self.num_envs, device=self.device ) def _record_completion_rate_for_envs(self, env_ids: torch.Tensor) -> None: if env_ids.numel() == 0: return selected_clip_indices = self._clip_indices[env_ids] lengths = self._motion_cache.lengths_for_indices(selected_clip_indices) window_indices = self._motion_cache.window_indices_for_indices( selected_clip_indices ) available_steps = torch.clamp( lengths - int(self.cfg.n_fut_frames) - self._start_frame_indices[env_ids], min=1, ) completion_rate = torch.clamp( self._step_count_since_assign[env_ids] / available_steps.float(), min=0.0, max=1.0, ) step_den = torch.clamp(self._step_count_since_assign[env_ids], min=1.0) mpkpe_mean = self._mpkpe_sum_since_assign[env_ids] / step_den completion_values = completion_rate.detach().cpu().tolist() mpkpe_values = mpkpe_mean.detach().cpu().tolist() window_values = window_indices.detach().cpu().tolist() for idx, window_index_obj in enumerate(window_values): completion_value = float(completion_values[idx]) mpkpe_value = float(mpkpe_values[idx]) mpkpe_signal = -mpkpe_value window_index = int(window_index_obj) if window_index in self._completion_rate_sum_by_window: self._completion_rate_sum_by_window[window_index] += ( completion_value ) self._completion_rate_count_by_window[window_index] += 1 else: self._completion_rate_sum_by_window[window_index] = ( completion_value ) self._completion_rate_count_by_window[window_index] = 1 if window_index in self._mpkpe_signal_sum_by_window: self._mpkpe_signal_sum_by_window[window_index] += mpkpe_signal self._mpkpe_signal_count_by_window[window_index] += 1 else: self._mpkpe_signal_sum_by_window[window_index] = mpkpe_signal self._mpkpe_signal_count_by_window[window_index] = 1 self._reward_sum_since_assign[env_ids] = 0.0 self._mpjpe_sum_since_assign[env_ids] = 0.0 self._mpkpe_sum_since_assign[env_ids] = 0.0 self._step_count_since_assign[env_ids] = 0.0 def _reset_window_curriculum_stats(self) -> None: self._completion_rate_sum_by_window = {} self._completion_rate_count_by_window = {} self._mpkpe_signal_sum_by_window = {} self._mpkpe_signal_count_by_window = {} def _build_window_curriculum_stats_from_current_batch( self, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_window_indices = self._motion_cache.current_batch.window_indices row_window_indices = batch_window_indices.detach().to( self.device, dtype=torch.long ) count = int(row_window_indices.numel()) row_mpkpe_signal = torch.zeros( count, dtype=torch.float32, device=self.device ) row_completion_rate = torch.zeros( count, dtype=torch.float32, device=self.device ) row_count = torch.zeros(count, dtype=torch.float32, device=self.device) window_values = row_window_indices.detach().cpu().tolist() for row_idx, window_index_obj in enumerate(window_values): window_index = int(window_index_obj) completion_count = int( self._completion_rate_count_by_window.get(window_index, 0) ) mpkpe_count = int( self._mpkpe_signal_count_by_window.get(window_index, 0) ) if completion_count > 0: row_completion_rate[row_idx] = float( self._completion_rate_sum_by_window[window_index] ) / float(completion_count) if mpkpe_count > 0: row_mpkpe_signal[row_idx] = float( self._mpkpe_signal_sum_by_window[window_index] ) / float(mpkpe_count) row_count[row_idx] = float(max(completion_count, mpkpe_count)) return ( row_window_indices, row_mpkpe_signal, row_completion_rate, row_count, ) def _update_cache_curriculum_state( self, *, accelerator, swap_index: int, ) -> None: if self._sampling_strategy != "curriculum": self._reset_window_curriculum_stats() return ( row_window_indices, row_mpkpe_signal, row_completion_rate, row_count, ) = self._build_window_curriculum_stats_from_current_batch() if accelerator is not None and int(accelerator.num_processes) > 1: gather_window_indices = accelerator.gather(row_window_indices) gather_mpkpe_signal = accelerator.gather(row_mpkpe_signal) gather_completion_rate = accelerator.gather(row_completion_rate) gather_count = accelerator.gather(row_count) else: gather_window_indices = row_window_indices gather_mpkpe_signal = row_mpkpe_signal gather_completion_rate = row_completion_rate gather_count = row_count self._motion_cache.update_cache_curriculum( window_indices=gather_window_indices, mpkpe_signal_means=gather_mpkpe_signal, completion_rate_means=gather_completion_rate, counts=gather_count, swap_index=int(swap_index), ) self._reset_window_curriculum_stats() def update_curriculum_reward_accumulators( self, rewards: torch.Tensor ) -> None: reward_flat = rewards.view(-1).to(self.device, dtype=torch.float32) all_ids = torch.arange( self.num_envs, dtype=torch.long, device=self.device ) motion_ids = self._filter_env_ids_for_motion_task(all_ids) if motion_ids.numel() == 0: return self._reward_sum_since_assign[motion_ids] += reward_flat[motion_ids] mpjpe = self.metrics["Task/MPJPE_WholeBody"] mpkpe = self.metrics["Task/MPKPE_WholeBody"] self._mpjpe_sum_since_assign[motion_ids] += mpjpe[motion_ids].to( dtype=torch.float32 ) self._mpkpe_sum_since_assign[motion_ids] += mpkpe[motion_ids].to( dtype=torch.float32 ) self._step_count_since_assign[motion_ids] += 1.0 @property def command( self, ) -> torch.Tensor: # call the corresponding method based on configured command_obs_name return getattr(self, f"_get_obs_{self.cfg.command_obs_name}")() @property def command_fut( self, ) -> torch.Tensor: # call the corresponding method based on configured command_obs_name return getattr(self, f"_get_obs_{self.cfg.command_obs_name}_fut")() def reset( self, env_ids: Sequence[int] | None = None, ) -> dict[str, float]: extras = super().reset(env_ids) if env_ids is None: env_ids = slice(None) if not isinstance(env_ids, torch.Tensor): env_ids = torch.tensor( env_ids, device=self.device, dtype=torch.long ) else: env_ids = env_ids.to(self.device) self._motion_end_mask[env_ids] = False self.motion_end_counter[env_ids] = 0 # Do not apply cache swap inside per-env reset; defer to PPO barrier. # Always resample only the requested envs here. motion_ids = self._filter_env_ids_for_motion_task(env_ids.view(-1)) self._resample_command(motion_ids, eval=self._is_evaluating) return extras def apply_cache_swap_if_pending_barrier(self, accelerator=None) -> bool: """Apply a pending cache swap at a rollout barrier. Returns: bool: True if a swap was applied, otherwise False. """ if not getattr(self, "_swap_pending", False): return False all_ids = torch.arange( self.num_envs, dtype=torch.long, device=self.device ) motion_ids = self._filter_env_ids_for_motion_task(all_ids) if motion_ids.numel() == 0: # No motion envs active under multi-task: keep ref motion inert. self._swap_pending = False self._swap_step_counter = 0 return False self._record_completion_rate_for_envs(motion_ids) next_swap_index = int(self._motion_cache.swap_index) + 1 self._update_cache_curriculum_state( accelerator=accelerator, swap_index=next_swap_index, ) # Advance cache and reset counters self._motion_cache.advance() self._maybe_dump_sampled_motion_keys() self._swap_pending = False self._swap_step_counter = 0 # Reassign motion envs to the new cache batch clip_idx, frame_idx = self._motion_cache.sample_env_assignments( int(motion_ids.numel()), self.cfg.n_fut_frames, self.device, deterministic_start=(self._is_evaluating), ) self._clip_indices[motion_ids] = clip_idx self._frame_indices[motion_ids] = frame_idx self._start_frame_indices[motion_ids] = frame_idx self._reward_sum_since_assign[motion_ids] = 0.0 self._step_count_since_assign[motion_ids] = 0.0 self._update_ref_motion_state_from_cache(env_ids=motion_ids) # Realign robot states to the new reference self._align_root_to_ref(motion_ids) self._align_dof_to_ref(motion_ids) # Reset per-episode timeout bookkeeping for consistency self._motion_end_mask[motion_ids] = False self.motion_end_counter[motion_ids] = 0 return True def compute(self, dt: float): all_ids = torch.arange( self.num_envs, dtype=torch.long, device=self.device ) motion_ids = self._filter_env_ids_for_motion_task(all_ids) if motion_ids.numel() == 0: return self._update_metrics() self._update_command() def _update_ref_motion_state(self): """Update reference motion state (unified API).""" return self._update_ref_motion_state_from_cache() def _update_ref_motion_state_from_cache( self, env_ids: torch.Tensor | None = None ): """Compatibility no-op for cache-backed reference access.""" del env_ids return None def _get_ref_state_array( self, base_key: str, prefix: str = "ref_", ) -> torch.Tensor: """Gather a reference tensor from the current cache batch. Args: base_key: Base key in the motion cache (e.g. \"dof_pos\", \"root_pos\"). prefix: Optional logical prefix (e.g. \"\", \"ref_\", \"ft_ref_\", \"robot_\"). Returns: Tensor of shape ``[num_envs, 1 + n_fut_frames, ...]`` gathered for the envs' current clip/frame assignments. """ batch_tensors = self._motion_cache.current_batch.tensors tensor_key = resolve_reference_tensor_key( batch_tensors=batch_tensors, base_key=base_key, prefix=prefix, ) return self._motion_cache.gather_tensor( tensor_key, clip_indices=self._clip_indices, frame_indices=self._frame_indices, n_future_frames=self.cfg.n_fut_frames, ) def get_ref_motion_filter_cutoff_hz_cur(self) -> torch.Tensor: try: base = self._get_ref_state_array("filter_cutoff_hz", prefix="") except KeyError: # Older/local datasets may not carry per-clip filter metadata. # Keep the observation available with a neutral default instead of # failing during env construction. return torch.zeros( self.num_envs, 1, device=self.device, dtype=torch.float32 ) return base[:, 0, ...] def _uniform_sample_ref_start_frames(self, env_ids: torch.Tensor): """Uniformly sample start frames within cached windows for env_ids. Sampling range is [start, end - 1 - n_fut_frames] to ensure required future frames exist. If that upper bound is < start, it falls back to start. """ if not isinstance(env_ids, torch.Tensor): env_ids = torch.tensor( env_ids, device=self.device, dtype=torch.long ) else: env_ids = env_ids.to(self.device).long() starts = self.ref_motion_global_start_frame_ids[env_ids] ends = self.ref_motion_global_end_frame_ids[env_ids] # Ensure room for future frames if requested n_fut = ( int(self.cfg.n_fut_frames) if hasattr(self.cfg, "n_fut_frames") else 0 ) max_start = ends - 1 - n_fut max_start = torch.maximum(max_start, starts) num_choices = (max_start - starts + 1).clamp(min=1) # Sample offsets uniformly rand = torch.rand_like(starts, dtype=torch.float32) offsets = torch.floor(rand * num_choices.float()).long() sampled = starts + offsets self.ref_motion_global_frame_ids[env_ids] = sampled def get_ref_motion_dof_pos_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("dof_pos", prefix) return base[:, 1:, ...][..., self.urdf2sim_dof_idx] def _get_immediate_next_ref_state_array( self, base_key: str, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array(base_key, prefix) if base.shape[1] < 2: raise ValueError( f"Immediate-next reference for '{base_key}' requires at least one future frame." ) return base[:, 1, ...] def get_ref_motion_dof_vel_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("dof_vel", prefix) return base[:, 1:, ...][..., self.urdf2sim_dof_idx] def get_ref_motion_root_global_pos_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("root_pos", prefix) return base[:, 1:, ...] + self._env_origins[:, None, :] def get_ref_motion_root_global_rot_quat_xyzw_fut( self, prefix: str = "ref_", ) -> torch.Tensor: return self._get_ref_state_array("root_rot", prefix)[:, 1:, ...] def get_ref_motion_root_global_rot_quat_wxyz_fut( self, prefix: str = "ref_", ) -> torch.Tensor: return self.get_ref_motion_root_global_rot_quat_xyzw_fut( prefix=prefix )[..., [3, 0, 1, 2]] def get_ref_motion_root_global_lin_vel_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("root_vel", prefix) return base[:, 1:, ...] def get_ref_motion_root_global_ang_vel_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("root_ang_vel", prefix) return base[:, 1:, ...] def get_ref_motion_bodylink_global_pos_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("rg_pos", prefix) return ( base[:, 1:, ...][..., self.urdf2sim_body_idx, :] + self._env_origins[:, None, None, :] ) def get_ref_motion_bodylink_rel_pos_cur( self, prefix: str = "ref_", ) -> torch.Tensor: ref_body_global_pos = self.get_ref_motion_bodylink_global_pos_cur( prefix=prefix ) # [B, N, 3] ref_root_global_pos = self.get_ref_motion_root_global_pos_cur( prefix=prefix ) # [B, 3] ref_root_global_rot_wxyz = ( self.get_ref_motion_root_global_rot_quat_wxyz_cur(prefix=prefix) ) # [B, 4] rel_pos_w = ( ref_body_global_pos - ref_root_global_pos[:, None, :] ) # [B, N, 3] num_bodies = rel_pos_w.shape[1] expanded_ref_root_global_rot_wxyz = ref_root_global_rot_wxyz[ :, None, : ].expand(-1, num_bodies, -1) return isaaclab_math.quat_apply_inverse( expanded_ref_root_global_rot_wxyz, rel_pos_w ) # [B, N, 3] def get_ref_motion_bodylink_rel_pos_fut( self, prefix: str = "ref_", ) -> torch.Tensor: ref_body_global_pos_fut = self.get_ref_motion_bodylink_global_pos_fut( prefix=prefix ) # [B, T, N, 3] ref_root_global_pos_fut = self.get_ref_motion_root_global_pos_fut( prefix=prefix ) # [B, T, 3] ref_root_global_rot_wxyz_fut = ( self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix) ) # [B, T, 4] rel_pos_w_fut = ( ref_body_global_pos_fut - ref_root_global_pos_fut[:, :, None, :] ) # [B, T, N, 3] num_bodies = rel_pos_w_fut.shape[2] expanded_ref_root_global_rot_wxyz_fut = ref_root_global_rot_wxyz_fut[ :, :, None, : ].expand(-1, -1, num_bodies, -1) return isaaclab_math.quat_apply_inverse( expanded_ref_root_global_rot_wxyz_fut, rel_pos_w_fut ) # [B, T, N, 3] def get_ref_motion_bodylink_global_rot_xyzw_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("rb_rot", prefix) return base[:, 1:, ...][..., self.urdf2sim_body_idx, :] def get_ref_motion_bodylink_global_lin_vel_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("body_vel", prefix) return base[:, 1:, ...][..., self.urdf2sim_body_idx, :] def get_ref_motion_bodylink_global_ang_vel_fut( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("body_ang_vel", prefix) return base[:, 1:, ...][..., self.urdf2sim_body_idx, :] def get_ref_motion_dof_pos_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("dof_pos", prefix) return base[:, 0, ...][..., self.urdf2sim_dof_idx] def get_ref_motion_dof_pos_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_immediate_next_ref_state_array("dof_pos", prefix) return base[..., self.urdf2sim_dof_idx] def get_immediate_next_two_dof_pos( self, prefix: str = "ref_", ) -> torch.Tensor: """Immediate next two DoF positions in simulator DoF order.""" n_fut = int(self.cfg.n_fut_frames) if n_fut < 1: raise ValueError( "n_fut_frames must be at least 1 for immediate next two DoF positions." ) base = self._get_ref_state_array("dof_pos", prefix) return base[:, :2, ...][..., self.urdf2sim_dof_idx] def get_ref_motion_dof_pos_cur_urdf_order( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("dof_pos", prefix) return base[:, 0, ...] def get_ref_motion_cur_heading_aligned_root_pos( self, prefix: str = "ref_", ) -> torch.Tensor: # prepare current frame robot root global poses robot_cur_global_root_pos = self.robot.data.root_pos_w robot_cur_global_root_rot = self.robot.data.root_quat_w # wxyz yaw_quat = isaaclab_math.yaw_quat(robot_cur_global_root_rot) # transform the current goal frame root poses into the relative heading aligned frame global_pos_diff = ( self.get_ref_motion_root_global_pos_cur(prefix=prefix) - robot_cur_global_root_pos ) global_pos_diff_heading_aligned = isaaclab_math.quat_apply_inverse( yaw_quat, global_pos_diff ) return global_pos_diff_heading_aligned def get_ref_motion_fut_heading_aligned_root_pos( self, prefix: str = "ref_", ) -> torch.Tensor: # prepare current frame robot root global poses robot_cur_global_root_pos = self.robot.data.root_pos_w # [B, 3] robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] yaw_quat = isaaclab_math.yaw_quat(robot_cur_global_root_rot) # [B, 4] # transform the current goal frame root poses into the relative heading aligned frame fut_root_global_pos = self.get_ref_motion_root_global_pos_fut( prefix=prefix ) # [B, T, 3] num_fut_frames = fut_root_global_pos.shape[1] global_pos_diff = ( fut_root_global_pos - robot_cur_global_root_pos[:, None, :] ) # [B, T, 3] expanded_yaw_quat = yaw_quat[:, None, :].expand( -1, num_fut_frames, -1 ) # [B, T, 4] fut_root_global_pos_heading_aligned = isaaclab_math.quat_apply_inverse( expanded_yaw_quat, global_pos_diff ) # [B, T, 3] return fut_root_global_pos_heading_aligned def get_ref_motion_cur_heading_aligned_root_rot6d( self, prefix: str = "ref_", ) -> torch.Tensor: """Current reference root rotation (rot6d) in heading-aligned frame. Returns: torch.Tensor: [B, 6] """ robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz heading_quat_wxyz = isaaclab_math.yaw_quat( robot_cur_global_root_rot ) # [B, 4] wxyz heading_quat_inv_wxyz = isaaclab_math.quat_inv( heading_quat_wxyz ) # [B, 4] wxyz ref_root_quat_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur( prefix=prefix ) # [B, 4] wxyz ref_root_quat_in_heading_wxyz = isaaclab_math.quat_mul( heading_quat_inv_wxyz, ref_root_quat_wxyz ) # [B, 4] wxyz # rot6d: first two columns of rotation matrix (flattened) ref_root_rot6d = isaaclab_math.matrix_from_quat( ref_root_quat_in_heading_wxyz )[..., :2].reshape(ref_root_quat_wxyz.shape[0], 6) # [B, 6] return ref_root_rot6d def get_ref_motion_fut_heading_aligned_root_rot6d( self, prefix: str = "ref_", ) -> torch.Tensor: """Future reference root rotations (rot6d) in heading-aligned frame. Returns: torch.Tensor: [B, T, 6] """ robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz heading_quat_wxyz = isaaclab_math.yaw_quat( robot_cur_global_root_rot ) # [B, 4] wxyz heading_quat_inv_wxyz = isaaclab_math.quat_inv( heading_quat_wxyz ) # [B, 4] wxyz ref_root_quat_wxyz_fut = ( self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix) ) # [B, T, 4] wxyz num_envs, num_fut_frames, _ = ref_root_quat_wxyz_fut.shape heading_quat_inv_wxyz_fut = heading_quat_inv_wxyz[:, None, :].expand( -1, num_fut_frames, -1 ) # [B, T, 4] ref_root_quat_in_heading_wxyz_fut = isaaclab_math.quat_mul( heading_quat_inv_wxyz_fut, ref_root_quat_wxyz_fut ) # [B, T, 4] wxyz ref_root_rot6d_fut = isaaclab_math.matrix_from_quat( ref_root_quat_in_heading_wxyz_fut )[..., :2].reshape(num_envs, num_fut_frames, 6) # [B, T, 6] return ref_root_rot6d_fut def get_ref_motion_cur_heading_aligned_root_lin_vel( self, prefix: str = "ref_", ) -> torch.Tensor: """Current reference root linear velocity in heading-aligned frame. Returns: [B, 3] """ robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz heading_quat_wxyz = isaaclab_math.yaw_quat( robot_cur_global_root_rot ) # [B, 4] wxyz ref_root_lin_vel_w = self.get_ref_motion_root_global_lin_vel_cur( prefix=prefix ) # [B, 3] ref_root_lin_vel_heading = isaaclab_math.quat_apply_inverse( heading_quat_wxyz, ref_root_lin_vel_w ) # [B, 3] return ref_root_lin_vel_heading def get_ref_motion_fut_heading_aligned_root_lin_vel( self, prefix: str = "ref_", ) -> torch.Tensor: """Future reference root linear velocity in heading-aligned frame. Returns: [B, T, 3] """ robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz heading_quat_wxyz = isaaclab_math.yaw_quat( robot_cur_global_root_rot ) # [B, 4] wxyz ref_root_lin_vel_w_fut = self.get_ref_motion_root_global_lin_vel_fut( prefix=prefix ) # [B, T, 3] num_envs, num_fut_frames, _ = ref_root_lin_vel_w_fut.shape heading_quat_wxyz_fut = heading_quat_wxyz[:, None, :].expand( -1, num_fut_frames, -1 ) # [B, T, 4] ref_root_lin_vel_heading_fut = isaaclab_math.quat_apply_inverse( heading_quat_wxyz_fut, ref_root_lin_vel_w_fut ) # [B, T, 3] return ref_root_lin_vel_heading_fut def get_ref_motion_cur_heading_aligned_root_ang_vel( self, prefix: str = "ref_", ) -> torch.Tensor: """Current reference root angular velocity in heading-aligned frame. Returns: [B, 3] """ robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz heading_quat_wxyz = isaaclab_math.yaw_quat( robot_cur_global_root_rot ) # [B, 4] wxyz ref_root_ang_vel_w = self.get_ref_motion_root_global_ang_vel_cur( prefix=prefix ) # [B, 3] ref_root_ang_vel_heading = isaaclab_math.quat_apply_inverse( heading_quat_wxyz, ref_root_ang_vel_w ) # [B, 3] return ref_root_ang_vel_heading def get_ref_motion_fut_heading_aligned_root_ang_vel( self, prefix: str = "ref_", ) -> torch.Tensor: """Future reference root angular velocity in heading-aligned frame. Returns: [B, T, 3] """ robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz heading_quat_wxyz = isaaclab_math.yaw_quat( robot_cur_global_root_rot ) # [B, 4] wxyz ref_root_ang_vel_w_fut = self.get_ref_motion_root_global_ang_vel_fut( prefix=prefix ) # [B, T, 3] num_envs, num_fut_frames, _ = ref_root_ang_vel_w_fut.shape heading_quat_wxyz_fut = heading_quat_wxyz[:, None, :].expand( -1, num_fut_frames, -1 ) # [B, T, 4] ref_root_ang_vel_heading_fut = isaaclab_math.quat_apply_inverse( heading_quat_wxyz_fut, ref_root_ang_vel_w_fut ) # [B, T, 3] return ref_root_ang_vel_heading_fut @property def robot_dof_pos_cur_urdf_order(self): return self.robot.data.joint_pos[..., self.sim2urdf_dof_idx] def get_ref_motion_dof_vel_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("dof_vel", prefix) return base[:, 0, ...][..., self.urdf2sim_dof_idx] def get_ref_motion_dof_vel_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_immediate_next_ref_state_array("dof_vel", prefix) return base[..., self.urdf2sim_dof_idx] @property def robot_dof_vel_cur_urdf_order(self): return self.robot.data.joint_vel[..., self.sim2urdf_dof_idx] def get_ref_motion_dof_vel_cur_urdf_order( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("dof_vel", prefix) return base[:, 0, ...] def get_ref_motion_root_global_pos_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("root_pos", prefix) return base[:, 0, ...] + self._env_origins def get_ref_motion_root_global_pos_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_immediate_next_ref_state_array("root_pos", prefix) return base + self._env_origins def get_ref_motion_root_global_rot_quat_xyzw_cur( self, prefix: str = "ref_", ) -> torch.Tensor: return self._get_ref_state_array("root_rot", prefix)[:, 0, ...] def get_ref_motion_root_global_rot_quat_xyzw_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: return self._get_immediate_next_ref_state_array("root_rot", prefix) def get_ref_motion_root_global_rot_quat_wxyz_cur( self, prefix: str = "ref_", ) -> torch.Tensor: return self.get_ref_motion_root_global_rot_quat_xyzw_cur( prefix=prefix )[..., [3, 0, 1, 2]] def get_ref_motion_root_global_rot_quat_wxyz_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: return self.get_ref_motion_root_global_rot_quat_xyzw_immediate_next( prefix=prefix )[..., [3, 0, 1, 2]] def get_ref_motion_root_global_lin_vel_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("root_vel", prefix) return base[:, 0, ...] def get_ref_motion_root_global_lin_vel_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: return self._get_immediate_next_ref_state_array("root_vel", prefix) @property def ref_motion_root_global_lin_vel_cur(self) -> torch.Tensor: return self.get_ref_motion_root_global_lin_vel_cur() def get_ref_motion_root_global_ang_vel_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("root_ang_vel", prefix) return base[:, 0, ...] def get_ref_motion_root_global_ang_vel_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: return self._get_immediate_next_ref_state_array("root_ang_vel", prefix) def get_ref_motion_gravity_projection_cur( self, prefix: str = "ref_", ) -> torch.Tensor: """Current reference gravity projected into reference root frame.""" g_w = self.robot.data.GRAVITY_VEC_W # [B, 3] ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur( prefix=prefix ) # [B, 4] return isaaclab_math.quat_apply_inverse(ref_root_rot_wxyz, g_w) def get_ref_motion_gravity_projection_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: g_w = self.robot.data.GRAVITY_VEC_W # [B, 3] ref_root_rot_wxyz = ( self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=prefix ) ) return isaaclab_math.quat_apply_inverse(ref_root_rot_wxyz, g_w) def get_ref_motion_gravity_projection_fut( self, prefix: str = "ref_", ) -> torch.Tensor: """Future reference gravity projected into reference root frame.""" g_w = self.robot.data.GRAVITY_VEC_W # [B, 3] ref_root_rot_wxyz_fut = ( self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix) ) # [B, T, 4] gravity_fut = g_w[:, None, :].expand( -1, ref_root_rot_wxyz_fut.shape[1], -1 ) # [B, T, 3] return isaaclab_math.quat_apply_inverse( ref_root_rot_wxyz_fut, gravity_fut ) # [B, T, 3] def get_ref_motion_base_linvel_cur( self, prefix: str = "ref_", ) -> torch.Tensor: """Current reference base linear velocity in reference root frame.""" ref_root_lin_vel_w = self.get_ref_motion_root_global_lin_vel_cur( prefix=prefix ) # [B, 3] ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur( prefix=prefix ) # [B, 4] return isaaclab_math.quat_apply_inverse( ref_root_rot_wxyz, ref_root_lin_vel_w ) # [B, 3] def get_ref_motion_base_linvel_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: ref_root_lin_vel_w = ( self.get_ref_motion_root_global_lin_vel_immediate_next( prefix=prefix ) ) ref_root_rot_wxyz = ( self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=prefix ) ) return isaaclab_math.quat_apply_inverse( ref_root_rot_wxyz, ref_root_lin_vel_w ) def get_ref_motion_base_linvel_fut( self, prefix: str = "ref_", ) -> torch.Tensor: """Future reference base linear velocity in reference root frame.""" ref_root_lin_vel_w_fut = self.get_ref_motion_root_global_lin_vel_fut( prefix=prefix ) # [B, T, 3] ref_root_rot_wxyz_fut = ( self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix) ) # [B, T, 4] return isaaclab_math.quat_apply_inverse( ref_root_rot_wxyz_fut, ref_root_lin_vel_w_fut ) # [B, T, 3] def get_ref_motion_base_angvel_cur( self, prefix: str = "ref_", ) -> torch.Tensor: """Current reference base angular velocity in reference root frame.""" ref_root_ang_vel_w = self.get_ref_motion_root_global_ang_vel_cur( prefix=prefix ) # [B, 3] ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur( prefix=prefix ) # [B, 4] return isaaclab_math.quat_apply_inverse( ref_root_rot_wxyz, ref_root_ang_vel_w ) # [B, 3] def get_ref_motion_base_angvel_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: ref_root_ang_vel_w = ( self.get_ref_motion_root_global_ang_vel_immediate_next( prefix=prefix ) ) ref_root_rot_wxyz = ( self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=prefix ) ) return isaaclab_math.quat_apply_inverse( ref_root_rot_wxyz, ref_root_ang_vel_w ) def get_ref_motion_base_angvel_fut( self, prefix: str = "ref_", ) -> torch.Tensor: """Future reference base angular velocity in reference root frame.""" ref_root_ang_vel_w_fut = self.get_ref_motion_root_global_ang_vel_fut( prefix=prefix ) # [B, T, 3] ref_root_rot_wxyz_fut = ( self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix) ) # [B, T, 4] return isaaclab_math.quat_apply_inverse( ref_root_rot_wxyz_fut, ref_root_ang_vel_w_fut ) # [B, T, 3] def get_ref_motion_bodylink_global_pos_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("rg_pos", prefix) return ( base[:, 0, ...][..., self.urdf2sim_body_idx, :] + self._env_origins[:, None, :] ) def get_ref_motion_bodylink_global_pos_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_immediate_next_ref_state_array("rg_pos", prefix) return ( base[..., self.urdf2sim_body_idx, :] + self._env_origins[:, None, :] ) def get_ref_motion_bodylink_global_pos_cur_urdf_order( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("rg_pos", prefix) return base[:, 0, ...] + self._env_origins[:, None, :] def get_ref_motion_bodylink_global_rot_wxyz_cur( self, prefix: str = "ref_", ) -> torch.Tensor: rot_xyzw = self.get_ref_motion_bodylink_global_rot_xyzw_cur( prefix=prefix ) return rot_xyzw[..., [3, 0, 1, 2]] def get_ref_motion_bodylink_global_rot_xyzw_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("rb_rot", prefix) return base[:, 0, ...][..., self.urdf2sim_body_idx, :] def get_ref_motion_bodylink_global_rot_xyzw_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_immediate_next_ref_state_array("rb_rot", prefix) return base[..., self.urdf2sim_body_idx, :] def get_ref_motion_bodylink_global_rot_wxyz_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: rot_xyzw = self.get_ref_motion_bodylink_global_rot_xyzw_immediate_next( prefix=prefix ) return rot_xyzw[..., [3, 0, 1, 2]] def get_ref_motion_bodylink_global_rot_xyzw_cur_urdf_order( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("rb_rot", prefix) return base[:, 0, ...] @property def robot_bodylink_global_pos_cur_urdf_order(self): return self.robot.data.body_pos_w[:, self.sim2urdf_body_idx] @property def robot_bodylink_global_rot_wxyz_cur_urdf_order(self): return self.robot.data.body_quat_w[:, self.sim2urdf_body_idx] @property def robot_bodylink_global_rot_xyzw_cur_urdf_order(self): return self.robot_bodylink_global_rot_wxyz_cur_urdf_order[ ..., [1, 2, 3, 0] ] @property def robot_bodylink_global_lin_vel_cur_urdf_order(self): return self.robot.data.body_lin_vel_w[:, self.sim2urdf_body_idx] @property def robot_bodylink_global_ang_vel_cur_urdf_order(self): return self.robot.data.body_ang_vel_w[:, self.sim2urdf_body_idx] def get_ref_motion_bodylink_global_lin_vel_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("body_vel", prefix) return base[:, 0, ...][..., self.urdf2sim_body_idx, :] def get_ref_motion_bodylink_global_lin_vel_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_immediate_next_ref_state_array("body_vel", prefix) return base[..., self.urdf2sim_body_idx, :] def get_ref_motion_bodylink_global_lin_vel_cur_urdf_order( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("body_vel", prefix) return base[:, 0, ...] def get_ref_motion_bodylink_global_ang_vel_cur( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("body_ang_vel", prefix) return base[:, 0, ...][..., self.urdf2sim_body_idx, :] def get_ref_motion_bodylink_global_ang_vel_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_immediate_next_ref_state_array("body_ang_vel", prefix) return base[..., self.urdf2sim_body_idx, :] def get_ref_motion_bodylink_global_ang_vel_cur_urdf_order( self, prefix: str = "ref_", ) -> torch.Tensor: base = self._get_ref_state_array("body_ang_vel", prefix) return base[:, 0, ...] def _build_amp_obs_from_ref_state( self, frame_idx: int, prefix: str = "ft_ref_" ) -> torch.Tensor: if ( not self._amp_left_arm_urdf_dof_idx or not self._amp_right_arm_urdf_dof_idx or not self._amp_left_leg_urdf_dof_idx or not self._amp_right_leg_urdf_dof_idx or self._amp_left_elbow_urdf_body_idx is None or self._amp_right_elbow_urdf_body_idx is None or self._amp_left_foot_urdf_body_idx is None or self._amp_right_foot_urdf_body_idx is None ): raise ValueError( "AMP obs indices are not initialized for ref motion." ) dof_pos = self._get_ref_state_array("dof_pos", prefix)[ :, frame_idx, ... ] dof_vel = self._get_ref_state_array("dof_vel", prefix)[ :, frame_idx, ... ] right_arm_pos = dof_pos[:, self._amp_right_arm_urdf_dof_idx] left_arm_pos = dof_pos[:, self._amp_left_arm_urdf_dof_idx] right_leg_pos = dof_pos[:, self._amp_right_leg_urdf_dof_idx] left_leg_pos = dof_pos[:, self._amp_left_leg_urdf_dof_idx] right_arm_vel = dof_vel[:, self._amp_right_arm_urdf_dof_idx] left_arm_vel = dof_vel[:, self._amp_left_arm_urdf_dof_idx] right_leg_vel = dof_vel[:, self._amp_right_leg_urdf_dof_idx] left_leg_vel = dof_vel[:, self._amp_left_leg_urdf_dof_idx] root_pos = self._get_ref_state_array("root_pos", prefix)[ :, frame_idx, ... ] root_rot = self._get_ref_state_array("root_rot", prefix)[ :, frame_idx, ... ] root_inv = quat_inverse(root_rot, w_last=True) rg_pos = self._get_ref_state_array("rg_pos", prefix)[:, frame_idx, ...] rb_rot = self._get_ref_state_array("rb_rot", prefix)[:, frame_idx, ...] left_elbow_pos = rg_pos[:, self._amp_left_elbow_urdf_body_idx, :] right_elbow_pos = rg_pos[:, self._amp_right_elbow_urdf_body_idx, :] left_elbow_rot = rb_rot[:, self._amp_left_elbow_urdf_body_idx, :] right_elbow_rot = rb_rot[:, self._amp_right_elbow_urdf_body_idx, :] left_hand_offset = self._amp_left_hand_local_vec.expand( left_elbow_pos.shape[0], -1 ) right_hand_offset = self._amp_right_hand_local_vec.expand( right_elbow_pos.shape[0], -1 ) left_hand_world = left_elbow_pos + quat_rotate( left_elbow_rot, left_hand_offset, w_last=True ) right_hand_world = right_elbow_pos + quat_rotate( right_elbow_rot, right_hand_offset, w_last=True ) left_hand_rel = quat_rotate( root_inv, left_hand_world - root_pos, w_last=True ) right_hand_rel = quat_rotate( root_inv, right_hand_world - root_pos, w_last=True ) left_foot_world = rg_pos[:, self._amp_left_foot_urdf_body_idx, :] right_foot_world = rg_pos[:, self._amp_right_foot_urdf_body_idx, :] left_foot_rel = quat_rotate( root_inv, left_foot_world - root_pos, w_last=True ) right_foot_rel = quat_rotate( root_inv, right_foot_world - root_pos, w_last=True ) return torch.cat( [ right_arm_pos, left_arm_pos, right_leg_pos, left_leg_pos, right_arm_vel, left_arm_vel, right_leg_vel, left_leg_vel, left_hand_rel, right_hand_rel, left_foot_rel, right_foot_rel, ], dim=-1, ) def get_ref_motion_amp_obs_cur( self, prefix: str = "ft_ref_" ) -> torch.Tensor: """AMP observation aligned with RSL reference (current frame).""" return self._build_amp_obs_from_ref_state(0, prefix=prefix) @property def motion_end_mask(self) -> torch.Tensor: """[B] bool: per-step timeout mask. Uses the per-step `motion_end_mask` set before resampling so the event is observable within the same step, and falls back to a direct comparison if not available. """ return self._motion_end_mask @property def global_robot_anchor_pos_cur(self): return self.robot.data.body_pos_w[:, self.anchor_bodylink_idx] def get_ref_motion_anchor_bodylink_global_pos_cur( self, prefix: str = "ref_", ) -> torch.Tensor: pos = self.get_ref_motion_bodylink_global_pos_cur(prefix=prefix) return pos[:, self.anchor_bodylink_idx] def get_ref_motion_anchor_bodylink_global_rot_wxyz_cur( self, prefix: str = "ref_", ) -> torch.Tensor: rot = self.get_ref_motion_bodylink_global_rot_wxyz_cur(prefix=prefix) return rot[:, self.anchor_bodylink_idx] def get_ref_motion_anchor_bodylink_global_pos_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: pos = self.get_ref_motion_bodylink_global_pos_immediate_next( prefix=prefix ) return pos[:, self.anchor_bodylink_idx] def get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next( self, prefix: str = "ref_", ) -> torch.Tensor: rot = self.get_ref_motion_bodylink_global_rot_wxyz_immediate_next( prefix=prefix ) return rot[:, self.anchor_bodylink_idx] def _get_obs_bydmmc_ref_motion( self, obs_prefix: str = "ref_", ) -> torch.Tensor: base_pos = self._get_ref_state_array("dof_pos", obs_prefix)[:, 0, ...][ ..., self.urdf2sim_dof_idx ] base_vel = self._get_ref_state_array("dof_vel", obs_prefix)[:, 0, ...][ ..., self.urdf2sim_dof_idx ] num_envs = base_pos.shape[0] cur_ref_dof_pos_flat = base_pos.reshape(num_envs, -1) cur_ref_dof_vel_flat = base_vel.reshape(num_envs, -1) return torch.cat([cur_ref_dof_pos_flat, cur_ref_dof_vel_flat], dim=-1) def _get_obs_bydmmc_ref_motion_fut( self, obs_prefix: str = "ref_", ) -> torch.Tensor: base_pos = self._get_ref_state_array("dof_pos", obs_prefix)[ :, 1:, ... ][..., self.urdf2sim_dof_idx] base_vel = self._get_ref_state_array("dof_vel", obs_prefix)[ :, 1:, ... ][..., self.urdf2sim_dof_idx] num_envs = base_pos.shape[0] n_fut_frames = int(self.cfg.n_fut_frames) fut_ref_dof_pos_flat = base_pos.reshape(num_envs, n_fut_frames, -1) fut_ref_dof_vel_flat = base_vel.reshape(num_envs, n_fut_frames, -1) rel_fut_ref_motion_state_seq = torch.cat( [fut_ref_dof_pos_flat, fut_ref_dof_vel_flat], dim=-1 ) return rel_fut_ref_motion_state_seq.reshape(num_envs, -1) def _get_obs_vr_ref_motion_states( self, obs_prefix: str = "ref_", ) -> torch.Tensor: base_pos = self._get_ref_state_array("dof_pos", obs_prefix)[:, 0, ...][ ..., self.urdf2sim_dof_idx ] num_envs = base_pos.shape[0] cur_ref_dof_pos_flat = base_pos.reshape(num_envs, -1) return torch.cat( [ cur_ref_dof_pos_flat, torch.zeros_like( cur_ref_dof_pos_flat, device=cur_ref_dof_pos_flat.device, ), ], dim=-1, ) def _get_obs_vr_ref_motion_fut( self, obs_prefix: str = "ref_", ) -> torch.Tensor: base_pos = self._get_ref_state_array("dof_pos", obs_prefix)[ :, 1:, ... ][..., self.urdf2sim_dof_idx] num_envs = base_pos.shape[0] n_fut_frames = int(self.cfg.n_fut_frames) fut_ref_dof_pos_flat = base_pos.reshape(num_envs, n_fut_frames, -1) rel_fut_ref_motion_state_seq = torch.cat( [ fut_ref_dof_pos_flat, torch.zeros_like( fut_ref_dof_pos_flat, device=fut_ref_dof_pos_flat.device ), ], dim=-1, ) return rel_fut_ref_motion_state_seq.reshape(num_envs, -1) def _get_obs_holomotion_rel_ref_motion_flat( self, obs_prefix: str = "ref_", ) -> torch.Tensor: # Gather all needed arrays with obs prefix fut_rg_pos = self._get_ref_state_array("rg_pos", obs_prefix)[ :, 1:, ... ][..., self.urdf2sim_body_idx, :] fut_rb_rot_xyzw = self._get_ref_state_array("rb_rot", obs_prefix)[ :, 1:, ... ][..., self.urdf2sim_body_idx, :] fut_root_rot_xyzw = self._get_ref_state_array("root_rot", obs_prefix)[ :, 1:, ... ] fut_root_lin_vel = self._get_ref_state_array("root_vel", obs_prefix)[ :, 1:, ... ] fut_root_ang_vel = self._get_ref_state_array( "root_ang_vel", obs_prefix )[:, 1:, ...] fut_dof_pos = self._get_ref_state_array("dof_pos", obs_prefix)[ :, 1:, ... ][..., self.urdf2sim_dof_idx] fut_dof_vel = self._get_ref_state_array("dof_vel", obs_prefix)[ :, 1:, ... ][..., self.urdf2sim_dof_idx] num_envs, num_fut_timesteps, num_bodies, _ = fut_rg_pos.shape assert num_envs == self.num_envs assert num_fut_timesteps == self.cfg.n_fut_frames fut_ref_root_rot_quat = fut_root_rot_xyzw # [B, T, 4] fut_ref_root_rot_quat_inv = quat_inverse( fut_ref_root_rot_quat, w_last=True ) # [B, T, 4] fut_ref_root_rot_quat_body_flat = ( fut_ref_root_rot_quat[:, :, None, :] .repeat(1, 1, num_bodies, 1) .reshape(-1, 4) ) fut_ref_root_rot_quat_body_flat_inv = quat_inverse( fut_ref_root_rot_quat_body_flat, w_last=True ) ref_fut_heading_quat_inv = calc_heading_quat_inv( fut_root_rot_xyzw.reshape(-1, 4), w_last=True, ) # [B*T, 4] ref_fut_quat_rp = quat_mul( ref_fut_heading_quat_inv, fut_root_rot_xyzw.reshape(-1, 4), w_last=True, ) # [B*T, 4] ref_fut_roll, ref_fut_pitch, _ = get_euler_xyz( ref_fut_quat_rp, w_last=True, ) ref_fut_roll = wrap_to_pi(ref_fut_roll).reshape( num_envs, num_fut_timesteps, -1 ) # [B, T, 1] ref_fut_pitch = wrap_to_pi(ref_fut_pitch).reshape( num_envs, num_fut_timesteps, -1 ) # [B, T, 1] ref_fut_rp = torch.cat( [ref_fut_roll, ref_fut_pitch], dim=-1 ) # [B, T, 2] ref_fut_rp_flat = ref_fut_rp.reshape(num_envs, -1) # [B, T * 2] # --- fut_ref_root_quat_inv_fut_flat = fut_ref_root_rot_quat_inv.reshape( -1, 4 ) fut_ref_cur_root_rel_base_lin_vel = quat_rotate( fut_ref_root_quat_inv_fut_flat, # [B*T, 4] fut_root_lin_vel.reshape(-1, 3), # [B*T, 3] w_last=True, ).reshape(num_envs, -1) # [B, num_fut_timesteps * 3] fut_ref_cur_root_rel_base_ang_vel = quat_rotate( fut_ref_root_quat_inv_fut_flat, # [B*T, 4] fut_root_ang_vel.reshape(-1, 3), # [B*T, 3] w_last=True, ).reshape(num_envs, -1) # [B, num_fut_timesteps * 3] # --- # --- calculate the absolute DoF position and velocity --- fut_ref_dof_pos_flat = fut_dof_pos.reshape(num_envs, -1) fut_ref_dof_vel_flat = fut_dof_vel.reshape(num_envs, -1) # --- # --- calculate the future per frame bodylink position and rotation --- fut_ref_global_bodylink_pos = fut_rg_pos # [B, T, num_bodies, 3] fut_ref_global_bodylink_rot = fut_rb_rot_xyzw # [B, T, num_bodies, 4] # get root-relative bodylink position fut_ref_root_rel_bodylink_pos = quat_rotate( fut_ref_root_rot_quat_body_flat_inv, ( fut_ref_global_bodylink_pos - fut_ref_global_bodylink_pos[:, :, 0:1, :] ).reshape(-1, 3), w_last=True, ).reshape( num_envs, num_fut_timesteps, num_bodies, -1 ) # [B, num_fut_timesteps, num_bodies, 3] # get root-relative bodylink rotation fut_ref_root_rel_bodylink_rot = quat_mul( fut_ref_root_rot_quat_body_flat_inv, fut_ref_global_bodylink_rot.reshape(-1, 4), w_last=True, ) fut_ref_root_rel_bodylink_rot_mat = quaternion_to_matrix( fut_ref_root_rel_bodylink_rot, w_last=True, )[:, :, :2].reshape( num_envs, num_fut_timesteps, num_bodies, -1 ) # [B, num_fut_timesteps, num_bodies, 6] rel_fut_ref_motion_state_seq = torch.cat( [ ref_fut_rp_flat.reshape( num_envs, num_fut_timesteps, -1 ), # [B, T, 2] fut_ref_cur_root_rel_base_lin_vel.reshape( num_envs, num_fut_timesteps, -1 ), # [B, T, 3] fut_ref_cur_root_rel_base_ang_vel.reshape( num_envs, num_fut_timesteps, -1 ), # [B, T, 3] fut_ref_dof_pos_flat.reshape( num_envs, num_fut_timesteps, -1 ), # [B, T, num_dofs] fut_ref_dof_vel_flat.reshape( num_envs, num_fut_timesteps, -1 ), # [B, T, num_dofs] fut_ref_root_rel_bodylink_pos.reshape( num_envs, num_fut_timesteps, -1 ), # [B, T, num_bodies*3] fut_ref_root_rel_bodylink_rot_mat.reshape( num_envs, num_fut_timesteps, -1 ), # [B, T, num_bodies*6] ], dim=-1, ) # [B, T, 2 + 3 + 3 + num_dofs * 2 + num_bodies * (3 + 6)] return rel_fut_ref_motion_state_seq.reshape(self.num_envs, -1) def _resample_command(self, env_ids: Sequence[int], eval=False): """Resample command for specified environments.""" if len(env_ids) == 0: return if not isinstance(env_ids, torch.Tensor): env_ids = torch.tensor(env_ids, device=self.device) else: env_ids = env_ids.to(self.device) if isinstance(env_ids, torch.Tensor): idxs = env_ids elif isinstance(env_ids, slice): idxs = torch.arange(self.num_envs, device=self.device) else: idxs = torch.tensor(env_ids, device=self.device, dtype=torch.long) idxs = self._filter_env_ids_for_motion_task(idxs.view(-1)) if idxs.numel() == 0: return self._record_completion_rate_for_envs(idxs) clip_idx, frame_idx = self._motion_cache.sample_env_assignments( len(idxs), self.cfg.n_fut_frames, self.device, deterministic_start=(eval or self._is_evaluating), ) self._clip_indices[idxs] = clip_idx self._frame_indices[idxs] = frame_idx self._start_frame_indices[idxs] = frame_idx self._reward_sum_since_assign[idxs] = 0.0 self._step_count_since_assign[idxs] = 0.0 self._update_ref_motion_state_from_cache(env_ids=idxs) self._align_root_to_ref(idxs) self._align_dof_to_ref(idxs) def _filter_env_ids_for_motion_task( self, env_ids: torch.Tensor ) -> torch.Tensor: """Filter env_ids to those currently assigned to motion_tracking task. In multi-task training, we may keep `ref_motion` registered for observation schemas, but we must avoid applying motion-based state alignment to envs that are not running motion tracking (e.g., velocity tracking only). Behavior: - If env does not expose multi-task task buffers, return env_ids (legacy). - If env exposes task buffers but has no "motion_tracking" task, return empty. - Otherwise, return env_ids where holo_task_ids == holo_task_name_to_id["motion_tracking"]. """ if env_ids.numel() == 0: return env_ids task_ids = getattr(self._env, "holo_task_ids", None) task_name_to_id = getattr(self._env, "holo_task_name_to_id", None) if task_ids is None or task_name_to_id is None: return env_ids motion_tid = task_name_to_id.get("motion_tracking", None) if motion_tid is None: return env_ids[:0] task_ids_t = task_ids.to(device=self.device, dtype=torch.long).view(-1) env_ids_t = env_ids.to(device=self.device, dtype=torch.long).view(-1) mask = task_ids_t[env_ids_t] == int(motion_tid) return env_ids_t[mask] def _align_root_to_ref(self, env_ids): if not isinstance(env_ids, torch.Tensor): env_ids = torch.tensor( env_ids, device=self.device, dtype=torch.long ) else: env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1) env_ids = self._filter_env_ids_for_motion_task(env_ids) if env_ids.numel() == 0: return root_pos = self.get_ref_motion_root_global_pos_cur().clone() root_rot_xyzw = self.get_ref_motion_root_global_rot_quat_xyzw_cur() root_rot = root_rot_xyzw[..., [3, 0, 1, 2]].clone() root_lin_vel = self.get_ref_motion_root_global_lin_vel_cur().clone() root_ang_vel = self.get_ref_motion_root_global_ang_vel_cur().clone() pos_rot_range_list = [ self.cfg.root_pose_perturb_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"] ] pos_rot_ranges = torch.tensor(pos_rot_range_list, device=self.device) pos_rot_rand_deltas = isaaclab_math.sample_uniform( pos_rot_ranges[:, 0], pos_rot_ranges[:, 1], (len(env_ids), 6), device=self.device, ) translation_delta = pos_rot_rand_deltas[:, 0:3] rotation_delta = isaaclab_math.quat_from_euler_xyz( pos_rot_rand_deltas[:, 3], pos_rot_rand_deltas[:, 4], pos_rot_rand_deltas[:, 5], ) root_pos[env_ids] += translation_delta root_rot[env_ids] = isaaclab_math.quat_mul( rotation_delta, root_rot[env_ids], ) lin_ang_vel_range_list = [ self.cfg.root_vel_perturb_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"] ] lin_ang_vel_ranges = torch.tensor( lin_ang_vel_range_list, device=self.device ) lin_ang_vel_rand_deltas = isaaclab_math.sample_uniform( lin_ang_vel_ranges[:, 0], lin_ang_vel_ranges[:, 1], (len(env_ids), 6), device=self.device, ) root_lin_vel[env_ids] += lin_ang_vel_rand_deltas[:, :3] root_ang_vel[env_ids] += lin_ang_vel_rand_deltas[:, 3:] self.robot.write_root_state_to_sim( torch.cat( [ root_pos[env_ids], root_rot[env_ids], root_lin_vel[env_ids], root_ang_vel[env_ids], ], dim=-1, ), env_ids=env_ids, ) def _align_dof_to_ref(self, env_ids): if not isinstance(env_ids, torch.Tensor): env_ids = torch.tensor( env_ids, device=self.device, dtype=torch.long ) else: env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1) env_ids = self._filter_env_ids_for_motion_task(env_ids) if env_ids.numel() == 0: return dof_pos = self.get_ref_motion_dof_pos_cur().clone() dof_vel = self.get_ref_motion_dof_vel_cur().clone() dof_pos += isaaclab_math.sample_uniform( *self.cfg.dof_pos_perturb_range, dof_pos.shape, dof_pos.device, ) soft_dof_pos_limits = self.robot.data.soft_joint_pos_limits[env_ids] dof_pos[env_ids] = torch.clip( dof_pos[env_ids], soft_dof_pos_limits[:, :, 0], soft_dof_pos_limits[:, :, 1], ) self.robot.write_joint_state_to_sim( dof_pos[env_ids], dof_vel[env_ids], env_ids=env_ids, ) def force_realign_root_state_to_ref_no_perturb(self, env_ids) -> None: if not isinstance(env_ids, torch.Tensor): env_ids = torch.tensor( env_ids, device=self.device, dtype=torch.long ) else: env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1) env_ids = self._filter_env_ids_for_motion_task(env_ids) if env_ids.numel() == 0: return root_pos = self.get_ref_motion_root_global_pos_cur().clone() root_rot_xyzw = self.get_ref_motion_root_global_rot_quat_xyzw_cur() root_rot = root_rot_xyzw[..., [3, 0, 1, 2]].clone() root_lin_vel = self.get_ref_motion_root_global_lin_vel_cur().clone() root_ang_vel = self.get_ref_motion_root_global_ang_vel_cur().clone() self.robot.write_root_state_to_sim( torch.cat( [ root_pos[env_ids], root_rot[env_ids], root_lin_vel[env_ids], root_ang_vel[env_ids], ], dim=-1, ), env_ids=env_ids, ) def force_realign_dof_state_to_ref_no_perturb(self, env_ids) -> None: if not isinstance(env_ids, torch.Tensor): env_ids = torch.tensor( env_ids, device=self.device, dtype=torch.long ) else: env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1) env_ids = self._filter_env_ids_for_motion_task(env_ids) if env_ids.numel() == 0: return dof_pos = self.get_ref_motion_dof_pos_cur().clone() dof_vel = self.get_ref_motion_dof_vel_cur().clone() soft_dof_pos_limits = self.robot.data.soft_joint_pos_limits[env_ids] dof_pos[env_ids] = torch.clip( dof_pos[env_ids], soft_dof_pos_limits[:, :, 0], soft_dof_pos_limits[:, :, 1], ) self.robot.write_joint_state_to_sim( dof_pos[env_ids], dof_vel[env_ids], env_ids=env_ids, ) def force_realign_offline_eval_no_perturb(self, env_ids) -> None: self.force_realign_root_state_to_ref_no_perturb(env_ids) self.force_realign_dof_state_to_ref_no_perturb(env_ids) def _update_command(self): all_ids = torch.arange( self.num_envs, dtype=torch.long, device=self.device ) motion_ids = self._filter_env_ids_for_motion_task(all_ids) if motion_ids.numel() == 0: return continue_ids = motion_ids episode_length_buf = getattr(self._env, "episode_length_buf", None) if episode_length_buf is not None: continue_mask = episode_length_buf[motion_ids] != 0 continue_ids = motion_ids[continue_mask] if continue_ids.numel() > 0: self._frame_indices[continue_ids] += 1 self._swap_step_counter += 1 if self._swap_step_counter >= self._motion_cache.swap_interval_steps: self._swap_pending = True # Resample when motion ends self._resample_when_motion_end_cache() self._update_ref_motion_state_from_cache() def _resample_when_motion_end_cache(self): """Resample environments when motion ends (simple cache mode).""" all_ids = torch.arange( self.num_envs, dtype=torch.long, device=self.device ) motion_ids = self._filter_env_ids_for_motion_task(all_ids) if motion_ids.numel() == 0: return lengths = self._motion_cache.lengths_for_indices(self._clip_indices) max_valid_frame = torch.clamp( lengths - 1 - self.cfg.n_fut_frames, min=0 ) need_resample = ( self._frame_indices[motion_ids] > max_valid_frame[motion_ids] ) if torch.any(need_resample): resample_ids = motion_ids[torch.nonzero(need_resample).squeeze(-1)] # Resample these envs self._record_completion_rate_for_envs(resample_ids) clip_idx, frame_idx = self._motion_cache.sample_env_assignments( len(resample_ids), self.cfg.n_fut_frames, self.device, deterministic_start=self._is_evaluating, ) self._clip_indices[resample_ids] = clip_idx self._frame_indices[resample_ids] = frame_idx self._start_frame_indices[resample_ids] = frame_idx self._reward_sum_since_assign[resample_ids] = 0.0 self._step_count_since_assign[resample_ids] = 0.0 # Realign robot state self._update_ref_motion_state_from_cache(env_ids=resample_ids) self._align_root_to_ref(resample_ids) self._align_dof_to_ref(resample_ids) # Mark motion end self._motion_end_mask[motion_ids] = False self._motion_end_mask[resample_ids] = True self.motion_end_counter[resample_ids] += 1 def _update_metrics(self): """Update metrics for command progress tracking.""" if not hasattr(self, "metrics"): self.metrics = {} self._update_mpjpe_metrics() self._update_mpkpe_metrics() def _update_mpjpe_metrics(self): """Update MPJPE (Mean Per Joint Position Error) metrics.""" # Get current and reference joint positions current_dof_pos = self.robot.data.joint_pos # [B, num_dofs] ref_dof_pos = self.get_ref_motion_dof_pos_immediate_next() # Compute joint position errors dof_pos_error = torch.abs( current_dof_pos - ref_dof_pos ) # [B, num_dofs] # MPJPE whole body mpjpe_wholebody = torch.mean(dof_pos_error, dim=-1) # [B] # MPJPE arms (using unified naming) mpjpe_arms = torch.mean( dof_pos_error[:, self.arm_dof_indices], dim=-1 ) # [B] # MPJPE torso (using unified naming) mpjpe_waist = torch.mean( dof_pos_error[:, self.torso_dof_indices], dim=-1 ) # [B] # MPJPE legs mpjpe_legs = torch.mean( dof_pos_error[:, self.leg_dof_indices], dim=-1 ) # [B] # Initialize metric tensors if needed for metric_name in [ "Task/MPJPE_WholeBody", "Task/MPJPE_Arms", "Task/MPJPE_Waist", "Task/MPJPE_Legs", ]: if metric_name not in self.metrics: self.metrics[metric_name] = torch.zeros( self.num_envs, device=self.device ) # Update metric values self.metrics["Task/MPJPE_WholeBody"][:] = mpjpe_wholebody self.metrics["Task/MPJPE_Arms"][:] = mpjpe_arms self.metrics["Task/MPJPE_Waist"][:] = mpjpe_waist self.metrics["Task/MPJPE_Legs"][:] = mpjpe_legs def _update_mpkpe_metrics(self): """Update MPKPE (Mean Per Keybody Position Error) metrics.""" # Get current and reference body positions current_body_pos = self.robot.data.body_pos_w # [B, num_bodies, 3] ref_body_pos = self.get_ref_motion_bodylink_global_pos_immediate_next() # [B, num_bodies, 3] # Compute body position errors (L2 norm) body_pos_error = torch.norm( current_body_pos - ref_body_pos, dim=-1 ) # [B, num_bodies] # MPKPE whole body mpkpe_wholebody = torch.mean(body_pos_error, dim=-1) # [B] # MPKPE arms (using unified naming) mpkpe_arms = torch.mean( body_pos_error[:, self.arm_body_indices], dim=-1 ) # [B] # MPKPE torso (using unified naming) mpkpe_waist = torch.mean( body_pos_error[:, self.torso_body_indices], dim=-1 ) # [B] # MPKPE legs mpkpe_legs = torch.mean( body_pos_error[:, self.leg_body_indices], dim=-1 ) # [B] # Initialize metric tensors if needed for metric_name in [ "Task/MPKPE_WholeBody", "Task/MPKPE_Arms", "Task/MPKPE_Waist", "Task/MPKPE_Legs", ]: if metric_name not in self.metrics: self.metrics[metric_name] = torch.zeros( self.num_envs, device=self.device ) # Update metric values self.metrics["Task/MPKPE_WholeBody"][:] = mpkpe_wholebody self.metrics["Task/MPKPE_Arms"][:] = mpkpe_arms self.metrics["Task/MPKPE_Waist"][:] = mpkpe_waist self.metrics["Task/MPKPE_Legs"][:] = mpkpe_legs # --- Pose-error getters for curriculum (WholeBody only) --- def get_wholebody_mpjpe( self, ) -> torch.Tensor: """[B] current whole-body MPJPE (URDF joint-space abs error).""" if not hasattr(self, "metrics") or ( "Task/MPJPE_WholeBody" not in self.metrics ): return torch.zeros(self.num_envs, device=self.device) return self.metrics["Task/MPJPE_WholeBody"] def get_wholebody_mpkpe( self, ) -> torch.Tensor: """[B] current whole-body MPKPE (body position error).""" if not hasattr(self, "metrics") or ( "Task/MPKPE_WholeBody" not in self.metrics ): return torch.zeros(self.num_envs, device=self.device) return self.metrics["Task/MPKPE_WholeBody"] def get_current_motion_keys( self, ) -> list[str]: """Return motion window keys for the envs' current cached clips.""" try: if hasattr(self, "_motion_cache") and hasattr( self._motion_cache, "motion_keys_for_indices" ): return self._motion_cache.motion_keys_for_indices( self._clip_indices ) except Exception: pass return [] def _set_debug_vis_impl(self, debug_vis: bool): if debug_vis: # Just enable debug mode - visualizers will be created lazily in callback self._debug_vis_enabled = True # Set visibility if visualizers already exist if hasattr(self, "ref_body_visualizers"): for visualizer in self.ref_body_visualizers: visualizer.set_visibility(True) else: self._debug_vis_enabled = False # Set visibility to false if hasattr(self, "ref_body_visualizers"): for visualizer in self.ref_body_visualizers: visualizer.set_visibility(False) def setup_offline_eval_from_frame_zero(self): """Setup reference frame indices for offline evaluation from frame 0.""" self._frame_indices[:] = 0 self._update_ref_motion_state() logger.info( f"Offline evaluation setup complete: all {self.num_envs} " f"environments set to frame 0 references" ) def setup_offline_eval_deterministic( self, apply_pending_swap: bool = True ) -> None: """Deterministic multi-env setup for offline evaluation. - Optionally apply a pending cache swap. - Set env i -> cache row i mapping for active clips, frame 0. - Update reference state only. Robot realignment is handled by caller. """ if apply_pending_swap and getattr(self, "_swap_pending", False): self._motion_cache.advance() self._swap_pending = False self._swap_step_counter = 0 clip_count = int(self._motion_cache.clip_count) active_count = min(int(self.num_envs), clip_count) # Reset indices self._clip_indices[:] = 0 self._frame_indices[:] = 0 if active_count > 0: active_ids = torch.arange( active_count, dtype=torch.long, device=self.device ) self._clip_indices[active_ids] = torch.arange( active_count, dtype=torch.long, device=self.device ) self._update_ref_motion_state_from_cache() def _debug_vis_callback(self, event): if not self.robot.is_initialized: return # Check if debug visualization is enabled if not getattr(self, "_debug_vis_enabled", False): return # Check if motion cache/assignments are available if ( not hasattr(self, "_motion_cache") or self._motion_cache is None or not hasattr(self, "_clip_indices") or not hasattr(self, "_frame_indices") ): return # Create visualizers lazily if they don't exist if not hasattr(self, "ref_body_visualizers"): self.ref_body_visualizers = [] # Get number of bodies from the reference motion data num_bodies = self.get_ref_motion_bodylink_global_pos_cur().shape[ -2 ] for i in range(num_bodies): # Reference bodylinks as red spheres self.ref_body_visualizers.append( VisualizationMarkers( self.cfg.body_keypoint_visualizer_cfg.replace( prim_path=f"/Visuals/Command/ref_body_{i}" ) ) ) # Visualize reference body keypoints if len(self.ref_body_visualizers) > 0: ref_body_pos = self.get_ref_motion_bodylink_global_pos_cur() # [B, num_bodies, 3] num_bodies = min( len(self.ref_body_visualizers), ref_body_pos.shape[1] ) for i in range(num_bodies): # Visualize reference bodylinks as spheres (position only) self.ref_body_visualizers[i].visualize( ref_body_pos[:, i], # [B, 3] ) @configclass class MotionCommandCfg(CommandTermCfg): """Configuration for the motion command.""" class_type: type = RefMotionCommand command_obs_name: str = MISSING urdf_dof_names: list[str] = MISSING urdf_body_names: list[str] = MISSING # DOF name groupings for mpjpe metrics (using unified naming) arm_dof_names: list[str] = MISSING waist_dof_names: list[str] = MISSING leg_dof_names: list[str] = MISSING # Body name groupings for mpkpe metrics (using unified naming) arm_body_names: list[str] = MISSING torso_body_names: list[str] = MISSING leg_body_names: list[str] = MISSING motion_lib_cfg: dict = MISSING seed: int = MISSING process_id: int = MISSING num_processes: int = MISSING is_evaluating: bool = MISSING resample_time_interval_s: float = MISSING n_fut_frames: int = MISSING target_fps: int = MISSING anchor_bodylink_name: str = "pelvis" asset_name: str = MISSING debug_vis: bool = False root_pose_perturb_range: dict[str, tuple[float, float]] = {} root_vel_perturb_range: dict[str, tuple[float, float]] = {} dof_pos_perturb_range: tuple[float, float] = (-0.1, 0.1) dof_vel_perturb_range: tuple[float, float] = (-1.0, 1.0) body_keypoint_visualizer_cfg: VisualizationMarkersCfg = ( SPHERE_MARKER_CFG.replace(prim_path="/Visuals/Command/ref_keypoint") ) body_keypoint_visualizer_cfg.markers["sphere"].radius = 0.03 body_keypoint_visualizer_cfg.markers[ "sphere" ].visual_material = PreviewSurfaceCfg( diffuse_color=(0.0, 0.0, 1.0) # blue ) resampling_time_range: tuple[float, float] = (1.0, 1.0) @configclass class MoTrack_CommandsCfg: pass def build_motion_tracking_commands_config(command_config_dict: dict): """Build isaaclab-compatible CommandsCfg from a config dictionary. Args: command_config_dict: Dictionary mapping command names to command configurations. Each command config should contain the type and parameters. Example: command_config_dict = { "ref_motion": { "type": "MotionCommandCfg", "params": { "command_obs_name": "bydmmc_ref_motion", "motion_lib_cfg": {...}, "process_id": 0, "num_processes": 1, # ... other parameters } } } """ commands_cfg = MoTrack_CommandsCfg() # Add command terms dynamically for command_name, command_config in command_config_dict.items(): command_type = command_config.get("type", "MotionCommandCfg") command_params = command_config.get("params", {}) # Get the command class type if command_type == "MotionCommandCfg": command_cfg = MotionCommandCfg(**command_params) else: raise ValueError(f"Unknown command type: {command_type}") # Add command to config setattr(commands_cfg, command_name, command_cfg) return commands_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_observation.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import isaaclab.envs.mdp as isaaclab_mdp import isaaclab.sim as sim_utils from dataclasses import fields as dataclass_fields from isaaclab.actuators import ImplicitActuatorCfg from isaaclab.assets import Articulation, ArticulationCfg, AssetBaseCfg from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg from isaaclab.managers import ( ActionTermCfg, CommandTerm, CommandTermCfg, EventTermCfg as EventTerm, ObservationGroupCfg, ObservationGroupCfg as ObsGroup, ObservationTermCfg, ObservationTermCfg as ObsTerm, RewardTermCfg, SceneEntityCfg, TerminationTermCfg, ) import torch from isaaclab.markers import ( VisualizationMarkers, VisualizationMarkersCfg, ) from isaaclab.markers.config import FRAME_MARKER_CFG from isaaclab.scene import InteractiveSceneCfg from isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns from isaaclab.sim import PhysxCfg, SimulationCfg from isaaclab.terrains import TerrainImporterCfg from isaaclab.utils import configclass import isaaclab.utils.math as isaaclab_math import isaaclab.utils.noise as isaaclab_noise from omegaconf import DictConfig, ListConfig, OmegaConf from holomotion.src.env.isaaclab_components.isaaclab_utils import ( resolve_holo_config, ) from holomotion.src.utils.frame_utils import ( positions_world_to_env_frame, root_relative_positions_from_env_frame, ) def _build_noise_cfg(noise_cfg): noise_cfg = resolve_holo_config(noise_cfg) if not (isinstance(noise_cfg, dict) and "type" in noise_cfg): return noise_cfg noise_cls = getattr(isaaclab_noise, noise_cfg["type"]) noise_params = resolve_holo_config(noise_cfg.get("params", {})) if not isinstance(noise_params, dict): return noise_cls(**noise_params) noise_params = dict(noise_params) if "n_min_z" in noise_params or "n_max_z" in noise_params: base_n_min = noise_params["n_min"] base_n_max = noise_params["n_max"] noise_params["n_min"] = torch.tensor( [base_n_min, base_n_min, noise_params.pop("n_min_z", base_n_min)], dtype=torch.float32, ) noise_params["n_max"] = torch.tensor( [base_n_max, base_n_max, noise_params.pop("n_max_z", base_n_max)], dtype=torch.float32, ) return noise_cls(**noise_params) class MirrorFunctions: """Generic observation mirroring utilities.""" @staticmethod def mirror_dof( x: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor ) -> torch.Tensor: """Mirror DOF-aligned tensor [..., A] with permutation and sign.""" if x.shape[-1] != int(perm.numel()): raise ValueError( f"mirror_dof expected last dim {perm.numel()}, got {x.shape[-1]}" ) if perm.device != x.device or perm.dtype != torch.long: perm = perm.to(device=x.device, dtype=torch.long) if sign.device != x.device or sign.dtype != x.dtype: sign = sign.to(device=x.device, dtype=x.dtype) mirrored = torch.index_select(x, dim=x.ndim - 1, index=perm) sign_view = sign.view(*([1] * (mirrored.ndim - 1)), sign.numel()) return mirrored * sign_view @staticmethod def mirror_action( actions: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor ) -> torch.Tensor: """Mirror action tensor [..., A] in DOF space with permutation and sign.""" return MirrorFunctions.mirror_dof(actions, perm=perm, sign=sign) @staticmethod def mirror_vec3(x: torch.Tensor) -> torch.Tensor: """Mirror a true vector [..., 3] with sign [1, -1, 1].""" if x.shape[-1] != 3: raise ValueError( f"mirror_vec3 expected last dim 3, got {x.shape[-1]}" ) sign = torch.tensor( [1.0, -1.0, 1.0], device=x.device, dtype=x.dtype ).view(*([1] * (x.ndim - 1)), 3) return x * sign @staticmethod def mirror_axial_vec3(x: torch.Tensor) -> torch.Tensor: """Mirror an axial vector [..., 3] with sign [-1, 1, -1].""" if x.shape[-1] != 3: raise ValueError( f"mirror_axial_vec3 expected last dim 3, got {x.shape[-1]}" ) sign = torch.tensor( [-1.0, 1.0, -1.0], device=x.device, dtype=x.dtype ).view(*([1] * (x.ndim - 1)), 3) return x * sign @staticmethod def mirror_velocity_command(x: torch.Tensor) -> torch.Tensor: """Mirror velocity command [..., 3] or [..., 4] preserving move_mask.""" last_dim = x.shape[-1] if last_dim == 3: sign = torch.tensor( [1.0, -1.0, -1.0], device=x.device, dtype=x.dtype ).view(*([1] * (x.ndim - 1)), 3) return x * sign if last_dim == 4: sign = torch.tensor( [1.0, 1.0, -1.0, -1.0], device=x.device, dtype=x.dtype ).view(*([1] * (x.ndim - 1)), 4) return x * sign raise ValueError( f"mirror_velocity_command expected last dim 3 or 4, got {last_dim}" ) class ObservationFunctions: """Atomic observation functions. The most foundamental observation functions are defined here, aiming to utize the convenient functions from isaaclab apis. For complex observation composition patterns, we'll use the custom observation serizliazer. """ @staticmethod def _get_body_indices( robot: Articulation, keybody_names: list[str] | None ) -> list[int]: """Convert body names to indices. Args: robot: Robot articulation asset keybody_names: List of body names. If None, returns all body indices. Returns: List of body indices corresponding to the given names """ if keybody_names is None: return list(range(robot.num_bodies)) body_indices = [] for name in keybody_names: if name not in robot.body_names: raise ValueError( f"Body '{name}' not found in robot.body_names: {robot.body_names}" ) body_indices.append(robot.body_names.index(name)) return body_indices @staticmethod def _slice_future_frames( tensor: torch.Tensor, *, num_frames: int | None, obs_name: str, ) -> torch.Tensor: if num_frames is None: return tensor num_frames = int(num_frames) if num_frames <= 0: raise ValueError( f"{obs_name} num_frames must be positive, got {num_frames}." ) if tensor.ndim < 2: raise ValueError( f"{obs_name} expected future tensor with ndim >= 2, got {tensor.ndim}." ) if int(tensor.shape[1]) < num_frames: raise ValueError( f"{obs_name} requested {num_frames} future frames, but only " f"{int(tensor.shape[1])} are available." ) return tensor[:, :num_frames, ...] # ------- Robot Head / mid360 States ------- @staticmethod def _get_obs_head_pos_quat_vel( env: ManagerBasedRLEnv, robot_asset_name: str = "robot" ): """Head (mid360) features in torso frame with first-frame anchor. Returns [B,13]: pos(3), quat_wxyz->xyzw(4), lin_vel(3), ang_vel(3), all in torso frame and anchored. """ robot_ptr = env.scene[robot_asset_name] body_names = robot_ptr.body_names if body_names is None: raise RuntimeError("robot.body_names is empty") try: torso_idx = body_names.index("torso_link") except ValueError: raise ValueError( f"'torso_link' not found in body_names: {body_names}" ) B = env.num_envs device = env.device # Mid360 extrinsics relative to torso (rotation about Y by pitch) rel_pos_t = torch.tensor( [0.0002835, 0.00003, 0.41618], dtype=torch.float, device=device ) pitch = torch.tensor( 0.04014257279586953, dtype=torch.float, device=device ) half = pitch * 0.5 # WXYZ rel_quat_wxyz = torch.stack( [ torch.cos(half), torch.zeros_like(half), torch.sin(half), torch.zeros_like(half), ], dim=-1, ) rel_quat_wxyz = rel_quat_wxyz.expand(B, -1) # World pose/vel from torso + extrinsics (WXYZ math) torso_pos_w = robot_ptr.data.body_pos_w[:, torso_idx, :] torso_quat_wxyz = robot_ptr.data.body_quat_w[:, torso_idx, :] torso_lin_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :] torso_ang_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :] rel_pos = rel_pos_t.expand(B, -1) r_world = isaaclab_math.quat_apply(torso_quat_wxyz, rel_pos) pos_w = torso_pos_w + r_world quat_wxyz = isaaclab_math.quat_mul(torso_quat_wxyz, rel_quat_wxyz) lin_w = torso_lin_w + torch.cross(torso_ang_w, r_world, dim=-1) ang_w = torso_ang_w # Convert to torso frame (WXYZ math) rel_p = pos_w - torso_pos_w torso_inv_wxyz = isaaclab_math.quat_inv(torso_quat_wxyz) pos_torso = isaaclab_math.quat_apply(torso_inv_wxyz, rel_p) lin_torso = isaaclab_math.quat_apply( torso_inv_wxyz, lin_w - torch.cross(ang_w, rel_p, dim=-1) ) ang_torso = isaaclab_math.quat_apply(torso_inv_wxyz, ang_w) quat_torso_wxyz = isaaclab_math.quat_mul(torso_inv_wxyz, quat_wxyz) # export quaternion as XYZW to match common obs format quat_torso_xyzw = quat_torso_wxyz[..., [1, 2, 3, 0]] # First-frame anchor normalization (in torso frame) if not hasattr(env, "head_anchor_set"): env.head_anchor_set = torch.zeros( B, dtype=torch.bool, device=device ) env.head_anchor_pos = torch.zeros(B, 3, device=device) env.head_anchor_quat_wxyz = torch.zeros(B, 4, device=device) env.head_anchor_quat_wxyz[:, 0] = 1.0 # identity W unset = ~env.head_anchor_set if unset.any(): env.head_anchor_pos[unset] = pos_torso[unset] env.head_anchor_quat_wxyz[unset] = quat_torso_wxyz[unset] env.head_anchor_set[unset] = True q0_inv = isaaclab_math.quat_inv(env.head_anchor_quat_wxyz) pos_rel = isaaclab_math.quat_apply( q0_inv, pos_torso - env.head_anchor_pos ) lin_rel = isaaclab_math.quat_apply(q0_inv, lin_torso) ang_rel = isaaclab_math.quat_apply(q0_inv, ang_torso) quat_rel_wxyz = isaaclab_math.quat_mul(q0_inv, quat_torso_wxyz) quat_rel_xyzw = quat_rel_wxyz[..., [1, 2, 3, 0]] return torch.cat([pos_rel, quat_rel_xyzw, lin_rel, ang_rel], dim=-1) @staticmethod def _get_obs_rel_headlink_lin_vel( env: ManagerBasedRLEnv, robot_asset_name: str = "robot" ) -> torch.Tensor: # [num_envs, 3] """Headlink relative linear velocity, expressed in the headlink's frame. Definitions: - Headlink: a virtual rigid sensor frame fixed to `torso_link` using the extrinsics defined below (translation `rel_pos_t` and rotation `rel_quat_wxyz`). - Relative linear velocity: v_head - v_torso_origin, both measured in the world frame before re-expression. For a rigid mount, this equals ω_torso × r_world. - Expression frame: the instantaneous headlink frame (i.e., result is in headlink axes). Returns: Tensor of shape [num_envs, 3]: headlink relative linear velocity in headlink frame. """ robot_ptr = env.scene[robot_asset_name] body_names = robot_ptr.body_names if body_names is None: raise RuntimeError("robot.body_names is empty") torso_idx = body_names.index("torso_link") num_envs = env.num_envs device = env.device # Headlink extrinsics relative to torso: translation + rotation about Y (pitch) rel_pos_t = torch.tensor( [0.0002835, 0.00003, 0.41618], dtype=torch.float, device=device ) # [3] pitch = torch.tensor( 0.04014257279586953, dtype=torch.float, device=device ) half = pitch * 0.5 # Quaternion (WXYZ) for rotation about Y by 'pitch' rel_quat_wxyz = torch.stack( [ torch.cos(half), torch.zeros_like(half), torch.sin(half), torch.zeros_like(half), ], dim=-1, ).expand(num_envs, -1) # [num_envs, 4] # Torso world state torso_quat_wxyz = robot_ptr.data.body_quat_w[ :, torso_idx, : ] # [num_envs, 4] torso_lin_w = robot_ptr.data.body_lin_vel_w[ :, torso_idx, : ] # [num_envs, 3] torso_ang_w = robot_ptr.data.body_ang_vel_w[ :, torso_idx, : ] # [num_envs, 3] # Headlink world pose from torso + extrinsics rel_pos = rel_pos_t.expand(num_envs, -1) # [num_envs, 3] r_world = isaaclab_math.quat_apply( torso_quat_wxyz, rel_pos ) # [num_envs, 3] head_quat_wxyz = isaaclab_math.quat_mul( torso_quat_wxyz, rel_quat_wxyz ) # [num_envs, 4] # World-frame velocities head_lin_w = torso_lin_w + torch.cross( torso_ang_w, r_world, dim=-1 ) # [num_envs, 3] # Relative linear velocity in world frame rel_lin_w = ( head_lin_w - torso_lin_w ) # [num_envs, 3] == ω_torso × r_world # Re-express in headlink frame head_inv_wxyz = isaaclab_math.quat_inv(head_quat_wxyz) # [num_envs, 4] rel_lin_head = isaaclab_math.quat_apply( head_inv_wxyz, rel_lin_w ) # [num_envs, 3] return rel_lin_head @staticmethod def _get_obs_rel_headlink_ang_vel( env: ManagerBasedRLEnv, robot_asset_name: str = "robot" ) -> torch.Tensor: # [num_envs, 3] """Headlink relative angular velocity, expressed in the headlink's frame. Definitions: - Headlink: a virtual rigid sensor frame fixed to `torso_link` using the extrinsics defined below. - Relative angular velocity: ω_head - ω_torso, measured in the world frame, then re-expressed in the headlink frame. - For a rigid mount (no neck articulation), ω_head == ω_torso, so the result is identically zero. If an articulated head exists, replace ω_head with the head link's world angular velocity before the subtraction. Returns: Tensor of shape [num_envs, 3]: headlink relative angular velocity in headlink frame. """ robot_ptr = env.scene[robot_asset_name] body_names = robot_ptr.body_names if body_names is None: raise RuntimeError("robot.body_names is empty") torso_idx = body_names.index("torso_link") num_envs = env.num_envs device = env.device # Headlink extrinsics (rotation about Y by pitch) pitch = torch.tensor( 0.04014257279586953, dtype=torch.float, device=device ) half = pitch * 0.5 rel_quat_wxyz = torch.stack( [ torch.cos(half), torch.zeros_like(half), torch.sin(half), torch.zeros_like(half), ], dim=-1, ).expand(num_envs, -1) # [num_envs, 4] torso_quat_wxyz = robot_ptr.data.body_quat_w[ :, torso_idx, : ] # [num_envs, 4] torso_ang_w = robot_ptr.data.body_ang_vel_w[ :, torso_idx, : ] # [num_envs, 3] # For the rigid mount, ω_head_w == ω_torso_w head_ang_w = torso_ang_w # [num_envs, 3] rel_ang_w = ( head_ang_w - torso_ang_w ) # [num_envs, 3] -> zeros for rigid mount # Re-express in headlink frame head_quat_wxyz = isaaclab_math.quat_mul( torso_quat_wxyz, rel_quat_wxyz ) # [num_envs, 4] head_inv_wxyz = isaaclab_math.quat_inv(head_quat_wxyz) # [num_envs, 4] rel_ang_head = isaaclab_math.quat_apply( head_inv_wxyz, rel_ang_w ) # [num_envs, 3] return rel_ang_head # ------- Robot Root States ------- @staticmethod def _get_obs_global_robot_root_pos(env: ManagerBasedRLEnv): """Asset root position in the environment frame. IsaacLab's root position helpers subtract `env.scene.env_origins`, so this is not the raw simulator-world position. """ return isaaclab_mdp.root_pos_w(env) @staticmethod def _get_obs_global_robot_root_rot_wxyz(env: ManagerBasedRLEnv): """Asset root orientation (w, x, y, z) in the environment frame.""" return isaaclab_mdp.root_quat_w(env) @staticmethod def _get_obs_global_robot_root_rot_xyzw(env: ManagerBasedRLEnv): """Asset root orientation (x, y, z, w) in the environment frame.""" return ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)[ ..., [1, 2, 3, 0] ] @staticmethod def _get_obs_global_robot_root_rot_mat(env: ManagerBasedRLEnv): """Asset root orientation as a 3x3 matrix, flattened to the first two rows (6D).""" return isaaclab_math.matrix_from_quat( ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env) )[..., :2] # [num_envs, 6] @staticmethod def _get_obs_global_robot_root_lin_vel(env: ManagerBasedRLEnv): """Asset root linear velocity in the environment frame.""" return isaaclab_mdp.root_lin_vel_w(env) # [num_envs, 3] @staticmethod def _get_obs_global_robot_root_ang_vel(env: ManagerBasedRLEnv): """Asset root angular velocity in the environment frame.""" return isaaclab_mdp.root_ang_vel_w(env) # [num_envs, 3] @staticmethod def _get_obs_rel_robot_root_lin_vel(env: ManagerBasedRLEnv): """Relative root linear velocity in the root frame.""" return isaaclab_mdp.base_lin_vel(env) # [num_envs, 3] @staticmethod def _get_obs_rel_robot_root_ang_vel(env: ManagerBasedRLEnv): """Relative root angular velocity in the root frame.""" return isaaclab_mdp.base_ang_vel(env) # [num_envs, 3] @staticmethod def _get_obs_rel_anchor_lin_vel( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", anchor_bodylink_name: str = "torso_link", ): """Relative anchor linear velocity in the anchor frame.""" torso_global_rot_quat_wxyz = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz( env, robot_asset_name, [anchor_bodylink_name] ) ) # [num_envs, 1, 4] torso_global_lin_vel = ( ObservationFunctions._get_obs_global_robot_bodylink_lin_vel( env, robot_asset_name, [anchor_bodylink_name] ) ) # [num_envs, 1, 3] torso_rel_lin_vel = isaaclab_math.quat_apply( isaaclab_math.quat_inv(torso_global_rot_quat_wxyz), torso_global_lin_vel, ) # [num_envs, 1, 3] return torso_rel_lin_vel.squeeze(1) # [num_envs, 3] @staticmethod def _get_obs_projected_gravity( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", ) -> torch.Tensor: # [num_envs, 3] """Gravity vector projected into the robot's root frame. Projects the world-frame gravity vector into the robot's base frame using the inverse root orientation quaternion. """ robot_ptr = env.scene[robot_asset_name] g_w: torch.Tensor = robot_ptr.data.GRAVITY_VEC_W # [num_envs, 3] root_quat_wxyz: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env) ) # [num_envs, 4] # Project gravity into root frame using inverse quaternion projected_gravity: torch.Tensor = isaaclab_math.quat_apply_inverse( root_quat_wxyz, g_w ) # [num_envs, 3] return projected_gravity @staticmethod def _get_obs_global_robot_root_yaw( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", ): """Robot's yaw heading in the environment frame (in radians).""" robot_ptr = env.scene[robot_asset_name] return robot_ptr.data.heading_w # [num_envs, ] # @torch.compile @staticmethod def _get_obs_robot_root_heading_aligned_quat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", ): """A quaternion representing only the robot's yaw heading.""" global_yaw = ObservationFunctions._get_obs_global_robot_root_yaw( env, robot_asset_name, ) # [num_envs, ] zero_roll = torch.zeros_like(global_yaw, device=env.device) zero_pitch = torch.zeros_like(global_yaw, device=env.device) heading_aligned_quat = isaaclab_math.quat_from_angle_axis( roll=zero_roll, pitch=zero_pitch, yaw=global_yaw, ) # [num_envs, 4] return heading_aligned_quat # [num_envs, 4] # @torch.compile @staticmethod def _get_obs_rel_robot_root_roll_pitch( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", ): """Robot's roll and pitch relative to its heading-aligned frame.""" heading_aligned_quat = ( ObservationFunctions._get_obs_robot_root_heading_aligned_quat( env, robot_asset_name, ) ) # [num_envs, 4] robot_quat_in_heading_aligned_frame = isaaclab_math.quat_mul( isaaclab_math.quat_inv(heading_aligned_quat), ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env), ) # [num_envs, 4] rel_roll, rel_pitch, _ = isaaclab_math.get_euler_xyz( robot_quat_in_heading_aligned_frame ) # [num_envs, 3] return torch.stack([rel_roll, rel_pitch], dim=-1) # [num_envs, 2] # ------- Robot Bodylink States ------- @staticmethod def _get_obs_global_robot_bodylink_pos( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ): """Positions of specified bodylinks in the environment frame. Body link poses are stored in simulator-world coordinates, so this helper subtracts `env.scene.env_origins` to match IsaacLab's environment-frame root helpers. """ robot_ptr = env.scene[robot_asset_name] keybody_idxs = ObservationFunctions._get_body_indices( robot_ptr, keybody_names ) keybody_global_pos = positions_world_to_env_frame( robot_ptr.data.body_pos_w[:, keybody_idxs], env.scene.env_origins, ) return keybody_global_pos # [num_envs, num_keybodies, 3] @staticmethod def _get_obs_global_robot_bodylink_rot_wxyz( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ): """Orientations (w, x, y, z) of specified bodylinks in the environment frame.""" robot_ptr = env.scene[robot_asset_name] keybody_idxs = ObservationFunctions._get_body_indices( robot_ptr, keybody_names ) keybody_global_rot = robot_ptr.data.body_quat_w[:, keybody_idxs] return keybody_global_rot # [num_envs, num_keybodies, 4] @staticmethod def _get_obs_global_robot_bodylink_rot_xyzw( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ): """Orientations (x, y, z, w) of specified bodylinks in the environment frame.""" return ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz( env, robot_asset_name, keybody_names, )[..., [1, 2, 3, 0]] # [num_envs, num_keybodies, 4] @staticmethod def _get_obs_global_robot_bodylink_rot_mat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ): """Orientations of specified bodylinks as a 3x3 matrix, flattened to the first two rows (6D).""" keybody_global_rot_wxyz = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz( env, robot_asset_name, keybody_names, ) ) return isaaclab_math.matrix_from_quat(keybody_global_rot_wxyz)[ ..., :2 ] # [num_envs, num_keybodies, 6] @staticmethod def _get_obs_global_robot_bodylink_lin_vel( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ): """Linear velocities of specified bodylinks in the environment frame.""" robot_ptr = env.scene[robot_asset_name] keybody_idxs = ObservationFunctions._get_body_indices( robot_ptr, keybody_names ) keybody_global_lin_vel = robot_ptr.data.body_lin_vel_w[:, keybody_idxs] return keybody_global_lin_vel # [num_envs, num_keybodies, 3] @staticmethod def _get_obs_global_robot_bodylink_ang_vel( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ): """Angular velocities of specified bodylinks in the environment frame.""" robot_ptr = env.scene[robot_asset_name] keybody_idxs = ObservationFunctions._get_body_indices( robot_ptr, keybody_names ) keybody_global_ang_vel = robot_ptr.data.body_ang_vel_w[:, keybody_idxs] return keybody_global_ang_vel # [num_envs, num_keybodies, 3] @staticmethod def _get_obs_root_rel_robot_bodylink_pos( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies, 3] """Root-relative bodylink positions from environment-frame positions.""" # Get global states keybody_global_pos: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_bodylink_pos( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 3] global_root_pos: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_root_pos(env) ) # [num_envs, 3] root_global_rot_wxyz: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env) ) # [num_envs, 4] return root_relative_positions_from_env_frame( body_pos_env=keybody_global_pos, root_pos_env=global_root_pos, root_quat_w=root_global_rot_wxyz, ) @staticmethod def _get_obs_root_rel_robot_bodylink_rot_wxyz( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies, 4] """Orientations (w, x, y, z) of specified bodylinks relative to the robot's root frame.""" # Get global states keybody_global_rot: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 4] root_global_rot_wxyz: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env) ) # [num_envs, 4] # Transform to root frame by multiplying with inverse root rotation root_inv_rot: torch.Tensor = isaaclab_math.quat_inv( root_global_rot_wxyz ) # [num_envs, 4] num_bodies = keybody_global_rot.shape[1] rel_rot_root: torch.Tensor = isaaclab_math.quat_mul( root_inv_rot[..., None, :].expand(-1, num_bodies, -1), keybody_global_rot, ) # [num_envs, num_keybodies, 4] return rel_rot_root @staticmethod def _get_obs_root_rel_robot_bodylink_rot_xyzw( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies, 4] """Orientations (x, y, z, w) of specified bodylinks relative to the robot's root frame.""" return ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz( env, robot_asset_name, keybody_names )[ ..., [1, 2, 3, 0] ] # [num_envs, num_keybodies, 4] - convert WXYZ to XYZW @staticmethod def _get_obs_root_rel_robot_bodylink_rot_mat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies, 6] """Orientations of specified bodylinks relative to the robot's root frame, as a 3x3 matrix, flattened to the first two rows (6D).""" keybody_rel_rot_wxyz: torch.Tensor = ( ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 4] return isaaclab_math.matrix_from_quat(keybody_rel_rot_wxyz)[ ..., :2 ] # [num_envs, num_keybodies, 6] @staticmethod def _get_obs_root_rel_robot_bodylink_lin_vel( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies, 3] """Linear velocities of specified bodylinks relative to the robot's root frame.""" # Get global states keybody_global_lin_vel: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_bodylink_lin_vel( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 3] root_global_lin_vel: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_root_lin_vel(env) ) # [num_envs, 3] root_global_rot_wxyz: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env) ) # [num_envs, 4] # Compute relative velocity in world frame rel_lin_vel_w = keybody_global_lin_vel - root_global_lin_vel.unsqueeze( 1 ) # Transform to root frame by rotating with inverse root rotation root_inv_rot: torch.Tensor = isaaclab_math.quat_inv( root_global_rot_wxyz ) # [num_envs, 4] rel_lin_vel_root: torch.Tensor = isaaclab_math.quat_apply( root_inv_rot.unsqueeze(1), rel_lin_vel_w ) # [num_envs, num_keybodies, 3] return rel_lin_vel_root @staticmethod def _get_obs_root_rel_robot_bodylink_ang_vel( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies, 3] """Angular velocities of specified bodylinks relative to the robot's root frame.""" # Get global states keybody_global_ang_vel: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_bodylink_ang_vel( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 3] root_global_ang_vel: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_root_ang_vel(env) ) # [num_envs, 3] root_global_rot_wxyz: torch.Tensor = ( ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env) ) # [num_envs, 4] # Compute relative angular velocity in world frame rel_ang_vel_w = keybody_global_ang_vel - root_global_ang_vel.unsqueeze( 1 ) # Transform to root frame by rotating with inverse root rotation root_inv_rot: torch.Tensor = isaaclab_math.quat_inv( root_global_rot_wxyz ) # [num_envs, 4] rel_ang_vel_root: torch.Tensor = isaaclab_math.quat_apply( root_inv_rot.unsqueeze(1), rel_ang_vel_w ) # [num_envs, num_keybodies, 3] return rel_ang_vel_root # ------- Flat Bodylink Observations ------- @staticmethod def _get_obs_global_robot_bodylink_pos_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 3] """Flattened positions of specified bodylinks in the environment frame.""" bodylink_pos = ObservationFunctions._get_obs_global_robot_bodylink_pos( env, robot_asset_name, keybody_names ) # [num_envs, num_keybodies, 3] return bodylink_pos.reshape( bodylink_pos.shape[0], -1 ) # [num_envs, num_keybodies * 3] @staticmethod def _get_obs_global_robot_bodylink_rot_wxyz_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 4] """Flattened orientations (w, x, y, z) of specified bodylinks in the environment frame.""" bodylink_rot = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 4] return bodylink_rot.reshape( bodylink_rot.shape[0], -1 ) # [num_envs, num_keybodies * 4] @staticmethod def _get_obs_global_robot_bodylink_rot_xyzw_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 4] """Flattened orientations (x, y, z, w) of specified bodylinks in the environment frame.""" bodylink_rot = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_xyzw( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 4] return bodylink_rot.reshape( bodylink_rot.shape[0], -1 ) # [num_envs, num_keybodies * 4] @staticmethod def _get_obs_global_robot_bodylink_rot_mat_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 6] """Flattened orientation matrices (6D) of specified bodylinks in the environment frame.""" bodylink_rot_mat = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_mat( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 6] return bodylink_rot_mat.reshape( bodylink_rot_mat.shape[0], -1 ) # [num_envs, num_keybodies * 6] @staticmethod def _get_obs_global_robot_bodylink_lin_vel_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 3] """Flattened linear velocities of specified bodylinks in the environment frame.""" bodylink_lin_vel = ( ObservationFunctions._get_obs_global_robot_bodylink_lin_vel( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 3] return bodylink_lin_vel.reshape( bodylink_lin_vel.shape[0], -1 ) # [num_envs, num_keybodies * 3] @staticmethod def _get_obs_global_robot_bodylink_ang_vel_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 3] """Flattened angular velocities of specified bodylinks in the environment frame.""" bodylink_ang_vel = ( ObservationFunctions._get_obs_global_robot_bodylink_ang_vel( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 3] return bodylink_ang_vel.reshape( bodylink_ang_vel.shape[0], -1 ) # [num_envs, num_keybodies * 3] @staticmethod def _get_obs_root_rel_robot_bodylink_pos_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 3] """Flattened positions of specified bodylinks relative to the robot's root frame.""" bodylink_pos = ( ObservationFunctions._get_obs_root_rel_robot_bodylink_pos( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 3] return bodylink_pos.reshape( bodylink_pos.shape[0], -1 ) # [num_envs, num_keybodies * 3] @staticmethod def _get_obs_root_rel_robot_bodylink_rot_wxyz_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 4] """Flattened orientations (w, x, y, z) of specified bodylinks relative to the robot's root frame.""" bodylink_rot = ( ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 4] return bodylink_rot.reshape( bodylink_rot.shape[0], -1 ) # [num_envs, num_keybodies * 4] @staticmethod def _get_obs_root_rel_robot_bodylink_rot_xyzw_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 4] """Flattened orientations (x, y, z, w) of specified bodylinks relative to the robot's root frame.""" bodylink_rot = ( ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_xyzw( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 4] return bodylink_rot.reshape( bodylink_rot.shape[0], -1 ) # [num_envs, num_keybodies * 4] @staticmethod def _get_obs_root_rel_robot_bodylink_rot_mat_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 6] """Flattened orientation matrices (6D) of specified bodylinks relative to the robot's root frame.""" bodylink_rot_mat = ( ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_mat( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 6] return bodylink_rot_mat.reshape( bodylink_rot_mat.shape[0], -1 ) # [num_envs, num_keybodies * 6] @staticmethod def _get_obs_root_rel_robot_bodylink_lin_vel_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 3] """Flattened linear velocities of specified bodylinks relative to the robot's root frame.""" bodylink_lin_vel = ( ObservationFunctions._get_obs_root_rel_robot_bodylink_lin_vel( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 3] return bodylink_lin_vel.reshape( bodylink_lin_vel.shape[0], -1 ) # [num_envs, num_keybodies * 3] @staticmethod def _get_obs_root_rel_robot_bodylink_ang_vel_flat( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies * 3] """Flattened angular velocities of specified bodylinks relative to the robot's root frame.""" bodylink_ang_vel = ( ObservationFunctions._get_obs_root_rel_robot_bodylink_ang_vel( env, robot_asset_name, keybody_names ) ) # [num_envs, num_keybodies, 3] return bodylink_ang_vel.reshape( bodylink_ang_vel.shape[0], -1 ) # [num_envs, num_keybodies * 3] # ------- Robot DoF States ------- @staticmethod def _get_obs_dof_pos(env: ManagerBasedRLEnv): """Joint positions relative to the default joint angles.""" return isaaclab_mdp.joint_pos_rel(env) # [num_envs, num_dofs] @staticmethod def _get_obs_dof_vel(env: ManagerBasedRLEnv): """Joint velocities.""" return isaaclab_mdp.joint_vel_rel(env) # [num_envs, num_dofs] @staticmethod def _get_obs_last_actions(env: ManagerBasedRLEnv): """Last action output by the policy.""" return isaaclab_mdp.last_action(env) # [num_envs, num_actions] # ------- Reference Motion States ------- @staticmethod def _get_obs_ref_motion_states( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ): """Reference motion states (flattened) via RefMotionCommand schema.""" command = env.command_manager.get_term(ref_motion_command_name) obs_fn_name = f"_get_obs_{command.cfg.command_obs_name}" obs_fn = getattr(command, obs_fn_name) return obs_fn(obs_prefix=ref_prefix) @staticmethod def _get_obs_ref_motion_states_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ): """Future reference motion states (flattened).""" command = env.command_manager.get_term(ref_motion_command_name) obs_fn_name = f"_get_obs_{command.cfg.command_obs_name}_fut" obs_fn = getattr(command, obs_fn_name) return obs_fn(obs_prefix=ref_prefix) @staticmethod def _get_obs_vr_ref_motion_states( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ): command = env.command_manager.get_term(ref_motion_command_name) return command._get_obs_vr_ref_motion_states(obs_prefix=ref_prefix) @staticmethod def _get_obs_vr_ref_motion_states_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ): """Future reference motion states (flattened).""" command = env.command_manager.get_term(ref_motion_command_name) return command._get_obs_vr_ref_motion_fut(obs_prefix=ref_prefix) @staticmethod def _get_obs_ref_dof_pos_cur( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, num_dofs] """Reference current DoF positions in simulator DoF order.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_dof_pos_cur(prefix=ref_prefix) @staticmethod def _get_obs_immediate_next_two_dof_pos( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 2 * num_dofs] """Immediate next two DoF positions in simulator DoF order.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_immediate_next_two_dof_pos(prefix=ref_prefix) @staticmethod def _get_obs_ref_motion_cur_heading_aligned_root_pos( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 3] """Reference current heading-aligned root position.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_cur_heading_aligned_root_pos( prefix=ref_prefix ) @staticmethod def _get_obs_ref_motion_fut_heading_aligned_root_pos( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, T, 3] """Future reference heading-aligned root position.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_fut_heading_aligned_root_pos( prefix=ref_prefix ) @staticmethod def _get_obs_ref_motion_cur_heading_aligned_root_rot6d( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 6] """Reference current heading-aligned root rotation (rot6d).""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_cur_heading_aligned_root_rot6d( prefix=ref_prefix ) @staticmethod def _get_obs_ref_motion_fut_heading_aligned_root_rot6d( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, T, 6] """Future reference heading-aligned root rotation (rot6d).""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_fut_heading_aligned_root_rot6d( prefix=ref_prefix ) @staticmethod def _get_obs_ref_motion_cur_heading_aligned_root_lin_vel( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 3] """Reference current heading-aligned root linear velocity.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_cur_heading_aligned_root_lin_vel( prefix=ref_prefix ) @staticmethod def _get_obs_ref_motion_fut_heading_aligned_root_lin_vel( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, T, 3] """Future reference heading-aligned root linear velocity.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_fut_heading_aligned_root_lin_vel( prefix=ref_prefix ) @staticmethod def _get_obs_ref_motion_cur_heading_aligned_root_ang_vel( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 3] """Reference current heading-aligned root angular velocity.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_cur_heading_aligned_root_ang_vel( prefix=ref_prefix ) @staticmethod def _get_obs_ref_motion_fut_heading_aligned_root_ang_vel( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, T, 3] """Future reference heading-aligned root angular velocity.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_fut_heading_aligned_root_ang_vel( prefix=ref_prefix ) @staticmethod def _get_obs_ref_dof_vel_cur( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, num_dofs] """Reference current DoF velocities in simulator DoF order.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_dof_vel_cur(prefix=ref_prefix) @staticmethod def _get_obs_ref_motion_filter_cutoff_hz( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ) -> torch.Tensor: """Return clip-level filter metadata; this is prefix-independent.""" command = env.command_manager.get_term(ref_motion_command_name) return command.get_ref_motion_filter_cutoff_hz_cur() @staticmethod def _get_obs_ref_root_height_cur( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 1] """Reference current root height: world z minus env-origin z.""" command = env.command_manager.get_term(ref_motion_command_name) world_pos = command.get_ref_motion_root_global_pos_cur( prefix=ref_prefix ) # [B, 3] height = (world_pos[..., 2] - env.scene.env_origins[..., 2]).unsqueeze( -1 ) # [B,1] return height @staticmethod def _get_obs_ref_dof_pos_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", num_frames: int | None = None, ) -> torch.Tensor: # [num_envs, n_fut_frames * num_dofs] """Future reference DoF positions (flattened over time) in simulator DoF order.""" command = env.command_manager.get_term(ref_motion_command_name) dof_pos_fut = command.get_ref_motion_dof_pos_fut( prefix=ref_prefix ) # [B, T, D(sim)] dof_pos_fut = ObservationFunctions._slice_future_frames( dof_pos_fut, num_frames=num_frames, obs_name="ref_dof_pos_fut", ) return dof_pos_fut @staticmethod def _get_obs_ref_gravity_projection_cur( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 3] """Reference gravity projection.""" command = env.command_manager.get_term(ref_motion_command_name) gravity_projection = command.get_ref_motion_gravity_projection_cur( prefix=ref_prefix ) return gravity_projection @staticmethod def _get_obs_ref_gravity_projection_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", num_frames: int | None = None, ) -> torch.Tensor: # [num_envs, T, 3] """Future reference gravity projection.""" command = env.command_manager.get_term(ref_motion_command_name) gravity_projection = command.get_ref_motion_gravity_projection_fut( prefix=ref_prefix ) gravity_projection = ObservationFunctions._slice_future_frames( gravity_projection, num_frames=num_frames, obs_name="ref_gravity_projection_fut", ) return gravity_projection @staticmethod def _get_obs_ref_base_linvel_cur( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 3] """Reference base linear velocity.""" command = env.command_manager.get_term(ref_motion_command_name) base_linvel = command.get_ref_motion_base_linvel_cur(prefix=ref_prefix) return base_linvel @staticmethod def _get_obs_ref_base_linvel_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", num_frames: int | None = None, ) -> torch.Tensor: # [num_envs, T, 3] """Future reference base linear velocity.""" command = env.command_manager.get_term(ref_motion_command_name) base_linvel = command.get_ref_motion_base_linvel_fut(prefix=ref_prefix) base_linvel = ObservationFunctions._slice_future_frames( base_linvel, num_frames=num_frames, obs_name="ref_base_linvel_fut", ) return base_linvel @staticmethod def _get_obs_ref_base_angvel_cur( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: # [num_envs, 3] """Reference base angular velocity.""" command = env.command_manager.get_term(ref_motion_command_name) base_angvel = command.get_ref_motion_base_angvel_cur(prefix=ref_prefix) return base_angvel @staticmethod def _get_obs_ref_keybody_rel_pos_cur( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", keybody_names: list[str] | None = None, ) -> torch.Tensor: # [num_envs, num_keybodies, 3] """Reference keybody root-relative positions.""" command = env.command_manager.get_term(ref_motion_command_name) ref_keybody_rel_pos = command.get_ref_motion_bodylink_rel_pos_cur( prefix=ref_prefix ) # [B, N, 3] if keybody_names is None: return ref_keybody_rel_pos robot_ptr = env.scene["robot"] keybody_idxs = ObservationFunctions._get_body_indices( robot_ptr, keybody_names ) kb_rel_pos = ref_keybody_rel_pos[:, keybody_idxs, :] bs = kb_rel_pos.shape[0] return kb_rel_pos.reshape(bs, -1) @staticmethod def _get_obs_ref_keybody_rel_pos_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", keybody_names: list[str] | None = None, num_frames: int | None = None, ) -> torch.Tensor: # [num_envs, T, num_keybodies, 3] """Future reference keybody root-relative positions.""" command = env.command_manager.get_term(ref_motion_command_name) ref_keybody_rel_pos_fut = command.get_ref_motion_bodylink_rel_pos_fut( prefix=ref_prefix ) # [B, T, N, 3] ref_keybody_rel_pos_fut = ObservationFunctions._slice_future_frames( ref_keybody_rel_pos_fut, num_frames=num_frames, obs_name="ref_keybody_rel_pos_fut", ) if keybody_names is None: return ref_keybody_rel_pos_fut robot_ptr = env.scene["robot"] keybody_idxs = ObservationFunctions._get_body_indices( robot_ptr, keybody_names ) kb_rel_pos_fut = ref_keybody_rel_pos_fut[:, :, keybody_idxs, :] bs, t, _, _ = kb_rel_pos_fut.shape return kb_rel_pos_fut.reshape(bs, t, -1) @staticmethod def _get_obs_ref_base_angvel_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", num_frames: int | None = None, ) -> torch.Tensor: # [num_envs, T, 3] """Future reference base angular velocity.""" command = env.command_manager.get_term(ref_motion_command_name) base_angvel = command.get_ref_motion_base_angvel_fut(prefix=ref_prefix) base_angvel = ObservationFunctions._slice_future_frames( base_angvel, num_frames=num_frames, obs_name="ref_base_angvel_fut", ) return base_angvel @staticmethod def _get_obs_ref_dof_vel_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", num_frames: int | None = None, ) -> torch.Tensor: # [num_envs, n_fut_frames * num_dofs] """Future reference DoF velocities (flattened over time) in simulator DoF order.""" command = env.command_manager.get_term(ref_motion_command_name) dof_vel_fut = command.get_ref_motion_dof_vel_fut( prefix=ref_prefix ) # [B, T, D(sim)] dof_vel_fut = ObservationFunctions._slice_future_frames( dof_vel_fut, num_frames=num_frames, obs_name="ref_dof_vel_fut", ) B, T, D = dof_vel_fut.shape return dof_vel_fut.reshape(B, T * D) @staticmethod def _get_obs_ref_root_height_fut( env: ManagerBasedRLEnv, ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", num_frames: int | None = None, ) -> torch.Tensor: # [num_envs, n_fut_frames] """Future reference root heights per frame: world z minus env-origin z.""" command = env.command_manager.get_term(ref_motion_command_name) world_pos = command.get_ref_motion_root_global_pos_fut( prefix=ref_prefix ) # [B, T, 3] world_pos = ObservationFunctions._slice_future_frames( world_pos, num_frames=num_frames, obs_name="ref_root_height_fut", ) heights = ( world_pos[..., 2] - env.scene.env_origins[:, None, 2] ) # [B, T] return heights[..., None] # @torch.compile @staticmethod def _get_obs_global_anchor_diff( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ): command = env.command_manager.get_term(ref_motion_command_name) env_ref_motion_anchor_pos = positions_world_to_env_frame( command.get_ref_motion_anchor_bodylink_global_pos_cur( prefix=ref_prefix ), env.scene.env_origins, ) global_ref_motino_anchor_rot_wxyz = ( command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur( prefix=ref_prefix ) ) global_robot_anchor_pos = ( ObservationFunctions._get_obs_global_robot_bodylink_pos( env, robot_asset_name, [command.anchor_bodylink_name] ).squeeze(1) ) global_robot_anchor_rot_wxyz = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz( env, robot_asset_name, [command.anchor_bodylink_name] ).squeeze(1) ) pos_diff, rot_diff = isaaclab_math.subtract_frame_transforms( t01=global_robot_anchor_pos, q01=global_robot_anchor_rot_wxyz, t02=env_ref_motion_anchor_pos, q02=global_ref_motino_anchor_rot_wxyz, ) rot_diff_mat = isaaclab_math.matrix_from_quat(rot_diff) return torch.cat( [ pos_diff, rot_diff_mat[..., :2].reshape(env.num_envs, -1), ], dim=-1, ) # [num_envs, 9] @staticmethod def _get_obs_global_anchor_pos_diff( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ): command = env.command_manager.get_term(ref_motion_command_name) env_ref_motion_anchor_pos = positions_world_to_env_frame( command.get_ref_motion_anchor_bodylink_global_pos_cur( prefix=ref_prefix ), env.scene.env_origins, ) global_ref_motino_anchor_rot_wxyz = ( command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur( prefix=ref_prefix ) ) global_robot_anchor_pos = ( ObservationFunctions._get_obs_global_robot_bodylink_pos( env, robot_asset_name, [command.anchor_bodylink_name] ).squeeze(1) ) global_robot_anchor_rot_wxyz = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz( env, robot_asset_name, [command.anchor_bodylink_name] ).squeeze(1) ) pos_diff, _ = isaaclab_math.subtract_frame_transforms( t01=global_robot_anchor_pos, q01=global_robot_anchor_rot_wxyz, t02=env_ref_motion_anchor_pos, q02=global_ref_motino_anchor_rot_wxyz, ) return pos_diff @staticmethod def _get_obs_global_anchor_rot_diff( env: ManagerBasedRLEnv, robot_asset_name: str = "robot", ref_motion_command_name: str = "ref_motion", ref_prefix: str = "ref_", ): command = env.command_manager.get_term(ref_motion_command_name) env_ref_motion_anchor_pos = positions_world_to_env_frame( command.get_ref_motion_anchor_bodylink_global_pos_cur( prefix=ref_prefix ), env.scene.env_origins, ) global_ref_motino_anchor_rot_wxyz = ( command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur( prefix=ref_prefix ) ) global_robot_anchor_pos = ( ObservationFunctions._get_obs_global_robot_bodylink_pos( env, robot_asset_name, [command.anchor_bodylink_name] ).squeeze(1) ) global_robot_anchor_rot_wxyz = ( ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz( env, robot_asset_name, [command.anchor_bodylink_name] ).squeeze(1) ) _, rot_diff = isaaclab_math.subtract_frame_transforms( t01=global_robot_anchor_pos, q01=global_robot_anchor_rot_wxyz, t02=env_ref_motion_anchor_pos, q02=global_ref_motino_anchor_rot_wxyz, ) rot_diff_mat = isaaclab_math.matrix_from_quat(rot_diff) return rot_diff_mat[..., :2].reshape(env.num_envs, -1) @staticmethod def _get_obs_velocity_command( env: ManagerBasedRLEnv, ): """Velocity command. This function should return the velocity command which has already been serialized into flattened vectors. Note that we also add a model switch mask dimension, when commands are small, the mode is set to 0, otherwise it is set to 1. """ velocity_command = isaaclab_mdp.generated_commands( env, command_name="base_velocity", ) # Some IsaacLab velocity commands may append extra channels (e.g., heading). # For velocity-tracking PPO we only use (vx, vy, yaw_rate) to keep the # observation contract stable. if velocity_command.shape[-1] > 3: velocity_command = velocity_command[..., :3] move_mask = (velocity_command.norm(dim=-1) > 0.1).to( dtype=velocity_command.dtype ) return torch.cat( [ move_mask[..., None], velocity_command, ], dim=-1, ) # [num_envs, 4] @staticmethod def _get_obs_place_holder(env: ManagerBasedRLEnv, n_dim: int): return torch.zeros(env.num_envs, n_dim, device=env.device) @staticmethod def _get_obs_ref_headling_aligned_vel_cmd( env: ManagerBasedRLEnv, ref_prefix: str = "ref_" ): heading_aligned_lin_vel_xyz = ObservationFunctions._get_obs_ref_motion_cur_heading_aligned_root_lin_vel( env, ref_prefix=ref_prefix ) heading_aligned_ang_vel_xyz = ObservationFunctions._get_obs_ref_motion_cur_heading_aligned_root_ang_vel( env, ref_prefix=ref_prefix ) heading_aligned_vel_cmd = torch.cat( [ heading_aligned_lin_vel_xyz[:, :2], heading_aligned_ang_vel_xyz[:, 2:3], ], dim=-1, ) move_mask = (heading_aligned_vel_cmd.norm(dim=-1) > 0.1).to( dtype=heading_aligned_vel_cmd.dtype ) heading_aligned_vel_cmd = torch.cat( [ move_mask[..., None], heading_aligned_vel_cmd, ], dim=-1, ) return heading_aligned_vel_cmd @staticmethod def _get_obs_heading_aligned_root_ang_vel(env: ManagerBasedRLEnv): root_global_ang_vel = ( ObservationFunctions._get_obs_global_robot_root_ang_vel(env) ) root_global_rot_wxyz = ( ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env) ) heading_quat_wxyz = isaaclab_math.yaw_quat(root_global_rot_wxyz) heading_aligned_root_ang_vel = isaaclab_math.quat_apply_inverse( heading_quat_wxyz, root_global_ang_vel ) return heading_aligned_root_ang_vel @configclass class ObservationsCfg: pass def build_observations_config(obs_config_dict: dict): """Build isaaclab-compatible ObservationsCfg from a config dictionary.""" if isinstance(obs_config_dict, (DictConfig, ListConfig)): obs_config_dict = OmegaConf.to_container(obs_config_dict, resolve=True) obs_cfg = ObservationsCfg() obs_term_field_names = { field.name for field in dataclass_fields(ObservationTermCfg) } # Create observation groups dynamically for group_name, group_cfg in obs_config_dict.items(): group_cfg = resolve_holo_config(group_cfg) isaaclab_obs_group_cfg = ObsGroup() for key, value in group_cfg.items(): if key == "atomic_obs_list": continue if hasattr(isaaclab_obs_group_cfg, key): setattr(isaaclab_obs_group_cfg, key, value) # Add observation terms to the group for obs_term_dict in group_cfg["atomic_obs_list"]: for obs_name, obs_params in obs_term_dict.items(): obs_params = resolve_holo_config(obs_params) func_name = obs_params.get("func", obs_name) method_name = f"_get_obs_{func_name}" if hasattr(ObservationFunctions, method_name): func = getattr(ObservationFunctions, method_name) elif hasattr(isaaclab_mdp, func_name): func = getattr(isaaclab_mdp, func_name) else: raise ValueError( f"Unknown observation function: {func_name}" ) obs_term_kwargs = {"func": func} try: params_cfg = obs_params.get("params", {}) except AttributeError: print(f"No params found for {obs_name}") obs_term_kwargs["params"] = resolve_holo_config(params_cfg) noise_cfg = obs_params.get("noise") if noise_cfg is not None: obs_term_kwargs["noise"] = _build_noise_cfg(noise_cfg) for field_name in obs_term_field_names: if field_name in {"func", "params", "noise"}: continue if field_name in obs_params: obs_term_kwargs[field_name] = obs_params[field_name] obs_term = ObsTerm(**obs_term_kwargs) # Add observation term to group setattr(isaaclab_obs_group_cfg, obs_name, obs_term) # Add group to main observations config setattr(obs_cfg, group_name, isaaclab_obs_group_cfg) return obs_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_rewards.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import torch from isaaclab.assets import Articulation from isaaclab.envs import ManagerBasedRLEnv from isaaclab.managers import ManagerTermBase, RewardTermCfg, SceneEntityCfg from isaaclab.sensors import ContactSensor from isaaclab.utils import configclass import isaaclab.utils.math as isaaclab_math from holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import ( RefMotionCommand, ) from holomotion.src.utils.frame_utils import ( positions_world_to_env_frame, root_relative_positions_from_env_frame, root_relative_positions_from_mixed_position_frames, ) import isaaclab.envs.mdp as isaaclab_mdp from hydra.utils import instantiate as hydra_instantiate from omegaconf import DictConfig, ListConfig, OmegaConf from loguru import logger from holomotion.src.env.isaaclab_components.isaaclab_utils import ( _get_body_indices, resolve_holo_config, _get_dof_indices, ) def _joint_ids_to_tensor( joint_ids: slice | list[int] | tuple[int, ...] | torch.Tensor | None, num_joints: int, device: torch.device | str, ) -> torch.Tensor: """Convert joint indices to a dense tensor in articulation order.""" if joint_ids is None: return torch.arange(num_joints, device=device, dtype=torch.long) if isinstance(joint_ids, slice): if joint_ids == slice(None): return torch.arange(num_joints, device=device, dtype=torch.long) return torch.arange(num_joints, device=device, dtype=torch.long)[ joint_ids ] if isinstance(joint_ids, torch.Tensor): return joint_ids.to(device=device, dtype=torch.long).flatten() return torch.tensor(joint_ids, device=device, dtype=torch.long) def _select_effort_limit_vector( asset: Articulation, selected_joint_ids: torch.Tensor, ) -> torch.Tensor: """Build a per-joint effort-limit vector from instantiated actuators.""" num_joints = asset.data.applied_torque.shape[1] device = asset.data.applied_torque.device dtype = asset.data.applied_torque.dtype effort_limit_vec = torch.zeros(num_joints, device=device, dtype=dtype) is_filled = torch.zeros(num_joints, device=device, dtype=torch.bool) for actuator in asset.actuators.values(): actuator_joint_ids = _joint_ids_to_tensor( actuator.joint_indices, num_joints=num_joints, device=device ) actuator_effort_limit = torch.as_tensor( actuator.effort_limit, device=device, dtype=dtype ) if actuator_effort_limit.ndim == 0: actuator_effort_limit = actuator_effort_limit.expand( actuator_joint_ids.numel() ) elif actuator_effort_limit.ndim == 2: if actuator_effort_limit.shape[0] > 1: reference = actuator_effort_limit[0].unsqueeze(0) if not torch.allclose( actuator_effort_limit, reference.expand_as(actuator_effort_limit), ): raise ValueError( "normed_torque_rate requires actuator effort limits to be static across envs." ) actuator_effort_limit = actuator_effort_limit[0] elif actuator_effort_limit.ndim != 1: raise ValueError( "normed_torque_rate expects actuator effort limits to be scalar, 1-D, or 2-D tensors." ) if actuator_effort_limit.numel() != actuator_joint_ids.numel(): raise ValueError( "normed_torque_rate found mismatched actuator joint indices and effort limits." ) effort_limit_vec[actuator_joint_ids] = actuator_effort_limit is_filled[actuator_joint_ids] = True if not torch.all(is_filled[selected_joint_ids]): missing_joint_ids = selected_joint_ids[~is_filled[selected_joint_ids]] raise ValueError( "normed_torque_rate could not resolve actuator effort limits for " f"joint ids {missing_joint_ids.tolist()}." ) selected_effort_limits = effort_limit_vec[selected_joint_ids] if not torch.all(torch.isfinite(selected_effort_limits)): raise ValueError( "normed_torque_rate requires finite actuator effort limits for all selected joints." ) if not torch.all(selected_effort_limits > 0.0): raise ValueError( "normed_torque_rate requires strictly positive actuator effort limits for all selected joints." ) return selected_effort_limits def key_dof_position_tracking_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", key_dofs: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) keydof_idxs = _get_dof_indices(command.robot, key_dofs) ref_dof_pos = command.get_ref_motion_dof_pos_immediate_next( prefix=ref_prefix ) error = torch.sum( torch.square( command.robot.data.joint_pos[:, keydof_idxs] - ref_dof_pos[:, keydof_idxs] ), dim=-1, ) return torch.exp(-error / std**2) def key_dof_velocity_tracking_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", key_dofs: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) keydof_idxs = _get_dof_indices(command.robot, key_dofs) ref_dof_vel = command.get_ref_motion_dof_vel_immediate_next( prefix=ref_prefix ) error = torch.sum( torch.square( command.robot.data.joint_vel[:, keydof_idxs] - ref_dof_vel[:, keydof_idxs] ), dim=-1, ) return torch.exp(-error / std**2) def motion_global_anchor_position_error_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: ref_motion_command: RefMotionCommand = env.command_manager.get_term( command_name ) ref_anchor_pos = ref_motion_command.get_ref_motion_anchor_bodylink_global_pos_immediate_next( prefix=ref_prefix ) robot_anchor_pos = ref_motion_command.global_robot_anchor_pos_cur error = torch.sum( torch.square(ref_anchor_pos - robot_anchor_pos), dim=-1, ) return torch.exp(-error / std**2) def motion_global_anchor_orientation_error_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) ref_anchor_quat = ( command.get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next( prefix=ref_prefix ) ) error = ( isaaclab_math.quat_error_magnitude( ref_anchor_quat, command.robot.data.body_quat_w[:, command.anchor_bodylink_idx], ) ** 2 ) return torch.exp(-error / std**2) def motion_relative_body_position_error_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) # Get body indexes based on body names (similar to whole_body_tracking implementation) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Get reference and robot anchor positions/orientations ref_anchor_pos = command.get_ref_motion_root_global_pos_immediate_next( prefix=ref_prefix ) # [B, 3] ref_anchor_quat = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, 4] (w,x,y,z) robot_anchor_pos = command.robot.data.body_pos_w[ :, command.anchor_bodylink_idx ] # [B, 3] robot_anchor_quat = command.robot.data.body_quat_w[ :, command.anchor_bodylink_idx ] # [B, 4] (w,x,y,z) # Get reference body positions in global frame ref_body_pos_global = ( command.get_ref_motion_bodylink_global_pos_immediate_next( prefix=ref_prefix ) ) # [B, num_bodies, 3] # Transform reference body positions to be relative to robot's current anchor # This follows the same logic as the whole_body_tracking implementation # Select relevant body indices first ref_body_pos_selected = ref_body_pos_global[ :, keybody_idxs ] # [B, selected_bodies, 3] # Expand anchor positions/orientations to match number of selected bodies num_bodies = len(keybody_idxs) ref_anchor_pos_exp = ref_anchor_pos[:, None, :].expand( -1, num_bodies, -1 ) # [B, num_bodies, 3] ref_anchor_quat_exp = ref_anchor_quat[:, None, :].expand( -1, num_bodies, -1 ) # [B, num_bodies, 4] robot_anchor_pos_exp = robot_anchor_pos[:, None, :].expand( -1, num_bodies, -1 ) # [B, num_bodies, 3] robot_anchor_quat_exp = robot_anchor_quat[:, None, :].expand( -1, num_bodies, -1 ) # [B, num_bodies, 4] # Create delta transformation (preserving z from reference, aligning xy to robot) delta_pos = robot_anchor_pos_exp.clone() delta_pos[..., 2] = ref_anchor_pos_exp[..., 2] # Keep reference Z height delta_ori = isaaclab_math.yaw_quat( isaaclab_math.quat_mul( robot_anchor_quat_exp, isaaclab_math.quat_inv(ref_anchor_quat_exp), ) ) # Transform reference body positions to relative frame ref_body_pos_relative = delta_pos + isaaclab_math.quat_apply( delta_ori, ref_body_pos_selected - ref_anchor_pos_exp ) # Get robot body positions robot_body_pos = command.robot.data.body_pos_w[:, keybody_idxs] # Compute error error = torch.sum( torch.square(ref_body_pos_relative - robot_body_pos), dim=-1, ) return torch.exp(-error.mean(-1) / std**2) def root_rel_keybodylink_pos_tracking_l2_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Track root-relative keybody positions using environment-frame positions. IsaacLab MDP root position helpers are expressed in the environment frame (simulation-world position minus `env.scene.env_origins`). This reward converts body positions into the same environment frame before computing root-relative vectors. """ command: RefMotionCommand = env.command_manager.get_term(command_name) # Get body indexes based on body names (similar to whole_body_tracking implementation) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Get reference and robot root positions/orientations ref_root_pos_env = positions_world_to_env_frame( command.get_ref_motion_root_global_pos_immediate_next( prefix=ref_prefix ), env.scene.env_origins, ) # [B, 3] ref_root_quat_w = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, 4] (w,x,y,z) robot_root_pos_env = isaaclab_mdp.root_pos_w(env) # [B, 3] robot_root_quat_w = isaaclab_mdp.root_quat_w(env) # [B, 4] (w,x,y,z) # Select relevant body indices first ref_body_pos_env = positions_world_to_env_frame( command.get_ref_motion_bodylink_global_pos_immediate_next( prefix=ref_prefix )[:, keybody_idxs], env.scene.env_origins, ) robot_body_pos_root_rel = ( root_relative_positions_from_mixed_position_frames( body_pos_w=command.robot.data.body_pos_w[:, keybody_idxs], root_pos_env=robot_root_pos_env, root_quat_w=robot_root_quat_w, env_origins=env.scene.env_origins, ) ) ref_body_pos_root_rel = root_relative_positions_from_env_frame( body_pos_env=ref_body_pos_env, root_pos_env=ref_root_pos_env, root_quat_w=ref_root_quat_w, ) # Compute error error = torch.sum( torch.square(ref_body_pos_root_rel - robot_body_pos_root_rel), dim=-1, ) return torch.exp(-error.mean(-1) / std**2) def motion_relative_body_orientation_error_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) # Get body indexes based on body names (similar to whole_body_tracking implementation) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Get reference and robot anchor orientations ref_anchor_quat = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, 4] (w,x,y,z) robot_anchor_quat = command.robot.data.body_quat_w[ :, command.anchor_bodylink_idx ] # [B, 4] (w,x,y,z) # Get reference body orientations in global frame ref_body_quat_global = ( command.get_ref_motion_bodylink_global_rot_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, num_bodies, 4] # Select relevant body indices ref_body_quat_selected = ref_body_quat_global[ :, keybody_idxs ] # [B, selected_bodies, 4] # Expand anchor orientations to match number of selected bodies num_bodies = len(keybody_idxs) ref_anchor_quat_exp = ref_anchor_quat[:, None, :].expand( -1, num_bodies, -1 ) # [B, num_bodies, 4] robot_anchor_quat_exp = robot_anchor_quat[:, None, :].expand( -1, num_bodies, -1 ) # [B, num_bodies, 4] # Compute relative orientation transformation (only yaw component) delta_ori = isaaclab_math.yaw_quat( isaaclab_math.quat_mul( robot_anchor_quat_exp, isaaclab_math.quat_inv(ref_anchor_quat_exp), ) ) # Transform reference body orientations to relative frame ref_body_quat_relative = isaaclab_math.quat_mul( delta_ori, ref_body_quat_selected ) # Get robot body orientations robot_body_quat = command.robot.data.body_quat_w[:, keybody_idxs] # Compute error error = ( isaaclab_math.quat_error_magnitude( ref_body_quat_relative, robot_body_quat ) ** 2 ) return torch.exp(-error.mean(-1) / std**2) def motion_global_body_linear_velocity_error_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) # Get body indexes based on body names (similar to whole_body_tracking implementation) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Direct comparison of global velocities (no coordinate transformation needed) ref_lin_vel = ( command.get_ref_motion_bodylink_global_lin_vel_immediate_next( prefix=ref_prefix )[:, keybody_idxs] ) robot_lin_vel = command.robot.data.body_lin_vel_w[:, keybody_idxs] error = torch.sum(torch.square(ref_lin_vel - robot_lin_vel), dim=-1) return torch.exp(-error.mean(-1) / std**2) def motion_global_body_angular_velocity_error_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) # Get body indexes based on body names (similar to whole_body_tracking implementation) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Direct comparison of global angular velocities (no coordinate transformation needed) ref_ang_vel = ( command.get_ref_motion_bodylink_global_ang_vel_immediate_next( prefix=ref_prefix )[:, keybody_idxs] ) robot_ang_vel = command.robot.data.body_ang_vel_w[:, keybody_idxs] error = torch.sum(torch.square(ref_ang_vel - robot_ang_vel), dim=-1) return torch.exp(-error.mean(-1) / std**2) def root_pos_xy_tracking_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) ref_root_pos = command.get_ref_motion_root_global_pos_immediate_next( prefix=ref_prefix ) error = torch.sum( torch.square( ref_root_pos[:, :2] - command.robot.data.root_pos_w[:, :2] ), dim=-1, ) return torch.exp(-error / std**2) def root_rot_tracking_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) ref_root_quat = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) error = ( isaaclab_math.quat_error_magnitude( ref_root_quat, isaaclab_mdp.root_quat_w(env), ) ** 2 ) return torch.exp(-error / std**2) def root_pos_rel_z_tracking_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: command: RefMotionCommand = env.command_manager.get_term(command_name) robot_root_z = command.robot.data.root_pos_w[:, 2] ref_root_z = command.get_ref_motion_root_global_pos_immediate_next( prefix=ref_prefix )[:, 2] dz_rel = robot_root_z - ref_root_z error = torch.square(dz_rel) return torch.exp(-error / std**2) def root_lin_vel_tracking_l2_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: """Track root linear velocity in each entity's own root frame. Returns: [B] """ command: RefMotionCommand = env.command_manager.get_term(command_name) # [B, 3], [B, 4] robot_root_lin_vel_w = isaaclab_mdp.root_lin_vel_w(env) robot_root_quat_w = isaaclab_mdp.root_quat_w(env) ref_root_lin_vel_w = ( command.get_ref_motion_root_global_lin_vel_immediate_next( prefix=ref_prefix ) ) ref_root_quat_w = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # Project to respective root frames robot_root_lin_vel = isaaclab_math.quat_apply( isaaclab_math.quat_inv(robot_root_quat_w), robot_root_lin_vel_w, ) # [B, 3] ref_root_lin_vel = isaaclab_math.quat_apply( isaaclab_math.quat_inv(ref_root_quat_w), ref_root_lin_vel_w, ) # [B, 3] error = torch.sum( torch.square(ref_root_lin_vel - robot_root_lin_vel), dim=-1 ) return torch.exp(-error / std**2) def root_ang_vel_tracking_l2_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: """Track root angular velocity in each entity's own root frame. Returns: [B] """ command: RefMotionCommand = env.command_manager.get_term(command_name) # [B, 3], [B, 4] robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env) robot_root_quat_w = isaaclab_mdp.root_quat_w(env) ref_root_ang_vel_w = ( command.get_ref_motion_root_global_ang_vel_immediate_next( prefix=ref_prefix ) ) ref_root_quat_w = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # Project to respective root frames robot_root_ang_vel = isaaclab_math.quat_apply( isaaclab_math.quat_inv(robot_root_quat_w), robot_root_ang_vel_w, ) # [B, 3] ref_root_ang_vel = isaaclab_math.quat_apply( isaaclab_math.quat_inv(ref_root_quat_w), ref_root_ang_vel_w, ) # [B, 3] error = torch.sum( torch.square(ref_root_ang_vel - robot_root_ang_vel), dim=-1 ) return torch.exp(-error / std**2) def root_rel_keybodylink_pos_tracking_l2_exp_bydmmc_style( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Track keybody positions using per-entity heading-aligned frames. For each of robot and reference: - subtract own root position (root-relative in world) - rotate by own yaw-only inverse (heading-aligned frame) Then compare these root-relative, heading-aligned positions. All positions are first converted into IsaacLab's environment frame (simulation world minus `env.scene.env_origins`) so robot root and body positions use the same translation convention. Returns: [B] """ command: RefMotionCommand = env.command_manager.get_term(command_name) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Root states in environment frame ref_root_pos = positions_world_to_env_frame( command.get_ref_motion_root_global_pos_immediate_next( prefix=ref_prefix ), env.scene.env_origins, ) # [B, 3] ref_root_quat = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, 4] robot_root_pos = isaaclab_mdp.root_pos_w(env) # [B, 3] robot_root_quat = isaaclab_mdp.root_quat_w(env) # [B, 4] # Body positions in environment frame robot_body_pos = positions_world_to_env_frame( command.robot.data.body_pos_w[:, keybody_idxs], env.scene.env_origins, ) # [B, N, 3] ref_body_pos = positions_world_to_env_frame( command.get_ref_motion_bodylink_global_pos_immediate_next( prefix=ref_prefix )[:, keybody_idxs], env.scene.env_origins, ) # [B, N, 3] # Expand for broadcasting num_bodies = len(keybody_idxs) ref_root_pos_exp = ref_root_pos[:, None, :].expand(-1, num_bodies, -1) ref_root_quat_exp = ref_root_quat[:, None, :].expand(-1, num_bodies, -1) robot_root_pos_exp = robot_root_pos[:, None, :].expand(-1, num_bodies, -1) robot_root_quat_exp = robot_root_quat[:, None, :].expand( -1, num_bodies, -1 ) # Yaw-only delta orientation (root frames) delta_ori = isaaclab_math.yaw_quat( isaaclab_math.quat_mul( robot_root_quat_exp, isaaclab_math.quat_inv(ref_root_quat_exp) ) ) # [B, N, 4] # Keep origin at root: compare root-relative vectors after yaw alignment robot_rel = robot_body_pos - robot_root_pos_exp # [B, N, 3] ref_rel = ref_body_pos - ref_root_pos_exp # [B, N, 3] ref_rel_aligned = isaaclab_math.quat_apply(delta_ori, ref_rel) # [B, N, 3] # Compare in world (root-relative) error = torch.sum( torch.square(ref_rel_aligned - robot_rel), dim=-1 ) # [B, N] return torch.exp(-error.mean(-1) / std**2) def root_rel_keybodylink_rot_tracking_l2_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Track root-relative keybody rotations in each entity's root frame. Returns: [B] """ command: RefMotionCommand = env.command_manager.get_term(command_name) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Root orientations robot_root_quat_w = isaaclab_mdp.root_quat_w(env) # [B, 4] ref_root_quat_w = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, 4] # Body orientations (world) robot_body_quat_w = command.robot.data.body_quat_w[ :, keybody_idxs ] # [B, N, 4] ref_body_quat_w = ( command.get_ref_motion_bodylink_global_rot_wxyz_immediate_next( prefix=ref_prefix )[:, keybody_idxs] ) # [B, N, 4] # Relative (q_rel = q_root^{-1} * q_body) num_bodies = len(keybody_idxs) robot_root_quat_inv_exp = isaaclab_math.quat_inv(robot_root_quat_w)[ :, None, : ].expand(-1, num_bodies, -1) ref_root_quat_inv_exp = isaaclab_math.quat_inv(ref_root_quat_w)[ :, None, : ].expand(-1, num_bodies, -1) robot_rel_quat = isaaclab_math.quat_mul( robot_root_quat_inv_exp, robot_body_quat_w, ) # [B, N, 4] ref_rel_quat = isaaclab_math.quat_mul( ref_root_quat_inv_exp, ref_body_quat_w, ) # [B, N, 4] error = ( isaaclab_math.quat_error_magnitude(ref_rel_quat, robot_rel_quat) ** 2 ) # [B, N] return torch.exp(-error.mean(-1) / std**2) def root_rel_keybodylink_lin_vel_tracking_l2_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Track keybody linear velocities with motion_relative frame alignment. Compute rigid-body-relative velocities for both entities w.r.t. their roots, yaw-align reference to robot using root quats, then compare in world space. Root/body positions used for rigid-body radius vectors are first converted into IsaacLab's environment frame (simulation world minus `env.scene.env_origins`) so the translation convention matches `isaaclab_mdp.root_pos_w(env)`. Returns: [B] """ command: RefMotionCommand = env.command_manager.get_term(command_name) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Root states robot_root_pos_w = isaaclab_mdp.root_pos_w(env) # [B, 3] robot_root_quat_w = isaaclab_mdp.root_quat_w(env) # [B, 4] robot_root_lin_vel_w = isaaclab_mdp.root_lin_vel_w(env) # [B, 3] robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env) # [B, 3] ref_root_pos_w = positions_world_to_env_frame( command.get_ref_motion_root_global_pos_immediate_next( prefix=ref_prefix ), env.scene.env_origins, ) # [B, 3] ref_root_quat_w = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, 4] ref_root_lin_vel_w = ( command.get_ref_motion_root_global_lin_vel_immediate_next( prefix=ref_prefix ) ) # [B, 3] ref_root_ang_vel_w = ( command.get_ref_motion_root_global_ang_vel_immediate_next( prefix=ref_prefix ) ) # [B, 3] # Body states (world) robot_body_pos_w = positions_world_to_env_frame( command.robot.data.body_pos_w[:, keybody_idxs], env.scene.env_origins, ) # [B, N, 3] robot_body_lin_vel_w = command.robot.data.body_lin_vel_w[ :, keybody_idxs ] # [B, N, 3] ref_body_pos_w = positions_world_to_env_frame( command.get_ref_motion_bodylink_global_pos_immediate_next( prefix=ref_prefix )[:, keybody_idxs], env.scene.env_origins, ) # [B, N, 3] ref_body_lin_vel_w = ( command.get_ref_motion_bodylink_global_lin_vel_immediate_next( prefix=ref_prefix )[:, keybody_idxs] ) # [B, N, 3] # Rigid-body relative (world) robot_r_w = robot_body_pos_w - robot_root_pos_w[:, None, :] ref_r_w = ref_body_pos_w - ref_root_pos_w[:, None, :] robot_cross = torch.cross( robot_root_ang_vel_w[:, None, :], robot_r_w, dim=-1 ) # [B, N, 3] ref_cross = torch.cross( ref_root_ang_vel_w[:, None, :], ref_r_w, dim=-1 ) # [B, N, 3] robot_v_rel_w = ( robot_body_lin_vel_w - robot_root_lin_vel_w[:, None, :] - robot_cross ) # [B, N, 3] ref_v_rel_w = ( ref_body_lin_vel_w - ref_root_lin_vel_w[:, None, :] - ref_cross ) # [B, N, 3] # Yaw-only delta orientation from root quats; rotate reference velocities num_bodies = len(keybody_idxs) robot_root_quat_exp = robot_root_quat_w[:, None, :].expand( -1, num_bodies, -1 ) # [B, N, 4] ref_root_quat_exp = ref_root_quat_w[:, None, :].expand( -1, num_bodies, -1 ) # [B, N, 4] delta_ori = isaaclab_math.yaw_quat( isaaclab_math.quat_mul( robot_root_quat_exp, isaaclab_math.quat_inv(ref_root_quat_exp) ) ) # [B, N, 4] ref_v_rel_aligned_w = isaaclab_math.quat_apply(delta_ori, ref_v_rel_w) error = torch.sum( torch.square(ref_v_rel_aligned_w - robot_v_rel_w), dim=-1 ) # [B, N] return torch.exp(-error.mean(-1) / std**2) def root_rel_keybodylink_ang_vel_tracking_l2_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Track root-relative keybody angular velocities in root frames. Uses w_rel_w = w_body - w_root, then rotates into each entity's root frame. Returns: [B] """ command: RefMotionCommand = env.command_manager.get_term(command_name) keybody_idxs = _get_body_indices(command.robot, keybody_names) # Root orientations and angular velocities robot_root_quat_w = isaaclab_mdp.root_quat_w(env) # [B, 4] robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env) # [B, 3] ref_root_quat_w = ( command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, 4] ref_root_ang_vel_w = ( command.get_ref_motion_root_global_ang_vel_immediate_next( prefix=ref_prefix ) ) # [B, 3] # Body angular velocities (world) robot_body_ang_vel_w = command.robot.data.body_ang_vel_w[ :, keybody_idxs ] # [B, N, 3] ref_body_ang_vel_w = ( command.get_ref_motion_bodylink_global_ang_vel_immediate_next( prefix=ref_prefix )[:, keybody_idxs] ) # [B, N, 3] # Relative (world), then rotate robot_w_rel_w = robot_body_ang_vel_w - robot_root_ang_vel_w[:, None, :] ref_w_rel_w = ref_body_ang_vel_w - ref_root_ang_vel_w[:, None, :] num_bodies = len(keybody_idxs) robot_root_quat_inv_exp = isaaclab_math.quat_inv(robot_root_quat_w)[ :, None, : ].expand(-1, num_bodies, -1) ref_root_quat_inv_exp = isaaclab_math.quat_inv(ref_root_quat_w)[ :, None, : ].expand(-1, num_bodies, -1) robot_w_rel = isaaclab_math.quat_apply( robot_root_quat_inv_exp, robot_w_rel_w, ) # [B, N, 3] ref_w_rel = isaaclab_math.quat_apply( ref_root_quat_inv_exp, ref_w_rel_w, ) # [B, N, 3] error = torch.sum(torch.square(ref_w_rel - robot_w_rel), dim=-1) # [B, N] return torch.exp(-error.mean(-1) / std**2) def global_keybodylink_lin_vel_tracking_l2_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Track global keybody linear velocities.""" command: RefMotionCommand = env.command_manager.get_term(command_name) keybody_idxs = _get_body_indices(command.robot, keybody_names) ref_global_keybody_lin_vel = ( command.get_ref_motion_bodylink_global_lin_vel_immediate_next( prefix=ref_prefix )[:, keybody_idxs] ) # [B, N, 3] robot_keybody_lin_vel = command.robot.data.body_lin_vel_w[ :, keybody_idxs ] # [B, N, 3] error = torch.sum( torch.square(ref_global_keybody_lin_vel - robot_keybody_lin_vel), dim=-1, ) # [B, N] return torch.exp(-error.mean(-1) / std**2) def global_keybodylink_ang_vel_tracking_l2_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Track global keybody angular velocities.""" command: RefMotionCommand = env.command_manager.get_term(command_name) keybody_idxs = _get_body_indices(command.robot, keybody_names) ref_global_keybody_ang_vel = ( command.get_ref_motion_bodylink_global_ang_vel_immediate_next( prefix=ref_prefix )[:, keybody_idxs] ) # [B, N, 3] robot_keybody_ang_vel = command.robot.data.body_ang_vel_w[ :, keybody_idxs ] # [B, N, 3] error = torch.sum( torch.square(ref_global_keybody_ang_vel - robot_keybody_ang_vel), dim=-1, ) # [B, N] return torch.exp(-error.mean(-1) / std**2) # @torch.compile def feet_contact_time( env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg, threshold: float, ) -> torch.Tensor: contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] first_air = contact_sensor.compute_first_air(env.step_dt, env.physics_dt)[ :, sensor_cfg.body_ids ] last_contact_time = contact_sensor.data.last_contact_time[ :, sensor_cfg.body_ids ] reward = torch.sum((last_contact_time < threshold) * first_air, dim=-1) return reward def track_lin_vel_xy_yaw_frame_exp( env: ManagerBasedRLEnv, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """Track linear velocity (xy) in the gravity-aligned yaw frame using exponential kernel. This mirrors the implementation in IsaacLab locomotion velocity MDP. """ asset: Articulation = env.scene[asset_cfg.name] vel_yaw = isaaclab_math.quat_apply_inverse( isaaclab_math.yaw_quat(asset.data.root_quat_w), asset.data.root_lin_vel_w[:, :3], ) # vel_yaw = isaaclab_math.quat_rotate_inverse( # isaaclab_math.yaw_quat(asset.data.root_quat_w), # asset.data.root_lin_vel_w[:, :3], # ) lin_vel_error = torch.sum( torch.square( env.command_manager.get_command(command_name)[:, :2] - vel_yaw[:, :2] ), dim=1, ) return torch.exp(-lin_vel_error / (std**2)) def feet_slide( env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """Penalize feet sliding when in contact using contact forces and foot linear velocity.""" contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] contacts = ( contact_sensor.data.net_forces_w_history[:, :, sensor_cfg.body_ids, :] .norm(dim=-1) .max(dim=1)[0] > 1.0 ) asset: Articulation = env.scene[asset_cfg.name] body_vel = asset.data.body_lin_vel_w[:, asset_cfg.body_ids, :2] reward = torch.sum(body_vel.norm(dim=-1) * contacts, dim=1) return reward def feet_slide_ang_vel( env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """Penalize feet sliding when in contact using contact forces and foot linear velocity.""" contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] contacts = ( contact_sensor.data.net_forces_w_history[:, :, sensor_cfg.body_ids, :] .norm(dim=-1) .max(dim=1)[0] > 1.0 ) asset: Articulation = env.scene[asset_cfg.name] body_ang_vel = asset.data.body_ang_vel_w[:, asset_cfg.body_ids, 2:3] reward = torch.sum(body_ang_vel.norm(dim=-1) * contacts, dim=1) return reward def foot_clearance_reward( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg, target_height: float, std: float, tanh_mult: float, sensor_cfg: SceneEntityCfg, ) -> torch.Tensor: """Reward swinging feet clearing a target height with velocity-shaped kernel. Only rewards feet that are swinging (not in contact) and are close to the target height. """ asset: Articulation = env.scene[asset_cfg.name] foot_z = asset.data.body_pos_w[:, asset_cfg.body_ids, 2] # [B, N] delta_z = target_height - foot_z delta_z = torch.clamp(delta_z, min=0.0) # only penalze if below target foot_z_error = torch.square(delta_z) # [B, N] # Only reward swinging feet (not in contact) is_swinging = torch.ones_like(foot_z_error, dtype=torch.bool) contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] is_contact = ( contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0 ) # [B, N] is_swinging = ~is_contact # Gate reward by horizontal velocity to ensure feet are actually moving foot_horizontal_vel = torch.norm( asset.data.body_lin_vel_w[:, asset_cfg.body_ids, :2], dim=2 ) # [B, N] velocity_gate = torch.tanh(tanh_mult * foot_horizontal_vel) # [B, N] # Reward: high when error is low (at target height) and foot is swinging reward_per_foot = ( torch.exp(-foot_z_error / std**2) * velocity_gate * is_swinging.float() ) return torch.sum(reward_per_foot, dim=1) def feet_gait( env: ManagerBasedRLEnv, period: float, offset: list[float], sensor_cfg: SceneEntityCfg, threshold: float = 0.5, command_name=None, ) -> torch.Tensor: contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] is_contact = ( contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0 ) global_phase = ( (env.episode_length_buf * env.step_dt) % period / period ).unsqueeze(1) phases = [] for offset_ in offset: phase = (global_phase + offset_) % 1.0 phases.append(phase) leg_phase = torch.cat(phases, dim=-1) reward = torch.zeros(env.num_envs, dtype=torch.float, device=env.device) for i in range(len(sensor_cfg.body_ids)): is_stance = leg_phase[:, i] < threshold reward += ~(is_stance ^ is_contact[:, i]) if command_name is not None: cmd_norm = torch.norm( env.command_manager.get_command(command_name), dim=1 ) reward *= cmd_norm > 0.1 return reward joint_deviation_l1_arms = isaaclab_mdp.joint_deviation_l1 joint_deviation_l1_arms_roll = isaaclab_mdp.joint_deviation_l1 joint_deviation_l1_waists = isaaclab_mdp.joint_deviation_l1 joint_deviation_l1_legs = isaaclab_mdp.joint_deviation_l1 joint_deviation_l1_legs_yaw = isaaclab_mdp.joint_deviation_l1 joint_deviation_l1_stand_still = isaaclab_mdp.joint_deviation_l1 def joint_deviation_l2( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """Penalize joint positions that deviate from the default one.""" # extract the used quantities (to enable type-hinting) asset: Articulation = env.scene[asset_cfg.name] # compute out of limits constraints angle = ( asset.data.joint_pos[:, asset_cfg.joint_ids] - asset.data.default_joint_pos[:, asset_cfg.joint_ids] ) return torch.sum(torch.square(angle), dim=1) joint_deviation_l2_arms_roll = joint_deviation_l2 joint_deviation_l2_arms = joint_deviation_l2 joint_deviation_l2_waists = joint_deviation_l2 joint_deviation_l2_legs = joint_deviation_l2 joint_deviation_l2_shoulder_roll = joint_deviation_l2 joint_deviation_l2_hip_roll = joint_deviation_l2 def energy( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """Penalize the energy used by the robot's joints.""" asset: Articulation = env.scene[asset_cfg.name] qvel = asset.data.joint_vel[:, asset_cfg.joint_ids] qfrc = asset.data.applied_torque[:, asset_cfg.joint_ids] return torch.sum(torch.abs(qvel) * torch.abs(qfrc), dim=-1) def positive_work( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """Penalize only the positive mechanical work (energy injected) by the joints.""" asset: Articulation = env.scene[asset_cfg.name] qvel = asset.data.joint_vel[:, asset_cfg.joint_ids] qfrc = asset.data.applied_torque[:, asset_cfg.joint_ids] # Calculate raw mechanical power (positive = motoring, negative = braking) power = qfrc * qvel # Only keep positive values, zero out negative (braking) work positive_power = torch.relu(power) return torch.sum(positive_power, dim=-1) class normed_positive_work(ManagerTermBase): """Penalize positive joint work normalized by effort and velocity limits.""" def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): super().__init__(cfg, env) self._asset_name: str | None = None self._joint_ids: torch.Tensor | None = None self._inv_effort_limit: torch.Tensor | None = None def _maybe_build_cache( self, env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg, ) -> Articulation: asset: Articulation = env.scene[asset_cfg.name] joint_ids = _joint_ids_to_tensor( getattr(asset_cfg, "joint_ids", None), num_joints=asset.data.applied_torque.shape[1], device=asset.data.applied_torque.device, ) cache_needs_refresh = ( self._asset_name != asset_cfg.name or self._joint_ids is None or not torch.equal(self._joint_ids, joint_ids) or self._inv_effort_limit is None or self._inv_effort_limit.shape != (joint_ids.numel(),) or self._inv_effort_limit.device != asset.data.applied_torque.device or self._inv_effort_limit.dtype != asset.data.applied_torque.dtype ) if not cache_needs_refresh: return asset effort_limit = _select_effort_limit_vector(asset, joint_ids) self._asset_name = asset_cfg.name self._joint_ids = joint_ids self._inv_effort_limit = effort_limit.reciprocal() return asset def __call__( self, env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: asset = self._maybe_build_cache(env, asset_cfg) joint_ids = self._joint_ids inv_effort_limit = self._inv_effort_limit assert joint_ids is not None assert inv_effort_limit is not None current_torque = asset.data.applied_torque[:, joint_ids] current_joint_vel = asset.data.joint_vel[:, joint_ids] joint_vel_limits = asset.data.joint_vel_limits[:, joint_ids] if not torch.all(torch.isfinite(joint_vel_limits)) or not torch.all( joint_vel_limits > 0.0 ): raise ValueError( "normed_positive_work requires finite, strictly positive " "joint velocity limits for all selected joints." ) normalized_power = (current_torque * inv_effort_limit) * ( current_joint_vel / joint_vel_limits ) return torch.sum(torch.relu(normalized_power), dim=-1) class normed_torque_rate(ManagerTermBase): """Penalize joint torque-rate changes normalized by actuator effort limits.""" def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): super().__init__(cfg, env) self._asset_name: str | None = None self._joint_ids: torch.Tensor | None = None self._inv_effort_limit: torch.Tensor | None = None self._prev_applied_torque: torch.Tensor | None = None self._needs_reseed = torch.ones( self.num_envs, device=self.device, dtype=torch.bool ) def reset(self, env_ids=None) -> None: if env_ids is None: self._needs_reseed[:] = True return if isinstance(env_ids, slice): self._needs_reseed[env_ids] = True return env_ids_tensor = torch.as_tensor( env_ids, device=self.device, dtype=torch.long ) self._needs_reseed[env_ids_tensor] = True def _maybe_build_cache( self, env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg, ) -> Articulation: asset: Articulation = env.scene[asset_cfg.name] joint_ids = _joint_ids_to_tensor( getattr(asset_cfg, "joint_ids", None), num_joints=asset.data.applied_torque.shape[1], device=asset.data.applied_torque.device, ) cache_needs_refresh = ( self._asset_name != asset_cfg.name or self._joint_ids is None or not torch.equal(self._joint_ids, joint_ids) or self._prev_applied_torque is None or self._prev_applied_torque.shape != (env.num_envs, joint_ids.numel()) or self._prev_applied_torque.device != asset.data.applied_torque.device or self._prev_applied_torque.dtype != asset.data.applied_torque.dtype ) if not cache_needs_refresh: return asset effort_limit = _select_effort_limit_vector(asset, joint_ids) self._asset_name = asset_cfg.name self._joint_ids = joint_ids self._inv_effort_limit = effort_limit.reciprocal() self._prev_applied_torque = torch.zeros( env.num_envs, joint_ids.numel(), device=asset.data.applied_torque.device, dtype=asset.data.applied_torque.dtype, ) self._needs_reseed = torch.ones( env.num_envs, device=asset.data.applied_torque.device, dtype=torch.bool, ) return asset def __call__( self, env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: asset = self._maybe_build_cache(env, asset_cfg) joint_ids = self._joint_ids inv_effort_limit = self._inv_effort_limit prev_applied_torque = self._prev_applied_torque assert joint_ids is not None assert inv_effort_limit is not None assert prev_applied_torque is not None current_torque = asset.data.applied_torque[:, joint_ids] reward = torch.zeros( env.num_envs, device=current_torque.device, dtype=current_torque.dtype, ) reseed_mask = self._needs_reseed.clone() if hasattr(env, "episode_length_buf"): reseed_mask |= env.episode_length_buf == 0 active_mask = ~reseed_mask if torch.any(active_mask): delta = ( current_torque[active_mask] - prev_applied_torque[active_mask] ) * inv_effort_limit reward[active_mask] = torch.sum(delta.square(), dim=1) prev_applied_torque.copy_(current_torque) self._needs_reseed[reseed_mask] = False return reward def track_stand_still_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "base_velocity", asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """Track stand still joint position using exponential kernel when command velocity is low. Returns: [B] """ asset: Articulation = env.scene[asset_cfg.name] error = torch.sum( torch.square(asset.data.joint_pos - asset.data.default_joint_pos), dim=1, ) # Use generated velocity commands (vx, vy, yaw_rate). Some command terms may # expose additional channels (e.g., heading) via get_command(). cmd = isaaclab_mdp.generated_commands(env, command_name=command_name) if cmd.shape[-1] > 3: cmd = cmd[..., :3] cmd_norm = torch.norm(cmd, dim=1) return torch.exp(-error / std**2) * (cmd_norm < 0.1) def stand_still_joint_deviation_l1( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), command_name: str = "base_velocity", ) -> torch.Tensor: """Penalize L1 joint deviation from default pose when command velocity is low. Returns: [B] """ asset: Articulation = env.scene[asset_cfg.name] # L1 error: sum(|q - q_default|) error = torch.sum( torch.abs(asset.data.joint_pos - asset.data.default_joint_pos), dim=1, ) cmd = isaaclab_mdp.generated_commands(env, command_name=command_name) if cmd.shape[-1] > 3: cmd = cmd[..., :3] cmd_norm = torch.norm(cmd, dim=1) # Return error (to be penalized with negative weight) only when standing still return error * (cmd_norm < 0.1) def stand_still_action_rate( env: ManagerBasedRLEnv, command_name: str = "base_velocity", ) -> torch.Tensor: cmd = isaaclab_mdp.generated_commands(env, command_name=command_name) if cmd.shape[-1] > 3: cmd = cmd[..., :3] stand_still = torch.norm(cmd, dim=1) < 0.1 return ( torch.sum( torch.square( env.action_manager.action - env.action_manager.prev_action ), dim=1, ) * stand_still ) def stand_still_dof_vel_l2( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), command_name: str = "base_velocity", ) -> torch.Tensor: cmd = isaaclab_mdp.generated_commands(env, command_name=command_name) if cmd.shape[-1] > 3: cmd = cmd[..., :3] stand_still = torch.norm(cmd, dim=1) < 0.1 return ( torch.sum( torch.square(env.scene[asset_cfg.name].data.joint_vel), dim=1, ) * stand_still ) class action_acc(ManagerTermBase): """Penalize the change in action-rate using a stateful second difference.""" def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): super().__init__(cfg, env) self._prev_action: torch.Tensor | None = None self._prev_action_rate: torch.Tensor | None = None self._needs_reseed = torch.ones( self.num_envs, device=self.device, dtype=torch.bool ) self._needs_prev_rate = torch.ones( self.num_envs, device=self.device, dtype=torch.bool ) def reset(self, env_ids=None) -> None: if env_ids is None: self._needs_reseed[:] = True self._needs_prev_rate[:] = True return if isinstance(env_ids, slice): self._needs_reseed[env_ids] = True self._needs_prev_rate[env_ids] = True return env_ids_tensor = torch.as_tensor( env_ids, device=self.device, dtype=torch.long ) self._needs_reseed[env_ids_tensor] = True self._needs_prev_rate[env_ids_tensor] = True def _maybe_build_cache( self, env: ManagerBasedRLEnv ) -> tuple[torch.Tensor, torch.Tensor]: current_action = env.action_manager.action cache_needs_refresh = ( self._prev_action is None or self._prev_action_rate is None or self._prev_action.shape != current_action.shape or self._prev_action.device != current_action.device or self._prev_action.dtype != current_action.dtype or self._prev_action_rate.shape != current_action.shape or self._prev_action_rate.device != current_action.device or self._prev_action_rate.dtype != current_action.dtype ) if cache_needs_refresh: self._prev_action = torch.zeros_like(current_action) self._prev_action_rate = torch.zeros_like(current_action) self._needs_reseed = torch.ones( env.num_envs, device=current_action.device, dtype=torch.bool, ) self._needs_prev_rate = torch.ones( env.num_envs, device=current_action.device, dtype=torch.bool, ) assert self._prev_action is not None assert self._prev_action_rate is not None return self._prev_action, self._prev_action_rate def __call__(self, env: ManagerBasedRLEnv) -> torch.Tensor: current_action = env.action_manager.action prev_action, prev_action_rate = self._maybe_build_cache(env) reward = torch.zeros( env.num_envs, device=current_action.device, dtype=current_action.dtype, ) reseed_mask = self._needs_reseed.clone() if hasattr(env, "episode_length_buf"): reseed_mask |= env.episode_length_buf == 0 if torch.any(reseed_mask): prev_action[reseed_mask] = current_action[reseed_mask] prev_action_rate[reseed_mask].zero_() self._needs_prev_rate[reseed_mask] = True active_mask = ~reseed_mask if torch.any(active_mask): current_action_rate = ( current_action[active_mask] - prev_action[active_mask] ) ready_mask = ~self._needs_prev_rate[active_mask] if torch.any(ready_mask): action_acc_value = ( current_action_rate[ready_mask] - prev_action_rate[active_mask][ready_mask] ) reward[ active_mask.nonzero(as_tuple=False).flatten()[ready_mask] ] = torch.sum(action_acc_value.square(), dim=1) prev_action[active_mask] = current_action[active_mask] prev_action_rate[active_mask] = current_action_rate self._needs_prev_rate[active_mask] = False self._needs_reseed[reseed_mask] = False return reward action_acc_l2 = action_acc def feet_stumble( env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg ) -> torch.Tensor: # extract the used quantities (to enable type-hinting) contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] forces_z = torch.abs( contact_sensor.data.net_forces_w[:, sensor_cfg.body_ids, 2] ) forces_xy = torch.linalg.norm( contact_sensor.data.net_forces_w[:, sensor_cfg.body_ids, :2], dim=2 ) # Penalize feet hitting vertical surfaces reward = torch.any(forces_xy > 4 * forces_z, dim=1).float() return reward def feet_too_near( env: ManagerBasedRLEnv, threshold: float = 0.2, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: asset: Articulation = env.scene[asset_cfg.name] feet_pos = asset.data.body_pos_w[:, asset_cfg.body_ids, :] distance = torch.norm(feet_pos[:, 0] - feet_pos[:, 1], dim=-1) return (threshold - distance).clamp(min=0) def feet_contact_without_cmd( env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg, command_name: str = "base_velocity", ) -> torch.Tensor: """ Reward for feet contact when the command is zero. """ # asset: Articulation = env.scene[asset_cfg.name] contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] is_contact = ( contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0 ) cmd = isaaclab_mdp.generated_commands(env, command_name=command_name) if cmd.shape[-1] > 3: cmd = cmd[..., :3] command_norm = torch.norm(cmd, dim=1) reward = torch.sum(is_contact, dim=-1).float() return reward * (command_norm < 0.1) def torso_xy_ang_vel_l2_penalty( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: robot_ptr = env.scene[asset_cfg.name] torso_idx = robot_ptr.body_names.index("torso_link") # World-frame torso angular velocity: [B, 3] torso_ang_vel_w: torch.Tensor = robot_ptr.data.body_ang_vel_w[ :, torso_idx, : ] # Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading. # Build yaw-only quaternion from stored heading_w (shape [B]). heading_yaw: torch.Tensor = robot_ptr.data.heading_w # [B] zero = torch.zeros_like(heading_yaw, device=env.device) heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz( roll=zero, pitch=zero, yaw=heading_yaw, ) # [B, 4] # Re-express torso angular velocity in heading-aligned frame. heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz) torso_ang_vel_h: torch.Tensor = isaaclab_math.quat_apply( heading_inv_wxyz, torso_ang_vel_w, ) # [B, 3] # Penalize lateral components (x, y) with squared magnitude. penalty: torch.Tensor = torch.sum( torch.square(torso_ang_vel_h[:, :2]), dim=-1, ) # [B] return penalty def torso_upright_l2_penalty( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: robot_ptr = env.scene[asset_cfg.name] torso_idx = robot_ptr.body_names.index("torso_link") torso_rot_quat_w = robot_ptr.data.body_quat_w[:, torso_idx, :] # Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading. # Build yaw-only quaternion from stored heading_w (shape [B]). heading_yaw: torch.Tensor = robot_ptr.data.heading_w # [B] zero = torch.zeros_like(heading_yaw, device=env.device) heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz( roll=zero, pitch=zero, yaw=heading_yaw, ) # [B, 4] # Re-express torso angular velocity in heading-aligned frame. heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz) torso_rot_quat_h: torch.Tensor = isaaclab_math.quat_mul( heading_inv_wxyz, torso_rot_quat_w, ) # [B, 3] # get the roll and pitch roll, pitch, _ = isaaclab_math.euler_xyz_from_quat(torso_rot_quat_h) pitch *= pitch > 0.0 rollpitch = torch.stack([roll * 2.0, pitch], dim=-1) # Penalize lateral components (x, y) with squared magnitude. penalty: torch.Tensor = torch.sum( torch.square(rollpitch), dim=-1, ) # [B] return penalty def torso_upright_l2_penalty_v2( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), target_pitch: float = 0.0, roll_scale: float = 2.0, pitch_scale: float = 1.0, ) -> torch.Tensor: """Penalize torso roll/pitch deviation in a heading-aligned frame (symmetric). Compared to `torso_upright_l2_penalty`, this version penalizes *both* forward and backward pitch w.r.t. `target_pitch`. Returns: [B] """ robot_ptr = env.scene[asset_cfg.name] torso_idx = robot_ptr.body_names.index("torso_link") torso_rot_quat_w: torch.Tensor = robot_ptr.data.body_quat_w[ :, torso_idx, : ] # [B, 4] heading_yaw: torch.Tensor = robot_ptr.data.heading_w # [B] zero = torch.zeros_like(heading_yaw, device=env.device) heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz( roll=zero, pitch=zero, yaw=heading_yaw, ) # [B, 4] heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz) torso_rot_quat_h: torch.Tensor = isaaclab_math.quat_mul( heading_inv_wxyz, torso_rot_quat_w, ) # [B, 4] roll, pitch, _ = isaaclab_math.euler_xyz_from_quat(torso_rot_quat_h) # [B] roll_err: torch.Tensor = roll_scale * roll pitch_err: torch.Tensor = pitch_scale * (pitch - target_pitch) roll_pitch = torch.stack([roll_err, pitch_err], dim=-1) # [B, 2] penalty: torch.Tensor = torch.sum(torch.square(roll_pitch), dim=-1) # [B] return penalty def stand_still_torso_upright_exp_v2( env: ManagerBasedRLEnv, std: float, command_name: str = "base_velocity", cmd_threshold: float = 0.1, target_pitch: float = 0.0, roll_scale: float = 2.0, pitch_scale: float = 1.0, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """Reward torso uprightness under stand-still commands using an exp kernel. Reward: exp(-penalty / std^2) if ||cmd|| <= cmd_threshold else 0 where penalty is computed by `torso_upright_l2_penalty_v2`. Returns: [B] """ command = env.command_manager.get_command(command_name) stand_still_flag: torch.Tensor = ( torch.norm(command, dim=1) <= cmd_threshold ) penalty = torso_upright_l2_penalty_v2( env, asset_cfg=asset_cfg, target_pitch=target_pitch, roll_scale=roll_scale, pitch_scale=pitch_scale, ) # [B] reward = torch.exp(-penalty / std**2) # [B] return reward * stand_still_flag.to(dtype=reward.dtype) def torso_linacc_xy_l2_penalty( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: robot_ptr = env.scene[asset_cfg.name] torso_idx = robot_ptr.body_names.index("torso_link") # World-frame torso angular velocity: [B, 3] torso_linacc_w = robot_ptr.data.body_lin_acc_w[:, torso_idx, :] # Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading. # Build yaw-only quaternion from stored heading_w (shape [B]). heading_yaw: torch.Tensor = robot_ptr.data.heading_w # [B] zero = torch.zeros_like(heading_yaw, device=env.device) heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz( roll=zero, pitch=zero, yaw=heading_yaw, ) # [B, 4] # Re-express torso angular velocity in heading-aligned frame. heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz) torso_linacc_h: torch.Tensor = isaaclab_math.quat_apply( heading_inv_wxyz, torso_linacc_w, ) # [B, 3] # Penalize lateral components (x, y) with squared magnitude. penalty: torch.Tensor = torch.sum( torch.square(torso_linacc_h), dim=-1, ) # [B] return penalty def track_lin_vel_xy_heading_aligned_frame_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "base_velocity", asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """ Track linear velocity (xy) in the heading-aligned frame using exponential kernel. Returns: [B] """ asset: Articulation = env.scene[asset_cfg.name] vel_yaw = isaaclab_math.quat_apply_inverse( isaaclab_math.yaw_quat(asset.data.root_quat_w), asset.data.root_lin_vel_w[:, :3], ) command = env.command_manager.get_command(command_name) stand_still_envs = torch.norm(command, dim=1) <= 0.1 # treat yaw-only envs as zero-translation targets too # (vx, vy are approx 0 by definition) zero_lin_vel_envs = stand_still_envs tracking_targets = torch.where( zero_lin_vel_envs[:, None], 0.0, command[:, :2] ) lin_vel_error = torch.sum( torch.square(tracking_targets - vel_yaw[:, :2]), dim=1, ) return torch.exp(-lin_vel_error / std**2) def track_lin_vel_xy_heading_aligned_frame_exp_v2( env: ManagerBasedRLEnv, std: float, command_name: str = "base_velocity", asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: asset: Articulation = env.scene[asset_cfg.name] vel_yaw = isaaclab_math.quat_apply_inverse( isaaclab_math.yaw_quat(asset.data.root_quat_w), asset.data.root_lin_vel_w[:, :3], ) command = env.command_manager.get_command(command_name) yaw_envs = (torch.norm(command[:, :2], dim=1) < 0.1) & ( torch.abs(command[:, 2]) > 0.1 ) stand_still_envs = torch.norm(command, dim=1) <= 0.1 # treat yaw-only envs as zero-translation targets too # (vx, vy are approx 0 by definition) zero_lin_vel_envs = stand_still_envs | yaw_envs tracking_targets = torch.where( zero_lin_vel_envs[:, None], 0.0, command[:, :2] ) lin_vel_error = torch.sum( torch.square(tracking_targets - vel_yaw[:, :2]), dim=1, ) # encourage zero linear velocity for stand still environments, and encourage yaw-only environments to have more # precise zero linear velocity tracking too reward_weights = torch.where(yaw_envs, 2.0, 1.0) + torch.where( stand_still_envs, 10.0, 0.0 ) return reward_weights * torch.exp(-lin_vel_error / std**2) def track_ang_vel_z_heading_aligned_frame_exp_v2( env: ManagerBasedRLEnv, std: float, command_name: str = "base_velocity", asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """ Track angular velocity (z) in the heading-aligned frame using exponential kernel. Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame. Returns: [B] """ asset: Articulation = env.scene[asset_cfg.name] command = env.command_manager.get_command(command_name) yaw_envs = (torch.norm(command[:, :2], dim=1) < 0.1) & ( torch.abs(command[:, 2]) > 0.1 ) stand_still_envs = torch.norm(command, dim=1) <= 0.1 # set the tracking targets to 0.0 for stand still environments tracking_targets = torch.where(stand_still_envs, 0.0, command[:, 2]) ang_vel_error = torch.square( tracking_targets - asset.data.root_ang_vel_w[:, 2] ) # encourage zero angular velocity for stand still environments, and encourage yaw-only environments to have more # precise angular velocity tracking reward_weights = torch.where(yaw_envs, 2.0, 1.0) + torch.where( stand_still_envs, 10.0, 0.0 ) return reward_weights * torch.exp(-ang_vel_error / std**2) def track_ang_vel_z_heading_aligned_frame_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "base_velocity", asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """ Track angular velocity (z) in the heading-aligned frame using exponential kernel. Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame. Returns: [B] """ asset: Articulation = env.scene[asset_cfg.name] command = env.command_manager.get_command(command_name) stand_still_envs = torch.norm(command, dim=1) <= 0.1 tracking_targets = torch.where(stand_still_envs, 0.0, command[:, 2]) ang_vel_error = torch.square( tracking_targets - asset.data.root_ang_vel_w[:, 2] ) return torch.exp(-ang_vel_error / std**2) def smoothed_track_ang_vel_z_heading_aligned_frame_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "base_velocity", asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), ) -> torch.Tensor: """ Track angular velocity (z) in the heading-aligned frame using exponential kernel. Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame. Returns: [B] """ asset: Articulation = env.scene[asset_cfg.name] command = env.command_manager.get_command(command_name) hist_robot_heading_aligned_ang_vel_z = env.observation_manager.compute()[ "unified" ]["rew_heading_aligned_root_ang_vel"][..., 2] ep_len = env.episode_length_buf obs_window_len = hist_robot_heading_aligned_ang_vel_z.shape[1] smooth_window = torch.minimum( torch.full_like(ep_len, obs_window_len), ep_len ) smoothed_robot_heading_aligned_ang_vel_z = ( hist_robot_heading_aligned_ang_vel_z.sum(dim=1) / smooth_window ) ang_vel_error = torch.square( command[:, 2] - smoothed_robot_heading_aligned_ang_vel_z ) return torch.exp(-ang_vel_error / std**2) def feet_air_time( env: ManagerBasedRLEnv, threshold: float, sensor_cfg: SceneEntityCfg, command_name: str = "base_velocity", ) -> torch.Tensor: contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids] contact_time = contact_sensor.data.current_contact_time[ :, sensor_cfg.body_ids ] in_contact = contact_time > 0.0 in_mode_time = torch.where(in_contact, contact_time, air_time) single_stance = torch.sum(in_contact.int(), dim=1) == 1 reward = torch.min( torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1 )[0] reward = torch.clamp(reward, max=threshold) # no reward for zero command command = env.command_manager.get_command(command_name) reward *= ( torch.norm(command[:, :2], dim=1) + torch.abs(command[:, 2]) ) > 0.1 return reward def feet_air_time_v2( env: ManagerBasedRLEnv, threshold: float, sensor_cfg: SceneEntityCfg, command_name: str = "base_velocity", ) -> torch.Tensor: contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids] contact_time = contact_sensor.data.current_contact_time[ :, sensor_cfg.body_ids ] in_contact = contact_time > 0.0 in_mode_time = torch.where(in_contact, contact_time, air_time) single_stance = torch.sum(in_contact.int(), dim=1) == 1 reward = torch.min( torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1 )[0] reward = torch.clamp(reward, max=threshold) # no reward for zero command command = env.command_manager.get_command(command_name) stand_still_envs_flag = torch.norm(command, dim=1) <= 0.1 ang_z_only_mask = (torch.norm(command[:, :2], dim=1) <= 0.1) & ( torch.abs(command[:, 2]) > 0.1 ) # Stand still: 0.0, yaw-only: 10.0, other: 1.0 reward_weights = torch.where( stand_still_envs_flag, 0.0, 1.0 ) + torch.where(ang_z_only_mask, 5.0, 0.0) return reward * reward_weights def feet_air_time_v3( env: ManagerBasedRLEnv, command_name: str, sensor_cfg: SceneEntityCfg, threshold: float, ) -> torch.Tensor: """Reward long steps taken by the feet using L2-kernel. This function rewards the agent for taking steps that are longer than a threshold. This helps ensure that the robot lifts its feet off the ground and takes steps. The reward is computed as the sum of the time for which the feet are in the air. If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero. """ # extract the used quantities (to enable type-hinting) contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] # compute the reward first_contact = contact_sensor.compute_first_contact(env.step_dt)[ :, sensor_cfg.body_ids ] last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids] reward = torch.sum((last_air_time - threshold) * first_contact, dim=1) # no reward for stand still commands, larger reward for yaw-only commands commands = env.command_manager.get_command(command_name) stand_still_envs = torch.norm(commands, dim=1) <= 0.1 yaw_only_envs = (torch.norm(commands[:, :2], dim=1) <= 0.1) & ( torch.abs(commands[:, 2]) > 0.1 ) reward_weights = torch.where(stand_still_envs, 0.0, 1.0) + torch.where( yaw_only_envs, 4.0, 0.0 ) return reward * reward_weights def feet_air_time_v4( env: ManagerBasedRLEnv, command_name: str, sensor_cfg: SceneEntityCfg, threshold: float, ) -> torch.Tensor: """Reward long steps taken by the feet using L2-kernel. This function rewards the agent for taking steps that are longer than a threshold. This helps ensure that the robot lifts its feet off the ground and takes steps. The reward is computed as the sum of the time for which the feet are in the air. If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero. """ # extract the used quantities (to enable type-hinting) contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] # compute the reward first_contact = contact_sensor.compute_first_contact(env.step_dt)[ :, sensor_cfg.body_ids ] last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids] reward = torch.sum((last_air_time - threshold) * first_contact, dim=1) # no reward for stand still commands, larger reward for yaw-only commands commands = env.command_manager.get_command(command_name) stand_still_envs = torch.norm(commands, dim=1) <= 0.1 reward_weights = torch.where(stand_still_envs, 0.0, 1.0) return reward * reward_weights def yaw_rate_only_movement_l2_penalty( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), command_name: str = "base_velocity", ) -> torch.Tensor: """Penalize world-frame root XY translation during yaw-only commands. When vx, vy are small commands, penalize the squared magnitude of root linear velocity (vx, vy) in world frame. Returns: [B] """ # Velocity command: [B, 3] (vx, vy, yaw_rate). Some command terms may # expose extra channels (e.g., heading) via generated_commands(). command = env.command_manager.get_command(command_name) # Gate only yaw-rate-only envs: vx=vy=0 and v_yaw > 0.0. yaw_only_mask: torch.Tensor = ( torch.norm(command[:, :2], dim=1) <= 0.1 ) # [B] # Penalize global (world-frame) root linear velocity in x/y. asset: Articulation = env.scene[asset_cfg.name] root_lin_vel_w: torch.Tensor = asset.data.root_lin_vel_w # [B, 3] penalty: torch.Tensor = torch.sum( torch.square(root_lin_vel_w[:, :2]), dim=1, ) # [B] return penalty * yaw_only_mask.to(dtype=penalty.dtype) def fly( env: ManagerBasedRLEnv, threshold: float, sensor_cfg: SceneEntityCfg, ) -> torch.Tensor: contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] net_contact_forces = contact_sensor.data.net_forces_w_history is_contact = ( torch.max( torch.norm(net_contact_forces[:, :, sensor_cfg.body_ids], dim=-1), dim=1, )[0] > threshold ) return torch.sum(is_contact, dim=-1) < 0.5 def stand_still_torso_lin_vel_l2_penalty( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), command_name: str = "base_velocity", ) -> torch.Tensor: robot_ptr = env.scene[asset_cfg.name] torso_idx = robot_ptr.body_names.index("torso_link") torso_lin_vel_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :] penalty = torch.sum(torch.square(torso_lin_vel_w), dim=-1) command = env.command_manager.get_command(command_name) stand_still_flag = torch.norm(command, dim=1) <= 0.1 return penalty * stand_still_flag.to(dtype=penalty.dtype) def stand_still_torso_ang_vel_l2_penalty( env: ManagerBasedRLEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), command_name: str = "base_velocity", ) -> torch.Tensor: robot_ptr = env.scene[asset_cfg.name] torso_idx = robot_ptr.body_names.index("torso_link") torso_ang_vel_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :] penalty = torch.sum(torch.square(torso_ang_vel_w), dim=-1) command = env.command_manager.get_command(command_name) stand_still_flag = torch.norm(command, dim=1) <= 0.1 return penalty * stand_still_flag.to(dtype=penalty.dtype) def stand_still_torso_lin_vel_exp( env: ManagerBasedRLEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), command_name: str = "base_velocity", ) -> torch.Tensor: """Reward staying still (zero torso linear velocity) when commanded to stand. Uses exponential kernel: exp(-||v||^2 / std^2) Args: env: Environment instance std: Standard deviation for exponential kernel asset_cfg: Robot asset configuration command_name: Name of velocity command Returns: Reward tensor of shape [B], active only when stand still commanded """ robot_ptr = env.scene[asset_cfg.name] torso_idx = robot_ptr.body_names.index("torso_link") torso_lin_vel_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :] error = torch.sum(torch.square(torso_lin_vel_w), dim=-1) reward = torch.exp(-error / std**2) command = env.command_manager.get_command(command_name) stand_still_flag = torch.norm(command, dim=1) <= 0.1 return reward * stand_still_flag.to(dtype=reward.dtype) def stand_still_torso_ang_vel_exp( env: ManagerBasedRLEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), command_name: str = "base_velocity", ) -> torch.Tensor: """Reward staying still (zero torso angular velocity) when commanded to stand. Uses exponential kernel: exp(-||omega||^2 / std^2) Args: env: Environment instance std: Standard deviation for exponential kernel asset_cfg: Robot asset configuration command_name: Name of velocity command Returns: Reward tensor of shape [B], active only when stand still commanded """ robot_ptr = env.scene[asset_cfg.name] torso_idx = robot_ptr.body_names.index("torso_link") torso_ang_vel_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :] error = torch.sum(torch.square(torso_ang_vel_w), dim=-1) reward = torch.exp(-error / std**2) command = env.command_manager.get_command(command_name) stand_still_flag = torch.norm(command, dim=1) <= 0.1 return reward * stand_still_flag.to(dtype=reward.dtype) def yaw_rate_only_movement_exp( env: ManagerBasedRLEnv, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), command_name: str = "base_velocity", ) -> torch.Tensor: """Reward minimal XY translation during yaw-only commands. When vx, vy commands are small, reward staying in place using exponential kernel. Uses exponential kernel: exp(-||v_xy||^2 / std^2) Args: env: Environment instance std: Standard deviation for exponential kernel asset_cfg: Robot asset configuration command_name: Name of velocity command Returns: Reward tensor of shape [B], active only during yaw-only commands """ command = env.command_manager.get_command(command_name) yaw_only_mask: torch.Tensor = torch.norm(command[:, :2], dim=1) <= 0.1 asset: Articulation = env.scene[asset_cfg.name] root_lin_vel_w: torch.Tensor = asset.data.root_lin_vel_w error: torch.Tensor = torch.sum(torch.square(root_lin_vel_w[:, :2]), dim=1) reward: torch.Tensor = torch.exp(-error / std**2) return reward * yaw_only_mask.to(dtype=reward.dtype) def yaw_rate_only_hip_yaw_usage_exp( env: ManagerBasedRLEnv, std: float, command_name: str = "base_velocity", hip_yaw_dofs: list[str] | None = None, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), lin_threshold: float = 0.1, yaw_threshold: float = 0.1, command_tanh_mult: float = 1.0, ) -> torch.Tensor: """Encourage using hip_yaw joint(s) during yaw-rate-only commands. Active only when commanded to rotate in place (vx, vy small and |yaw_rate| large). Rewards hip_yaw joint velocity magnitude using a saturating exponential kernel: r = (1 - exp(-mean(qd_hip_yaw^2) / std^2)) * tanh(command_tanh_mult * |cmd_yaw|) Shapes: - command: [B, 3] (vx, vy, yaw_rate) - asset.data.joint_vel: [B, num_dofs] - return: [B] """ command: torch.Tensor = env.command_manager.get_command(command_name) yaw_only_mask: torch.Tensor = ( torch.norm(command[:, :2], dim=1) <= lin_threshold ) & (torch.abs(command[:, 2]) > yaw_threshold) # [B] asset: Articulation = env.scene[asset_cfg.name] if hip_yaw_dofs is None: raise ValueError( "yaw_rate_only_hip_yaw_usage_exp requires hip_yaw_dofs (joint names in " f"robot.joint_names). Got hip_yaw_dofs=None. robot.joint_names={asset.joint_names}" ) hip_yaw_joint_ids: list[int] = _get_dof_indices(asset, hip_yaw_dofs) hip_yaw_vel: torch.Tensor = asset.data.joint_vel[ :, hip_yaw_joint_ids ] # [B, N] activity_sq: torch.Tensor = torch.mean( torch.square(hip_yaw_vel), dim=-1 ) # [B] usage_reward: torch.Tensor = 1.0 - torch.exp(-activity_sq / std**2) # [B] cmd_yaw_abs: torch.Tensor = torch.abs(command[:, 2]) # [B] cmd_weight: torch.Tensor = torch.tanh( command_tanh_mult * cmd_yaw_abs ) # [B] reward: torch.Tensor = usage_reward * cmd_weight return reward * yaw_only_mask.to(dtype=reward.dtype) @configclass class RewardsCfg: pass class TaskGatedReward: """Callable wrapper to gate reward terms by task_id.""" def __init__(self, func, task_name: str): self.func = func self.task_name = task_name self.__name__ = f"TaskGatedReward[{task_name}]" def __call__(self, env: ManagerBasedRLEnv, *args, **kwargs): task_ids = getattr(env, "holo_task_ids", None) mapping = getattr(env, "holo_task_name_to_id", None) if task_ids is None or mapping is None: return torch.zeros(env.num_envs, device=env.device) target = mapping.get(self.task_name, None) if target is None: return torch.zeros(env.num_envs, device=env.device) mask = task_ids == target if not torch.any(mask): return torch.zeros(env.num_envs, device=env.device) inner_args = kwargs.pop("args", None) inner_kwargs = kwargs.pop("kwargs", None) call_args = args if inner_args is None else (*args, *inner_args) call_kwargs = ( kwargs if inner_kwargs is None else {**kwargs, **inner_kwargs} ) reward = self.func(env, *call_args, **call_kwargs) mask = mask.to(device=reward.device, dtype=reward.dtype) return reward * mask def build_rewards_config(reward_config_dict: dict): if isinstance(reward_config_dict, (DictConfig, ListConfig)): reward_config_dict = OmegaConf.to_container( reward_config_dict, resolve=True ) rewards_cfg = RewardsCfg() # Detect grouped (multi-task) vs flat (legacy) layout def _is_grouped(cfg: dict) -> bool: for k, v in cfg.items(): if k == "_config": continue if isinstance(v, dict) and "weight" in v: return False return True return False is_grouped = _is_grouped(reward_config_dict) if not is_grouped: for reward_name, reward_cfg in reward_config_dict.items(): if reward_name == "_config": continue reward_cfg = resolve_holo_config(reward_cfg) base_params = resolve_holo_config(reward_cfg["params"]) method_name = f"{reward_name}" func = globals().get(method_name, None) if func is None: func = getattr(isaaclab_mdp, reward_name, None) if func is None: raise ValueError(f"Unknown reward function: {reward_name}") params = dict(base_params) setattr( rewards_cfg, reward_name, RewardTermCfg( func=func, weight=reward_cfg["weight"], params=params, ), ) return rewards_cfg # Grouped: rewards: {task_name: {term: ...}} for task_name, task_group in reward_config_dict.items(): if task_name.startswith("_"): continue if not isinstance(task_group, dict): raise ValueError(f"Expected dict for task group {task_name}") for reward_name, reward_cfg in task_group.items(): reward_cfg = resolve_holo_config(reward_cfg) base_params = resolve_holo_config(reward_cfg["params"]) method_name = f"{reward_name}" func = globals().get(method_name, None) if func is None: func = getattr(isaaclab_mdp, reward_name, None) if func is None: raise ValueError(f"Unknown reward function: {reward_name}") if task_name != "common": func = TaskGatedReward(func, task_name) params = {"args": [], "kwargs": base_params} else: params = base_params flat_name = f"{task_name}.{reward_name}" setattr( rewards_cfg, flat_name, RewardTermCfg( func=func, weight=reward_cfg["weight"], params=params, ), ) return rewards_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_scene.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import copy import os import time from dataclasses import MISSING import isaaclab.sim as sim_utils from isaaclab.actuators import ImplicitActuatorCfg from isaaclab.assets import ArticulationCfg, AssetBaseCfg from isaaclab.scene import InteractiveSceneCfg from isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns from isaaclab.terrains import TerrainImporterCfg from isaaclab.utils import configclass from loguru import logger from holomotion.src.env.isaaclab_components.isaaclab_terrain import ( build_terrain_config, ) from holomotion.src.env.isaaclab_components.unitree_actuators import ( UnitreeActuator, UnitreeActuatorCfg, UnitreeErfiActuator, UnitreeErfiActuatorCfg, ) class SceneFunctions: """Collection of scene component builders.""" @staticmethod def build_robot_config( config: dict, domain_rand_config: dict | None = None, main_process: bool = True, process_id: int = 0, num_processes: int = 1, ) -> ArticulationCfg: """Build robot articulation configuration. Args: config: Robot configuration dictionary main_process: Whether this is the main process (from compiled config) process_id: Process ID/rank (from compiled config) num_processes: Total number of processes (from compiled config) """ urdf_path = config.asset.urdf_file init_pos = config.init_state.pos default_joint_positions = config.init_state.default_joint_angles root_link_name = config.get("root_name", "pelvis") prim_path = "{ENV_REGEX_NS}/Robot" actuator_type = config.actuators.get("actuator_type", "implicit") if actuator_type in {"unitree", "unitree_erfi"}: actuators = _build_unitree_actuator_cfg( config.actuators, domain_rand_config or {} ) else: actuators = { "all_joints": ImplicitActuatorCfg( **config.actuators.all_joints ) } logger.info(f"Using {actuator_type} actuators") logger.info(f"Actuators: {actuators}") if not os.path.exists(urdf_path): raise FileNotFoundError(f"URDF file not found: {urdf_path}") # Configure USD output directory. Optionally isolate per rank to avoid races. usd_base_dir = os.path.join(os.path.dirname(urdf_path), "usd") unique_usd_per_rank = True if num_processes > 1 and unique_usd_per_rank: usd_dir = os.path.join(usd_base_dir, f"rank_{process_id}") else: usd_dir = usd_base_dir os.makedirs(usd_dir, exist_ok=True) logger.info(f"Using URDF path: {urdf_path}") logger.info(f"Using USD directory: {usd_dir}") force_usd_conversion = config.asset.get("force_usd_conversion", True) if num_processes > 1 and unique_usd_per_rank: # Ensure each rank generates its own USD into its isolated directory force_usd_conversion = True # Handle DDP if num_processes > 1: logger.info( f"[Process {process_id}/{num_processes}] Distributed training detected" ) if unique_usd_per_rank: logger.info( f"[Process {process_id}] Using per-rank USD dir; forcing USD conversion: {force_usd_conversion}" ) else: # Only main process should convert USD to avoid file conflicts if main_process: logger.info( f"[Process {process_id}] Main process - Force USD conversion: {force_usd_conversion}" ) else: logger.info( f"[Process {process_id}] Non-main process - Skipping USD conversion, waiting for main process" ) force_usd_conversion = False # Wait for USD files to be created by main process urdf_basename = os.path.splitext( os.path.basename(urdf_path) )[0] expected_usd_file = os.path.join( usd_dir, f"{urdf_basename}.usd" ) logger.info( f"[Process {process_id}] Waiting for main process to create USD files at {expected_usd_file}..." ) max_wait = 60 wait_interval = 1 waited = 0 while ( not os.path.exists(expected_usd_file) and waited < max_wait ): time.sleep(wait_interval) waited += wait_interval if os.path.exists(expected_usd_file): logger.info( f"[Process {process_id}] USD file found, proceeding with loading" ) else: logger.warning( f"[Process {process_id}] USD file not found after {max_wait}s, proceeding anyway" ) else: logger.info( f"Single process training. Force USD conversion: {force_usd_conversion}" ) articulation_cfg = ArticulationCfg( prim_path=prim_path, spawn=sim_utils.UrdfFileCfg( asset_path=os.path.abspath(urdf_path), usd_dir=os.path.abspath(usd_dir), force_usd_conversion=force_usd_conversion, fix_base=False, merge_fixed_joints=True, root_link_name=root_link_name, replace_cylinders_with_capsules=True, activate_contact_sensors=True, rigid_props=sim_utils.RigidBodyPropertiesCfg( disable_gravity=False, retain_accelerations=False, linear_damping=0.0, angular_damping=0.0, max_linear_velocity=1000.0, max_angular_velocity=1000.0, max_depenetration_velocity=1.0, ), articulation_props=sim_utils.ArticulationRootPropertiesCfg( enabled_self_collisions=True, solver_position_iteration_count=8, solver_velocity_iteration_count=4, ), joint_drive=sim_utils.UrdfConverterCfg.JointDriveCfg( gains=sim_utils.UrdfConverterCfg.JointDriveCfg.PDGainsCfg( stiffness=0, damping=0, ) ), ), init_state=ArticulationCfg.InitialStateCfg( pos=init_pos, joint_pos=default_joint_positions, joint_vel={".*": 0.0}, ), soft_joint_pos_limit_factor=0.9, actuators=actuators, ) return articulation_cfg @staticmethod def build_lighting_config( config: dict, ) -> tuple[AssetBaseCfg, AssetBaseCfg]: """Build lighting configuration.""" distant_light_intensity = config.get("distant_light_intensity", 3000.0) dome_light_intensity = config.get("dome_light_intensity", 1000.0) distant_light_color = config.get( "distant_light_color", (0.75, 0.75, 0.75) ) dome_light_color = config.get("dome_light_color", (0.13, 0.13, 0.13)) light = AssetBaseCfg( prim_path="/World/light", spawn=sim_utils.DistantLightCfg( color=distant_light_color, intensity=distant_light_intensity ), ) sky_light = AssetBaseCfg( prim_path="/World/skyLight", spawn=sim_utils.DomeLightCfg( color=dome_light_color, intensity=dome_light_intensity ), ) return light, sky_light @staticmethod def build_contact_sensor_config(config: dict) -> ContactSensorCfg: """Build contact sensor configuration.""" prim_path = config.get("prim_path", "{ENV_REGEX_NS}/Robot/.*") history_length = config.get("history_length", 3) force_threshold = config.get("force_threshold", 10.0) track_air_time = config.get("track_air_time", True) debug_vis = config.get("debug_vis", False) return ContactSensorCfg( prim_path=prim_path, history_length=history_length, track_air_time=track_air_time, force_threshold=force_threshold, debug_vis=debug_vis, ) @configclass class MotionTrackingSceneCfg(InteractiveSceneCfg): """Scene configuration for motion tracking environment.""" pass def build_scene_config( scene_config_dict: dict, main_process: bool = True, process_id: int = 0, num_processes: int = 1, ) -> MotionTrackingSceneCfg: """Build IsaacLab-compatible scene configuration from config dictionary. Args: scene_config_dict: Scene configuration dictionary main_process: Whether this is the main process (from compiled config) process_id: Process ID/rank (from compiled config) num_processes: Total number of processes (from compiled config) """ scene_cfg = MotionTrackingSceneCfg() # Basic scene properties scene_cfg.num_envs = scene_config_dict.get("num_envs", MISSING) scene_cfg.env_spacing = scene_config_dict.get("env_spacing", 2.5) scene_cfg.replicate_physics = scene_config_dict.get( "replicate_physics", True ) # Build robot configuration with process info if "robot" in scene_config_dict: robot_config = scene_config_dict["robot"] scene_cfg.robot = SceneFunctions.build_robot_config( robot_config, domain_rand_config=scene_config_dict.get("domain_rand", {}), main_process=main_process, process_id=process_id, num_processes=num_processes, ) # Build terrain configuration if "terrain" in scene_config_dict: terrain_config = scene_config_dict["terrain"] scene_cfg.terrain = build_terrain_config( terrain_config, scene_env_spacing=scene_cfg.env_spacing ) if "robot" in scene_config_dict: scene_cfg.height_scanner = RayCasterCfg( prim_path="{ENV_REGEX_NS}/Robot", offset=RayCasterCfg.OffsetCfg(pos=(0.0, 0.0, 1.0)), ray_alignment="world", pattern_cfg=patterns.GridPatternCfg( resolution=1.0, size=(1.0e-3, 1.0e-3) ), debug_vis=False, mesh_prim_paths=[str(scene_cfg.terrain.prim_path)], max_distance=1.0e6, ) # Build lighting configuration if "lighting" in scene_config_dict: lighting_config = scene_config_dict["lighting"] light, sky_light = SceneFunctions.build_lighting_config( lighting_config ) scene_cfg.light = light scene_cfg.sky_light = sky_light # Build contact sensor configuration if "contact_sensor" in scene_config_dict: contact_config = scene_config_dict["contact_sensor"] scene_cfg.contact_forces = SceneFunctions.build_contact_sensor_config( contact_config ) return scene_cfg def _cfg_to_kwargs(cfg: object) -> dict: return { key: copy.deepcopy(value) for key, value in vars(cfg).items() if not key.startswith("_") } def _build_unitree_actuator_cfg( config: dict, domain_rand_config: dict ) -> dict[str, object]: base_cfg = unitree_actuator_config_hardcoded["all_joints"] base_kwargs = _cfg_to_kwargs(base_cfg) action_delay_cfg = copy.deepcopy( domain_rand_config.get("action_delay", {}) ) if action_delay_cfg.get("enabled", False): delay_kwargs = { "min_delay": int(action_delay_cfg.get("min_delay", 0)), "max_delay": int(action_delay_cfg.get("max_delay", 0)), } else: delay_kwargs = {"min_delay": 0, "max_delay": 0} if config.get("actuator_type", "unitree") == "unitree_erfi": erfi_cfg = copy.deepcopy(domain_rand_config.get("erfi", {})) actuator_filter_kwargs = { "ema_filter_enabled": bool( config.get("ema_filter_enabled", False) ), "ema_filter_alpha": config.get("ema_filter_alpha", 1.0), } erfi_kwargs = { "erfi_enabled": bool(erfi_cfg.get("enabled", False)), "rfi_probability": erfi_cfg.get("rfi_probability", 0.5), "rfi_lim": erfi_cfg.get("rfi_lim", 0.1), "randomize_rfi_lim": erfi_cfg.get("randomize_rfi_lim", True), "rfi_lim_range": erfi_cfg.get("rfi_lim_range", (0.5, 1.5)), "rao_lim": erfi_cfg.get("rao_lim", 0.1), } actuator_kwargs = { **base_kwargs, **delay_kwargs, **actuator_filter_kwargs, **erfi_kwargs, } actuator_cfg = UnitreeErfiActuatorCfg(**actuator_kwargs) actuator_cfg.class_type = UnitreeErfiActuator else: actuator_kwargs = {**base_kwargs, **delay_kwargs} actuator_cfg = UnitreeActuatorCfg(**actuator_kwargs) actuator_cfg.class_type = UnitreeActuator return {"all_joints": actuator_cfg} unitree_actuator_config_hardcoded = { "all_joints": UnitreeActuatorCfg( joint_names_expr=[ ".*_hip_yaw_joint", ".*_hip_roll_joint", ".*_hip_pitch_joint", ".*_knee_joint", ".*_ankle_pitch_joint", ".*_ankle_roll_joint", "waist_roll_joint", "waist_pitch_joint", "waist_yaw_joint", ".*_shoulder_pitch_joint", ".*_shoulder_roll_joint", ".*_shoulder_yaw_joint", ".*_elbow_joint", ".*_wrist_roll_joint", ".*_wrist_pitch_joint", ".*_wrist_yaw_joint", ], min_delay=0, max_delay=0, effort_limit={ ".*_hip_yaw_joint": 88, ".*_hip_roll_joint": 139, ".*_hip_pitch_joint": 88, ".*_knee_joint": 139, ".*_ankle_pitch_joint": 50, ".*_ankle_roll_joint": 50, "waist_roll_joint": 50, "waist_pitch_joint": 50, "waist_yaw_joint": 88, ".*_shoulder_pitch_joint": 25, ".*_shoulder_roll_joint": 25, ".*_shoulder_yaw_joint": 25, ".*_elbow_joint": 25, ".*_wrist_roll_joint": 25, ".*_wrist_pitch_joint": 5, ".*_wrist_yaw_joint": 5, }, velocity_limit={ ".*_hip_yaw_joint": 32, ".*_hip_roll_joint": 20, ".*_hip_pitch_joint": 32, ".*_knee_joint": 20, ".*_ankle_pitch_joint": 37, ".*_ankle_roll_joint": 37, "waist_roll_joint": 37, "waist_pitch_joint": 37, "waist_yaw_joint": 32, ".*_shoulder_pitch_joint": 37, ".*_shoulder_roll_joint": 37, ".*_shoulder_yaw_joint": 37, ".*_elbow_joint": 37, ".*_wrist_roll_joint": 37, ".*_wrist_pitch_joint": 22, ".*_wrist_yaw_joint": 22, }, stiffness={ ".*_hip_yaw_joint": 40.1792384737, ".*_hip_roll_joint": 99.0984277823, ".*_hip_pitch_joint": 40.1792384737, ".*_knee_joint": 99.0984277823, ".*_ankle_pitch_joint": 28.5012461974, ".*_ankle_roll_joint": 28.5012461974, "waist_roll_joint": 28.5012461974, "waist_pitch_joint": 28.5012461974, "waist_yaw_joint": 40.1792384737, ".*_shoulder_pitch_joint": 14.2506230987, ".*_shoulder_roll_joint": 14.2506230987, ".*_shoulder_yaw_joint": 14.2506230987, ".*_elbow_joint": 14.2506230987, ".*_wrist_roll_joint": 14.2506230987, ".*_wrist_pitch_joint": 16.7783274819, ".*_wrist_yaw_joint": 16.7783274819, }, damping={ ".*_hip_yaw_joint": 2.5578897651, ".*_hip_roll_joint": 6.30880185368, ".*_hip_pitch_joint": 2.5578897651, ".*_knee_joint": 6.30880185368, ".*_ankle_pitch_joint": 1.81444568664, ".*_ankle_roll_joint": 1.81444568664, "waist_roll_joint": 1.81444568664, "waist_pitch_joint": 1.81444568664, "waist_yaw_joint": 2.5578897651, ".*_shoulder_pitch_joint": 0.907222843318, ".*_shoulder_roll_joint": 0.907222843318, ".*_shoulder_yaw_joint": 0.907222843318, ".*_elbow_joint": 0.907222843318, ".*_wrist_roll_joint": 0.907222843318, ".*_wrist_pitch_joint": 1.06814150222, ".*_wrist_yaw_joint": 1.06814150222, }, armature={ ".*_hip_yaw_joint": 0.01017752, ".*_hip_roll_joint": 0.025101925, ".*_hip_pitch_joint": 0.01017752, ".*_knee_joint": 0.025101925, ".*_ankle_pitch_joint": 0.00721945, ".*_ankle_roll_joint": 0.00721945, "waist_roll_joint": 0.00721945, "waist_pitch_joint": 0.00721945, "waist_yaw_joint": 0.01017752, ".*_shoulder_pitch_joint": 0.003609725, ".*_shoulder_roll_joint": 0.003609725, ".*_shoulder_yaw_joint": 0.003609725, ".*_elbow_joint": 0.003609725, ".*_wrist_roll_joint": 0.003609725, ".*_wrist_pitch_joint": 0.00425, ".*_wrist_yaw_joint": 0.00425, }, friction=0, dynamic_friction=0, viscous_friction=0, X1={ ".*_hip_yaw_joint": 22.63, ".*_hip_roll_joint": 14.5, ".*_hip_pitch_joint": 22.63, ".*_knee_joint": 14.5, ".*_ankle_pitch_joint": 30.86, ".*_ankle_roll_joint": 30.86, "waist_roll_joint": 30.86, "waist_pitch_joint": 30.86, "waist_yaw_joint": 22.63, ".*_shoulder_pitch_joint": 30.86, ".*_shoulder_roll_joint": 30.86, ".*_shoulder_yaw_joint": 30.86, ".*_elbow_joint": 30.86, ".*_wrist_roll_joint": 30.86, ".*_wrist_pitch_joint": 15.3, ".*_wrist_yaw_joint": 15.3, }, X2={ ".*_hip_yaw_joint": 35.52, ".*_hip_roll_joint": 22.7, ".*_hip_pitch_joint": 35.52, ".*_knee_joint": 22.7, ".*_ankle_pitch_joint": 40.13, ".*_ankle_roll_joint": 40.13, "waist_roll_joint": 40.13, "waist_pitch_joint": 40.13, "waist_yaw_joint": 35.52, ".*_shoulder_pitch_joint": 40.13, ".*_shoulder_roll_joint": 40.13, ".*_shoulder_yaw_joint": 40.13, ".*_elbow_joint": 40.13, ".*_wrist_roll_joint": 40.13, ".*_wrist_pitch_joint": 24.76, ".*_wrist_yaw_joint": 24.76, }, Y1={ ".*_hip_yaw_joint": 71, ".*_hip_roll_joint": 111, ".*_hip_pitch_joint": 71, ".*_knee_joint": 111, ".*_ankle_pitch_joint": 24.8, ".*_ankle_roll_joint": 24.8, "waist_roll_joint": 24.8, "waist_pitch_joint": 24.8, "waist_yaw_joint": 71, ".*_shoulder_pitch_joint": 24.8, ".*_shoulder_roll_joint": 24.8, ".*_shoulder_yaw_joint": 24.8, ".*_elbow_joint": 24.8, ".*_wrist_roll_joint": 24.8, ".*_wrist_pitch_joint": 4.8, ".*_wrist_yaw_joint": 4.8, }, Y2={ ".*_hip_yaw_joint": 83.3, ".*_hip_roll_joint": 131, ".*_hip_pitch_joint": 83.3, ".*_knee_joint": 131, ".*_ankle_pitch_joint": 31.9, ".*_ankle_roll_joint": 31.9, "waist_roll_joint": 31.9, "waist_pitch_joint": 31.9, "waist_yaw_joint": 83.3, ".*_shoulder_pitch_joint": 31.9, ".*_shoulder_roll_joint": 31.9, ".*_shoulder_yaw_joint": 31.9, ".*_elbow_joint": 31.9, ".*_wrist_roll_joint": 31.9, ".*_wrist_pitch_joint": 8.6, ".*_wrist_yaw_joint": 8.6, }, Fs={ ".*_hip_yaw_joint": 1.6, ".*_hip_roll_joint": 2.4, ".*_hip_pitch_joint": 1.6, ".*_knee_joint": 2.4, ".*_ankle_pitch_joint": 0.6, ".*_ankle_roll_joint": 0.6, "waist_roll_joint": 0.6, "waist_pitch_joint": 0.6, "waist_yaw_joint": 1.6, ".*_shoulder_pitch_joint": 0.6, ".*_shoulder_roll_joint": 0.6, ".*_shoulder_yaw_joint": 0.6, ".*_elbow_joint": 0.6, ".*_wrist_roll_joint": 0.6, ".*_wrist_pitch_joint": 0.6, ".*_wrist_yaw_joint": 0.6, }, Fd={ ".*_hip_yaw_joint": 0.16, ".*_hip_roll_joint": 0.24, ".*_hip_pitch_joint": 0.16, ".*_knee_joint": 0.24, ".*_ankle_pitch_joint": 0.06, ".*_ankle_roll_joint": 0.06, "waist_roll_joint": 0.06, "waist_pitch_joint": 0.06, "waist_yaw_joint": 0.16, ".*_shoulder_pitch_joint": 0.06, ".*_shoulder_roll_joint": 0.06, ".*_shoulder_yaw_joint": 0.06, ".*_elbow_joint": 0.06, ".*_wrist_roll_joint": 0.06, ".*_wrist_pitch_joint": 0.06, ".*_wrist_yaw_joint": 0.06, }, Va={ ".*_hip_yaw_joint": 0.01, ".*_hip_roll_joint": 0.01, ".*_hip_pitch_joint": 0.01, ".*_knee_joint": 0.01, ".*_ankle_pitch_joint": 0.01, ".*_ankle_roll_joint": 0.01, "waist_roll_joint": 0.01, "waist_pitch_joint": 0.01, "waist_yaw_joint": 0.01, ".*_shoulder_pitch_joint": 0.01, ".*_shoulder_roll_joint": 0.01, ".*_shoulder_yaw_joint": 0.01, ".*_elbow_joint": 0.01, ".*_wrist_roll_joint": 0.01, ".*_wrist_pitch_joint": 0.01, ".*_wrist_yaw_joint": 0.01, }, ) } ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_simulator.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from isaaclab.sim import SimulationCfg, PhysxCfg def build_simulator_config(sim_config_dict: dict) -> SimulationCfg: """Build simulation configuration from config dictionary.""" policy_freq = sim_config_dict.get("policy_freq", 50) sim_freq = sim_config_dict.get("sim_freq", 200) decimation = int(sim_freq / policy_freq) dt = 1.0 / sim_freq device = sim_config_dict.get("device", "cuda") # PhysX configuration physx_config = sim_config_dict.get("physx", {}) physx = PhysxCfg( bounce_threshold_velocity=physx_config.get( "bounce_threshold_velocity", 0.2 ), gpu_max_rigid_patch_count=physx_config.get( "gpu_max_rigid_patch_count", int(10 * 2**15) ), ) return SimulationCfg( dt=dt, render_interval=decimation, physx=physx, device=device, ) ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_termination.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import inspect import isaaclab.envs.mdp as isaaclab_mdp import isaaclab.utils.math as isaaclab_math import torch from isaaclab.envs import ManagerBasedRLEnv from isaaclab.managers import TerminationTermCfg from isaaclab.utils import configclass from holomotion.src.env.isaaclab_components import ( isaaclab_motion_tracking_command as motion_tracking_command, isaaclab_utils, ) def _list_supported_terminations() -> list[str]: custom_terminations = { name for name, obj in globals().items() if ( inspect.isfunction(obj) and obj.__module__ == __name__ and not name.startswith("_") ) } native_terminations = { name for name in dir(isaaclab_mdp.terminations) if ( not name.startswith("_") and callable(getattr(isaaclab_mdp.terminations, name)) ) } return sorted(custom_terminations | native_terminations) def _resolve_termination_func(name: str): func = globals().get(name) if inspect.isfunction(func) and func.__module__ == __name__: return func func = getattr(isaaclab_mdp.terminations, name, None) if callable(func): return func supported = _list_supported_terminations() raise ValueError( f"Unknown termination function: {name}. Supported: {supported}" ) def global_bodylink_pos_far( env: ManagerBasedRLEnv, threshold: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Any body link position deviates more than threshold (world frame).""" command: motion_tracking_command.RefMotionCommand = ( env.command_manager.get_term(command_name) ) ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next( prefix=ref_prefix ) # [B, Nb, 3] robot_pos_w = command.robot.data.body_pos_w # [B, Nb, 3] keybody_idxs = isaaclab_utils._get_body_indices( command.robot, keybody_names ) if keybody_idxs is not None and len(keybody_idxs) > 0: idxs = torch.as_tensor( keybody_idxs, device=ref_pos_w.device, dtype=torch.long, ) ref_pos_w = ref_pos_w[:, idxs] robot_pos_w = robot_pos_w[:, idxs] error = torch.norm(ref_pos_w - robot_pos_w, dim=-1) # [B, Nb] return torch.any(error > threshold, dim=-1) # [B] def anchor_ref_z_far( env: ManagerBasedRLEnv, threshold: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: """Anchor link z difference exceeds threshold (world frame).""" command: motion_tracking_command.RefMotionCommand = ( env.command_manager.get_term(command_name) ) ref_z = command.get_ref_motion_anchor_bodylink_global_pos_immediate_next( prefix=ref_prefix )[:, -1] robot_z = command.global_robot_anchor_pos_cur[:, -1] return (ref_z - robot_z).abs() > threshold def ref_gravity_projection_far( env: ManagerBasedRLEnv, threshold: float, asset_name: str = "robot", command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: """Difference in projected gravity z-component exceeds threshold. Project world gravity into the anchor body frames using inverse quaternion rotation and compare z-components. """ command: motion_tracking_command.RefMotionCommand = ( env.command_manager.get_term(command_name) ) g_w = env.scene[asset_name].data.GRAVITY_VEC_W # [B, 3] # Reference anchor orientation (xyzw) from motion cache ref_anchor_quat_xyzw = ( command.get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next( prefix=ref_prefix ) ) # [B, 4] motion_projected_gravity_b = isaaclab_math.quat_apply_inverse( ref_anchor_quat_xyzw, g_w ) # [B, 3] # motion_projected_gravity_b = isaaclab_math.quat_rotate_inverse( # ref_anchor_quat_xyzw, g_w # ) # [B, 3] # Robot anchor orientation (xyzw) from sim robot_anchor_quat_wxyz = command.robot.data.body_quat_w[ :, command.anchor_bodylink_idx ] # [B, 4] robot_projected_gravity_b = isaaclab_math.quat_apply_inverse( robot_anchor_quat_wxyz, g_w ) # [B, 3] # robot_projected_gravity_b = isaaclab_math.quat_rotate_inverse( # robot_anchor_quat_wxyz, g_w # ) # [B, 3] return ( motion_projected_gravity_b[:, 2] - robot_projected_gravity_b[:, 2] ).abs() > threshold def keybody_ref_pos_far( env: ManagerBasedRLEnv, threshold: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Any key body link z difference exceeds threshold (world frame).""" command: motion_tracking_command.RefMotionCommand = ( env.command_manager.get_term(command_name) ) ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next( prefix=ref_prefix ) # [B, Nb, 3] robot_pos_w = command.robot.data.body_pos_w # [B, Nb, 3] keybody_idxs = isaaclab_utils._get_body_indices( command.robot, keybody_names ) if keybody_idxs is not None and len(keybody_idxs) > 0: idxs = torch.as_tensor( keybody_idxs, device=ref_pos_w.device, dtype=torch.long, ) ref_pos_w = ref_pos_w[:, idxs] robot_pos_w = robot_pos_w[:, idxs] error = torch.norm(ref_pos_w - robot_pos_w, dim=-1) # [B, Nb] return torch.any(error > threshold, dim=-1) # [B] def keybody_ref_z_far( env: ManagerBasedRLEnv, threshold: float, command_name: str = "ref_motion", keybody_names: list[str] | None = None, ref_prefix: str = "ref_", ) -> torch.Tensor: """Any key body link z difference exceeds threshold (world frame).""" command: motion_tracking_command.RefMotionCommand = ( env.command_manager.get_term(command_name) ) ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next( prefix=ref_prefix ) # [B, Nb, 3] robot_pos_w = command.robot.data.body_pos_w # [B, Nb, 3] keybody_idxs = isaaclab_utils._get_body_indices( command.robot, keybody_names ) if keybody_idxs is not None and len(keybody_idxs) > 0: idxs = torch.as_tensor( keybody_idxs, device=ref_pos_w.device, dtype=torch.long, ) ref_pos_w = ref_pos_w[:, idxs] robot_pos_w = robot_pos_w[:, idxs] error_z = (ref_pos_w[..., 2] - robot_pos_w[..., 2]).abs() # [B, Nb] return torch.any(error_z > threshold, dim=-1) # [B] def wholebody_mpjpe_far( env: ManagerBasedRLEnv, threshold: float, command_name: str = "ref_motion", ref_prefix: str = "ref_", ) -> torch.Tensor: """Mean whole-body DOF position error exceeds threshold.""" command: motion_tracking_command.RefMotionCommand = ( env.command_manager.get_term(command_name) ) ref_dof_pos = command.get_ref_motion_dof_pos_immediate_next( prefix=ref_prefix ) robot_dof_pos = command.robot.data.joint_pos mean_dof_error = torch.mean(torch.abs(robot_dof_pos - ref_dof_pos), dim=-1) return mean_dof_error > threshold def motion_end( env: ManagerBasedRLEnv, command_name: str = "ref_motion", ) -> torch.Tensor: """Terminate when reference motion frames exceed their end frames. Returns a boolean mask of shape [num_envs]. """ command: motion_tracking_command.RefMotionCommand = ( env.command_manager.get_term(command_name) ) result = command.motion_end_mask.clone().bool() return result @configclass class TerminationsCfg: pass def build_terminations_config( termination_config_dict: dict, ) -> TerminationsCfg: terminations_cfg = TerminationsCfg() for termination_name, termination_cfg in termination_config_dict.items(): termination_cfg = isaaclab_utils.resolve_holo_config(termination_cfg) func = _resolve_termination_func(termination_name) params = isaaclab_utils.resolve_holo_config( termination_cfg.get("params", {}) ) term_cfg = TerminationTermCfg( func=func, params=params, time_out=(termination_name == "time_out") or termination_cfg.get("time_out", False), ) setattr(terminations_cfg, termination_name, term_cfg) return terminations_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_terrain.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import isaaclab.sim as sim_utils import isaaclab.terrains as terrain_gen import numpy as np import torch from isaaclab.terrains import TerrainImporter, TerrainImporterCfg from isaaclab.terrains.height_field import ( HfDiscreteObstaclesTerrainCfg, HfPyramidSlopedTerrainCfg, HfRandomUniformTerrainCfg, HfTerrainBaseCfg, ) from isaaclab.terrains.height_field.utils import height_field_to_mesh from isaaclab.utils import configclass from loguru import logger def _convert_range_like_params(params: dict) -> dict: """Convert list values for common range/size keys to tuples. This helps map Hydra YAML list values into IsaacLab config classes that expect tuples (e.g. ``*_range``). """ converted = {} for key, value in params.items(): if isinstance(value, list) and ( key.endswith("_range") or key in ("size", "difficulty_range") ): converted[key] = tuple(value) else: converted[key] = value return converted @height_field_to_mesh def plane_terrain(difficulty: float, cfg: HfTerrainBaseCfg) -> np.ndarray: """Generate a truly flat height-field patch. This is a lightweight alternative to using ``random_uniform`` with a zero noise range. The ``difficulty`` parameter is ignored. """ width_pixels = int(cfg.size[0] / cfg.horizontal_scale) length_pixels = int(cfg.size[1] / cfg.horizontal_scale) return np.zeros((width_pixels, length_pixels), dtype=np.int16) @configclass class HfPlaneTerrainCfg(HfTerrainBaseCfg): """Configuration for a flat height-field plane terrain.""" function = plane_terrain class RandomSpawnTerrainImporter(TerrainImporter): """Terrain importer that spawns robots randomly within each sub-terrain.""" _terrain_width: float | None = None _terrain_length: float | None = None _spawn_margin: float = 0.0 def _compute_env_origins_curriculum( self, num_envs: int, origins: torch.Tensor ) -> torch.Tensor: """Compute env origins with random (x, y) positions. This overrides the default curriculum-based distribution to add random offsets within each sub-terrain's bounds. Args: num_envs: Number of environments. origins: Terrain origins tensor of shape (num_rows, num_cols, 3). Returns: Environment origins tensor of shape (num_envs, 3). """ num_rows, num_cols = origins.shape[:2] # Get sub-terrain size from terrain generator config if self.cfg.terrain_generator is None: raise ValueError( "terrain_generator config is required for random spawning" ) sub_terrain_size = self.cfg.terrain_generator.size terrain_width, terrain_length = ( sub_terrain_size[0], sub_terrain_size[1], ) spawn_margin = float(getattr(self.cfg, "random_spawn_margin", 0.0)) spawn_margin = max(0.0, spawn_margin) # Clamp margin to avoid invalid sampling ranges. max_margin = 0.5 * min(float(terrain_width), float(terrain_length)) if spawn_margin >= max_margin: logger.warning( f"random_spawn_margin={spawn_margin} is too large " f"for sub-terrain size={sub_terrain_size}. " "Clamping to 0.0." ) spawn_margin = 0.0 # Maximum initial level possible for the terrains if self.cfg.max_init_terrain_level is None: max_init_level = num_rows - 1 else: max_init_level = min(self.cfg.max_init_terrain_level, num_rows - 1) # Store maximum terrain level possible self.max_terrain_level = num_rows # Use default curriculum-based assignment self.terrain_levels = torch.randint( 0, max_init_level + 1, (num_envs,), device=self.device ) self.terrain_types = torch.div( torch.arange(num_envs, device=self.device), (num_envs / num_cols), rounding_mode="floor", ).to(torch.long) # Create environment origins tensor starting from terrain origins env_origins = torch.zeros(num_envs, 3, device=self.device) env_origins[:] = origins[self.terrain_levels, self.terrain_types] # Add random (x, y) offsets within each sub-terrain's bounds # Offset range: [-size/2 + margin, size/2 - margin] for both x and y x_min = -terrain_width / 2 + spawn_margin x_max = terrain_width / 2 - spawn_margin y_min = -terrain_length / 2 + spawn_margin y_max = terrain_length / 2 - spawn_margin x_offsets = torch.empty(num_envs, device=self.device).uniform_( x_min, x_max ) y_offsets = torch.empty(num_envs, device=self.device).uniform_( y_min, y_max ) env_origins[:, 0] += x_offsets env_origins[:, 1] += y_offsets # Store terrain size for use in update_env_origins self._terrain_width = terrain_width self._terrain_length = terrain_length self._spawn_margin = spawn_margin return env_origins def update_env_origins( self, env_ids: torch.Tensor, move_up: torch.Tensor, move_down: torch.Tensor, ): """Update env origins when terrain levels change.""" # Check if grid-like spawning if self.terrain_origins is None: return # Update terrain level for the envs self.terrain_levels[env_ids] += 1 * move_up - 1 * move_down # Robots that solve the last level are sent to a random one # The minimum level is zero self.terrain_levels[env_ids] = torch.where( self.terrain_levels[env_ids] >= self.max_terrain_level, torch.randint_like( self.terrain_levels[env_ids], self.max_terrain_level ), torch.clip(self.terrain_levels[env_ids], 0), ) # Update the env origins with terrain origins self.env_origins[env_ids] = self.terrain_origins[ self.terrain_levels[env_ids], self.terrain_types[env_ids] ] # Add random (x, y) offsets within each sub-terrain's bounds if self._terrain_width is None or self._terrain_length is None: return num_updated = len(env_ids) x_min = -self._terrain_width / 2 + self._spawn_margin x_max = self._terrain_width / 2 - self._spawn_margin y_min = -self._terrain_length / 2 + self._spawn_margin y_max = self._terrain_length / 2 - self._spawn_margin x_offsets = torch.empty(num_updated, device=self.device).uniform_( x_min, x_max ) y_offsets = torch.empty(num_updated, device=self.device).uniform_( y_min, y_max ) self.env_origins[env_ids, 0] += x_offsets self.env_origins[env_ids, 1] += y_offsets def build_terrain_config( config: dict, scene_env_spacing: float = None ) -> TerrainImporterCfg: """Build terrain configuration. Preferred usage in Holomotion is via the IsaacLab terrain generator API with height-field sub-terrains fully specified from Hydra configs. For backward compatibility only, two legacy modes are still supported: * ``terrain_type=\"plane\"``: simple infinite plane using Isaac Sim's grid. * ``terrain_type=\"usd\"``: load terrain from a local USD file. All paths are offline by construction. Visual materials must use local data: * ``visual_material.type=\"color\"`` maps to :class:`PreviewSurfaceCfg` with ``diffuse_color``, ``metallic`` and ``roughness``. * ``visual_material.type=\"mdl\"`` is accepted only for local MDL files and never uses NVIDIA Nucleus. When paths are invalid, a neutral color material is used instead. Args: config: Terrain configuration dictionary with fields: * ``terrain_type``: ``\"generator\"`` (preferred), ``\"plane\"`` or ``\"usd\"`` (legacy). * ``generator`` (required when ``terrain_type=\"generator\"``): high-level :class:`TerrainGeneratorCfg` parameters such as ``num_rows``, ``num_cols``, ``size``, ``border_width``, ``horizontal_scale``, ``vertical_scale``, ``slope_threshold``, ``difficulty_range``, ``color_scheme``. * ``height_field`` (required when ``terrain_type=\"generator\"``): height-field sub-terrain configuration with: - ``type``: ``\"plane\"``, ``\"random_uniform\"``, ``\"discrete_obstacles\"`` or ``\"pyramid_sloped\"``. - Remaining keys are forwarded to the corresponding :class:`HfRandomUniformTerrainCfg` or :class:`HfDiscreteObstaclesTerrainCfg`. * ``random_spawn`` (optional): if True, spawns robots at random (x, y) positions within each sub-terrain's bounds. * ``random_spawn_margin`` (optional): if set, keeps random spawn points at least this many meters away from sub-terrain edges (helps avoid spawning near the outer border where robots may fall off). * ``visual_material`` (optional): offline visual material config. * ``static_friction``, ``dynamic_friction``, ``restitution``, etc. scene_env_spacing: Environment spacing from scene config (used only when ``terrain_type=\"plane\"`` is selected). Returns: TerrainImporterCfg configured according to the input parameters """ prim_path = config.get("prim_path", "/World/ground") static_friction = config.get("static_friction", 1.0) dynamic_friction = config.get("dynamic_friction", 1.0) restitution = config.get("restitution", 0.0) friction_combine_mode = config.get("friction_combine_mode", "multiply") restitution_combine_mode = config.get( "restitution_combine_mode", "multiply" ) terrain_type = config.get("terrain_type", "generator") if terrain_type == "usd": usd_path = config.get("usd_path") if usd_path is None: raise ValueError( "'usd_path' must be specified for terrain_type 'usd'" ) terrain_cfg = TerrainImporterCfg( prim_path=prim_path, terrain_type="usd", usd_path=usd_path, collision_group=-1, physics_material=sim_utils.RigidBodyMaterialCfg( friction_combine_mode=friction_combine_mode, restitution_combine_mode=restitution_combine_mode, static_friction=static_friction, dynamic_friction=dynamic_friction, restitution=restitution, ), debug_vis=config.get("debug_vis", False), ) return terrain_cfg if terrain_type == "plane": env_spacing = ( scene_env_spacing if scene_env_spacing is not None else 2.5 ) terrain_cfg = TerrainImporterCfg( prim_path=prim_path, terrain_type="plane", collision_group=-1, env_spacing=env_spacing, physics_material=sim_utils.RigidBodyMaterialCfg( friction_combine_mode=friction_combine_mode, restitution_combine_mode=restitution_combine_mode, static_friction=static_friction, dynamic_friction=dynamic_friction, restitution=restitution, ), debug_vis=config.get("debug_vis", False), ) return terrain_cfg if terrain_type != "generator": raise ValueError( f"Unsupported terrain_type '{terrain_type}'. " "Expected 'generator', 'plane', or 'usd'." ) generator_cfg_dict = config.get("generator") if generator_cfg_dict is None: raise ValueError( "When 'terrain_type' is 'generator', a 'generator' dict must be " "provided in terrain config." ) # Optional new path: multiple sub-terrains defined under # generator.sub_terrains. sub_terrains_cfg_dict = generator_cfg_dict.get("sub_terrains") sub_terrains_cfg = None if sub_terrains_cfg_dict is not None: if not isinstance(sub_terrains_cfg_dict, dict): raise ValueError( "Expected 'generator.sub_terrains' to be a mapping from names " "to sub-terrain configs." ) sub_terrains_cfg = {} for sub_name, sub_cfg_dict in sub_terrains_cfg_dict.items(): if not isinstance(sub_cfg_dict, dict): raise ValueError( f"Sub-terrain '{sub_name}' must be a dictionary with at " "least a 'type' field." ) sub_type = sub_cfg_dict.get("type", "random_uniform") sub_proportion = sub_cfg_dict.get("proportion", 1.0) sub_params_raw = { key: value for key, value in sub_cfg_dict.items() if key not in ("type", "proportion") } sub_params = _convert_range_like_params(sub_params_raw) if sub_type == "random_uniform": hf_cfg = HfRandomUniformTerrainCfg( proportion=sub_proportion, **sub_params ) elif sub_type == "plane": hf_cfg = HfPlaneTerrainCfg( proportion=sub_proportion, **sub_params ) elif sub_type == "discrete_obstacles": hf_cfg = HfDiscreteObstaclesTerrainCfg( proportion=sub_proportion, **sub_params ) elif sub_type == "pyramid_sloped": hf_cfg = HfPyramidSlopedTerrainCfg( proportion=sub_proportion, **sub_params ) else: raise ValueError( f"Unknown sub_terrains['{sub_name}'].type '{sub_type}'. " "Expected 'plane', 'random_uniform', 'discrete_obstacles'," " or 'pyramid_sloped'." ) sub_terrains_cfg[sub_name] = hf_cfg # Deprecated path: single height_field block at top-level. if sub_terrains_cfg is None: height_field_cfg_dict = config.get("height_field") if height_field_cfg_dict is None: raise ValueError( "When 'terrain_type' is 'generator', either " "'generator.sub_terrains' or a 'height_field' dict must be " "provided in terrain config." ) logger.warning( "Terrain config is using deprecated 'height_field' key. " "Please migrate to 'generator.sub_terrains' for multi-sub-terrain " "support." ) hf_type = height_field_cfg_dict.get("type", "random_uniform") hf_params_raw = { key: value for key, value in height_field_cfg_dict.items() if key != "type" } hf_params = _convert_range_like_params(hf_params_raw) if hf_type == "random_uniform": height_field_cfg = HfRandomUniformTerrainCfg(**hf_params) elif hf_type == "discrete_obstacles": height_field_cfg = HfDiscreteObstaclesTerrainCfg(**hf_params) else: raise ValueError( f"Unknown height_field.type '{hf_type}'. " "Expected 'random_uniform' or 'discrete_obstacles'." ) sub_terrains_cfg = {"main": height_field_cfg} # Build TerrainGeneratorCfg from Hydra config. generator_params = _convert_range_like_params( { key: value for key, value in generator_cfg_dict.items() if key != "sub_terrains" } ) terrain_generator = terrain_gen.TerrainGeneratorCfg( **{ key: value for key, value in generator_params.items() if key in ( "size", "border_width", "border_height", "num_rows", "num_cols", "horizontal_scale", "vertical_scale", "slope_threshold", "difficulty_range", "color_scheme", "curriculum", "seed", "use_cache", "cache_dir", ) }, sub_terrains=sub_terrains_cfg, ) # Configure visual material for offline use visual_material = None if "visual_material" in config: visual_material_dict = config["visual_material"] material_type = visual_material_dict.get("type", "color") if material_type == "color": # Use PreviewSurfaceCfg with diffuse_color (no internet needed) diffuse_color_raw = visual_material_dict.get( "diffuse_color", (0.8, 0.8, 0.8) ) # Convert list to tuple if needed (YAML loads lists). # Ensure it's a tuple of floats as required by PreviewSurfaceCfg if isinstance(diffuse_color_raw, list): diffuse_color = tuple(float(x) for x in diffuse_color_raw) elif isinstance(diffuse_color_raw, tuple): diffuse_color = tuple(float(x) for x in diffuse_color_raw) else: diffuse_color = diffuse_color_raw metallic = float(visual_material_dict.get("metallic", 0.0)) roughness = float(visual_material_dict.get("roughness", 0.5)) visual_material = sim_utils.PreviewSurfaceCfg( diffuse_color=diffuse_color, metallic=metallic, roughness=roughness, ) elif material_type == "none": # No visual material, rely on vertex colors (e.g. from height map) visual_material = None elif material_type == "mdl": # Use MdlFileCfg with local mdl_path mdl_path = visual_material_dict.get("mdl_path") if mdl_path is None: logger.warning( "visual_material type is 'mdl' but no mdl_path specified. " "Falling back to color material to avoid internet " "requirements." ) visual_material = sim_utils.PreviewSurfaceCfg( diffuse_color=(0.5, 0.5, 0.5) ) else: # Resolve relative paths if not os.path.isabs(mdl_path): if os.path.exists(mdl_path): resolved_mdl_path = os.path.abspath(mdl_path) else: workspace_root = os.path.abspath( os.path.join( os.path.dirname(__file__), "../../../.." ) ) resolved_mdl_path = os.path.join( workspace_root, mdl_path ) else: resolved_mdl_path = mdl_path if os.path.exists(resolved_mdl_path): visual_material = sim_utils.MdlFileCfg( mdl_path=resolved_mdl_path ) else: logger.warning( f"MDL file not found at {resolved_mdl_path}. " "Falling back to color material to avoid internet " "requirements." ) visual_material = sim_utils.PreviewSurfaceCfg( diffuse_color=(0.5, 0.5, 0.5) ) else: logger.warning( f"Unknown visual_material type: {material_type}. " "Using default color material." ) visual_material = sim_utils.PreviewSurfaceCfg( diffuse_color=(0.5, 0.5, 0.5) ) # Configure random spawning within sub-terrains if requested random_spawn = config.get("random_spawn", False) terrain_importer_class = ( RandomSpawnTerrainImporter if random_spawn else TerrainImporter ) terrain_cfg = TerrainImporterCfg( prim_path=prim_path, terrain_type="generator", terrain_generator=terrain_generator, max_init_terrain_level=config.get( "max_init_terrain_level", terrain_generator.num_rows - 1, ), collision_group=-1, visual_material=visual_material, physics_material=sim_utils.RigidBodyMaterialCfg( friction_combine_mode=friction_combine_mode, restitution_combine_mode=restitution_combine_mode, static_friction=static_friction, dynamic_friction=dynamic_friction, restitution=restitution, ), debug_vis=config.get("debug_vis", False), class_type=terrain_importer_class, ) if random_spawn: terrain_cfg.random_spawn_margin = float( config.get("random_spawn_margin", 0.0) ) return terrain_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_utils.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import torch from isaaclab.assets import Articulation from isaaclab.envs import ManagerBasedRLEnv from isaaclab.managers import RewardTermCfg, SceneEntityCfg from isaaclab.sensors import ContactSensor from isaaclab.utils import configclass import isaaclab.utils.math as isaaclab_math from holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import ( RefMotionCommand, ) import isaaclab.envs.mdp as isaaclab_mdp from hydra.utils import instantiate as hydra_instantiate from omegaconf import DictConfig, ListConfig, OmegaConf from loguru import logger def _get_dof_indices( robot: Articulation, key_dofs: list[str] | None, ) -> list[int] | None: if key_dofs is None: return list(range(len(robot.joint_names))) dof_indices = [] for name in key_dofs: if name not in robot.joint_names: raise ValueError( f"DOF '{name}' not found in robot.joint_names: {robot.joint_names}" ) dof_indices.append(robot.joint_names.index(name)) return dof_indices def _get_body_indices( robot: Articulation, keybody_names: list[str] | None, ) -> list[int] | None: """Convert body names to indices. Args: robot: Robot articulation asset keybody_names: List of body names. If None, returns None. Returns: List of body indices corresponding to the given names, or None if keybody_names is None """ if keybody_names is None: return list(range(len(robot.body_names))) body_indices = [] for name in keybody_names: if name not in robot.body_names: raise ValueError( f"Body '{name}' not found in robot.body_names: {robot.body_names}" ) body_indices.append(robot.body_names.index(name)) return body_indices def resolve_holo_config(value): def _sanitize_config_object(obj): for attr, attr_value in vars(obj).items(): sanitized_value = resolve_holo_config(attr_value) setattr(obj, attr, sanitized_value) return obj if isinstance(value, (DictConfig, ListConfig)): value = OmegaConf.to_container(value, resolve=True) if isinstance(value, dict): if "_target_" in value: instantiated = hydra_instantiate(value) if hasattr(instantiated, "__dict__") and not callable( instantiated ): return _sanitize_config_object(instantiated) return instantiated return {key: resolve_holo_config(item) for key, item in value.items()} if isinstance(value, list): return [resolve_holo_config(item) for item in value] if hasattr(value, "__dict__") and not callable(value): return _sanitize_config_object(value) return value ================================================ FILE: holomotion/src/env/isaaclab_components/isaaclab_velocity_tracking_command.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from __future__ import annotations from dataclasses import MISSING from isaaclab.managers import CommandTermCfg from isaaclab.utils import configclass import torch from isaaclab.assets import Articulation from isaaclab.managers import CommandTerm from isaaclab.markers import VisualizationMarkers from isaaclab.envs import ManagerBasedEnv import isaaclab.utils.math as math_utils from isaaclab.markers import VisualizationMarkersCfg from isaaclab.markers.config import ( BLUE_ARROW_X_MARKER_CFG, GREEN_ARROW_X_MARKER_CFG, ) from typing import Sequence class HoloMotionUniformVelocityCommand(CommandTerm): r"""Command generator that generates a velocity command in SE(2) from uniform distribution. The command comprises of a linear velocity in x and y direction and an angular velocity around the z-axis. It is given in the robot's base frame. If the :attr:`cfg.heading_command` flag is set to True, the angular velocity is computed from the heading error similar to doing a proportional control on the heading error. The target heading is sampled uniformly from the provided range. Otherwise, the angular velocity is sampled uniformly from the provided range. Mathematically, the angular velocity is computed as follows from the heading command: .. math:: \omega_z = \frac{1}{2} \text{wrap_to_pi}(\theta_{\text{target}} - \theta_{\text{current}}) """ cfg: HoloMotionUniformVelocityCommandCfg """The configuration of the command generator.""" def __init__( self, cfg: HoloMotionUniformVelocityCommandCfg, env: ManagerBasedEnv ): """Initialize the command generator. Args: cfg: The configuration of the command generator. env: The environment. Raises: ValueError: If the heading command is active but the heading range is not provided. """ # initialize the base class super().__init__(cfg, env) # check configuration if self.cfg.heading_command and self.cfg.ranges.heading is None: raise ValueError( "The velocity command has heading commands active (heading_command=True) but the `ranges.heading`" " parameter is set to None." ) if self.cfg.rel_yaw_envs > 0.0: yaw_min, yaw_max = self.cfg.ranges.ang_vel_z # obtain the robot asset # -- robot self.robot: Articulation = env.scene[cfg.asset_name] # crete buffers to store the command # -- command: x vel, y vel, yaw vel, heading self.vel_command_b = torch.zeros(self.num_envs, 3, device=self.device) self.heading_target = torch.zeros(self.num_envs, device=self.device) self.is_heading_env = torch.zeros( self.num_envs, dtype=torch.bool, device=self.device ) self.is_standing_env = torch.zeros_like(self.is_heading_env) self.is_yaw_env = torch.zeros_like(self.is_heading_env) # -- metrics self.metrics["error_vel_xy"] = torch.zeros( self.num_envs, device=self.device ) self.metrics["error_vel_yaw"] = torch.zeros( self.num_envs, device=self.device ) def __str__(self) -> str: """Return a string representation of the command generator.""" msg = "HoloMotionUniformVelocityCommand:\n" msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n" msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n" msg += f"\tHeading command: {self.cfg.heading_command}\n" if self.cfg.heading_command: msg += f"\tHeading probability: {self.cfg.rel_heading_envs}\n" msg += f"\tStanding probability: {self.cfg.rel_standing_envs}\n" msg += f"\tYaw-only probability: {self.cfg.rel_yaw_envs}" return msg """ Properties """ @property def command(self) -> torch.Tensor: """The desired base velocity command in the base frame. Shape is (num_envs, 3).""" return self.vel_command_b """ Implementation specific functions. """ def _update_metrics(self): # time for which the command was executed max_command_time = self.cfg.resampling_time_range[1] max_command_step = max_command_time / self._env.step_dt # logs data self.metrics["error_vel_xy"] += ( torch.norm( self.vel_command_b[:, :2] - self.robot.data.root_lin_vel_b[:, :2], dim=-1, ) / max_command_step ) self.metrics["error_vel_yaw"] += ( torch.abs( self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2] ) / max_command_step ) def _resample_command(self, env_ids: Sequence[int]): # sample velocity commands r = torch.empty(len(env_ids), device=self.device) # -- linear velocity - x direction self.vel_command_b[env_ids, 0] = r.uniform_(*self.cfg.ranges.lin_vel_x) # -- linear velocity - y direction self.vel_command_b[env_ids, 1] = r.uniform_(*self.cfg.ranges.lin_vel_y) # -- ang vel yaw - rotation around z self.vel_command_b[env_ids, 2] = r.uniform_(*self.cfg.ranges.ang_vel_z) # heading target if self.cfg.heading_command: self.heading_target[env_ids] = r.uniform_(*self.cfg.ranges.heading) # update heading envs self.is_heading_env[env_ids] = ( r.uniform_(0.0, 1.0) <= self.cfg.rel_heading_envs ) self.is_yaw_env[env_ids] = ( r.uniform_(0.0, 1.0) <= self.cfg.rel_yaw_envs ) if self.cfg.heading_command: # yaw-only envs should follow directly sampled yaw commands (not heading control) self.is_heading_env[env_ids] &= ~self.is_yaw_env[env_ids] # update standing envs self.is_standing_env[env_ids] = ( r.uniform_(0.0, 1.0) <= self.cfg.rel_standing_envs ) def _update_command(self): """Post-processes the velocity command. This function sets velocity command to zero for standing environments and computes angular velocity from heading direction if the heading_command flag is set. """ # Compute angular velocity from heading direction if self.cfg.heading_command: # resolve indices of heading envs env_ids = self.is_heading_env.nonzero(as_tuple=False).flatten() # compute angular velocity heading_error = math_utils.wrap_to_pi( self.heading_target[env_ids] - self.robot.data.heading_w[env_ids] ) self.vel_command_b[env_ids, 2] = torch.clip( self.cfg.heading_control_stiffness * heading_error, min=self.cfg.ranges.ang_vel_z[0], max=self.cfg.ranges.ang_vel_z[1], ) yaw_env_ids = self.is_yaw_env.nonzero(as_tuple=False).flatten() self.vel_command_b[yaw_env_ids, :2] = 0.0 # Enforce standing (i.e., zero velocity command) for standing envs # TODO: check if conversion is needed standing_env_ids = self.is_standing_env.nonzero( as_tuple=False ).flatten() self.vel_command_b[standing_env_ids, :] = 0.0 def _set_debug_vis_impl(self, debug_vis: bool): # set visibility of markers # note: parent only deals with callbacks. not their visibility if debug_vis: # create markers if necessary for the first time if not hasattr(self, "goal_vel_visualizer"): # -- goal self.goal_vel_visualizer = VisualizationMarkers( self.cfg.goal_vel_visualizer_cfg ) # -- current self.current_vel_visualizer = VisualizationMarkers( self.cfg.current_vel_visualizer_cfg ) # set their visibility to true self.goal_vel_visualizer.set_visibility(True) self.current_vel_visualizer.set_visibility(True) else: if hasattr(self, "goal_vel_visualizer"): self.goal_vel_visualizer.set_visibility(False) self.current_vel_visualizer.set_visibility(False) def _debug_vis_callback(self, event): # check if robot is initialized # note: this is needed in-case the robot is de-initialized. we can't access the data if not self.robot.is_initialized: return # get marker location # -- base state base_pos_w = self.robot.data.root_pos_w.clone() base_pos_w[:, 2] += 0.5 # -- resolve the scales and quaternions vel_des_arrow_scale, vel_des_arrow_quat = ( self._resolve_xy_velocity_to_arrow(self.command[:, :2]) ) vel_arrow_scale, vel_arrow_quat = self._resolve_xy_velocity_to_arrow( self.robot.data.root_lin_vel_b[:, :2] ) # display markers self.goal_vel_visualizer.visualize( base_pos_w, vel_des_arrow_quat, vel_des_arrow_scale ) self.current_vel_visualizer.visualize( base_pos_w, vel_arrow_quat, vel_arrow_scale ) """ Internal helpers. """ def _resolve_xy_velocity_to_arrow( self, xy_velocity: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """Converts the XY base velocity command to arrow direction rotation.""" # obtain default scale of the marker default_scale = self.goal_vel_visualizer.cfg.markers["arrow"].scale # arrow-scale arrow_scale = torch.tensor(default_scale, device=self.device).repeat( xy_velocity.shape[0], 1 ) arrow_scale[:, 0] *= torch.linalg.norm(xy_velocity, dim=1) * 3.0 # arrow-direction heading_angle = torch.atan2(xy_velocity[:, 1], xy_velocity[:, 0]) zeros = torch.zeros_like(heading_angle) arrow_quat = math_utils.quat_from_euler_xyz( zeros, zeros, heading_angle ) # convert everything back from base to world frame base_quat_w = self.robot.data.root_quat_w arrow_quat = math_utils.quat_mul(base_quat_w, arrow_quat) return arrow_scale, arrow_quat @configclass class HoloMotionUniformVelocityCommandCfg(CommandTermCfg): """Configuration for the uniform velocity command generator.""" class_type: type = HoloMotionUniformVelocityCommand asset_name: str = MISSING """Name of the asset in the environment for which the commands are generated.""" heading_command: bool = False """Whether to use heading command or angular velocity command. Defaults to False. If True, the angular velocity command is computed from the heading error, where the target heading is sampled uniformly from provided range. Otherwise, the angular velocity command is sampled uniformly from provided range. """ heading_control_stiffness: float = 1.0 """Scale factor to convert the heading error to angular velocity command. Defaults to 1.0.""" rel_standing_envs: float = 0.0 """The sampled probability of environments that should be standing still. Defaults to 0.0.""" rel_yaw_envs: float = 0.0 """The sampled probability of environments that should receive yaw-only commands. Defaults to 0.0. For yaw-only environments, the command is post-processed to: - enforce vx=vy=0 This is sampled independently from :attr:`rel_standing_envs`. If an environment is both yaw-only and standing, standing still overrides to zero command. """ rel_heading_envs: float = 1.0 """The sampled probability of environments where the robots follow the heading-based angular velocity command (the others follow the sampled angular velocity command). Defaults to 1.0. This parameter is only used if :attr:`heading_command` is True. """ @configclass class Ranges: """Uniform distribution ranges for the velocity commands.""" lin_vel_x: tuple[float, float] = MISSING """Range for the linear-x velocity command (in m/s).""" lin_vel_y: tuple[float, float] = MISSING """Range for the linear-y velocity command (in m/s).""" ang_vel_z: tuple[float, float] = MISSING """Range for the angular-z velocity command (in rad/s).""" heading: tuple[float, float] | None = None """Range for the heading command (in rad). Defaults to None. This parameter is only used if :attr:`~HoloMotionUniformVelocityCommandCfg.heading_command` is True. """ ranges: Ranges = MISSING """Distribution ranges for the velocity commands.""" goal_vel_visualizer_cfg: VisualizationMarkersCfg = ( GREEN_ARROW_X_MARKER_CFG.replace( prim_path="/Visuals/Command/velocity_goal" ) ) """The configuration for the goal velocity visualization marker. Defaults to GREEN_ARROW_X_MARKER_CFG.""" current_vel_visualizer_cfg: VisualizationMarkersCfg = ( BLUE_ARROW_X_MARKER_CFG.replace( prim_path="/Visuals/Command/velocity_current" ) ) """The configuration for the current velocity visualization marker. Defaults to BLUE_ARROW_X_MARKER_CFG.""" # Set the scale of the visualization markers to (0.5, 0.5, 0.5) goal_vel_visualizer_cfg.markers["arrow"].scale = (0.5, 0.5, 0.5) current_vel_visualizer_cfg.markers["arrow"].scale = (0.5, 0.5, 0.5) @configclass class VelTrack_CommandsCfg: pass def _convert_ranges_dict_to_object( ranges_dict: dict, ) -> HoloMotionUniformVelocityCommandCfg.Ranges: """Convert a dict of ranges to a proper Ranges object with tuples.""" ranges_kwargs = {} for key, value in ranges_dict.items(): if value is None: ranges_kwargs[key] = None elif isinstance(value, (list, tuple)): ranges_kwargs[key] = tuple(value) else: ranges_kwargs[key] = value return HoloMotionUniformVelocityCommandCfg.Ranges(**ranges_kwargs) def build_velocity_commands_config( command_config_dict: dict, ) -> VelTrack_CommandsCfg: """Build a CommandsCfg that supports velocity commands via IsaacLab isaaclab_mdp. Expected format: { "base_velocity": { "type": "VelocityCommandCfg" | "HoloMotionUniformVelocityCommandCfg" | "UniformLevelVelocityCommandCfg", "params": { ... } # args compatible with mdp command cfgs } } For ranges and limit_ranges, pass them as dicts with keys like lin_vel_x, lin_vel_y, ang_vel_z, heading. """ commands_cfg = VelTrack_CommandsCfg() for name, cfg in command_config_dict.items(): command_type = cfg.get("type", "VelocityCommandCfg") params = cfg.get("params", {}).copy() if "ranges" in params and isinstance(params["ranges"], dict): params["ranges"] = _convert_ranges_dict_to_object(params["ranges"]) if "limit_ranges" in params and isinstance( params["limit_ranges"], dict ): params["limit_ranges"] = _convert_ranges_dict_to_object( params["limit_ranges"] ) if command_type == "HoloMotionUniformVelocityCommandCfg": term_cfg = HoloMotionUniformVelocityCommandCfg(**params) else: raise ValueError(f"Unknown velocity command type: {command_type}") setattr(commands_cfg, name, term_cfg) return commands_cfg ================================================ FILE: holomotion/src/env/isaaclab_components/unitree_actuators.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # This file is modified from the unitree_rl_lab repository: # https://github.com/unitreerobotics/unitree_rl_lab from __future__ import annotations import json import os from pathlib import Path import torch from dataclasses import MISSING from typing import Sequence from isaaclab.actuators import DelayedPDActuator, DelayedPDActuatorCfg from isaaclab.utils import configclass from isaaclab.utils.types import ArticulationActions from loguru import logger class UnitreeActuator(DelayedPDActuator): """Unitree actuator class that implements a torque-speed curve for the actuators. The torque-speed curve is defined as follows: Torque Limit, N·m ^ Y2──────────| |──────────────Y1 | │\ | │ \ | │ \ | | \ ------------+--------------|------> velocity: rad/s X1 X2 - Y1: Peak Torque Test (Torque and Speed in the Same Direction) - Y2: Peak Torque Test (Torque and Speed in the Opposite Direction) - X1: Maximum Speed at Full Torque (T-N Curve Knee Point) - X2: No-Load Speed Test - Fs: Static friction coefficient - Fd: Dynamic friction coefficient - Va: Velocity at which the friction is fully activated """ cfg: UnitreeActuatorCfg armature: torch.Tensor """The armature of the actuator joints. Shape is (num_envs, num_joints). armature = J2 + J1 * i2 ^ 2 + Jr * (i1 * i2) ^ 2 """ def __init__(self, cfg: UnitreeActuatorCfg, *args, **kwargs): super().__init__(cfg, *args, **kwargs) self._joint_vel = torch.zeros_like(self.computed_effort) self._effort_y1 = self._parse_joint_parameter(cfg.Y1, 1e9) self._effort_y2 = self._parse_joint_parameter(cfg.Y2, cfg.Y1) self._velocity_x1 = self._parse_joint_parameter(cfg.X1, 1e9) self._velocity_x2 = self._parse_joint_parameter(cfg.X2, 1e9) self._friction_static = self._parse_joint_parameter(cfg.Fs, 0.0) self._friction_dynamic = self._parse_joint_parameter(cfg.Fd, 0.0) self._activation_vel = self._parse_joint_parameter(cfg.Va, 0.01) def compute( self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor, ) -> ArticulationActions: # save current joint vel self._joint_vel[:] = joint_vel # calculate the desired joint torques control_action = super().compute(control_action, joint_pos, joint_vel) # apply friction model on the torque self.applied_effort -= ( self._friction_static * torch.tanh(joint_vel / self._activation_vel) + self._friction_dynamic * joint_vel ) control_action.joint_positions = None control_action.joint_velocities = None control_action.joint_efforts = self.applied_effort return control_action def _clip_effort(self, effort: torch.Tensor) -> torch.Tensor: # check if the effort is the same direction as the joint velocity same_direction = (self._joint_vel * effort) > 0 max_effort = torch.where( same_direction, self._effort_y1, self._effort_y2 ) # check if the joint velocity is less than the max speed at full torque max_effort = torch.where( self._joint_vel.abs() < self._velocity_x1, max_effort, self._compute_effort_limit(max_effort), ) return torch.clip(effort, -max_effort, max_effort) def _compute_effort_limit(self, max_effort): k = -max_effort / (self._velocity_x2 - self._velocity_x1) limit = k * (self._joint_vel.abs() - self._velocity_x1) + max_effort return limit.clip(min=0.0) class UnitreeErfiActuator(UnitreeActuator): """Unitree actuator with per-env ERFI-50 torque perturbations. On environment reset, each env is assigned either step-wise random force injection (RFI) or episode-level random actuation offset (RAO). During rollout, only the selected mode is applied for that env. """ cfg: UnitreeErfiActuatorCfg def __init__(self, cfg: UnitreeErfiActuatorCfg, *args, **kwargs): super().__init__(cfg, *args, **kwargs) self._ema_filter_alpha = float(cfg.ema_filter_alpha) if not 0.0 <= self._ema_filter_alpha <= 1.0: raise ValueError( "ema_filter_alpha must be within [0, 1], " f"got {self._ema_filter_alpha}." ) self._ema_filter_debug_dump_path = ( cfg.ema_filter_debug_dump_path or os.environ.get("HOLOMOTION_EMA_FILTER_DEBUG_DUMP_PATH") ) self._ema_filter_debug_stop_after_dump = self._parse_bool_env( "HOLOMOTION_EMA_FILTER_DEBUG_STOP_AFTER_DUMP", cfg.ema_filter_debug_stop_after_dump, ) self._ema_filter_debug_dumped = False self._ema_filter_state = torch.zeros_like(self.computed_effort) self._ema_filter_initialized = torch.zeros( self._num_envs, dtype=torch.bool, device=self._device ) self._mode_is_rfi = torch.zeros( self._num_envs, dtype=torch.bool, device=self._device ) self._rfi_lim_scale = torch.ones_like(self.computed_effort) self._rao_scale = torch.zeros_like(self.computed_effort) def reset(self, env_ids: Sequence[int] | slice | None): super().reset(env_ids) env_ids_tensor = self._env_ids_to_tensor(env_ids) if env_ids_tensor.numel() == 0: return if self.cfg.ema_filter_enabled: self._ema_filter_state[env_ids_tensor] = 0.0 self._ema_filter_initialized[env_ids_tensor] = False if not self.cfg.erfi_enabled: self._mode_is_rfi[env_ids_tensor] = False self._rfi_lim_scale[env_ids_tensor] = 1.0 self._rao_scale[env_ids_tensor] = 0.0 return sampled_is_rfi = ( torch.rand(env_ids_tensor.numel(), device=self._device) < self.cfg.rfi_probability ) self._mode_is_rfi[env_ids_tensor] = sampled_is_rfi if self.cfg.randomize_rfi_lim: self._rfi_lim_scale[env_ids_tensor] = self._sample_uniform( self.cfg.rfi_lim_range[0], self.cfg.rfi_lim_range[1], (env_ids_tensor.numel(), self.num_joints), ) else: self._rfi_lim_scale[env_ids_tensor] = 1.0 self._rao_scale[env_ids_tensor] = self._sample_uniform( -self.cfg.rao_lim, self.cfg.rao_lim, (env_ids_tensor.numel(), self.num_joints), ) rfi_env_ids = env_ids_tensor[sampled_is_rfi] if rfi_env_ids.numel() > 0: self._rao_scale[rfi_env_ids] = 0.0 def compute( self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor, ) -> ArticulationActions: control_action = self._filter_joint_position_action(control_action) if not self.cfg.erfi_enabled: return super().compute(control_action, joint_pos, joint_vel) if control_action.joint_efforts is None: base_joint_efforts = torch.zeros_like(joint_pos) else: base_joint_efforts = control_action.joint_efforts.clone() effort_limit = self.effort_limit.to(base_joint_efforts) rfi_noise = self._sample_uniform(-1.0, 1.0, base_joint_efforts.shape) rfi_term = ( rfi_noise * self.cfg.rfi_lim * self._rfi_lim_scale * effort_limit ) rao_term = self._rao_scale * effort_limit mode_is_rfi = self._mode_is_rfi.unsqueeze(-1) control_action_with_erfi = ArticulationActions( joint_positions=control_action.joint_positions, joint_velocities=control_action.joint_velocities, joint_efforts=base_joint_efforts + torch.where(mode_is_rfi, rfi_term, rao_term), joint_indices=control_action.joint_indices, ) return super().compute(control_action_with_erfi, joint_pos, joint_vel) def _filter_joint_position_action( self, control_action: ArticulationActions ) -> ArticulationActions: if not self.cfg.ema_filter_enabled: self._maybe_dump_ema_filter_debug_skip("ema_filter_disabled") return control_action if control_action.joint_positions is None: self._maybe_dump_ema_filter_debug_skip("joint_positions_none") return control_action raw_joint_positions = control_action.joint_positions previous_filtered_joint_positions = self._ema_filter_state.clone() needs_init = ~self._ema_filter_initialized filtered_joint_positions = raw_joint_positions.clone() if torch.any(~needs_init): filtered_joint_positions = torch.where( needs_init.unsqueeze(-1), raw_joint_positions, self._ema_filter_alpha * raw_joint_positions + (1.0 - self._ema_filter_alpha) * self._ema_filter_state, ) self._maybe_dump_ema_filter_debug_verification( raw_joint_positions=raw_joint_positions, filtered_joint_positions=filtered_joint_positions, previous_filtered_joint_positions=previous_filtered_joint_positions, needs_init=needs_init, ) self._ema_filter_state[:] = filtered_joint_positions self._ema_filter_initialized[:] = True return ArticulationActions( joint_positions=filtered_joint_positions, joint_velocities=control_action.joint_velocities, joint_efforts=control_action.joint_efforts, joint_indices=control_action.joint_indices, ) def _maybe_dump_ema_filter_debug_verification( self, raw_joint_positions: torch.Tensor, filtered_joint_positions: torch.Tensor, previous_filtered_joint_positions: torch.Tensor, needs_init: torch.Tensor, ) -> None: if ( self._ema_filter_debug_dumped or not self._ema_filter_debug_dump_path ): return rank = os.environ.get("RANK") if rank is not None and rank != "0": return initialized_env_ids = torch.nonzero( ~needs_init, as_tuple=False ).flatten() if initialized_env_ids.numel() == 0: return env_idx = int(initialized_env_ids[0].item()) raw = raw_joint_positions[env_idx].detach().cpu() prev = previous_filtered_joint_positions[env_idx].detach().cpu() actual = filtered_joint_positions[env_idx].detach().cpu() expected = ( self._ema_filter_alpha * raw + (1.0 - self._ema_filter_alpha) * prev ) matched = torch.allclose(actual, expected, atol=1.0e-6, rtol=1.0e-6) dump_path = Path(self._ema_filter_debug_dump_path) dump_path.parent.mkdir(parents=True, exist_ok=True) dump_path.write_text( json.dumps( { "alpha": self._ema_filter_alpha, "env_index": env_idx, "matched": bool(matched), "raw_joint_positions": raw.tolist(), "previous_filtered_joint_positions": prev.tolist(), "expected_filtered_joint_positions": expected.tolist(), "actual_filtered_joint_positions": actual.tolist(), "pid": os.getpid(), "rank": rank or "0", }, indent=2, ) ) self._ema_filter_debug_dumped = True logger.info("Wrote EMA verification dump to {}", dump_path) if self._ema_filter_debug_stop_after_dump: raise RuntimeError(f"EMA verification dump written to {dump_path}") def _maybe_dump_ema_filter_debug_skip(self, reason: str) -> None: self._maybe_dump_ema_filter_debug_payload( { "applied": False, "reason": reason, } ) def _maybe_dump_ema_filter_debug_payload(self, payload: dict) -> None: if ( self._ema_filter_debug_dumped or not self._ema_filter_debug_dump_path ): return rank = os.environ.get("RANK") if rank is not None and rank != "0": return dump_path = Path(self._ema_filter_debug_dump_path) dump_path.parent.mkdir(parents=True, exist_ok=True) dump_path.write_text( json.dumps( { **payload, "alpha": self._ema_filter_alpha, "pid": os.getpid(), "rank": rank or "0", }, indent=2, ) ) self._ema_filter_debug_dumped = True logger.info("Wrote EMA verification dump to {}", dump_path) if self._ema_filter_debug_stop_after_dump: raise RuntimeError(f"EMA verification dump written to {dump_path}") @staticmethod def _parse_bool_env(name: str, default: bool) -> bool: raw_value = os.environ.get(name) if raw_value is None: return bool(default) return raw_value.strip().lower() in {"1", "true", "yes", "on"} def _env_ids_to_tensor( self, env_ids: Sequence[int] | slice | None ) -> torch.Tensor: if env_ids is None or env_ids == slice(None): return torch.arange(self._num_envs, device=self._device) if isinstance(env_ids, torch.Tensor): return env_ids.to(device=self._device, dtype=torch.long).flatten() return torch.tensor(env_ids, device=self._device, dtype=torch.long) def _sample_uniform( self, low: float, high: float, shape: tuple[int, ...] ) -> torch.Tensor: return torch.empty(shape, device=self._device).uniform_(low, high) @configclass class UnitreeActuatorCfg(DelayedPDActuatorCfg): """ Configuration for Unitree actuators. """ class_type: type = UnitreeActuator X1: float = 1e9 """Maximum Speed at Full Torque(T-N Curve Knee Point) Unit: rad/s""" X2: float = 1e9 """No-Load Speed Test Unit: rad/s""" Y1: float = MISSING """Peak Torque Test(Torque and Speed in the Same Direction) Unit: N*m""" Y2: float | None = None """Peak Torque Test(Torque and Speed in the Opposite Direction) Unit: N*m""" Fs: float = 0.0 """ Static friction coefficient """ Fd: float = 0.0 """ Dynamic friction coefficient """ Va: float = 0.01 """ Velocity at which the friction is fully activated """ @configclass class UnitreeErfiActuatorCfg(UnitreeActuatorCfg): """Configuration for Unitree actuators with ERFI-50 perturbations.""" class_type: type = UnitreeErfiActuator erfi_enabled: bool = False """Whether ERFI perturbations are enabled for this actuator.""" ema_filter_enabled: bool = False """Whether to apply EMA filtering to incoming joint-position actions.""" ema_filter_alpha: float = 1.0 """EMA mixing factor using filtered = alpha * raw + (1 - alpha) * prev.""" ema_filter_debug_dump_path: str | None = None """Optional JSON path for a one-shot EMA verification dump during runtime.""" ema_filter_debug_stop_after_dump: bool = False """Whether to stop execution after writing the EMA verification dump.""" rfi_probability: float = 0.5 """Probability of assigning RFI to an environment on reset.""" rfi_lim: float = 0.1 """Base RFI limit, expressed as a ratio of joint effort limits.""" randomize_rfi_lim: bool = True """Whether to randomize the per-episode RFI limit scale.""" rfi_lim_range: tuple[float, float] = (0.5, 1.5) """Multiplicative range for per-episode RFI scaling.""" rao_lim: float = 0.1 """RAO limit, expressed as a ratio of joint effort limits.""" @configclass class UnitreeActuatorCfg_M107_15(UnitreeActuatorCfg): X1 = 14.0 X2 = 25.6 Y1 = 150.0 Y2 = 182.8 armature = 0.063259741 @configclass class UnitreeActuatorCfg_M107_24(UnitreeActuatorCfg): X1 = 8.8 X2 = 16 Y1 = 240 Y2 = 292.5 armature = 0.160478022 @configclass class UnitreeActuatorCfg_Go2HV(UnitreeActuatorCfg): X1 = 13.5 X2 = 30 Y1 = 20.2 Y2 = 23.4 @configclass class UnitreeActuatorCfg_N7520_14p3(UnitreeActuatorCfg): # Decimal point cannot be used as variable name, use `p` instead X1 = 22.63 X2 = 35.52 Y1 = 71 Y2 = 83.3 Fs = 1.6 Fd = 0.16 """ | rotor | 0.489e-4 kg·m² | gear_1 | 0.098e-4 kg·m² | ratio | 4.5 | gear_2 | 0.533e-4 kg·m² | ratio | 48/22+1 """ armature = 0.01017752 @configclass class UnitreeActuatorCfg_N7520_22p5(UnitreeActuatorCfg): # Decimal point cannot be used as variable name, use `p` instead X1 = 14.5 X2 = 22.7 Y1 = 111.0 Y2 = 131.0 Fs = 2.4 Fd = 0.24 """ | rotor | 0.489e-4 kg·m² | gear_1 | 0.109e-4 kg·m² | ratio | 4.5 | gear_2 | 0.738e-4 kg·m² | ratio | 5.0 """ armature = 0.025101925 @configclass class UnitreeActuatorCfg_N5010_16(UnitreeActuatorCfg): X1 = 27.0 X2 = 41.5 Y1 = 9.5 Y2 = 17.0 """ | rotor | 0.084e-4 kg·m² | gear_1 | 0.015e-4 kg·m² | ratio | 4 | gear_2 | 0.068e-4 kg·m² | ratio | 4 """ armature = 0.0021812 @configclass class UnitreeActuatorCfg_N5020_16(UnitreeActuatorCfg): X1 = 30.86 X2 = 40.13 Y1 = 24.8 Y2 = 31.9 Fs = 0.6 Fd = 0.06 """ | rotor | 0.139e-4 kg·m² | gear_1 | 0.017e-4 kg·m² | ratio | 46/18+1 | gear_2 | 0.169e-4 kg·m² | ratio | 56/16+1 """ armature = 0.003609725 @configclass class UnitreeActuatorCfg_W4010_25(UnitreeActuatorCfg): X1 = 15.3 X2 = 24.76 Y1 = 4.8 Y2 = 8.6 Fs = 0.6 Fd = 0.06 """ | rotor | 0.068e-4 kg·m² | gear_1 | | ratio | 5 | gear_2 | | ratio | 5 """ armature = 0.00425 ================================================ FILE: holomotion/src/env/motion_tracking.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import torch import time import os import yaml from collections import deque from functools import wraps from easydict import EasyDict import random import numpy as np from isaaclab.actuators import ImplicitActuatorCfg from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg from isaaclab.sim import PhysxCfg, SimulationCfg from isaaclab.utils import configclass from isaaclab.utils.io import dump_yaml from loguru import logger from omegaconf import OmegaConf from holomotion.src.env.isaaclab_components import ( ActionsCfg, VelTrack_CommandsCfg, MoTrack_CommandsCfg, EventsCfg, MotionTrackingSceneCfg, ObservationsCfg, RewardsCfg, TerminationsCfg, CurriculumCfg, build_actions_config, build_motion_tracking_commands_config, build_velocity_commands_config, build_domain_rand_config, build_curriculum_config, build_observations_config, build_rewards_config, build_scene_config, build_terminations_config, ) from holomotion.src.env.isaaclab_components.isaaclab_observation import ( ObservationFunctions, ) from holomotion.src.env.isaaclab_components.isaaclab_utils import ( resolve_holo_config, ) # from holomotion.src.modules.agent_modules import ObsSeqSerializer import isaaclab.envs.mdp as isaaclab_mdp from isaaclab.envs.mdp.events import _randomize_prop_by_op from isaaclab.managers import SceneEntityCfg, EventTermCfg from isaaclab.utils import configclass from isaaclab.envs import ManagerBasedEnv from isaaclab.managers import EventTermCfg from isaaclab.managers import EventTermCfg as EventTerm import isaaclab.utils.math as math_utils from isaaclab.assets import Articulation from isaaclab.envs.mdp.events import _randomize_prop_by_op from isaaclab.managers import SceneEntityCfg from typing import TYPE_CHECKING, Literal def _joint_ids_to_tensor( joint_ids: slice | list[int] | tuple[int, ...] | torch.Tensor | None, num_joints: int, device: torch.device | str, ) -> torch.Tensor: if joint_ids is None: return torch.arange(num_joints, device=device, dtype=torch.long) if isinstance(joint_ids, slice): if joint_ids == slice(None): return torch.arange(num_joints, device=device, dtype=torch.long) return torch.arange(num_joints, device=device, dtype=torch.long)[ joint_ids ] if isinstance(joint_ids, torch.Tensor): return joint_ids.to(device=device, dtype=torch.long).flatten() return torch.tensor(joint_ids, device=device, dtype=torch.long) def _select_effort_limit_vector( asset: Articulation, selected_joint_ids: torch.Tensor, ) -> torch.Tensor: num_joints = asset.data.applied_torque.shape[1] device = asset.data.applied_torque.device dtype = asset.data.applied_torque.dtype effort_limit_vec = torch.zeros(num_joints, device=device, dtype=dtype) is_filled = torch.zeros(num_joints, device=device, dtype=torch.bool) for actuator in asset.actuators.values(): actuator_joint_ids = _joint_ids_to_tensor( actuator.joint_indices, num_joints=num_joints, device=device ) actuator_effort_limit = torch.as_tensor( actuator.effort_limit, device=device, dtype=dtype ) if actuator_effort_limit.ndim == 0: actuator_effort_limit = actuator_effort_limit.expand( actuator_joint_ids.numel() ) elif actuator_effort_limit.ndim == 2: if actuator_effort_limit.shape[0] > 1: reference = actuator_effort_limit[0].unsqueeze(0) if not torch.allclose( actuator_effort_limit, reference.expand_as(actuator_effort_limit), ): raise ValueError( "normed_torque_rate requires actuator effort limits to be static across envs." ) actuator_effort_limit = actuator_effort_limit[0] elif actuator_effort_limit.ndim != 1: raise ValueError( "normed_torque_rate expects actuator effort limits to be scalar, 1-D, or 2-D tensors." ) if actuator_effort_limit.numel() != actuator_joint_ids.numel(): raise ValueError( "normed_torque_rate found mismatched actuator joint indices and effort limits." ) effort_limit_vec[actuator_joint_ids] = actuator_effort_limit is_filled[actuator_joint_ids] = True if not torch.all(is_filled[selected_joint_ids]): missing_joint_ids = selected_joint_ids[~is_filled[selected_joint_ids]] raise ValueError( "normed_torque_rate could not resolve actuator effort limits for " f"joint ids {missing_joint_ids.tolist()}." ) selected_effort_limits = effort_limit_vec[selected_joint_ids] if not torch.all(torch.isfinite(selected_effort_limits)): raise ValueError( "normed_torque_rate requires finite actuator effort limits for all selected joints." ) if not torch.all(selected_effort_limits > 0.0): raise ValueError( "normed_torque_rate requires strictly positive actuator effort limits for all selected joints." ) return selected_effort_limits class MotionTrackingEnv: """IsaacLab-based Motion Tracking Environment. This environment integrates motion tracking capabilities with IsaacLab's manager-based architecture, supporting curriculum learning, domain randomization, and various termination conditions. This is a wrapper class that handles Isaac Sim initialization and delegates to an internal ManagerBasedRLEnv instance. """ def __init__( self, config, device: torch.device = None, log_dir: str = None, render_mode: str | None = None, headless: bool = True, accelerator=None, ): """Initialize the Motion Tracking Environment. Args: config: Configuration for the environment device: Device for tensor operations log_dir: Logging directory render_mode: Render mode for the environment headless: Whether to run in headless mode accelerator: Accelerator instance for distributed training (optional) """ self.config = config self._device = device self.accelerator = accelerator self.log_dir = log_dir self.headless = headless self.init_done = False self.is_evaluating = False self.render_mode = render_mode # self._init_motion_tracking_components() self._init_isaaclab_env() # self._init_serializers() self._completion_total_queue = deque(maxlen=1000) self._completion_success_queue = deque(maxlen=1000) self.metrics = {} self._robot_prev_joint_vel = None self._robot_prev_applied_torque = None self._robot_torque_rate_inv_effort_limit = None self._robot_torque_rate_needs_reseed = None @property def num_envs(self): return self._env.num_envs @property def device(self): return self._env.device def _init_isaaclab_env(self): _device = self._device curriculum = CurriculumCfg() # Determine per-process seed if provided; else create a deterministic per-rank default seed_val = getattr(self.config, "seed", None) if seed_val is None: if self.accelerator is not None: pid = self.accelerator.process_index else: pid = int(self.config.get("process_id", 0)) seed_val = int(time.time()) + pid _robot_config_dict = EasyDict( OmegaConf.to_container(self.config.robot, resolve=True) ) _terrain_config_dict = EasyDict( OmegaConf.to_container(self.config.terrain, resolve=True) ) _obs_config_dict = EasyDict( OmegaConf.to_container(self.config.obs, resolve=True) ) _rewards_config_dict = EasyDict( OmegaConf.to_container(self.config.rewards, resolve=True) ) _domain_rand_config_dict = ( EasyDict( OmegaConf.to_container( self.config.domain_rand, resolve=True, ) ) if self.config.domain_rand is not None else {} ) _terminations_config_dict = ( EasyDict( OmegaConf.to_container( self.config.terminations, resolve=True, ) ) if self.config.terminations is not None else {} ) _scene_config_dict = EasyDict( OmegaConf.to_container( self.config.scene, resolve=True, ) ) _commands_config_dict = OmegaConf.to_container( self.config.commands, resolve=True, ) _simulation_config_dict = EasyDict( OmegaConf.to_container( self.config.simulation, resolve=True, ) ) _actions_config_dict = EasyDict( OmegaConf.to_container( self.config.actions, resolve=True, ) ) if getattr(self.config, "curriculum", None) is not None: _curriculum_config_dict = EasyDict( OmegaConf.to_container(self.config.curriculum, resolve=True) ) else: _curriculum_config_dict = {} @configclass class MotionTrackingEnvCfg(ManagerBasedRLEnvCfg): seed: int = seed_val scene_config_dict = { "num_envs": self.config.num_envs, "env_spacing": self.config.env_spacing, "replicate_physics": self.config.replicate_physics, "robot": _robot_config_dict, "terrain": _terrain_config_dict, "domain_rand": _domain_rand_config_dict, "lighting": _scene_config_dict.lighting, "contact_sensor": _scene_config_dict.contact_sensor, } decimation: int = _simulation_config_dict.control_decimation episode_length_s: int = _simulation_config_dict.episode_length_s sim_freq = _simulation_config_dict.sim_freq dt = 1.0 / sim_freq physx = PhysxCfg( bounce_threshold_velocity=_simulation_config_dict.physx.bounce_threshold_velocity, gpu_max_rigid_patch_count=_simulation_config_dict.physx.gpu_max_rigid_patch_count, enable_stabilization=True, ) if self.accelerator is not None: main_process = self.accelerator.is_main_process process_id = self.accelerator.process_index num_processes = self.accelerator.num_processes else: main_process = self.config.get("main_process", True) process_id = self.config.get("process_id", 0) num_processes = self.config.get("num_processes", 1) scene: MotionTrackingSceneCfg = build_scene_config( scene_config_dict, main_process=main_process, process_id=process_id, num_processes=num_processes, ) sim: SimulationCfg = SimulationCfg( dt=dt, render_interval=decimation, physx=physx, device=_device, enable_scene_query_support=True, ) sim.physics_material = scene.terrain.physics_material viewer: ViewerCfg = ViewerCfg(origin_type="world") motion_cmds = {} vel_cmds = {} for k, v in _commands_config_dict.items(): if ( isinstance(v, dict) and v.get("type", "") == "MotionCommandCfg" ): motion_cmds[k] = v else: vel_cmds[k] = v # Populate RefMotionCommand distributed params when present. if "ref_motion" in motion_cmds: if self.accelerator is not None: cmd_process_id = self.accelerator.process_index cmd_num_processes = self.accelerator.num_processes else: cmd_process_id = getattr(self.config, "process_id", 0) cmd_num_processes = getattr( self.config, "num_processes", 1 ) motion_cmds["ref_motion"]["params"].update( { "seed": int(seed_val), "process_id": cmd_process_id, "num_processes": cmd_num_processes, "is_evaluating": self.is_evaluating, } ) # Build a unified commands cfg that may contain both motion and velocity terms. if motion_cmds: commands: MoTrack_CommandsCfg = ( build_motion_tracking_commands_config(motion_cmds) ) else: commands: MoTrack_CommandsCfg = MoTrack_CommandsCfg() if vel_cmds: vel_commands: VelTrack_CommandsCfg = ( build_velocity_commands_config(vel_cmds) ) for name in vel_cmds.keys(): setattr(commands, name, getattr(vel_commands, name)) observations: ObservationsCfg = build_observations_config( _obs_config_dict.obs_groups ) rewards: RewardsCfg = build_rewards_config(_rewards_config_dict) if _terminations_config_dict: terminations: TerminationsCfg = build_terminations_config( _terminations_config_dict ) else: terminations: TerminationsCfg = TerminationsCfg() if _domain_rand_config_dict: events: EventsCfg = build_domain_rand_config( _domain_rand_config_dict ) else: events: EventsCfg = EventsCfg() if "base_velocity" in vel_cmds: events.reset_base = EventTerm( func=isaaclab_mdp.reset_root_state_uniform, mode="reset", params={ "pose_range": { "x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14), }, "velocity_range": { "x": (0.0, 0.0), "y": (0.0, 0.0), "z": (0.0, 0.0), "roll": (0.0, 0.0), "pitch": (0.0, 0.0), "yaw": (0.0, 0.0), }, }, ) events.reset_robot_joints = EventTerm( func=isaaclab_mdp.reset_joints_by_scale, mode="reset", params={ "position_range": (1.0, 1.0), "velocity_range": (-1.0, 1.0), }, ) curriculum: CurriculumCfg = build_curriculum_config( _curriculum_config_dict ) actions: ActionsCfg = build_actions_config(_actions_config_dict) sim: SimulationCfg = SimulationCfg( dt=dt, render_interval=decimation, physx=physx, device=_device, enable_scene_query_support=True, ) sim.physx.gpu_max_rigid_patch_count = 10 * 2**15 sim.physx.enable_stabilization = True sim.physics_material = scene.terrain.physics_material isaaclab_env_cfg = MotionTrackingEnvCfg() isaaclab_envconfig_dump_path = os.path.join( self.log_dir, "isaaclab_env_cfg.yaml" ) dump_yaml(isaaclab_envconfig_dump_path, isaaclab_env_cfg) self._env = ManagerBasedRLEnv(isaaclab_env_cfg, self.render_mode) logger.info("IsaacLab environment initialized !") return self._env def _init_motion_tracking_components(self): self.n_fut_frames = self.config.commands.ref_motion.params.n_fut_frames self.target_fps = self.config.commands.ref_motion.params.target_fps self._init_serializers() def step(self, actor_state: dict): obs_dict, rewards, terminated, time_outs, infos = self._env.step( actor_state ) # IsaacLab separates terminated vs time_outs, combine them for consistency dones = terminated | time_outs self._update_completion_rate_stats(terminated, time_outs, infos) self._update_robot_metrics(infos) return obs_dict, rewards, dones, time_outs, infos def _update_robot_metrics(self, infos: dict) -> None: """Log robot low-level metrics (scalar means) for TensorBoard/console.""" if ("log" not in infos) or (not isinstance(infos["log"], dict)): infos["log"] = {} dt = float(self._env.step_dt) action = self._env.action_manager.action # [B, A] prev_action = self._env.action_manager.prev_action # [B, A] action_rate = torch.norm(action - prev_action, dim=-1) / dt # [B] robot = self._env.scene["robot"] dof_vel = robot.data.joint_vel # [B, Nd] dof_torque = robot.data.applied_torque # [B, Nd] if self._robot_prev_joint_vel is None or ( self._robot_prev_joint_vel.shape != dof_vel.shape ): self._robot_prev_joint_vel = dof_vel.clone() dof_acc = (dof_vel - self._robot_prev_joint_vel) / dt # [B, Nd] self._robot_prev_joint_vel = dof_vel.clone() if self._robot_prev_applied_torque is None or ( self._robot_prev_applied_torque.shape != dof_torque.shape ): joint_ids = torch.arange( dof_torque.shape[1], device=dof_torque.device, dtype=torch.long ) effort_limit = _select_effort_limit_vector(robot, joint_ids) self._robot_torque_rate_inv_effort_limit = ( effort_limit.reciprocal() ) self._robot_prev_applied_torque = torch.zeros_like(dof_torque) self._robot_torque_rate_needs_reseed = torch.ones( dof_torque.shape[0], device=dof_torque.device, dtype=torch.bool ) normed_torque_rate = torch.zeros( dof_torque.shape[0], device=dof_torque.device, dtype=dof_torque.dtype, ) reseed_mask = self._robot_torque_rate_needs_reseed.clone() if hasattr(self._env, "episode_length_buf"): reseed_mask |= self._env.episode_length_buf == 0 active_mask = ~reseed_mask if torch.any(active_mask): delta = ( dof_torque[active_mask] - self._robot_prev_applied_torque[active_mask] ) * self._robot_torque_rate_inv_effort_limit normed_torque_rate[active_mask] = torch.sum(delta.square(), dim=1) self._robot_prev_applied_torque.copy_(dof_torque) self._robot_torque_rate_needs_reseed[reseed_mask] = False dof_acc_norm = torch.norm(dof_acc, dim=-1) # [B] dof_torque_norm = torch.norm(dof_torque, dim=-1) # [B] energy = torch.sum( torch.abs(dof_vel) * torch.abs(dof_torque), dim=-1 ) # [B] self.metrics["Robot/Action_Rate"] = action_rate.mean() self.metrics["Robot/DOF_Acc"] = dof_acc_norm.mean() self.metrics["Robot/DOF_Torque"] = dof_torque_norm.mean() self.metrics["Robot/Energy"] = energy.mean() self.metrics["Robot/Normed_Torque_Rate"] = normed_torque_rate.mean() infos["log"]["Metrics/Robot/Action_Rate"] = self.metrics[ "Robot/Action_Rate" ] infos["log"]["Metrics/Robot/DOF_Acc"] = self.metrics["Robot/DOF_Acc"] infos["log"]["Metrics/Robot/DOF_Torque"] = self.metrics[ "Robot/DOF_Torque" ] infos["log"]["Metrics/Robot/Energy"] = self.metrics["Robot/Energy"] infos["log"]["Metrics/Robot/Normed_Torque_Rate"] = self.metrics[ "Robot/Normed_Torque_Rate" ] def _update_completion_rate_stats( self, terminated: torch.Tensor, time_outs: torch.Tensor, infos: dict, ) -> None: """Log completion rate over recent done batches. Definition: - Completed: time_outs==True and terminated==False. - Failed: terminated==True. The rolling window stores per-step done counts (only when any done occurs). """ done_mask = (terminated | time_outs).reshape(-1).bool() if torch.any(done_mask): done_count = int(done_mask.sum().item()) completed_mask = ( time_outs.reshape(-1).bool() & ~terminated.reshape(-1).bool() & done_mask ) completed_count = int(completed_mask.sum().item()) self._completion_total_queue.append(done_count) self._completion_success_queue.append(completed_count) denom = sum(self._completion_total_queue) completion_rate = ( float(sum(self._completion_success_queue)) / float(denom) if denom > 0 else 0.0 ) if ("log" not in infos) or (not isinstance(infos["log"], dict)): infos["log"] = {} infos["log"]["Metrics/ref_motion/Task/Completion_Rate"] = torch.tensor( completion_rate, device=self.device, dtype=torch.float32 ) self.metrics["Metrics/ref_motion/Task/Completion_Rate"] = ( completion_rate ) def reset_idx(self, env_ids: torch.Tensor): return self._env.reset(env_ids=env_ids) def reset_all(self): env_ids = torch.arange(self.num_envs, device=self.device) out = self._env.reset(env_ids=env_ids) return out def set_is_evaluating(self): logger.info("Setting environment to evaluation mode") self.is_evaluating = True def seed(self, seed: int): self._env.seed(seed) ================================================ FILE: holomotion/src/env/velocity_tracking.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import torch import time import os import yaml from collections import deque from functools import wraps from easydict import EasyDict import random import numpy as np from isaaclab.actuators import ImplicitActuatorCfg from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg from isaaclab.sim import PhysxCfg, SimulationCfg from isaaclab.utils import configclass from isaaclab.utils.io import dump_yaml from loguru import logger from omegaconf import OmegaConf from holomotion.src.env.isaaclab_components import ( ActionsCfg, VelTrack_CommandsCfg, MoTrack_CommandsCfg, EventsCfg, MotionTrackingSceneCfg, ObservationsCfg, RewardsCfg, TerminationsCfg, CurriculumCfg, build_actions_config, build_motion_tracking_commands_config, build_velocity_commands_config, build_domain_rand_config, build_curriculum_config, build_observations_config, build_rewards_config, build_scene_config, build_terminations_config, ) from holomotion.src.env.isaaclab_components.isaaclab_observation import ( ObservationFunctions, ) from holomotion.src.env.isaaclab_components.isaaclab_utils import ( resolve_holo_config, ) import isaaclab.envs.mdp as isaaclab_mdp from isaaclab.envs.mdp.events import _randomize_prop_by_op from isaaclab.managers import SceneEntityCfg, EventTermCfg from isaaclab.utils import configclass from isaaclab.envs import ManagerBasedEnv from isaaclab.managers import EventTermCfg from isaaclab.managers import EventTermCfg as EventTerm import isaaclab.utils.math as math_utils from isaaclab.assets import Articulation from isaaclab.envs.mdp.events import _randomize_prop_by_op from isaaclab.managers import SceneEntityCfg from typing import TYPE_CHECKING, Literal class VelocityTrackingEnv: """IsaacLab-based Motion Tracking Environment. This environment integrates motion tracking capabilities with IsaacLab's manager-based architecture, supporting curriculum learning, domain randomization, and various termination conditions. This is a wrapper class that handles Isaac Sim initialization and delegates to an internal ManagerBasedRLEnv instance. """ def __init__( self, config, device: torch.device = None, log_dir: str = None, render_mode: str | None = None, headless: bool = True, accelerator=None, ): """Initialize the Motion Tracking Environment. Args: config: Configuration for the environment device: Device for tensor operations log_dir: Logging directory render_mode: Render mode for the environment headless: Whether to run in headless mode accelerator: Accelerator instance for distributed training (optional) """ self.config = config self._device = device self.accelerator = accelerator self.log_dir = log_dir self.headless = headless self.init_done = False self.is_evaluating = False self.render_mode = render_mode # self._init_motion_tracking_components() self._init_isaaclab_env() # self._init_serializers() self._completion_total_queue = deque(maxlen=1000) self._completion_success_queue = deque(maxlen=1000) self.metrics = {} self._robot_prev_joint_vel = None @property def num_envs(self): return self._env.num_envs @property def device(self): return self._env.device def _init_isaaclab_env(self): _device = self._device # curriculum = CurriculumCfg() # Determine per-process seed if provided; else create a deterministic per-rank default seed_val = getattr(self.config, "seed", None) if seed_val is None: if self.accelerator is not None: pid = self.accelerator.process_index else: pid = int(self.config.get("process_id", 0)) seed_val = int(time.time()) + pid _robot_config_dict = EasyDict( OmegaConf.to_container(self.config.robot, resolve=True) ) _terrain_config_dict = EasyDict( OmegaConf.to_container(self.config.terrain, resolve=True) ) _obs_config_dict = EasyDict( OmegaConf.to_container(self.config.obs, resolve=True) ) _rewards_config_dict = EasyDict( OmegaConf.to_container(self.config.rewards, resolve=True) ) _domain_rand_config_dict = ( EasyDict( OmegaConf.to_container( self.config.domain_rand, resolve=True, ) ) if self.config.domain_rand is not None else {} ) _terminations_config_dict = ( EasyDict( OmegaConf.to_container( self.config.terminations, resolve=True, ) ) if self.config.terminations is not None else {} ) _scene_config_dict = EasyDict( OmegaConf.to_container( self.config.scene, resolve=True, ) ) _commands_config_dict = OmegaConf.to_container( self.config.commands, resolve=True, ) # Headless + no rendering: disable base_velocity debug visualization. # In k8s headless runs, IsaacSim/IsaacLab command debug_vis may wedge # during/after simulation start (seen on velocity-tracking only). # Keep an escape hatch for debugging/video. allow_debug_vis = (not self.headless) or (self.render_mode is not None) force_debug_vis = bool( int(os.environ.get("HOLOMOTION_VELCMD_DEBUG_VIS", "0")) ) if ( (not allow_debug_vis) and (not force_debug_vis) and isinstance(_commands_config_dict, dict) and ("base_velocity" in _commands_config_dict) ): bv = _commands_config_dict.get("base_velocity", {}) bv_params = bv.get("params", {}) if isinstance(bv_params, dict) and bool( bv_params.get("debug_vis", False) ): bv_params["debug_vis"] = False bv["params"] = bv_params _commands_config_dict["base_velocity"] = bv logger.warning( "Disabled base_velocity debug_vis for headless non-render runs. " "Set HOLOMOTION_VELCMD_DEBUG_VIS=1 to force-enable." ) _simulation_config_dict = EasyDict( OmegaConf.to_container( self.config.simulation, resolve=True, ) ) _actions_config_dict = EasyDict( OmegaConf.to_container( self.config.actions, resolve=True, ) ) @configclass class VelocityTrackingEnvCfg(ManagerBasedRLEnvCfg): seed: int = seed_val scene_config_dict = { "num_envs": self.config.num_envs, "env_spacing": self.config.env_spacing, "replicate_physics": self.config.replicate_physics, "robot": _robot_config_dict, "terrain": _terrain_config_dict, "domain_rand": _domain_rand_config_dict, "lighting": _scene_config_dict.lighting, "contact_sensor": _scene_config_dict.contact_sensor, } decimation: int = _simulation_config_dict.control_decimation episode_length_s: int = _simulation_config_dict.episode_length_s sim_freq = _simulation_config_dict.sim_freq dt = 1.0 / sim_freq physx = PhysxCfg( bounce_threshold_velocity=_simulation_config_dict.physx.bounce_threshold_velocity, gpu_max_rigid_patch_count=_simulation_config_dict.physx.gpu_max_rigid_patch_count, enable_stabilization=True, ) if self.accelerator is not None: main_process = self.accelerator.is_main_process process_id = self.accelerator.process_index num_processes = self.accelerator.num_processes else: main_process = self.config.get("main_process", True) process_id = self.config.get("process_id", 0) num_processes = self.config.get("num_processes", 1) scene: MotionTrackingSceneCfg = build_scene_config( scene_config_dict, main_process=main_process, process_id=process_id, num_processes=num_processes, ) sim: SimulationCfg = SimulationCfg( dt=dt, render_interval=decimation, physx=physx, device=_device, enable_scene_query_support=True, ) sim.physics_material = scene.terrain.physics_material viewer: ViewerCfg = ViewerCfg(origin_type="world") command_name = list(_commands_config_dict.keys())[0] commands: VelTrack_CommandsCfg = build_velocity_commands_config( _commands_config_dict ) observations: ObservationsCfg = build_observations_config( _obs_config_dict.obs_groups ) rewards: RewardsCfg = build_rewards_config(_rewards_config_dict) if _terminations_config_dict: terminations: TerminationsCfg = build_terminations_config( _terminations_config_dict ) else: terminations: TerminationsCfg = TerminationsCfg() if _domain_rand_config_dict: events: EventsCfg = build_domain_rand_config( _domain_rand_config_dict ) else: events: EventsCfg = EventsCfg() events.reset_base = EventTerm( func=isaaclab_mdp.reset_root_state_uniform, mode="reset", params={ "pose_range": { "x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14), }, "velocity_range": { "x": (0.0, 0.0), "y": (0.0, 0.0), "z": (0.0, 0.0), "roll": (0.0, 0.0), "pitch": (0.0, 0.0), "yaw": (0.0, 0.0), }, }, ) events.reset_robot_joints = EventTerm( func=isaaclab_mdp.reset_joints_by_scale, mode="reset", params={ "position_range": (1.0, 1.0), "velocity_range": (-1.0, 1.0), }, ) # curriculum: CurriculumCfg = build_curriculum_config( # getattr(self.config, "curriculum", {}) # ) actions: ActionsCfg = build_actions_config(_actions_config_dict) sim: SimulationCfg = SimulationCfg( dt=dt, render_interval=decimation, physx=physx, device=_device, enable_scene_query_support=True, ) sim.physx.gpu_max_rigid_patch_count = 10 * 2**15 sim.physx.enable_stabilization = True sim.physics_material = scene.terrain.physics_material isaaclab_env_cfg = VelocityTrackingEnvCfg() isaaclab_envconfig_dump_path = os.path.join( self.log_dir, "isaaclab_env_cfg.yaml" ) dump_yaml(isaaclab_envconfig_dump_path, isaaclab_env_cfg) logger.info( "Constructing IsaacLab ManagerBasedRLEnv (velocity_tracking) ..." ) self._env = ManagerBasedRLEnv(isaaclab_env_cfg, self.render_mode) logger.info( "IsaacLab ManagerBasedRLEnv constructed (velocity_tracking)." ) logger.info("IsaacLab environment initialized !") return self._env def _init_motion_tracking_components(self): self._init_serializers() def step(self, actor_state: dict): obs_dict, rewards, terminated, time_outs, infos = self._env.step( actor_state ) # IsaacLab separates terminated vs time_outs, combine them for consistency dones = terminated | time_outs self._update_completion_rate_stats(terminated, time_outs, infos) return obs_dict, rewards, dones, time_outs, infos def _update_completion_rate_stats( self, terminated: torch.Tensor, time_outs: torch.Tensor, infos: dict, ) -> None: """Log completion rate over recent done batches. Definition: - Completed: time_outs==True and terminated==False. - Failed: terminated==True. The rolling window stores per-step done counts (only when any done occurs). """ done_mask = (terminated | time_outs).reshape(-1).bool() if torch.any(done_mask): done_count = int(done_mask.sum().item()) completed_mask = ( time_outs.reshape(-1).bool() & ~terminated.reshape(-1).bool() & done_mask ) completed_count = int(completed_mask.sum().item()) self._completion_total_queue.append(done_count) self._completion_success_queue.append(completed_count) denom = sum(self._completion_total_queue) completion_rate = ( float(sum(self._completion_success_queue)) / float(denom) if denom > 0 else 0.0 ) if ("log" not in infos) or (not isinstance(infos["log"], dict)): infos["log"] = {} infos["log"]["Task/Completion_Rate"] = torch.tensor( completion_rate, device=self.device, dtype=torch.float32 ) def reset_idx(self, env_ids: torch.Tensor): return self._env.reset(env_ids=env_ids) def reset_all(self): env_ids = torch.arange(self.num_envs, device=self.device) out = self._env.reset(env_ids=env_ids) return out def set_is_evaluating(self): logger.info("Setting environment to evaluation mode") self.is_evaluating = True def seed(self, seed: int): self._env.seed(seed) ================================================ FILE: holomotion/src/evaluation/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/evaluation/eval_motion_tracking.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import argparse import subprocess from pathlib import Path from typing import List, Optional, Tuple from loguru import logger def find_checkpoints_to_evaluate( eval_h5_dataset_path: str, root_dir: Path, target_checkpoints: Optional[List[str]], config_name: str, ) -> List[Tuple[str, str]]: """Scan all model subdirectories and collect checkpoints that need evaluation. Behavior: - If `target_checkpoints` is provided and non-empty: only these checkpoint stems are considered (e.g. ['model_17500']). - If `target_checkpoints` is None or empty: all checkpoints matching 'model_*.pt' under each model directory will be considered. Returns: A list of (checkpoint_path, config_name) tuples to be evaluated. """ checkpoints_to_evaluate: List[Tuple[str, str]] = [] dataset_path = Path(eval_h5_dataset_path) dataset_suffix = ( dataset_path.name if dataset_path.name else "dataset_unknown" ) if root_dir.is_file(): checkpoint_file = root_dir model_dir_path = checkpoint_file.parent checkpoint_stem = checkpoint_file.stem eval_out_dir = ( model_dir_path / f"isaaclab_eval_output_{checkpoint_stem}_{dataset_suffix}" ) cfg_name = f"evaluation/{config_name}" return [(str(checkpoint_file), cfg_name)] if not root_dir.is_dir(): logger.error( f"Checkpoint root directory '{root_dir}' does not exist or is not a directory." ) return [] if target_checkpoints: logger.info( f"Searching for explicit target checkpoints: {target_checkpoints}" ) # Iterate over each model directory directly under root_dir for model_dir_path in root_dir.iterdir(): if not model_dir_path.is_dir(): continue if target_checkpoints: # Use only the requested checkpoint stems candidate_files = [ model_dir_path / f"{stem}.pt" for stem in target_checkpoints ] else: candidate_files = sorted(model_dir_path.glob("model_*.pt")) if not candidate_files: continue for checkpoint_file in candidate_files: if not checkpoint_file.is_file(): logger.debug(f"Target checkpoint not found: {checkpoint_file}") continue checkpoint_stem = checkpoint_file.stem eval_out_dir = ( model_dir_path / f"isaaclab_eval_output_{checkpoint_stem}_{dataset_suffix}" ) if eval_out_dir.is_dir(): logger.debug( f"Skipping {checkpoint_file.name}, output exists." ) continue # Construct Hydra config name from the folder name cfg_name = f"evaluation/{config_name}" checkpoints_to_evaluate.append((str(checkpoint_file), cfg_name)) checkpoints_to_evaluate.sort(key=lambda x: x[0]) return checkpoints_to_evaluate def main( checkpoint_dir: str, target_checkpoints: Optional[List[str]], eval_h5_dataset_path: str, config_name: str, num_envs: str, ) -> None: """ Entry point for batch evaluation. Args: checkpoint_root_dir: Root directory containing subdirectories for models. target_checkpoints: Optional list of checkpoint stems to evaluate single_eval_script: Path to the shell script to run a single evaluation. """ root_path = Path(checkpoint_dir) checkpoints_to_evaluate = find_checkpoints_to_evaluate( eval_h5_dataset_path=eval_h5_dataset_path, root_dir=root_path, target_checkpoints=target_checkpoints, config_name=config_name, ) if not checkpoints_to_evaluate: logger.warning( f"No pending evaluations found under '{checkpoint_dir}'." ) return logger.info( f"Found {len(checkpoints_to_evaluate)} checkpoints to evaluate." ) for i, (ckpt_path, cfg_name) in enumerate( checkpoints_to_evaluate, start=1 ): logger.info( f"[{i}/{len(checkpoints_to_evaluate)}] Evaluating: {cfg_name}/{ckpt_path}" ) command = [ "bash", "holomotion/scripts/evaluation/eval_motion_tracking_single.sh", ckpt_path, cfg_name, eval_h5_dataset_path, num_envs, ] subprocess.run( command, ) def parse_args() -> argparse.Namespace: """Parse CLI arguments for the batch evaluation script.""" parser = argparse.ArgumentParser(description="motion-tracking evaluation.") parser.add_argument("--checkpoint_dir", type=str, required=True) parser.add_argument( "--target_checkpoints", type=str, nargs="*", default=None ) parser.add_argument("--config_name", type=str, required=True) parser.add_argument("--eval_h5_dataset_path", type=str, required=True) parser.add_argument("--num_envs", type=str, required=True) return parser.parse_args() if __name__ == "__main__": args = parse_args() main( checkpoint_dir=args.checkpoint_dir, target_checkpoints=args.target_checkpoints, eval_h5_dataset_path=args.eval_h5_dataset_path, config_name=args.config_name, num_envs=args.num_envs, ) ================================================ FILE: holomotion/src/evaluation/eval_motion_tracking_single.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import re from pathlib import Path import hydra from hydra.utils import get_class from loguru import logger from omegaconf import ListConfig, OmegaConf from holomotion.src.evaluation.metrics import run_evaluation from holomotion.src.utils.config import compile_config from holomotion.src.utils.onnx_export import export_policy_to_onnx def load_training_config( checkpoint_path: str, eval_config: OmegaConf ) -> OmegaConf: """Load training config from checkpoint directory. Args: checkpoint_path: Path to the checkpoint file. eval_config: Full evaluation config (including command line overrides). Returns: Merged config with training config as base. """ checkpoint = Path(checkpoint_path) config_path = checkpoint.parent / "config.yaml" if not config_path.exists(): config_path = checkpoint.parent.parent / "config.yaml" if not config_path.exists(): logger.warning( f"Training config not found at {config_path}, using evaluation config" ) return eval_config logger.info(f"Loading training config from {config_path}") with open(config_path) as file: train_config = OmegaConf.load(file) # Apply eval_overrides from training config if they exist if train_config.get("eval_overrides") is not None: train_config = OmegaConf.merge( train_config, train_config.eval_overrides ) # Set checkpoint path train_config.checkpoint = checkpoint_path train_config.algo.config.checkpoint = checkpoint_path # For evaluation, merge eval_config into train_config config = OmegaConf.merge(train_config, eval_config) # force set the terminations and domain rand with eval_config's config.env.config.terminations = eval_config.env.config.terminations config.env.config.domain_rand = eval_config.env.config.domain_rand obs_groups = config.env.config.obs.obs_groups if "policy" in obs_groups: obs_groups.policy.enable_corruption = False if "critic" in obs_groups: obs_groups.critic.enable_corruption = False if "unified" in obs_groups: obs_groups.unified.enable_corruption = False return config def _infer_dataset_suffix(output_dir: str, checkpoint_path: str) -> str: output_name = Path(output_dir).name model_name = Path(checkpoint_path).stem expected_prefix = f"isaaclab_eval_output_{model_name}_" if output_name.startswith(expected_prefix): return output_name[len(expected_prefix) :] return output_name def _checkpoint_sort_key(checkpoint_path: Path): match = re.search(r"model_(\d+)\.pt$", checkpoint_path.name) if match is not None: return (0, int(match.group(1)), checkpoint_path.name) return (1, checkpoint_path.name) def _normalize_ckpt_pt_names(ckpt_pt_names) -> list[str]: if ckpt_pt_names is None: return [] if isinstance(ckpt_pt_names, ListConfig): raw_names = list(ckpt_pt_names) elif isinstance(ckpt_pt_names, (list, tuple)): raw_names = list(ckpt_pt_names) else: raise TypeError( f"ckpt_pt_names must be a list/tuple, got {type(ckpt_pt_names)}" ) normalized_names = [] for name in raw_names: name_str = str(name).strip() if name_str == "": continue if not name_str.endswith(".pt"): name_str = f"{name_str}.pt" normalized_names.append(name_str) return normalized_names def _resolve_export_ckpt_paths(config: OmegaConf) -> list[Path]: log_dir_value = config.get("log_dir", None) checkpoint_value = config.get("checkpoint", None) if log_dir_value is None or str(log_dir_value).strip() == "": if checkpoint_value is None or str(checkpoint_value).strip() == "": raise ValueError( "When export_only=true, set log_dir or checkpoint." ) log_dir = Path(str(checkpoint_value)).parent else: log_dir = Path(str(log_dir_value)) if not log_dir.is_dir(): raise NotADirectoryError( f"log_dir does not exist or is not a directory: {log_dir}" ) ckpt_pt_names = _normalize_ckpt_pt_names(config.get("ckpt_pt_names", None)) if len(ckpt_pt_names) > 0: selected_paths = [] missing_names = [] for name in ckpt_pt_names: ckpt_path = log_dir / name if ckpt_path.is_file(): selected_paths.append(ckpt_path) else: missing_names.append(name) if len(missing_names) > 0: raise FileNotFoundError( f"Missing checkpoints in log_dir={log_dir}: {missing_names}" ) return selected_paths discovered_paths = sorted(log_dir.glob("*.pt"), key=_checkpoint_sort_key) if len(discovered_paths) == 0: raise FileNotFoundError( f"No .pt checkpoints found in log_dir={log_dir}" ) return discovered_paths @hydra.main( config_path="../../config", config_name="evaluation/eval_isaaclab", version_base=None, ) def main(config: OmegaConf): """Evaluate the motion tracking model. Args: config: OmegaConf object containing the evaluation configuration. """ export_only = bool(config.get("export_only", False)) if export_only: checkpoint_paths = _resolve_export_ckpt_paths(config) config = load_training_config(str(checkpoint_paths[0]), config) else: if config.checkpoint is None: raise ValueError("Checkpoint path must be provided for evaluation") checkpoint_paths = [Path(str(config.checkpoint))] config = load_training_config(config.checkpoint, config) # Compile config without accelerator (PPO will create it) config = compile_config(config, accelerator=None) # Use checkpoint directory as log_dir for offline evaluation/export. log_dir = str(checkpoint_paths[0].parent) headless = config.headless # PPO creates Accelerator, AppLauncher, and environment internally algo_class = get_class(config.algo._target_) algo = algo_class( env_config=config.env, config=config.algo.config, log_dir=log_dir, headless=headless, is_offline_eval=True, ) if ( algo.accelerator.is_main_process and os.environ.get("TORCH_COMPILE_DISABLE", "0") != "1" ): logger.info( "Tip: If you encounter Triton/compilation errors during evaluation," ) logger.info( " set environment variable: export TORCH_COMPILE_DISABLE=1" ) if algo.accelerator.is_main_process: with open(os.path.join(log_dir, "eval_config.yaml"), "w") as f: OmegaConf.save(config, f) if export_only: if algo.accelerator.is_main_process: logger.info( "Running export-only mode for " f"{len(checkpoint_paths)} checkpoints in {log_dir}" ) onnx_name_suffix = config.get("onnx_name_suffix", None) use_kv_cache = config.get("use_kv_cache", True) for i, checkpoint_path in enumerate(checkpoint_paths, start=1): ckpt_path = str(checkpoint_path) if algo.accelerator.is_main_process: logger.info( f"[{i}/{len(checkpoint_paths)}] Loading checkpoint: " f"{ckpt_path}" ) algo.load(ckpt_path) if algo.accelerator.is_main_process: onnx_path = export_policy_to_onnx( algo, ckpt_path, onnx_name_suffix=onnx_name_suffix, use_kv_cache=use_kv_cache, ) logger.info(f"Successfully exported policy to: {onnx_path}") algo.accelerator.wait_for_everyone() if algo.accelerator.is_main_process: logger.info("Export-only mode completed successfully!") return if algo.accelerator.is_main_process: logger.info(f"Loading checkpoint for evaluation: {config.checkpoint}") algo.load(config.checkpoint) command_name = list(config.env.config.commands.keys())[0] if command_name == "ref_motion": motion_cmd = algo.env._env.command_manager.get_term("ref_motion") algo.env._env.reset() motion_cmd._update_ref_motion_state() # Export ONNX if requested if config.get("export_policy", True): if algo.accelerator.is_main_process: onnx_name_suffix = config.get("onnx_name_suffix", None) onnx_path = export_policy_to_onnx( algo, config.checkpoint, onnx_name_suffix=onnx_name_suffix, use_kv_cache=config.get("use_kv_cache", True), ) logger.info(f"Successfully exported policy to: {onnx_path}") algo.accelerator.wait_for_everyone() calc_per_clip_metrics = bool(config.get("calc_per_clip_metrics", False)) generate_report = bool(config.get("generate_report", False)) dump_npzs = bool(config.get("dump_npzs", False)) or calc_per_clip_metrics dof_mode = config.get("dof_mode", "29") if ( calc_per_clip_metrics and not bool(config.get("dump_npzs", False)) and algo.accelerator.is_main_process ): logger.info( "calc_per_clip_metrics=true requires dumped NPZs; " "enabling dump_npzs automatically." ) result = algo.offline_evaluate_policy(dump_npzs) algo.accelerator.wait_for_everyone() if algo.accelerator.is_main_process: logger.info("Evaluation completed successfully!") output_dir = ( result.get("output_dir") if isinstance(result, dict) else None ) if output_dir is not None: logger.info(f"NPZs saved to: {output_dir}") if calc_per_clip_metrics: if output_dir is None: logger.warning( "Skipping per-clip metric calculation because " "output_dir is unavailable." ) else: dataset_suffix = _infer_dataset_suffix( output_dir, config.checkpoint ) run_evaluation( npz_dir=output_dir, dataset_suffix=dataset_suffix, failure_pos_err_thresh_m=0.25, dof_mode=dof_mode, ) logger.info( f"Finished per-clip metric calculation for: {output_dir}" ) if generate_report: if output_dir is None: logger.warning( "Skipping report generation because output_dir is unavailable." ) else: from holomotion.scripts.evaluation import ( mean_process_5metrics, ) report_path = mean_process_5metrics.generate_macro_mean_report_from_json_dir( output_dir ) logger.info(f"Generated metrics report at: {report_path}") if __name__ == "__main__": main() ================================================ FILE: holomotion/src/evaluation/eval_mujoco_sim2sim.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os import csv import shutil import sys import threading import time from collections import deque from pathlib import Path from threading import Thread import cv2 import hydra import mujoco import mujoco.viewer import numpy as np import onnx import onnxruntime import torch from loguru import logger from omegaconf import ListConfig, OmegaConf, open_dict from tqdm import tqdm import glob import re import ray from holomotion.src.evaluation.metrics import run_evaluation try: from horizon_tc_ui.hb_runtime import HBRuntime except ImportError: HB_ONNXRuntime = None logger.warning("HB_ONNXRuntime not available!") ONNX_IO_DUMP_DIRNAME = "onnx_io_npy" try: import pynput.keyboard as pynput_kb PYNPUT_AVAILABLE = True except ImportError: PYNPUT_AVAILABLE = False if "headless" in sys.argv and "false" in sys.argv: logger.warning("pynput not available, keyboard control disabled") from holomotion.src.evaluation.obs import PolicyObsBuilder from holomotion.src.utils.torch_utils import ( quat_apply, quat_inv, subtract_frame_transforms, quat_normalize_wxyz, matrix_from_quat, xyzw_to_wxyz, quat_mul, quat_from_euler_xyz, ) from holomotion.src.motion_retargeting.utils.rotation_conversions import ( standardize_quaternion, ) DEFAULT_FEET_GEOM_NAMES = { "left": ["left_foot"], "right": ["right_foot"], } DEFAULT_FEET_BODY_NAMES = { "left": ["left_ankle_roll_link"], "right": ["right_ankle_roll_link"], } def _coerce_config_bool(value, default: bool = False) -> bool: """Interpret config booleans without treating non-empty strings as truthy.""" if value is None: return default if isinstance(value, (bool, np.bool_)): return bool(value) if isinstance(value, str): value = value.strip().lower() if value in {"1", "true", "yes", "y", "on"}: return True if value in {"0", "false", "no", "n", "off", ""}: return False return bool(value) class OffscreenRenderer: """Minimal offscreen renderer for MuJoCo frames.""" def __init__( self, model, height: int, width: int, distance: float | None = None, azimuth: float | None = None, elevation: float | None = None, ): self.model = model self.height = height self.width = width self._overlay_callback = None self._gl_ctx = mujoco.GLContext(width, height) self._gl_ctx.make_current() self._scene = mujoco.MjvScene(model, maxgeom=1000) self._cam = mujoco.MjvCamera() self._opt = mujoco.MjvOption() mujoco.mjv_defaultFreeCamera(model, self._cam) self.set_align_view( distance=distance, azimuth=azimuth, elevation=elevation, ) self._con = mujoco.MjrContext( model, mujoco.mjtFontScale.mjFONTSCALE_100, ) self._rgb = np.zeros((height, width, 3), dtype=np.uint8) self._viewport = mujoco.MjrRect(0, 0, width, height) def set_overlay_callback(self, callback) -> None: """Register a callback to draw custom geoms into the scene each frame.""" self._overlay_callback = callback def render(self, data) -> np.ndarray: mujoco.mjv_updateScene( self.model, data, self._opt, None, self._cam, mujoco.mjtCatBit.mjCAT_ALL.value, self._scene, ) if self._overlay_callback is not None: self._overlay_callback(self._scene) mujoco.mjr_render(self._viewport, self._scene, self._con) mujoco.mjr_readPixels(self._rgb, None, self._viewport, self._con) return np.flipud(self._rgb) def set_align_view( self, lookat: np.ndarray | None = None, distance: float | None = None, azimuth: float | None = None, elevation: float | None = None, ): """Set camera to 'align' preset view (default azimuth=60, elevation=-20). Args: lookat: Optional lookat point [x, y, z]. If None, uses current lookat. distance: Optional camera distance from lookat point. If None, uses current distance. """ self._cam.type = mujoco.mjtCamera.mjCAMERA_FREE if azimuth is None: self._cam.azimuth = 60.0 # Side view (looking along Y-axis) else: self._cam.azimuth = float(azimuth) if elevation is None: self._cam.elevation = -20.0 # Slight downward angle else: self._cam.elevation = float(elevation) if lookat is not None: self._cam.lookat = np.asarray(lookat, dtype=np.float32) if distance is not None: self._cam.distance = float(distance) def close(self): self._gl_ctx.free() class VelocityKeyboardHandler: """Keyboard handler for interactive velocity commands using WASD and JL keys.""" def __init__( self, vx_increment: float = 0.1, vy_increment: float = 0.1, vyaw_increment: float = 0.05, vx_limits: tuple = (-0.5, 1.0), vy_limits: tuple = (-0.3, 0.3), vyaw_limits: tuple = (-0.5, 0.5), ): self.vx_increment = vx_increment self.vy_increment = vy_increment self.vyaw_increment = vyaw_increment # Velocity limits from training config self.vx_min, self.vx_max = vx_limits self.vy_min, self.vy_max = vy_limits self.vyaw_min, self.vyaw_max = vyaw_limits self.vx = 0.0 self.vy = 0.0 self.vyaw = 0.0 self._listener = None self._lock = threading.Lock() def start_listener(self): """Start keyboard listener thread (requires pynput).""" if not PYNPUT_AVAILABLE: logger.warning("pynput not available, keyboard control disabled") return def on_press(key): try: if hasattr(key, "char") and key.char: self._handle_key(key.char) except AttributeError: pass self._listener = pynput_kb.Listener(on_press=on_press) self._listener.start() logger.info( f"Keyboard listener started. Velocity limits: " f"vx=[{self.vx_min:.1f},{self.vx_max:.1f}], " f"vy=[{self.vy_min:.1f},{self.vy_max:.1f}], " f"vyaw=[{self.vyaw_min:.1f},{self.vyaw_max:.1f}]" ) def stop_listener(self): """Stop keyboard listener thread.""" if self._listener is not None: self._listener.stop() self._listener = None def get_velocity_command(self) -> np.ndarray: """Get velocity command [vx, vy, vyaw]. Returns: Velocity command [vx, vy, vyaw] """ with self._lock: return np.array([self.vx, self.vy, self.vyaw], dtype=np.float32) def _handle_key(self, char: str): """Handle keyboard press events.""" with self._lock: # W/S for vx (forward/backward) if char in ["W", "w"]: self.vx = np.clip( self.vx + self.vx_increment, self.vx_min, self.vx_max ) logger.info( f"[W] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}" ) elif char in ["S", "s"]: self.vx = np.clip( self.vx - self.vx_increment, self.vx_min, self.vx_max ) logger.info( f"[S] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}" ) # A/D for vy (left/right) elif char in ["A", "a"]: self.vy = np.clip( self.vy + self.vy_increment, self.vy_min, self.vy_max ) logger.info( f"[A] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}" ) elif char in ["D", "d"]: self.vy = np.clip( self.vy - self.vy_increment, self.vy_min, self.vy_max ) logger.info( f"[D] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}" ) # J/L for vyaw (turn left/right) elif char in ["J", "j"]: self.vyaw = np.clip( self.vyaw + self.vyaw_increment, self.vyaw_min, self.vyaw_max, ) logger.info( f"[J] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}" ) elif char in ["L", "l"]: self.vyaw = np.clip( self.vyaw - self.vyaw_increment, self.vyaw_min, self.vyaw_max, ) logger.info( f"[L] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}" ) # Space to reset all elif char == " ": self.vx = 0.0 self.vy = 0.0 self.vyaw = 0.0 logger.info("[Space] Command reset to zero") # X to stop (emergency brake) elif char in ["X", "x"]: self.vx = 0.0 self.vy = 0.0 self.vyaw = 0.0 logger.info("[X] Emergency stop - all velocities set to zero") class MujocoEvaluator: """Class to handle MuJoCo simulation for policy evaluation.""" def __init__(self, config): """Initialize the MuJoCo evaluator. Args: config: Configuration object with simulation parameters. """ self.config = config # Initialize variables self.policy_session = None self.motion_encoding = None self.m = None # MuJoCo model self.d = None # MuJoCo data # Determine command mode from config self.command_mode = self._detect_command_mode() if "motion_npz_dir" not in config: logger.info(f"Command mode: {self.command_mode}") # Motion data self.ref_dof_pos = None self.ref_dof_vel = None self.filter_cutoff_hz = None self.n_motion_frames = 0 self.motion_frame_idx = 0 # Velocity command (for velocity tracking mode) self.velocity_command = np.zeros(3, dtype=np.float32) # [vx, vy, vyaw] self.target_heading = 0.0 # Target heading for velocity tracking self.keyboard_handler = ( None # Will be initialized if velocity_tracking ) # Extract configuration parameters self.simulation_dt = 1 / 200 self.policy_dt = 1 / 50 self.control_decimation = 4 self.dof_names_ref_motion = list(config.robot.dof_names) self.num_actions = len(self.dof_names_ref_motion) self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32) self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32) self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32) self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32) self.target_dof_pos_onnx = self.default_angles_onnx.copy() self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32) self.n_fut_frames = int(config.obs.n_fut_frames) self.actor_place_holder_ndim = self._find_actor_place_holder_ndim() self.use_kv_cache = False self.policy_kv_cache = None self.policy_kv_input_name = None self.policy_kv_output_name = None self.policy_kv_shape = None self.policy_model_context_len = 0 algo_cfg = self.config.get("algo", None) if algo_cfg is None: raise ValueError("Missing config.algo for MuJoCo evaluation.") algo_config = algo_cfg.get("config", None) if algo_config is None: raise ValueError( "Missing config.algo.config for MuJoCo evaluation." ) max_context_len_cfg = algo_config.get("num_steps_per_env", None) if max_context_len_cfg is None: raise ValueError( "Missing config.algo.config.num_steps_per_env for MuJoCo evaluation." ) self.max_context_len = int(max_context_len_cfg) if self.max_context_len <= 0: raise ValueError( "config.algo.config.num_steps_per_env must be > 0, " f"got {self.max_context_len}" ) self.policy_effective_context_len = 0 self.counter = 0 self.tau_hist = [] # Latest Unitree lowstate message (populated when using Unitree bridge) # self._lowstate_msg = None # Desired target positions keyed by DOF name (updated after each policy step) self.target_dof_pos_by_name = {} # Video/recording related self._video_writer = None self._offscreen = None self._frame_interval = None self._last_frame_time = 0.0 # Reference(global)->Simulation(global) rigid transform (computed at init) self._ref_to_sim_ready = False self._ref_to_sim_q_wxyz = np.array( [1.0, 0.0, 0.0, 0.0], dtype=np.float32 ) self._ref_to_sim_t = np.zeros(3, dtype=np.float32) # Optional offset between reference globals and dataset body names (e.g., world body at index 0) # Robot state recording buffers for offline NPZ dumping self._robot_dof_pos_seq: list[np.ndarray] = [] self._robot_dof_vel_seq: list[np.ndarray] = [] self._robot_dof_acc_seq: list[np.ndarray] = [] self._robot_dof_torque_seq: list[np.ndarray] = [] self._robot_low_level_dof_torque_seq: list[np.ndarray] = [] self._robot_low_level_foot_contact_seq: list[np.ndarray] = [] self._robot_low_level_foot_normal_force_seq: list[np.ndarray] = [] self._robot_low_level_foot_tangent_speed_seq: list[np.ndarray] = [] self._robot_actions_seq: list[np.ndarray] = [] self._robot_action_rate_seq: list[np.float32] = [] self._robot_global_translation_seq: list[np.ndarray] = [] self._robot_global_rotation_quat_seq: list[np.ndarray] = [] self._robot_global_velocity_seq: list[np.ndarray] = [] self._robot_global_angular_velocity_seq: list[np.ndarray] = [] self._robot_moe_expert_indices_seq: list[np.ndarray] = [] self._robot_moe_expert_logits_seq: list[np.ndarray] = [] self._prev_recorded_dof_vel_ref: np.ndarray | None = None self._prev_actions_onnx: np.ndarray | None = None ( self.action_ema_filter_enabled, self.action_ema_filter_alpha, ) = self._get_action_ema_filter_cfg() self._filtered_actions_onnx: np.ndarray | None = None ( self.policy_action_delay_step, self.action_delay_type, ) = self._get_action_delay_cfg() self._policy_action_delay_buffer: deque[np.ndarray] = deque( maxlen=max(1, self.policy_action_delay_step + 1) ) self._current_policy_action_delay_step = 0 self._reset_action_delay_randomization() # Camera config (viewer + offscreen) self._camera_tracking_enabled = bool( self.config.get("camera_tracking", True) ) self._camera_height_offset = float( self.config.get("camera_height_offset", 0.3) ) self._camera_distance = float(self.config.get("camera_distance", 4.0)) self._camera_azimuth = float(self.config.get("camera_azimuth", 60.0)) self._camera_elevation = float( self.config.get("camera_elevation", -20.0) ) self._root_body_id = -1 self._foot_contact_logging_enabled = False self._foot_geom_id_groups: list[list[int]] = [[], []] self._foot_geom_id_to_side: dict[int, int] = {} self._prev_low_level_foot_geom_centers: np.ndarray | None = None self.dump_onnx_io_npy = bool( self.config.get("dump_onnx_io_npy", False) ) self.policy_moe_layer_output_names: list[tuple[int, str, str]] = [] self._reset_onnx_io_dump_buffers() def _reset_onnx_io_dump_buffers(self): self._onnx_io_input_names: list[str] = [] self._onnx_io_output_names: list[str] = [] self._onnx_io_inputs: dict[str, list[np.ndarray]] = {} self._onnx_io_outputs: dict[str, list[np.ndarray]] = {} def _get_action_ema_filter_cfg(self) -> tuple[bool, float]: actuator_cfg = self.config.get("robot", {}).get("actuators", {}) actuator_type = actuator_cfg.get("actuator_type", "unitree") if actuator_type != "unitree_erfi": return False, 1.0 enabled = _coerce_config_bool( actuator_cfg.get("ema_filter_enabled", False), default=False ) alpha = float(actuator_cfg.get("ema_filter_alpha", 1.0)) if not 0.0 <= alpha <= 1.0: raise ValueError( "robot.actuators.ema_filter_alpha must be within [0, 1], " f"got {alpha}." ) return enabled, alpha def _reset_action_ema_filter(self) -> None: self._filtered_actions_onnx = None def _apply_action_ema_filter(self, raw_actions: np.ndarray) -> np.ndarray: raw_actions = np.asarray(raw_actions, dtype=np.float32) if not self.action_ema_filter_enabled: return raw_actions.copy() if self._filtered_actions_onnx is None: self._filtered_actions_onnx = raw_actions.copy() return self._filtered_actions_onnx.copy() # self.action_ema_filter_alpha = 0.7 filtered_actions = ( self.action_ema_filter_alpha * raw_actions + (1.0 - self.action_ema_filter_alpha) * self._filtered_actions_onnx ).astype(np.float32, copy=False) self._filtered_actions_onnx = filtered_actions.copy() return self._filtered_actions_onnx.copy() def _get_action_delay_cfg(self) -> tuple[int, str]: max_delay_step = int(self.config.get("policy_action_delay_step", 0)) if max_delay_step < 0: raise ValueError( "policy_action_delay_step must be non-negative, " f"got {max_delay_step}." ) delay_type = ( str(self.config.get("action_delay_type", "episode")) .strip() .lower() ) if delay_type not in {"step", "episode"}: raise ValueError( "action_delay_type must be one of {'step', 'episode'}, " f"got {delay_type!r}." ) return max_delay_step, delay_type def _sample_policy_action_delay_step(self) -> int: if self.policy_action_delay_step <= 0: return 0 return int(np.random.randint(0, self.policy_action_delay_step + 1)) def _reset_action_delay_randomization(self) -> None: self._policy_action_delay_buffer = deque( maxlen=max(1, self.policy_action_delay_step + 1) ) if self.policy_action_delay_step <= 0: self._current_policy_action_delay_step = 0 return if self.action_delay_type == "episode": self._current_policy_action_delay_step = ( self._sample_policy_action_delay_step() ) else: self._current_policy_action_delay_step = 0 def _apply_action_delay(self, raw_actions: np.ndarray) -> np.ndarray: raw_actions = np.asarray(raw_actions, dtype=np.float32) if self.policy_action_delay_step <= 0: return raw_actions.copy() expected_buffer_len = max(1, self.policy_action_delay_step + 1) if ( not hasattr(self, "_policy_action_delay_buffer") or self._policy_action_delay_buffer.maxlen != expected_buffer_len ): self._reset_action_delay_randomization() if self.action_delay_type == "step": self._current_policy_action_delay_step = ( self._sample_policy_action_delay_step() ) self._policy_action_delay_buffer.append(raw_actions.copy()) if self._current_policy_action_delay_step >= len( self._policy_action_delay_buffer ): return self._policy_action_delay_buffer[-1].copy() return self._policy_action_delay_buffer[ -1 - self._current_policy_action_delay_step ].copy() @staticmethod def _normalize_foot_geom_name_groups(raw_spec) -> list[list[str]]: if raw_spec is None: return [[], []] if OmegaConf.is_config(raw_spec): raw_spec = OmegaConf.to_container(raw_spec, resolve=True) def coerce_names(value) -> list[str]: if value is None: return [] if isinstance(value, str): return [value] if isinstance(value, (list, tuple)): return [str(name) for name in value if str(name)] return [] if isinstance(raw_spec, dict): return [ coerce_names(raw_spec.get("left", raw_spec.get("left_foot"))), coerce_names( raw_spec.get("right", raw_spec.get("right_foot")) ), ] if isinstance(raw_spec, (list, tuple)) and len(raw_spec) == 2: return [coerce_names(raw_spec[0]), coerce_names(raw_spec[1])] logger.warning( "Unsupported robot.feet_geom_names format. Ignoring configured " "foot geom names." ) return [[], []] @staticmethod def _normalize_foot_body_name_groups(raw_spec) -> list[list[str]]: if raw_spec is None: return [ list(DEFAULT_FEET_BODY_NAMES["left"]), list(DEFAULT_FEET_BODY_NAMES["right"]), ] if OmegaConf.is_config(raw_spec): raw_spec = OmegaConf.to_container(raw_spec, resolve=True) def coerce_names(value) -> list[str]: if value is None: return [] if isinstance(value, str): return [value] if isinstance(value, (list, tuple)): return [str(name) for name in value if str(name)] return [] if isinstance(raw_spec, dict): return [ coerce_names(raw_spec.get("left", raw_spec.get("left_foot"))), coerce_names( raw_spec.get("right", raw_spec.get("right_foot")) ), ] if isinstance(raw_spec, (list, tuple)) and len(raw_spec) == 2: return [coerce_names(raw_spec[0]), coerce_names(raw_spec[1])] logger.warning( "Unsupported robot.feet_body_names format. Falling back to " f"default foot bodies: {DEFAULT_FEET_BODY_NAMES}" ) return [ list(DEFAULT_FEET_BODY_NAMES["left"]), list(DEFAULT_FEET_BODY_NAMES["right"]), ] def _resolve_foot_geom_ids_from_geom_names( self, foot_geom_name_groups: list[list[str]] ) -> list[list[int]]: foot_geom_id_groups: list[list[int]] = [[], []] for side_idx, geom_names in enumerate(foot_geom_name_groups): for geom_name in geom_names: geom_id = mujoco.mj_name2id( self.m, mujoco.mjtObj.mjOBJ_GEOM, geom_name ) if geom_id == -1: logger.warning( f"Foot geom '{geom_name}' was not found in the MuJoCo model." ) continue foot_geom_id_groups[side_idx].append(int(geom_id)) return foot_geom_id_groups def _resolve_foot_geom_ids_from_body_names( self, foot_body_name_groups: list[list[str]] ) -> list[list[int]]: foot_geom_id_groups: list[list[int]] = [[], []] geom_bodyid = np.asarray(self.m.geom_bodyid, dtype=np.int32) geom_contype = np.asarray(self.m.geom_contype, dtype=np.int32) geom_conaffinity = np.asarray(self.m.geom_conaffinity, dtype=np.int32) collidable_mask = (geom_contype != 0) | (geom_conaffinity != 0) for side_idx, body_names in enumerate(foot_body_name_groups): resolved_geom_ids: list[int] = [] for body_name in body_names: body_id = mujoco.mj_name2id( self.m, mujoco.mjtObj.mjOBJ_BODY, body_name ) if body_id == -1: logger.warning( f"Foot body '{body_name}' was not found in the MuJoCo model." ) continue body_geom_ids = np.flatnonzero(geom_bodyid == int(body_id)) if body_geom_ids.size == 0: logger.warning( f"Foot body '{body_name}' has no attached geoms." ) continue contact_geom_ids = body_geom_ids[ collidable_mask[body_geom_ids] ] if contact_geom_ids.size == 0: contact_geom_ids = body_geom_ids resolved_geom_ids.extend(contact_geom_ids.astype(int).tolist()) # Preserve order while removing duplicates. deduped = list(dict.fromkeys(resolved_geom_ids)) foot_geom_id_groups[side_idx] = deduped return foot_geom_id_groups def _init_low_level_foot_contact_logging(self) -> None: self._foot_geom_id_groups = [[], []] self._foot_geom_id_to_side = {} self._foot_contact_logging_enabled = False self._prev_low_level_foot_geom_centers = None foot_geom_name_groups = self._normalize_foot_geom_name_groups( getattr(self.config.robot, "feet_geom_names", None) ) foot_body_name_groups = self._normalize_foot_body_name_groups( getattr(self.config.robot, "feet_body_names", None) ) geom_name_groups = self._resolve_foot_geom_ids_from_geom_names( foot_geom_name_groups ) body_name_groups = self._resolve_foot_geom_ids_from_body_names( foot_body_name_groups ) for side_idx in range(2): resolved_ids = ( geom_name_groups[side_idx] if len(geom_name_groups[side_idx]) > 0 else body_name_groups[side_idx] ) self._foot_geom_id_groups[side_idx] = list(resolved_ids) for geom_id in resolved_ids: self._foot_geom_id_to_side[int(geom_id)] = side_idx if any(len(group) == 0 for group in self._foot_geom_id_groups): logger.warning( "Low-level foot contact logging is unavailable because one or " "both foot geom groups could not be resolved. Contact metrics " "will be written as NaN." ) return self._foot_contact_logging_enabled = True def _record_low_level_foot_contact_sample(self) -> None: foot_contact = np.full((2,), np.nan, dtype=np.float32) foot_normal_force = np.full((2,), np.nan, dtype=np.float32) foot_tangent_speed = np.full((2,), np.nan, dtype=np.float32) if not self._foot_contact_logging_enabled: self._robot_low_level_foot_contact_seq.append(foot_contact) self._robot_low_level_foot_normal_force_seq.append( foot_normal_force ) self._robot_low_level_foot_tangent_speed_seq.append( foot_tangent_speed ) return current_centers = np.zeros((2, 3), dtype=np.float32) for side_idx, geom_ids in enumerate(self._foot_geom_id_groups): current_centers[side_idx] = np.mean( self.d.geom_xpos[np.asarray(geom_ids, dtype=np.int32)], axis=0, ).astype(np.float32) if self._prev_low_level_foot_geom_centers is None: tangential_speed = np.zeros((2,), dtype=np.float32) else: foot_velocity = ( current_centers - self._prev_low_level_foot_geom_centers ) / np.float32(self.simulation_dt) tangential_speed = np.linalg.norm( foot_velocity[:, :2], axis=1 ).astype(np.float32) self._prev_low_level_foot_geom_centers = current_centers.copy() foot_contact.fill(0.0) foot_normal_force.fill(0.0) foot_tangent_speed = tangential_speed contact_force = np.zeros(6, dtype=np.float64) for contact_idx in range(int(self.d.ncon)): contact = self.d.contact[contact_idx] contact_sides = set() geom1 = int(contact.geom1) geom2 = int(contact.geom2) if geom1 in self._foot_geom_id_to_side: contact_sides.add(self._foot_geom_id_to_side[geom1]) if geom2 in self._foot_geom_id_to_side: contact_sides.add(self._foot_geom_id_to_side[geom2]) if len(contact_sides) != 1: continue side_idx = next(iter(contact_sides)) foot_contact[side_idx] = 1.0 mujoco.mj_contactForce(self.m, self.d, contact_idx, contact_force) foot_normal_force[side_idx] += np.float32(abs(contact_force[0])) self._robot_low_level_foot_contact_seq.append(foot_contact) self._robot_low_level_foot_normal_force_seq.append(foot_normal_force) self._robot_low_level_foot_tangent_speed_seq.append(foot_tangent_speed) @staticmethod def _flatten_single_step_output(values, *, dtype=None) -> np.ndarray: arr = np.asarray(values, dtype=dtype) if arr.ndim == 0: raise ValueError( "Expected at least 1D output for single-step ONNX routing dump." ) return arr.reshape(-1, arr.shape[-1])[0] def _discover_policy_moe_outputs(self) -> None: self.policy_moe_layer_output_names: list[tuple[int, str, str]] = [] routing_outputs: dict[int, dict[str, str]] = {} pattern = re.compile(r"^moe_layer_(\d+)_expert_(indices|logits)$") for node in self.policy_session.get_outputs(): match = pattern.fullmatch(node.name) if match is None: continue layer_idx = int(match.group(1)) kind = str(match.group(2)) routing_outputs.setdefault(layer_idx, {})[kind] = node.name for layer_idx in sorted(routing_outputs): layer_outputs = routing_outputs[layer_idx] if "indices" not in layer_outputs or "logits" not in layer_outputs: logger.warning( "Skipping incomplete MoE routing outputs for layer " f"{layer_idx}: {sorted(layer_outputs)}" ) continue self.policy_moe_layer_output_names.append( ( layer_idx, layer_outputs["indices"], layer_outputs["logits"], ) ) if self.policy_moe_layer_output_names: logger.info( "Detected MoE routing outputs for layers: " f"{[layer_idx for layer_idx, _, _ in self.policy_moe_layer_output_names]}" ) def _get_stacked_moe_routing_tensors( self, ) -> tuple[np.ndarray | None, np.ndarray | None]: indices_seq = getattr(self, "_robot_moe_expert_indices_seq", []) logits_seq = getattr(self, "_robot_moe_expert_logits_seq", []) if len(indices_seq) == 0 or len(logits_seq) == 0: return None, None return ( np.stack(indices_seq, axis=0).astype(np.int64), np.stack(logits_seq, axis=0).astype(np.float32), ) def _get_stacked_low_level_foot_contact_tensors( self, ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: contact_seq = getattr(self, "_robot_low_level_foot_contact_seq", []) normal_force_seq = getattr( self, "_robot_low_level_foot_normal_force_seq", [] ) tangent_speed_seq = getattr( self, "_robot_low_level_foot_tangent_speed_seq", [] ) if contact_seq and normal_force_seq and tangent_speed_seq: return ( np.stack(contact_seq, axis=0).astype(np.float32), np.stack(normal_force_seq, axis=0).astype(np.float32), np.stack(tangent_speed_seq, axis=0).astype(np.float32), ) num_low_level_samples = len( getattr(self, "_robot_low_level_dof_torque_seq", []) ) if num_low_level_samples <= 0: return None, None, None nan_array = np.full((num_low_level_samples, 2), np.nan, np.float32) return nan_array.copy(), nan_array.copy(), nan_array.copy() def _record_onnx_io_frame(self, input_feed, output_names, onnx_output): if not self._onnx_io_input_names: self._onnx_io_input_names = list(input_feed.keys()) self._onnx_io_inputs = { name: [] for name in self._onnx_io_input_names } if not self._onnx_io_output_names: self._onnx_io_output_names = list(output_names) self._onnx_io_outputs = { name: [] for name in self._onnx_io_output_names } for name in self._onnx_io_input_names: if name not in input_feed: raise KeyError(f"Missing ONNX input tensor: {name}") self._onnx_io_inputs[name].append( np.array(input_feed[name], copy=True) ) for name, value in zip(self._onnx_io_output_names, onnx_output): self._onnx_io_outputs[name].append(np.array(value, copy=True)) @staticmethod def _stack_onnx_io_frames( frame_dict: dict[str, list[np.ndarray]], ) -> dict[str, np.ndarray]: stacked: dict[str, np.ndarray] = {} for name, frames in frame_dict.items(): if frames: stacked[name] = np.stack(frames, axis=0) else: stacked[name] = np.empty((0,), dtype=np.float32) return stacked def save_onnx_io_dump(self, output_path, meta_info): payload = { "input_names": list(self._onnx_io_input_names), "output_names": list(self._onnx_io_output_names), "inputs": self._stack_onnx_io_frames(self._onnx_io_inputs), "outputs": self._stack_onnx_io_frames(self._onnx_io_outputs), "source_npz": meta_info.get( "source_npz", meta_info.get("source_file", "") ), "onnx_model": meta_info.get( "onnx_model", meta_info.get("model", "") ), } np.save(output_path, payload, allow_pickle=True) def _find_actor_place_holder_ndim(self): n_dim = 0 for obs_dict in self._get_policy_atomic_obs_list(): name = str(list(obs_dict.keys())[0]) if name == "place_holder": params = obs_dict["place_holder"].get("params", {}) n_dim = int(params.get("n_dim", 0)) if name == "actor_place_holder": params = obs_dict["actor_place_holder"].get("params", {}) n_dim = int(params.get("n_dim", 0)) return n_dim def _get_actor_obs_term_params(self, term_name: str) -> dict: for obs_dict in self._get_policy_atomic_obs_list(): configured_name = str(list(obs_dict.keys())[0]) if configured_name != term_name: continue term_cfg = obs_dict[configured_name] if not isinstance(term_cfg, dict): return {} params = term_cfg.get("params", {}) return dict(params) if isinstance(params, dict) else {} return {} def _get_ref_keybody_indices(self, term_name: str) -> np.ndarray: params = self._get_actor_obs_term_params(term_name) keybody_names = params.get("keybody_names", None) body_names = [str(name) for name in self.config.robot.body_names] if keybody_names is None: return np.arange(len(body_names), dtype=np.int64) keybody_names = [str(name) for name in keybody_names] body_name_to_idx = { body_name: idx for idx, body_name in enumerate(body_names) } missing_names = [ name for name in keybody_names if name not in body_name_to_idx ] if len(missing_names) > 0: raise ValueError( f"Unknown keybody_names in '{term_name}': {missing_names}. " f"Available body names: {body_names}" ) return np.asarray( [body_name_to_idx[name] for name in keybody_names], dtype=np.int64, ) @staticmethod def _to_plain_obs_cfg(cfg): if OmegaConf.is_config(cfg): plain_cfg = OmegaConf.to_container(cfg, resolve=True) else: plain_cfg = dict(cfg) if not isinstance(plain_cfg, dict): raise ValueError( f"Observation term config must be a mapping, got {type(plain_cfg)}" ) return plain_cfg def _get_actor_obs_schema_terms(self) -> list[str]: modules_cfg = self.config.get("modules", None) if modules_cfg is None: return [] actor_cfg = modules_cfg.get("actor", None) if actor_cfg is None: return [] obs_schema = actor_cfg.get("obs_schema", None) if obs_schema is None: return [] ordered_terms: list[str] = [] for _, seq_cfg in obs_schema.items(): seq_terms = seq_cfg.get("terms", []) ordered_terms.extend(str(term) for term in seq_terms) return ordered_terms def _get_actor_atomic_obs_entries(self) -> list[tuple[str, str, dict]]: obs_cfg = self.config.get("obs", None) if obs_cfg is None: raise ValueError("Missing config.obs for MuJoCo sim2sim") obs_groups = obs_cfg.get("obs_groups", None) if obs_groups is None: raise ValueError( "Missing config.obs.obs_groups for MuJoCo sim2sim" ) if obs_groups.get("policy", None) is not None: entries: list[tuple[str, str, dict]] = [] for term_dict in obs_groups.policy.atomic_obs_list: term_name = str(list(term_dict.keys())[0]) entries.append( ( "policy", term_name, self._to_plain_obs_cfg(term_dict[term_name]), ) ) return entries if obs_groups.get("unified", None) is not None: entries = [] for term_dict in obs_groups.unified.atomic_obs_list: term_name = str(list(term_dict.keys())[0]) if term_name.startswith("critic_"): continue entries.append( ( "unified", term_name, self._to_plain_obs_cfg(term_dict[term_name]), ) ) if not entries: raise ValueError( "obs_groups.unified found but contains no non-critic terms." ) return entries raise ValueError( "Unsupported obs config for MuJoCo sim2sim: expected obs_groups.policy or obs_groups.unified." ) def _get_policy_atomic_obs_list(self): """Resolve the atomic obs list used to build the ONNX policy input. Supports both legacy configs (obs_groups.policy) and PULSE-stage2 configs that use a unified group (obs_groups.unified) with actor_/critic_ prefixes. """ actor_atomic_entries = self._get_actor_atomic_obs_entries() schema_terms = self._get_actor_obs_schema_terms() if len(schema_terms) == 0: logger.warning( "modules.actor.obs_schema is unavailable; using obs_groups actor term order for MuJoCo policy input." ) return [ {term_name: cfg} for _, term_name, cfg in actor_atomic_entries ] by_full_key: dict[str, tuple[str, dict]] = {} by_leaf_key: dict[str, tuple[str, dict]] = {} ambiguous_leaf_keys: set[str] = set() for group_name, term_name, term_cfg in actor_atomic_entries: full_key = f"{group_name}/{term_name}" by_full_key[full_key] = (term_name, term_cfg) if term_name in by_leaf_key: ambiguous_leaf_keys.add(term_name) else: by_leaf_key[term_name] = (term_name, term_cfg) ordered_atomic_list = [] for schema_term in schema_terms: schema_term_key = str(schema_term) if schema_term_key in by_full_key: term_name, term_cfg = by_full_key[schema_term_key] ordered_atomic_list.append({term_name: term_cfg}) continue leaf_key = schema_term_key.split("/")[-1] if leaf_key in ambiguous_leaf_keys: raise ValueError( "Actor obs_schema term " f"'{schema_term}' is ambiguous by leaf key '{leaf_key}'. " "Use explicit group/term hierarchy in obs_schema terms." ) if leaf_key not in by_leaf_key: raise ValueError( "Actor obs_schema term " f"'{schema_term}' is not present in obs_groups actor atomic obs list." ) term_name, term_cfg = by_leaf_key[leaf_key] ordered_atomic_list.append({term_name: term_cfg}) return ordered_atomic_list # ----------------- Kinematics / velocities ----------------- # ----------------- Kinematics / velocities ----------------- def _body_origin_world_velocity( self, body_id: int ) -> tuple[np.ndarray, np.ndarray]: """Compute world-frame spatial velocity (v, w) of a body's frame origin. Returns: tuple: (lin_vel_w[3], ang_vel_w[3]) in world coordinates. """ # World-frame Jacobians for body origin jacp = np.zeros((3, self.m.nv), dtype=np.float64) jacr = np.zeros((3, self.m.nv), dtype=np.float64) mujoco.mj_jacBody(self.m, self.d, jacp, jacr, int(body_id)) # qvel is float64 in MuJoCo; keep computation in float64 then cast lin_vel_w = jacp @ self.d.qvel ang_vel_w = jacr @ self.d.qvel return lin_vel_w.astype(np.float32), ang_vel_w.astype(np.float32) # ----------------- Body name/id resolution ----------------- def _get_anchor_body_name(self) -> str: if not hasattr(self, "anchor_body_name"): self.anchor_body_name = str( getattr(self.config.robot, "anchor_body", "pelvis") ) logger.info(f"Anchor body name: {self.anchor_body_name}") return self.anchor_body_name def _get_torso_body_name(self) -> str: if not hasattr(self, "torso_body_name"): self.torso_body_name = str( getattr(self.config.robot, "torso_name", "torso_link") ) return self.torso_body_name @property def ref_motion_frame_idx(self): return self.motion_frame_idx @property def anchor_body_idx(self) -> int: return self.config.robot.body_names.index( self.config.robot.anchor_body ) @property def root_body_idx(self) -> int: return 0 @property def torso_body_idx(self) -> int: return self.config.robot.body_names.index(self.config.robot.torso_name) @property def robot_global_bodylink_pos(self): """World-frame positions of all robot bodies at their MuJoCo body frame origins. MuJoCo stores body state for a special world body at index 0, which does not correspond to any physical link and is always static. We slice it out and return `xpos[1:]` so that row 0 corresponds to the root body (e.g. pelvis) and the body dimension matches the HoloMotion NPZ `*_global_translation` arrays. Returns: np.ndarray: Array of shape [n_bodies, 3] in MuJoCo body order with the world body excluded. """ return self.d.xpos[1:] @property def robot_global_bodylink_rot(self): """World-frame orientations of all robot bodies as WXYZ quaternions. As with positions, the MuJoCo world body at index 0 is excluded so that the returned array is aligned with the body dimension used in HoloMotion NPZ `*_global_rotation_quat` arrays (root at index 0, no world entry). Returns: np.ndarray: Array of shape [n_bodies, 4] in MuJoCo body order with the world body excluded. """ xquat = self.d.xquat[1:] xquat_t = torch.as_tensor(xquat, dtype=torch.float32, device="cpu") xquat_t = standardize_quaternion(xquat_t) return xquat_t.detach().cpu().numpy() @property def robot_global_bodylink_lin_vel(self): """World-frame linear velocities of all robot body frame origins. Uses `mujoco.mj_objectVelocity` with `mjOBJ_BODY` and `flg_centered=0` to query the 6D spatial velocity at each body's frame origin, then slices the translational component. The world body (ID 0) is excluded so that the body dimension matches the NPZ `*_global_velocity` arrays. Returns: np.ndarray: Array of shape [n_bodies, 3] giving linear velocities in the MuJoCo world frame, ordered by body ID starting from the root body. """ nbody = int(self.m.nbody) vel_6d = np.zeros((nbody, 6), dtype=np.float64) for bid in range(1, nbody): mujoco.mj_objectVelocity( self.m, self.d, mujoco.mjtObj.mjOBJ_BODY, bid, vel_6d[bid], 0, ) return vel_6d[1:, 3:6] @property def robot_global_bodylink_ang_vel(self): """World-frame angular velocities of all robot body frame origins. Uses the same `mujoco.mj_objectVelocity` call as `robot_global_bodylink_lin_vel` and slices the rotational component. The world body (ID 0) is dropped so that the body dimension is identical to the NPZ `*_global_angular_velocity` arrays and the translation/rotation/velocity tensors all share the same body ordering. Returns: np.ndarray: Array of shape [n_bodies, 3] giving angular velocities in the MuJoCo world frame, ordered by body ID starting from the root body. """ nbody = int(self.m.nbody) vel_6d = np.zeros((nbody, 6), dtype=np.float64) for bid in range(1, nbody): mujoco.mj_objectVelocity( self.m, self.d, mujoco.mjtObj.mjOBJ_BODY, bid, vel_6d[bid], 0, ) return vel_6d[1:, 0:3] @property def robot_dof_pos(self): if hasattr(self, "actuator_qpos_indices"): return self.d.qpos[self.actuator_qpos_indices] return self.d.qpos[7:] @property def robot_dof_vel(self): if hasattr(self, "actuator_qvel_indices"): return self.d.qvel[self.actuator_qvel_indices] return self.d.qvel[6:] # ----------------- Reference->Simulation alignment ----------------- def _ensure_ref_to_sim_transform_rigid(self): """Compute rigid transform (yaw + translation) from reference globals to sim globals. The transform is defined such that the reference **anchor body** pose at frame 0 is mapped onto the robot's current global anchor pose in XY translation and yaw: - `yaw(q_ref_to_sim * q_ref_anchor_0) = yaw(q_robot_anchor_0)` - `t_ref_to_sim + R(q_ref_to_sim) @ t_ref_anchor_0 = t_robot_anchor_0` This uses the robot's initial global pose so that arbitrary initialization offsets in XY position and yaw between the robot and the reference motion are absorbed into the reference->simulation mapping, and all subsequent reference globals are expressed in the same world frame as the robot. """ if self._ref_to_sim_ready: return # If we don't have reference globals, fall back to identity transform. if getattr(self, "ref_global_translation", None) is None: self._ref_to_sim_q_wxyz = np.array( [1.0, 0.0, 0.0, 0.0], dtype=np.float32 ) self._ref_to_sim_t = np.zeros(3, dtype=np.float32) self._ref_to_sim_ready = True logger.info( "No reference global translations available; using identity Ref->Sim transform." ) return # If rotations are missing, keep the previous translation-only semantics. if getattr(self, "ref_global_rotation_quat_xyzw", None) is None: t_robot = torch.as_tensor( self.robot_global_bodylink_pos[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) t_ref = torch.as_tensor( self.ref_global_translation[0, self.anchor_body_idx].astype( np.float32 ), dtype=torch.float32, device="cpu", ) t_ref_to_sim = t_robot - t_ref self._ref_to_sim_q_wxyz = np.array( [1.0, 0.0, 0.0, 0.0], dtype=np.float32 ) self._ref_to_sim_t = t_ref_to_sim.detach().cpu().numpy() self._ref_to_sim_ready = True logger.info( "Reference rotations missing; initialized Ref->Sim as translation-only " f"transform. t={self._ref_to_sim_t}" ) return # Anchor body index shared between robot globals and reference globals anchor_idx = self.anchor_body_idx # Robot anchor pose in simulation world frame (after initial state has been set) t_robot = torch.as_tensor( self.robot_global_bodylink_pos[anchor_idx], dtype=torch.float32, device="cpu", ) q_robot_wxyz = torch.as_tensor( self.robot_global_bodylink_rot[anchor_idx], dtype=torch.float32, device="cpu", ) # Reference anchor pose at frame 0 in NPZ global frame t_ref0 = torch.as_tensor( self.ref_global_translation[0, anchor_idx].astype(np.float32), dtype=torch.float32, device="cpu", ) q_ref0_xyzw = torch.as_tensor( self.ref_global_rotation_quat_xyzw[0, anchor_idx].astype( np.float32 ), dtype=torch.float32, device="cpu", ) q_ref0_wxyz = xyzw_to_wxyz(q_ref0_xyzw) # Yaw-only rotation mapping: align reference yaw to robot yaw (keep roll/pitch from reference). R_robot = matrix_from_quat(q_robot_wxyz) R_ref0 = matrix_from_quat(q_ref0_wxyz) yaw_robot = torch.atan2(R_robot[1, 0], R_robot[0, 0]) yaw_ref0 = torch.atan2(R_ref0[1, 0], R_ref0[0, 0]) yaw_delta = yaw_robot - yaw_ref0 yaw_quat_xyzw = quat_from_euler_xyz( torch.tensor(0.0, dtype=torch.float32, device="cpu"), torch.tensor(0.0, dtype=torch.float32, device="cpu"), yaw_delta, ) q_ref_to_sim = xyzw_to_wxyz(yaw_quat_xyzw) q_ref_to_sim = quat_normalize_wxyz(q_ref_to_sim) # Translation mapping: t_ref_to_sim + R(q_ref_to_sim) @ t_ref0 = t_robot t_ref0_in_sim = quat_apply(q_ref_to_sim, t_ref0) t_ref_to_sim = t_robot - t_ref0_in_sim self._ref_to_sim_q_wxyz = ( q_ref_to_sim.detach().cpu().numpy().astype(np.float32) ) self._ref_to_sim_t = ( t_ref_to_sim.detach().cpu().numpy().astype(np.float32) ) self._ref_to_sim_ready = True logger.info( "Initialized Ref->Sim rigid transform. " f"q={self._ref_to_sim_q_wxyz}, t={self._ref_to_sim_t}" ) def _detect_command_mode(self) -> str: m_dir = self.config.get("motion_npz_dir") or self.config.get( "eval", {} ).get("motion_npz_dir") m_path = self.config.get("motion_npz_path") or self.config.get( "eval", {} ).get("motion_npz_path") if m_path is not None and not os.path.exists(m_path): raise FileNotFoundError(f"Motion file not found: {m_path}") if (m_dir and str(m_dir) != "") or (m_path and str(m_path) != ""): return "motion_tracking" return "velocity_tracking" def _init_obs_buffers(self): atomic_list = self._get_policy_atomic_obs_list() obs_policy_cfg = {"atomic_obs_list": atomic_list} self.obs_builder = PolicyObsBuilder( dof_names_onnx=self.dof_names_onnx, default_angles_onnx=self.default_angles_onnx, evaluator=self, obs_policy_cfg=obs_policy_cfg, ) def load_policy(self): """Load the policy model using ONNX Runtime.""" onnx_model_path = Path(self.config.ckpt_onnx_path) logger.info(f"Loading ONNX policy from {onnx_model_path}") providers = ["CPUExecutionProvider"] use_gpu = _coerce_config_bool( self.config.get("use_gpu", False), default=False ) gpu_id = int(self.config.get("gpu_id", 0)) available_providers = onnxruntime.get_available_providers() if use_gpu: if "CUDAExecutionProvider" in available_providers: cuda_options = {"device_id": gpu_id} if torch.cuda.is_available(): torch.cuda.set_device(gpu_id) cuda_options["user_compute_stream"] = str( torch.cuda.current_stream().cuda_stream ) providers = [ ("CUDAExecutionProvider", cuda_options), "CPUExecutionProvider", ] logger.info( f"Using CUDAExecutionProvider with gpu_id={gpu_id}" ) else: logger.warning( "use_gpu=true but CUDAExecutionProvider is unavailable; " "falling back to CPUExecutionProvider." ) sess_options = onnxruntime.SessionOptions() sess_options.intra_op_num_threads = 1 sess_options.inter_op_num_threads = 1 sess_options.log_severity_level = 3 self.policy_session = onnxruntime.InferenceSession( str(onnx_model_path), sess_options=sess_options, providers=providers, ) logger.info( f"ONNX Runtime session created successfully using: {self.policy_session.get_providers()}" ) self.policy_input_name = self.policy_session.get_inputs()[0].name self.policy_output_name = self.policy_session.get_outputs()[0].name logger.info( f"Policy ONNX Input: {self.policy_input_name}, Output: {self.policy_output_name}" ) logger.info("Initializing KV-Cache for Policy...") self.policy_input_name = "obs" self.policy_kv_input_name = None self.policy_step_input_name = None self.policy_kv_shape = None for node in self.policy_session.get_inputs(): name = node.name shape = node.shape logger.info(f"Model Input: Name={name}, Shape={shape}") if "obs" in name: self.policy_input_name = name elif "past_key_values" in name: self.policy_kv_input_name = name self.policy_kv_shape = shape elif "step_idx" in name or "step" in name or "pos" in name: self.policy_step_input_name = name self.policy_output_name = self.policy_session.get_outputs()[0].name self.policy_kv_output_name = None for node in self.policy_session.get_outputs(): if "present_key_values" in node.name: self.policy_kv_output_name = node.name self._discover_policy_moe_outputs() if self.policy_kv_input_name and self.policy_kv_shape: shape = [ d if isinstance(d, int) else 1 for d in self.policy_kv_shape ] self.policy_kv_cache = np.zeros(shape, dtype=np.float32) self.policy_model_context_len = ( int(shape[3]) if len(shape) > 3 else 0 ) if self.max_context_len > 0 and self.policy_model_context_len > 0: self.policy_effective_context_len = min( self.max_context_len, self.policy_model_context_len ) logger.info( "Using context window from " f"algo.config.num_steps_per_env={self.max_context_len} " f"(model cache len={self.policy_model_context_len}, " f"effective={self.policy_effective_context_len})" ) else: self.policy_effective_context_len = ( self.policy_model_context_len ) self.use_kv_cache = True logger.info(f"KV-Cache ENABLED. Shape: {shape}") else: self.use_kv_cache = False self.policy_kv_cache = None self.policy_model_context_len = 0 self.policy_effective_context_len = 0 logger.warning("KV-Cache NOT found in model inputs!") if self.max_context_len > 0: logger.warning( "algo.config.num_steps_per_env is set but KV-Cache is unavailable; " "ignoring context window limit." ) logger.info("ONNX Policy loaded successfully") def _read_onnx_metadata(self) -> dict: """Read model metadata from ONNX file and parse into Python types.""" onnx_model_path = Path(self.config.ckpt_onnx_path) model = onnx.load(str(onnx_model_path)) meta = {p.key: p.value for p in model.metadata_props} def _parse_floats(csv_str: str): return np.array( [float(x) for x in csv_str.split(",") if x != ""], dtype=np.float32, ) result = {} result["action_scale"] = _parse_floats(meta["action_scale"]) result["kps"] = _parse_floats(meta["joint_stiffness"]) result["kds"] = _parse_floats(meta["joint_damping"]) result["default_joint_pos"] = _parse_floats(meta["default_joint_pos"]) result["joint_names"] = [ x for x in meta["joint_names"].split(",") if x != "" ] # 打印解析后的元数据 logger.info("=== Loaded ONNX Metadata ===") for key, value in result.items(): # 如果关节名称列表很长,进行格式化处理以保持整洁 if key == "joint_names": logger.info(f"{key}: {', '.join(value)}") else: logger.info(f"{key}:\n{value}") logger.info("============================") return result def _apply_onnx_metadata(self): """Apply PD/scale/defaults from ONNX metadata as authoritative values.""" meta = self._read_onnx_metadata() self.dof_names_onnx = meta["joint_names"] self.action_scale_onnx = meta["action_scale"].astype(np.float32) self.kps_onnx = meta["kps"].astype(np.float32) self.kds_onnx = meta["kds"].astype(np.float32) self.default_angles_onnx = meta["default_joint_pos"].astype(np.float32) def _build_dof_mappings(self): # Map ONNX <-> MJCF for control self.onnx_to_mu = [ self.dof_names_onnx.index(name) for name in self.mjcf_dof_names ] self.mu_to_onnx = [ self.mjcf_dof_names.index(name) for name in self.dof_names_onnx ] self.ref_to_onnx = [ self.dof_names_ref_motion.index(name) for name in self.dof_names_onnx ] # Map MuJoCo actuator DOF order -> reference DOF order used in motion npz self.mu_to_ref = [] for mu_idx in range(len(self.mjcf_dof_names)): onnx_idx = self.onnx_to_mu[mu_idx] ref_idx = self.ref_to_onnx[onnx_idx] self.mu_to_ref.append(ref_idx) self.kps_mu = self.kps_onnx[self.onnx_to_mu].astype(np.float32) self.kds_mu = self.kds_onnx[self.onnx_to_mu].astype(np.float32) self.default_angles_mu = self.default_angles_onnx[ self.onnx_to_mu ].astype(np.float32) self.action_scale_mu = self.action_scale_onnx[self.onnx_to_mu].astype( np.float32 ) @staticmethod def _normalize_filter_cutoff_hz(raw_values, num_frames: int) -> np.ndarray: num_frames = max(int(num_frames), 0) if num_frames == 0: return np.zeros((0, 1), dtype=np.float32) if raw_values is None: return np.zeros((num_frames, 1), dtype=np.float32) cutoff = np.asarray(raw_values, dtype=np.float32) if cutoff.ndim == 0: cutoff = np.full((num_frames, 1), float(cutoff), dtype=np.float32) return cutoff if cutoff.ndim == 1: cutoff = cutoff[:, None] else: cutoff = cutoff.reshape(cutoff.shape[0], -1)[:, :1] if cutoff.shape[0] == 0: return np.zeros((num_frames, 1), dtype=np.float32) if cutoff.shape[0] == 1 and num_frames > 1: cutoff = np.repeat(cutoff, num_frames, axis=0) elif cutoff.shape[0] < num_frames: pad = np.repeat( cutoff[-1:, :], num_frames - cutoff.shape[0], axis=0 ) cutoff = np.concatenate([cutoff, pad], axis=0) elif cutoff.shape[0] > num_frames: cutoff = cutoff[:num_frames] return cutoff.astype(np.float32, copy=False) def load_motion_data(self): """Load motion data from npz file.""" motion_npz_path = self.config.get("motion_npz_path", None) if motion_npz_path is None: logger.warning( "No motion_npz_path specified in config, using zero reference motion" ) return logger.info(f"Loading motion data from {motion_npz_path}") # Load npz file with np.load(motion_npz_path, allow_pickle=True) as npz: keys = list(npz.keys()) raw_filter_cutoff_hz = ( np.array(npz["filter_cutoff_hz"]).astype(np.float32) if "filter_cutoff_hz" in npz else None ) # Try direct arrays first (dof_pos, dof_vel or variants) naming_pairs = [ ("ref_dof_pos", "ref_dof_vel"), ("dof_pos", "dof_vels"), # backward compat # ("ft_ref_pos", "ft_ref_dof_vel"), ] pos_key = None vel_key = None for pos_k, vel_k in naming_pairs: if pos_k in npz and vel_k in npz: pos_key = pos_k vel_key = vel_k break if pos_key is not None and vel_key is not None: # Direct arrays found self.ref_dof_pos = np.array(npz[pos_key]).astype(np.float32) self.ref_dof_vel = np.array(npz[vel_key]).astype(np.float32) elif len(keys) == 1: # Single key - might contain nested dict arr = npz[keys[0]] if getattr(arr, "dtype", None) == object: obj = arr.item() if arr.size == 1 else arr if isinstance(obj, dict): if ( raw_filter_cutoff_hz is None and "filter_cutoff_hz" in obj ): raw_filter_cutoff_hz = np.array( obj["filter_cutoff_hz"] ).astype(np.float32) # Try to find dof_pos/dof_vel in nested dict for pos_k, vel_k in naming_pairs: if pos_k in obj and vel_k in obj: self.ref_dof_pos = np.array(obj[pos_k]).astype( np.float32 ) self.ref_dof_vel = np.array(obj[vel_k]).astype( np.float32 ) break else: raise ValueError( f"Could not find dof_pos/dof_vel in nested dict. " f"Available keys: {list(obj.keys())}" ) else: raise ValueError( f"Single key '{keys[0]}' does not contain a dict. " f"Type: {type(obj)}" ) else: raise ValueError( f"Single key '{keys[0]}' is not an object array. " f"Available keys: {keys}" ) else: raise ValueError( f"Could not find dof_pos/dof_vel arrays. Available keys: {keys}" ) # Ensure consistent frame count if self.ref_dof_pos.shape[0] != self.ref_dof_vel.shape[0]: min_frames = min( self.ref_dof_pos.shape[0], self.ref_dof_vel.shape[0] ) self.ref_dof_pos = self.ref_dof_pos[:min_frames] self.ref_dof_vel = self.ref_dof_vel[:min_frames] logger.warning( f"Frame count mismatch, truncated to {min_frames} frames" ) self.n_motion_frames = self.ref_dof_pos.shape[0] # Optional: load reference global body frames as per motion spec ref_pos_keys = ["ref_global_translation", "global_translation"] ref_rot_keys = ["ref_global_rotation_quat", "global_rotation_quat"] ref_vel_keys = ["ref_global_velocity", "global_velocity"] ref_ang_vel_keys = [ "ref_global_angular_velocity", "global_angular_velocity", ] self.ref_global_translation = None self.ref_global_rotation_quat_xyzw = None self.ref_global_velocity = None self.ref_global_angular_velocity = None for k in ref_pos_keys: if k in npz: self.ref_global_translation = np.array(npz[k]).astype( np.float32 ) break for k in ref_rot_keys: if k in npz: self.ref_global_rotation_quat_xyzw = np.array( npz[k] ).astype(np.float32) break for k in ref_vel_keys: if k in npz: self.ref_global_velocity = np.array(npz[k]).astype( np.float32 ) break for k in ref_ang_vel_keys: if k in npz: self.ref_global_angular_velocity = np.array(npz[k]).astype( np.float32 ) break if self.ref_global_translation is not None: # Truncate to motion frames if needed t_tr = min( self.n_motion_frames, self.ref_global_translation.shape[0] ) if t_tr < self.n_motion_frames: logger.warning( f"Global translation shorter than motion frames ({t_tr} < {self.n_motion_frames}), truncating motion." ) self.n_motion_frames = t_tr self.ref_dof_pos = self.ref_dof_pos[:t_tr] self.ref_dof_vel = self.ref_dof_vel[:t_tr] self.ref_global_translation = self.ref_global_translation[ :t_tr ] if self.ref_global_rotation_quat_xyzw is not None: t_rr = min( self.n_motion_frames, self.ref_global_rotation_quat_xyzw.shape[0], ) if t_rr < self.n_motion_frames: logger.warning( f"Global rotation shorter than motion frames ({t_rr} < {self.n_motion_frames}), truncating motion." ) self.n_motion_frames = t_rr self.ref_dof_pos = self.ref_dof_pos[:t_rr] self.ref_dof_vel = self.ref_dof_vel[:t_rr] # Also truncate previously processed globals if necessary if self.ref_global_translation is not None: self.ref_global_translation = ( self.ref_global_translation[:t_rr] ) self.ref_global_rotation_quat_xyzw = ( self.ref_global_rotation_quat_xyzw[:t_rr] ) if self.ref_global_velocity is not None: t_rv = min( self.n_motion_frames, self.ref_global_velocity.shape[0], ) if t_rv < self.n_motion_frames: self.n_motion_frames = t_rv self.ref_dof_pos = self.ref_dof_pos[:t_rv] self.ref_dof_vel = self.ref_dof_vel[:t_rv] if self.ref_global_translation is not None: self.ref_global_translation = ( self.ref_global_translation[:t_rv] ) if self.ref_global_rotation_quat_xyzw is not None: self.ref_global_rotation_quat_xyzw = ( self.ref_global_rotation_quat_xyzw[:t_rv] ) self.ref_global_velocity = self.ref_global_velocity[:t_rv] if self.ref_global_angular_velocity is not None: t_ra = min( self.n_motion_frames, self.ref_global_angular_velocity.shape[0], ) if t_ra < self.n_motion_frames: self.n_motion_frames = t_ra self.ref_dof_pos = self.ref_dof_pos[:t_ra] self.ref_dof_vel = self.ref_dof_vel[:t_ra] if self.ref_global_translation is not None: self.ref_global_translation = ( self.ref_global_translation[:t_ra] ) if self.ref_global_rotation_quat_xyzw is not None: self.ref_global_rotation_quat_xyzw = ( self.ref_global_rotation_quat_xyzw[:t_ra] ) if self.ref_global_velocity is not None: self.ref_global_velocity = self.ref_global_velocity[ :t_ra ] self.ref_global_angular_velocity = ( self.ref_global_angular_velocity[:t_ra] ) self.filter_cutoff_hz = self._normalize_filter_cutoff_hz( raw_filter_cutoff_hz, self.n_motion_frames ) logger.info( f"Loaded motion data with {self.n_motion_frames} frames and {self.ref_dof_pos.shape[1]} DOFs" ) def load_mujoco_model(self): """Load the MuJoCo model.""" xml_path = self.config.get("robot_xml_path", None) if xml_path is None: raise ValueError( "robot_xml_path should be specified in config !!!" ) logger.info(f"Loading MuJoCo model from {xml_path}") self.m = mujoco.MjModel.from_xml_path(xml_path) self.d = mujoco.MjData(self.m) self.m.opt.timestep = self.simulation_dt logger.info( f"MuJoCo model loaded with {self.m.nq} position DOFs and {self.m.nu} control DOFs" ) def _init_camera_config(self): """Initialize shared camera configuration for viewer and offscreen renderers.""" self._root_body_id = -1 if not self._camera_tracking_enabled: logger.info("Camera tracking disabled") return # Prefer anchor body from robot config, then fall back to common root names candidates: list[str] = [] anchor_name = self._get_anchor_body_name() candidates.append(anchor_name) for name in ["pelvis", "torso", "base_link", "trunk", "root"]: if name not in candidates: candidates.append(name) for body_name in candidates: bid = int( mujoco.mj_name2id(self.m, mujoco.mjtObj.mjOBJ_BODY, body_name) ) if bid != -1: self._root_body_id = bid break if self._root_body_id != -1: logger.info( f"Camera tracking enabled for body '{body_name}' (ID={self._root_body_id}), " f"lookat height offset: {self._camera_height_offset:.2f}m" ) else: logger.warning( "Could not find robot root body for camera tracking; " "viewer and offscreen cameras will not track the robot." ) def _configure_viewer_camera(self, viewer): """Apply shared align-view parameters to the interactive viewer camera.""" mujoco.mjv_defaultFreeCamera(self.m, viewer.cam) viewer.cam.azimuth = self._camera_azimuth viewer.cam.elevation = self._camera_elevation viewer.cam.distance = self._camera_distance def _init_video_tools(self, tag: str): """Initialize video writer and offscreen renderer when recording is enabled.""" if not bool(self.config.get("record_video", False)): return width = int(self.config.get("video_width", 1280)) height = int(self.config.get("video_height", 720)) fps = float(self.config.get("video_fps", 30.0)) onnx_stem = os.path.splitext( os.path.basename(self.config.ckpt_onnx_path) )[0] output_dir = os.path.join( os.path.dirname(self.config.ckpt_onnx_path), f"mujoco_output_{onnx_stem}", ) os.makedirs(output_dir, exist_ok=True) motion_npz_path = self.config.get("motion_npz_path", None) if motion_npz_path is not None: motion_stem = os.path.splitext(os.path.basename(motion_npz_path))[ 0 ] out_path = os.path.join(output_dir, f"{motion_stem}.mp4") else: out_path = os.path.join(output_dir, "velocity_tracking.mp4") fourcc = cv2.VideoWriter_fourcc(*"mp4v") self._video_writer = cv2.VideoWriter( out_path, fourcc, fps, (width, height) ) self._offscreen = OffscreenRenderer( self.m, height, width, distance=self._camera_distance, azimuth=self._camera_azimuth, elevation=self._camera_elevation, ) self._frame_interval = 1.0 / max(fps, 1.0) self._last_frame_time = 0.0 if getattr(self, "ref_global_translation", None) is not None: self._offscreen.set_overlay_callback( lambda scene: self._draw_ref_body_spheres_to_scene( scene, reset_ngeom=False ) ) logger.info(f"Recording enabled. Writing to: {out_path}") def _dump_robot_augmented_npz(self) -> None: """Copy original motion npz and append robot_* states, saved next to video output. The output follows the holomotion offline-eval spec used in PPO: - robot_dof_pos, robot_dof_vel: [T, num_dofs] - robot_global_translation: [T, num_bodies, 3] - robot_global_rotation_quat: [T, num_bodies, 4] (XYZW) - robot_global_velocity: [T, num_bodies, 3] - robot_global_angular_velocity: [T, num_bodies, 3] """ motion_npz_path = self.config.get("motion_npz_path", None) if motion_npz_path is None: return if len(self._robot_dof_pos_seq) == 0: return # Stack recorded sequences robot_dof_pos = np.stack(self._robot_dof_pos_seq, axis=0).astype( np.float32 ) robot_dof_vel = np.stack(self._robot_dof_vel_seq, axis=0).astype( np.float32 ) robot_dof_acc = np.stack(self._robot_dof_acc_seq, axis=0).astype( np.float32 ) robot_dof_torque = np.stack(self._robot_dof_torque_seq, axis=0).astype( np.float32 ) robot_low_level_dof_torque = None if len(self._robot_low_level_dof_torque_seq) > 0: robot_low_level_dof_torque = np.stack( self._robot_low_level_dof_torque_seq, axis=0 ).astype(np.float32) ( robot_low_level_foot_contact, robot_low_level_foot_normal_force, robot_low_level_foot_tangent_speed, ) = self._get_stacked_low_level_foot_contact_tensors() robot_actions = None if len(getattr(self, "_robot_actions_seq", [])) > 0: robot_actions = np.stack(self._robot_actions_seq, axis=0).astype( np.float32 ) robot_action_rate = np.asarray( self._robot_action_rate_seq, dtype=np.float32 ) robot_global_translation = np.stack( self._robot_global_translation_seq, axis=0 ).astype(np.float32) robot_global_rotation_quat = np.stack( self._robot_global_rotation_quat_seq, axis=0 ).astype(np.float32) robot_global_velocity = np.stack( self._robot_global_velocity_seq, axis=0 ).astype(np.float32) robot_global_angular_velocity = np.stack( self._robot_global_angular_velocity_seq, axis=0 ).astype(np.float32) robot_moe_expert_indices, robot_moe_expert_logits = ( self._get_stacked_moe_routing_tensors() ) # Load original motion npz with np.load(motion_npz_path, allow_pickle=True) as npz: data_dict = {k: npz[k] for k in npz.files} # Augment with robot_* arrays (override if already present) data_dict["robot_dof_pos"] = robot_dof_pos data_dict["robot_dof_vel"] = robot_dof_vel data_dict["robot_dof_acc"] = robot_dof_acc data_dict["robot_dof_torque"] = robot_dof_torque if robot_low_level_dof_torque is not None: data_dict["robot_low_level_dof_torque"] = ( robot_low_level_dof_torque ) if robot_low_level_foot_contact is not None: data_dict["robot_low_level_foot_contact"] = ( robot_low_level_foot_contact ) if robot_low_level_foot_normal_force is not None: data_dict["robot_low_level_foot_normal_force"] = ( robot_low_level_foot_normal_force ) if robot_low_level_foot_tangent_speed is not None: data_dict["robot_low_level_foot_tangent_speed"] = ( robot_low_level_foot_tangent_speed ) if robot_actions is not None: data_dict["robot_actions"] = robot_actions data_dict["robot_low_level_torque_dt"] = np.array( self.simulation_dt, dtype=np.float32 ) data_dict["robot_low_level_contact_dt"] = np.array( self.simulation_dt, dtype=np.float32 ) data_dict["robot_action_rate"] = robot_action_rate data_dict["robot_global_translation"] = robot_global_translation data_dict["robot_global_rotation_quat"] = robot_global_rotation_quat data_dict["robot_global_velocity"] = robot_global_velocity data_dict["robot_global_angular_velocity"] = ( robot_global_angular_velocity ) if robot_moe_expert_indices is not None: data_dict["robot_moe_expert_indices"] = robot_moe_expert_indices if robot_moe_expert_logits is not None: data_dict["robot_moe_expert_logits"] = robot_moe_expert_logits # Derive output directory consistent with video writer onnx_stem = os.path.splitext( os.path.basename(self.config.ckpt_onnx_path) )[0] output_dir = os.path.join( os.path.dirname(self.config.ckpt_onnx_path), f"mujoco_output_{onnx_stem}", ) os.makedirs(output_dir, exist_ok=True) motion_stem = os.path.splitext(os.path.basename(motion_npz_path))[0] out_npz_path = os.path.join(output_dir, f"{motion_stem}_robot.npz") np.savez_compressed(out_npz_path, **data_dict) logger.info( f"Robot-augmented motion npz saved to: {out_npz_path} " f"(T={robot_dof_pos.shape[0]}, num_dofs={robot_dof_pos.shape[1]}, " f"num_bodies={robot_global_translation.shape[1]})" ) def _close_video_tools(self): if self._video_writer is not None: self._video_writer.release() self._video_writer = None if self._offscreen is not None: self._offscreen.close() self._offscreen = None self._frame_interval = None self._last_frame_time = 0.0 def _update_camera_lookat(self, cam): """Update camera lookat to track the robot root when tracking is enabled.""" if not self._camera_tracking_enabled: return if self._root_body_id == -1: return cam.lookat[:2] = self.d.xpos[self._root_body_id][:2] cam.lookat[2] = ( self.d.xpos[self._root_body_id][2] + self._camera_height_offset ) def _maybe_record_frame(self): if self._video_writer is None or self._offscreen is None: return now = time.time() if ( self._last_frame_time == 0.0 or (now - self._last_frame_time) >= self._frame_interval ): # Update offscreen camera lookat to track robot (if enabled) self._update_camera_lookat(self._offscreen._cam) frame_rgb = self._offscreen.render(self.d) # Convert RGB (MuJoCo) -> BGR (OpenCV) before writing frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) self._video_writer.write(frame_bgr) self._last_frame_time = now def _apply_control(self, sleep: bool): """Apply PD targets via Unitree lowcmd, step MuJoCo, optionally sleep.""" for _ in range(self.control_decimation): record_low_level_torque = ( self.command_mode == "motion_tracking" and self.ref_dof_pos is not None ) if record_low_level_torque: torque_ref = np.zeros( len(self.dof_names_ref_motion), dtype=np.float32 ) current_dof_pos = self.robot_dof_pos current_dof_vel = self.robot_dof_vel for name, act_idx in self.actuator_name_to_index.items(): mu_idx = self.actuator_name_to_mu_idx[name] joint_name = self.mjcf_dof_names[mu_idx] target_q = self.target_dof_pos_by_name.get( joint_name, float(self.default_angles_mu[mu_idx]), ) target_dq = 0.0 feedforward_tau = 0.0 kp = self.kps_mu[mu_idx] kd = self.kds_mu[mu_idx] current_q = current_dof_pos[mu_idx] current_dq = current_dof_vel[mu_idx] tau = ( feedforward_tau + kp * (target_q - current_q) + kd * (target_dq - current_dq) ) if ( act_idx in self.actuator_force_range and self.actuator_force_range[act_idx] is not None ): min_force, max_force = self.actuator_force_range[act_idx] tau = np.clip(tau, min_force, max_force) self.d.ctrl[mu_idx] = tau if record_low_level_torque: torque_ref[self.mu_to_ref[mu_idx]] = np.float32(tau) mujoco.mj_step(self.m, self.d) if record_low_level_torque: self._robot_low_level_dof_torque_seq.append(torque_ref) self._record_low_level_foot_contact_sample() if sleep: time.sleep(self.simulation_dt) def _compute_pd_torque_command_ref(self) -> np.ndarray: current_dof_pos = self.robot_dof_pos current_dof_vel = self.robot_dof_vel num_mu_dofs = len(self.mjcf_dof_names) torque_mu = np.zeros(num_mu_dofs, dtype=np.float32) for name, act_idx in self.actuator_name_to_index.items(): mu_idx = self.actuator_name_to_mu_idx[name] joint_name = self.mjcf_dof_names[mu_idx] target_q = self.target_dof_pos_by_name.get( joint_name, float(self.default_angles_mu[mu_idx]), ) target_dq = 0.0 feedforward_tau = 0.0 kp = self.kps_mu[mu_idx] kd = self.kds_mu[mu_idx] current_q = current_dof_pos[mu_idx] current_dq = current_dof_vel[mu_idx] tau = ( feedforward_tau + kp * (target_q - current_q) + kd * (target_dq - current_dq) ) if ( act_idx in self.actuator_force_range and self.actuator_force_range[act_idx] is not None ): min_force, max_force = self.actuator_force_range[act_idx] tau = np.clip(tau, min_force, max_force) torque_mu[mu_idx] = np.float32(tau) num_ref_dofs = len(self.dof_names_ref_motion) torque_ref = np.zeros(num_ref_dofs, dtype=np.float32) for mu_idx, ref_idx in enumerate(self.mu_to_ref): torque_ref[ref_idx] = torque_mu[mu_idx] return torque_ref def _get_obs_ref_motion_states(self): # [2 * num_actions] in ONNX order: [ref_pos, ref_vel] if self.ref_dof_pos is None or self.ref_dof_vel is None: return np.zeros(2 * self.num_actions, dtype=np.float32) frame_idx = self.motion_frame_idx ref_pos_mu = self.ref_dof_pos[frame_idx] ref_vel_mu = self.ref_dof_vel[frame_idx] # Map URDF/Mu order -> ONNX order using precomputed indices ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32) ref_vel_onnx = ref_vel_mu[self.ref_to_onnx].astype(np.float32) return np.concatenate([ref_pos_onnx, ref_vel_onnx], axis=0).astype( np.float32 ) def _get_obs_ref_motion_states_fut(self): # [T, 2 * num_actions] flattened, ONNX order T = int(self.n_fut_frames) if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None: return np.zeros(0, dtype=np.float32) N = int(self.num_actions) frame_idx = self.motion_frame_idx last_valid_frame_idx = self.n_motion_frames - 1 # Build future arrays in Mu order [N, T] pos_fut = np.zeros( (len(self.dof_names_ref_motion), T), dtype=np.float32 ) vel_fut = np.zeros( (len(self.dof_names_ref_motion), T), dtype=np.float32 ) for i in range(T): idx = frame_idx + i + 1 if idx < self.n_motion_frames: pos_fut[:, i] = self.ref_dof_pos[idx] vel_fut[:, i] = self.ref_dof_vel[idx] else: pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx] vel_fut[:, i] = self.ref_dof_vel[last_valid_frame_idx] # Reorder to ONNX and flatten per training layout pos_fut_onnx = pos_fut[self.ref_to_onnx, :] # [N, T] vel_fut_onnx = vel_fut[self.ref_to_onnx, :] # [N, T] fut_concat = np.concatenate( [pos_fut_onnx.T, vel_fut_onnx.T], axis=1 ) # [T, 2N] return fut_concat.reshape(-1).astype(np.float32) def _get_obs_ref_dof_pos_fut(self): # [T, 2 * num_actions] flattened, ONNX order T = int(self.n_fut_frames) if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None: return np.zeros(0, dtype=np.float32) frame_idx = self.motion_frame_idx last_valid_frame_idx = self.n_motion_frames - 1 # Build future arrays in Mu order [N, T] pos_fut = np.zeros( (len(self.dof_names_ref_motion), T), dtype=np.float32 ) for i in range(T): idx = frame_idx + i + 1 if idx < self.n_motion_frames: pos_fut[:, i] = self.ref_dof_pos[idx] else: pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx] # Reorder to ONNX and flatten per training layout pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0) # [N, T] return pos_fut_onnx.reshape(-1).astype(np.float32) def _get_obs_ref_root_height_fut(self): T = int(self.n_fut_frames) if ( T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None or getattr(self, "ref_global_translation", None) is None ): return np.zeros(0, dtype=np.float32) frame_idx = self.motion_frame_idx last_valid_frame_idx = self.n_motion_frames - 1 # Build future arrays in Mu order [N, T] h_fut = np.zeros((1, T), dtype=np.float32) for i in range(T): idx = frame_idx + i + 1 if idx < self.n_motion_frames: h_fut[:, i] = self.ref_global_translation[ idx, self.root_body_idx, 2 ] else: h_fut[:, i] = self.ref_global_translation[ last_valid_frame_idx, self.root_body_idx, 2 ] return h_fut.reshape(-1).astype(np.float32) def _get_obs_ref_dof_pos_cur(self): # [2 * num_actions] in ONNX order: [ref_pos, ref_vel] if self.ref_dof_pos is None or self.ref_dof_vel is None: return np.zeros(2 * self.num_actions, dtype=np.float32) ref_pos_mu = self.ref_dof_pos[self.motion_frame_idx] # Map URDF/Mu order -> ONNX order using precomputed indices ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32) return ref_pos_onnx def _get_obs_ref_dof_vel_cur(self): # [2 * num_actions] in ONNX order: [ref_pos, ref_vel] if self.ref_dof_pos is None or self.ref_dof_vel is None: return np.zeros(2 * self.num_actions, dtype=np.float32) ref_vel_mu = self.ref_dof_vel[self.motion_frame_idx] # Map URDF/Mu order -> ONNX order using precomputed indices ref_vel_onnx = ref_vel_mu[self.ref_to_onnx].astype(np.float32) return ref_vel_onnx def _get_obs_ref_motion_filter_cutoff_hz(self): # cutoff = getattr(self, "filter_cutoff_hz", None) cutoff = 1.0 if cutoff is None: return np.float32(0.0) cutoff_flat = np.asarray(cutoff, dtype=np.float32).reshape(-1) if cutoff_flat.size == 0: return np.float32(0.0) frame_idx = min( max(int(getattr(self, "motion_frame_idx", 0)), 0), cutoff_flat.size - 1, ) return np.float32(cutoff_flat[frame_idx]) def _get_obs_ref_root_height_cur(self): if getattr(self, "ref_global_translation", None) is None: return 0.0 return self.ref_global_translation[ self.motion_frame_idx, self.root_body_idx, 2 ] def _get_obs_ref_gravity_projection_cur(self): if ( getattr(self, "ref_global_rotation_quat_xyzw", None) is None or self.n_motion_frames <= 0 ): return np.zeros(3, dtype=np.float32) q_root_xyzw = self.ref_global_rotation_quat_xyzw[ self.motion_frame_idx, self.root_body_idx ].astype(np.float32) q_root_wxyz = xyzw_to_wxyz( torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu") ) q_root_wxyz = standardize_quaternion(q_root_wxyz) g_w = torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device="cpu") g_root = quat_apply(quat_inv(q_root_wxyz), g_w) return g_root.detach().cpu().numpy().astype(np.float32) def _get_obs_ref_gravity_projection_fut(self): T = int(self.n_fut_frames) if ( T <= 0 or getattr(self, "ref_global_rotation_quat_xyzw", None) is None or self.n_motion_frames <= 0 ): return np.zeros(0, dtype=np.float32) frame_idx = self.motion_frame_idx last_valid_frame_idx = self.n_motion_frames - 1 g_w = torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device="cpu") gravity_fut = np.zeros((T, 3), dtype=np.float32) for i in range(T): idx = frame_idx + i + 1 if idx >= self.n_motion_frames: idx = last_valid_frame_idx q_root_xyzw = self.ref_global_rotation_quat_xyzw[ idx, self.root_body_idx ].astype(np.float32) q_root_wxyz = xyzw_to_wxyz( torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu") ) q_root_wxyz = standardize_quaternion(q_root_wxyz) gravity_fut[i] = ( quat_apply(quat_inv(q_root_wxyz), g_w) .detach() .cpu() .numpy() .astype(np.float32) ) return gravity_fut.reshape(-1).astype(np.float32) def _get_obs_ref_base_linvel_cur(self): if ( getattr(self, "ref_global_rotation_quat_xyzw", None) is None or getattr(self, "ref_global_velocity", None) is None or self.n_motion_frames <= 0 ): return np.zeros(3, dtype=np.float32) q_root_xyzw = self.ref_global_rotation_quat_xyzw[ self.motion_frame_idx, self.root_body_idx ].astype(np.float32) q_root_wxyz = xyzw_to_wxyz( torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu") ) q_root_wxyz = standardize_quaternion(q_root_wxyz) v_root_w = torch.as_tensor( self.ref_global_velocity[ self.motion_frame_idx, self.root_body_idx ], dtype=torch.float32, device="cpu", ) v_root = quat_apply(quat_inv(q_root_wxyz), v_root_w) return v_root.detach().cpu().numpy().astype(np.float32) def _get_obs_ref_base_linvel_fut(self): T = int(self.n_fut_frames) if ( T <= 0 or getattr(self, "ref_global_rotation_quat_xyzw", None) is None or getattr(self, "ref_global_velocity", None) is None or self.n_motion_frames <= 0 ): return np.zeros(0, dtype=np.float32) frame_idx = self.motion_frame_idx last_valid_frame_idx = self.n_motion_frames - 1 base_linvel_fut = np.zeros((T, 3), dtype=np.float32) for i in range(T): idx = frame_idx + i + 1 if idx >= self.n_motion_frames: idx = last_valid_frame_idx q_root_xyzw = self.ref_global_rotation_quat_xyzw[ idx, self.root_body_idx ].astype(np.float32) q_root_wxyz = xyzw_to_wxyz( torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu") ) q_root_wxyz = standardize_quaternion(q_root_wxyz) v_root_w = torch.as_tensor( self.ref_global_velocity[idx, self.root_body_idx], dtype=torch.float32, device="cpu", ) base_linvel_fut[i] = ( quat_apply(quat_inv(q_root_wxyz), v_root_w) .detach() .cpu() .numpy() .astype(np.float32) ) return base_linvel_fut.reshape(-1).astype(np.float32) def _get_obs_ref_base_angvel_cur(self): if ( getattr(self, "ref_global_rotation_quat_xyzw", None) is None or getattr(self, "ref_global_angular_velocity", None) is None or self.n_motion_frames <= 0 ): return np.zeros(3, dtype=np.float32) q_root_xyzw = self.ref_global_rotation_quat_xyzw[ self.motion_frame_idx, self.root_body_idx ].astype(np.float32) q_root_wxyz = xyzw_to_wxyz( torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu") ) q_root_wxyz = standardize_quaternion(q_root_wxyz) w_root_w = torch.as_tensor( self.ref_global_angular_velocity[ self.motion_frame_idx, self.root_body_idx ], dtype=torch.float32, device="cpu", ) w_root = quat_apply(quat_inv(q_root_wxyz), w_root_w) return w_root.detach().cpu().numpy().astype(np.float32) def _get_obs_ref_base_angvel_fut(self): T = int(self.n_fut_frames) if ( T <= 0 or getattr(self, "ref_global_rotation_quat_xyzw", None) is None or getattr(self, "ref_global_angular_velocity", None) is None or self.n_motion_frames <= 0 ): return np.zeros(0, dtype=np.float32) frame_idx = self.motion_frame_idx last_valid_frame_idx = self.n_motion_frames - 1 base_angvel_fut = np.zeros((T, 3), dtype=np.float32) for i in range(T): idx = frame_idx + i + 1 if idx >= self.n_motion_frames: idx = last_valid_frame_idx q_root_xyzw = self.ref_global_rotation_quat_xyzw[ idx, self.root_body_idx ].astype(np.float32) q_root_wxyz = xyzw_to_wxyz( torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu") ) q_root_wxyz = standardize_quaternion(q_root_wxyz) w_root_w = torch.as_tensor( self.ref_global_angular_velocity[idx, self.root_body_idx], dtype=torch.float32, device="cpu", ) base_angvel_fut[i] = ( quat_apply(quat_inv(q_root_wxyz), w_root_w) .detach() .cpu() .numpy() .astype(np.float32) ) return base_angvel_fut.reshape(-1).astype(np.float32) def _get_obs_ref_keybody_rel_pos_cur(self): keybody_idxs = self._get_ref_keybody_indices( "actor_ref_keybody_rel_pos_cur" ) n_keybodies = int(keybody_idxs.shape[0]) if n_keybodies == 0: return np.zeros(0, dtype=np.float32) if ( getattr(self, "ref_global_translation", None) is None or getattr(self, "ref_global_rotation_quat_xyzw", None) is None or self.n_motion_frames <= 0 ): return np.zeros(n_keybodies * 3, dtype=np.float32) frame_idx = self.motion_frame_idx ref_body_global_pos = self.ref_global_translation[frame_idx].astype( np.float32 ) # [B, 3] ref_root_global_pos = ref_body_global_pos[ self.root_body_idx ] # [3], world q_root_xyzw = self.ref_global_rotation_quat_xyzw[ frame_idx, self.root_body_idx ].astype(np.float32) q_root_wxyz = xyzw_to_wxyz( torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu") ) q_root_wxyz = standardize_quaternion(q_root_wxyz) rel_pos_w = ( ref_body_global_pos[keybody_idxs] - ref_root_global_pos[None, :] ) # [K, 3] rel_pos_w_t = torch.as_tensor( rel_pos_w, dtype=torch.float32, device="cpu" ) q_root_expand = q_root_wxyz.unsqueeze(0).expand(n_keybodies, 4) rel_pos_root_t = quat_apply(quat_inv(q_root_expand), rel_pos_w_t) return ( rel_pos_root_t.detach() .cpu() .numpy() .reshape(-1) .astype(np.float32) ) def _get_obs_ref_keybody_rel_pos_fut(self): T = int(self.n_fut_frames) if T <= 0: return np.zeros(0, dtype=np.float32) keybody_idxs = self._get_ref_keybody_indices( "actor_ref_keybody_rel_pos_fut" ) n_keybodies = int(keybody_idxs.shape[0]) if n_keybodies == 0: return np.zeros((T, 0), dtype=np.float32) if ( getattr(self, "ref_global_translation", None) is None or getattr(self, "ref_global_rotation_quat_xyzw", None) is None or self.n_motion_frames <= 0 ): return np.zeros((T, n_keybodies * 3), dtype=np.float32) frame_idx = self.motion_frame_idx last_valid_frame_idx = self.n_motion_frames - 1 rel_pos_fut = np.zeros((T, n_keybodies, 3), dtype=np.float32) for i in range(T): idx = frame_idx + i + 1 if idx >= self.n_motion_frames: idx = last_valid_frame_idx ref_body_global_pos = self.ref_global_translation[idx].astype( np.float32 ) # [B, 3] ref_root_global_pos = ref_body_global_pos[ self.root_body_idx ] # [3], world q_root_xyzw = self.ref_global_rotation_quat_xyzw[ idx, self.root_body_idx ].astype(np.float32) q_root_wxyz = xyzw_to_wxyz( torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu") ) q_root_wxyz = standardize_quaternion(q_root_wxyz) rel_pos_w = ( ref_body_global_pos[keybody_idxs] - ref_root_global_pos[None, :] ) # [K, 3] rel_pos_w_t = torch.as_tensor( rel_pos_w, dtype=torch.float32, device="cpu" ) q_root_expand = q_root_wxyz.unsqueeze(0).expand(n_keybodies, 4) rel_pos_fut[i] = ( quat_apply(quat_inv(q_root_expand), rel_pos_w_t) .detach() .cpu() .numpy() .astype(np.float32) ) return rel_pos_fut.reshape(T, -1).astype(np.float32) def _get_obs_place_holder(self): return np.zeros(self.actor_place_holder_ndim, dtype=np.float32) def _get_obs_vr_ref_motion_states(self): # [2 * num_actions] in ONNX order: [ref_pos, ref_vel] if self.ref_dof_pos is None or self.ref_dof_vel is None: return np.zeros(2 * self.num_actions, dtype=np.float32) frame_idx = self.motion_frame_idx ref_pos_mu = self.ref_dof_pos[frame_idx] # Map URDF/Mu order -> ONNX order using precomputed indices ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32) return np.concatenate( [ref_pos_onnx, np.zeros_like(ref_pos_onnx)], axis=0, ).astype(np.float32) def _get_obs_vr_ref_motion_states_fut(self): # [T, 2 * num_actions] flattened, ONNX order T = int(self.n_fut_frames) if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None: return np.zeros(0, dtype=np.float32) N = int(self.num_actions) frame_idx = self.motion_frame_idx last_valid_frame_idx = self.n_motion_frames - 1 # Build future arrays in Mu order [N, T] pos_fut = np.zeros( (len(self.dof_names_ref_motion), T), dtype=np.float32 ) for i in range(T): idx = frame_idx + i + 1 if idx < self.n_motion_frames: pos_fut[:, i] = self.ref_dof_pos[idx] else: pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx] # Reorder to ONNX and flatten per training layout pos_fut_onnx = pos_fut[self.ref_to_onnx, :] # [N, T] fut_concat = np.concatenate( [pos_fut_onnx.T, np.zeros_like(pos_fut_onnx.T)], axis=1 ) # [T, 2N] return fut_concat.reshape(-1).astype(np.float32) def _get_obs_rel_robot_root_ang_vel(self): q_root_wxyz = torch.as_tensor( self.robot_global_bodylink_rot[self.root_body_idx], dtype=torch.float32, device="cpu", ) w_root_w = torch.as_tensor( self.robot_global_bodylink_ang_vel[self.root_body_idx], dtype=torch.float32, device="cpu", ) w_root_b = quat_apply(quat_inv(q_root_wxyz), w_root_w) return w_root_b.detach().cpu().numpy().astype(np.float32) def _get_obs_last_action(self): return np.array(self.actions_onnx, dtype=np.float32).reshape(-1) def _get_obs_velocity_command(self): # Extended velocity command: [move_mask, vx, vy, vyaw] if ( self.command_mode == "velocity_tracking" and getattr(self, "keyboard_handler", None) is not None ): cmd = np.asarray( self.keyboard_handler.get_velocity_command(), dtype=np.float32 ).reshape(3) else: cmd = np.zeros(3, dtype=np.float32) out = np.zeros(4, dtype=np.float32) out[1:] = cmd out[0] = float(np.linalg.norm(cmd) > 0.1) return out def _get_obs_actor_ref_headling_aligned_vel_cmd(self): return self._get_obs_velocity_command() # ----------------- Actor term aliases (PULSE stage2 unified obs) ----------------- def _get_obs_actor_velocity_command(self): return self._get_obs_velocity_command() def _get_obs_actor_projected_gravity(self): return self._get_obs_projected_gravity() def _get_obs_actor_rel_robot_root_ang_vel(self): return self._get_obs_rel_robot_root_ang_vel() def _get_obs_actor_dof_pos(self): return self._get_obs_dof_pos() def _get_obs_actor_dof_vel(self): return self._get_obs_dof_vel() def _get_obs_actor_last_action(self): return self._get_obs_last_action() def _get_obs_actor_place_holder(self): return self._get_obs_place_holder() def _get_obs_actor_ref_dof_pos_fut(self): return self._get_obs_ref_dof_pos_fut() def _get_obs_actor_ref_dof_pos_cur(self): return self._get_obs_ref_dof_pos_cur() def _get_obs_actor_ref_motion_filter_cutoff_hz(self): return self._get_obs_ref_motion_filter_cutoff_hz() def _get_obs_actor_ref_root_height_fut(self): return self._get_obs_ref_root_height_fut() def _get_obs_actor_ref_root_height_cur(self): return self._get_obs_ref_root_height_cur() def _get_obs_actor_ref_gravity_projection_cur(self): return self._get_obs_ref_gravity_projection_cur() def _get_obs_actor_ref_gravity_projection_fut(self): return self._get_obs_ref_gravity_projection_fut() def _get_obs_actor_ref_base_linvel_cur(self): return self._get_obs_ref_base_linvel_cur() def _get_obs_actor_ref_base_linvel_fut(self): return self._get_obs_ref_base_linvel_fut() def _get_obs_actor_ref_base_angvel_cur(self): return self._get_obs_ref_base_angvel_cur() def _get_obs_actor_ref_base_angvel_fut(self): return self._get_obs_ref_base_angvel_fut() def _get_obs_actor_ref_keybody_rel_pos_cur(self): return self._get_obs_ref_keybody_rel_pos_cur() def _get_obs_actor_ref_keybody_rel_pos_fut(self): return self._get_obs_ref_keybody_rel_pos_fut() def _get_obs_global_anchor_diff(self): self._ensure_ref_to_sim_transform_rigid() ref_pos_sim = self.ref_global_bodylink_pos ref_rot_sim = self.ref_global_bodylink_rot if ref_pos_sim is None or ref_rot_sim is None: return np.zeros(9, dtype=np.float32) t_robot = torch.as_tensor( self.robot_global_bodylink_pos[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) q_robot_wxyz = torch.as_tensor( self.robot_global_bodylink_rot[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) t_ref_sim = torch.as_tensor( ref_pos_sim[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) q_ref_sim = torch.as_tensor( ref_rot_sim[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) # Use isaaclab semantics: pose of ref (frame 2) w.r.t. robot (frame 1) p_diff_t, q_diff_wxyz_t = subtract_frame_transforms( t01=t_robot, q01=q_robot_wxyz, t02=t_ref_sim, q02=q_ref_sim, ) q_diff_wxyz_t = quat_normalize_wxyz(q_diff_wxyz_t) rot_diff_mat = matrix_from_quat(q_diff_wxyz_t) out = torch.cat( [p_diff_t.reshape(-1), rot_diff_mat[..., :2].reshape(-1)], dim=-1 ) return out.detach().cpu().numpy().astype(np.float32) def _get_obs_global_anchor_pos_diff(self): self._ensure_ref_to_sim_transform_rigid() ref_pos_sim = self.ref_global_bodylink_pos ref_rot_sim = self.ref_global_bodylink_rot if ref_pos_sim is None or ref_rot_sim is None: return np.zeros(3, dtype=np.float32) t_robot = torch.as_tensor( self.robot_global_bodylink_pos[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) # [3], world q_robot_wxyz = torch.as_tensor( self.robot_global_bodylink_rot[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) # [4], wxyz # Transform reference anchor pose into simulation global frame t_ref_sim = torch.as_tensor( ref_pos_sim[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) q_ref_sim = torch.as_tensor( ref_rot_sim[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) pos_diff_anchor_t, _ = subtract_frame_transforms( t01=t_robot, q01=q_robot_wxyz, t02=t_ref_sim, q02=q_ref_sim, ) return pos_diff_anchor_t.detach().cpu().numpy().astype(np.float32) def _get_obs_global_anchor_rot_diff(self): self._ensure_ref_to_sim_transform_rigid() ref_pos_sim = self.ref_global_bodylink_pos ref_rot_sim = self.ref_global_bodylink_rot if ref_pos_sim is None or ref_rot_sim is None: return np.zeros(6, dtype=np.float32) t_robot = torch.as_tensor( self.robot_global_bodylink_pos[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) q_robot_wxyz = torch.as_tensor( self.robot_global_bodylink_rot[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) q_robot_wxyz = standardize_quaternion(q_robot_wxyz) t_ref_sim = torch.as_tensor( ref_pos_sim[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) q_ref_sim = torch.as_tensor( ref_rot_sim[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) q_ref_sim = standardize_quaternion(q_ref_sim) _, q_diff_wxyz_t = subtract_frame_transforms( t01=t_robot, q01=q_robot_wxyz, t02=t_ref_sim, q02=q_ref_sim, ) q_diff_wxyz_t = standardize_quaternion(q_diff_wxyz_t) rot_diff_mat = matrix_from_quat(q_diff_wxyz_t) return ( rot_diff_mat[..., :2] .reshape(-1) .detach() .cpu() .numpy() .astype(np.float32) ) def _get_obs_global_bodylink_translation(self) -> np.ndarray: """Global body translations in simulator/URDF order, flattened as [num_bodies * 3]. The body dimension excludes the MuJoCo world body and is assumed to match the NPZ `*_global_translation` arrays (root at index 0). """ pos = self.robot_global_bodylink_pos.astype(np.float32) # [B, 3] return pos.reshape(-1) def _get_obs_global_bodylink_rotation_quat(self) -> np.ndarray: """Global body rotations as XYZW quaternions in simulator/URDF order, flattened [num_bodies * 4].""" q_wxyz = self.robot_global_bodylink_rot # [B, 4] in w, x, y, z q_xyzw = np.empty_like(q_wxyz, dtype=np.float32) q_xyzw[..., 0] = q_wxyz[..., 1] q_xyzw[..., 1] = q_wxyz[..., 2] q_xyzw[..., 2] = q_wxyz[..., 3] q_xyzw[..., 3] = q_wxyz[..., 0] return q_xyzw.reshape(-1) def _get_obs_global_bodylink_velocity(self) -> np.ndarray: """Global body linear velocities in world frame, flattened [num_bodies * 3].""" lin_vel = self.robot_global_bodylink_lin_vel.astype( np.float32 ) # [B, 3] return lin_vel.reshape(-1) def _get_obs_global_bodylink_angular_velocity(self) -> np.ndarray: """Global body angular velocities in world frame, flattened [num_bodies * 3].""" ang_vel = self.robot_global_bodylink_ang_vel.astype( np.float32 ) # [B, 3] return ang_vel.reshape(-1) @property def ref_global_bodylink_pos(self) -> np.ndarray | None: """Reference body positions transformed into the simulator global frame. Uses the yaw+translation Ref->Sim rigid transform computed from the initial robot global pose so that the reference motion is expressed in the same world frame as the robot (matching XY translation and yaw at frame 0). Returns: Array of shape [num_bodies, 3] giving reference positions in simulator world frame, or None if reference globals are not available. """ if getattr(self, "ref_global_translation", None) is None: return None if self.n_motion_frames <= 0: return None self._ensure_ref_to_sim_transform_rigid() frame_idx = self.ref_motion_frame_idx ref_pos_world = self.ref_global_translation[frame_idx].astype( np.float32 ) # [B, 3] pos_world_t = torch.as_tensor( ref_pos_world, dtype=torch.float32, device="cpu" ) q_ref_to_sim = torch.as_tensor( self._ref_to_sim_q_wxyz, dtype=torch.float32, device="cpu" ) q_ref_to_sim = q_ref_to_sim.unsqueeze(0).expand( pos_world_t.shape[0], 4 ) t_ref_to_sim = torch.as_tensor( self._ref_to_sim_t, dtype=torch.float32, device="cpu" ) # Apply yaw rotation + translation based on initial robot state pos_sim_t = ( quat_apply(q_ref_to_sim, pos_world_t) + t_ref_to_sim[None, :] ) return pos_sim_t.detach().cpu().numpy().astype(np.float32) @property def ref_global_bodylink_rot(self) -> np.ndarray | None: """Reference body rotations transformed into the simulator global frame. Uses the yaw component of the Ref->Sim transform so that the reference motion's global yaw is aligned with the robot's initial yaw, while preserving roll/pitch from the motion data. Returns: Array of shape [num_bodies, 4] giving reference orientations in WXYZ format, or None if reference globals are not available. """ if getattr(self, "ref_global_rotation_quat_xyzw", None) is None: return None if self.n_motion_frames <= 0: return None frame_idx = self.ref_motion_frame_idx ref_rot_xyzw = self.ref_global_rotation_quat_xyzw[frame_idx].astype( np.float32 ) # [B, 4] in XYZW q_ref_xyzw_t = torch.as_tensor( ref_rot_xyzw, dtype=torch.float32, device="cpu" ) q_ref_wxyz_t = xyzw_to_wxyz(q_ref_xyzw_t) q_ref_wxyz_t = standardize_quaternion(q_ref_wxyz_t) q_ref_to_sim = torch.as_tensor( self._ref_to_sim_q_wxyz, dtype=torch.float32, device="cpu" ) q_ref_to_sim = q_ref_to_sim.unsqueeze(0).expand_as(q_ref_wxyz_t) q_ref_sim_wxyz_t = quat_mul(q_ref_to_sim, q_ref_wxyz_t) q_ref_sim_wxyz_t = standardize_quaternion(q_ref_sim_wxyz_t) return q_ref_sim_wxyz_t.detach().cpu().numpy().astype(np.float32) def _draw_ref_body_spheres_to_scene( self, scene, reset_ngeom: bool ) -> None: """Draw blue spheres at reference body positions into a MuJoCo scene.""" ref_positions_sim = self.ref_global_bodylink_pos if ref_positions_sim is None: if reset_ngeom: scene.ngeom = 0 return if reset_ngeom: scene.ngeom = 0 radius = float(self.config.get("ref_marker_radius", 0.03)) rgba = np.array([0.8, 0.0, 0.0, 1.0], dtype=np.float32) size = np.array([radius, 0.0, 0.0], dtype=np.float32) mat = np.eye(3, dtype=np.float32).reshape(-1) start = int(scene.ngeom) idx = 0 for pos in ref_positions_sim: geom_id = start + idx if geom_id >= scene.maxgeom: break mujoco.mjv_initGeom( scene.geoms[geom_id], mujoco.mjtGeom.mjGEOM_SPHERE, size, pos.astype(np.float32), mat, rgba, ) idx += 1 scene.ngeom = start + idx def _get_obs_rel_anchor_lin_vel(self): # Anchor linear velocity expressed in the anchor frame (IsaacLab semantics) q_anchor_wxyz = torch.as_tensor( self.robot_global_bodylink_rot[self.anchor_body_idx], dtype=torch.float32, device="cpu", ) v_local_t = quat_apply( quat_inv(q_anchor_wxyz), torch.as_tensor( self.robot_global_bodylink_lin_vel[self.anchor_body_idx], dtype=torch.float32, device="cpu", ), ) return v_local_t.detach().cpu().numpy().astype(np.float32) def _get_obs_projected_gravity(self): q = torch.as_tensor( self.robot_global_bodylink_rot[self.root_body_idx], dtype=torch.float32, ) qw, qx, qy, qz = q[0], q[1], q[2], q[3] gravity_orientation = torch.zeros(3, dtype=torch.float32, device="cpu") gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy) gravity_orientation[1] = -2.0 * (qz * qy + qw * qx) gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz) return gravity_orientation.detach().cpu().numpy().astype(np.float32) def _get_obs_dof_pos(self): pos_mu = self.robot_dof_pos pos_onnx = pos_mu[self.mu_to_onnx] return (pos_onnx - self.default_angles_onnx.astype(np.float32)).astype( np.float32 ) def _get_obs_dof_vel(self): vel_mu = self.robot_dof_vel vel_onnx = vel_mu[self.mu_to_onnx] return vel_onnx.astype(np.float32) def _record_robot_states(self) -> None: """Record current robot DOF and global body states for offline NPZ dumping. - DOF states are stored in reference DOF order (config.robot.dof_names). - Body states are stored in dataset/URDF order (config.robot.body_names). """ if self.command_mode != "motion_tracking": return if self.ref_dof_pos is None or self.n_motion_frames <= 0: return if len(self._robot_dof_pos_seq) >= self.n_motion_frames: return # Joint positions/velocities from Unitree lowstate in actuator (MuJoCo) order pos_mu = self.robot_dof_pos vel_mu = self.robot_dof_vel # Map MuJoCo actuator order -> reference DOF order num_dofs = len(self.dof_names_ref_motion) pos_ref = np.zeros(num_dofs, dtype=np.float32) vel_ref = np.zeros(num_dofs, dtype=np.float32) for mu_idx, ref_idx in enumerate(self.mu_to_ref): pos_ref[ref_idx] = pos_mu[mu_idx] vel_ref[ref_idx] = vel_mu[mu_idx] self._robot_dof_pos_seq.append(pos_ref) self._robot_dof_vel_seq.append(vel_ref) if self._prev_recorded_dof_vel_ref is None: acc_ref = np.zeros_like(vel_ref, dtype=np.float32) else: acc_ref = (vel_ref - self._prev_recorded_dof_vel_ref) / np.float32( self.policy_dt ) self._prev_recorded_dof_vel_ref = vel_ref.copy() self._robot_dof_acc_seq.append(acc_ref.astype(np.float32)) # Global bodylink states in dataset/URDF order body_count = int(self.robot_global_bodylink_pos.shape[0]) trans = self._get_obs_global_bodylink_translation().reshape( body_count, 3 ) rot = self._get_obs_global_bodylink_rotation_quat().reshape( body_count, 4 ) vel = self._get_obs_global_bodylink_velocity().reshape(body_count, 3) ang_vel = self._get_obs_global_bodylink_angular_velocity().reshape( body_count, 3 ) self._robot_global_translation_seq.append(trans) self._robot_global_rotation_quat_seq.append(rot) self._robot_global_velocity_seq.append(vel) self._robot_global_angular_velocity_seq.append(ang_vel) def load_specific_motion(self, npz_path): with np.load(npz_path, allow_pickle=True) as npz: self.ref_global_translation = npz["ref_global_translation"] self.ref_global_rotation_quat_xyzw = npz[ "ref_global_rotation_quat" ] self.ref_global_velocity = npz["ref_global_velocity"] self.ref_global_angular_velocity = npz[ "ref_global_angular_velocity" ] self.ref_dof_pos = npz["ref_dof_pos"] self.ref_dof_vel = npz["ref_dof_vel"] raw_filter_cutoff_hz = ( np.array(npz["filter_cutoff_hz"]).astype(np.float32) if "filter_cutoff_hz" in npz else None ) self.n_motion_frames = self.ref_global_translation.shape[0] # self.filter_cutoff_hz = self._normalize_filter_cutoff_hz( # raw_filter_cutoff_hz, self.n_motion_frames # ) self.filter_cutoff_hz = 1.0 self._ref_to_sim_q_wxyz = np.array( [1.0, 0.0, 0.0, 0.0], dtype=np.float32 ) self._ref_to_sim_t = np.zeros(3, dtype=np.float32) self._ref_to_sim_ready = True def reset_state_teleport(self): self.counter = 0 self.motion_frame_idx = 0 mujoco.mj_resetDataKeyframe(self.m, self.d, 0) has_ref_motion = ( self.ref_dof_pos is not None and self.ref_dof_vel is not None and self.ref_global_translation is not None and self.ref_global_rotation_quat_xyzw is not None and self.ref_global_velocity is not None and self.ref_global_angular_velocity is not None ) if has_ref_motion: root_pos = self.ref_global_translation[0, 0] # (x, y, z) root_rot = self.ref_global_rotation_quat_xyzw[0, 0] # XYZW root_vel = self.ref_global_velocity[0, 0] root_ang = self.ref_global_angular_velocity[0, 0] dof_pos = getattr( self, "stored_full_ref_dof_pos", self.ref_dof_pos )[0] dof_vel = getattr( self, "stored_full_ref_dof_vel", self.ref_dof_vel )[0] self.d.qpos[0:3] = root_pos self.d.qpos[3:7] = [ root_rot[3], root_rot[0], root_rot[1], root_rot[2], ] # XYZW -> WXYZ self.d.qpos[self.actuator_qpos_indices] = dof_pos[self.mu_to_ref] self.d.qvel[0:3] = root_vel self.d.qvel[3:6] = root_ang self.d.qvel[self.actuator_qvel_indices] = dof_vel[self.mu_to_ref] self.target_dof_pos_mu = dof_pos[self.mu_to_ref].astype(np.float32) logger.info( "Teleport reset initialized from reference frame 0 " "(root + dof pos/vel)" ) else: self.d.qpos[self.actuator_qpos_indices] = self.default_angles_mu self.d.qvel[self.actuator_qvel_indices] = 0.0 self.target_dof_pos_mu = self.default_angles_mu.astype(np.float32) logger.info( "Teleport reset initialized from ONNX default joint positions" ) self.target_dof_pos_by_name = { self.mjcf_dof_names[i]: float(self.target_dof_pos_mu[i]) for i in range(self.m.nu) } mujoco.mj_forward(self.m, self.d) if self.use_kv_cache and self.policy_kv_shape: shape = [ d if isinstance(d, int) else 1 for d in self.policy_kv_shape ] self.policy_kv_cache = np.zeros(shape, dtype=np.float32) self._robot_dof_pos_seq = [] self._robot_dof_vel_seq = [] self._robot_dof_acc_seq = [] self._robot_dof_torque_seq = [] self._robot_low_level_dof_torque_seq = [] self._robot_low_level_foot_contact_seq = [] self._robot_low_level_foot_normal_force_seq = [] self._robot_low_level_foot_tangent_speed_seq = [] self._robot_actions_seq = [] self._robot_action_rate_seq = [] self._prev_recorded_dof_vel_ref = None self._prev_actions_onnx = None self._reset_action_ema_filter() self._reset_action_delay_randomization() self._prev_low_level_foot_geom_centers = None self._robot_global_translation_seq = [] self._robot_global_rotation_quat_seq = [] self._robot_global_velocity_seq = [] self._robot_global_angular_velocity_seq = [] self._robot_moe_expert_indices_seq = [] self._robot_moe_expert_logits_seq = [] self._reset_onnx_io_dump_buffers() def save_batch_result(self, output_path, meta_info): import json metadata = dict(meta_info) metadata.setdefault( "robot_low_level_torque_dt", float(getattr(self, "simulation_dt", 1.0 / 200.0)), ) metadata.setdefault( "robot_low_level_contact_dt", float(getattr(self, "simulation_dt", 1.0 / 200.0)), ) robot_moe_expert_indices, robot_moe_expert_logits = ( self._get_stacked_moe_routing_tensors() ) ( robot_low_level_foot_contact, robot_low_level_foot_normal_force, robot_low_level_foot_tangent_speed, ) = self._get_stacked_low_level_foot_contact_tensors() res = { "robot_dof_pos": np.stack(self._robot_dof_pos_seq), "robot_dof_vel": np.stack(self._robot_dof_vel_seq), "robot_dof_acc": np.stack(self._robot_dof_acc_seq), "robot_dof_torque": np.stack(self._robot_dof_torque_seq), "robot_low_level_dof_torque": np.stack( self._robot_low_level_dof_torque_seq ), "robot_low_level_foot_contact": robot_low_level_foot_contact, "robot_low_level_foot_normal_force": ( robot_low_level_foot_normal_force ), "robot_low_level_foot_tangent_speed": ( robot_low_level_foot_tangent_speed ), "robot_low_level_torque_dt": np.array( getattr(self, "simulation_dt", 1.0 / 200.0), dtype=np.float32 ), "robot_low_level_contact_dt": np.array( getattr(self, "simulation_dt", 1.0 / 200.0), dtype=np.float32 ), "robot_action_rate": np.asarray( self._robot_action_rate_seq, dtype=np.float32 ), "robot_global_translation": np.stack( self._robot_global_translation_seq ), "robot_global_rotation_quat": np.stack( self._robot_global_rotation_quat_seq ), "robot_global_velocity": np.stack(self._robot_global_velocity_seq), "robot_global_angular_velocity": np.stack( self._robot_global_angular_velocity_seq ), "ref_dof_pos": self.ref_dof_pos, "ref_dof_vel": self.ref_dof_vel, "ref_global_translation": self.ref_global_translation, "ref_global_rotation_quat": self.ref_global_rotation_quat_xyzw, "ref_global_velocity": self.ref_global_velocity, "ref_global_angular_velocity": self.ref_global_angular_velocity, "metadata": json.dumps(metadata), } if len(getattr(self, "_robot_actions_seq", [])) > 0: res["robot_actions"] = np.stack( self._robot_actions_seq, axis=0 ).astype(np.float32) if robot_moe_expert_indices is not None: res["robot_moe_expert_indices"] = robot_moe_expert_indices if robot_moe_expert_logits is not None: res["robot_moe_expert_logits"] = robot_moe_expert_logits np.savez_compressed(output_path, **res) def setup(self): """Set up the evaluator by loading all required components.""" self.load_mujoco_model() self._init_low_level_foot_contact_logging() self._build_mjcf_dof_names() self.load_policy() self._apply_onnx_metadata() self._build_actuator_qpos_indices() self._build_dof_mappings() self._build_actuator_name_map() self._build_actuator_force_range_map() self._init_camera_config() self._init_obs_buffers() # Initialize keyboard handler for velocity tracking if self.command_mode == "velocity_tracking": self.keyboard_handler = VelocityKeyboardHandler( vx_increment=0.1, vy_increment=0.05, vyaw_increment=0.05, vx_limits=(-0.5, 1.0), vy_limits=(-0.3, 0.3), vyaw_limits=(-0.5, 0.5), ) logger.info( "Velocity tracking mode enabled. Keyboard controls:\n" " W/S: forward/backward velocity\n" " A/D: left/right velocity\n" " J/L: turn left/right\n" " Space/X: reset all\n" " Keep terminal window focused for keyboard input" ) elif self.command_mode == "motion_tracking": m_path = self.config.get("motion_npz_path", "") if m_path and os.path.isfile(m_path): self.load_motion_data() def _create_eval_progress_bar(self, desc: str, max_steps: int): if self.ref_dof_pos is not None: return tqdm(total=self.n_motion_frames, desc=desc, unit="frame") if max_steps > 0: return tqdm(total=max_steps, desc=desc, unit="step") return None def _advance_eval_frame(self, max_steps: int) -> bool: if self.ref_dof_pos is not None: if self.motion_frame_idx >= (self.n_motion_frames - 1): return False self.motion_frame_idx += 1 return True if max_steps > 0 and self.counter >= max_steps: return False return True def _run_eval_step(self, max_steps: int) -> bool: self._update_policy() self.counter += 1 self._apply_control(sleep=True) if self._video_writer is not None: self._maybe_record_frame() return self._advance_eval_frame(max_steps) def _build_mjcf_dof_names(self): """Build MJCF joint name lists used for control/state indexing. - mjcf_dof_names: joint names corresponding to each actuator (actuator order) """ names = [] for i in range(self.m.nu): j_id = int(self.m.actuator_trnid[i][0]) j_name = mujoco.mj_id2name( self.m, mujoco._enums.mjtObj.mjOBJ_JOINT, j_id ) names.append(j_name) self.mjcf_dof_names = names def _build_actuator_qpos_indices(self): """Build mapping from actuator index to qpos/qvel indices.""" self.actuator_qpos_indices = np.zeros(self.m.nu, dtype=np.int32) self.actuator_qvel_indices = np.zeros(self.m.nu, dtype=np.int32) for i in range(self.m.nu): j_id = int(self.m.actuator_trnid[i, 0]) self.actuator_qpos_indices[i] = self.m.jnt_qposadr[j_id] self.actuator_qvel_indices[i] = self.m.jnt_dofadr[j_id] def _build_actuator_name_map(self): """Build mappings from actuator name to indices and MJCF DOF indices.""" name_to_index = {} actuator_name_to_mu_idx = {} for i in range(self.m.nu): act_name = mujoco.mj_id2name( self.m, mujoco._enums.mjtObj.mjOBJ_ACTUATOR, i ) name_to_index[act_name] = i j_id = int(self.m.actuator_trnid[i][0]) j_name = mujoco.mj_id2name( self.m, mujoco._enums.mjtObj.mjOBJ_JOINT, j_id ) mu_idx = self.mjcf_dof_names.index(j_name) actuator_name_to_mu_idx[act_name] = mu_idx self.actuator_name_to_index = name_to_index self.actuator_name_to_mu_idx = actuator_name_to_mu_idx def _build_actuator_force_range_map(self): """Build mapping from actuator index to joint actuator force range from XML.""" self.actuator_force_range = {} for i in range(self.m.nu): j_id = int(self.m.actuator_trnid[i][0]) has_limit = False min_force = 0.0 max_force = 0.0 if j_id >= 0 and j_id < self.m.njnt: if self.m.jnt_actfrclimited[j_id]: min_force = float(self.m.jnt_actfrcrange[j_id][0]) max_force = float(self.m.jnt_actfrcrange[j_id][1]) if min_force != 0.0 or max_force != 0.0: has_limit = True if not has_limit: if self.m.actuator_forcelimited[i]: min_force = float(self.m.actuator_forcerange[i][0]) max_force = float(self.m.actuator_forcerange[i][1]) if min_force != 0.0 or max_force != 0.0: has_limit = True if has_limit: self.actuator_force_range[i] = (min_force, max_force) else: self.actuator_force_range[i] = None def run_simulation_unitree(self): """Run simulation using Unitree's official threading/viewer pattern.""" # Defer heavy deps to runtime to keep default path light # Ensure thirdparty simulate_python is on sys.path for imports self.counter = 0 self.motion_frame_idx = 0 self.reset_state_teleport() max_steps = int(self.config.get("max_policy_steps", 0)) viewer_dt = float(self.config.get("unitree_viewer_dt", 1.0 / 60.0)) viewer = mujoco.viewer.launch_passive(self.m, self.d) # Configure viewer camera to use shared align / tracking settings self._configure_viewer_camera(viewer) # Start keyboard listener for velocity tracking if ( self.command_mode == "velocity_tracking" and self.keyboard_handler is not None ): self.keyboard_handler.start_listener() # Optional recording in viewer mode if bool(self.config.get("record_video", False)): self._init_video_tools(tag="viewer") pbar = self._create_eval_progress_bar("GUI eval", max_steps) locker = threading.Lock() stop_event = threading.Event() def simulation_thread(): while viewer.is_running() and not stop_event.is_set(): with locker: keep_running = self._run_eval_step(max_steps) if pbar is not None: pbar.update(1) if not keep_running: stop_event.set() viewer.close() def physics_viewer_thread(): while viewer.is_running() and not stop_event.is_set(): with locker: # Update camera lookat to track robot root (with small offset for framing) self._update_camera_lookat(viewer.cam) # Draw reference global bodylink positions as blue spheres when available self._draw_ref_body_spheres_to_scene( viewer.user_scn, reset_ngeom=True ) viewer.sync() time.sleep(viewer_dt) viewer_thread = Thread(target=physics_viewer_thread) sim_thread = Thread(target=simulation_thread) viewer_thread.start() sim_thread.start() # Block until viewer closes viewer_thread.join() sim_thread.join() # Close progress bar if pbar is not None: pbar.close() # Stop keyboard listener if ( self.command_mode == "velocity_tracking" and self.keyboard_handler is not None ): self.keyboard_handler.stop_listener() # Teardown recording self._close_video_tools() # Dump robot-augmented motion npz if motion tracking is enabled self._dump_robot_augmented_npz() def run_simulation_unitree_headless(self): """Run simulation headless (no GUI) with optional video recording.""" # Defer heavy deps to runtime to keep default path light # Initialize self.counter = 0 self.motion_frame_idx = 0 self.reset_state_teleport() max_steps = int(self.config.get("max_policy_steps", 0)) # Start keyboard listener for velocity tracking (even in headless mode) if ( self.command_mode == "velocity_tracking" and self.keyboard_handler is not None ): self.keyboard_handler.start_listener() # Optional recording in headless mode if bool(self.config.get("record_video", False)): self._init_video_tools(tag="headless") pbar = self._create_eval_progress_bar("Headless eval", max_steps) running = True while running: running = self._run_eval_step(max_steps) if pbar is not None: pbar.update(1) if pbar is not None: pbar.close() # Stop keyboard listener if ( self.command_mode == "velocity_tracking" and self.keyboard_handler is not None ): self.keyboard_handler.stop_listener() self._close_video_tools() # Dump robot-augmented motion npz if motion tracking is enabled self._dump_robot_augmented_npz() def run_simulation(self): if bool(self.config.get("headless", False)): logger.info("Running MuJoCo sim2sim headless") self.run_simulation_unitree_headless() else: self.run_simulation_unitree() def _update_policy(self): # Record robot states once per policy step for offline NPZ dumping self._record_robot_states() latest_obs = self.obs_builder.build_policy_obs() policy_obs_np = latest_obs[None, :] input_feed = {} input_feed[self.policy_input_name] = policy_obs_np if self.use_kv_cache: if self.policy_kv_cache is None: shape = [ d if isinstance(d, int) else 1 for d in self.policy_kv_shape ] self.policy_kv_cache = np.zeros(shape, dtype=np.float32) # if ( # self.policy_effective_context_len > 0 # and self.counter > 0 # and self.counter % self.policy_effective_context_len == 0 # ): # self.policy_kv_cache.fill(0.0) input_feed[self.policy_kv_input_name] = self.policy_kv_cache if self.policy_step_input_name is not None: step_idx = self.counter if self.use_kv_cache and self.policy_effective_context_len > 0: step_idx = self.counter % self.policy_effective_context_len step_tensor = np.array([step_idx], dtype=np.int64) input_feed[self.policy_step_input_name] = step_tensor if torch.cuda.is_available(): torch.cuda.synchronize() output_names = [self.policy_output_name] if self.use_kv_cache and self.policy_kv_output_name: output_names.append(self.policy_kv_output_name) for _, indices_name, logits_name in self.policy_moe_layer_output_names: output_names.extend([indices_name, logits_name]) onnx_output = self.policy_session.run(output_names, input_feed) if self.dump_onnx_io_npy: self._record_onnx_io_frame(input_feed, output_names, onnx_output) if torch.cuda.is_available(): torch.cuda.synchronize() raw_actions_onnx = onnx_output[0].reshape(-1) filtered_actions_onnx = self._apply_action_ema_filter(raw_actions_onnx) self.actions_onnx = self._apply_action_delay(filtered_actions_onnx) if self.use_kv_cache and len(onnx_output) > 1: new_cache = onnx_output[1] self.policy_kv_cache = new_cache output_offset = 1 + int( bool(self.use_kv_cache and self.policy_kv_output_name) ) if self.policy_moe_layer_output_names: step_indices = [] step_logits = [] for ( _layer_idx, _indices_name, _logits_name, ) in self.policy_moe_layer_output_names: step_indices.append( self._flatten_single_step_output( onnx_output[output_offset], dtype=np.int64, ) ) output_offset += 1 step_logits.append( self._flatten_single_step_output( onnx_output[output_offset], dtype=np.float32, ) ) output_offset += 1 self._robot_moe_expert_indices_seq.append( np.stack(step_indices, axis=0) ) self._robot_moe_expert_logits_seq.append( np.stack(step_logits, axis=0) ) self.target_dof_pos_onnx = ( self.actions_onnx * self.action_scale_onnx + self.default_angles_onnx ) self.target_dof_pos_mu = self.target_dof_pos_onnx[self.onnx_to_mu] for i, dof_name in enumerate(self.mjcf_dof_names): self.target_dof_pos_by_name[dof_name] = float( self.target_dof_pos_mu[i] ) if ( self.command_mode == "motion_tracking" and self.ref_dof_pos is not None and len(self._robot_action_rate_seq) < len(self._robot_dof_pos_seq) ): self._robot_actions_seq.append( self.actions_onnx.astype(np.float32).copy() ) if self._prev_actions_onnx is None: action_rate = np.float32(0.0) else: action_rate = np.float32( np.linalg.norm(self.actions_onnx - self._prev_actions_onnx) / self.policy_dt ) self._prev_actions_onnx = self.actions_onnx.copy() self._robot_action_rate_seq.append(action_rate) self._robot_dof_torque_seq.append( self._compute_pd_torque_command_ref() ) def _get_config_value(config_obj, key: str): value = config_obj.get(key, None) if value is None and config_obj.get("eval", None) is not None: value = config_obj.eval.get(key, None) return value def _normalize_ckpt_name_list(ckpt_onnx_names): if ckpt_onnx_names is None: return [] if isinstance(ckpt_onnx_names, ListConfig): raw_names = list(ckpt_onnx_names) elif isinstance(ckpt_onnx_names, (list, tuple)): raw_names = list(ckpt_onnx_names) else: raise TypeError( "ckpt_onnx_names must be a list/tuple, " f"got {type(ckpt_onnx_names)}" ) normalized_names = [] for name in raw_names: name_str = str(name).strip() if name_str != "": normalized_names.append(name_str) return normalized_names def _resolve_multi_ckpt_paths(ckpt_onnx_root_dir, ckpt_onnx_names): root_dir_str = str(ckpt_onnx_root_dir).strip() if root_dir_str == "": raise ValueError("ckpt_onnx_root_dir cannot be empty") root_dir = Path(root_dir_str) if not root_dir.is_dir(): raise NotADirectoryError( f"ckpt_onnx_root_dir does not exist or is not a directory: {root_dir}" ) requested_names = _normalize_ckpt_name_list(ckpt_onnx_names) if len(requested_names) == 0: raise ValueError( "ckpt_onnx_names is empty. Please provide checkpoint names " 'like ["model_1000.onnx", "model_2000.onnx"].' ) discovered_paths = sorted(root_dir.rglob("*.onnx")) if len(discovered_paths) == 0: raise FileNotFoundError( f"No .onnx files found under ckpt_onnx_root_dir={root_dir}" ) paths_by_name = {} for path in discovered_paths: if path.name not in paths_by_name: paths_by_name[path.name] = [] paths_by_name[path.name].append(path) selected_paths = [] missing_names = [] for name in requested_names: candidates = paths_by_name.get(name, []) if len(candidates) == 0: missing_names.append(name) continue if len(candidates) > 1: logger.warning( f"Found {len(candidates)} ONNX files named '{name}' under " f"{root_dir}; selecting the first one: {candidates[0]}" ) selected_paths.append(candidates[0]) if len(missing_names) > 0: logger.warning( "Some requested checkpoints were not found under " f"{root_dir}: {missing_names}" ) if len(selected_paths) == 0: raise FileNotFoundError( "None of the requested checkpoints were found under " f"{root_dir}. Requested names: {requested_names}" ) return selected_paths def _resolve_eval_ckpt_paths(config_obj): ckpt_onnx_root_dir = _get_config_value(config_obj, "ckpt_onnx_root_dir") if ( ckpt_onnx_root_dir is not None and str(ckpt_onnx_root_dir).strip() != "" ): ckpt_onnx_names = _get_config_value(config_obj, "ckpt_onnx_names") return _resolve_multi_ckpt_paths(ckpt_onnx_root_dir, ckpt_onnx_names) ckpt_onnx_path = _get_config_value(config_obj, "ckpt_onnx_path") if ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == "": raise ValueError( "No ONNX checkpoint is provided. Set ckpt_onnx_path, or set " "ckpt_onnx_root_dir + ckpt_onnx_names." ) ckpt_path = Path(str(ckpt_onnx_path)) if not ckpt_path.is_file(): raise FileNotFoundError(f"ONNX checkpoint not found: {ckpt_path}") return [ckpt_path] def _checkpoint_tag_from_path(ckpt_path: Path) -> str: match = re.search(r"model_(\d+)", ckpt_path.name) if match: return f"model_{match.group(1)}" return ckpt_path.stem def _build_eval_output_dir(ckpt_path: Path, dataset_name: str) -> Path: ckpt_tag = _checkpoint_tag_from_path(ckpt_path) dir_name = f"mujoco_eval_output_{ckpt_tag}_{dataset_name}" return ckpt_path.parent.parent / dir_name def _build_onnx_io_dump_dir(output_dir: str | Path) -> Path: return Path(output_dir) / ONNX_IO_DUMP_DIRNAME def _build_onnx_io_dump_path( output_dir: str | Path, source_file: str | Path ) -> Path: source_stem = Path(source_file).stem return _build_onnx_io_dump_dir(output_dir) / f"{source_stem}_onnx_io.npy" def _build_onnx_io_dump_readme_text() -> str: return """# ONNX I/O 导出说明 本目录用于保存 MuJoCo sim2sim 评测过程中导出的 ONNX 输入输出数据。 ## 文件组织 - 每个动作片段会生成一个 `.npy` 文件,文件名形如 `_onnx_io.npy` - 每个 `.npy` 文件对应一个原始的动作片段 `.npz` - 当前只支持默认的 `holomotion` / `MujocoEvaluator` 批量目录评测模式(`motion_npz_dir`) ## 读取方式 `.npy` 文件内部保存的是一个 Python `dict`,读取时需要开启 `allow_pickle=True`: ```python import numpy as np npy_path = "onnx_io_npy/example_clip_onnx_io.npy" payload = np.load(npy_path, allow_pickle=True).item() print(payload.keys()) print(payload["input_names"]) print(payload["output_names"]) print(payload["inputs"]["obs"].shape) print(payload["outputs"]["action"].shape) ``` ## 数据字段 - `input_names`: ONNX 实际输入张量名称列表 - `output_names`: ONNX 实际输出张量名称列表 - `inputs`: 按输入张量名称组织的字典,数组第 0 维是帧索引 - `outputs`: 按输出张量名称组织的字典,数组第 0 维是帧索引 - `source_npz`: 原始动作片段文件名 - `onnx_model`: 导出这些张量时使用的 ONNX 模型路径 ## 说明 单个 `.npy` 文件只能保存一个顶层对象,因此这里使用 pickled dict 来同时保存输入名称、输出名称以及逐帧堆叠后的 numpy 数组。 如果某次导出未产生有效 ONNX I/O 数据,`inputs` 和 `outputs` 可能为空字典,读取时请先检查键是否存在。 """ def write_onnx_io_dump_readme(output_dir: str | Path) -> Path: output_dir_path = Path(output_dir) output_dir_path.mkdir(parents=True, exist_ok=True) readme_path = output_dir_path / "README.md" readme_path.write_text(_build_onnx_io_dump_readme_text(), encoding="utf-8") return readme_path def _allocate_actor_counts(num_checkpoints: int, total_actors: int): if num_checkpoints <= 0: raise ValueError("num_checkpoints must be > 0") if total_actors <= 0: raise ValueError("total_actors must be > 0") base = total_actors // num_checkpoints rem = total_actors % num_checkpoints return [base + (1 if i < rem else 0) for i in range(num_checkpoints)] def _infer_step_from_ckpt_name(ckpt_name: str): match = re.search(r"model_(\d+)", ckpt_name) if match: return int(match.group(1)) fallback = re.search(r"(\d+)", ckpt_name) if fallback: return int(fallback.group(1)) return None def _read_total_macro_row(tsv_path: Path): if not tsv_path.is_file(): return None with open(tsv_path, "r", encoding="utf-8", newline="") as tsv_file: reader = csv.DictReader(tsv_file, delimiter="\t") for row in reader: dataset_value = str(row.get("Dataset", "")).strip().lower() if "total" in dataset_value and "macro" in dataset_value: return row return None def _write_total_macro_summary_table( eval_targets, job_log_dir: Path | None = None ): rows_by_parent = {} for target in eval_targets: output_dir_path = Path(target["output_dir"]) ckpt_path = target["ckpt_path"] tsv_path = output_dir_path / "sub_dataset_macro_mean_metrics.tsv" total_row = _read_total_macro_row(tsv_path) if total_row is None: logger.warning( "Skipping aggregated total metrics entry because " f"Total (Macro) row is unavailable: {tsv_path}" ) continue parent_dir = output_dir_path.parent if parent_dir not in rows_by_parent: rows_by_parent[parent_dir] = [] rows_by_parent[parent_dir].append( { "step": _infer_step_from_ckpt_name(ckpt_path.stem), "total_row": total_row, "ckpt_name": ckpt_path.stem, } ) for parent_dir, entries in rows_by_parent.items(): if len(entries) == 0: continue entries.sort( key=lambda item: ( item["step"] is None, item["step"] if item["step"] is not None else 0, item["ckpt_name"], ) ) metric_columns = list(entries[0]["total_row"].keys()) available_steps = [ entry["step"] for entry in entries if entry["step"] is not None ] if len(available_steps) > 0: step_range = f"{min(available_steps)}-{max(available_steps)}" else: step_range = "na-na" output_name = f"mujoco_model-{step_range}_total_metrics.tsv" output_path = parent_dir / output_name generated_artifacts = [output_path] with open(output_path, "w", encoding="utf-8", newline="") as out_file: writer = csv.writer(out_file, delimiter="\t", lineterminator="\n") writer.writerow(["step"] + metric_columns) for entry in entries: step_value = ( str(entry["step"]) if entry["step"] is not None else "" ) writer.writerow( [step_value] + [ entry["total_row"].get(col, "") for col in metric_columns ] ) logger.info(f"Saved aggregated total metrics table at: {output_path}") plot_metric_columns = [ col for col in metric_columns if col != "Dataset" ] if len(plot_metric_columns) > 0: import matplotlib.pyplot as plt ncols = 4 nrows = (len(plot_metric_columns) + ncols - 1) // ncols fig, axes = plt.subplots( nrows=nrows, ncols=ncols, figsize=(4.0 * ncols, 2.8 * nrows), squeeze=False, ) for idx, metric_name in enumerate(plot_metric_columns): ax = axes[idx // ncols][idx % ncols] trend_pairs = [] for entry in entries: step_value = entry["step"] if step_value is None: continue raw_metric = entry["total_row"].get(metric_name, "") if str(raw_metric).strip() == "": continue trend_pairs.append((step_value, float(raw_metric))) if len(trend_pairs) == 0: ax.text( 0.5, 0.5, "No valid data", ha="center", va="center", transform=ax.transAxes, ) ax.set_title(metric_name) ax.set_xticks([]) ax.set_yticks([]) ax.grid(False) continue trend_pairs.sort(key=lambda pair: pair[0]) plot_steps = [pair[0] for pair in trend_pairs] plot_values = [pair[1] for pair in trend_pairs] ax.plot(plot_steps, plot_values, marker="o", linewidth=1.2) ax.set_title(metric_name) ax.set_xlabel("step") ax.grid(True, alpha=0.3) total_axes = nrows * ncols for idx in range(len(plot_metric_columns), total_axes): axes[idx // ncols][idx % ncols].axis("off") fig.tight_layout() plot_path = output_path.with_name( f"{output_path.stem}_all_metric_trends.pdf" ) fig.savefig(plot_path, format="pdf") plt.close(fig) generated_artifacts.append(plot_path) logger.info(f"Saved combined metric trend plot at: {plot_path}") if job_log_dir is not None: for artifact_path in generated_artifacts: job_log_path = job_log_dir / artifact_path.name shutil.copy2(artifact_path, job_log_path) logger.info(f"Exported artifact to /job_log: {job_log_path}") def process_config(override_config): """Process the configuration, merging with training config if available.""" ckpt_onnx_path = _get_config_value(override_config, "ckpt_onnx_path") ckpt_onnx_root_dir = _get_config_value( override_config, "ckpt_onnx_root_dir" ) if ( (ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == "") and ckpt_onnx_root_dir is not None and str(ckpt_onnx_root_dir).strip() != "" ): ckpt_onnx_names = _get_config_value(override_config, "ckpt_onnx_names") resolved_paths = _resolve_multi_ckpt_paths( ckpt_onnx_root_dir, ckpt_onnx_names ) ckpt_onnx_path = str(resolved_paths[0]) logger.info( "Using the first resolved checkpoint as config anchor: " f"{ckpt_onnx_path}" ) model_type = override_config.get("model_type") or "holomotion" if model_type == "gmt": config_path = Path( "holomotion/config/evaluation/gmt_eval_mujoco_sim2sim.yaml" ) elif model_type == "any2track": config_path = Path( "holomotion/config/evaluation/any2track_eval_mujoco_sim2sim.json" ) elif model_type == "sonic": config_path = Path( "holomotion/config/evaluation/sonic_eval_mujoco_sim2sim.yaml" ) else: if ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == "": raise ValueError( "Cannot locate training config.yaml for model_type='holomotion' " "without an ONNX checkpoint path. Set ckpt_onnx_path, or set " "ckpt_onnx_root_dir + ckpt_onnx_names." ) onnx_path = Path(str(ckpt_onnx_path)) # Load training config.yaml from one level above the ONNX path (../onnx_path) config_path = onnx_path.parent.parent / "config.yaml" logger.info(f"Loading training config file from {config_path}") # Ensure ${eval:'...'} expressions are supported during resolution if not OmegaConf.has_resolver("eval"): OmegaConf.register_new_resolver("eval", lambda expr: eval(expr)) with open(config_path) as file: train_config = OmegaConf.load(file) # Merge training config with any overrides config = OmegaConf.merge(train_config, override_config) with open_dict(config): config.model_type = model_type # Resolve config values in-place OmegaConf.resolve(config) if ( ( config.get("ckpt_onnx_path", None) is None or str(config.get("ckpt_onnx_path")).strip() == "" ) and ckpt_onnx_path is not None and str(ckpt_onnx_path).strip() != "" ): with open_dict(config): config.ckpt_onnx_path = str(ckpt_onnx_path) return config def _create_ray_evaluator(config_dict, model_type): """Create evaluator from serializable config dict (used inside Ray actor).""" from omegaconf import OmegaConf, open_dict config = OmegaConf.create(config_dict) if model_type == "gmt": from holomotion.src.evaluation.gmt_sim2sim import GMTEvaluator return GMTEvaluator(config) if model_type == "any2track": from holomotion.src.evaluation.any2track_sim2sim import ( Any2TrackEvaluator, ) return Any2TrackEvaluator(config) if model_type == "sonic": from holomotion.src.evaluation.sonic_mujoco_sim2sim import ( SonicEvaluator, ) return SonicEvaluator(config) return MujocoEvaluator(config) def run_mujoco_sim2sim_eval(override_config: OmegaConf): os.chdir(hydra.utils.get_original_cwd()) config = process_config(override_config) is_eval_mode = False dataset_dir = config.get("motion_npz_dir", None) specific_file = config.get("motion_npz_path", None) calc_per_clip_metrics = bool(config.get("calc_per_clip_metrics", False)) generate_report = bool(config.get("generate_report", False)) dump_npzs_cfg = bool(config.get("dump_npzs", False)) dump_onnx_io_npy = bool(config.get("dump_onnx_io_npy", False)) dump_npzs = dump_npzs_cfg or calc_per_clip_metrics if calc_per_clip_metrics and not dump_npzs_cfg: logger.info( "calc_per_clip_metrics=true requires dumped NPZs; " "enabling dump_npzs automatically." ) if ( dataset_dir and os.path.isdir(str(dataset_dir)) and (not specific_file or str(specific_file) == "") ): is_eval_mode = True if is_eval_mode: logger.info(f"Mode: EVALUATION on directory: {dataset_dir}") logger.remove() logger.add(sys.stderr, level="INFO") dataset_name = Path(dataset_dir).name ckpt_paths = _resolve_eval_ckpt_paths(config) logger.info( f"Resolved {len(ckpt_paths)} checkpoint(s) for evaluation." ) for idx, ckpt_path in enumerate(ckpt_paths): logger.info(f" [{idx}] {ckpt_path}") eval_targets = [] for ckpt_path in ckpt_paths: output_dir = _build_eval_output_dir(ckpt_path, dataset_name) eval_targets.append( { "ckpt_path": ckpt_path, "output_dir": str(output_dir), } ) if dump_npzs: for target in eval_targets: os.makedirs(target["output_dir"], exist_ok=True) if dump_onnx_io_npy: write_onnx_io_dump_readme( _build_onnx_io_dump_dir(target["output_dir"]) ) files = sorted( [ os.path.join(root, name) for root, _, filenames in os.walk( dataset_dir, followlinks=True ) for name in filenames if name.endswith(".npz") ] ) logger.info( f"Found {len(files)} files for dataset_dir={dataset_dir}. " f"Will evaluate {len(eval_targets)} checkpoint(s)." ) if len(files) == 0: logger.warning( f"No NPZ files found under dataset_dir={dataset_dir}" ) requested_use_gpu = _coerce_config_bool( config.get("use_gpu", True), default=True ) num_available_gpus = 0 if requested_use_gpu and torch.cuda.is_available(): num_available_gpus = int(torch.cuda.device_count()) if requested_use_gpu and num_available_gpus == 0: logger.warning( "use_gpu=true but no CUDA device is detected; " "Ray actors will run on CPU." ) if num_available_gpus > 0: logger.info( f"Detected {num_available_gpus} CUDA device(s). " "Using Ray for batch evaluation." ) ray_actors_per_gpu = int(config.get("ray_actors_per_gpu", 4)) if ray_actors_per_gpu <= 0: raise ValueError("ray_actors_per_gpu must be > 0") ray_multi_ckpt_mode = str( config.get("ray_multi_ckpt_mode", "split") ) if ray_multi_ckpt_mode not in ("split", "per_checkpoint"): raise ValueError( "ray_multi_ckpt_mode must be one of: " "'split', 'per_checkpoint'" ) success_count = 0 total_jobs = len(files) * len(eval_targets) if total_jobs > 0: base_config_dict = OmegaConf.to_container(config, resolve=True) base_config_dict.setdefault( "ray_evaluator_module", "holomotion.src.evaluation.eval_mujoco_sim2sim", ) if not ray.is_initialized(): ray.init() from holomotion.src.evaluation.ray_evaluator_actor import ( RayEvaluatorActor, ) if num_available_gpus > 0: base_actor_count = num_available_gpus * ray_actors_per_gpu gpus_per_actor = 1.0 / ray_actors_per_gpu remote_actor = ray.remote(num_gpus=gpus_per_actor)( RayEvaluatorActor ) else: base_actor_count = max(1, ray_actors_per_gpu) gpus_per_actor = 0.0 remote_actor = ray.remote(num_gpus=0)(RayEvaluatorActor) if ray_multi_ckpt_mode == "per_checkpoint": actor_counts = [ base_actor_count for _ in range(len(eval_targets)) ] else: actor_counts = _allocate_actor_counts( len(eval_targets), base_actor_count ) if min(actor_counts) <= 0: raise ValueError( "Not enough actor budget to assign at least one actor " "per checkpoint in split mode. Reduce checkpoint count, " "increase ray_actors_per_gpu, or switch to " "ray_multi_ckpt_mode=per_checkpoint." ) total_actor_count = sum(actor_counts) logger.info( f"Ray: {total_actor_count} persistent actors " f"({ray_actors_per_gpu} per GPU, {gpus_per_actor} GPU each)" ) logger.info( "Checkpoint actor allocation: " + ", ".join( [ f"{target['ckpt_path'].name}={actor_counts[idx]}" for idx, target in enumerate(eval_targets) ] ) ) refs = [] for target_idx, target in enumerate(eval_targets): target_config_dict = dict(base_config_dict) target_config_dict["ckpt_onnx_path"] = str( target["ckpt_path"] ) num_actors = actor_counts[target_idx] target_actors = [ remote_actor.remote( target_config_dict, target["output_dir"] ) for _ in range(num_actors) ] for file_idx, file_path in enumerate(files): actor = target_actors[file_idx % len(target_actors)] refs.append(actor.run_clip.remote(file_path)) pbar = tqdm( total=total_jobs, desc="Batch Processing (all checkpoints)", unit="job", dynamic_ncols=True, ) while refs: done, refs = ray.wait(refs, num_returns=1) for ref in done: status = ray.get(ref) if status == "success": success_count += 1 pbar.update(1) pbar.close() logger.info( f"Batch processing done. Success: {success_count}/{total_jobs}" ) else: logger.info("Skipping NPZ dumping because dump_npzs=false.") job_log_dir = Path("/job_log") job_log_enabled = job_log_dir.is_dir() and os.access( str(job_log_dir), os.W_OK ) if job_log_enabled: logger.info( f"Detected writable /job_log. Will copy summary TSVs to {job_log_dir}." ) else: logger.info( "/job_log is unavailable or not writable. " "Skipping summary TSV export." ) postprocess_targets = [] for target in eval_targets: output_dir = target["output_dir"] output_dir_path = Path(output_dir) if not output_dir_path.is_dir(): logger.warning( f"Output directory does not exist, skipping post-processing: {output_dir}" ) continue postprocess_targets.append(target) failure_pos_err_thresh_m = float( config.get("failure_pos_err_thresh_m", 0.25) ) metric_calculation = str(config.get("metric_calculation", "per_clip")) dof_mode = str(config.get("dof_mode", "29")) ray_parallel_metrics = bool( config.get( "ray_parallel_metrics_postprocess", config.get("ray_parallel_metrics", True), ) ) metrics_threadpool_max_workers = config.get( "metrics_threadpool_max_workers", None ) should_parallelize_metrics = ( ray_parallel_metrics and len(postprocess_targets) > 1 and (calc_per_clip_metrics or generate_report or job_log_enabled) ) logger.info( "Metrics config: " f"ray_parallel_metrics_postprocess={ray_parallel_metrics}, " f"metrics_threadpool_max_workers={metrics_threadpool_max_workers}" ) if should_parallelize_metrics: if not ray.is_initialized(): ray.init() from holomotion.src.evaluation.ray_metrics_postprocess import ( run_metrics_postprocess_job, ) ray_metrics_num_cpus_cfg = config.get( "ray_metrics_postprocess_num_cpus", config.get("ray_metrics_num_cpus", None), ) if ray_metrics_num_cpus_cfg is None: ray_metrics_num_cpus = 0.0 else: ray_metrics_num_cpus = float(ray_metrics_num_cpus_cfg) if ray_metrics_num_cpus < 0.0: raise ValueError("ray_metrics_num_cpus must be >= 0") metric_refs = [] for target in postprocess_targets: ckpt_path = target["ckpt_path"] metric_refs.append( run_metrics_postprocess_job.options( num_cpus=ray_metrics_num_cpus ).remote( output_dir=target["output_dir"], dataset_name=dataset_name, calc_per_clip_metrics=calc_per_clip_metrics, failure_pos_err_thresh_m=failure_pos_err_thresh_m, metric_calculation=metric_calculation, dof_mode=dof_mode, metrics_threadpool_max_workers=metrics_threadpool_max_workers, generate_report=generate_report, job_log_dir=str(job_log_dir) if job_log_enabled else None, ckpt_stem=ckpt_path.stem, ) ) pbar = tqdm( total=len(metric_refs), desc="Metrics post-processing (all checkpoints)", unit="ckpt", dynamic_ncols=True, ) while metric_refs: done, metric_refs = ray.wait(metric_refs, num_returns=1) for ref in done: result = ray.get(ref) ckpt_stem = str(result.get("ckpt_stem", "")).strip() if ckpt_stem == "": ckpt_stem = "unknown" if calc_per_clip_metrics: logger.info( f"Metric calculation finished for {ckpt_stem}." ) report_path = str(result.get("report_path", "")).strip() if report_path != "": logger.info( f"Generated metrics report for {ckpt_stem} at: {report_path}" ) exported_tsv = str( result.get("exported_summary_tsv", "") ).strip() if exported_tsv != "": logger.info(f"Exported summary TSV to: {exported_tsv}") pbar.update(1) pbar.close() else: mean_process_5metrics = None if generate_report: from holomotion.scripts.evaluation import mean_process_5metrics for target in postprocess_targets: output_dir = target["output_dir"] output_dir_path = Path(output_dir) ckpt_path = target["ckpt_path"] if calc_per_clip_metrics: logger.info( "Starting metric calculation for " f"{ckpt_path.name}: {output_dir}" ) run_evaluation( npz_dir=output_dir, dataset_suffix=dataset_name, failure_pos_err_thresh_m=failure_pos_err_thresh_m, metric_calculation=metric_calculation, dof_mode=dof_mode, threadpool_max_workers=metrics_threadpool_max_workers, ) logger.info( f"Metric calculation finished for {ckpt_path.name}." ) if generate_report: report_path = mean_process_5metrics.generate_macro_mean_report_from_json_dir( output_dir ) logger.info( f"Generated metrics report for {ckpt_path.name} at: {report_path}" ) if job_log_enabled: sub_dataset_tsv = ( output_dir_path / "sub_dataset_macro_mean_metrics.tsv" ) if sub_dataset_tsv.is_file(): export_name = f"{ckpt_path.stem}_sub_dataset_macro_mean_metrics.tsv" export_path = job_log_dir / export_name shutil.copy2(sub_dataset_tsv, export_path) logger.info(f"Exported summary TSV to: {export_path}") else: logger.warning( "Summary TSV not found (skip export): " f"{sub_dataset_tsv}" ) _write_total_macro_summary_table( eval_targets, job_log_dir=job_log_dir if job_log_enabled else None, ) else: if config.get("model_type", "holomotion") == "sonic": from holomotion.src.evaluation.sonic_mujoco_sim2sim import ( SonicEvaluator, ) evaluator = SonicEvaluator(config) else: evaluator = MujocoEvaluator(config) evaluator.setup() evaluator.run_simulation() @hydra.main( config_path="../../config", config_name="evaluation/eval_mujoco_sim2sim", version_base=None, ) def main(override_config: OmegaConf): run_mujoco_sim2sim_eval(override_config) if __name__ == "__main__": main() ================================================ FILE: holomotion/src/evaluation/eval_velocity_tracking.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os from pathlib import Path import hydra from hydra.utils import get_class from loguru import logger from omegaconf import OmegaConf from holomotion.src.utils.config import compile_config from holomotion.src.utils.onnx_export import export_policy_to_onnx def load_training_config( checkpoint_path: str, eval_config: OmegaConf ) -> OmegaConf: """Load training config from checkpoint directory. Args: checkpoint_path: Path to the checkpoint file. eval_config: Full evaluation config (including command line overrides). Returns: Merged config with training config as base. """ checkpoint = Path(checkpoint_path) config_path = checkpoint.parent / "config.yaml" if not config_path.exists(): config_path = checkpoint.parent.parent / "config.yaml" if not config_path.exists(): logger.warning( "Training config not found at " f"{config_path}, using evaluation config" ) return eval_config logger.info(f"Loading training config from {config_path}") with open(config_path) as file: train_config = OmegaConf.load(file) # Apply eval_overrides from training config if they exist if train_config.get("eval_overrides") is not None: train_config = OmegaConf.merge( train_config, train_config.eval_overrides ) # Set checkpoint path train_config.checkpoint = checkpoint_path train_config.algo.config.checkpoint = checkpoint_path # For evaluation, merge eval_config into train_config config = OmegaConf.merge(train_config, eval_config) # For velocity tracking, always keep the robot configuration from training if hasattr(train_config, "robot"): config.robot = train_config.robot # foce set the terminations and domain rand with eval_config's config.env.config.terminations = eval_config.env.config.terminations config.env.config.domain_rand = eval_config.env.config.domain_rand config.env.config.domain_rand = eval_config.env.config.domain_rand return config @hydra.main( config_path="../../config", config_name="evaluation/eval_isaaclab", version_base=None, ) def main(config: OmegaConf): """Evaluate the velocity tracking model. Args: config: OmegaConf object containing the evaluation configuration. """ # Load training config first if config.checkpoint is None: raise ValueError("Checkpoint path must be provided for evaluation") config = load_training_config(config.checkpoint, config) # Compile config without accelerator (PPO will create it) config = compile_config(config, accelerator=None) # Use checkpoint directory as log_dir for offline evaluation log_dir = os.path.dirname(config.checkpoint) headless = config.headless # PPO creates Accelerator, AppLauncher, and environment internally algo_class = get_class(config.algo._target_) algo = algo_class( env_config=config.env, config=config.algo.config, log_dir=log_dir, headless=headless, is_offline_eval=True, ) if ( algo.accelerator.is_main_process and os.environ.get("TORCH_COMPILE_DISABLE", "0") != "1" ): logger.info( "Tip: set TORCH_COMPILE_DISABLE=1 if Triton/compile errors occur" ) if algo.accelerator.is_main_process: eval_log_dir = os.path.dirname(config.checkpoint) with open(os.path.join(eval_log_dir, "eval_config.yaml"), "w") as f: OmegaConf.save(config, f) if hasattr(config, "checkpoint") and config.checkpoint is not None: if algo.accelerator.is_main_process: logger.info( f"Loading checkpoint for evaluation: {config.checkpoint}" ) algo.load(config.checkpoint) else: if algo.accelerator.is_main_process: logger.warning("No checkpoint provided for evaluation!") # Export ONNX if requested if config.get("export_policy", True): if algo.accelerator.is_main_process: onnx_name_suffix = config.get("onnx_name_suffix", None) onnx_path = export_policy_to_onnx( algo, config.checkpoint, onnx_name_suffix=onnx_name_suffix, use_kv_cache=config.get("use_kv_cache", True), ) logger.info(f"Successfully exported policy to: {onnx_path}") algo.accelerator.wait_for_everyone() # Run indefinite velocity tracking rollout for visualization algo.offline_evaluate_velocity_tracking() if algo.accelerator.is_main_process: logger.info("Velocity tracking evaluation completed!") if __name__ == "__main__": main() ================================================ FILE: holomotion/src/evaluation/find_worst_clips.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import json import math from pathlib import Path from typing import Dict, Any, List JSON_INPUT_FILE = "logs/Holomotion/metrics_output_dataset/model_17500.json" input_path = Path(JSON_INPUT_FILE).expanduser().resolve() OUTPUT_JSON_FILE = str(input_path.parent / "bad_clips.json") WORST_PERCENTAGE = 0.2 METRICS_INFO: Dict[str, Dict[str, str]] = { "whole_body_joints_dist": { "name": "Joint Angle Error (Whole Body Average)", "unit": "rad", "direction": "higher_is_worse", }, } def find_and_save_bad_clips( data: Dict[str, Any], metrics_info: Dict[str, Dict[str, str]], percentage: float, output_file: str, ) -> None: per_clip_data: List[Dict[str, Any]] = data.get("per_clip", []) if not per_clip_data: print("Error: 'per_clip' not found in JSON data.") return total_clips = len(per_clip_data) num_to_select = math.ceil(total_clips * percentage) if num_to_select == 0 and total_clips > 0: num_to_select = 1 bad_clips_report: Dict[str, List[Dict[str, Any]]] = {} for key, info in metrics_info.items(): direction = info.get("direction") if not direction: continue sort_descending = direction == "higher_is_worse" clips_with_metric_value = [ {"motion_key": clip["motion_key"], "value": clip[key]} for clip in per_clip_data if key in clip and "motion_key" in clip ] if not clips_with_metric_value: print(f"Warning: no values found for metric '{key}' in data.") continue sorted_clips = sorted( clips_with_metric_value, key=lambda x: x["value"], reverse=sort_descending, ) worst_clips = sorted_clips[:num_to_select] bad_clips_report[key] = worst_clips with open(output_file, "w", encoding="utf-8") as f: json.dump(bad_clips_report, f, indent=4, ensure_ascii=False) print(f"Saved bad-clips report to: {output_file}") def main() -> None: if not Path(JSON_INPUT_FILE).is_file(): print(f"Error: JSON input file '{JSON_INPUT_FILE}' not found.") return with open(JSON_INPUT_FILE, "r", encoding="utf-8") as f: data = json.load(f) find_and_save_bad_clips( data, METRICS_INFO, WORST_PERCENTAGE, OUTPUT_JSON_FILE ) if __name__ == "__main__": main() ================================================ FILE: holomotion/src/evaluation/metrics.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from pathlib import Path from typing import Dict, List, Optional from concurrent.futures import ThreadPoolExecutor, as_completed import argparse import csv import json import os import re from glob import glob from zipfile import BadZipFile import numpy as np import pandas as pd from loguru import logger from scipy.signal import welch from scipy.spatial.transform import Rotation as sRot from tabulate import tabulate from tqdm import tqdm DEFAULT_ROBOT_CONTROL_DT = 1.0 / 50.0 TORQUE_JUMP_RATIO_EPS = 1e-6 MIN_WELCH_SAMPLES = 8 STABILITY_BURST_WINDOW_SECONDS = 0.5 TOUCHDOWN_WINDOW_SECONDS = 0.05 ROOT_BODY_INDEX = 0 PROBABILITY_EPS = 1e-12 def quat_inv(q): return np.concatenate([-q[..., :3], q[..., 3:4]], axis=-1) def quat_apply(q, v): q = np.asarray(q, dtype=np.float64) v = np.asarray(v, dtype=np.float64) xyz = q[:, None, :3] w = q[:, None, 3:4] t = 2.0 * np.cross(xyz, v, axis=-1) return v + w * t + np.cross(xyz, t, axis=-1) def p_mpjpe(predicted: np.ndarray, target: np.ndarray) -> np.ndarray: """Compute Procrustes-aligned MPJPE between predicted and ground truth. Reference: This function is inspired by and partially adapted from the SMPLSim: https://github.com/ZhengyiLuo/SMPLSim/blob/0d672790a7672f28361d59dadd98ae2fc1b9685e/smpl_sim/smpllib/smpl_eval.py. """ assert predicted.shape == target.shape mu_x = np.mean(target, axis=1, keepdims=True) mu_y = np.mean(predicted, axis=1, keepdims=True) x0 = target - mu_x y0 = predicted - mu_y norm_x = np.sqrt(np.sum(x0**2, axis=(1, 2), keepdims=True)) norm_y = np.sqrt(np.sum(y0**2, axis=(1, 2), keepdims=True)) x0 /= norm_x y0 /= norm_y h = np.matmul(x0.transpose(0, 2, 1), y0) # Per-frame SVD with graceful handling for non-convergence: mark those frames as NaN batch_size = int(h.shape[0]) jdim = int(h.shape[1]) u = np.empty((batch_size, jdim, jdim), dtype=h.dtype) s = np.empty((batch_size, jdim), dtype=h.dtype) vt = np.empty((batch_size, jdim, jdim), dtype=h.dtype) for i in range(batch_size): try: ui, si, vti = np.linalg.svd(h[i]) u[i] = ui s[i] = si vt[i] = vti except np.linalg.LinAlgError: u[i].fill(np.nan) s[i].fill(np.nan) vt[i].fill(np.nan) v = vt.transpose(0, 2, 1) r = np.matmul(v, u.transpose(0, 2, 1)) # Avoid improper rotations (reflections), i.e. rotations with det(R) = -1 sign_det_r = np.sign(np.expand_dims(np.linalg.det(r), axis=1)) v[:, :, -1] *= sign_det_r s[:, -1] *= sign_det_r.flatten() r = np.matmul(v, u.transpose(0, 2, 1)) # Corrected rotation tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) a = tr * norm_x / norm_y # Scale t = mu_x - a * np.matmul(mu_y, r) # Translation predicted_aligned = a * np.matmul(predicted, r) + t return np.linalg.norm( predicted_aligned - target, axis=len(target.shape) - 1 ) def _parse_clip_len_from_name(filename: str) -> Optional[int]: """Extract clip length from filename suffix '__start_XXX_len_N'.""" m = re.search(r"__start_\d+_len_(\d+)", os.path.basename(filename)) return int(m.group(1)) if m else None def _parse_metadata_entry(raw_metadata) -> Dict[str, object]: if raw_metadata is None: return {} parsed = raw_metadata if isinstance(parsed, np.ndarray): if parsed.shape != (): return {} parsed = parsed.item() if isinstance(parsed, dict): return parsed if isinstance(parsed, bytes): parsed = parsed.decode("utf-8") if isinstance(parsed, str): try: obj = json.loads(parsed) except json.JSONDecodeError: return {} return obj if isinstance(obj, dict) else {} return {} def _extract_robot_control_dt( metadata: Dict[str, object], raw_data: Dict[str, np.ndarray] ) -> float: if "robot_low_level_torque_dt" in raw_data: raw_dt = np.asarray(raw_data["robot_low_level_torque_dt"]).item() else: raw_dt = metadata.get( "robot_low_level_torque_dt", metadata.get("robot_control_dt", DEFAULT_ROBOT_CONTROL_DT), ) try: robot_control_dt = float(raw_dt) except (TypeError, ValueError): return DEFAULT_ROBOT_CONTROL_DT if not np.isfinite(robot_control_dt) or robot_control_dt <= 0.0: return DEFAULT_ROBOT_CONTROL_DT return robot_control_dt def _extract_low_level_contact_dt( metadata: Dict[str, object], raw_data: Dict[str, np.ndarray], robot_control_dt: float, ) -> float: if "robot_low_level_contact_dt" in raw_data: raw_dt = np.asarray(raw_data["robot_low_level_contact_dt"]).item() else: raw_dt = metadata.get( "robot_low_level_contact_dt", metadata.get( "robot_low_level_torque_dt", metadata.get("robot_control_dt", robot_control_dt), ), ) try: contact_dt = float(raw_dt) except (TypeError, ValueError): return robot_control_dt if not np.isfinite(contact_dt) or contact_dt <= 0.0: return robot_control_dt return contact_dt def _aggregate_sample_metric_to_frames( sample_metric: np.ndarray, num_frames: int ) -> np.ndarray: if int(sample_metric.shape[0]) == num_frames: return sample_metric.astype(float, copy=False) if num_frames <= 0: return np.empty((0,), dtype=float) aggregated = np.full((num_frames,), np.nan, dtype=float) for frame_idx, chunk in enumerate( np.array_split(sample_metric, num_frames) ): if chunk.size == 0: continue if np.all(np.isnan(chunk)): continue aggregated[frame_idx] = float(np.nanmean(chunk)) return aggregated def _compute_torque_jump_series( torque_samples: np.ndarray, torque_dt: float ) -> tuple[np.ndarray, np.ndarray]: num_samples = int(torque_samples.shape[0]) torque_jump_norm = np.full((num_samples,), np.nan, dtype=float) torque_jump_ratio = np.full((num_samples,), np.nan, dtype=float) if num_samples <= 1: return torque_jump_norm, torque_jump_ratio torque_mag = np.linalg.norm(torque_samples, axis=1) torque_delta_norm = np.linalg.norm( torque_samples[1:] - torque_samples[:-1], axis=1 ) torque_jump_norm[1:] = torque_delta_norm / torque_dt torque_scale = np.maximum( np.maximum(torque_mag[1:], torque_mag[:-1]), TORQUE_JUMP_RATIO_EPS ) torque_jump_ratio[1:] = torque_delta_norm / torque_scale return torque_jump_norm, torque_jump_ratio def _safe_nanpercentile(values: np.ndarray, q: float) -> float: arr = np.asarray(values, dtype=float).reshape(-1) arr = arr[np.isfinite(arr)] if arr.size == 0: return float("nan") return float(np.nanpercentile(arr, q)) def _safe_nanmean(values: np.ndarray) -> float: arr = np.asarray(values, dtype=float).reshape(-1) arr = arr[np.isfinite(arr)] if arr.size == 0: return float("nan") return float(np.mean(arr)) def _safe_nanmedian(values: np.ndarray) -> float: arr = np.asarray(values, dtype=float).reshape(-1) arr = arr[np.isfinite(arr)] if arr.size == 0: return float("nan") return float(np.median(arr)) def _compute_rolling_nanmean_max( values: np.ndarray, window_size: int ) -> float: arr = np.asarray(values, dtype=float).reshape(-1) if arr.size == 0: return float("nan") if window_size <= 1: return float(np.nanmax(arr)) best = float("nan") max_start = int(arr.size) - int(window_size) + 1 if max_start <= 0: if np.all(np.isnan(arr)): return float("nan") return float(np.nanmean(arr)) for start in range(max_start): window = arr[start : start + window_size] if np.all(np.isnan(window)): continue mean_value = float(np.nanmean(window)) if np.isnan(best) or mean_value > best: best = mean_value return best def _integrate_psd_band( frequencies: np.ndarray, power_density: np.ndarray, low_hz: float, high_hz: float, ) -> float: if ( not np.isfinite(low_hz) or not np.isfinite(high_hz) or high_hz <= low_hz ): return float("nan") band_mask = (frequencies >= low_hz) & (frequencies <= high_hz) if not np.any(band_mask): return float("nan") band_freq = frequencies[band_mask] band_power = power_density[band_mask] if band_freq.size == 1: return float(band_power[0]) return float(np.trapz(band_power, band_freq)) def _compute_psd_high_frequency_ratio( signal_values: np.ndarray, sample_dt: float, *, high_band_low_hz: float, band_high_hz: float, band_low_hz: float = 0.5, ) -> float: samples = np.asarray(signal_values, dtype=float).reshape(-1) samples = samples[np.isfinite(samples)] if samples.size < MIN_WELCH_SAMPLES: return float("nan") sample_rate_hz = 1.0 / float(sample_dt) max_band_hz = min(float(band_high_hz), 0.45 * sample_rate_hz) if max_band_hz <= max(float(band_low_hz), float(high_band_low_hz)): return float("nan") nperseg = min(int(samples.size), 256) frequencies, power_density = welch( samples, fs=sample_rate_hz, nperseg=nperseg, detrend="constant", average="mean", ) total_power = _integrate_psd_band( frequencies, power_density, float(band_low_hz), max_band_hz ) high_power = _integrate_psd_band( frequencies, power_density, float(high_band_low_hz), max_band_hz ) if ( not np.isfinite(total_power) or total_power <= 0.0 or not np.isfinite(high_power) ): return float("nan") return float(high_power / total_power) def _compute_torque_chatter_hf_ratio( low_level_torque: np.ndarray, low_level_dt: float ) -> float: torque_samples = np.asarray(low_level_torque, dtype=float) if torque_samples.ndim != 2 or torque_samples.shape[0] < MIN_WELCH_SAMPLES: return float("nan") ratios = [] for joint_idx in range(int(torque_samples.shape[1])): ratio = _compute_psd_high_frequency_ratio( torque_samples[:, joint_idx], low_level_dt, high_band_low_hz=10.0, band_high_hz=40.0, ) if np.isfinite(ratio): ratios.append(ratio) if len(ratios) == 0: return float("nan") return float(np.mean(ratios)) def _compute_torso_roll_pitch_stability_metrics( robot_global_angular_velocity: np.ndarray, robot_control_dt: float, ) -> Dict[str, float]: angular_velocity = np.asarray(robot_global_angular_velocity, dtype=float) if angular_velocity.ndim != 3 or angular_velocity.shape[0] == 0: return { "torso_rp_hf_ratio": float("nan"), "torso_rp_angacc_p95": float("nan"), } torso_roll_pitch_vel = angular_velocity[:, ROOT_BODY_INDEX, :2] torso_roll_pitch_speed = np.linalg.norm(torso_roll_pitch_vel, axis=1) hf_ratio = _compute_psd_high_frequency_ratio( torso_roll_pitch_speed, robot_control_dt, high_band_low_hz=5.0, band_high_hz=20.0, ) if torso_roll_pitch_vel.shape[0] <= 1: angacc_p95 = float("nan") else: roll_pitch_angacc = np.diff(torso_roll_pitch_vel, axis=0) / float( robot_control_dt ) roll_pitch_angacc_mag = np.linalg.norm(roll_pitch_angacc, axis=1) angacc_p95 = _safe_nanpercentile(roll_pitch_angacc_mag, 95.0) return { "torso_rp_hf_ratio": hf_ratio, "torso_rp_angacc_p95": angacc_p95, } def _compute_expert_switching_js_div( robot_moe_expert_logits: np.ndarray | None, ) -> float: if robot_moe_expert_logits is None: return float("nan") logits = np.asarray(robot_moe_expert_logits, dtype=float) if logits.ndim != 3 or logits.shape[0] <= 1 or logits.shape[-1] <= 1: return float("nan") if not np.all(np.isfinite(logits)): return float("nan") shifted_logits = logits - np.max(logits, axis=-1, keepdims=True) probs = np.exp(shifted_logits) probs /= np.sum(probs, axis=-1, keepdims=True) prev_probs = np.clip(probs[:-1], PROBABILITY_EPS, 1.0) next_probs = np.clip(probs[1:], PROBABILITY_EPS, 1.0) mixture = 0.5 * (prev_probs + next_probs) kl_prev = np.sum( prev_probs * (np.log(prev_probs) - np.log(mixture)), axis=-1 ) kl_next = np.sum( next_probs * (np.log(next_probs) - np.log(mixture)), axis=-1 ) js_divergence = 0.5 * (kl_prev + kl_next) / np.log(2.0) return _safe_nanmean(js_divergence) def _compute_contact_stability_metrics( foot_contact_samples: np.ndarray | None, foot_normal_force_samples: np.ndarray | None, foot_tangent_speed_samples: np.ndarray | None, contact_dt: float, ) -> Dict[str, float]: metrics = { "foot_contact_toggle_rate": float("nan"), "foot_impact_force_p95": float("nan"), "stance_slip_speed_p95": float("nan"), } if ( foot_contact_samples is None or foot_normal_force_samples is None or foot_tangent_speed_samples is None ): return metrics contact = np.asarray(foot_contact_samples, dtype=float) normal_force = np.asarray(foot_normal_force_samples, dtype=float) tangent_speed = np.asarray(foot_tangent_speed_samples, dtype=float) if ( contact.shape != normal_force.shape or contact.shape != tangent_speed.shape or contact.ndim != 2 or contact.shape[1] != 2 ): return metrics finite_contact = np.isfinite(contact) if not np.any(finite_contact): return metrics contact_binary = np.where(contact >= 0.5, 1.0, 0.0) valid_pair_mask = finite_contact[1:] & finite_contact[:-1] toggle_count = int( np.sum( np.abs(contact_binary[1:] - contact_binary[:-1]) * valid_pair_mask ) ) clip_duration_seconds = float(contact.shape[0]) * float(contact_dt) if clip_duration_seconds > 0.0: metrics["foot_contact_toggle_rate"] = ( float(toggle_count) / clip_duration_seconds ) touchdown_window = max( 1, int(round(TOUCHDOWN_WINDOW_SECONDS / float(contact_dt))) ) touchdown_peaks = [] for foot_idx in range(2): foot_contact = contact_binary[:, foot_idx] foot_force = normal_force[:, foot_idx] onset_mask = np.zeros_like(foot_contact, dtype=bool) onset_mask[0] = foot_contact[0] >= 0.5 onset_mask[1:] = (foot_contact[1:] >= 0.5) & (foot_contact[:-1] < 0.5) for onset_idx in np.flatnonzero(onset_mask): window = foot_force[onset_idx : onset_idx + touchdown_window] if window.size == 0 or np.all(~np.isfinite(window)): continue touchdown_peaks.append(float(np.nanmax(window))) metrics["foot_impact_force_p95"] = _safe_nanpercentile( np.asarray(touchdown_peaks, dtype=float), 95.0 ) stance_slip_mask = (contact_binary >= 0.5) & np.isfinite(tangent_speed) if np.any(stance_slip_mask): metrics["stance_slip_speed_p95"] = _safe_nanpercentile( tangent_speed[stance_slip_mask], 95.0 ) return metrics def _compute_clip_stability_summary( data: Dict[str, np.ndarray], robot_control_dt: float, low_level_contact_dt: float, ) -> Dict[str, float]: robot_low_level_dof_torque = ( np.asarray(data["robot_low_level_dof_torque"]) if "robot_low_level_dof_torque" in data else None ) if robot_low_level_dof_torque is None and "robot_dof_torque" in data: robot_low_level_dof_torque = np.asarray(data["robot_dof_torque"]) if robot_low_level_dof_torque is None: torque_chatter_hf_ratio = float("nan") torque_jump_burst_max = float("nan") else: torque_chatter_hf_ratio = _compute_torque_chatter_hf_ratio( robot_low_level_dof_torque, low_level_contact_dt ) _, torque_jump_ratio = _compute_torque_jump_series( robot_low_level_dof_torque, low_level_contact_dt ) torque_jump_window = max( 1, int( round( STABILITY_BURST_WINDOW_SECONDS / float(low_level_contact_dt) ) ), ) torque_jump_burst_max = _compute_rolling_nanmean_max( torque_jump_ratio[1:], torque_jump_window ) torso_metrics = _compute_torso_roll_pitch_stability_metrics( np.asarray(data["robot_global_angular_velocity"]), robot_control_dt, ) contact_metrics = _compute_contact_stability_metrics( np.asarray(data["robot_low_level_foot_contact"]) if "robot_low_level_foot_contact" in data else None, np.asarray(data["robot_low_level_foot_normal_force"]) if "robot_low_level_foot_normal_force" in data else None, np.asarray(data["robot_low_level_foot_tangent_speed"]) if "robot_low_level_foot_tangent_speed" in data else None, low_level_contact_dt, ) expert_switching_js_div = _compute_expert_switching_js_div( np.asarray(data["robot_moe_expert_logits"]) if "robot_moe_expert_logits" in data else None ) return { "torque_chatter_hf_ratio": torque_chatter_hf_ratio, "torque_jump_burst_max": torque_jump_burst_max, "expert_switching_js_div": expert_switching_js_div, **torso_metrics, **contact_metrics, } def _compute_clip_torque_jump_summary( data: Dict[str, np.ndarray], dof_mode: str, torque_dt: float, ) -> Dict[str, float]: robot_dof_torque = ( np.asarray(data["robot_dof_torque"]) if "robot_dof_torque" in data else None ) robot_low_level_dof_torque = ( np.asarray(data["robot_low_level_dof_torque"]) if "robot_low_level_dof_torque" in data else None ) if dof_mode == "23" and robot_dof_torque is not None: total_dofs_in_file = int(robot_dof_torque.shape[1]) if total_dofs_in_file == 29: idx_23_in_29_dof = list(range(19)) + list(range(22, 26)) robot_dof_torque = robot_dof_torque[:, idx_23_in_29_dof] if ( robot_low_level_dof_torque is not None and int(robot_low_level_dof_torque.shape[1]) == total_dofs_in_file ): robot_low_level_dof_torque = robot_low_level_dof_torque[ :, idx_23_in_29_dof ] chatter_torque = robot_low_level_dof_torque if chatter_torque is None: chatter_torque = robot_dof_torque if chatter_torque is None or int(chatter_torque.shape[0]) <= 1: return { "mean_torque_jump_norm": float("nan"), "p95_torque_jump_norm": float("nan"), "mean_torque_jump_ratio": float("nan"), "p95_torque_jump_ratio": float("nan"), } torque_jump_norm, torque_jump_ratio = _compute_torque_jump_series( chatter_torque, torque_dt ) return { "mean_torque_jump_norm": float(np.nanmean(torque_jump_norm)), "p95_torque_jump_norm": float( np.nanpercentile(torque_jump_norm[1:], 95) ), "mean_torque_jump_ratio": float(np.nanmean(torque_jump_ratio)), "p95_torque_jump_ratio": float( np.nanpercentile(torque_jump_ratio[1:], 95) ), } def _per_frame_metrics_from_npz( motion_key: str, data: Dict[str, np.ndarray], dof_mode: str = "29", robot_control_dt: float = DEFAULT_ROBOT_CONTROL_DT, ) -> pd.DataFrame: """Compute per-frame metrics for a single motion clip from loaded npz arrays. Expects the following keys in `data` (URDF order): - dof_pos, robot_dof_pos - global_translation, robot_global_translation - global_rotation_quat, robot_global_rotation_quat (xyzw) """ # Required arrays jpos_gt = np.asarray(data["ref_global_translation"]) # (T, J, 3) jpos_pred = np.asarray(data["robot_global_translation"]) # (T, J, 3) rot_gt = np.asarray(data["ref_global_rotation_quat"]) # (T, J, 4) xyzw rot_pred = np.asarray(data["robot_global_rotation_quat"]) # (T, J, 4) dof_gt = np.asarray(data["ref_dof_pos"]) # (T, D) dof_pred = np.asarray(data["robot_dof_pos"]) # (T, D) robot_dof_vel = ( np.asarray(data["robot_dof_vel"]) if "robot_dof_vel" in data else None ) robot_dof_acc = ( np.asarray(data["robot_dof_acc"]) if "robot_dof_acc" in data else None ) robot_dof_torque = ( np.asarray(data["robot_dof_torque"]) if "robot_dof_torque" in data else None ) robot_low_level_dof_torque = ( np.asarray(data["robot_low_level_dof_torque"]) if "robot_low_level_dof_torque" in data else None ) robot_action_rate = ( np.asarray(data["robot_action_rate"]) if "robot_action_rate" in data else None ) total_dofs_in_file = int(dof_gt.shape[1]) IDX_23_IN_29_DOF = list(range(19)) + list(range(22, 26)) IDX_23_IN_29_BODY = [0] + [i + 1 for i in IDX_23_IN_29_DOF] if dof_mode == "23": if total_dofs_in_file == 29: dof_gt = dof_gt[:, IDX_23_IN_29_DOF] dof_pred = dof_pred[:, IDX_23_IN_29_DOF] if ( robot_dof_vel is not None and int(robot_dof_vel.shape[1]) == total_dofs_in_file ): robot_dof_vel = robot_dof_vel[:, IDX_23_IN_29_DOF] if ( robot_dof_acc is not None and int(robot_dof_acc.shape[1]) == total_dofs_in_file ): robot_dof_acc = robot_dof_acc[:, IDX_23_IN_29_DOF] if ( robot_dof_torque is not None and int(robot_dof_torque.shape[1]) == total_dofs_in_file ): robot_dof_torque = robot_dof_torque[:, IDX_23_IN_29_DOF] if ( robot_low_level_dof_torque is not None and int(robot_low_level_dof_torque.shape[1]) == total_dofs_in_file ): robot_low_level_dof_torque = robot_low_level_dof_torque[ :, IDX_23_IN_29_DOF ] jpos_gt = jpos_gt[:, IDX_23_IN_29_BODY, :] jpos_pred = jpos_pred[:, IDX_23_IN_29_BODY, :] rot_gt = rot_gt[:, IDX_23_IN_29_BODY, :] rot_pred = rot_pred[:, IDX_23_IN_29_BODY, :] assert jpos_gt.shape == jpos_pred.shape assert rot_gt.shape == rot_pred.shape assert dof_gt.shape == dof_pred.shape num_frames = int(jpos_gt.shape[0]) # Global MPJPE [mm] mpjpe_g = ( np.mean(np.linalg.norm(jpos_gt - jpos_pred, axis=2), axis=1) * 1000.0 ) # Per-frame maximum body-link position error [m] (used for failure criterion) # per_joint_err = np.linalg.norm(jpos_pred - jpos_gt, axis=2) # frame_max_body_pos_err = np.max(per_joint_err, axis=1) frame_max_body_pos_err = np.abs(jpos_pred[:, 0, 2] - jpos_gt[:, 0, 2]) # Localize by root (index 0) jpos_gt_local = jpos_gt - jpos_gt[:, [0]] jpos_pred_local = jpos_pred - jpos_pred[:, [0]] ref_body_pos_root_rel = quat_apply( quat_inv(rot_gt[:, 0, :]), jpos_gt - jpos_gt[:, [0]], ) robot_body_pos_root_rel = quat_apply( quat_inv(rot_pred[:, 0, :]), jpos_pred - jpos_pred[:, [0]], ) mpjpe_l = ( np.mean( np.linalg.norm( robot_body_pos_root_rel - ref_body_pos_root_rel, axis=2 ), axis=1, ) * 1000.0 ) # Procrustes-aligned MPJPE [mm] pa_per_joint = p_mpjpe(jpos_pred_local, jpos_gt_local) mpjpe_pa = np.mean(pa_per_joint, axis=1) * 1000.0 # Velocity/acceleration errors from positions (discrete frame diffs) [mm/frame],[mm/frame^2] vel_gt = jpos_gt[1:] - jpos_gt[:-1] vel_pred = jpos_pred[1:] - jpos_pred[:-1] vel_dist = ( np.mean(np.linalg.norm(vel_pred - vel_gt, axis=2), axis=1) * 1000.0 ) acc_gt = jpos_gt[:-2] - 2 * jpos_gt[1:-1] + jpos_gt[2:] acc_pred = jpos_pred[:-2] - 2 * jpos_pred[1:-1] + jpos_pred[2:] accel_dist = ( np.mean(np.linalg.norm(acc_pred - acc_gt, axis=2), axis=1) * 1000.0 ) # DOF angle errors [radians] — whole body average dof_err = np.abs(dof_pred - dof_gt) whole_body_joints_dist = np.mean(dof_err, axis=1) # Root orientation errors [radians] — handle zero-norm/invalid quaternions by NaN q_gt_root = rot_gt[:, 0, :] q_pred_root = rot_pred[:, 0, :] norms_gt = np.linalg.norm(q_gt_root, axis=1) norms_pred = np.linalg.norm(q_pred_root, axis=1) valid_mask = ( (norms_gt > 0.0) & (norms_pred > 0.0) & np.isfinite(norms_gt) & np.isfinite(norms_pred) ) root_r_error = np.full((num_frames,), np.nan, dtype=float) root_p_error = np.full((num_frames,), np.nan, dtype=float) root_y_error = np.full((num_frames,), np.nan, dtype=float) if np.any(valid_mask): q_gt_valid = q_gt_root[valid_mask] / norms_gt[valid_mask, None] q_pred_valid = q_pred_root[valid_mask] / norms_pred[valid_mask, None] rel_valid = sRot.from_quat(q_gt_valid).inv() * sRot.from_quat( q_pred_valid ) euler_xyz = rel_valid.as_euler("xyz", degrees=False) root_r_error[valid_mask] = np.abs(euler_xyz[:, 0]) root_p_error[valid_mask] = np.abs(euler_xyz[:, 1]) root_y_error[valid_mask] = np.abs(euler_xyz[:, 2]) # Root velocity error [m/frame] root_pos_gt = jpos_gt[:, 0, :] root_pos_pred = jpos_pred[:, 0, :] root_vel_err = np.linalg.norm( (root_pos_pred[1:] - root_pos_pred[:-1]) - (root_pos_gt[1:] - root_pos_gt[:-1]), axis=1, ) # Root height error [m] root_height_error = np.abs(root_pos_pred[:, 2] - root_pos_gt[:, 2]) # Robot low-level magnitudes (optional) mean_dof_vel = np.full((num_frames,), np.nan, dtype=float) if robot_dof_vel is not None: if int(robot_dof_vel.shape[0]) != num_frames: raise ValueError( "robot_dof_vel frame length mismatch: " f"{robot_dof_vel.shape[0]} vs {num_frames}" ) mean_dof_vel = np.linalg.norm(robot_dof_vel, axis=1) mean_dof_acc = np.full((num_frames,), np.nan, dtype=float) if robot_dof_acc is not None: if int(robot_dof_acc.shape[0]) != num_frames: raise ValueError( "robot_dof_acc frame length mismatch: " f"{robot_dof_acc.shape[0]} vs {num_frames}" ) mean_dof_acc = np.linalg.norm(robot_dof_acc, axis=1) mean_dof_torque = np.full((num_frames,), np.nan, dtype=float) mean_torque_jump_norm = np.full((num_frames,), np.nan, dtype=float) mean_torque_jump_ratio = np.full((num_frames,), np.nan, dtype=float) if robot_dof_torque is not None: if int(robot_dof_torque.shape[0]) != num_frames: raise ValueError( "robot_dof_torque frame length mismatch: " f"{robot_dof_torque.shape[0]} vs {num_frames}" ) mean_dof_torque = np.linalg.norm(robot_dof_torque, axis=1) chatter_torque = robot_low_level_dof_torque if chatter_torque is None: chatter_torque = robot_dof_torque if chatter_torque is not None and int(chatter_torque.shape[0]) > 1: torque_jump_norm, torque_jump_ratio = _compute_torque_jump_series( chatter_torque, robot_control_dt ) mean_torque_jump_norm = _aggregate_sample_metric_to_frames( torque_jump_norm, num_frames ) mean_torque_jump_ratio = _aggregate_sample_metric_to_frames( torque_jump_ratio, num_frames ) mean_action_rate = np.full((num_frames,), np.nan, dtype=float) if robot_action_rate is not None: flat_action_rate = robot_action_rate.reshape(-1) if int(flat_action_rate.shape[0]) != num_frames: raise ValueError( "robot_action_rate frame length mismatch: " f"{flat_action_rate.shape[0]} vs {num_frames}" ) mean_action_rate = flat_action_rate # Frame DataFrame (align lengths by padding NaN at the start where needed) def pad_front(x: np.ndarray, pad: int) -> np.ndarray: if pad <= 0: return x return np.concatenate( [np.full((pad,), np.nan, dtype=float), x], axis=0 ) df = pd.DataFrame( { "motion_key": [motion_key] * num_frames, "frame_idx": np.arange(num_frames, dtype=int), "mpjpe_g": mpjpe_g, "mpjpe_l": mpjpe_l, "mpjpe_pa": mpjpe_pa, "vel_dist": pad_front(vel_dist, 1), "accel_dist": pad_front(accel_dist, 2), "frame_max_body_pos_err": frame_max_body_pos_err, "whole_body_joints_dist": whole_body_joints_dist, "root_r_error": root_r_error, "root_p_error": root_p_error, "root_y_error": root_y_error, "root_vel_error": pad_front(root_vel_err, 1), "root_height_error": root_height_error, "mean_dof_vel": mean_dof_vel, "mean_dof_acc": mean_dof_acc, "mean_dof_torque": mean_dof_torque, "mean_torque_jump_norm": mean_torque_jump_norm, "mean_torque_jump_ratio": mean_torque_jump_ratio, "mean_action_rate": mean_action_rate, } ) return df def offline_evaluate_dumped_npzs( npz_dir: str, output_json_path: str, failure_pos_err_thresh_m: float = 0.25, metric_calculation: str = "per_clip", dof_mode: str = "29", threadpool_max_workers: Optional[int] = None, ) -> Dict[str, dict]: """Evaluate dumped NPZs in `npz_dir` and write a JSON summary to `output_dir`. The function produces dataset-wide averages and per-clip averages across frames. """ npz_dir_abs = Path(npz_dir).resolve() os.makedirs(npz_dir_abs, exist_ok=True) # Add file handler for logging to metric.log metric_log_path = npz_dir_abs / "metric.log" logger.add( str(metric_log_path), format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", level="INFO", ) logger.info(f"Input NPZ directory (absolute): {npz_dir_abs}") # Gather NPZ files files = sorted(glob(os.path.join(npz_dir_abs, "*.npz"))) if len(files) == 0: raise FileNotFoundError(f"No NPZ files found in: {npz_dir_abs}") # Accumulate per-frame metrics frame_tables: List[pd.DataFrame] = [] clip_meta: Dict[str, dict] = {} skipped_files_count = 0 required_keys = [ "ref_dof_pos", "ref_dof_vel", "ref_global_translation", "ref_global_rotation_quat", "ref_global_velocity", "ref_global_angular_velocity", "robot_dof_pos", "robot_dof_vel", "robot_global_translation", "robot_global_rotation_quat", "robot_global_velocity", "robot_global_angular_velocity", ] optional_keys = [ "robot_dof_acc", "robot_dof_torque", "robot_low_level_dof_torque", "robot_low_level_torque_dt", "robot_low_level_foot_contact", "robot_low_level_foot_normal_force", "robot_low_level_foot_tangent_speed", "robot_low_level_contact_dt", "robot_action_rate", "robot_moe_expert_indices", "robot_moe_expert_logits", ] def _compute_metrics_from_file(fpath: str): try: with np.load(fpath, allow_pickle=True) as npz_data: # Extract arrays and metadata data = {k: npz_data[k] for k in required_keys} for k in optional_keys: if k in npz_data.files: data[k] = npz_data[k] metadata = _parse_metadata_entry(npz_data.get("metadata")) robot_control_dt = _extract_robot_control_dt(metadata, data) low_level_contact_dt = _extract_low_level_contact_dt( metadata, data, robot_control_dt ) motion_key = os.path.splitext(os.path.basename(fpath))[0] clip_len_from_name = _parse_clip_len_from_name(fpath) df_frames = _per_frame_metrics_from_npz( motion_key=motion_key, data=data, dof_mode=dof_mode, robot_control_dt=robot_control_dt, ) chatter_summary = _compute_clip_torque_jump_summary( data=data, dof_mode=dof_mode, torque_dt=robot_control_dt ) stability_summary = _compute_clip_stability_summary( data=data, robot_control_dt=robot_control_dt, low_level_contact_dt=low_level_contact_dt, ) # Clip-level info and failure criterion (max body-link pos error > threshold) num_frames_clip = int(df_frames.shape[0]) clip_length = int( metadata.get( "clip_length", clip_len_from_name or num_frames_clip ) ) max_body_err = float( np.nanmax(df_frames["frame_max_body_pos_err"].to_numpy()) ) success = 1.0 if max_body_err <= failure_pos_err_thresh_m else 0.0 clip_meta_entry = { "motion_key": motion_key, "num_frames": num_frames_clip, "clip_length": clip_length, "success": success, "max_body_pos_err": max_body_err, "failure_threshold_m": float(failure_pos_err_thresh_m), **chatter_summary, **stability_summary, } return fpath, df_frames, motion_key, clip_meta_entry, None except (ValueError, KeyError, BadZipFile, EOFError, OSError) as e: return fpath, None, None, None, e if threadpool_max_workers is None: max_workers = max(1, min(len(files), 24)) requested_workers = None else: requested_workers = int(threadpool_max_workers) if requested_workers <= 0: raise ValueError("threadpool_max_workers must be > 0") max_workers = min(requested_workers, len(files)) if max_workers <= 0: max_workers = 1 logger.info( f"Metric ThreadPoolExecutor max_workers={max_workers} " f"(requested={requested_workers}, num_npz_files={len(files)})" ) with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit(_compute_metrics_from_file, fpath): file_idx for file_idx, fpath in enumerate(files) } processed_results = [None] * len(files) for future in tqdm( as_completed(futures.keys()), total=len(files), desc="Compute metrics from NPZs", ): processed_results[futures[future]] = future.result() for result in processed_results: ( fpath, df_frames, motion_key, clip_meta_entry, file_error, ) = result if file_error is not None: logger.warning(f"\nCaught an error while processing file: {fpath}") logger.warning(f"Error type: {type(file_error).__name__}") logger.warning(f"Error message: {file_error}") logger.warning("This file will be SKIPPED.") skipped_files_count += 1 continue frame_tables.append(df_frames) clip_meta[motion_key] = clip_meta_entry if skipped_files_count > 0: logger.info( f"\nFinished processing. Skipped a total of {skipped_files_count} files due to errors." ) # If all files were skipped, there's nothing to process further. if not frame_tables: logger.error( "No valid NPZ files could be processed. Aborting evaluation." ) return {} # Concatenate per-frame metrics all_frames = pd.concat(frame_tables, ignore_index=True) # Per-clip averages frame_metric_cols = [ "mpjpe_g", "mpjpe_l", "whole_body_joints_dist", "root_vel_error", "root_r_error", "root_p_error", "root_y_error", "root_height_error", "mean_dof_vel", "mean_dof_acc", "mean_dof_torque", "mean_torque_jump_norm", "mean_torque_jump_ratio", "mean_action_rate", ] percentile_metric_cols = [ "mean_torque_jump_norm", "mean_torque_jump_ratio", ] percentile_rename_map = { "mean_torque_jump_norm": "p95_torque_jump_norm", "mean_torque_jump_ratio": "p95_torque_jump_ratio", } metric_cols = frame_metric_cols + list(percentile_rename_map.values()) clip_only_metric_cols = [ "torque_chatter_hf_ratio", "torque_jump_burst_max", "expert_switching_js_div", "torso_rp_hf_ratio", "torso_rp_angacc_p95", "foot_contact_toggle_rate", "foot_impact_force_p95", "stance_slip_speed_p95", ] metric_cols += clip_only_metric_cols # Metric display configuration: metric_key -> (display_name, unit) metric_display_map = { "mpjpe_g": ("Global Bodylink Mean Position Error", "mm"), "mpjpe_l": ("Local Bodylink Mean Position Error", "mm"), "whole_body_joints_dist": ("DOF Position Error", "rad"), "root_vel_error": ("Root Velocity Error", "m/s"), "root_r_error": ("Root Roll Error", "rad"), "root_p_error": ("Root Pitch Error", "rad"), "root_y_error": ("Root Yaw Error", "rad"), "root_height_error": ("Root Height Error", "mm"), "mean_dof_vel": ("Mean DOF Velocity", "rad/s"), "mean_dof_acc": ("Mean DOF Acceleration", "rad/s^2"), "mean_dof_torque": ("Mean DOF Torque", "N*m"), "mean_torque_jump_norm": ("Mean Torque Jump Norm", "N*m/s"), "p95_torque_jump_norm": ("P95 Torque Jump Norm", "N*m/s"), "mean_torque_jump_ratio": ("Mean Torque Jump Ratio", "ratio"), "p95_torque_jump_ratio": ("P95 Torque Jump Ratio", "ratio"), "mean_action_rate": ("Mean Action Rate", "1/s"), "torque_chatter_hf_ratio": ("Torque Chatter HF Ratio", "ratio"), "torque_jump_burst_max": ("Torque Jump Burst Max", "ratio"), "expert_switching_js_div": ("Expert Switching JS Div", "bits"), "torso_rp_hf_ratio": ("Torso RP HF Ratio", "ratio"), "torso_rp_angacc_p95": ("Torso RP Angular Accel P95", "rad/s^2"), "foot_contact_toggle_rate": ("Foot Contact Toggle Rate", "1/s"), "foot_impact_force_p95": ("Foot Impact Force P95", "N"), "stance_slip_speed_p95": ("Stance Slip Speed P95", "m/s"), } per_clip_mean = ( all_frames.groupby("motion_key")[frame_metric_cols] .mean(numeric_only=True) .reset_index() ) per_clip_p95 = ( all_frames.groupby("motion_key")[percentile_metric_cols] .quantile(0.95) .reset_index() .rename(columns=percentile_rename_map) ) per_clip_summary = per_clip_mean.merge( per_clip_p95, on="motion_key", how="left" ) for metric_key in ( "mean_torque_jump_norm", "p95_torque_jump_norm", "mean_torque_jump_ratio", "p95_torque_jump_ratio", ): per_clip_summary[metric_key] = per_clip_summary["motion_key"].map( {mk: clip_meta[mk].get(metric_key, np.nan) for mk in clip_meta} ) for metric_key in clip_only_metric_cols: per_clip_summary[metric_key] = per_clip_summary["motion_key"].map( {mk: clip_meta[mk].get(metric_key, np.nan) for mk in clip_meta} ) # Merge with success flags per_clip_records = [] for _, row in per_clip_summary.iterrows(): mk = row["motion_key"] rec = {**row.to_dict(), **clip_meta.get(mk, {})} per_clip_records.append(rec) # Persist per-clip metrics as a tabular CSV for easier downstream analysis. per_clip_df = pd.DataFrame(per_clip_records) output_csv_path = str(npz_dir_abs / "per_clip_metrics.csv") per_clip_df.to_csv(output_csv_path, index=False) logger.info(f"Saved per-clip metrics CSV to: {output_csv_path}") dataset_means = {} dataset_medians = {} if metric_calculation == "per_frame": agg_source = all_frames agg_desc = "PER-FRAME" else: agg_source = per_clip_summary agg_desc = "PER-CLIP" for k in metric_cols: if k in agg_source.columns: arr = agg_source[k].to_numpy() else: arr = per_clip_summary[k].to_numpy() dataset_means[k] = _safe_nanmean(arr) dataset_medians[k] = _safe_nanmedian(arr) success_rate = float( np.mean([clip_meta[mk]["success"] for mk in clip_meta]) if len(clip_meta) > 0 else 0.0 ) dataset_means["success_rate"] = success_rate # Compose result and write result = { "dataset": { "calculation_mode": metric_calculation, "mean": dataset_means, "median": dataset_medians, "success_rate": success_rate, }, "num_clips": int(len(clip_meta)), "per_clip": per_clip_records, } with open(output_json_path, "w", encoding="utf-8") as f: json.dump(result, f, indent=2) # Conversion factors for unit conversion (assuming 50Hz) frame_rate_hz = 50.0 unit_conversions = { "root_height_error": 1000.0, # m to mm "root_vel_error": frame_rate_hz, # m/frame to m/s } table_data = [] # Iterate through metric_display_map to preserve order for key in metric_display_map.keys(): if key not in dataset_means: continue val_mean = dataset_means[key] val_median = dataset_medians[key] display_name, unit = metric_display_map[key] # Apply unit conversion if needed if key in unit_conversions: factor = unit_conversions[key] val_mean = val_mean * factor val_median = val_median * factor def fmt(v): return f"{v:.4f}" if isinstance(v, float) else str(v) table_data.append([display_name, fmt(val_mean), fmt(val_median), unit]) table_headers = ["Metric", "Mean", "Median", "Unit"] output_tsv_path = str(npz_dir_abs / "whole_dataset_metrics.tsv") with open(output_tsv_path, "w", encoding="utf-8", newline="") as f: writer = csv.writer(f, delimiter="\t", lineterminator="\n") writer.writerow(table_headers) writer.writerows(table_data) logger.info(f"Saved whole-dataset metrics TSV to: {output_tsv_path}") table_str = tabulate( table_data, headers=table_headers, tablefmt="simple_outline", colalign=("left", "left", "left", "left"), ) logger.info( "\n" + "=" * 80 + f"\nDATASET-WISE METRICS ({agg_desc})\n" + "=" * 80 + f"\n\n{table_str}\n" + "=" * 80 + "\n" ) return result def parse_ckpt_and_dataset_from_eval_dirname( eval_dir_name: str, dataset_suffix: str ): VALID_PREFIXES = ["isaaclab_eval_output_", "mujoco_eval_output_"] matched_prefix = None for prefix in VALID_PREFIXES: if eval_dir_name.startswith(prefix): matched_prefix = prefix break if matched_prefix is None: return None, None rest = eval_dir_name[len(matched_prefix) :] if not rest.endswith(dataset_suffix): return None, None model_part = rest[: -len(dataset_suffix)] if model_part.endswith("_"): model_part = model_part[:-1] m = re.search(r"model_(\d+)$", model_part) if not m: return None, dataset_suffix return m.group(1), dataset_suffix def run_evaluation( npz_dir: str, dataset_suffix: str, failure_pos_err_thresh_m: float = 0.25, metric_calculation: str = "per_clip", dof_mode: str = "29", threadpool_max_workers: Optional[int] = None, ): """ Main function to run evaluation. It scans a root directory, runs evaluation for each found subdirectory, and generates a final summary report. Args: npz_dir (str): Top-level directory containing all model evaluation results (e.g., 'logs/test'). output_dir (str): Directory to store all generated JSON files and logs. failure_pos_err_thresh_m (float): The position error threshold in meters to determine a failure. """ root_path = Path(npz_dir) logger.info(f"Starting batch evaluation. Root directory: '{root_path}'") logger.info( f"Searching for directories matching pattern: '{dataset_suffix}'" ) def has_npz_files(path: Path) -> bool: return path.is_dir() and any(path.glob("*.npz")) is_single_eval_dir = ( root_path.is_dir() and ( root_path.name.startswith("isaaclab_eval_output_") or root_path.name.startswith("mujoco_eval_output_") ) and has_npz_files(root_path) ) if is_single_eval_dir: output_path = root_path else: output_path = root_path / f"metrics_output_{dataset_suffix}" output_path.mkdir(parents=True, exist_ok=True) if is_single_eval_dir: logger.info( f"Detected '{root_path}' as a single evaluation directory. " "Running offline evaluation only for this directory." ) model_name = root_path.parent.name ckpt_str, ds = parse_ckpt_and_dataset_from_eval_dirname( root_path.name, dataset_suffix ) if ckpt_str is None: logger.warning( f"Could not parse checkpoint/dataset from directory name '{root_path.name}'. " "Using 'checkpoint_unknown' in output filename." ) ckpt_str = "checkpoint_unknown" ds = dataset_suffix output_json_name = f"{model_name}_{ckpt_str}_{dof_mode}dof.json" output_json_path = output_path / output_json_name offline_evaluate_dumped_npzs( npz_dir=str(root_path), output_json_path=str(output_json_path), failure_pos_err_thresh_m=failure_pos_err_thresh_m, metric_calculation=metric_calculation, dof_mode=dof_mode, threadpool_max_workers=threadpool_max_workers, ) logger.success( f"Finished single-directory evaluation: model='{model_name}', checkpoint={ckpt_str}" ) return logger.info( f"Treating '{root_path}' as root directory for batch evaluation." ) # Find all directories matching the evaluation output pattern. eval_dirs = sorted( p for p in root_path.glob(f"**/*eval_output_*_{dataset_suffix}") if p.is_dir() ) if not eval_dirs: logger.error( f"No directories matching the pattern '{dataset_suffix}' found under '{root_path}'. " "Please check the path and pattern." ) return all_results = [] # Process each found evaluation directory. for eval_dir in tqdm(eval_dirs, desc="Overall Progress"): # Extract model name from the parent directory. model_name = eval_dir.parent.name # Parse the checkpoint number from the directory name. ckpt_str, ds = parse_ckpt_and_dataset_from_eval_dirname( eval_dir.name, dataset_suffix ) if ckpt_str is None: logger.warning( f"Could not parse ckpt/dataset from '{eval_dir.name}'. Skipping." ) continue checkpoint = int(ckpt_str) logger.info( f"\n--- Processing: model='{model_name}', dataset='{ds}', checkpoint={checkpoint} ---" ) # Construct a unique output JSON filename. output_json_name = f"{model_name}_{checkpoint}.json" output_json_path = output_path / output_json_name # Call the evaluation function for the current directory. result = offline_evaluate_dumped_npzs( npz_dir=str(eval_dir), output_json_path=str(output_json_path), failure_pos_err_thresh_m=failure_pos_err_thresh_m, metric_calculation=metric_calculation, dof_mode=dof_mode, threadpool_max_workers=threadpool_max_workers, ) if result and "dataset" in result: # Collect dataset-level average metrics for the final summary. flat_result = { "model": model_name, "checkpoint": checkpoint, **result["dataset"], } all_results.append(flat_result) logger.success( f"--- Finished processing: model='{model_name}', checkpoint={checkpoint} ---" ) else: logger.error( f"--- Failed to process: model='{model_name}', checkpoint={checkpoint} ---" ) if not all_results: logger.error( "No evaluations succeeded. Cannot generate a summary report." ) return logger.info("\n" + "=" * 80) logger.info("Batch evaluation finished successfully.") logger.info(f"Total successful evaluations: {len(all_results)}") logger.info("=" * 80) if __name__ == "__main__": argument_parser = argparse.ArgumentParser() argument_parser.add_argument("--npz_dir", type=str, required=True) argument_parser.add_argument( "--dataset_suffix", type=str, required=True, ) argument_parser.add_argument( "--failure_pos_err_thresh_m", type=float, default=0.25 ) argument_parser.add_argument( "--metric_calculation", type=str, choices=["per_clip", "per_frame"], default="per_clip", help="Calculation mode for dataset metrics. 'per_clip' averages clip means (Macro). 'per_frame' averages all frames (Micro).", ) argument_parser.add_argument( "--dof_mode", type=str, choices=["29", "23"], default="29", help="Compute metrics for full 29 DoF or reduced 23 DoF (excluding hands).", ) argument_parser.add_argument( "--threadpool_max_workers", type=int, default=None, help="Max workers for per-NPZ ThreadPoolExecutor. " "Default: None (auto = min(num_files, 24)).", ) args = argument_parser.parse_args() run_evaluation( npz_dir=args.npz_dir, dataset_suffix=args.dataset_suffix, failure_pos_err_thresh_m=args.failure_pos_err_thresh_m, metric_calculation=args.metric_calculation, dof_mode=args.dof_mode, threadpool_max_workers=args.threadpool_max_workers, ) ================================================ FILE: holomotion/src/evaluation/multi_model_metrics_report.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import argparse import itertools import json from collections import defaultdict from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from loguru import logger from scipy.stats import mannwhitneyu import textwrap DEFAULT_METRICS_TO_ANALYZE = [ "mpjpe_g", "mpjpe_l", "whole_body_joints_dist", "root_vel_error", "root_r_error", "root_p_error", "root_y_error", "root_height_error", ] RADAR_METRICS = [ "mpjpe_g", "mpjpe_l", "whole_body_joints_dist", "root_vel_error", ] DEFAULT_RADAR_MAPPING = {m: m for m in RADAR_METRICS} DEFAULT_ALPHA = 0.05 class AnalysisReportGenerator: """Load per-clip JSON metrics, run analysis, generate plots + markdown report.""" def __init__( self, json_dir: str, plots_dir: str, dataset_name: str, metrics_to_analyze: List[str], radar_metric_mapping: Dict[str, str], metric_types_for_radar: Dict[str, str], alpha: float = DEFAULT_ALPHA, plot_quantile_cutoff: float = 0.99, kde_linewidth: float = 2.5, min_normalized_value: float = 0.2, radar_chart_filename: str = "radar_chart_comparison.png", ) -> None: self.json_dir = Path(json_dir) self.plots_dir = Path(plots_dir) self.dataset_name = dataset_name self.metrics_to_analyze = metrics_to_analyze self.radar_metric_mapping = radar_metric_mapping.copy() self.metric_types_for_radar = metric_types_for_radar.copy() self.alpha = alpha self.plot_quantile_cutoff = plot_quantile_cutoff self.kde_linewidth = kde_linewidth self.min_normalized_value = min_normalized_value self.radar_chart_filename = radar_chart_filename self.df: Optional[pd.DataFrame] = None self.models: List[str] = [] def run(self) -> None: self.plots_dir.mkdir(exist_ok=True, parents=True) self.df = self._load_and_prepare_data() if self.df is None or self.df.empty: logger.warning("No valid data loaded; aborting analysis.") return self.models = sorted(self.df["model"].unique().tolist()) if len(self.models) < 1: logger.warning("No models found in data; aborting analysis.") return self._create_matplotlib_radar_chart() markdown_content = self._generate_markdown_report() ts = datetime.now().strftime("%Y%m%d_%H%M%S") out_md = ( self.plots_dir / f"analysis_report_{self.dataset_name}_{ts}.md" ) out_md.write_text(markdown_content, encoding="utf-8") logger.info(f"Markdown report written to: {out_md}") def _load_and_prepare_data(self) -> Optional[pd.DataFrame]: if not self.json_dir.is_dir(): logger.error(f"json_dir '{self.json_dir}' is not a directory.") return None json_files = list(self.json_dir.glob("*.json")) if not json_files: logger.error(f"No .json files found in '{self.json_dir}'.") return None all_clips: List[Dict[str, Any]] = [] for jf in json_files: model_name = jf.stem data = json.loads(jf.read_text(encoding="utf-8")) if not isinstance(data, dict) or "per_clip" not in data: logger.warning( f"Skipping non-eval JSON file '{jf.name}' " f"(top-level type={type(data)}, has_per_clip={'per_clip' in data if isinstance(data, dict) else False})." ) continue per_clip = data.get("per_clip") if not per_clip: logger.warning( f"File '{jf.name}' has empty 'per_clip'; skipping." ) continue for clip in per_clip: clip["model"] = model_name all_clips.append(clip) if not all_clips: logger.error("No per_clip data found in any JSON files.") return None df = pd.DataFrame(all_clips) logger.info( f"Loaded {len(df)} clip records from {len(json_files)} JSON files." ) return df def _create_kde_plot(self, metric: str, save_path: Path) -> None: if self.df is None or metric not in self.df.columns: return if self.df[metric].isnull().all(): return q_high = self.df[metric].quantile(self.plot_quantile_cutoff) df_filtered = self.df[self.df[metric] <= q_high] plt.style.use("seaborn-v0_8-whitegrid") fig, ax = plt.subplots(figsize=(12, 7)) sns.kdeplot( data=df_filtered, x=metric, hue="model", hue_order=self.models, ax=ax, fill=False, common_norm=False, palette="tab10", linewidth=self.kde_linewidth, ) ax.set_title( f'Error Distribution for "{metric}" on {self.dataset_name}', fontsize=16, weight="bold", ) ax.set_xlabel(f"Error Value ({metric})", fontsize=12) ax.set_ylabel("Density", fontsize=12) ax.set_xlim(left=0) legend = ax.get_legend() if legend: legend.set_title("Models", prop={"size": 14, "weight": "bold"}) for text in legend.get_texts(): text.set_fontsize(14) fig.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) def _create_matplotlib_radar_chart(self) -> None: if self.df is None: return original_metrics = list(self.radar_metric_mapping.keys()) raw_labels = [self.radar_metric_mapping[m] for m in original_metrics] display_labels = [ textwrap.fill(label, width=20, break_long_words=False) for label in raw_labels ] num_metrics = len(original_metrics) median_df = self.df.groupby("model")[original_metrics].median() rounded_median_df = median_df.round(2) normalized_df = pd.DataFrame( index=self.models, columns=original_metrics, dtype=float ) scale = 1.0 - self.min_normalized_value for metric in original_metrics: medians = rounded_median_df[metric].dropna() if medians.empty: normalized_df[metric] = self.min_normalized_value continue min_val, max_val = medians.min(), medians.max() rng = max_val - min_val if max_val > min_val else 1.0 for model in self.models: val = rounded_median_df.loc[model, metric] if pd.isna(val): normalized_df.loc[model, metric] = ( self.min_normalized_value ) continue lower_better = ( self.metric_types_for_radar.get(metric, "lower") == "lower" ) if lower_better: base = (max_val - val) / rng else: base = (val - min_val) / rng norm = self.min_normalized_value + base * scale normalized_df.loc[model, metric] = norm angles = np.linspace( 0, 2 * np.pi, num_metrics, endpoint=False ).tolist() angles += angles[:1] fig, ax = plt.subplots( figsize=(10, 10), subplot_kw=dict(projection="polar") ) cmap = plt.get_cmap("tab10") colors = {m: cmap(i % 10) for i, m in enumerate(self.models)} for model in self.models: vals = normalized_df.loc[model].tolist() vals += vals[:1] ax.fill(angles, vals, color=colors[model], alpha=0.25) ax.plot( angles, vals, color=colors[model], linewidth=2.5, label=model, marker="o", markersize=7, markeredgecolor="white", markeredgewidth=1, ) for j, metric in enumerate(original_metrics): angle = angles[j] groups: Dict[str, List[float]] = defaultdict(list) for model in self.models: orig_val = rounded_median_df.loc[model, metric] norm_val = normalized_df.loc[model, metric] groups[f"{orig_val:.2f}"].append(norm_val) for label_text, norm_vals in groups.items(): avg_norm = float(np.mean(norm_vals)) offset = 0.05 ax.text( angle, avg_norm + offset, label_text, ha="center", va="center", color="black", weight="bold", fontsize=9, bbox=dict( boxstyle="square,pad=0.3", fc="white", ec="none", alpha=0.8, ), ) ax.set_thetagrids(np.degrees(angles[:-1]), display_labels, fontsize=16) ax.tick_params(axis="x", pad=30) ax.set_rgrids([0.4, 0.6, 0.8, 1.0], labels=[]) ax.set_ylim(0, 1.25) ax.spines["polar"].set_visible(False) ax.grid(color="grey", linestyle="--", linewidth=0.5) handles, labels = ax.get_legend_handles_labels() if handles: legend_map = dict(zip(labels, handles)) ordered_labels = [m for m in self.models if m in legend_map] ordered_handles = [legend_map[m] for m in ordered_labels] ax.legend( handles=ordered_handles, labels=ordered_labels, loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=len(ordered_handles), fontsize=14, frameon=False, ) fig.suptitle( f"Model Comparison on {self.dataset_name} Dataset", fontsize=20, weight="bold", y=1.05, ) save_path = self.plots_dir / self.radar_chart_filename fig.savefig(save_path, dpi=300, bbox_inches="tight") plt.close(fig) logger.info(f"Radar chart saved to: {save_path}") def _generate_markdown_report(self) -> str: if self.df is None or len(self.models) < 2: return "" parts: List[str] = [ f"**Dataset**: {self.dataset_name}", f"**Models**: {', '.join(self.models)}", f"**Significance level (alpha)**: {self.alpha}", "### Pairwise metric comparisons and distributions", ] two_models = len(self.models) == 2 if two_models: model1, model2 = self.models[0], self.models[1] for metric in self.metrics_to_analyze: if metric not in self.df.columns: continue p_value_str = "" if two_models: d1 = self.df.loc[self.df["model"] == model1, metric].dropna() d2 = self.df.loc[self.df["model"] == model2, metric].dropna() if not d1.empty and not d2.empty: _, p_val = mannwhitneyu(d1, d2, alternative="two-sided") p_value_str = f" (p = {p_val:.3g})" parts.append(f"#### Metric: `{metric}`{p_value_str}") metric_stats: List[Dict[str, Any]] = [] for name in self.models: data = self.df.loc[self.df["model"] == name, metric].dropna() if data.empty: continue metric_stats.append( { "Model": name, "Median": data.median(), "Q1 (25%)": data.quantile(0.25), "Q3 (75%)": data.quantile(0.75), } ) if metric_stats: stats_df = ( pd.DataFrame(metric_stats) .sort_values(by="Median") .reset_index(drop=True) ) parts.append(stats_df.to_markdown(index=False, floatfmt=".4f")) findings: List[str] = [] lower_better = ( self.metric_types_for_radar.get(metric, "lower") == "lower" ) for m1, m2 in itertools.combinations(self.models, 2): d1 = self.df.loc[self.df["model"] == m1, metric].dropna() d2 = self.df.loc[self.df["model"] == m2, metric].dropna() if d1.empty or d2.empty: continue _, p_val = mannwhitneyu(d1, d2, alternative="two-sided") if p_val >= self.alpha: continue m1_med, m2_med = d1.median(), d2.median() better, worse = (m1, m2) if m1_med < m2_med else (m2, m1) if not lower_better: better, worse = worse, better findings.append( f"- **{better}** is significantly better than **{worse}** " f"(p < {self.alpha})." ) if findings: parts.append("\n".join(findings)) else: parts.append( "No statistically significant differences between models." ) safe_metric = metric.replace(" ", "_") plot_filename = f"{safe_metric}.png" plot_path = self.plots_dir / plot_filename self._create_kde_plot(metric, plot_path) parts.append( f"##### Distribution plot\n" f"![{metric} distribution on {self.dataset_name}]({plot_filename})" ) return "\n\n".join(parts) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Analyze per-clip JSON metrics, generate plots and markdown report." ) parser.add_argument("--json_dir", type=str, required=True) return parser.parse_args() if __name__ == "__main__": args = parse_args() json_dir = Path(args.json_dir).resolve() name = json_dir.name for prefix in ("metrics_output_",): if name.startswith(prefix) and len(name) > len(prefix): name = name[len(prefix) :] break dataset_name = name # e.g. "AMASS" plots_dir = json_dir / f"analysis_plots_{dataset_name}" metric_types_for_radar = {m: "lower" for m in DEFAULT_METRICS_TO_ANALYZE} analyzer = AnalysisReportGenerator( json_dir=args.json_dir, plots_dir=str(plots_dir), dataset_name=dataset_name, metrics_to_analyze=DEFAULT_METRICS_TO_ANALYZE, radar_metric_mapping=DEFAULT_RADAR_MAPPING, metric_types_for_radar=metric_types_for_radar, alpha=DEFAULT_ALPHA, ) analyzer.run() ================================================ FILE: holomotion/src/evaluation/obs/__init__.py ================================================ from .obs_builder import PolicyObsBuilder, get_gravity_orientation __all__ = [ "PolicyObsBuilder", "get_gravity_orientation", ] ================================================ FILE: holomotion/src/evaluation/obs/obs_builder.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import numpy as np import torch from typing import Dict, List, Sequence, Any, Optional def get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray: """Calculate gravity orientation from quaternion. Args: quaternion: Array-like [w, x, y, z] Returns: np.ndarray of shape (3,) representing gravity projection. """ qw = float(quaternion[0]) qx = float(quaternion[1]) qy = float(quaternion[2]) qz = float(quaternion[3]) gravity_orientation = np.zeros(3, dtype=np.float32) gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy) gravity_orientation[1] = -2.0 * (qz * qy + qw * qx) gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz) return gravity_orientation class _CircularBuffer: """History buffer for batched tensor data (batch==1 in our eval/deploy). Stores history in oldest->newest order when accessed via .buffer. """ def __init__(self, max_len: int, feat_dim: int, device: str): if max_len < 1: raise ValueError(f"max_len must be >= 1, got {max_len}") self._max_len = int(max_len) self._feat_dim = int(feat_dim) self._device = device self._pointer = -1 self._num_pushes = 0 self._buffer: torch.Tensor = torch.zeros( (self._max_len, 1, self._feat_dim), dtype=torch.float32, device="cpu", ) @property def buffer(self) -> torch.Tensor: """Tensor of shape [1, max_len, feat_dim], oldest->newest along dim=1.""" if self._num_pushes == 0: raise RuntimeError( "Attempting to read from an empty history buffer." ) # roll such that oldest is at index=0 along the history axis rolled = torch.roll( self._buffer, shifts=self._max_len - self._pointer - 1, dims=0 ) return torch.transpose(rolled, 0, 1) # [1, max_len, feat] def append(self, data: torch.Tensor) -> None: """Append one step: data shape [1, feat_dim] on the configured device.""" if ( data.ndim != 2 or data.shape[0] != 1 or data.shape[1] != self._feat_dim ): raise ValueError( f"Expected data with shape [1, {self._feat_dim}], got {tuple(data.shape)}" ) self._pointer = (self._pointer + 1) % self._max_len self._buffer[self._pointer] = data if self._num_pushes == 0: # duplicate first push across entire history for warm start self._buffer[:] = data self._num_pushes += 1 class PolicyObsBuilder: """Builds policy observations from Unitree lowstate with temporal history. Designed to be shared between MuJoCo sim2sim evaluation and ROS2 deployment. History management is internal and produces a flattened vector of size sum_i(context_length * feat_i) across the configured observation items. Supports two command modes: - "motion_tracking": uses reference motion states - "velocity_tracking": uses velocity commands [vx, vy, vyaw] """ def __init__( self, dof_names_onnx: Sequence[str], default_angles_onnx: np.ndarray, evaluator: Optional[Any] = None, obs_policy_cfg: Optional[Dict[str, Any]] = None, ) -> None: self.dof_names_onnx: List[str] = list(dof_names_onnx) self.num_actions: int = len(self.dof_names_onnx) self.evaluator = evaluator self.obs_policy_cfg = obs_policy_cfg if default_angles_onnx.shape[0] != self.num_actions: raise ValueError( "default_angles_onnx length must match num actions" ) self.default_angles_onnx = default_angles_onnx.astype(np.float32) self.default_angles_dict: Dict[str, float] = { name: float(self.default_angles_onnx[idx]) for idx, name in enumerate(self.dof_names_onnx) } # Build observation schema from config if provided self.term_specs: List[Dict[str, Any]] = [] for term_dict in self.obs_policy_cfg["atomic_obs_list"]: for name, cfg in term_dict.items(): term_dict = {**cfg} term_dict["name"] = name self.term_specs.append(term_dict) # Buffers are created lazily after first dimension inference self._buffers: Dict[str, _CircularBuffer] = {} def reset(self) -> None: for buf in self._buffers.values(): buf._pointer = -1 buf._num_pushes = 0 buf._buffer.zero_() def _compute_term( self, name: str, ) -> np.ndarray: # Prefer evaluator-provided methods; no legacy fallbacks if self.evaluator is not None: meth = getattr(self.evaluator, f"_get_obs_{name}", None) if callable(meth): out = meth() return np.asarray(out, dtype=np.float32).reshape(-1) raise ValueError( f"Unknown observation term '{name}' or evaluator method missing." ) def build_policy_obs(self) -> np.ndarray: """Append one step using evaluator-provided observation terms and return flattened obs.""" # Compute per-term outputs values: Dict[str, np.ndarray] = {} for spec in self.term_specs: name = spec["name"] scale = spec.get("scale", 1.0) values[name] = self._compute_term(name) * scale # Lazily initialize buffers with inferred feature dims if len(self._buffers) == 0: for spec in self.term_specs: name = spec["name"] hist_len = int(spec.get("history_length", 0)) if hist_len <= 0: continue feat_dim = int(values[name].reshape(-1).shape[0]) self._buffers[name] = _CircularBuffer( hist_len, feat_dim, "cpu" ) # Append current step to buffers (skip terms without history) for spec in self.term_specs: name = spec["name"] if name in self._buffers: item = torch.as_tensor( values[name].reshape(1, -1), dtype=torch.float32, device="cpu", ) self._buffers[name].append(item) # Assemble flat list according to term ordering and history flatten rules flat_list: List[np.ndarray] = [] for spec in self.term_specs: name = spec["name"] flatten = bool(spec.get("flatten", True)) if name in self._buffers: buf = self._buffers[name].buffer[0] # [hist, feat] arr = ( buf.reshape(-1).detach().cpu().numpy() if flatten else buf[-1].detach().cpu().numpy() ) flat_list.append(arr.astype(np.float32)) else: # no history -> use computed value directly flat_list.append(values[name].reshape(-1).astype(np.float32)) if len(flat_list) == 0: return np.zeros(0, dtype=np.float32) return np.concatenate(flat_list, axis=0).astype(np.float32) ================================================ FILE: holomotion/src/evaluation/ray_evaluator_actor.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Minimal Ray actor for batch eval. Lives in its own module so the class can be pickled without pulling in torch/jit from eval_mujoco_sim2sim. """ import importlib import os import sys import ray from loguru import logger class RayEvaluatorActor: """Persistent Ray actor: one evaluator (one ONNX session) per actor. Schedule with num_gpus=1/actors_per_gpu so that multiple actors share one GPU. Ray sets CUDA_VISIBLE_DEVICES so this actor sees a single GPU as device 0. """ def __init__(self, config_dict, output_dir): logger.remove() logger.add(sys.stderr, level="WARNING") self.output_dir = output_dir self.config_dict = config_dict model_type = config_dict.get("model_type") or "holomotion" self.evaluator = _load_ray_evaluator(config_dict, model_type) self.evaluator.setup() if model_type == "gmt": self.evaluator.gmt_proprio_buf.clear() def run_clip(self, file_path): from holomotion.src.evaluation.eval_mujoco_sim2sim import ( _build_onnx_io_dump_dir, _build_onnx_io_dump_path, ) fname = os.path.basename(file_path) save_name = fname.replace(".npz", "_eval.npz") save_path = os.path.join(self.output_dir, save_name) self.evaluator.load_specific_motion(file_path) self.evaluator.reset_state_teleport() for i in range(self.evaluator.n_motion_frames): self.evaluator.motion_frame_idx = i self.evaluator._update_policy() self.evaluator._apply_control(sleep=False) self.evaluator.counter += 1 meta = { "source_file": fname, "model": str(self.config_dict.get("ckpt_onnx_path", "")), "source_npz": fname, "onnx_model": str(self.config_dict.get("ckpt_onnx_path", "")), } self.evaluator.save_batch_result(save_path, meta) model_type = self.config_dict.get("model_type") or "holomotion" if bool(self.config_dict.get("dump_onnx_io_npy", False)) and ( model_type == "holomotion" ): onnx_io_dir = _build_onnx_io_dump_dir(self.output_dir) os.makedirs(onnx_io_dir, exist_ok=True) self.evaluator.save_onnx_io_dump( _build_onnx_io_dump_path(self.output_dir, fname), meta ) return "success" def _load_ray_evaluator(config_dict, model_type): module_name = config_dict.get( "ray_evaluator_module", "holomotion.src.evaluation.eval_mujoco_sim2sim", ) factory_module = importlib.import_module(module_name) return factory_module._create_ray_evaluator(config_dict, model_type) ================================================ FILE: holomotion/src/evaluation/ray_metrics_postprocess.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import shutil import sys from pathlib import Path from typing import Any import ray from loguru import logger @ray.remote def run_metrics_postprocess_job( output_dir: str, dataset_name: str, calc_per_clip_metrics: bool, failure_pos_err_thresh_m: float, metric_calculation: str, dof_mode: str, metrics_threadpool_max_workers: int | None, generate_report: bool, job_log_dir: str | None, ckpt_stem: str, ) -> dict[str, Any]: logger.remove() logger.add(sys.stderr, level="WARNING") if calc_per_clip_metrics: from holomotion.src.evaluation.metrics import run_evaluation run_evaluation( npz_dir=output_dir, dataset_suffix=dataset_name, failure_pos_err_thresh_m=failure_pos_err_thresh_m, metric_calculation=metric_calculation, dof_mode=dof_mode, threadpool_max_workers=metrics_threadpool_max_workers, ) report_path = None if generate_report: from holomotion.scripts.evaluation import mean_process_5metrics report_path = ( mean_process_5metrics.generate_macro_mean_report_from_json_dir( output_dir ) ) exported_summary_tsv = None if job_log_dir is not None: job_log_dir_path = Path(job_log_dir) sub_dataset_tsv = ( Path(output_dir) / "sub_dataset_macro_mean_metrics.tsv" ) if sub_dataset_tsv.is_file(): export_name = f"{ckpt_stem}_sub_dataset_macro_mean_metrics.tsv" export_path = job_log_dir_path / export_name shutil.copy2(sub_dataset_tsv, export_path) exported_summary_tsv = export_path return { "ckpt_stem": ckpt_stem, "output_dir": output_dir, "report_path": str(report_path) if report_path is not None else "", "exported_summary_tsv": str(exported_summary_tsv) if exported_summary_tsv is not None else "", } ================================================ FILE: holomotion/src/modules/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/modules/agent_modules.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from __future__ import annotations import io import copy import math from pathlib import Path import holomotion.src.modules.network_modules as NM import torch import torch.nn as nn import torch.nn.functional as F from holomotion.src.modules.network_modules import EmpiricalNormalization from loguru import logger from tensordict import TensorDict from tensordict.nn import TensorDictModuleBase from torch.distributions import Normal def _module_device(module: nn.Module) -> torch.device: for tensor in module.parameters(): return tensor.device for tensor in module.buffers(): return tensor.device return torch.device("cpu") def _clone_module_for_cpu_export(module: nn.Module) -> nn.Module: """Clone a module for CPU-side export without mutating the live module.""" buffer = io.BytesIO() # Keep the training module on-device; rank-local device hops during export # can desynchronize DDP state and hang later collectives. torch.save(module, buffer) buffer.seek(0) clone = torch.load(buffer, map_location="cpu", weights_only=False) clone = clone.to("cpu") clone.eval() return clone class TensorDictAssembler(torch.nn.Module): def __init__(self, schema_config: dict, *, output_mode: str = "flat"): super().__init__() self.schema_config = schema_config self.output_mode = str(output_mode).lower() if self.output_mode not in ("flat", "seq"): raise ValueError( f"output_mode must be one of {{'flat','seq'}}, got {output_mode}" ) self.seq_len_dict: dict[str, int] = { str(k): int(v.get("seq_len", 1)) for k, v in schema_config.items() } _uniq_lens = sorted(set(self.seq_len_dict.values())) self.seq_len: int | None = ( int(_uniq_lens[0]) if len(_uniq_lens) == 1 else None ) if self.output_mode == "seq" and self.seq_len is None: raise ValueError( "TensorDictAssembler(output_mode='seq') requires a single unique seq_len " f"across schema groups, got seq_len_dict={self.seq_len_dict}" ) self.output_dim: int | None = None @staticmethod def _get_from_data(data: TensorDict, key: str): # Support hierarchical keys like "latent/z" if key in data.keys(): return data.get(key) if "/" in key: current = data for p in key.split("/"): if isinstance(current, TensorDict) and p in current.keys(): current = current.get(p) else: return None return current return None def _validate_to_seq( self, tensor: torch.Tensor, seq_len: int, term: str, ) -> torch.Tensor: """Return [B, seq_len, d] tensor.""" if tensor.ndim == 2: # [B, d] treat as seq_len=1 if seq_len != 1: raise ValueError( f"Term '{term}' expected seq_len={seq_len} but tensor is 2D {tensor.shape}" ) return tensor[:, None, :] if tensor.ndim == 3: if tensor.shape[1] != seq_len: raise ValueError( f"Term '{term}' seq_len mismatch: expected {seq_len}, got {tensor.shape[1]}" ) return tensor raise ValueError( f"Term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}" ) def _validate_and_flatten( self, tensor: torch.Tensor, seq_len: int, term: str, ) -> torch.Tensor: if tensor.ndim == 2: # [B, D] treat as seq_len=1 if seq_len != 1: raise ValueError( f"Term '{term}' expected seq_len={seq_len} but tensor is 2D {tensor.shape}" ) return tensor if tensor.ndim == 3: if tensor.shape[1] != seq_len: raise ValueError( f"Term '{term}' seq_len mismatch: expected {seq_len}, got {tensor.shape[1]}" ) b, t, d = tensor.shape return tensor.reshape(b, t * d) raise ValueError( f"Term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}" ) def forward(self, data: TensorDict) -> torch.Tensor: if not isinstance(data, TensorDict): raise TypeError("TensorDictAssembler expects TensorDict input.") if self.output_mode == "flat": assembled = [] output_dim = 0 batch_size = None for _, seq_cfg in self.schema_config.items(): seq_len = int(seq_cfg.get("seq_len", 1)) terms = seq_cfg.get("terms", []) for term in terms: tensor = self._get_from_data(data, term) if tensor is None: raise KeyError( f"Missing term '{term}' in TensorDict input for assembler. " "Use explicit hierarchical terms (e.g. 'group/term') " "for nested TensorDict keys." ) flat = self._validate_and_flatten(tensor, seq_len, term) if batch_size is None: batch_size = flat.shape[0] elif flat.shape[0] != batch_size: raise ValueError( f"Batch size mismatch for term '{term}': {flat.shape[0]} vs {batch_size}" ) assembled.append(flat) output_dim += flat.shape[-1] if not assembled: raise ValueError( "Assembler received an empty schema or no tensors found" ) out = torch.cat(assembled, dim=-1) # Cache output_dim on first successful forward if self.output_dim is None: self.output_dim = output_dim return out # output_mode == "seq" assembled_seq = [] batch_size = None seq_len_ref = None for _, seq_cfg in self.schema_config.items(): seq_len = int(seq_cfg.get("seq_len", 1)) if seq_len_ref is None: seq_len_ref = seq_len elif seq_len != seq_len_ref: raise ValueError( "TensorDictAssembler(output_mode='seq') requires consistent seq_len " f"across schema groups, got {seq_len_ref} vs {seq_len}" ) terms = seq_cfg.get("terms", []) for term in terms: tensor = self._get_from_data(data, term) if tensor is None: raise KeyError( f"Missing term '{term}' in TensorDict input for assembler. " "Use explicit hierarchical terms (e.g. 'group/term') " "for nested TensorDict keys." ) seq_tensor = self._validate_to_seq(tensor, seq_len, term) if batch_size is None: batch_size = seq_tensor.shape[0] elif seq_tensor.shape[0] != batch_size: raise ValueError( f"Batch size mismatch for term '{term}': {seq_tensor.shape[0]} vs {batch_size}" ) assembled_seq.append(seq_tensor) if not assembled_seq: raise ValueError( "Assembler received an empty schema or no tensors found" ) out = torch.cat(assembled_seq, dim=-1) # Expose seq_len and output_dim for sequence assembly if self.seq_len is None: self.seq_len = int(out.shape[1]) if self.output_dim is None: self.output_dim = int(out.shape[-1]) return out @torch.inference_mode() def infer_output_dim(self, sample: TensorDict) -> int: """Run a dry forward pass to populate output_dim without grads.""" if self.output_dim is not None: return int(self.output_dim) _ = self.forward(sample) return self.output_dim class PPOActorOnnxModule(nn.Module): def __init__( self, actor_module: nn.Module, obs_normalizer: nn.Module, obs_norm_enabled: bool, obs_norm_clip: float, ): super().__init__() self.actor_module = actor_module self.obs_normalizer = obs_normalizer self.obs_norm_enabled = bool(obs_norm_enabled) self.obs_norm_clip = float(obs_norm_clip) def forward(self, obs: torch.Tensor) -> torch.Tensor: actor_obs = obs if self.obs_norm_enabled: actor_obs = self.obs_normalizer.normalize_only(actor_obs) if self.obs_norm_clip > 0.0: actor_obs = torch.clamp( actor_obs, -self.obs_norm_clip, self.obs_norm_clip ) return self.actor_module(actor_obs) class PPOTFActorOnnxModule(nn.Module): def __init__( self, actor_module: nn.Module, obs_normalizer: nn.Module, obs_norm_enabled: bool, obs_norm_clip: float, ): super().__init__() self.actor_module = actor_module self.obs_normalizer = obs_normalizer self.obs_norm_enabled = bool(obs_norm_enabled) self.obs_norm_clip = float(obs_norm_clip) def forward( self, obs: torch.Tensor, past_key_values: torch.Tensor, step_idx: torch.Tensor, ) -> tuple[torch.Tensor, ...]: actor_obs = obs if self.obs_norm_enabled: actor_obs = self.obs_normalizer.normalize_only(actor_obs) if self.obs_norm_clip > 0.0: actor_obs = torch.clamp( actor_obs, -self.obs_norm_clip, self.obs_norm_clip ) return self.actor_module( actor_obs, past_key_values=past_key_values, current_pos=step_idx, ) class PPOTFWoKVCacheActorOnnxModule(nn.Module): def __init__( self, actor_module: nn.Module, obs_normalizer: nn.Module, obs_norm_enabled: bool, obs_norm_clip: float, ): super().__init__() self.actor_module = actor_module self.obs_normalizer = obs_normalizer self.obs_norm_enabled = bool(obs_norm_enabled) self.obs_norm_clip = float(obs_norm_clip) def forward(self, obs: torch.Tensor) -> torch.Tensor: if obs.ndim != 3: raise ValueError( f"Expected obs [B, 32, D] for no-kv ONNX path, got {obs.shape}" ) if obs.shape[1] != 32: raise ValueError( f"Expected fixed token length 32, got {int(obs.shape[1])}" ) actor_obs = obs if self.obs_norm_enabled: actor_obs = self.obs_normalizer.normalize_only(actor_obs) if self.obs_norm_clip > 0.0: actor_obs = torch.clamp( actor_obs, -self.obs_norm_clip, self.obs_norm_clip ) action_seq = self.actor_module.sequence_mu(actor_obs, attn_mask=None) return action_seq[:, -1, :] class PPOCondTFActorOnnxModule(nn.Module): def __init__( self, actor_module: nn.Module, state_obs_normalizer: nn.Module, obs_norm_enabled: bool, obs_norm_clip: float, state_dim: int, future_seq_len: int, future_token_dim: int, future_term_dims: list[int], ): super().__init__() self.actor_module = actor_module self.state_obs_normalizer = state_obs_normalizer self.obs_norm_enabled = bool(obs_norm_enabled) self.obs_norm_clip = float(obs_norm_clip) self.state_dim = int(state_dim) self.future_seq_len = int(future_seq_len) self.future_token_dim = int(future_token_dim) self.future_term_dims = [int(x) for x in future_term_dims] if any(d <= 0 for d in self.future_term_dims): raise ValueError( f"future_term_dims must be all positive, got {self.future_term_dims}" ) if sum(self.future_term_dims) != self.future_token_dim: raise ValueError( "future_term_dims sum mismatch: expected " f"{self.future_token_dim}, got {sum(self.future_term_dims)}" ) def forward( self, obs: torch.Tensor, past_key_values: torch.Tensor, step_idx: torch.Tensor, ) -> tuple[torch.Tensor, ...]: if obs.ndim != 2: raise ValueError(f"Expected obs [B, D], got {obs.shape}") state_obs = obs[:, : self.state_dim] future_flat = obs[:, self.state_dim :] expected_future_dim = self.future_seq_len * self.future_token_dim if future_flat.shape[-1] != expected_future_dim: raise ValueError( "Future obs dim mismatch for ONNX path: expected " f"{expected_future_dim}, got {future_flat.shape[-1]}" ) if self.obs_norm_enabled: state_obs = self.state_obs_normalizer.normalize_only(state_obs) if self.obs_norm_clip > 0.0: state_obs = torch.clamp( state_obs, -self.obs_norm_clip, self.obs_norm_clip ) # Reconstruct [B, N_fut, D_fut] from term-major flattened layout: # [term1 (N_fut*d1), term2 (N_fut*d2), ...] -> per-step concat along last dim. b = int(obs.shape[0]) offset = 0 future_parts = [] for d_term in self.future_term_dims: span = int(self.future_seq_len * d_term) chunk = future_flat[:, offset : offset + span] future_parts.append(chunk.reshape(b, self.future_seq_len, d_term)) offset += span if offset != int(future_flat.shape[-1]): raise ValueError( "Future flat slicing mismatch in ONNX path: " f"consumed={offset}, total={int(future_flat.shape[-1])}" ) future_obs = torch.cat(future_parts, dim=-1) return self.actor_module._forward_inference_onnx_cond( state_obs, future_obs, past_key_values, step_idx, ) class PPOActor(TensorDictModuleBase): def __init__( self, obs_schema: dict | None, module_config_dict: dict, num_actions: int, init_noise_std: float, *, obs_example: dict | None = None, ): super(PPOActor, self).__init__() self.use_logvar = module_config_dict.get("use_logvar", False) obs_norm_cfg = module_config_dict.get("obs_norm", {}) self.obs_norm_enabled = bool(obs_norm_cfg.get("enabled", False)) if self.obs_norm_enabled: self.obs_norm_clip = float(obs_norm_cfg.get("clip_range", 0.0)) self.obs_norm_eps = float(obs_norm_cfg.get("epsilon", 1.0e-8)) self.obs_norm_update_method = str( obs_norm_cfg.get( "update_method", obs_norm_cfg.get("method", "cumulative") ) ).lower() self.obs_norm_ema_momentum = float( obs_norm_cfg.get("ema_momentum") ) module_config_dict = self._process_module_config( module_config_dict, num_actions, ) self.actor_net_type = module_config_dict.get("type", "MLP") logger.info(f"actor_net_type: {self.actor_net_type}") actor_net_class = getattr(NM, self.actor_net_type, None) if actor_net_class is NM.MLP and obs_schema is None: raise ValueError( "PPOActor(Mlp) requires obs_schema so the agent module can assemble" "TensorDict observations into a flat tensor." ) if obs_schema is not None: output_mode = "seq" if actor_net_class is NM.ConvMLP else "flat" self.assembler = TensorDictAssembler( obs_schema, output_mode=output_mode ) if obs_example is not None: self.assembler.infer_output_dim(obs_example) if self.assembler.output_dim is None: raise ValueError( "TensorDictAssembler could not infer output_dim" ) input_dim_for_net = int(self.assembler.output_dim) else: raise ValueError("obs_schema can't be None!") actor_in_keys: list[str] = [] for _, seq_cfg in obs_schema.items(): if not isinstance(seq_cfg, dict): continue for term in seq_cfg.get("terms", []): actor_in_keys.append(str(term)) self.in_keys = actor_in_keys self.out_keys = [ "actions", "actions_log_prob", "mu", "sigma", "entropy", ] if self.obs_norm_enabled and self.assembler is not None: self.obs_normalizer = EmpiricalNormalization( shape=self.assembler.output_dim, eps=self.obs_norm_eps, update_method=self.obs_norm_update_method, ema_momentum=self.obs_norm_ema_momentum, ) else: self.obs_normalizer = nn.Identity() # Always pass obs_example if available if obs_example is not None: self.actor_module = actor_net_class( input_dim=input_dim_for_net, output_dim=int(module_config_dict["output_dim"]), module_config_dict=module_config_dict, ) else: raise ValueError("Obs example can't be None!") if "output_head_init_scale" in module_config_dict: output_head_init_scale = float( module_config_dict["output_head_init_scale"] ) if output_head_init_scale <= 0.0: raise ValueError("output_head_init_scale must be > 0.") output_head = self.actor_module.output_head if not isinstance(output_head, nn.Linear): raise ValueError( "output_head_init_scale requires actor_module.output_head to be nn.Linear." ) with torch.no_grad(): output_head.weight.mul_(output_head_init_scale) if output_head.bias is not None: output_head.bias.mul_(output_head_init_scale) self._actor_schema_module = bool( getattr(self.actor_module, "proprio_assembler", None) ) self.fix_sigma = module_config_dict.get("fix_sigma", False) self.max_sigma = module_config_dict.get("max_sigma", 1.0) self.min_sigma = module_config_dict.get("min_sigma", 0.1) if "noise_std_type" in module_config_dict: self.noise_std_type = str( module_config_dict["noise_std_type"] ).lower() elif self.use_logvar: self.noise_std_type = "log" else: self.noise_std_type = "scalar" # Action noise parameters (kept outside nets so optimizer updates them) if self.noise_std_type == "log": logger.info("Using log-std parameterization for action noise") self.log_std = nn.Parameter( torch.log(torch.ones(num_actions) * init_noise_std) ) if self.fix_sigma: self.log_std.requires_grad = False else: # scalar (default) self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) if self.fix_sigma: self.std.requires_grad = False self.distribution = None # disable args validation for speedup Normal.set_default_validate_args = False self.actor_obs_transforms: list[callable] = [] if self.obs_norm_enabled: self.actor_obs_transforms.append(self._normalize_actor_obs) def _process_module_config(self, module_config_dict, num_actions): if module_config_dict.get("output_schema", None) is not None: raise ValueError( "PPOActor no longer supports module_config_dict.output_schema. " "Use scalar module_config_dict.output_dim instead." ) # Resolve output_dim placeholders when present. if "output_dim" in module_config_dict: output_dim = module_config_dict["output_dim"] if isinstance(output_dim, list): raise ValueError( "PPOActor expects module_config_dict.output_dim to be a scalar. " "List-valued output_dim is not supported." ) if output_dim == "robot_action_dim": module_config_dict["output_dim"] = num_actions return module_config_dict def _sigma_from_params(self) -> torch.Tensor: if self.noise_std_type == "log": return torch.exp(self.log_std) return self.std def _normalize_actor_obs( self, obs: torch.Tensor, update: bool ) -> torch.Tensor: if not self.obs_norm_enabled: return obs clip = float(self.obs_norm_clip) if obs.ndim == 3: b, seq_len, d = obs.shape flat_obs = obs.reshape(b * seq_len, d) if update: self.obs_normalizer.update(flat_obs) flat_obs = self.obs_normalizer.normalize_only(flat_obs) obs = flat_obs.reshape(b, seq_len, d) else: if update: self.obs_normalizer.update(obs) obs = self.obs_normalizer.normalize_only(obs) if clip > 0.0: obs = torch.clamp(obs, -clip, clip) return obs def _sigma_like(self, like: torch.Tensor) -> torch.Tensor: sigma_vec = self._sigma_from_params() sigma_vec = torch.clamp( sigma_vec, min=float(self.min_sigma), max=float(self.max_sigma), ) if sigma_vec.ndim == 1 and like.ndim >= 2: view_shape = [1 for _ in range(like.ndim - 1)] + [ sigma_vec.shape[0] ] return sigma_vec.view(*view_shape).expand_as(like) if sigma_vec.shape != like.shape: return sigma_vec.expand_as(like) return sigma_vec @property def actor(self): return self.actor_module @property def flat_obs_dim(self) -> int: if self.assembler is None: raise ValueError( "PPOActor has no assembler; flat obs dim unavailable." ) if self.assembler.output_dim is None: raise ValueError( "PPOActor assembler output_dim is not initialized." ) return int(self.assembler.output_dim) def export_onnx( self, onnx_path: str | Path, *, opset_version: int = 17, ) -> str: if self._actor_schema_module: raise ValueError( "PPOActor export expects flat-obs actor modules, not schema-native modules." ) export_path = Path(onnx_path) export_path.parent.mkdir(parents=True, exist_ok=True) if hasattr(self.actor_module, "clear_router_distribution_cache"): self.actor_module.clear_router_distribution_cache() actor_module = _clone_module_for_cpu_export(self.actor_module) if self.obs_norm_enabled: obs_normalizer = _clone_module_for_cpu_export(self.obs_normalizer) else: obs_normalizer = nn.Identity() exporter = PPOActorOnnxModule( actor_module=actor_module, obs_normalizer=obs_normalizer, obs_norm_enabled=self.obs_norm_enabled, obs_norm_clip=self.obs_norm_clip if self.obs_norm_enabled else 0.0, ).to("cpu") exporter.eval() obs = torch.zeros( 1, self.flat_obs_dim, device="cpu", dtype=torch.float32 ) torch.onnx.export( exporter, (obs,), str(export_path), export_params=True, opset_version=opset_version, verbose=False, dynamo=False, input_names=["obs"], output_names=["actions"], ) return str(export_path) def forward( self, obs_td: TensorDict, actions: torch.Tensor | None = None, mode: str = "sampling", *, update_obs_norm: bool = True, ) -> TensorDict: """TensorDict-first forward for PPOActor. Returns a TensorDict with keys: - actions: [B, A] - actions_log_prob: [B] (sampling/logp only) - mu: [B, A] - sigma: [B, A] - entropy: [B] (sampling/logp only) """ if mode not in ("sampling", "logp", "inference"): raise ValueError(f"Unsupported mode: {mode}") if not isinstance(obs_td, TensorDict): raise ValueError("PPOActor.forward expects TensorDict input.") td = obs_td.clone( recurse=False ) # this only clones the tree sturcture, not the data if self._actor_schema_module: mu = self.actor_module(obs_td) else: if self.assembler is None: raise ValueError( "Flat-tensor actor module requires obs_schema in PPOActor init." ) actor_obs = self.assembler(obs_td) update = bool(update_obs_norm) for fn in self.actor_obs_transforms: actor_obs = fn(actor_obs, update) mu = self.actor_module(actor_obs) sigma = self._sigma_like(mu) td.set("mu", mu) td.set("sigma", sigma) if mode == "inference": actions_out = mu td.set("actions", actions_out) return td self.distribution = Normal(mu, sigma) if mode == "sampling": actions_out = self.distribution.sample() else: if actions is None: raise ValueError("actions must be provided when mode='logp'") actions_out = actions td.set("actions", actions_out) td.set( "actions_log_prob", self.distribution.log_prob(actions_out).sum(dim=-1), ) td.set("entropy", self.distribution.entropy().sum(dim=-1)) return td def update_distribution(self, actor_obs): mean = self.actor(actor_obs) # Resolve std according to parameterization std_val = self._sigma_from_params() std_val = torch.clamp(std_val, min=self.min_sigma, max=self.max_sigma) self.distribution = Normal(mean, std_val) def override_sigma(self, sigma_override: float | torch.Tensor) -> None: """Override actor sigma parameters (std) explicitly. Args: sigma_override: scalar or [A] tensor for sigma_theta (std). """ if self.noise_std_type not in ("scalar", "log"): raise ValueError( f"Unsupported noise_std_type for override: {self.noise_std_type}" ) param = self.log_std if self.noise_std_type == "log" else self.std sigma_tensor = torch.as_tensor( sigma_override, device=param.device, dtype=param.dtype ) if sigma_tensor.numel() == 1: sigma_tensor = sigma_tensor.expand_as(param) elif sigma_tensor.shape != param.shape: raise ValueError( f"sigma_override shape {tuple(sigma_tensor.shape)} does not match " f"actor sigma shape {tuple(param.shape)}." ) if torch.any(sigma_tensor <= 0): raise ValueError("sigma_override must be > 0 for all dims.") if self.noise_std_type == "log": sigma_tensor = torch.log(sigma_tensor) with torch.no_grad(): param.copy_(sigma_tensor) class PPOCritic(TensorDictModuleBase): def __init__( self, obs_schema: dict | None, module_config_dict, *, obs_example: dict | None = None, ): super(PPOCritic, self).__init__() self.critic_net_type = module_config_dict.get("type", "MLP") obs_norm_cfg = module_config_dict.get("obs_norm", {}) self.obs_norm_enabled = bool(obs_norm_cfg.get("enabled", False)) if self.obs_norm_enabled: self.obs_norm_clip = float(obs_norm_cfg.get("clip_range", 0.0)) self.obs_norm_eps = float(obs_norm_cfg.get("epsilon", 1.0e-8)) self.obs_norm_update_method = str( obs_norm_cfg.get( "update_method", obs_norm_cfg.get("method", "cumulative") ) ).lower() self.obs_norm_ema_momentum = float( obs_norm_cfg.get("ema_momentum") ) critic_net_class = getattr(NM, self.critic_net_type, None) if critic_net_class is None: critic_net_class = globals().get(self.critic_net_type, None) if critic_net_class is None or not isinstance(critic_net_class, type): available_classes = [ name for name in dir(NM) if isinstance(getattr(NM, name, None), type) ] + [ name for name, obj in globals().items() if isinstance(obj, type) ] raise NotImplementedError( f"Unknown critic_net_type: {self.critic_net_type}. " f"Available classes: {available_classes}" ) if critic_net_class is NM.MLP and obs_schema is None: raise ValueError( "PPOCritic(MLP) requires obs_schema so the agent module can assemble " "TensorDict observations into a flat tensor." ) # Build assembler for flat-tensor networks only # Schema-based networks (e.g., MultiTaskCritic) don't need it if obs_schema is not None: output_mode = "seq" if critic_net_class is NM.ConvMLP else "flat" self.assembler = TensorDictAssembler( obs_schema, output_mode=output_mode ) if obs_example is not None: self.assembler.infer_output_dim(obs_example) if self.assembler.output_dim is None: raise ValueError( "TensorDictAssembler could not infer output_dim; provide obs_example." ) input_dim_for_net = int(self.assembler.output_dim) else: # Schema-based modules don't use wrapper's assembler self.assembler = None input_dim_for_net = 0 critic_in_keys: list[str] = [] if obs_schema is not None: for _, seq_cfg in obs_schema.items(): if not isinstance(seq_cfg, dict): continue for term in seq_cfg.get("terms", []): critic_in_keys.append(str(term)) self.in_keys = critic_in_keys self.out_keys = ["values"] if self.obs_norm_enabled and self.assembler is not None: self.obs_normalizer = EmpiricalNormalization( shape=self.assembler.output_dim, eps=self.obs_norm_eps, update_method=self.obs_norm_update_method, ema_momentum=self.obs_norm_ema_momentum, ) else: self.obs_normalizer = nn.Identity() # Always pass obs_example if available if obs_example is not None: self.critic_module = critic_net_class( input_dim=input_dim_for_net, output_dim=int(module_config_dict["output_dim"]), module_config_dict=module_config_dict, ) else: raise ValueError("obs_schema can't be None!") self._critic_schema_module = bool( getattr(self.critic_module, "proprio_assembler", None) ) self.critic_obs_transforms: list[callable] = [] if self.obs_norm_enabled: self.critic_obs_transforms.append(self._normalize_critic_obs) def _normalize_critic_obs( self, obs: torch.Tensor, update: bool ) -> torch.Tensor: if not self.obs_norm_enabled: return obs clip = float(self.obs_norm_clip) if obs.ndim == 3: b, seq_len, d = obs.shape flat_obs = obs.reshape(b * seq_len, d) if update: self.obs_normalizer.update(flat_obs) flat_obs = self.obs_normalizer.normalize_only(flat_obs) obs = flat_obs.reshape(b, seq_len, d) else: if update: self.obs_normalizer.update(obs) obs = self.obs_normalizer.normalize_only(obs) if clip > 0.0: obs = torch.clamp(obs, -clip, clip) return obs def forward( self, obs_td: TensorDict, update_obs_norm: bool = True, **kwargs, ) -> TensorDict: """TensorDict-first forward for PPOCritic. Args: obs_td: TensorDict observations keyed by obs terms. update_obs_norm: If False, skip updating running stats. Returns: TensorDict with key: - "values": [B, 1] """ if not isinstance(obs_td, TensorDict): raise ValueError("PPOCritic.forward expects TensorDict input.") td = obs_td.clone(recurse=False) if self._critic_schema_module: values = self.critic_module(obs_td) if values.ndim == 1: values = values[..., None] td.set("values", values) return td if self.assembler is None: raise ValueError( "Flat-tensor critic module requires obs_schema in PPOCritic init." ) critic_obs = self.assembler(obs_td) update = bool(update_obs_norm) for fn in self.critic_obs_transforms: critic_obs = fn(critic_obs, update) values = self.critic_module(critic_obs) if values.ndim == 1: values = values[..., None] td.set("values", values) return td class PPOTFActor(PPOActor): """Transformer-based PPO actor wrapper compatible with PPOActor interface. - Uses NM.TransformerDecoderPolicy as actor_module - Provides KV-cache controls - Uses model-predicted diagonal std for distribution """ def __init__( self, obs_schema: dict | None, module_config_dict: dict, num_actions: int, init_noise_std: float, *, obs_example: dict | None = None, ): super().__init__( obs_schema=obs_schema, module_config_dict=module_config_dict, num_actions=num_actions, init_noise_std=init_noise_std, obs_example=obs_example, ) # Ensure initial std is strictly inside [min_sigma, max_sigma] to avoid boundary saturation init_std_val = float(init_noise_std) if not (self.min_sigma < init_std_val < self.max_sigma): # Expand bounds conservatively if needed if init_std_val >= self.max_sigma: self.max_sigma = max(self.max_sigma, init_std_val * 2.0) if init_std_val <= self.min_sigma: self.min_sigma = min(self.min_sigma, init_std_val * 0.1) aux_cfg = module_config_dict.get("aux_state_pred", {}) self.aux_state_pred_enabled = bool(aux_cfg.get("enabled", False)) aux_cmd_cfg = module_config_dict.get("aux_router_command_recon", {}) self.aux_router_command_recon_enabled = bool( aux_cmd_cfg.get("enabled", False) ) aux_switch_cfg = module_config_dict.get( "aux_router_switch_penalty", {} ) self.aux_router_switch_penalty_enabled = bool( aux_switch_cfg.get("enabled", False) ) aux_router_future_cfg = module_config_dict.get( "aux_router_future_recon", {} ) self.aux_router_future_recon_enabled = bool( aux_router_future_cfg.get("enabled", False) ) self.aux_router_future_recon_assembler: TensorDictAssembler | None = ( None ) def _sigma_from_params(self) -> torch.Tensor: # Prefer log-std if present; otherwise use softplus(linear) for positivity if hasattr(self, "log_std"): return torch.exp(self.log_std) return F.softplus(self.std) def reset_kv_cache(self, num_envs: int, device): if hasattr(self.actor_module, "reset_kv_cache"): self.actor_module.reset_kv_cache(num_envs, device) def clear_env_cache(self, env_ids: torch.Tensor): if hasattr(self.actor_module, "clear_env_cache"): self.actor_module.clear_env_cache(env_ids) def onnx_past_key_values_shape( self, *, batch_size: int = 1 ) -> tuple[int, int, int, int, int, int]: num_kv_layers = int( getattr( self.actor_module, "onnx_kv_layers", self.actor_module.n_layers ) ) return ( num_kv_layers, 2, int(batch_size), int(self.actor_module.max_ctx_len), int(self.actor_module.n_kv_heads), int(self.actor_module.head_dim), ) def onnx_moe_layer_indices(self) -> list[int]: layers = getattr(self.actor_module, "layers", None) if layers is None: return [] return [ layer_idx for layer_idx, layer in enumerate(layers) if isinstance(layer, NM.GroupedMoEBlock) ] def onnx_routing_output_names(self) -> list[str]: output_names: list[str] = [] for layer_idx in self.onnx_moe_layer_indices(): output_names.extend( [ f"moe_layer_{layer_idx}_expert_indices", f"moe_layer_{layer_idx}_expert_logits", ] ) return output_names def _maybe_update_aux_router_future_recon_norm( self, obs_td: TensorDict, *, update: bool, ) -> None: if ( not update or not self.aux_router_future_recon_enabled or self.aux_router_future_recon_assembler is None ): return future_target = self.aux_router_future_recon_assembler(obs_td) self.actor_module.update_aux_router_future_recon_normalizer( future_target ) def export_onnx( self, onnx_path: str | Path, *, opset_version: int = 17, use_kv_cache: bool = True, ) -> str: export_path = Path(onnx_path) export_path.parent.mkdir(parents=True, exist_ok=True) if hasattr(self.actor_module, "clear_router_distribution_cache"): self.actor_module.clear_router_distribution_cache() actor_module = _clone_module_for_cpu_export(self.actor_module) if self.obs_norm_enabled: obs_normalizer = _clone_module_for_cpu_export(self.obs_normalizer) else: obs_normalizer = nn.Identity() obs = torch.zeros( 1, self.flat_obs_dim, device="cpu", dtype=torch.float32 ) if use_kv_cache: exporter = PPOTFActorOnnxModule( actor_module=actor_module, obs_normalizer=obs_normalizer, obs_norm_enabled=self.obs_norm_enabled, obs_norm_clip=self.obs_norm_clip if self.obs_norm_enabled else 0.0, ).to("cpu") exporter.eval() cache_shape = self.onnx_past_key_values_shape(batch_size=1) past_key_values = torch.zeros( *cache_shape, device="cpu", dtype=torch.float32 ) step_idx = torch.tensor([0], dtype=torch.long, device="cpu") output_names = [ "actions", "present_key_values", *self.onnx_routing_output_names(), ] torch.onnx.export( exporter, (obs, past_key_values, step_idx), str(export_path), export_params=True, opset_version=opset_version, verbose=False, dynamo=False, input_names=["obs", "past_key_values", "step_idx"], output_names=output_names, ) else: exporter = PPOTFWoKVCacheActorOnnxModule( actor_module=actor_module, obs_normalizer=obs_normalizer, obs_norm_enabled=self.obs_norm_enabled, obs_norm_clip=self.obs_norm_clip if self.obs_norm_enabled else 0.0, ).to("cpu") exporter.eval() obs = torch.zeros( 1, 32, self.flat_obs_dim, device="cpu", dtype=torch.float32 ) torch.onnx.export( exporter, (obs,), str(export_path), export_params=True, opset_version=opset_version, verbose=False, dynamo=False, input_names=["obs"], output_names=["actions"], ) return str(export_path) def update_distribution(self, actor_obs): """Distribution using TransformerDecoderPolicy single-step mu + learnable log-std. Args: actor_obs: [B, D] normalized obs """ mu = self.actor_module.single_step_mu(actor_obs) std = self._sigma_from_params() std = torch.clamp(std, min=self.min_sigma, max=self.max_sigma) self.distribution = Normal(mu, std) def forward( self, obs_td: TensorDict | torch.Tensor, actions: torch.Tensor | None = None, mode: str = "sampling", attn_mask: torch.Tensor | None = None, *, update_obs_norm: bool = True, past_key_values: torch.Tensor | None = None, current_pos: torch.Tensor | None = None, ) -> TensorDict | tuple[torch.Tensor, torch.Tensor]: """TensorDict-first forward for PPOTFActor. Modes: - "sampling" / "logp" / "inference": single-step policy with KV-cache-aware mean prediction via `actor_module.single_step_mu`. - "sequence_logp": sequence log-prob evaluation with attention mask support. """ if past_key_values is not None: if isinstance(obs_td, TensorDict): if self.assembler is None: raise ValueError( "PPOTFActor requires obs_schema/assembler for ONNX cache path." ) actor_obs = self.assembler(obs_td) else: actor_obs = obs_td return self.actor_module( actor_obs, past_key_values=past_key_values, current_pos=current_pos, ) if mode == "sequence_logp": if not isinstance(obs_td, TensorDict): raise ValueError( "PPOTFActor.forward(mode='sequence_logp') expects TensorDict input." ) if obs_td.batch_dims != 2: raise ValueError( "PPOTFActor.forward(mode='sequence_logp') expects TensorDict with " f"batch_dims=2 [B, T], got batch_size={tuple(obs_td.batch_size)}" ) if self.assembler is None: raise ValueError( "PPOTFActor requires obs_schema to assemble sequence observations." ) if actions is None: raise ValueError( "actions must be provided when mode='sequence_logp'" ) b, t = int(obs_td.batch_size[0]), int(obs_td.batch_size[1]) flat_td = obs_td.flatten(0, 1) actor_obs_flat = self.assembler(flat_td) update = bool(update_obs_norm) for fn in self.actor_obs_transforms: actor_obs_flat = fn(actor_obs_flat, update) self._maybe_update_aux_router_future_recon_norm( flat_td, update=update ) actor_obs_seq = actor_obs_flat.reshape(b, t, -1) if actor_obs_seq.ndim != 3: raise ValueError( "PPOTFActor forward(mode='sequence_logp') expects actor_obs " f"with shape [B, T, D], got {actor_obs_seq.shape}" ) mu, sigma, logp, entropy, aux_preds = self.sequence_forward_logp( actor_obs_seq, actions, attn_mask ) td = obs_td.clone(recurse=False) td.set("mu", mu) td.set("sigma", sigma) td.set("actions", actions) td.set("actions_log_prob", logp) td.set("entropy", entropy) if aux_preds is not None: if "base_lin_vel_loc" in aux_preds: td.set( "aux_base_lin_vel_loc", aux_preds["base_lin_vel_loc"] ) td.set( "aux_base_lin_vel_log_std", aux_preds["base_lin_vel_log_std"], ) td.set("aux_root_height_loc", aux_preds["root_height_loc"]) td.set( "aux_root_height_log_std", aux_preds["root_height_log_std"], ) td.set( "aux_keybody_contact_logits", aux_preds["keybody_contact_logits"], ) td.set( "aux_ref_keybody_rel_pos", aux_preds["ref_keybody_rel_pos"], ) td.set( "aux_robot_keybody_rel_pos", aux_preds["robot_keybody_rel_pos"], ) if "denoise_ref_root_lin_vel_residual" in aux_preds: td.set( "aux_denoise_ref_root_lin_vel_residual", aux_preds["denoise_ref_root_lin_vel_residual"], ) if "denoise_ref_root_ang_vel_residual" in aux_preds: td.set( "aux_denoise_ref_root_ang_vel_residual", aux_preds["denoise_ref_root_ang_vel_residual"], ) if "denoise_ref_dof_pos_residual" in aux_preds: td.set( "aux_denoise_ref_dof_pos_residual", aux_preds["denoise_ref_dof_pos_residual"], ) if "router_command_recon" in aux_preds: td.set( "aux_router_command_recon", aux_preds["router_command_recon"], ) if "router_future_recon" in aux_preds: td.set( "aux_router_future_recon", aux_preds["router_future_recon"], ) if "router_features" in aux_preds: td.set("router_features", aux_preds["router_features"]) if "router_temporal_features" in aux_preds: td.set( "router_temporal_features", aux_preds["router_temporal_features"], ) return td if mode not in ("sampling", "logp", "inference"): raise ValueError(f"Unsupported mode: {mode}") if not isinstance(obs_td, TensorDict): raise ValueError("PPOTFActor.forward expects TensorDict input.") if self.assembler is None: raise ValueError( "Flat-tensor actor module requires obs_schema in PPOTFActor init." ) td = obs_td.clone(recurse=False) actor_obs = self.assembler(obs_td) update = bool(update_obs_norm) for fn in self.actor_obs_transforms: actor_obs = fn(actor_obs, update) self._maybe_update_aux_router_future_recon_norm(obs_td, update=update) if hasattr(self.actor_module, "single_step_mu"): mu = self.actor_module.single_step_mu(actor_obs) else: mu = self.actor_module(actor_obs) sigma = self._sigma_like(mu) td.set("mu", mu) td.set("sigma", sigma) if mode == "inference": td.set("actions", mu) return td self.distribution = Normal(mu, sigma) if mode == "sampling": actions_out = self.distribution.sample() else: if actions is None: raise ValueError("actions must be provided when mode='logp'") actions_out = actions td.set("actions", actions_out) td.set( "actions_log_prob", self.distribution.log_prob(actions_out).sum(dim=-1), ) td.set("entropy", self.distribution.entropy().sum(dim=-1)) return td def sequence_forward_logp( self, obs_seq: torch.Tensor, actions: torch.Tensor, attn_mask: torch.Tensor | None, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor] | None, ]: """Sequence log-prob path with learnable per-action log-std. Args: obs_seq: [B, T, D] actions: [B, T, A] attn_mask: [B, T, T] boolean (True if attend allowed) Returns: mu: [B, T, A], sigma: [B, T, A], logp: [B, T, 1], entropy: [B, T, 1] """ aux_preds = None aux_router_future_recon_enabled = bool( getattr(self, "aux_router_future_recon_enabled", False) ) need_pre_moe_aux = self.aux_state_pred_enabled need_router_features = ( self.aux_router_command_recon_enabled or self.aux_router_switch_penalty_enabled ) need_router_aux = ( need_router_features or aux_router_future_recon_enabled ) need_ref_aux_hidden = bool( (need_pre_moe_aux or aux_router_future_recon_enabled) and getattr( self.actor_module, "supports_explicit_ref_aux_hidden", False ) ) if need_pre_moe_aux and need_router_aux: sequence_mu_kwargs = { "attn_mask": attn_mask, "return_pre_moe_hidden": True, "return_router_features": need_router_features, "return_router_temporal_features": self.aux_router_switch_penalty_enabled, } if need_ref_aux_hidden: sequence_mu_kwargs["return_ref_aux_hidden"] = True actor_outputs = self.actor_module.sequence_mu( obs_seq, **sequence_mu_kwargs, ) output_parts = list(actor_outputs) mu = output_parts.pop(0) pre_moe_hidden = output_parts.pop(0) ref_aux_hidden = ( output_parts.pop(0) if need_ref_aux_hidden else None ) router_features = ( output_parts.pop(0) if need_router_features else None ) router_temporal_features = ( output_parts.pop(0) if self.aux_router_switch_penalty_enabled else None ) aux_preds = self.actor_module.predict_aux_from_pre_moe( pre_moe_hidden, ref_aux_hidden=ref_aux_hidden if need_ref_aux_hidden else None, ) if router_features is not None: aux_preds["router_features"] = router_features if router_temporal_features is not None: aux_preds["router_temporal_features"] = ( router_temporal_features ) if self.aux_router_command_recon_enabled: aux_preds["router_command_recon"] = ( self.actor_module.predict_aux_router_command_from_router_features( router_features ) ) if aux_router_future_recon_enabled: aux_preds["router_future_recon"] = ( self.actor_module.predict_aux_router_future_recon_from_router_hidden( ref_aux_hidden ) ) elif need_pre_moe_aux: sequence_mu_kwargs = { "attn_mask": attn_mask, "return_pre_moe_hidden": True, } if need_ref_aux_hidden: sequence_mu_kwargs["return_ref_aux_hidden"] = True actor_outputs = self.actor_module.sequence_mu( obs_seq, **sequence_mu_kwargs, ) if need_ref_aux_hidden: mu, pre_moe_hidden, ref_aux_hidden = actor_outputs else: mu, pre_moe_hidden = actor_outputs aux_preds = self.actor_module.predict_aux_from_pre_moe( pre_moe_hidden, ref_aux_hidden=ref_aux_hidden if need_ref_aux_hidden else None, ) elif need_router_aux: sequence_mu_kwargs = { "attn_mask": attn_mask, "return_router_features": need_router_features, "return_router_temporal_features": self.aux_router_switch_penalty_enabled, } if need_ref_aux_hidden: sequence_mu_kwargs["return_ref_aux_hidden"] = True actor_outputs = self.actor_module.sequence_mu( obs_seq, **sequence_mu_kwargs, ) output_parts = list(actor_outputs) mu = output_parts.pop(0) ref_aux_hidden = ( output_parts.pop(0) if need_ref_aux_hidden else None ) router_features = ( output_parts.pop(0) if need_router_features else None ) router_temporal_features = ( output_parts.pop(0) if self.aux_router_switch_penalty_enabled else None ) aux_preds = {} if router_features is not None: aux_preds["router_features"] = router_features if router_temporal_features is not None: aux_preds["router_temporal_features"] = ( router_temporal_features ) if self.aux_router_command_recon_enabled: aux_preds["router_command_recon"] = ( self.actor_module.predict_aux_router_command_from_router_features( router_features ) ) if aux_router_future_recon_enabled: aux_preds["router_future_recon"] = ( self.actor_module.predict_aux_router_future_recon_from_router_hidden( ref_aux_hidden ) ) else: mu = self.actor_module.sequence_mu(obs_seq, attn_mask=attn_mask) # Match sampling-time clamping for stability and consistent KL/log-prob sigma_vec = self._sigma_from_params().clamp( self.min_sigma, self.max_sigma ) sigma = sigma_vec[None, None, :].expand_as(mu) var = sigma * sigma logp = -0.5 * ( ((actions - mu) ** 2) / (var + 1.0e-8) + 2.0 * torch.log(sigma + 1.0e-8) + math.log(2.0 * math.pi) ).sum(dim=-1, keepdim=True) entropy = ( 0.5 + 0.5 * math.log(2.0 * math.pi) + torch.log(sigma + 1.0e-8) ).sum(dim=-1, keepdim=True) return mu, sigma, logp, entropy, aux_preds class PPOTFRefRouterActor(PPOTFActor): @staticmethod def _leaf_obs_name(term: str) -> str: return str(term).rsplit("/", maxsplit=1)[-1] @classmethod def _infer_flat_term_dim( cls, *, obs_example: TensorDict, term: str, seq_len: int, ) -> int: tensor = TensorDictAssembler._get_from_data(obs_example, str(term)) if tensor is None: raise KeyError( f"Missing obs term '{term}' in obs_example while inferring " "reference-router feature indices." ) if not isinstance(tensor, torch.Tensor): raise TypeError( f"Obs term '{term}' must be a torch.Tensor, got {type(tensor)}." ) if tensor.ndim == 2: if seq_len != 1: raise ValueError( f"Obs term '{term}' expected seq_len={seq_len} but tensor " f"is 2D with shape {tuple(tensor.shape)}." ) return int(tensor.shape[-1]) if tensor.ndim == 3: if int(tensor.shape[1]) != seq_len: raise ValueError( f"Obs term '{term}' seq_len mismatch: expected {seq_len}, " f"got {int(tensor.shape[1])}." ) return int(tensor.shape[1] * tensor.shape[-1]) raise ValueError( f"Obs term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}." ) @classmethod def infer_router_feature_indices( cls, obs_schema: dict, obs_example: TensorDict, ) -> list[int]: if not isinstance(obs_example, TensorDict): raise ValueError( "PPOTFRefRouterActor requires TensorDict obs_example." ) router_feature_indices: list[int] = [] offset = 0 for _, seq_cfg in obs_schema.items(): if not isinstance(seq_cfg, dict): continue seq_len = int(seq_cfg.get("seq_len", 1)) for term in seq_cfg.get("terms", []): term_str = str(term) flat_dim = cls._infer_flat_term_dim( obs_example=obs_example, term=term_str, seq_len=seq_len, ) leaf_name = cls._leaf_obs_name(term_str) if leaf_name.startswith("actor_ref_"): router_feature_indices.extend( range(offset, offset + flat_dim) ) offset += flat_dim if len(router_feature_indices) == 0: raise ValueError( "PPOTFRefRouterActor could not infer any actor_ref_* features " "from obs_schema." ) return router_feature_indices def __init__( self, obs_schema: dict | None, module_config_dict: dict, num_actions: int, init_noise_std: float, *, obs_example: dict | None = None, ): if obs_schema is None: raise ValueError( "PPOTFRefRouterActor requires non-empty obs_schema." ) if obs_example is None: raise ValueError("PPOTFRefRouterActor requires obs_example.") if bool(module_config_dict.get("use_future_cross_attn", False)): raise ValueError( "PPOTFRefRouterActor does not support use_future_cross_attn=True." ) actor_module_cfg = copy.deepcopy(module_config_dict) aux_future_cfg = actor_module_cfg.get("aux_router_future_recon", {}) if bool(aux_future_cfg.get("enabled", False)): raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicy does not support " "aux_router_future_recon." ) router_feature_indices = self.infer_router_feature_indices( obs_schema, obs_example ) actor_module_cfg["router_input_dim"] = int(len(router_feature_indices)) actor_module_cfg["router_feature_indices"] = list( router_feature_indices ) if "router_embed_mlp_hidden" not in actor_module_cfg: actor_module_cfg["router_embed_mlp_hidden"] = int( actor_module_cfg.get("obs_embed_mlp_hidden", 1024) ) super().__init__( obs_schema=obs_schema, module_config_dict=actor_module_cfg, num_actions=num_actions, init_noise_std=init_noise_std, obs_example=obs_example, ) self.router_feature_indices = list(router_feature_indices) class PPOTFRefRouterSeqActor(PPOTFActor): REQUIRED_CURRENT_REF_TERMS = ( "actor_ref_gravity_projection_cur", "actor_ref_base_linvel_cur", "actor_ref_base_angvel_cur", "actor_ref_dof_pos_cur", "actor_ref_root_height_cur", ) REQUIRED_FUTURE_REF_TERMS = ( "actor_ref_gravity_projection_fut", "actor_ref_base_linvel_fut", "actor_ref_base_angvel_fut", "actor_ref_dof_pos_fut", "actor_ref_root_height_fut", ) SUPPORTED_AUX_WEIGHT_NAMES = { "w_base_lin_vel", "w_keybody_contact", "w_ref_keybody_rel_pos", "w_robot_keybody_rel_pos", } @staticmethod def _leaf_obs_name(term: str) -> str: return str(term).rsplit("/", maxsplit=1)[-1] @classmethod def _infer_flat_term_dim( cls, *, obs_example: TensorDict, term: str, seq_len: int, ) -> int: tensor = TensorDictAssembler._get_from_data(obs_example, str(term)) if tensor is None: raise KeyError( f"Missing obs term '{term}' in obs_example while inferring shared ref partitions." ) if not isinstance(tensor, torch.Tensor): raise TypeError( f"Obs term '{term}' must be a torch.Tensor, got {type(tensor)}." ) if tensor.ndim == 2: if seq_len != 1: raise ValueError( f"Obs term '{term}' expected seq_len={seq_len} but tensor " f"is 2D with shape {tuple(tensor.shape)}." ) return int(tensor.shape[-1]) if tensor.ndim == 3: if int(tensor.shape[1]) != seq_len: raise ValueError( f"Obs term '{term}' seq_len mismatch: expected {seq_len}, " f"got {int(tensor.shape[1])}." ) return int(tensor.shape[-1]) raise ValueError( f"Obs term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}." ) @classmethod def _validate_v2_aux_config(cls, module_config_dict: dict) -> None: aux_cmd_cfg = module_config_dict.get("aux_router_command_recon", {}) if bool(aux_cmd_cfg.get("enabled", False)): raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 does not support " "aux_router_command_recon." ) aux_cfg = module_config_dict.get("aux_state_pred", {}) if not bool(aux_cfg.get("enabled", False)): return for key, value in aux_cfg.items(): if not str(key).startswith("w_"): continue if float(value) <= 0.0: continue if str(key) not in cls.SUPPORTED_AUX_WEIGHT_NAMES: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 only supports " "aux_state_pred weights for " "base_lin_vel, keybody_contact, ref_keybody_rel_pos, and " "robot_keybody_rel_pos. Unsupported weight: " f"{key}." ) @classmethod def _build_aux_router_future_recon_schema( cls, obs_schema: dict ) -> dict[str, dict]: required_terms = set(cls.REQUIRED_FUTURE_REF_TERMS) matched_terms: set[str] = set() future_schema: dict[str, dict] = {} for group_name, seq_cfg in obs_schema.items(): if not isinstance(seq_cfg, dict): continue terms = [ str(term) for term in seq_cfg.get("terms", []) if cls._leaf_obs_name(str(term)) in required_terms ] if len(terms) == 0: continue next_seq_cfg = dict(seq_cfg) next_seq_cfg["terms"] = terms future_schema[str(group_name)] = next_seq_cfg matched_terms.update(cls._leaf_obs_name(term) for term in terms) missing_terms = sorted(required_terms.difference(matched_terms)) if missing_terms: raise ValueError( "PPOTFRefRouterSeqActor could not infer all future ref terms " "for aux_router_future_recon. Missing: " + ", ".join(missing_terms) ) return future_schema @classmethod def _prepare_aux_router_future_recon( cls, *, actor_module_cfg: dict, obs_schema: dict, obs_example: TensorDict, ) -> TensorDictAssembler | None: aux_future_cfg = copy.deepcopy( actor_module_cfg.get("aux_router_future_recon", {}) ) if not bool(aux_future_cfg.get("enabled", False)): actor_module_cfg["aux_router_future_recon"] = aux_future_cfg return None future_schema = cls._build_aux_router_future_recon_schema(obs_schema) future_assembler = TensorDictAssembler( future_schema, output_mode="flat" ) aux_future_cfg["output_dim"] = int( future_assembler.infer_output_dim(obs_example) ) actor_module_cfg["aux_router_future_recon"] = aux_future_cfg return future_assembler @classmethod def _infer_shared_ref_layout( cls, obs_schema: dict, obs_example: TensorDict, ) -> dict[str, int | list[int] | list[tuple[int, int, int]]]: if not isinstance(obs_example, TensorDict): raise ValueError( "PPOTFRefRouterSeqActor requires TensorDict obs_example." ) required_cur = set(cls.REQUIRED_CURRENT_REF_TERMS) required_fut = set(cls.REQUIRED_FUTURE_REF_TERMS) found_cur: dict[str, tuple[int, int]] = {} found_fut: dict[str, tuple[int, int, int]] = {} state_indices: list[int] = [] ref_cur_indices: list[int] = [] offset = 0 ref_fut_seq_len: int | None = None for _, seq_cfg in obs_schema.items(): if not isinstance(seq_cfg, dict): continue seq_len = int(seq_cfg.get("seq_len", 1)) for term in seq_cfg.get("terms", []): term_str = str(term) leaf_name = cls._leaf_obs_name(term_str) flat_term_dim = cls._infer_flat_term_dim( obs_example=obs_example, term=term_str, seq_len=seq_len, ) flat_span = int(seq_len * flat_term_dim) term_range = list(range(offset, offset + flat_span)) if leaf_name in required_cur: if seq_len != 1: raise ValueError( "current ref term " f"'{leaf_name}' must have seq_len=1, got {seq_len}." ) if leaf_name in found_cur: raise ValueError( f"duplicate current ref term '{leaf_name}' in obs_schema." ) found_cur[leaf_name] = (offset, flat_term_dim) ref_cur_indices.extend(term_range) elif leaf_name in required_fut: if leaf_name in found_fut: raise ValueError( f"duplicate future ref term '{leaf_name}' in obs_schema." ) if ref_fut_seq_len is None: ref_fut_seq_len = seq_len elif ref_fut_seq_len != seq_len: raise ValueError( "future ref terms must share one seq_len, got " f"{ref_fut_seq_len} and {seq_len}." ) found_fut[leaf_name] = ( offset, offset + flat_span, flat_term_dim, ) else: state_indices.extend(term_range) offset += flat_span missing_cur = sorted(required_cur.difference(found_cur.keys())) if missing_cur: raise ValueError( "missing required current ref term(s): " + ", ".join(missing_cur) ) missing_fut = sorted(required_fut.difference(found_fut.keys())) if missing_fut: raise ValueError( "missing required future ref term(s): " + ", ".join(missing_fut) ) if ref_fut_seq_len is None or ref_fut_seq_len <= 0: raise ValueError( "missing required future ref terms in obs_schema." ) if len(state_indices) == 0: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 requires at least " "one non-reference actor state feature." ) ordered_fut_slices = [ found_fut[leaf_name] for leaf_name in cls.REQUIRED_FUTURE_REF_TERMS ] return { "full_obs_input_dim": int(offset), "state_obs_input_dim": int(len(state_indices)), "ref_cur_token_dim": int(len(ref_cur_indices)), "ref_fut_token_dim": int( sum(end - start for start, end, _ in ordered_fut_slices) // ref_fut_seq_len ), "ref_fut_seq_len": int(ref_fut_seq_len), "state_feature_indices": state_indices, "ref_cur_feature_indices": ref_cur_indices, "ref_fut_slices": ordered_fut_slices, } def __init__( self, obs_schema: dict | None, module_config_dict: dict, num_actions: int, init_noise_std: float, *, obs_example: dict | None = None, ): if obs_schema is None: raise ValueError( "PPOTFRefRouterSeqActor requires non-empty obs_schema." ) if obs_example is None: raise ValueError("PPOTFRefRouterSeqActor requires obs_example.") if bool(module_config_dict.get("use_future_cross_attn", False)): raise ValueError( "PPOTFRefRouterSeqActor does not support use_future_cross_attn=True." ) self._validate_v2_aux_config(module_config_dict) inferred_layout = self._infer_shared_ref_layout( obs_schema, obs_example ) actor_module_cfg = copy.deepcopy(module_config_dict) actor_module_cfg["input_dim_override"] = int( inferred_layout["state_obs_input_dim"] ) actor_module_cfg["state_obs_input_dim"] = int( inferred_layout["state_obs_input_dim"] ) actor_module_cfg["ref_cur_token_dim"] = int( inferred_layout["ref_cur_token_dim"] ) actor_module_cfg["ref_fut_token_dim"] = int( inferred_layout["ref_fut_token_dim"] ) actor_module_cfg["ref_fut_seq_len"] = int( inferred_layout["ref_fut_seq_len"] ) actor_module_cfg["state_feature_indices"] = list( inferred_layout["state_feature_indices"] ) actor_module_cfg["ref_cur_feature_indices"] = list( inferred_layout["ref_cur_feature_indices"] ) actor_module_cfg["ref_fut_slices"] = [ list(item) for item in inferred_layout["ref_fut_slices"] ] actor_module_cfg.pop("router_hist_obs_schema", None) actor_module_cfg.pop("router_fut_obs_schema", None) super().__init__( obs_schema=obs_schema, module_config_dict=actor_module_cfg, num_actions=num_actions, init_noise_std=init_noise_std, obs_example=obs_example, ) self.full_obs_input_dim = int(inferred_layout["full_obs_input_dim"]) self.state_obs_input_dim = int(inferred_layout["state_obs_input_dim"]) self.ref_cur_token_dim = int(inferred_layout["ref_cur_token_dim"]) self.ref_fut_token_dim = int(inferred_layout["ref_fut_token_dim"]) self.ref_fut_seq_len = int(inferred_layout["ref_fut_seq_len"]) self.state_feature_indices = list( inferred_layout["state_feature_indices"] ) self.ref_cur_feature_indices = list( inferred_layout["ref_cur_feature_indices"] ) self.ref_fut_slices = [ tuple(int(v) for v in item) for item in inferred_layout["ref_fut_slices"] ] class PPOTFRefRouterV3Actor(PPOTFRefRouterSeqActor): def __init__( self, obs_schema: dict | None, module_config_dict: dict, num_actions: int, init_noise_std: float, *, obs_example: dict | None = None, ): if obs_schema is None: raise ValueError( "PPOTFRefRouterV3Actor requires non-empty obs_schema." ) if obs_example is None: raise ValueError("PPOTFRefRouterV3Actor requires obs_example.") if bool(module_config_dict.get("use_future_cross_attn", False)): raise ValueError( "PPOTFRefRouterV3Actor does not support use_future_cross_attn=True." ) self._validate_v2_aux_config(module_config_dict) inferred_layout = self._infer_shared_ref_layout( obs_schema, obs_example ) actor_module_cfg = copy.deepcopy(module_config_dict) actor_module_cfg["state_obs_input_dim"] = int( inferred_layout["state_obs_input_dim"] ) actor_module_cfg["ref_cur_token_dim"] = int( inferred_layout["ref_cur_token_dim"] ) actor_module_cfg["ref_fut_token_dim"] = int( inferred_layout["ref_fut_token_dim"] ) actor_module_cfg["ref_fut_seq_len"] = int( inferred_layout["ref_fut_seq_len"] ) actor_module_cfg["state_feature_indices"] = list( inferred_layout["state_feature_indices"] ) actor_module_cfg["ref_cur_feature_indices"] = list( inferred_layout["ref_cur_feature_indices"] ) actor_module_cfg["ref_fut_slices"] = [ list(item) for item in inferred_layout["ref_fut_slices"] ] actor_module_cfg.pop("router_hist_obs_schema", None) actor_module_cfg.pop("router_fut_obs_schema", None) future_recon_assembler = self._prepare_aux_router_future_recon( actor_module_cfg=actor_module_cfg, obs_schema=obs_schema, obs_example=obs_example, ) PPOTFActor.__init__( self, obs_schema=obs_schema, module_config_dict=actor_module_cfg, num_actions=num_actions, init_noise_std=init_noise_std, obs_example=obs_example, ) self.full_obs_input_dim = int(inferred_layout["full_obs_input_dim"]) self.state_obs_input_dim = int(inferred_layout["state_obs_input_dim"]) self.ref_cur_token_dim = int(inferred_layout["ref_cur_token_dim"]) self.ref_fut_token_dim = int(inferred_layout["ref_fut_token_dim"]) self.ref_fut_seq_len = int(inferred_layout["ref_fut_seq_len"]) self.state_feature_indices = list( inferred_layout["state_feature_indices"] ) self.ref_cur_feature_indices = list( inferred_layout["ref_cur_feature_indices"] ) self.ref_fut_slices = [ tuple(int(v) for v in item) for item in inferred_layout["ref_fut_slices"] ] self.aux_router_future_recon_assembler = future_recon_assembler class PPOCondTFActor(PPOTFActor): """Transformer actor with flat state obs and seq future-token conditioning.""" def __init__( self, obs_schema: dict | None, module_config_dict: dict, num_actions: int, init_noise_std: float, *, obs_example: dict | None = None, ): super().__init__( obs_schema=obs_schema, module_config_dict=module_config_dict, num_actions=num_actions, init_noise_std=init_noise_std, obs_example=obs_example, ) if obs_schema is None: raise ValueError("PPOCondTFActor requires non-empty obs_schema.") if "flattened_obs" not in obs_schema: raise ValueError("obs_schema must contain 'flattened_obs'.") if "flattened_obs_fut" not in obs_schema: raise ValueError("obs_schema must contain 'flattened_obs_fut'.") if obs_example is None: raise ValueError("PPOCondTFActor requires obs_example.") self.state_schema = {"flattened_obs": obs_schema["flattened_obs"]} self.future_schema = { "flattened_obs_fut": obs_schema["flattened_obs_fut"] } self.state_assembler = TensorDictAssembler( self.state_schema, output_mode="flat" ) self.future_assembler = TensorDictAssembler( self.future_schema, output_mode="seq" ) self.state_dim = int( self.state_assembler.infer_output_dim(obs_example) ) self.future_token_dim = int( self.future_assembler.infer_output_dim(obs_example) ) self.future_seq_len = int(self.future_assembler.seq_len) self.future_term_dims = self._infer_future_term_dims(obs_example) self.full_obs_dim = int(self.flat_obs_dim) expected_full = self.state_dim + ( self.future_seq_len * self.future_token_dim ) if self.full_obs_dim != expected_full: raise ValueError( "Assembled obs dim mismatch in PPOCondTFActor: " f"full={self.full_obs_dim}, expected={expected_full}" ) if self.obs_norm_enabled: self.state_obs_normalizer = EmpiricalNormalization( shape=self.state_dim, eps=self.obs_norm_eps, update_method=self.obs_norm_update_method, ema_momentum=self.obs_norm_ema_momentum, ) else: self.state_obs_normalizer = nn.Identity() def _infer_future_term_dims(self, obs_example: TensorDict) -> list[int]: if not isinstance(obs_example, TensorDict): raise ValueError("PPOCondTFActor requires TensorDict obs_example.") fut_cfg = self.future_schema.get("flattened_obs_fut", None) if fut_cfg is None: raise ValueError( "Missing future schema group 'flattened_obs_fut'." ) terms = fut_cfg.get("terms", []) if not isinstance(terms, list) or len(terms) == 0: raise ValueError("Future schema terms must be a non-empty list.") dims: list[int] = [] for term in terms: tensor = TensorDictAssembler._get_from_data(obs_example, str(term)) if tensor is None: raise KeyError( f"Missing future term '{term}' in obs_example TensorDict." ) if not isinstance(tensor, torch.Tensor): raise TypeError( f"Future term '{term}' must be a torch.Tensor, got {type(tensor)}" ) if tensor.ndim == 2: dims.append(int(tensor.shape[-1])) elif tensor.ndim == 3: dims.append(int(tensor.shape[-1])) else: raise ValueError( f"Future term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}" ) if sum(dims) != int(self.future_token_dim): raise ValueError( "Inferred future_term_dims sum mismatch: expected " f"{int(self.future_token_dim)}, got {sum(dims)} (dims={dims})" ) return dims @property def flat_obs_dim(self) -> int: if self.assembler is None: raise ValueError( "PPOCondTFActor requires the base flat assembler for ONNX." ) if self.assembler.output_dim is None: raise ValueError("Base assembler output_dim is not initialized.") return int(self.assembler.output_dim) def _normalize_state_obs( self, state_obs: torch.Tensor, update: bool ) -> torch.Tensor: if not self.obs_norm_enabled: return state_obs if state_obs.ndim != 2: raise ValueError( f"state_obs must be [B, D_state], got {tuple(state_obs.shape)}" ) if update: self.state_obs_normalizer.update(state_obs) state_obs = self.state_obs_normalizer.normalize_only(state_obs) if self.obs_norm_clip > 0.0: state_obs = torch.clamp( state_obs, -self.obs_norm_clip, self.obs_norm_clip ) return state_obs def _assemble_state_future( self, obs_td: TensorDict ) -> tuple[torch.Tensor, torch.Tensor]: if not isinstance(obs_td, TensorDict): raise ValueError( "PPOCondTFActor._assemble_state_future expects TensorDict input." ) state_obs = self.state_assembler(obs_td) future_obs = self.future_assembler(obs_td) return state_obs, future_obs def _split_flat_obs( self, obs: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: if obs.ndim != 2: raise ValueError(f"Expected [B, D], got {obs.shape}") state_obs = obs[:, : self.state_dim] future_flat = obs[:, self.state_dim :] expected_dim = self.future_seq_len * self.future_token_dim if future_flat.shape[-1] != expected_dim: raise ValueError( "Future flat obs dim mismatch: expected " f"{expected_dim}, got {future_flat.shape[-1]}" ) b = int(obs.shape[0]) offset = 0 future_parts = [] for d_term in self.future_term_dims: span = int(self.future_seq_len * d_term) chunk = future_flat[:, offset : offset + span] future_parts.append(chunk.reshape(b, self.future_seq_len, d_term)) offset += span if offset != int(future_flat.shape[-1]): raise ValueError( "Future flat slicing mismatch: " f"consumed={offset}, total={int(future_flat.shape[-1])}" ) future_obs = torch.cat(future_parts, dim=-1) return state_obs, future_obs def export_onnx( self, onnx_path: str | Path, *, opset_version: int = 17, ) -> str: export_path = Path(onnx_path) export_path.parent.mkdir(parents=True, exist_ok=True) if hasattr(self.actor_module, "clear_router_distribution_cache"): self.actor_module.clear_router_distribution_cache() actor_module = _clone_module_for_cpu_export(self.actor_module) if self.obs_norm_enabled: state_obs_normalizer = _clone_module_for_cpu_export( self.state_obs_normalizer ) else: state_obs_normalizer = nn.Identity() exporter = PPOCondTFActorOnnxModule( actor_module=actor_module, state_obs_normalizer=state_obs_normalizer, obs_norm_enabled=self.obs_norm_enabled, obs_norm_clip=self.obs_norm_clip if self.obs_norm_enabled else 0.0, state_dim=self.state_dim, future_seq_len=self.future_seq_len, future_token_dim=self.future_token_dim, future_term_dims=self.future_term_dims, ).to("cpu") exporter.eval() cache_shape = self.onnx_past_key_values_shape(batch_size=1) obs = torch.zeros( 1, self.flat_obs_dim, device="cpu", dtype=torch.float32 ) past_key_values = torch.zeros( *cache_shape, device="cpu", dtype=torch.float32 ) step_idx = torch.tensor([0], dtype=torch.long, device="cpu") output_names = [ "actions", "present_key_values", *self.onnx_routing_output_names(), ] torch.onnx.export( exporter, (obs, past_key_values, step_idx), str(export_path), export_params=True, opset_version=opset_version, verbose=False, dynamo=False, input_names=["obs", "past_key_values", "step_idx"], output_names=output_names, ) return str(export_path) def update_distribution(self, actor_obs): if not isinstance(actor_obs, tuple) or len(actor_obs) != 2: raise ValueError( "PPOCondTFActor.update_distribution expects tuple(state_obs, future_obs)." ) state_obs, future_obs = actor_obs mu = self.actor_module.single_step_mu_cond( state_obs, future_obs, future_mask=None, ) std = self._sigma_from_params() std = torch.clamp(std, min=self.min_sigma, max=self.max_sigma) self.distribution = Normal(mu, std) def forward( self, obs_td: TensorDict | torch.Tensor, actions: torch.Tensor | None = None, mode: str = "sampling", attn_mask: torch.Tensor | None = None, *, update_obs_norm: bool = True, past_key_values: torch.Tensor | None = None, current_pos: torch.Tensor | None = None, ) -> TensorDict | tuple[torch.Tensor, torch.Tensor]: if past_key_values is not None: if isinstance(obs_td, TensorDict): state_obs, future_obs = self._assemble_state_future(obs_td) else: state_obs, future_obs = self._split_flat_obs(obs_td) state_obs = self._normalize_state_obs(state_obs, update=False) return self.actor_module._forward_inference_onnx_cond( state_obs, future_obs, past_key_values, current_pos, ) if mode == "sequence_logp": if not isinstance(obs_td, TensorDict): raise ValueError( "PPOCondTFActor.forward(mode='sequence_logp') expects TensorDict input." ) if obs_td.batch_dims != 2: raise ValueError( "PPOCondTFActor.forward(mode='sequence_logp') expects batch_dims=2 [B, T], " f"got batch_size={tuple(obs_td.batch_size)}" ) if actions is None: raise ValueError( "actions must be provided when mode='sequence_logp'" ) b, t = int(obs_td.batch_size[0]), int(obs_td.batch_size[1]) future_mask = None if "future_mask" in obs_td.keys(): future_mask = obs_td.get("future_mask") if future_mask.shape != (b, t, self.future_seq_len): raise ValueError( "future_mask shape mismatch in sequence_logp: expected " f"{(b, t, self.future_seq_len)}, got {tuple(future_mask.shape)}" ) future_mask = future_mask.to(torch.bool) flat_td = obs_td.flatten(0, 1) state_flat, future_flat = self._assemble_state_future(flat_td) update = bool(update_obs_norm) state_flat = self._normalize_state_obs(state_flat, update=update) state_seq = state_flat.reshape(b, t, -1) future_seq = future_flat.reshape( b, t, self.future_seq_len, self.future_token_dim ) ( mu, sigma, logp, entropy, aux_preds, ) = self.sequence_forward_logp_cond( state_seq, future_seq, actions, attn_mask, future_mask, ) td = obs_td.clone(recurse=False) td.set("mu", mu) td.set("sigma", sigma) td.set("actions", actions) td.set("actions_log_prob", logp) td.set("entropy", entropy) if aux_preds is not None: if "base_lin_vel_loc" in aux_preds: td.set( "aux_base_lin_vel_loc", aux_preds["base_lin_vel_loc"] ) td.set( "aux_base_lin_vel_log_std", aux_preds["base_lin_vel_log_std"], ) td.set("aux_root_height_loc", aux_preds["root_height_loc"]) td.set( "aux_root_height_log_std", aux_preds["root_height_log_std"], ) td.set( "aux_keybody_contact_logits", aux_preds["keybody_contact_logits"], ) td.set( "aux_ref_keybody_rel_pos", aux_preds["ref_keybody_rel_pos"], ) td.set( "aux_robot_keybody_rel_pos", aux_preds["robot_keybody_rel_pos"], ) if "denoise_ref_root_lin_vel_residual" in aux_preds: td.set( "aux_denoise_ref_root_lin_vel_residual", aux_preds["denoise_ref_root_lin_vel_residual"], ) if "denoise_ref_root_ang_vel_residual" in aux_preds: td.set( "aux_denoise_ref_root_ang_vel_residual", aux_preds["denoise_ref_root_ang_vel_residual"], ) if "denoise_ref_dof_pos_residual" in aux_preds: td.set( "aux_denoise_ref_dof_pos_residual", aux_preds["denoise_ref_dof_pos_residual"], ) if "router_command_recon" in aux_preds: td.set( "aux_router_command_recon", aux_preds["router_command_recon"], ) if "router_features" in aux_preds: td.set("router_features", aux_preds["router_features"]) if "router_temporal_features" in aux_preds: td.set( "router_temporal_features", aux_preds["router_temporal_features"], ) return td if mode not in ("sampling", "logp", "inference"): raise ValueError(f"Unsupported mode: {mode}") if not isinstance(obs_td, TensorDict): raise ValueError( "PPOCondTFActor.forward expects TensorDict input." ) td = obs_td.clone(recurse=False) state_obs, future_obs = self._assemble_state_future(obs_td) update = bool(update_obs_norm) state_obs = self._normalize_state_obs(state_obs, update=update) future_mask = None if "future_mask" in td.keys(): future_mask = td.get("future_mask") if future_mask.shape != (state_obs.shape[0], self.future_seq_len): raise ValueError( "future_mask shape mismatch in single-step forward: expected " f"{(state_obs.shape[0], self.future_seq_len)}, got {tuple(future_mask.shape)}" ) future_mask = future_mask.to(torch.bool) mu = self.actor_module.single_step_mu_cond( state_obs, future_obs, future_mask=future_mask ) sigma = self._sigma_like(mu) td.set("mu", mu) td.set("sigma", sigma) if mode == "inference": td.set("actions", mu) return td self.distribution = Normal(mu, sigma) if mode == "sampling": actions_out = self.distribution.sample() else: if actions is None: raise ValueError("actions must be provided when mode='logp'") actions_out = actions td.set("actions", actions_out) td.set( "actions_log_prob", self.distribution.log_prob(actions_out).sum(dim=-1), ) td.set("entropy", self.distribution.entropy().sum(dim=-1)) return td def sequence_forward_logp_cond( self, state_seq: torch.Tensor, future_seq: torch.Tensor, actions: torch.Tensor, attn_mask: torch.Tensor | None, future_mask: torch.Tensor | None, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor] | None, ]: aux_preds = None need_pre_moe_aux = self.aux_state_pred_enabled need_router_aux = ( self.aux_router_command_recon_enabled or self.aux_router_switch_penalty_enabled ) if need_pre_moe_aux and need_router_aux: actor_outputs = self.actor_module.sequence_mu_cond( state_seq, future_seq, attn_mask=attn_mask, future_mask=future_mask, return_pre_moe_hidden=True, return_router_features=True, return_router_temporal_features=self.aux_router_switch_penalty_enabled, ) if self.aux_router_switch_penalty_enabled: ( mu, pre_moe_hidden, router_features, router_temporal_features, ) = actor_outputs else: mu, pre_moe_hidden, router_features = actor_outputs aux_preds = self.actor_module.predict_aux_from_pre_moe( pre_moe_hidden ) aux_preds["router_features"] = router_features if self.aux_router_switch_penalty_enabled: aux_preds["router_temporal_features"] = ( router_temporal_features ) if self.aux_router_command_recon_enabled: aux_preds["router_command_recon"] = ( self.actor_module.predict_aux_router_command_from_router_features( router_features ) ) elif need_pre_moe_aux: mu, pre_moe_hidden = self.actor_module.sequence_mu_cond( state_seq, future_seq, attn_mask=attn_mask, future_mask=future_mask, return_pre_moe_hidden=True, ) aux_preds = self.actor_module.predict_aux_from_pre_moe( pre_moe_hidden ) elif need_router_aux: actor_outputs = self.actor_module.sequence_mu_cond( state_seq, future_seq, attn_mask=attn_mask, future_mask=future_mask, return_router_features=True, return_router_temporal_features=self.aux_router_switch_penalty_enabled, ) if self.aux_router_switch_penalty_enabled: ( mu, router_features, router_temporal_features, ) = actor_outputs else: mu, router_features = actor_outputs aux_preds = {"router_features": router_features} if self.aux_router_switch_penalty_enabled: aux_preds["router_temporal_features"] = ( router_temporal_features ) if self.aux_router_command_recon_enabled: aux_preds["router_command_recon"] = ( self.actor_module.predict_aux_router_command_from_router_features( router_features ) ) else: mu = self.actor_module.sequence_mu_cond( state_seq, future_seq, attn_mask=attn_mask, future_mask=future_mask, ) sigma_vec = self._sigma_from_params().clamp( self.min_sigma, self.max_sigma ) sigma = sigma_vec[None, None, :].expand_as(mu) var = sigma * sigma logp = -0.5 * ( ((actions - mu) ** 2) / (var + 1.0e-8) + 2.0 * torch.log(sigma + 1.0e-8) + math.log(2.0 * math.pi) ).sum(dim=-1, keepdim=True) entropy = ( 0.5 + 0.5 * math.log(2.0 * math.pi) + torch.log(sigma + 1.0e-8) ).sum(dim=-1, keepdim=True) return mu, sigma, logp, entropy, aux_preds ================================================ FILE: holomotion/src/modules/network_modules.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import math from contextlib import nullcontext import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint class EmpiricalNormalization(nn.Module): """Normalize mean and variance of values based on empirical values.""" def __init__( self, shape, eps: float = 1e-2, until: int | None = None, *, update_method: str = "cumulative", ema_momentum: float | None = None, ): """Initialize EmpiricalNormalization module. Args: shape (int or tuple of int): Shape of input values except batch axis. eps (float): Small value for stability. until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes exceeds it. update_method: One of {"cumulative", "ema"}. - "cumulative": count-based updates (legacy behavior). - "ema": EMA updates of mean and second moment. ema_momentum: EMA momentum in (0, 1]. Required when update_method == "ema". """ super().__init__() self.eps = eps self.until = until self.update_method = str(update_method).lower() self.ema_momentum = ( float(ema_momentum) if ema_momentum is not None else None ) if self.update_method in ("count", "cumulative"): self.update_method = "cumulative" elif self.update_method in ("ema", "exp", "exponential"): self.update_method = "ema" else: raise ValueError( f"update_method must be one of {{'cumulative','ema'}}, got {update_method}" ) if self.update_method == "ema": if self.ema_momentum is None: raise ValueError( "ema_momentum must be provided when update_method == 'ema'" ) if not (0.0 < self.ema_momentum <= 1.0): raise ValueError( f"ema_momentum must be in (0, 1], got {self.ema_momentum}" ) self.register_buffer("_mean", torch.zeros(shape)[None, ...]) self.register_buffer("_var", torch.ones(shape)[None, ...]) self.register_buffer("_std", torch.ones(shape)[None, ...]) self.register_buffer("_ex2", torch.ones(shape)[None, ...]) self.register_buffer("count", torch.tensor(0, dtype=torch.long)) self.register_buffer("_last_sync_mean", torch.zeros(shape)[None, ...]) self.register_buffer("_last_sync_var", torch.ones(shape)[None, ...]) self.register_buffer( "_last_sync_count", torch.tensor(0, dtype=torch.long) ) @property def mean(self): return self._mean.squeeze(0).clone() @property def std(self): return self._std.squeeze(0).clone() def forward(self, x): """Normalize mean and variance of values based on empirical values. Args: x (ndarray or Variable): Input values Returns: ndarray or Variable: Normalized output values """ if self.training: self.update(x) return (x - self._mean) / (self._std + self.eps) def normalize_only(self, x): return (x - self._mean) / (self._std + self.eps) @torch.compiler.disable @torch.jit.unused def update(self, x): """Learn input values without computing the output values of them.""" if self.until is not None and self.count >= self.until: return count_x = x.shape[0] self.count += count_x if self.update_method == "ema": m = float(self.ema_momentum) mean_x = torch.mean(x, dim=0, keepdim=True) ex2_x = torch.mean(x * x, dim=0, keepdim=True) self._mean.mul_(1.0 - m).add_(mean_x, alpha=m) self._ex2.mul_(1.0 - m).add_(ex2_x, alpha=m) var = torch.clamp(self._ex2 - self._mean * self._mean, min=0.0) self._var.copy_(var) self._std.copy_(torch.sqrt(self._var)) return rate = count_x / self.count var_x = torch.var(x, dim=0, unbiased=False, keepdim=True) mean_x = torch.mean(x, dim=0, keepdim=True) delta_mean = mean_x - self._mean self._mean += rate * delta_mean self._var += rate * ( var_x - self._var + delta_mean * (mean_x - self._mean) ) self._std = torch.sqrt(self._var) @torch.jit.unused def inverse(self, y): return y * (self._std + self.eps) + self._mean def sync_stats_across_processes(self, accelerator): """Synchronize normalization statistics across distributed processes.""" if accelerator.num_processes <= 1: return if self.update_method == "ema": # EMA stats are already running estimates. # Sync by averaging across ranks. mean_g = accelerator.reduce( self._mean.to(dtype=torch.float32), reduction="mean" ) ex2_g = accelerator.reduce( self._ex2.to(dtype=torch.float32), reduction="mean" ) var_g = torch.clamp(ex2_g - mean_g * mean_g, min=0.0) self._mean.copy_(mean_g.to(self._mean.dtype)) self._ex2.copy_(ex2_g.to(self._ex2.dtype)) self._var.copy_(var_g.to(self._var.dtype)) self._std.copy_(torch.sqrt(self._var)) return # Weighted synchronization with correction to avoid double counting device = self._mean.device count_local = self.count.to(device=device, dtype=torch.float32) mean_local = self._mean.to(device=device, dtype=torch.float32) var_local = self._var.to(device=device, dtype=torch.float32) # Local weighted sums sum_count = accelerator.reduce(count_local, reduction="sum") sum_mean_count = accelerator.reduce( mean_local * count_local, reduction="sum" ) sum_ex2_count = accelerator.reduce( (var_local + mean_local * mean_local) * count_local, reduction="sum", ) # Correct for replication of previously-synced global stats # across ranks. last_c = self._last_sync_count.to(device=device, dtype=torch.float32) if last_c.item() > 0: w_minus_1 = float(accelerator.num_processes - 1) last_mean = self._last_sync_mean.to( device=device, dtype=torch.float32 ) last_var = self._last_sync_var.to( device=device, dtype=torch.float32 ) sum_count = sum_count - w_minus_1 * last_c sum_mean_count = sum_mean_count - w_minus_1 * (last_mean * last_c) sum_ex2_count = sum_ex2_count - w_minus_1 * ( (last_var + last_mean * last_mean) * last_c ) if sum_count.item() <= 0: return global_mean = sum_mean_count / sum_count global_ex2 = sum_ex2_count / sum_count global_var = torch.clamp( global_ex2 - global_mean * global_mean, min=0.0 ) global_std = torch.sqrt(global_var) # Copy back (keep original buffer shapes) self._mean.copy_(global_mean.to(self._mean.dtype)) self._var.copy_(global_var.to(self._var.dtype)) self._std.copy_(global_std.to(self._std.dtype)) # Set global sample count and remember snapshot for next correction self.count.copy_(sum_count.to(self.count.dtype)) self._last_sync_mean.copy_(global_mean.to(self._last_sync_mean.dtype)) self._last_sync_var.copy_(global_var.to(self._last_sync_var.dtype)) self._last_sync_count.copy_(self.count) class MLP(nn.Module): def __init__( self, input_dim: int, output_dim: int, module_config_dict: dict, ): super().__init__() self.module_config_dict = module_config_dict self.input_dim = int(input_dim) self.output_dim = int(output_dim) if self.input_dim <= 0: raise ValueError( f"MLP input_dim must be positive, got {self.input_dim}" ) if self.output_dim <= 0: raise ValueError( f"MLP output_dim must be positive, got {self.output_dim}" ) def _make_norm( norm_type: str, dim: int, *, eps: float, ) -> nn.Module: t = str(norm_type).lower() if t in ("none", "identity", "null"): return nn.Identity() if t in ("layernorm", "ln"): return nn.LayerNorm(dim, eps=eps) if t in ("rmsnorm", "rms"): return RMSNorm(dim, eps=eps) raise ValueError( f"Unknown norm '{t}'. Expected one of {'none', 'layernorm', 'rmsnorm'}." ) self.hidden_norm_type = module_config_dict.get("hidden_norm", "none") self.hidden_norm_eps = float( module_config_dict.get("hidden_norm_eps", 1.0e-6) ) layer_config = self.module_config_dict["layer_config"] hidden_dims: list[int] = list(layer_config.get("hidden_dims", [])) activation = getattr(nn, str(layer_config["activation"]))() layers: list[nn.Module] = [] prev = self.input_dim for h in hidden_dims: h_i = int(h) layers.append(nn.Linear(prev, h_i)) layers.append( _make_norm( self.hidden_norm_type, h_i, eps=self.hidden_norm_eps, ) ) layers.append(activation) prev = h_i self.trunk = nn.Sequential(*layers) if layers else nn.Identity() self.output_head = nn.Linear(prev, self.output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward. Args: x: [..., input_dim] assembled tensor observations. Returns: y: [..., output_dim] """ if not isinstance(x, torch.Tensor): raise TypeError(f"MLP expects torch.Tensor input, got {type(x)}") h = self.trunk(x) return self.output_head(h) class ConvMLP(nn.Module): """Conv1d + pooling history encoder with an MLP head.""" def __init__( self, input_dim: int, output_dim: int, module_config_dict: dict, ): super().__init__() self.module_config_dict = module_config_dict self.input_dim = int(input_dim) self.output_dim = int(output_dim) layer_cfg = dict(module_config_dict.get("layer_config", {})) activation = str(layer_cfg.get("activation", "SiLU")) self.conv_channels = int(module_config_dict.get("conv_channels", 128)) self.conv_layers = int(module_config_dict.get("conv_layers", 2)) self.conv_kernel_size = int( module_config_dict.get("conv_kernel_size", 3) ) self.pool_type = str( module_config_dict.get("pool_type", "avg") ).lower() conv_modules: list[nn.Module] = [] padding = self.conv_kernel_size // 2 in_ch = int(self.input_dim) for _ in range(self.conv_layers): conv_modules.append( nn.Conv1d( in_channels=in_ch, out_channels=self.conv_channels, kernel_size=self.conv_kernel_size, padding=padding, bias=True, ) ) conv_modules.append(getattr(nn, activation)()) in_ch = self.conv_channels conv_modules.append(nn.AdaptiveAvgPool1d(1)) self.hist_encoder = nn.Sequential(*conv_modules) fused_dim = int(self.conv_channels + self.input_dim) self.mlp_head = MLP( input_dim=fused_dim, output_dim=int(self.output_dim), module_config_dict=module_config_dict, ) def forward(self, hist_seq: torch.Tensor) -> torch.Tensor: ctx = self.hist_encoder(hist_seq.transpose(1, 2)).squeeze(-1) latest = hist_seq[:, -1, :] fused = torch.cat([ctx, latest], dim=-1) return self.mlp_head(fused) class ReferenceMotionConvRouterEncoder(nn.Module): """Conv1d encoder for reference-motion router sequences.""" def __init__( self, input_dim: int, output_dim: int, *, conv_channels: int = 128, conv_layers: int = 2, conv_kernel_size: int = 3, pool_type: str = "avg", ): super().__init__() self.input_dim = int(input_dim) self.output_dim = int(output_dim) self.conv_channels = int(conv_channels) self.conv_layers = int(conv_layers) self.conv_kernel_size = int(conv_kernel_size) self.pool_type = str(pool_type).lower() if self.input_dim <= 0: raise ValueError( f"input_dim must be positive, got {self.input_dim}" ) if self.output_dim <= 0: raise ValueError( f"output_dim must be positive, got {self.output_dim}" ) if self.conv_channels <= 0: raise ValueError( f"conv_channels must be positive, got {self.conv_channels}" ) if self.conv_layers <= 0: raise ValueError( f"conv_layers must be positive, got {self.conv_layers}" ) if self.conv_kernel_size <= 0: raise ValueError( f"conv_kernel_size must be positive, got {self.conv_kernel_size}" ) if self.pool_type not in {"avg", "max"}: raise ValueError( f"pool_type must be one of {{'avg','max'}}, got {self.pool_type}" ) padding = self.conv_kernel_size // 2 conv_modules: list[nn.Module] = [] in_ch = self.input_dim for _ in range(self.conv_layers): conv_modules.append( nn.Conv1d( in_channels=in_ch, out_channels=self.conv_channels, kernel_size=self.conv_kernel_size, padding=padding, bias=True, ) ) conv_modules.append(nn.SiLU()) in_ch = self.conv_channels self.temporal_trunk = nn.Sequential(*conv_modules) if self.pool_type == "avg": self.pool = nn.AdaptiveAvgPool1d(1) else: self.pool = nn.AdaptiveMaxPool1d(1) self.out_proj = nn.Sequential( nn.Linear(self.conv_channels, self.output_dim), nn.SiLU(), nn.Linear(self.output_dim, self.output_dim), ) def forward(self, seq: torch.Tensor) -> torch.Tensor: if seq.ndim != 3: raise ValueError( f"Expected router seq with shape [B, T, D], got {tuple(seq.shape)}." ) if int(seq.shape[-1]) != self.input_dim: raise ValueError( "Router seq dim mismatch: expected " f"{self.input_dim}, got {int(seq.shape[-1])}." ) x = seq.transpose(1, 2) x = self.temporal_trunk(x) x = self.pool(x).squeeze(-1) return self.out_proj(x) class SingleQueryAttentionPool(nn.Module): def __init__(self, d_model: int): super().__init__() self.d_model = int(d_model) self.scale = float(self.d_model) ** -0.5 self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False) self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False) self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False) self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False) def forward( self, query: torch.Tensor, tokens: torch.Tensor, ) -> torch.Tensor: if query.ndim == 2: if tokens.ndim != 3: raise ValueError( "SingleQueryAttentionPool expected [B, N, D] tokens for " f"2D query, got {tuple(tokens.shape)}." ) q = self.q_proj(query).unsqueeze(-2) k = self.k_proj(tokens) v = self.v_proj(tokens) attn = torch.softmax( (q * k).sum(dim=-1, keepdim=True) * self.scale, dim=-2, ) return self.out_proj((attn * v).sum(dim=-2)) if query.ndim == 3: if tokens.ndim != 4: raise ValueError( "SingleQueryAttentionPool expected [B, T, N, D] tokens for " f"3D query, got {tuple(tokens.shape)}." ) q = self.q_proj(query).unsqueeze(-2) k = self.k_proj(tokens) v = self.v_proj(tokens) attn = torch.softmax( (q * k).sum(dim=-1, keepdim=True) * self.scale, dim=-2, ) return self.out_proj((attn * v).sum(dim=-2)) raise ValueError( f"SingleQueryAttentionPool query must be 2D or 3D, got {query.ndim}." ) class GroupedMoETransformerPolicy(nn.Module): """Hybrid Modern Transformer decoder policy with SOTA improvements. Structure: - Layer 0: Dense MLP (ModernTransformerBlock) - Optional final layer: Dense MLP when dense_layer_at_last=True - Intermediate layers: MoE MLP (GroupedMoEBlock) Features: - RealRoPE. - RMSNorm: Root Mean Square Normalization. - GQA: Grouped Query Attention (configurable n_kv_heads). - QK-Norm: RMSNorm on Queries and Keys. - Gated Attention: Qwen-style element-wise sigmoid gating. - SwiGLU MLP: DeepseekV3MLP for feed-forward. - Flash Attention: via F.scaled_dot_product_attention. - Gradient Checkpointing: optional for memory efficiency. """ def __init__( self, input_dim: int, output_dim: int, module_config_dict: dict, ): super().__init__() self.input_dim = int(input_dim) self.output_dim = int(output_dim) self.module_config_dict = module_config_dict self.num_fine_experts = module_config_dict["num_fine_experts"] self.num_shared_experts = module_config_dict["num_shared_experts"] self.top_k = module_config_dict["top_k"] self.use_dynamic_bias = module_config_dict.get( "use_dynamic_bias", False ) self.bias_update_rate = module_config_dict.get( "bias_update_rate", 0.001 ) self.routing_score_fn = str( module_config_dict.get("routing_score_fn", "softmax") ).lower() self.freeze_router = bool( module_config_dict.get("freeze_router", False) ) self.routing_scale = float( module_config_dict.get("routing_scale", 1.0) ) self.expert_bias_clip = float( module_config_dict.get("expert_bias_clip", 0.0) ) dead_margin_cfg = module_config_dict.get( "dead_expert_margin_to_topk", {} ) selected_margin_cfg = module_config_dict.get( "selected_expert_margin_to_unselected", {} ) self.dead_expert_margin_to_topk_enabled = bool( dead_margin_cfg.get("enabled", False) ) self.selected_expert_margin_to_unselected_enabled = bool( selected_margin_cfg.get("enabled", False) ) self.selected_expert_margin_to_unselected_target = float( selected_margin_cfg.get("target", 0.0) ) if self.routing_score_fn not in ("softmax", "sigmoid"): raise ValueError( f"routing_score_fn must be one of {{'softmax','sigmoid'}}, got {self.routing_score_fn}" ) if self.routing_scale <= 0.0: raise ValueError( f"routing_scale must be > 0, got {self.routing_scale}" ) if self.expert_bias_clip < 0.0: raise ValueError( f"expert_bias_clip must be >= 0, got {self.expert_bias_clip}" ) if self.selected_expert_margin_to_unselected_target < 0.0: raise ValueError( "selected_expert_margin_to_unselected.target must be >= 0, " f"got {self.selected_expert_margin_to_unselected_target}" ) _ov = module_config_dict.get("input_dim_override", None) self.obs_input_dim = ( int(_ov) if isinstance(_ov, (int, float)) else None ) self.obs_embed_mlp_hidden = int( module_config_dict.get("obs_embed_mlp_hidden", 1024) ) self.d_model = int(module_config_dict.get("d_model", 256)) self.n_layers = int(module_config_dict.get("n_layers", 4)) self.dense_layer_at_last = bool( module_config_dict.get("dense_layer_at_last", False) ) self.n_heads = int(module_config_dict.get("n_heads", 4)) self.n_kv_heads = int( module_config_dict.get("n_kv_heads", self.n_heads // 2) ) self.ff_mult = float(module_config_dict.get("ff_mult", 4)) self.ff_mult_dense = int( module_config_dict.get("ff_mult_dense", self.ff_mult * 3) ) self.attn_dropout = float(module_config_dict.get("attn_dropout", 0.0)) self.mlp_dropout = float(module_config_dict.get("mlp_dropout", 0.0)) self.max_ctx_len = int(module_config_dict.get("max_ctx_len", 64)) self.use_qk_norm = module_config_dict.get("use_qk_norm", True) self.use_gated_attn = module_config_dict.get("use_gated_attn", True) self.gated_attn_type = module_config_dict.get( "gated_attn_type", "headwise" ) self.use_checkpointing = module_config_dict.get( "use_checkpointing", False ) self.use_future_cross_attn = bool( module_config_dict.get("use_future_cross_attn", False) ) self.state_obs_dim = int( module_config_dict.get( "state_obs_dim", self.obs_input_dim or self.input_dim ) ) self.future_seq_len = int(module_config_dict.get("future_seq_len", 0)) self.future_token_dim = int( module_config_dict.get("future_token_dim", 0) ) self.head_dim = self.d_model // self.n_heads if self.d_model % self.n_heads != 0: raise ValueError( f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})" ) if self.head_dim % 2 != 0: raise ValueError( f"RoPE requires even head_dim, got head_dim={self.head_dim}" ) # RoPE configuration (used in both sequence and KV-cached single-step inference) self.rope_theta = float(module_config_dict.get("rope_theta", 10000.0)) self.inv_freq = 1.0 / ( self.rope_theta ** ( torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim ) ) # [head_dim//2] self.register_buffer("_rope_inv_freq", self.inv_freq, persistent=False) self._set_cos_sin_cache(seq_len=8192) obs_in = self.obs_input_dim or self.input_dim if self.use_future_cross_attn: if self.future_seq_len <= 0: raise ValueError( "future_seq_len must be positive when use_future_cross_attn=True" ) if self.future_token_dim <= 0: raise ValueError( "future_token_dim must be positive when use_future_cross_attn=True" ) self.state_obs_embed = nn.Sequential( nn.Linear(self.state_obs_dim, self.obs_embed_mlp_hidden), nn.SiLU(), nn.Linear(self.obs_embed_mlp_hidden, self.d_model), ) # Keep a single state embedding module so DDP doesn't see unused # parameters from an extra unused `obs_embed` in conditional mode. self.obs_embed = self.state_obs_embed self.future_obs_embed = nn.Sequential( nn.Linear(self.future_token_dim, self.obs_embed_mlp_hidden), nn.SiLU(), nn.Linear(self.obs_embed_mlp_hidden, self.d_model), ) self.future_pos_embed = nn.Embedding( self.future_seq_len, self.d_model ) else: self.obs_embed = nn.Sequential( nn.Linear(obs_in, self.obs_embed_mlp_hidden), nn.SiLU(), nn.Linear(self.obs_embed_mlp_hidden, self.d_model), ) self.state_obs_embed = None self.future_obs_embed = None self.future_pos_embed = None # Stack of TransformerBlocks: first layer is always dense; the last # layer is also dense when dense_layer_at_last=True. self.layers = nn.ModuleList() for i in range(self.n_layers): use_dense_layer = i == 0 or ( self.dense_layer_at_last and i == self.n_layers - 1 ) if use_dense_layer: layer = ModernTransformerBlock( d_model=self.d_model, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, ff_mult=self.ff_mult_dense, use_qk_norm=self.use_qk_norm, use_gated_attn=self.use_gated_attn, gated_attn_type=self.gated_attn_type, attn_dropout=self.attn_dropout, mlp_dropout=self.mlp_dropout, use_cross_attn=self.use_future_cross_attn, ) else: layer = GroupedMoEBlock( d_model=self.d_model, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, ff_mult=self.ff_mult, use_qk_norm=self.use_qk_norm, use_gated_attn=self.use_gated_attn, gated_attn_type=self.gated_attn_type, attn_dropout=self.attn_dropout, mlp_dropout=self.mlp_dropout, num_fine_experts=self.num_fine_experts, num_shared_experts=self.num_shared_experts, top_k=self.top_k, use_dynamic_bias=self.use_dynamic_bias, bias_update_rate=self.bias_update_rate, routing_score_fn=self.routing_score_fn, freeze_router=self.freeze_router, routing_scale=self.routing_scale, expert_bias_clip=self.expert_bias_clip, dead_expert_margin_to_topk_enabled=( self.dead_expert_margin_to_topk_enabled ), selected_expert_margin_to_unselected_enabled=( self.selected_expert_margin_to_unselected_enabled ), selected_expert_margin_to_unselected_target=( self.selected_expert_margin_to_unselected_target ), use_cross_attn=self.use_future_cross_attn, ) self.layers.append(layer) self._last_moe_layer_idx = None for layer_idx, layer in enumerate(self.layers): if isinstance(layer, GroupedMoEBlock): self._last_moe_layer_idx = layer_idx self.norm_f = RMSNorm(self.d_model) self.action_mu_head = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.SiLU(), nn.Linear(self.d_model, self.output_dim), ) aux_cfg = module_config_dict.get("aux_state_pred", {}) self.aux_state_pred_enabled = bool(aux_cfg.get("enabled", False)) self.aux_contact_dim = int( len(aux_cfg.get("keybody_contact_names", [])) ) self.aux_keybody_pos_dim = int( len(aux_cfg.get("keybody_rel_pos_names", [])) ) self.use_aux_denoise_ref_root_lin_vel = bool( float(aux_cfg.get("w_denoise_ref_root_lin_vel", 0.0)) > 0.0 ) self.use_aux_denoise_ref_root_ang_vel = bool( float(aux_cfg.get("w_denoise_ref_root_ang_vel", 0.0)) > 0.0 ) self.use_aux_denoise_ref_dof_pos = bool( float(aux_cfg.get("w_denoise_ref_dof_pos", 0.0)) > 0.0 ) if self.aux_state_pred_enabled: self.aux_vel_head = nn.Linear(self.d_model, 6) self.aux_height_head = nn.Linear(self.d_model, 2) self.aux_denoise_ref_root_lin_vel_head = ( nn.Linear(self.d_model, 3) if self.use_aux_denoise_ref_root_lin_vel else None ) self.aux_denoise_ref_root_ang_vel_head = ( nn.Linear(self.d_model, 3) if self.use_aux_denoise_ref_root_ang_vel else None ) self.aux_contact_head = ( nn.Linear(self.d_model, self.aux_contact_dim) if self.aux_contact_dim > 0 else None ) self.aux_ref_keybody_pos_head = ( nn.Linear(self.d_model, self.aux_keybody_pos_dim * 3) if self.aux_keybody_pos_dim > 0 else None ) self.aux_robot_keybody_pos_head = ( nn.Linear(self.d_model, self.aux_keybody_pos_dim * 3) if self.aux_keybody_pos_dim > 0 else None ) self.aux_denoise_ref_dof_pos_head = ( nn.Linear(self.d_model, self.output_dim) if self.use_aux_denoise_ref_dof_pos else None ) else: self.aux_vel_head = None self.aux_height_head = None self.aux_denoise_ref_root_lin_vel_head = None self.aux_denoise_ref_root_ang_vel_head = None self.aux_contact_head = None self.aux_ref_keybody_pos_head = None self.aux_robot_keybody_pos_head = None self.aux_denoise_ref_dof_pos_head = None # True per-layer KV cache for single-step inference. # K/V shapes: [B, n_layers, max_ctx_len, n_kv_heads, head_dim] self._k_cache: torch.Tensor | None = None self._v_cache: torch.Tensor | None = None # Cache state per environment self._kv_cache_len: torch.Tensor | None = None # [B] self._kv_cache_write_idx: torch.Tensor | None = None # [B] self._kv_cache_abs_pos: torch.Tensor | None = None # [B] self._prev_last_moe_router_p: torch.Tensor | None = None self._prev_last_moe_router_valid: torch.Tensor | None = None self._last_moe_router_js_sum: torch.Tensor | None = None self._last_moe_router_js_count: torch.Tensor | None = None self._last_moe_router_top1_switch_sum: torch.Tensor | None = None self._last_moe_router_top1_switch_count: torch.Tensor | None = None aux_cmd_cfg = module_config_dict.get("aux_router_command_recon", {}) self.aux_router_command_recon_enabled = bool( aux_cmd_cfg.get("enabled", False) ) self.aux_router_command_recon_output_dim = int( aux_cmd_cfg.get("output_dim", 0) ) self.aux_router_command_recon_hidden_dim = int( aux_cmd_cfg.get("hidden_dim", self.d_model) ) self._num_moe_layers = sum( 1 for layer in self.layers if isinstance(layer, GroupedMoEBlock) ) if self.aux_router_command_recon_enabled: if self._num_moe_layers <= 0: raise ValueError( "aux_router_command_recon requires at least one GroupedMoEBlock." ) if self.aux_router_command_recon_output_dim <= 0: raise ValueError( "aux_router_command_recon.output_dim must be positive when enabled." ) router_feature_dim = self._num_moe_layers * self.num_fine_experts self.aux_router_command_recon_head = nn.Sequential( nn.Linear( router_feature_dim, self.aux_router_command_recon_hidden_dim, ), nn.SiLU(), nn.Linear( self.aux_router_command_recon_hidden_dim, self.aux_router_command_recon_output_dim, ), ) else: self.aux_router_command_recon_head = None aux_router_future_cfg = module_config_dict.get( "aux_router_future_recon", {} ) self.aux_router_future_recon_enabled = bool( aux_router_future_cfg.get("enabled", False) ) self.aux_router_future_recon_output_dim = int( aux_router_future_cfg.get("output_dim", 0) ) self.aux_router_future_recon_hidden_dim = int( aux_router_future_cfg.get("hidden_dim", self.d_model) ) aux_router_future_norm_cfg = aux_router_future_cfg.get( "target_norm", {} ) self.aux_router_future_recon_norm_eps = float( aux_router_future_norm_cfg.get("epsilon", 1.0e-2) ) self.aux_router_future_recon_norm_update_method = str( aux_router_future_norm_cfg.get("update_method", "cumulative") ).lower() aux_router_future_norm_ema = aux_router_future_norm_cfg.get( "ema_momentum", None ) self.aux_router_future_recon_norm_ema_momentum = ( float(aux_router_future_norm_ema) if aux_router_future_norm_ema is not None else None ) if self.aux_router_future_recon_enabled: if self.aux_router_future_recon_output_dim <= 0: raise ValueError( "aux_router_future_recon.output_dim must be positive when enabled." ) self.aux_router_future_recon_head = nn.Sequential( nn.Linear( self.d_model, self.aux_router_future_recon_hidden_dim, ), nn.SiLU(), nn.Linear( self.aux_router_future_recon_hidden_dim, self.aux_router_future_recon_output_dim, ), ) self.aux_router_future_recon_normalizer = EmpiricalNormalization( shape=self.aux_router_future_recon_output_dim, eps=self.aux_router_future_recon_norm_eps, update_method=self.aux_router_future_recon_norm_update_method, ema_momentum=self.aux_router_future_recon_norm_ema_momentum, ) else: self.aux_router_future_recon_head = None self.aux_router_future_recon_normalizer = None self._apply_base_freeze_router_state() def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): if self.use_future_cross_attn: # In conditional mode, `obs_embed` is tied to `state_obs_embed`. # Older checkpoints may contain separate weights for both; ensure we # always load the trained state embedding weights. obs_prefix = prefix + "obs_embed." state_prefix = prefix + "state_obs_embed." for suffix in ("0.weight", "0.bias", "2.weight", "2.bias"): s_key = state_prefix + suffix o_key = obs_prefix + suffix if s_key in state_dict: state_dict[o_key] = state_dict[s_key] legacy_aux_prefix = prefix + "aux_command_recon_head." current_aux_prefix = prefix + "aux_router_command_recon_head." legacy_aux_keys = [ key for key in list(state_dict.keys()) if key.startswith(legacy_aux_prefix) ] if legacy_aux_keys: if self.aux_router_command_recon_head is not None: for legacy_key in legacy_aux_keys: suffix = legacy_key.removeprefix(legacy_aux_prefix) current_key = current_aux_prefix + suffix state_dict.setdefault(current_key, state_dict[legacy_key]) for legacy_key in legacy_aux_keys: state_dict.pop(legacy_key, None) super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) self._apply_freeze_router_state() def _router_no_grad_context(self): if self.freeze_router: return torch.no_grad() return nullcontext() def _apply_base_freeze_router_state(self) -> None: for layer in self.layers: if isinstance(layer, GroupedMoEBlock): layer._apply_freeze_router_state() def _apply_freeze_router_state(self) -> None: self._apply_base_freeze_router_state() if self.aux_router_future_recon_head is not None: self.aux_router_future_recon_head.requires_grad_( not self.freeze_router ) def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len t = torch.arange( seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype ) # outer product: [seq_len, head_dim/2] freqs = torch.outer(t, self.inv_freq) # Concatenate to match rotate_half: [seq_len, head_dim] # Different from complex, here we just concat freqs to match the real-valued rotation logic emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, head_dim] self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def get_cos_sin(self, x, position_ids): """根据 position_ids 获取 cos/sin x: [B, T, D] position_ids: [B, T] Returns: cos, sin -> [B, T, D] (broadcastable) """ # cos_cached: [MaxLen, D] # F.embedding(pos, cache) -> [B, T, D] cos = F.embedding(position_ids, self.cos_cached) sin = F.embedding(position_ids, self.sin_cached) return cos.to(x.dtype), sin.to(x.dtype) def _init_last_moe_router_shift_state(self, num_envs: int, device) -> None: if self._last_moe_layer_idx is None: self._prev_last_moe_router_p = None self._prev_last_moe_router_valid = None self._last_moe_router_js_sum = None self._last_moe_router_js_count = None self._last_moe_router_top1_switch_sum = None self._last_moe_router_top1_switch_count = None return self._prev_last_moe_router_p = torch.zeros( num_envs, self.num_fine_experts, device=device, dtype=torch.float32, ) self._prev_last_moe_router_valid = torch.zeros( num_envs, device=device, dtype=torch.bool ) self._last_moe_router_js_sum = torch.zeros( (), device=device, dtype=torch.float32 ) self._last_moe_router_js_count = torch.zeros( (), device=device, dtype=torch.float32 ) self._last_moe_router_top1_switch_sum = torch.zeros( (), device=device, dtype=torch.float32 ) self._last_moe_router_top1_switch_count = torch.zeros( (), device=device, dtype=torch.float32 ) def _accumulate_last_moe_router_shift( self, router_distribution: torch.Tensor ) -> None: if ( self._prev_last_moe_router_p is None or self._prev_last_moe_router_valid is None or self._last_moe_router_js_sum is None or self._last_moe_router_js_count is None or self._last_moe_router_top1_switch_sum is None or self._last_moe_router_top1_switch_count is None ): return if ( router_distribution.ndim != 3 or int(router_distribution.shape[1]) != 1 ): return curr_p = router_distribution[:, 0, :].to(torch.float32) if int(curr_p.shape[0]) != int(self._prev_last_moe_router_p.shape[0]): return prev_valid = self._prev_last_moe_router_valid if torch.any(prev_valid): prev_p = self._prev_last_moe_router_p[prev_valid] curr_p_valid = curr_p[prev_valid] mix_p = 0.5 * (curr_p_valid + prev_p) eps = 1.0e-20 curr_safe = curr_p_valid.clamp_min(eps) prev_safe = prev_p.clamp_min(eps) mix_safe = mix_p.clamp_min(eps) kl_curr = ( curr_p_valid * (torch.log(curr_safe) - torch.log(mix_safe)) ).sum(dim=-1) kl_prev = ( prev_p * (torch.log(prev_safe) - torch.log(mix_safe)) ).sum(dim=-1) js = 0.5 * (kl_curr + kl_prev) self._last_moe_router_js_sum.add_(js.sum()) self._last_moe_router_js_count.add_(float(js.numel())) curr_top1 = curr_p_valid.argmax(dim=-1) prev_top1 = prev_p.argmax(dim=-1) switch = (curr_top1 != prev_top1).to(torch.float32) self._last_moe_router_top1_switch_sum.add_(switch.sum()) self._last_moe_router_top1_switch_count.add_(float(switch.numel())) self._prev_last_moe_router_p.copy_(curr_p) self._prev_last_moe_router_valid.fill_(True) def get_last_moe_router_shift_stats( self, ) -> dict[str, torch.Tensor | None]: return { "js_sum": self._last_moe_router_js_sum, "js_count": self._last_moe_router_js_count, "top1_switch_sum": self._last_moe_router_top1_switch_sum, "top1_switch_count": self._last_moe_router_top1_switch_count, } def reset_kv_cache(self, num_envs: int, device): """Initialize per-environment KV cache for single-step inference.""" cache_dtype = ( torch.float16 if torch.device(device).type == "cuda" else torch.float32 ) self._k_cache = torch.zeros( num_envs, self.n_layers, self.max_ctx_len, self.n_kv_heads, self.head_dim, device=device, dtype=cache_dtype, ) self._v_cache = torch.zeros_like(self._k_cache) self._kv_cache_len = torch.zeros( num_envs, dtype=torch.long, device=device ) self._kv_cache_write_idx = torch.zeros( num_envs, dtype=torch.long, device=device ) self._kv_cache_abs_pos = torch.zeros( num_envs, dtype=torch.long, device=device ) self._init_last_moe_router_shift_state(num_envs, device) def clear_env_cache(self, env_ids: torch.Tensor | None): """Reset KV cache state for specific environments.""" if self._k_cache is None: return if env_ids is None: self._k_cache.zero_() self._v_cache.zero_() self._kv_cache_len.zero_() self._kv_cache_write_idx.zero_() self._kv_cache_abs_pos.zero_() if self._prev_last_moe_router_p is not None: self._prev_last_moe_router_p.zero_() if self._prev_last_moe_router_valid is not None: self._prev_last_moe_router_valid.zero_() if self._last_moe_router_js_sum is not None: self._last_moe_router_js_sum.zero_() if self._last_moe_router_js_count is not None: self._last_moe_router_js_count.zero_() if self._last_moe_router_top1_switch_sum is not None: self._last_moe_router_top1_switch_sum.zero_() if self._last_moe_router_top1_switch_count is not None: self._last_moe_router_top1_switch_count.zero_() else: self._k_cache[env_ids] = 0.0 self._v_cache[env_ids] = 0.0 self._kv_cache_len[env_ids] = 0 self._kv_cache_write_idx[env_ids] = 0 self._kv_cache_abs_pos[env_ids] = 0 if self._prev_last_moe_router_valid is not None: self._prev_last_moe_router_valid[env_ids] = False if self._prev_last_moe_router_p is not None: self._prev_last_moe_router_p[env_ids] = 0.0 def set_collect_routing_stats(self, collect: bool) -> None: collect_flag = bool(collect) for layer_idx, layer in enumerate(self.layers): if isinstance(layer, GroupedMoEBlock): layer.collect_routing_stats = collect_flag layer.collect_router_distribution = ( collect_flag and layer_idx == self._last_moe_layer_idx ) def reset_routing_stats(self) -> None: for layer in self.layers: if isinstance(layer, GroupedMoEBlock): layer.reset_routing_stats() def clear_router_distribution_cache(self) -> None: for layer in self.layers: if isinstance(layer, GroupedMoEBlock): layer.last_router_distribution = None layer.last_router_logits = None layer.capture_router_distribution = False layer.capture_router_logits = False def _set_capture_router_distributions(self, capture: bool) -> None: self._set_capture_router_features( capture_distributions=capture, capture_logits=False, ) def _set_capture_router_features( self, *, capture_distributions: bool, capture_logits: bool, ) -> None: capture_distribution_flag = bool(capture_distributions) capture_logits_flag = bool(capture_logits) for layer in self.layers: if isinstance(layer, GroupedMoEBlock): layer.capture_router_distribution = capture_distribution_flag layer.capture_router_logits = capture_logits_flag def apply_dynamic_bias_update_from_stats(self) -> None: for layer in self.layers: if isinstance(layer, GroupedMoEBlock): layer.apply_bias_update_from_counts() def _make_causal_mask(self, T: int, device) -> torch.Tensor: """Generate causal attention mask: shape [T, T], True where attend allowed.""" return torch.tril(torch.ones(T, T, device=device, dtype=torch.bool)) def _forward_layers_range( self, h: torch.Tensor, cos: torch.Tensor | None, sin: torch.Tensor | None, mask: torch.Tensor | None, memory: torch.Tensor | None = None, memory_mask: torch.Tensor | None = None, router_h: torch.Tensor | None = None, router_h_per_layer: list[torch.Tensor | None] | None = None, *, start_layer: int, end_layer: int, return_pre_moe_hidden: bool = False, return_router_features: bool = False, return_router_temporal_features: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward through a contiguous layer range with optional checkpointing.""" if ( start_layer < 0 or end_layer < start_layer or end_layer > len(self.layers) ): raise ValueError( "Invalid layer range for _forward_layers_range: " f"start_layer={start_layer}, end_layer={end_layer}, " f"num_layers={len(self.layers)}." ) pre_moe_hidden = None router_features = [] router_temporal_features = [] self._set_capture_router_features( capture_distributions=return_router_features, capture_logits=return_router_temporal_features, ) try: for layer_idx in range(start_layer, end_layer): layer = self.layers[layer_idx] layer_router_h = router_h if router_h_per_layer is not None: layer_router_h = router_h_per_layer[layer_idx] if self.use_checkpointing and self.training: if isinstance(layer, GroupedMoEBlock): h = checkpoint.checkpoint( layer, h, cos, sin, mask, memory, memory_mask, layer_router_h, use_reentrant=False, ) else: h = checkpoint.checkpoint( layer, h, cos, sin, mask, memory, memory_mask, use_reentrant=False, ) else: if isinstance(layer, GroupedMoEBlock): h = layer( h, cos, sin, mask, memory, memory_mask, router_x=layer_router_h, ) else: h = layer(h, cos, sin, mask, memory, memory_mask) if return_pre_moe_hidden and layer_idx == 0: pre_moe_hidden = h if return_router_features and isinstance( layer, GroupedMoEBlock ): if layer.last_router_distribution is None: raise ValueError( f"Missing router distribution for MoE layer {layer_idx}." ) router_features.append(layer.last_router_distribution) if return_router_temporal_features and isinstance( layer, GroupedMoEBlock ): if layer.last_router_logits is None: raise ValueError( f"Missing router logits for MoE layer {layer_idx}." ) router_temporal_features.append(layer.last_router_logits) finally: self._set_capture_router_features( capture_distributions=False, capture_logits=False, ) outputs: list[torch.Tensor] = [h] if return_pre_moe_hidden: if pre_moe_hidden is None: raise ValueError( "Missing pre-MoE hidden state from the leading dense layer." ) outputs.append(pre_moe_hidden) if return_router_features: if len(router_features) == 0: raise ValueError( "Missing router features while return_router_features=True." ) outputs.append(torch.cat(router_features, dim=-1)) if return_router_temporal_features: if len(router_temporal_features) == 0: raise ValueError( "Missing router temporal features while " "return_router_temporal_features=True." ) outputs.append(torch.cat(router_temporal_features, dim=-1)) if len(outputs) == 1: return outputs[0] return tuple(outputs) def _forward_layers( self, h: torch.Tensor, cos: torch.Tensor | None, sin: torch.Tensor | None, mask: torch.Tensor | None, memory: torch.Tensor | None = None, memory_mask: torch.Tensor | None = None, router_h: torch.Tensor | None = None, router_h_per_layer: list[torch.Tensor | None] | None = None, return_pre_moe_hidden: bool = False, return_router_features: bool = False, return_router_temporal_features: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: return self._forward_layers_range( h, cos, sin, mask, memory, memory_mask, router_h, router_h_per_layer, start_layer=0, end_layer=len(self.layers), return_pre_moe_hidden=return_pre_moe_hidden, return_router_features=return_router_features, return_router_temporal_features=return_router_temporal_features, ) def _compute_router_hidden(self, x: torch.Tensor) -> torch.Tensor | None: return None def sequence_mu( self, x: torch.Tensor, *, attn_mask: torch.Tensor | None = None, return_hidden: bool = False, return_pre_moe_hidden: bool = False, return_router_features: bool = False, return_router_temporal_features: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute per-token action mean for sequences. Args: x: [B, T, D] flat obs per token. attn_mask: [B, T, T] boolean mask (True if attend allowed), or None for causal. return_hidden: If True, also return the hidden states. Returns: mu: [B, T, A] h: [B, T, d_model] (only if return_hidden=True) """ B, T, _ = x.shape h = self.obs_embed(x) # [B, T, d_model] router_h = self._compute_router_hidden(x) # SDPA bool attention mask uses True = allowed (can attend). if attn_mask is not None: tgt_mask = attn_mask.unsqueeze(1) # [B, 1, T, T] # Episode-aware positions: first attendable token is episode start. start_idx = attn_mask.to(torch.int64).argmax(dim=-1) # [B, T] t_idx = torch.arange(T, device=x.device, dtype=torch.long)[ None, : ].expand(B, T) pos = t_idx - start_idx # [B, T] else: tgt_mask = None pos = torch.arange(T, device=x.device, dtype=torch.long)[ None, : ].expand(B, T) cos, sin = self.get_cos_sin(h, pos) # [B, T, head_dim//2] if return_hidden and return_pre_moe_hidden: raise ValueError( "return_hidden and return_pre_moe_hidden cannot both be True." ) forward_out = self._forward_layers( h, cos=cos, sin=sin, mask=tgt_mask, router_h=router_h, return_pre_moe_hidden=return_pre_moe_hidden, return_router_features=return_router_features, return_router_temporal_features=return_router_temporal_features, ) extras: list[torch.Tensor] = [] if isinstance(forward_out, tuple): h = forward_out[0] extras = list(forward_out[1:]) else: h = forward_out h = self.norm_f(h) mu = self.action_mu_head(h) outputs: list[torch.Tensor] = [mu] if return_pre_moe_hidden: outputs.append(extras.pop(0)) if return_router_features: outputs.append(extras.pop(0)) if return_router_temporal_features: outputs.append(extras.pop(0)) if len(outputs) > 1: return tuple(outputs) if return_hidden: return mu, h return mu def sequence_hidden( self, x: torch.Tensor, *, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Compute per-token latent features for sequences. Args: x: [B, T, D] flat obs per token. attn_mask: [B, T, T] boolean mask (True if attend allowed). Returns: h_f: [B, T, d_model] """ B, T, _ = x.shape h = self.obs_embed(x) # [B, T, d_model] router_h = self._compute_router_hidden(x) if attn_mask is not None: tgt_mask = attn_mask.unsqueeze(1) # [B, 1, T, T] start_idx = attn_mask.to(torch.int64).argmax(dim=-1) # [B, T] t_idx = torch.arange(T, device=x.device, dtype=torch.long)[ None, : ].expand(B, T) pos = t_idx - start_idx # [B, T] else: tgt_mask = None pos = torch.arange(T, device=x.device, dtype=torch.long)[ None, : ].expand(B, T) cos, sin = self.get_cos_sin(h, pos) h = self._forward_layers( h, cos=cos, sin=sin, mask=tgt_mask, router_h=router_h, ) h = self.norm_f(h) return h def _embed_future_tokens( self, future_tokens: torch.Tensor ) -> torch.Tensor: if not self.use_future_cross_attn: raise ValueError( "_embed_future_tokens requires use_future_cross_attn=True" ) if future_tokens.ndim == 3: b, n, d = future_tokens.shape if n != self.future_seq_len: raise ValueError( f"future token length mismatch: expected {self.future_seq_len}, got {n}" ) if d != self.future_token_dim: raise ValueError( f"future token dim mismatch: expected {self.future_token_dim}, got {d}" ) pos = torch.arange( n, device=future_tokens.device, dtype=torch.long ) pos_emb = self.future_pos_embed(pos)[None, :, :] return self.future_obs_embed(future_tokens) + pos_emb if future_tokens.ndim == 4: b, t, n, d = future_tokens.shape if n != self.future_seq_len: raise ValueError( f"future token length mismatch: expected {self.future_seq_len}, got {n}" ) if d != self.future_token_dim: raise ValueError( f"future token dim mismatch: expected {self.future_token_dim}, got {d}" ) pos = torch.arange( n, device=future_tokens.device, dtype=torch.long ) pos_emb = self.future_pos_embed(pos)[None, None, :, :] return self.future_obs_embed(future_tokens) + pos_emb raise ValueError( f"future_tokens must be 3D or 4D, got shape {tuple(future_tokens.shape)}" ) def sequence_mu_cond( self, state_seq: torch.Tensor, future_seq: torch.Tensor, *, attn_mask: torch.Tensor | None = None, future_mask: torch.Tensor | None = None, return_pre_moe_hidden: bool = False, return_router_features: bool = False, return_router_temporal_features: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: if not self.use_future_cross_attn: raise ValueError( "sequence_mu_cond requires use_future_cross_attn=True" ) if state_seq.ndim != 3: raise ValueError( f"state_seq must have shape [B, T, D], got {tuple(state_seq.shape)}" ) if future_seq.ndim != 4: raise ValueError( "future_seq must have shape [B, T, N_fut, D_fut], " f"got {tuple(future_seq.shape)}" ) b, t, d_state = state_seq.shape bf, tf, n_fut, d_fut = future_seq.shape if bf != b or tf != t: raise ValueError( "state_seq and future_seq batch/time mismatch: " f"state={tuple(state_seq.shape)}, future={tuple(future_seq.shape)}" ) if d_state != self.state_obs_dim: raise ValueError( f"state_seq dim mismatch: expected {self.state_obs_dim}, got {d_state}" ) if n_fut != self.future_seq_len: raise ValueError( f"future_seq len mismatch: expected {self.future_seq_len}, got {n_fut}" ) if d_fut != self.future_token_dim: raise ValueError( f"future_seq dim mismatch: expected {self.future_token_dim}, got {d_fut}" ) h = self.state_obs_embed(state_seq) memory = self._embed_future_tokens(future_seq) if future_mask is None: future_mask = torch.ones( b, t, n_fut, dtype=torch.bool, device=state_seq.device, ) if future_mask.shape != (b, t, n_fut): raise ValueError( "future_mask shape mismatch: expected " f"{(b, t, n_fut)}, got {tuple(future_mask.shape)}" ) if attn_mask is not None: tgt_mask = attn_mask.unsqueeze(1) start_idx = attn_mask.to(torch.int64).argmax(dim=-1) t_idx = torch.arange(t, device=state_seq.device, dtype=torch.long)[ None, : ].expand(b, t) pos = t_idx - start_idx else: tgt_mask = None pos = torch.arange(t, device=state_seq.device, dtype=torch.long)[ None, : ].expand(b, t) cos, sin = self.get_cos_sin(h, pos) forward_out = self._forward_layers( h, cos=cos, sin=sin, mask=tgt_mask, memory=memory, memory_mask=future_mask, return_pre_moe_hidden=return_pre_moe_hidden, return_router_features=return_router_features, return_router_temporal_features=return_router_temporal_features, ) extras: list[torch.Tensor] = [] if isinstance(forward_out, tuple): h = forward_out[0] extras = list(forward_out[1:]) else: h = forward_out h = self.norm_f(h) mu = self.action_mu_head(h) outputs: list[torch.Tensor] = [mu] if return_pre_moe_hidden: outputs.append(extras.pop(0)) if return_router_features: outputs.append(extras.pop(0)) if return_router_temporal_features: outputs.append(extras.pop(0)) if len(outputs) > 1: return tuple(outputs) return mu def predict_aux_from_pre_moe( self, pre_moe_hidden: torch.Tensor, *, ref_aux_hidden: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: if not self.aux_state_pred_enabled: raise ValueError( "predict_aux_from_pre_moe requires aux_state_pred.enabled=True." ) if pre_moe_hidden.ndim != 3: raise ValueError( f"Expected pre_moe_hidden with shape [B, T, D], got {tuple(pre_moe_hidden.shape)}" ) vel_params = self.aux_vel_head(pre_moe_hidden) height_params = self.aux_height_head(pre_moe_hidden) vel_loc, vel_log_std = vel_params.chunk(2, dim=-1) height_loc, height_log_std = height_params.chunk(2, dim=-1) aux_outputs = { "base_lin_vel_loc": vel_loc, "base_lin_vel_log_std": vel_log_std, "root_height_loc": height_loc, "root_height_log_std": height_log_std, } if self.aux_contact_head is not None: aux_outputs["keybody_contact_logits"] = self.aux_contact_head( pre_moe_hidden ) else: aux_outputs["keybody_contact_logits"] = pre_moe_hidden.new_zeros( pre_moe_hidden.shape[0], pre_moe_hidden.shape[1], 0, ) if self.aux_denoise_ref_root_lin_vel_head is not None: aux_outputs["denoise_ref_root_lin_vel_residual"] = ( self.aux_denoise_ref_root_lin_vel_head(pre_moe_hidden) ) if self.aux_denoise_ref_root_ang_vel_head is not None: aux_outputs["denoise_ref_root_ang_vel_residual"] = ( self.aux_denoise_ref_root_ang_vel_head(pre_moe_hidden) ) if self.aux_ref_keybody_pos_head is not None: aux_outputs["ref_keybody_rel_pos"] = self.aux_ref_keybody_pos_head( pre_moe_hidden ).reshape( pre_moe_hidden.shape[0], pre_moe_hidden.shape[1], self.aux_keybody_pos_dim, 3, ) aux_outputs["robot_keybody_rel_pos"] = ( self.aux_robot_keybody_pos_head(pre_moe_hidden).reshape( pre_moe_hidden.shape[0], pre_moe_hidden.shape[1], self.aux_keybody_pos_dim, 3, ) ) else: aux_outputs["ref_keybody_rel_pos"] = pre_moe_hidden.new_zeros( pre_moe_hidden.shape[0], pre_moe_hidden.shape[1], 0, 3, ) aux_outputs["robot_keybody_rel_pos"] = pre_moe_hidden.new_zeros( pre_moe_hidden.shape[0], pre_moe_hidden.shape[1], 0, 3, ) if self.aux_denoise_ref_dof_pos_head is not None: aux_outputs["denoise_ref_dof_pos_residual"] = ( self.aux_denoise_ref_dof_pos_head(pre_moe_hidden) ) return aux_outputs def predict_aux_router_command_from_router_features( self, router_features: torch.Tensor ) -> torch.Tensor: if not self.aux_router_command_recon_enabled: raise ValueError( "predict_aux_router_command_from_router_features requires " "aux_router_command_recon.enabled=True." ) if router_features.ndim != 3: raise ValueError( "Expected router_features with shape [B, T, D], got " f"{tuple(router_features.shape)}." ) if self.aux_router_command_recon_head is None: raise ValueError( "aux_router_command_recon_head is not initialized." ) return self.aux_router_command_recon_head(router_features) def update_aux_router_future_recon_normalizer( self, future_target: torch.Tensor ) -> None: if not self.aux_router_future_recon_enabled: raise ValueError( "update_aux_router_future_recon_normalizer requires " "aux_router_future_recon.enabled=True." ) if self.aux_router_future_recon_normalizer is None: raise ValueError( "aux_router_future_recon_normalizer is not initialized." ) if future_target.ndim < 2: raise ValueError( "Expected future_target with shape [B, D] or [B, T, D], got " f"{tuple(future_target.shape)}." ) flat_target = future_target.reshape( -1, future_target.shape[-1] ).detach() self.aux_router_future_recon_normalizer.update(flat_target) def normalize_aux_router_future_recon_target( self, future_target: torch.Tensor ) -> torch.Tensor: if not self.aux_router_future_recon_enabled: raise ValueError( "normalize_aux_router_future_recon_target requires " "aux_router_future_recon.enabled=True." ) if self.aux_router_future_recon_normalizer is None: raise ValueError( "aux_router_future_recon_normalizer is not initialized." ) if future_target.ndim < 2: raise ValueError( "Expected future_target with shape [B, D] or [B, T, D], got " f"{tuple(future_target.shape)}." ) flat_target = future_target.reshape(-1, future_target.shape[-1]) norm_target = self.aux_router_future_recon_normalizer.normalize_only( flat_target ) return norm_target.reshape_as(future_target) def predict_aux_router_future_recon_from_router_hidden( self, router_hidden: torch.Tensor ) -> torch.Tensor: if not self.aux_router_future_recon_enabled: raise ValueError( "predict_aux_router_future_recon_from_router_hidden requires " "aux_router_future_recon.enabled=True." ) if router_hidden.ndim != 3: raise ValueError( "Expected router_hidden with shape [B, T, D], got " f"{tuple(router_hidden.shape)}." ) if self.aux_router_future_recon_head is None: raise ValueError( "aux_router_future_recon_head is not initialized." ) return self.aux_router_future_recon_head(router_hidden) def single_step_mu_cond( self, state_x: torch.Tensor, future_tokens: torch.Tensor, *, future_mask: torch.Tensor | None = None, ) -> torch.Tensor: if not self.use_future_cross_attn: raise ValueError( "single_step_mu_cond requires use_future_cross_attn=True" ) if state_x.ndim != 2: raise ValueError(f"Expected state_x [B, D], got {state_x.shape}") if future_tokens.ndim != 3: raise ValueError( "Expected future_tokens [B, N_fut, D_fut], " f"got {future_tokens.shape}" ) b, d_state = state_x.shape bf, n_fut, d_fut = future_tokens.shape if bf != b: raise ValueError( f"Batch mismatch between state_x and future_tokens: {b} vs {bf}" ) if d_state != self.state_obs_dim: raise ValueError( f"state_x dim mismatch: expected {self.state_obs_dim}, got {d_state}" ) if n_fut != self.future_seq_len: raise ValueError( f"future len mismatch: expected {self.future_seq_len}, got {n_fut}" ) if d_fut != self.future_token_dim: raise ValueError( f"future dim mismatch: expected {self.future_token_dim}, got {d_fut}" ) if self._k_cache is None: state_seq = state_x[:, None, :] future_seq = future_tokens[:, None, :, :] if future_mask is not None: future_mask = future_mask[:, None, :] mu_seq = self.sequence_mu_cond( state_seq, future_seq, attn_mask=None, future_mask=future_mask, ) return mu_seq[:, 0, :] if self._k_cache.device != state_x.device: self._k_cache = self._k_cache.to(state_x.device) self._v_cache = self._v_cache.to(state_x.device) self._kv_cache_len = self._kv_cache_len.to(state_x.device) self._kv_cache_write_idx = self._kv_cache_write_idx.to( state_x.device ) self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(state_x.device) h = self.state_obs_embed(state_x)[:, None, :] memory = self._embed_future_tokens(future_tokens) if self._k_cache.dtype != h.dtype: self._k_cache = self._k_cache.to(h.dtype) self._v_cache = self._v_cache.to(h.dtype) cache_len = self._kv_cache_len insert_pos = self._kv_cache_write_idx max_len = int(self.max_ctx_len) new_len = torch.clamp(cache_len + 1, max=max_len) self._kv_cache_len = new_len self._kv_cache_write_idx = (insert_pos + 1) % max_len pos = self._kv_cache_abs_pos self._kv_cache_abs_pos = pos + 1 pos_ids = pos.unsqueeze(1) cos, sin = self.get_cos_sin(h, pos_ids) memory_mask = None if future_mask is not None: if future_mask.shape != (b, n_fut): raise ValueError( "future_mask shape mismatch for single-step path: expected " f"{(b, n_fut)}, got {tuple(future_mask.shape)}" ) memory_mask = future_mask[:, None, None, :] for layer_idx, layer in enumerate(self.layers): x_norm = layer.norm1(h) k_cache_l = self._k_cache[:, layer_idx] v_cache_l = self._v_cache[:, layer_idx] attn_out, _, _ = layer.attn.forward_single_token( x_norm, cos, sin, k_cache_l, v_cache_l, new_len, insert_pos, ) h = h + attn_out if layer.use_cross_attn: h = h + layer.cross_attn( layer.norm_cross(h), memory, memory_mask ) h2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): ffn = layer.compute_moe_ffn(h2) if ( layer_idx == self._last_moe_layer_idx and layer.collect_routing_stats and layer.last_router_distribution is not None ): self._accumulate_last_moe_router_shift( layer.last_router_distribution ) else: ffn = layer.mlp_dropout(layer.mlp(h2)) h = h + ffn h = self.norm_f(h) return self.action_mu_head(h[:, 0, :]) def forward( self, input: torch.Tensor, past_key_values: torch.Tensor | None = None, current_pos: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Forward pass for single-step inference (no history).""" if past_key_values is not None: return self._forward_inference_onnx( input, past_key_values, current_pos ) if input.ndim != 2: raise ValueError(f"Expected [B, D], got {input.shape}") mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None) return mu_seq[:, 0, :] def single_step_mu(self, x: torch.Tensor) -> torch.Tensor: """Compute action mean for a single step using per-layer KV cache. Uses a ring-buffer KV cache with per-env absolute positions for RoPE. """ if x.ndim != 2: raise ValueError(f"Expected [B, D], got {x.shape}") B, _ = x.shape if self._k_cache is None: mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None) return mu_seq[:, 0, :] # Ensure cache device matches if self._k_cache.device != x.device: self._k_cache = self._k_cache.to(x.device) self._v_cache = self._v_cache.to(x.device) self._kv_cache_len = self._kv_cache_len.to(x.device) self._kv_cache_write_idx = self._kv_cache_write_idx.to(x.device) self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(x.device) h = self.obs_embed(x)[:, None, :] # [B, 1, d_model] router_h = self._compute_router_hidden(x) if router_h is not None: router_h = router_h[:, None, :] # Ensure cache dtype matches compute dtype (convert once if needed) if self._k_cache.dtype != h.dtype: self._k_cache = self._k_cache.to(h.dtype) self._v_cache = self._v_cache.to(h.dtype) cache_len = self._kv_cache_len # [B] insert_pos = self._kv_cache_write_idx # [B] max_len = int(self.max_ctx_len) new_len = torch.clamp(cache_len + 1, max=max_len) # [B] self._kv_cache_len = new_len self._kv_cache_write_idx = (insert_pos + 1) % max_len # RoPE frequencies for current absolute position pos = self._kv_cache_abs_pos # [B] self._kv_cache_abs_pos = pos + 1 pos_ids = pos.unsqueeze(1) # [B, 1] cos, sin = self.get_cos_sin(h, pos_ids) for layer_idx, layer in enumerate(self.layers): x_norm = layer.norm1(h) k_cache_l = self._k_cache[ :, layer_idx ] # [B, L, n_kv_heads, head_dim] v_cache_l = self._v_cache[:, layer_idx] attn_out, _, _ = layer.attn.forward_single_token( x_norm, cos, sin, k_cache_l, v_cache_l, new_len, insert_pos, ) h = h + attn_out # FFN/MoE path for single token (h: [B,1,D]) h2 = layer.norm2(h) # [B,1,D] if isinstance(layer, GroupedMoEBlock): ffn = layer.compute_moe_ffn(h2, router_x=router_h) if ( layer_idx == self._last_moe_layer_idx and layer.collect_routing_stats and layer.last_router_distribution is not None ): self._accumulate_last_moe_router_shift( layer.last_router_distribution ) else: ffn = layer.mlp_dropout(layer.mlp(h2)) # 保持 [B, 1, D] h = h + ffn h = self.norm_f(h) return self.action_mu_head(h[:, 0, :]) def _forward_inference_onnx( self, x: torch.Tensor, past_key_values: torch.Tensor, current_pos: torch.Tensor, ) -> tuple[torch.Tensor, ...]: """Single-step inference compatible with ONNX export. Aligns strictly with `single_step_mu` logic using Real-valued RoPE. Args: x: [B, D] (Batch=1 for ONNX usually) past_key_values: [n_layers, 2, B, max_len, n_kv_heads, head_dim] current_pos: [B] or scalar, the absolute step index (0, 1, 2...) Returns: action: [B, A] present_key_values: Updated KV cache tensor """ # Embedding [B, D] -> [B, 1, D] h = self.obs_embed(x)[:, None, :] # [1, 1, 512] router_h = self._compute_router_hidden(x) if router_h is not None: router_h = router_h[:, None, :] B = h.shape[0] # 1 # Calculate Cache Indices (Ring Buffer Logic) # past_key_values shape: [L, 2, B, T, H, D] -> T is index 3 max_len = past_key_values.shape[3] # 32 if current_pos.ndim == 0: current_pos = current_pos.view(1).expand(B) # insert_pos: [B] insert_pos = current_pos % max_len # new_len: [B] new_len = torch.clamp(current_pos + 1, max=max_len) # position_ids: [B, 1] position_ids = current_pos.unsqueeze(1) # cos, sin shape: [B, 1, head_dim] cos, sin = self.get_cos_sin(h, position_ids) present_key_values_list = [] routing_debug_outputs: list[torch.Tensor] = [] export_routing_debug = torch.onnx.is_in_onnx_export() for i, layer in enumerate(self.layers): # Unpack Cache: [2, B, T, H, D] layer_past = past_key_values[i] k_cache = layer_past[0] v_cache = layer_past[1] h_norm = layer.norm1(h) # Attention attn_out, new_k_cache, new_v_cache = ( layer.attn.forward_single_token( x=h_norm, cos=cos, sin=sin, k_cache=k_cache, v_cache=v_cache, new_len=new_len, insert_pos=insert_pos, ) ) h = h + attn_out # FFN / MoE h_norm2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): if export_routing_debug: ffn_out, topk_idx, router_logits = layer.compute_moe_ffn( h_norm2, router_x=router_h, return_routing_debug=True, ) routing_debug_outputs.extend([topk_idx, router_logits]) else: ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h) else: # Dense MLP ffn_out = layer.mlp_dropout(layer.mlp(h_norm2)) h = h + ffn_out current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0) present_key_values_list.append(current_layer_kv) h = self.norm_f(h) action = self.action_mu_head(h[:, 0, :]) present_key_values = torch.stack(present_key_values_list, dim=0) if export_routing_debug and routing_debug_outputs: return (action, present_key_values, *routing_debug_outputs) return action, present_key_values class ReferenceRoutedGroupedMoETransformerPolicy(GroupedMoETransformerPolicy): def __init__( self, input_dim: int, output_dim: int, module_config_dict: dict, ): module_config = dict(module_config_dict) if bool(module_config.get("use_future_cross_attn", False)): raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicy does not support " "use_future_cross_attn=True." ) router_input_dim = module_config.get("router_input_dim", None) router_feature_indices = module_config.get( "router_feature_indices", None ) if router_input_dim is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicy requires router_input_dim." ) if router_feature_indices is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicy requires " "router_feature_indices." ) self.router_input_dim = int(router_input_dim) self.router_feature_indices = tuple( int(idx) for idx in router_feature_indices ) if self.router_input_dim <= 0: raise ValueError( f"router_input_dim must be positive, got {self.router_input_dim}." ) if len(self.router_feature_indices) != self.router_input_dim: raise ValueError( "router_input_dim must match len(router_feature_indices): " f"{self.router_input_dim} vs {len(self.router_feature_indices)}." ) if any(idx < 0 for idx in self.router_feature_indices): raise ValueError( f"router_feature_indices must be non-negative, got {self.router_feature_indices}." ) super().__init__( input_dim=input_dim, output_dim=output_dim, module_config_dict=module_config, ) obs_in = int(self.obs_input_dim or self.input_dim) if any(idx >= obs_in for idx in self.router_feature_indices): raise ValueError( "router_feature_indices exceed the flat actor obs dim " f"{obs_in}: {self.router_feature_indices}" ) self.router_embed_mlp_hidden = int( module_config.get( "router_embed_mlp_hidden", self.obs_embed_mlp_hidden ) ) self.register_buffer( "_router_feature_indices", torch.tensor(self.router_feature_indices, dtype=torch.long), persistent=False, ) self.router_obs_embed = nn.Sequential( nn.Linear(self.router_input_dim, self.router_embed_mlp_hidden), nn.SiLU(), nn.Linear(self.router_embed_mlp_hidden, self.d_model), ) self._apply_freeze_router_state() def _apply_freeze_router_state(self) -> None: super()._apply_freeze_router_state() self.router_obs_embed.requires_grad_(not self.freeze_router) def _compute_router_hidden(self, x: torch.Tensor) -> torch.Tensor | None: if x.shape[-1] != int(self.obs_input_dim or self.input_dim): raise ValueError( "Reference-routed policy expected flat obs dim " f"{int(self.obs_input_dim or self.input_dim)}, got {x.shape[-1]}." ) router_idx = self._router_feature_indices if router_idx.device != x.device: router_idx = router_idx.to(x.device) router_obs = torch.index_select(x, dim=x.ndim - 1, index=router_idx) return self.router_obs_embed(router_obs) def _forward_inference_onnx_cond( self, state_x: torch.Tensor, future_tokens: torch.Tensor, past_key_values: torch.Tensor, current_pos: torch.Tensor, ) -> tuple[torch.Tensor, ...]: if not self.use_future_cross_attn: raise ValueError( "_forward_inference_onnx_cond requires use_future_cross_attn=True" ) if state_x.ndim != 2: raise ValueError( f"state_x must have shape [B, D_state], got {tuple(state_x.shape)}" ) if future_tokens.ndim != 3: raise ValueError( "future_tokens must have shape [B, N_fut, D_fut], " f"got {tuple(future_tokens.shape)}" ) h = self.state_obs_embed(state_x)[:, None, :] memory = self._embed_future_tokens(future_tokens) b = h.shape[0] max_len = past_key_values.shape[3] if current_pos.ndim == 0: current_pos = current_pos.view(1).expand(b) insert_pos = current_pos % max_len new_len = torch.clamp(current_pos + 1, max=max_len) position_ids = current_pos.unsqueeze(1) cos, sin = self.get_cos_sin(h, position_ids) present_key_values_list = [] routing_debug_outputs: list[torch.Tensor] = [] export_routing_debug = torch.onnx.is_in_onnx_export() for i, layer in enumerate(self.layers): layer_past = past_key_values[i] k_cache = layer_past[0] v_cache = layer_past[1] h_norm = layer.norm1(h) attn_out, new_k_cache, new_v_cache = ( layer.attn.forward_single_token( x=h_norm, cos=cos, sin=sin, k_cache=k_cache, v_cache=v_cache, new_len=new_len, insert_pos=insert_pos, ) ) h = h + attn_out if layer.use_cross_attn: h = h + layer.cross_attn(layer.norm_cross(h), memory, None) h_norm2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): if export_routing_debug: ffn_out, topk_idx, router_logits = layer.compute_moe_ffn( h_norm2, return_routing_debug=True, ) routing_debug_outputs.extend([topk_idx, router_logits]) else: ffn_out = layer.compute_moe_ffn(h_norm2) else: ffn_out = layer.mlp_dropout(layer.mlp(h_norm2)) h = h + ffn_out current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0) present_key_values_list.append(current_layer_kv) h = self.norm_f(h) action = self.action_mu_head(h[:, 0, :]) present_key_values = torch.stack(present_key_values_list, dim=0) if export_routing_debug and routing_debug_outputs: return (action, present_key_values, *routing_debug_outputs) return action, present_key_values class ReferenceRoutedGroupedMoETransformerPolicyV2( GroupedMoETransformerPolicy ): supports_explicit_ref_aux_hidden = True def __init__( self, input_dim: int, output_dim: int, module_config_dict: dict, ): module_config = dict(module_config_dict) if bool(module_config.get("use_future_cross_attn", False)): raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 does not " "support use_future_cross_attn=True." ) state_obs_input_dim = module_config.get("state_obs_input_dim", None) ref_cur_token_dim = module_config.get("ref_cur_token_dim", None) ref_fut_token_dim = module_config.get("ref_fut_token_dim", None) ref_fut_seq_len = module_config.get("ref_fut_seq_len", None) state_feature_indices = module_config.get( "state_feature_indices", None ) ref_cur_feature_indices = module_config.get( "ref_cur_feature_indices", None ) ref_fut_slices = module_config.get("ref_fut_slices", None) if state_obs_input_dim is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 requires " "state_obs_input_dim." ) if ref_cur_token_dim is None or ref_fut_token_dim is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 requires " "ref_cur_token_dim and ref_fut_token_dim." ) if ref_fut_seq_len is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 requires " "ref_fut_seq_len." ) if state_feature_indices is None or ref_cur_feature_indices is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 requires " "state_feature_indices and ref_cur_feature_indices." ) if ref_fut_slices is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 requires " "ref_fut_slices." ) self.full_obs_input_dim = int(input_dim) self.state_obs_input_dim = int(state_obs_input_dim) self.ref_cur_token_dim = int(ref_cur_token_dim) self.ref_fut_token_dim = int(ref_fut_token_dim) self.ref_fut_seq_len = int(ref_fut_seq_len) self.state_feature_indices = tuple( int(idx) for idx in state_feature_indices ) self.ref_cur_feature_indices = tuple( int(idx) for idx in ref_cur_feature_indices ) self.ref_fut_slices = tuple( (int(start), int(end), int(dim)) for start, end, dim in ref_fut_slices ) if self.state_obs_input_dim <= 0: raise ValueError( "state_obs_input_dim must be positive, got " f"{self.state_obs_input_dim}." ) if self.ref_cur_token_dim <= 0 or self.ref_fut_token_dim <= 0: raise ValueError( "ref token dims must be positive, got " f"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}." ) if self.ref_cur_token_dim != self.ref_fut_token_dim: raise ValueError( "current/future ref token dims must match, got " f"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}." ) if self.ref_fut_seq_len <= 0: raise ValueError( f"ref_fut_seq_len must be positive, got {self.ref_fut_seq_len}." ) if len(self.state_feature_indices) != self.state_obs_input_dim: raise ValueError( "state_obs_input_dim must match len(state_feature_indices): " f"{self.state_obs_input_dim} vs {len(self.state_feature_indices)}." ) if len(self.ref_cur_feature_indices) != self.ref_cur_token_dim: raise ValueError( "ref_cur_token_dim must match len(ref_cur_feature_indices): " f"{self.ref_cur_token_dim} vs {len(self.ref_cur_feature_indices)}." ) fut_flat_dim = 0 for start, end, dim in self.ref_fut_slices: if end <= start or dim <= 0: raise ValueError( f"Invalid ref_fut_slices entry {(start, end, dim)}." ) if (end - start) != self.ref_fut_seq_len * dim: raise ValueError( "Future ref slice span must equal ref_fut_seq_len * dim, got " f"{(start, end, dim)} with ref_fut_seq_len={self.ref_fut_seq_len}." ) fut_flat_dim += end - start expected_full_input_dim = ( self.state_obs_input_dim + self.ref_cur_token_dim + fut_flat_dim ) if self.full_obs_input_dim != expected_full_input_dim: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 expected full " f"input dim {expected_full_input_dim}, got {self.full_obs_input_dim}." ) self.ref_hist_n_layers = int(module_config.get("ref_hist_n_layers", 1)) if self.ref_hist_n_layers != 1: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV2 currently supports " "exactly one ref history attention layer." ) self.ref_future_conv_channels = int( module_config.get( "ref_future_conv_channels", self.ref_cur_token_dim ) ) self.ref_future_conv_layers = int( module_config.get("ref_future_conv_layers", 2) ) self.ref_future_conv_kernel_size = int( module_config.get("ref_future_conv_kernel_size", 3) ) self.ref_future_conv_stride = int( module_config.get("ref_future_conv_stride", 2) ) if self.ref_future_conv_layers <= 0: raise ValueError( "ref_future_conv_layers must be positive, got " f"{self.ref_future_conv_layers}." ) if self.ref_future_conv_kernel_size <= 0: raise ValueError( "ref_future_conv_kernel_size must be positive, got " f"{self.ref_future_conv_kernel_size}." ) if self.ref_future_conv_stride <= 0: raise ValueError( "ref_future_conv_stride must be positive, got " f"{self.ref_future_conv_stride}." ) module_config["input_dim_override"] = self.state_obs_input_dim super().__init__( input_dim=input_dim, output_dim=output_dim, module_config_dict=module_config, ) self.onnx_kv_layers = int(self.ref_hist_n_layers + self.n_layers) self.register_buffer( "_state_feature_indices", torch.tensor(self.state_feature_indices, dtype=torch.long), persistent=False, ) self.register_buffer( "_ref_cur_feature_indices", torch.tensor(self.ref_cur_feature_indices, dtype=torch.long), persistent=False, ) self.ref_frame_embed = nn.Sequential( nn.Linear(self.ref_cur_token_dim, self.obs_embed_mlp_hidden), nn.SiLU(), nn.Linear(self.obs_embed_mlp_hidden, self.d_model), ) self.ref_hist_norm = RMSNorm(self.d_model) self.ref_hist_attn = ModernAttention( d_model=self.d_model, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, use_qk_norm=self.use_qk_norm, use_gated_attn=self.use_gated_attn, gated_attn_type=self.gated_attn_type, attn_dropout=self.attn_dropout, ) self.ref_hist_out_norm = RMSNorm(self.d_model) padding = self.ref_future_conv_kernel_size // 2 conv_modules: list[nn.Module] = [] in_ch = self.d_model for layer_idx in range(self.ref_future_conv_layers): out_ch = ( self.d_model if layer_idx == self.ref_future_conv_layers - 1 else self.ref_future_conv_channels ) conv_modules.append( nn.Conv1d( in_channels=in_ch, out_channels=out_ch, kernel_size=self.ref_future_conv_kernel_size, stride=self.ref_future_conv_stride, padding=padding, bias=True, ) ) conv_modules.append(nn.SiLU()) in_ch = out_ch self.ref_future_conv = nn.Sequential(*conv_modules) self.actor_ref_pool = SingleQueryAttentionPool(self.d_model) self.router_ref_pool = SingleQueryAttentionPool(self.d_model) self.router_query = nn.Parameter(torch.zeros(self.d_model)) self.actor_ref_ctx_norm = RMSNorm(self.d_model) self.actor_film_hidden_norm = RMSNorm(self.d_model) self.actor_ref_film = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.SiLU(), nn.Linear(self.d_model, 2 * self.d_model), ) nn.init.zeros_(self.actor_ref_film[-1].weight) nn.init.zeros_(self.actor_ref_film[-1].bias) self.actor_film_gain_max = float( module_config.get("actor_film_gain_max", 1.0) ) self.actor_film_gain_init = float( module_config.get("actor_film_gain_init", 0.05) ) if self.actor_film_gain_max <= 0.0: raise ValueError( "actor_film_gain_max must be positive, got " f"{self.actor_film_gain_max}." ) if not (0.0 < self.actor_film_gain_init < self.actor_film_gain_max): raise ValueError( "actor_film_gain_init must be in (0, actor_film_gain_max), " f"got {self.actor_film_gain_init} with max " f"{self.actor_film_gain_max}." ) gain_init_ratio = self.actor_film_gain_init / self.actor_film_gain_max self.actor_film_gain_raw = nn.Parameter( torch.full( (self.d_model,), math.log(gain_init_ratio / (1.0 - gain_init_ratio)), ) ) self.actor_film_scale_max = 0.5 self.actor_film_shift_max = 0.5 self.actor_film_delta_rms_eps = float( module_config.get("actor_film_delta_rms_eps", 1.0e-6) ) if self.actor_film_delta_rms_eps <= 0.0: raise ValueError( "actor_film_delta_rms_eps must be positive, got " f"{self.actor_film_delta_rms_eps}." ) self._ref_hist_k_cache: torch.Tensor | None = None self._ref_hist_v_cache: torch.Tensor | None = None self._apply_freeze_router_state() def _apply_freeze_router_state(self) -> None: super()._apply_freeze_router_state() requires_grad = not self.freeze_router self.ref_frame_embed.requires_grad_(requires_grad) self.ref_hist_norm.requires_grad_(requires_grad) self.ref_hist_attn.requires_grad_(requires_grad) self.ref_hist_out_norm.requires_grad_(requires_grad) self.ref_future_conv.requires_grad_(requires_grad) self.router_ref_pool.requires_grad_(requires_grad) self.router_query.requires_grad_(requires_grad) def _build_shared_ref_tokens( self, ref_cur_x: torch.Tensor, ref_fut_x: torch.Tensor, pos: torch.Tensor, tgt_mask: torch.Tensor | None, ) -> torch.Tensor: with self._router_no_grad_context(): ref_cur_h = self.ref_frame_embed(ref_cur_x) ref_hist_attn = self.ref_hist_attn( self.ref_hist_norm(ref_cur_h), *self.get_cos_sin(ref_cur_h, pos), mask=tgt_mask, ) ref_hist_h = self.ref_hist_out_norm(ref_cur_h + ref_hist_attn) ref_fut_tokens = self._encode_future_tokens(ref_fut_x) return torch.cat([ref_hist_h.unsqueeze(2), ref_fut_tokens], dim=2) def _build_shared_ref_tokens_single_step( self, ref_cur_x: torch.Tensor, ref_fut_x: torch.Tensor, pos_ids: torch.Tensor, *, k_cache: torch.Tensor, v_cache: torch.Tensor, new_len: torch.Tensor, insert_pos: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: with self._router_no_grad_context(): ref_cur_h = self.ref_frame_embed(ref_cur_x)[:, None, :] ref_cos, ref_sin = self.get_cos_sin(ref_cur_h, pos_ids) ref_hist_attn, ref_k_cache, ref_v_cache = ( self.ref_hist_attn.forward_single_token( self.ref_hist_norm(ref_cur_h), ref_cos, ref_sin, k_cache, v_cache, new_len, insert_pos, ) ) ref_hist_h = self.ref_hist_out_norm(ref_cur_h + ref_hist_attn) ref_fut_tokens = self._encode_future_tokens(ref_fut_x) shared_ref_tokens = torch.cat([ref_hist_h, ref_fut_tokens], dim=1) return shared_ref_tokens, ref_k_cache, ref_v_cache def _split_actor_ref_inputs( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if x.ndim not in (2, 3): raise ValueError( f"Expected full obs tensor with ndim 2 or 3, got {x.ndim}." ) if int(x.shape[-1]) != self.full_obs_input_dim: raise ValueError( "Full obs dim mismatch for reference router V2: expected " f"{self.full_obs_input_dim}, got {int(x.shape[-1])}." ) state_idx = self._state_feature_indices.to(x.device) ref_cur_idx = self._ref_cur_feature_indices.to(x.device) state_x = torch.index_select(x, dim=x.ndim - 1, index=state_idx) ref_cur_x = torch.index_select(x, dim=x.ndim - 1, index=ref_cur_idx) fut_parts: list[torch.Tensor] = [] for start, end, dim in self.ref_fut_slices: chunk = x[..., start:end] if x.ndim == 2: fut_parts.append( chunk.reshape(int(x.shape[0]), self.ref_fut_seq_len, dim) ) else: fut_parts.append( chunk.reshape( int(x.shape[0]), int(x.shape[1]), self.ref_fut_seq_len, dim, ) ) ref_fut_x = torch.cat(fut_parts, dim=-1) return state_x, ref_cur_x, ref_fut_x def _encode_future_tokens(self, ref_fut_x: torch.Tensor) -> torch.Tensor: if ref_fut_x.ndim == 3: fut = self.ref_frame_embed(ref_fut_x) return self.ref_future_conv(fut.transpose(1, 2)).transpose(1, 2) if ref_fut_x.ndim == 4: batch, time, seq_len, dim = ref_fut_x.shape fut = self.ref_frame_embed( ref_fut_x.reshape(batch * time, seq_len, dim) ) fut = self.ref_future_conv(fut.transpose(1, 2)).transpose(1, 2) return fut.reshape(batch, time, fut.shape[1], self.d_model) raise ValueError( f"Expected ref_fut_x with ndim 3 or 4, got {ref_fut_x.ndim}." ) def _pool_router_context( self, shared_ref_tokens: torch.Tensor ) -> torch.Tensor: with self._router_no_grad_context(): if shared_ref_tokens.ndim == 3: query = self.router_query.to( device=shared_ref_tokens.device, dtype=shared_ref_tokens.dtype, )[None, :].expand(int(shared_ref_tokens.shape[0]), -1) elif shared_ref_tokens.ndim == 4: query = self.router_query.to( device=shared_ref_tokens.device, dtype=shared_ref_tokens.dtype, )[None, None, :].expand( int(shared_ref_tokens.shape[0]), int(shared_ref_tokens.shape[1]), -1, ) else: raise ValueError( "shared_ref_tokens must have ndim 3 or 4, got " f"{shared_ref_tokens.ndim}." ) return self.router_ref_pool(query, shared_ref_tokens) def _apply_actor_ref_film( self, state_hidden: torch.Tensor, actor_ref_ctx: torch.Tensor ) -> torch.Tensor: ctx = self.actor_ref_ctx_norm(actor_ref_ctx) scale_raw, shift_raw = self.actor_ref_film(ctx).chunk(2, dim=-1) scale = self.actor_film_scale_max * torch.tanh(scale_raw) shift = self.actor_film_shift_max * torch.tanh(shift_raw) hidden_norm = self.actor_film_hidden_norm(state_hidden) delta = scale * hidden_norm + shift delta = self._normalize_actor_film_delta(delta) gain = self._actor_film_gain().to( device=state_hidden.device, dtype=state_hidden.dtype ) expand_shape = [1] * (delta.ndim - 1) + [self.d_model] return state_hidden + delta * gain.view(*expand_shape) def _actor_film_gain(self) -> torch.Tensor: return self.actor_film_gain_max * torch.sigmoid( self.actor_film_gain_raw ) def _normalize_actor_film_delta(self, delta: torch.Tensor) -> torch.Tensor: rms = delta.pow(2).mean(dim=-1, keepdim=True) return delta * torch.rsqrt(rms + self.actor_film_delta_rms_eps) def _ensure_internal_cache_device( self, device, *, dtype: torch.dtype | None = None, ) -> None: if self._k_cache is not None and self._k_cache.device != device: self._k_cache = self._k_cache.to(device) self._v_cache = self._v_cache.to(device) self._ref_hist_k_cache = self._ref_hist_k_cache.to(device) self._ref_hist_v_cache = self._ref_hist_v_cache.to(device) self._kv_cache_len = self._kv_cache_len.to(device) self._kv_cache_write_idx = self._kv_cache_write_idx.to(device) self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(device) if ( dtype is not None and self._k_cache is not None and self._k_cache.dtype != dtype ): self._k_cache = self._k_cache.to(dtype) self._v_cache = self._v_cache.to(dtype) self._ref_hist_k_cache = self._ref_hist_k_cache.to(dtype) self._ref_hist_v_cache = self._ref_hist_v_cache.to(dtype) def reset_kv_cache(self, num_envs: int, device): cache_dtype = ( torch.float16 if torch.device(device).type == "cuda" else torch.float32 ) self._k_cache = torch.zeros( num_envs, self.n_layers, self.max_ctx_len, self.n_kv_heads, self.head_dim, device=device, dtype=cache_dtype, ) self._v_cache = torch.zeros_like(self._k_cache) self._ref_hist_k_cache = torch.zeros( num_envs, self.ref_hist_n_layers, self.max_ctx_len, self.n_kv_heads, self.head_dim, device=device, dtype=cache_dtype, ) self._ref_hist_v_cache = torch.zeros_like(self._ref_hist_k_cache) self._kv_cache_len = torch.zeros( num_envs, dtype=torch.long, device=device ) self._kv_cache_write_idx = torch.zeros( num_envs, dtype=torch.long, device=device ) self._kv_cache_abs_pos = torch.zeros( num_envs, dtype=torch.long, device=device ) self._init_last_moe_router_shift_state(num_envs, device) def clear_env_cache(self, env_ids: torch.Tensor | None): if self._k_cache is None: return if env_ids is None: self._k_cache.zero_() self._v_cache.zero_() self._ref_hist_k_cache.zero_() self._ref_hist_v_cache.zero_() self._kv_cache_len.zero_() self._kv_cache_write_idx.zero_() self._kv_cache_abs_pos.zero_() if self._prev_last_moe_router_p is not None: self._prev_last_moe_router_p.zero_() if self._prev_last_moe_router_valid is not None: self._prev_last_moe_router_valid.zero_() if self._last_moe_router_js_sum is not None: self._last_moe_router_js_sum.zero_() if self._last_moe_router_js_count is not None: self._last_moe_router_js_count.zero_() if self._last_moe_router_top1_switch_sum is not None: self._last_moe_router_top1_switch_sum.zero_() if self._last_moe_router_top1_switch_count is not None: self._last_moe_router_top1_switch_count.zero_() return self._k_cache[env_ids] = 0.0 self._v_cache[env_ids] = 0.0 self._ref_hist_k_cache[env_ids] = 0.0 self._ref_hist_v_cache[env_ids] = 0.0 self._kv_cache_len[env_ids] = 0 self._kv_cache_write_idx[env_ids] = 0 self._kv_cache_abs_pos[env_ids] = 0 if self._prev_last_moe_router_valid is not None: self._prev_last_moe_router_valid[env_ids] = False if self._prev_last_moe_router_p is not None: self._prev_last_moe_router_p[env_ids] = 0.0 def predict_aux_from_pre_moe( self, pre_moe_hidden: torch.Tensor, *, ref_aux_hidden: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: aux_outputs = super().predict_aux_from_pre_moe( pre_moe_hidden, ref_aux_hidden=ref_aux_hidden ) if self.aux_ref_keybody_pos_head is not None: if ref_aux_hidden is None: raise ValueError( "Missing shared-ref auxiliary hidden state for " "ref_keybody_rel_pos prediction." ) aux_outputs["ref_keybody_rel_pos"] = self.aux_ref_keybody_pos_head( ref_aux_hidden ).reshape( ref_aux_hidden.shape[0], ref_aux_hidden.shape[1], self.aux_keybody_pos_dim, 3, ) return aux_outputs def sequence_mu( self, x: torch.Tensor, *, attn_mask: torch.Tensor | None = None, return_hidden: bool = False, return_pre_moe_hidden: bool = False, return_ref_aux_hidden: bool = False, return_router_features: bool = False, return_router_temporal_features: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: _, time, _ = x.shape state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x) state_h = self.obs_embed(state_x) if attn_mask is not None: tgt_mask = attn_mask.unsqueeze(1) start_idx = attn_mask.to(torch.int64).argmax(dim=-1) t_idx = torch.arange(time, device=x.device, dtype=torch.long)[ None, : ].expand(int(x.shape[0]), time) pos = t_idx - start_idx else: tgt_mask = None pos = torch.arange(time, device=x.device, dtype=torch.long)[ None, : ].expand(int(x.shape[0]), time) shared_ref_tokens = self._build_shared_ref_tokens( ref_cur_x=ref_cur_x, ref_fut_x=ref_fut_x, pos=pos, tgt_mask=tgt_mask, ) actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens) router_h = self._pool_router_context(shared_ref_tokens) cos, sin = self.get_cos_sin(state_h, pos) if return_hidden and return_pre_moe_hidden: raise ValueError( "return_hidden and return_pre_moe_hidden cannot both be True." ) block0_h = self._forward_layers_range( state_h, cos=cos, sin=sin, mask=tgt_mask, router_h=router_h, start_layer=0, end_layer=1, ) h = self._apply_actor_ref_film(block0_h, actor_ref_ctx) forward_out = self._forward_layers_range( h, cos=cos, sin=sin, mask=tgt_mask, router_h=router_h, start_layer=1, end_layer=len(self.layers), return_router_features=return_router_features, return_router_temporal_features=return_router_temporal_features, ) extras: list[torch.Tensor] = [] if isinstance(forward_out, tuple): h = forward_out[0] extras = list(forward_out[1:]) else: h = forward_out h = self.norm_f(h) mu = self.action_mu_head(h) outputs: list[torch.Tensor] = [mu] if return_pre_moe_hidden: outputs.append(block0_h) if return_ref_aux_hidden: outputs.append(router_h) if return_router_features: outputs.append(extras.pop(0)) if return_router_temporal_features: outputs.append(extras.pop(0)) if len(outputs) > 1: return tuple(outputs) if return_hidden: return mu, h return mu def sequence_hidden( self, x: torch.Tensor, *, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: _, time, _ = x.shape state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x) state_h = self.obs_embed(state_x) if attn_mask is not None: tgt_mask = attn_mask.unsqueeze(1) start_idx = attn_mask.to(torch.int64).argmax(dim=-1) t_idx = torch.arange(time, device=x.device, dtype=torch.long)[ None, : ].expand(int(x.shape[0]), time) pos = t_idx - start_idx else: tgt_mask = None pos = torch.arange(time, device=x.device, dtype=torch.long)[ None, : ].expand(int(x.shape[0]), time) shared_ref_tokens = self._build_shared_ref_tokens( ref_cur_x=ref_cur_x, ref_fut_x=ref_fut_x, pos=pos, tgt_mask=tgt_mask, ) actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens) router_h = self._pool_router_context(shared_ref_tokens) cos, sin = self.get_cos_sin(state_h, pos) h = self._forward_layers_range( state_h, cos=cos, sin=sin, mask=tgt_mask, router_h=router_h, start_layer=0, end_layer=1, ) h = self._apply_actor_ref_film(h, actor_ref_ctx) h = self._forward_layers_range( h, cos=cos, sin=sin, mask=tgt_mask, router_h=router_h, start_layer=1, end_layer=len(self.layers), ) h = self.norm_f(h) return h def forward( self, input: torch.Tensor, past_key_values: torch.Tensor | None = None, current_pos: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if past_key_values is not None: return self._forward_inference_onnx( input, past_key_values, current_pos ) if input.ndim != 2: raise ValueError(f"Expected [B, D], got {input.shape}") mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None) return mu_seq[:, 0, :] def single_step_mu(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 2: raise ValueError(f"Expected [B, D], got {x.shape}") state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x) batch = int(state_x.shape[0]) if self._k_cache is None: mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None) return mu_seq[:, 0, :] state_h = self.obs_embed(state_x) self._ensure_internal_cache_device(x.device, dtype=state_h.dtype) cache_len = self._kv_cache_len insert_pos = self._kv_cache_write_idx max_len = int(self.max_ctx_len) new_len = torch.clamp(cache_len + 1, max=max_len) self._kv_cache_len = new_len self._kv_cache_write_idx = (insert_pos + 1) % max_len pos = self._kv_cache_abs_pos self._kv_cache_abs_pos = pos + 1 pos_ids = pos.unsqueeze(1) shared_ref_tokens, _, _ = self._build_shared_ref_tokens_single_step( ref_cur_x=ref_cur_x, ref_fut_x=ref_fut_x, pos_ids=pos_ids, k_cache=self._ref_hist_k_cache[:, 0], v_cache=self._ref_hist_v_cache[:, 0], new_len=new_len, insert_pos=insert_pos, ) actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)[ :, None, : ] router_h = self._pool_router_context(shared_ref_tokens)[:, None, :] cos, sin = self.get_cos_sin(state_h[:, None, :], pos_ids) h = state_h[:, None, :] for layer_idx, layer in enumerate(self.layers[:1]): x_norm = layer.norm1(h) k_cache_l = self._k_cache[:, layer_idx] v_cache_l = self._v_cache[:, layer_idx] attn_out, _, _ = layer.attn.forward_single_token( x_norm, cos, sin, k_cache_l, v_cache_l, new_len, insert_pos, ) h = h + attn_out h2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): ffn = layer.compute_moe_ffn(h2, router_x=router_h) if ( layer_idx == self._last_moe_layer_idx and layer.collect_routing_stats and layer.last_router_distribution is not None ): self._accumulate_last_moe_router_shift( layer.last_router_distribution ) else: ffn = layer.mlp_dropout(layer.mlp(h2)) h = h + ffn h = self._apply_actor_ref_film(h, actor_ref_ctx) for layer_idx, layer in enumerate(self.layers[1:], start=1): x_norm = layer.norm1(h) k_cache_l = self._k_cache[:, layer_idx] v_cache_l = self._v_cache[:, layer_idx] attn_out, _, _ = layer.attn.forward_single_token( x_norm, cos, sin, k_cache_l, v_cache_l, new_len, insert_pos, ) h = h + attn_out h2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): ffn = layer.compute_moe_ffn(h2, router_x=router_h) if ( layer_idx == self._last_moe_layer_idx and layer.collect_routing_stats and layer.last_router_distribution is not None ): self._accumulate_last_moe_router_shift( layer.last_router_distribution ) else: ffn = layer.mlp_dropout(layer.mlp(h2)) h = h + ffn h = self.norm_f(h) return self.action_mu_head(h[:, 0, :]).reshape(batch, -1) def _forward_inference_onnx( self, x: torch.Tensor, past_key_values: torch.Tensor, current_pos: torch.Tensor, ) -> tuple[torch.Tensor, ...]: state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x) state_h = self.obs_embed(state_x) batch = state_h.shape[0] max_len = past_key_values.shape[3] if current_pos.ndim == 0: current_pos = current_pos.view(1).expand(batch) insert_pos = current_pos % max_len new_len = torch.clamp(current_pos + 1, max=max_len) position_ids = current_pos.unsqueeze(1) ref_layer_past = past_key_values[0] shared_ref_tokens, ref_k_cache, ref_v_cache = ( self._build_shared_ref_tokens_single_step( ref_cur_x=ref_cur_x, ref_fut_x=ref_fut_x, pos_ids=position_ids, k_cache=ref_layer_past[0], v_cache=ref_layer_past[1], new_len=new_len, insert_pos=insert_pos, ) ) actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)[ :, None, : ] router_h = self._pool_router_context(shared_ref_tokens)[:, None, :] cos, sin = self.get_cos_sin(state_h[:, None, :], position_ids) present_key_values_list = [ torch.stack([ref_k_cache, ref_v_cache], dim=0) ] routing_debug_outputs: list[torch.Tensor] = [] export_routing_debug = torch.onnx.is_in_onnx_export() h = state_h[:, None, :] for i, layer in enumerate(self.layers[:1]): layer_past = past_key_values[self.ref_hist_n_layers + i] k_cache = layer_past[0] v_cache = layer_past[1] h_norm = layer.norm1(h) attn_out, new_k_cache, new_v_cache = ( layer.attn.forward_single_token( x=h_norm, cos=cos, sin=sin, k_cache=k_cache, v_cache=v_cache, new_len=new_len, insert_pos=insert_pos, ) ) h = h + attn_out h_norm2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): if export_routing_debug: ffn_out, topk_idx, router_logits = layer.compute_moe_ffn( h_norm2, router_x=router_h, return_routing_debug=True, ) routing_debug_outputs.extend([topk_idx, router_logits]) else: ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h) else: ffn_out = layer.mlp_dropout(layer.mlp(h_norm2)) h = h + ffn_out current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0) present_key_values_list.append(current_layer_kv) h = self._apply_actor_ref_film(h, actor_ref_ctx) for i, layer in enumerate(self.layers[1:], start=1): layer_past = past_key_values[self.ref_hist_n_layers + i] k_cache = layer_past[0] v_cache = layer_past[1] h_norm = layer.norm1(h) attn_out, new_k_cache, new_v_cache = ( layer.attn.forward_single_token( x=h_norm, cos=cos, sin=sin, k_cache=k_cache, v_cache=v_cache, new_len=new_len, insert_pos=insert_pos, ) ) h = h + attn_out h_norm2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): if export_routing_debug: ffn_out, topk_idx, router_logits = layer.compute_moe_ffn( h_norm2, router_x=router_h, return_routing_debug=True, ) routing_debug_outputs.extend([topk_idx, router_logits]) else: ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h) else: ffn_out = layer.mlp_dropout(layer.mlp(h_norm2)) h = h + ffn_out current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0) present_key_values_list.append(current_layer_kv) h = self.norm_f(h) action = self.action_mu_head(h[:, 0, :]) present_key_values = torch.stack(present_key_values_list, dim=0) if export_routing_debug and routing_debug_outputs: return (action, present_key_values, *routing_debug_outputs) return action, present_key_values class ReferenceRoutedGroupedMoETransformerPolicyV3( ReferenceRoutedGroupedMoETransformerPolicyV2 ): supports_explicit_ref_aux_hidden = True def __init__( self, input_dim: int, output_dim: int, module_config_dict: dict, ): module_config = dict(module_config_dict) if bool(module_config.get("use_future_cross_attn", False)): raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 does not " "support use_future_cross_attn=True." ) state_obs_input_dim = module_config.get("state_obs_input_dim", None) ref_cur_token_dim = module_config.get("ref_cur_token_dim", None) ref_fut_token_dim = module_config.get("ref_fut_token_dim", None) ref_fut_seq_len = module_config.get("ref_fut_seq_len", None) state_feature_indices = module_config.get( "state_feature_indices", None ) ref_cur_feature_indices = module_config.get( "ref_cur_feature_indices", None ) ref_fut_slices = module_config.get("ref_fut_slices", None) if state_obs_input_dim is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 requires " "state_obs_input_dim." ) if ref_cur_token_dim is None or ref_fut_token_dim is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 requires " "ref_cur_token_dim and ref_fut_token_dim." ) if ref_fut_seq_len is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 requires " "ref_fut_seq_len." ) if state_feature_indices is None or ref_cur_feature_indices is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 requires " "state_feature_indices and ref_cur_feature_indices." ) if ref_fut_slices is None: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 requires " "ref_fut_slices." ) self.full_obs_input_dim = int(input_dim) self.state_obs_input_dim = int(state_obs_input_dim) self.ref_cur_token_dim = int(ref_cur_token_dim) self.ref_fut_token_dim = int(ref_fut_token_dim) self.ref_fut_seq_len = int(ref_fut_seq_len) self.state_feature_indices = tuple( int(idx) for idx in state_feature_indices ) self.ref_cur_feature_indices = tuple( int(idx) for idx in ref_cur_feature_indices ) self.ref_fut_slices = tuple( (int(start), int(end), int(dim)) for start, end, dim in ref_fut_slices ) if self.state_obs_input_dim <= 0: raise ValueError( "state_obs_input_dim must be positive, got " f"{self.state_obs_input_dim}." ) if self.ref_cur_token_dim <= 0 or self.ref_fut_token_dim <= 0: raise ValueError( "ref token dims must be positive, got " f"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}." ) if self.ref_cur_token_dim != self.ref_fut_token_dim: raise ValueError( "current/future ref token dims must match, got " f"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}." ) if self.ref_fut_seq_len <= 0: raise ValueError( f"ref_fut_seq_len must be positive, got {self.ref_fut_seq_len}." ) if len(self.state_feature_indices) != self.state_obs_input_dim: raise ValueError( "state_obs_input_dim must match len(state_feature_indices): " f"{self.state_obs_input_dim} vs {len(self.state_feature_indices)}." ) if len(self.ref_cur_feature_indices) != self.ref_cur_token_dim: raise ValueError( "ref_cur_token_dim must match len(ref_cur_feature_indices): " f"{self.ref_cur_token_dim} vs {len(self.ref_cur_feature_indices)}." ) fut_flat_dim = 0 for start, end, dim in self.ref_fut_slices: if end <= start or dim <= 0: raise ValueError( f"Invalid ref_fut_slices entry {(start, end, dim)}." ) if (end - start) != self.ref_fut_seq_len * dim: raise ValueError( "Future ref slice span must equal ref_fut_seq_len * dim, got " f"{(start, end, dim)} with ref_fut_seq_len={self.ref_fut_seq_len}." ) fut_flat_dim += end - start expected_full_input_dim = ( self.state_obs_input_dim + self.ref_cur_token_dim + fut_flat_dim ) if self.full_obs_input_dim != expected_full_input_dim: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 expected full " f"input dim {expected_full_input_dim}, got {self.full_obs_input_dim}." ) self.ref_hist_n_layers = int(module_config.get("ref_hist_n_layers", 1)) if self.ref_hist_n_layers != 1: raise ValueError( "ReferenceRoutedGroupedMoETransformerPolicyV3 currently supports " "exactly one ref history attention layer." ) layer_proj_hidden_default = int( module_config.get( "router_layer_proj_hidden_dim", module_config.get("d_model", 256), ) ) self.ref_motion_input_dim = int( self.ref_cur_token_dim + self.ref_fut_seq_len * self.ref_fut_token_dim ) self.router_layer_proj_hidden_dim = int(layer_proj_hidden_default) if self.router_layer_proj_hidden_dim <= 0: raise ValueError( "router_layer_proj_hidden_dim must be positive, got " f"{self.router_layer_proj_hidden_dim}." ) GroupedMoETransformerPolicy.__init__( self, input_dim=input_dim, output_dim=output_dim, module_config_dict=module_config, ) self.onnx_kv_layers = int(self.ref_hist_n_layers + self.n_layers) self.register_buffer( "_state_feature_indices", torch.tensor(self.state_feature_indices, dtype=torch.long), persistent=False, ) self.register_buffer( "_ref_cur_feature_indices", torch.tensor(self.ref_cur_feature_indices, dtype=torch.long), persistent=False, ) self.ref_frame_embed = nn.Sequential( nn.Linear(self.ref_motion_input_dim, self.obs_embed_mlp_hidden), nn.SiLU(), nn.Linear(self.obs_embed_mlp_hidden, self.d_model), ) self.ref_hist_norm = RMSNorm(self.d_model) self.ref_hist_attn = ModernAttention( d_model=self.d_model, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, use_qk_norm=self.use_qk_norm, use_gated_attn=self.use_gated_attn, gated_attn_type=self.gated_attn_type, attn_dropout=self.attn_dropout, ) self.ref_hist_out_norm = RMSNorm(self.d_model) self._moe_layer_indices = tuple( i for i, layer in enumerate(self.layers) if isinstance(layer, GroupedMoEBlock) ) self.router_layer_projections = nn.ModuleList( [ nn.Sequential( RMSNorm(self.d_model), nn.Linear(self.d_model, self.router_layer_proj_hidden_dim), nn.SiLU(), nn.Linear(self.router_layer_proj_hidden_dim, self.d_model), ) for _ in self._moe_layer_indices ] ) self._ref_hist_k_cache: torch.Tensor | None = None self._ref_hist_v_cache: torch.Tensor | None = None self._apply_freeze_router_state() def _apply_freeze_router_state(self) -> None: GroupedMoETransformerPolicy._apply_freeze_router_state(self) requires_grad = not self.freeze_router self.ref_frame_embed.requires_grad_(requires_grad) self.ref_hist_norm.requires_grad_(requires_grad) self.ref_hist_attn.requires_grad_(requires_grad) self.ref_hist_out_norm.requires_grad_(requires_grad) self.router_layer_projections.requires_grad_(requires_grad) def _build_router_ref_motion( self, ref_cur_x: torch.Tensor, ref_fut_x: torch.Tensor, ) -> torch.Tensor: if ref_cur_x.ndim not in (2, 3): raise ValueError( f"Expected ref_cur_x with ndim 2 or 3, got {ref_cur_x.ndim}." ) if ref_fut_x.ndim != ref_cur_x.ndim + 1: raise ValueError( "Expected ref_fut_x to add one future-seq axis relative to " f"ref_cur_x, got cur={tuple(ref_cur_x.shape)}, " f"fut={tuple(ref_fut_x.shape)}." ) ref_fut_flat = torch.flatten(ref_fut_x, start_dim=-2) return torch.cat([ref_cur_x, ref_fut_flat], dim=-1) def _build_shared_router_summary( self, ref_hist_h: torch.Tensor, ) -> torch.Tensor: with self._router_no_grad_context(): return ref_hist_h def _build_router_h_per_layer( self, shared_router_summary: torch.Tensor, ) -> list[torch.Tensor | None]: with self._router_no_grad_context(): router_h_per_layer: list[torch.Tensor | None] = [ None for _ in self.layers ] for proj, layer_idx in zip( self.router_layer_projections, self._moe_layer_indices ): router_h_per_layer[layer_idx] = proj(shared_router_summary) return router_h_per_layer def _build_ref_hist_hidden( self, ref_motion_x: torch.Tensor, pos: torch.Tensor, tgt_mask: torch.Tensor | None, ) -> torch.Tensor: with self._router_no_grad_context(): ref_motion_h = self.ref_frame_embed(ref_motion_x) ref_hist_attn = self.ref_hist_attn( self.ref_hist_norm(ref_motion_h), *self.get_cos_sin(ref_motion_h, pos), mask=tgt_mask, ) return self.ref_hist_out_norm(ref_motion_h + ref_hist_attn) def _build_ref_hist_hidden_single_step( self, ref_motion_x: torch.Tensor, pos_ids: torch.Tensor, *, k_cache: torch.Tensor, v_cache: torch.Tensor, new_len: torch.Tensor, insert_pos: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: with self._router_no_grad_context(): ref_motion_h = self.ref_frame_embed(ref_motion_x)[:, None, :] ref_cos, ref_sin = self.get_cos_sin(ref_motion_h, pos_ids) ref_hist_attn, ref_k_cache, ref_v_cache = ( self.ref_hist_attn.forward_single_token( self.ref_hist_norm(ref_motion_h), ref_cos, ref_sin, k_cache, v_cache, new_len, insert_pos, ) ) ref_hist_h = self.ref_hist_out_norm(ref_motion_h + ref_hist_attn) return ref_hist_h, ref_k_cache, ref_v_cache def predict_aux_from_pre_moe( self, pre_moe_hidden: torch.Tensor, *, ref_aux_hidden: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: return GroupedMoETransformerPolicy.predict_aux_from_pre_moe( self, pre_moe_hidden, ref_aux_hidden=ref_aux_hidden, ) def sequence_mu( self, x: torch.Tensor, *, attn_mask: torch.Tensor | None = None, return_hidden: bool = False, return_pre_moe_hidden: bool = False, return_ref_aux_hidden: bool = False, return_router_features: bool = False, return_router_temporal_features: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, ...]: batch, time, _ = x.shape h = self.obs_embed(x) _, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x) ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x) if attn_mask is not None: tgt_mask = attn_mask.unsqueeze(1) start_idx = attn_mask.to(torch.int64).argmax(dim=-1) t_idx = torch.arange(time, device=x.device, dtype=torch.long)[ None, : ].expand(batch, time) pos = t_idx - start_idx else: tgt_mask = None pos = torch.arange(time, device=x.device, dtype=torch.long)[ None, : ].expand(batch, time) ref_hist_h = self._build_ref_hist_hidden( ref_motion_x=ref_motion_x, pos=pos, tgt_mask=tgt_mask, ) shared_router_summary = self._build_shared_router_summary(ref_hist_h) router_h_per_layer = self._build_router_h_per_layer( shared_router_summary ) cos, sin = self.get_cos_sin(h, pos) if return_hidden and return_pre_moe_hidden: raise ValueError( "return_hidden and return_pre_moe_hidden cannot both be True." ) forward_out = self._forward_layers( h, cos=cos, sin=sin, mask=tgt_mask, router_h_per_layer=router_h_per_layer, return_pre_moe_hidden=return_pre_moe_hidden, return_router_features=return_router_features, return_router_temporal_features=return_router_temporal_features, ) extras: list[torch.Tensor] = [] if isinstance(forward_out, tuple): h = forward_out[0] extras = list(forward_out[1:]) else: h = forward_out h = self.norm_f(h) mu = self.action_mu_head(h) outputs: list[torch.Tensor] = [mu] if return_pre_moe_hidden: outputs.append(extras.pop(0)) if return_ref_aux_hidden: outputs.append(shared_router_summary) if return_router_features: outputs.append(extras.pop(0)) if return_router_temporal_features: outputs.append(extras.pop(0)) if len(outputs) > 1: return tuple(outputs) if return_hidden: return mu, h return mu def sequence_hidden( self, x: torch.Tensor, *, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: batch, time, _ = x.shape h = self.obs_embed(x) _, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x) ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x) if attn_mask is not None: tgt_mask = attn_mask.unsqueeze(1) start_idx = attn_mask.to(torch.int64).argmax(dim=-1) t_idx = torch.arange(time, device=x.device, dtype=torch.long)[ None, : ].expand(batch, time) pos = t_idx - start_idx else: tgt_mask = None pos = torch.arange(time, device=x.device, dtype=torch.long)[ None, : ].expand(batch, time) ref_hist_h = self._build_ref_hist_hidden( ref_motion_x=ref_motion_x, pos=pos, tgt_mask=tgt_mask, ) shared_router_summary = self._build_shared_router_summary(ref_hist_h) router_h_per_layer = self._build_router_h_per_layer( shared_router_summary ) cos, sin = self.get_cos_sin(h, pos) h = self._forward_layers( h, cos=cos, sin=sin, mask=tgt_mask, router_h_per_layer=router_h_per_layer, ) h = self.norm_f(h) return h def forward( self, input: torch.Tensor, past_key_values: torch.Tensor | None = None, current_pos: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if past_key_values is not None: return self._forward_inference_onnx( input, past_key_values, current_pos ) if input.ndim != 2: raise ValueError(f"Expected [B, D], got {input.shape}") mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None) return mu_seq[:, 0, :] def single_step_mu(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 2: raise ValueError(f"Expected [B, D], got {x.shape}") _, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x) ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x) batch = int(x.shape[0]) if self._k_cache is None: mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None) return mu_seq[:, 0, :] h = self.obs_embed(x)[:, None, :] self._ensure_internal_cache_device(x.device, dtype=h.dtype) cache_len = self._kv_cache_len insert_pos = self._kv_cache_write_idx max_len = int(self.max_ctx_len) new_len = torch.clamp(cache_len + 1, max=max_len) self._kv_cache_len = new_len self._kv_cache_write_idx = (insert_pos + 1) % max_len pos = self._kv_cache_abs_pos self._kv_cache_abs_pos = pos + 1 pos_ids = pos.unsqueeze(1) ref_hist_h, _, _ = self._build_ref_hist_hidden_single_step( ref_motion_x=ref_motion_x, pos_ids=pos_ids, k_cache=self._ref_hist_k_cache[:, 0], v_cache=self._ref_hist_v_cache[:, 0], new_len=new_len, insert_pos=insert_pos, ) shared_router_summary = self._build_shared_router_summary(ref_hist_h) router_h_per_layer = self._build_router_h_per_layer( shared_router_summary ) cos, sin = self.get_cos_sin(h, pos_ids) for layer_idx, layer in enumerate(self.layers): x_norm = layer.norm1(h) k_cache_l = self._k_cache[:, layer_idx] v_cache_l = self._v_cache[:, layer_idx] attn_out, _, _ = layer.attn.forward_single_token( x_norm, cos, sin, k_cache_l, v_cache_l, new_len, insert_pos, ) h = h + attn_out h2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): ffn = layer.compute_moe_ffn( h2, router_x=router_h_per_layer[layer_idx] ) if ( layer_idx == self._last_moe_layer_idx and layer.collect_routing_stats and layer.last_router_distribution is not None ): self._accumulate_last_moe_router_shift( layer.last_router_distribution ) else: ffn = layer.mlp_dropout(layer.mlp(h2)) h = h + ffn h = self.norm_f(h) return self.action_mu_head(h[:, 0, :]).reshape(batch, -1) def _forward_inference_onnx( self, x: torch.Tensor, past_key_values: torch.Tensor, current_pos: torch.Tensor, ) -> tuple[torch.Tensor, ...]: _, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x) ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x) h = self.obs_embed(x)[:, None, :] batch = h.shape[0] max_len = past_key_values.shape[3] if current_pos.ndim == 0: current_pos = current_pos.view(1).expand(batch) insert_pos = current_pos % max_len new_len = torch.clamp(current_pos + 1, max=max_len) position_ids = current_pos.unsqueeze(1) ref_layer_past = past_key_values[0] ref_hist_h, ref_k_cache, ref_v_cache = ( self._build_ref_hist_hidden_single_step( ref_motion_x=ref_motion_x, pos_ids=position_ids, k_cache=ref_layer_past[0], v_cache=ref_layer_past[1], new_len=new_len, insert_pos=insert_pos, ) ) shared_router_summary = self._build_shared_router_summary(ref_hist_h) router_h_per_layer = self._build_router_h_per_layer( shared_router_summary ) cos, sin = self.get_cos_sin(h, position_ids) present_key_values_list = [ torch.stack([ref_k_cache, ref_v_cache], dim=0) ] routing_debug_outputs: list[torch.Tensor] = [] export_routing_debug = torch.onnx.is_in_onnx_export() for i, layer in enumerate(self.layers): layer_past = past_key_values[self.ref_hist_n_layers + i] k_cache = layer_past[0] v_cache = layer_past[1] h_norm = layer.norm1(h) attn_out, new_k_cache, new_v_cache = ( layer.attn.forward_single_token( x=h_norm, cos=cos, sin=sin, k_cache=k_cache, v_cache=v_cache, new_len=new_len, insert_pos=insert_pos, ) ) h = h + attn_out h_norm2 = layer.norm2(h) if isinstance(layer, GroupedMoEBlock): if export_routing_debug: ffn_out, topk_idx, router_logits = layer.compute_moe_ffn( h_norm2, router_x=router_h_per_layer[i], return_routing_debug=True, ) routing_debug_outputs.extend([topk_idx, router_logits]) else: ffn_out = layer.compute_moe_ffn( h_norm2, router_x=router_h_per_layer[i] ) else: ffn_out = layer.mlp_dropout(layer.mlp(h_norm2)) h = h + ffn_out current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0) present_key_values_list.append(current_layer_kv) h = self.norm_f(h) action = self.action_mu_head(h[:, 0, :]) present_key_values = torch.stack(present_key_values_list, dim=0) if export_routing_debug and routing_debug_outputs: return (action, present_key_values, *routing_debug_outputs) return action, present_key_values class GroupedMoEBlock(nn.Module): def __init__( self, d_model: int, n_heads: int, num_fine_experts: int, num_shared_experts: int, top_k: int, n_kv_heads: int | None = None, ff_mult: float = 2, use_qk_norm: bool = True, use_gated_attn: bool = True, gated_attn_type: str = "headwise", attn_dropout: float = 0.0, mlp_dropout: float = 0.0, use_dynamic_bias: bool = False, bias_update_rate: float = 0.001, routing_score_fn: str = "softmax", freeze_router: bool = False, routing_scale: float = 1.0, expert_bias_clip: float = 0.0, dead_expert_margin_to_topk_enabled: bool = False, selected_expert_margin_to_unselected_enabled: bool = False, selected_expert_margin_to_unselected_target: float = 0.0, use_cross_attn: bool = False, ): super().__init__() self.d_model = d_model self.n_heads = n_heads self.num_fine_experts = num_fine_experts self.num_shared_experts = num_shared_experts self.top_k = top_k self.use_dynamic_bias = use_dynamic_bias self.bias_update_rate = bias_update_rate self.routing_score_fn = str(routing_score_fn).lower() self.freeze_router = bool(freeze_router) self.routing_scale = float(routing_scale) self.expert_bias_clip = float(expert_bias_clip) self.dead_expert_margin_to_topk_enabled = bool( dead_expert_margin_to_topk_enabled ) self.selected_expert_margin_to_unselected_enabled = bool( selected_expert_margin_to_unselected_enabled ) self.selected_expert_margin_to_unselected_target = float( selected_expert_margin_to_unselected_target ) if self.routing_score_fn not in ("softmax", "sigmoid"): raise ValueError( f"routing_score_fn must be one of {{'softmax','sigmoid'}}, got {self.routing_score_fn}" ) if self.routing_scale <= 0.0: raise ValueError( f"routing_scale must be > 0, got {self.routing_scale}" ) if self.expert_bias_clip < 0.0: raise ValueError( f"expert_bias_clip must be >= 0, got {self.expert_bias_clip}" ) if self.selected_expert_margin_to_unselected_target < 0.0: raise ValueError( "selected_expert_margin_to_unselected_target must be >= 0, " f"got {self.selected_expert_margin_to_unselected_target}" ) self.register_buffer("expert_bias", torch.zeros(num_fine_experts)) self.register_buffer( "routing_counts_accum", torch.zeros(num_fine_experts, dtype=torch.long), persistent=False, ) self.register_buffer( "last_routed_expert_usage", torch.zeros(num_fine_experts, dtype=torch.float32), persistent=False, ) self.register_buffer( "last_routed_active_expert_count", torch.tensor(0.0), persistent=False, ) self.register_buffer( "last_routed_max_expert_frac", torch.tensor(0.0), persistent=False, ) self.register_buffer( "last_active_expert_ratio", torch.tensor(0.0), persistent=False ) self.register_buffer( "last_max_expert_frac", torch.tensor(0.0), persistent=False ) self.register_buffer( "last_expert_count_cv", torch.tensor(0.0), persistent=False ) self.register_buffer( "last_min_expert_frac", torch.tensor(0.0), persistent=False ) self.register_buffer( "last_dead_expert_ratio", torch.tensor(0.0), persistent=False ) self.register_buffer( "last_dense_expert_usage", torch.zeros(num_fine_experts, dtype=torch.float32), persistent=False, ) self.register_buffer( "last_dead_expert_margin_to_topk_loss_value", torch.tensor(0.0), persistent=False, ) self.register_buffer( "last_dead_expert_margin_to_topk_target", torch.tensor(0.0), persistent=False, ) self.register_buffer( "last_selected_expert_margin_to_unselected", torch.tensor(0.0), persistent=False, ) self.register_buffer( "last_selected_expert_margin_to_unselected_loss_value", torch.tensor(0.0), persistent=False, ) self.collect_routing_stats = False self.collect_router_distribution = False self.capture_router_distribution = False self.capture_router_logits = False self.last_router_distribution: torch.Tensor | None = None self.last_router_logits: torch.Tensor | None = None self.last_dead_expert_margin_to_topk_loss: torch.Tensor | None = None self.last_selected_expert_margin_to_unselected_loss: ( torch.Tensor | None ) = None self.use_cross_attn = bool(use_cross_attn) self.norm1 = RMSNorm(d_model) self.attn = ModernAttention( d_model=d_model, n_heads=n_heads, n_kv_heads=n_kv_heads, use_qk_norm=use_qk_norm, use_gated_attn=use_gated_attn, gated_attn_type=gated_attn_type, attn_dropout=attn_dropout, ) if self.use_cross_attn: self.norm_cross = RMSNorm(d_model) self.cross_attn = ModernCrossAttention( d_model=d_model, n_heads=n_heads, n_kv_heads=n_kv_heads, use_qk_norm=use_qk_norm, use_gated_attn=use_gated_attn, gated_attn_type=gated_attn_type, attn_dropout=attn_dropout, ) else: self.norm_cross = None self.cross_attn = None self.norm2 = RMSNorm(d_model) self.intermediate_dim = int(d_model * ff_mult) self.router = nn.Linear(d_model, num_fine_experts, bias=False) self._apply_freeze_router_state() # Gate + Up (Combined) self.gate_up_proj = nn.Parameter( torch.empty( num_fine_experts, self.d_model, 2 * self.intermediate_dim ) ) # Down self.down_proj = nn.Parameter( torch.empty(num_fine_experts, self.intermediate_dim, self.d_model) ) self.shared_experts = DeepseekV3MLP( hidden_size=d_model, intermediate_size=int(d_model * ff_mult * num_shared_experts), ) self.mlp_dropout = ( nn.Dropout(mlp_dropout) if mlp_dropout > 0.0 else nn.Identity() ) self.reset_parameters() def reset_parameters(self) -> None: nn.init.xavier_uniform_(self.gate_up_proj) nn.init.xavier_uniform_(self.down_proj) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ): gate_up_key = prefix + "gate_up_proj" current_gate_up_shape = tuple(self.gate_up_proj.shape) legacy_gate_up_shape = ( self.num_fine_experts, 2 * self.intermediate_dim, self.d_model, ) down_key = prefix + "down_proj" current_down_shape = tuple(self.down_proj.shape) legacy_down_shape = ( self.num_fine_experts, self.d_model, self.intermediate_dim, ) is_legacy_layout = None if gate_up_key in state_dict: gate_up_shape = tuple(state_dict[gate_up_key].shape) gate_up_is_current = gate_up_shape == current_gate_up_shape gate_up_is_legacy = gate_up_shape == legacy_gate_up_shape if gate_up_is_current and not gate_up_is_legacy: is_legacy_layout = False elif gate_up_is_legacy and not gate_up_is_current: is_legacy_layout = True if is_legacy_layout is None and down_key in state_dict: down_shape = tuple(state_dict[down_key].shape) down_is_current = down_shape == current_down_shape down_is_legacy = down_shape == legacy_down_shape if down_is_current and not down_is_legacy: is_legacy_layout = False elif down_is_legacy and not down_is_current: is_legacy_layout = True if gate_up_key in state_dict: gate_up_w = state_dict[gate_up_key] gate_up_shape = tuple(gate_up_w.shape) gate_up_is_legacy_only = ( gate_up_shape == legacy_gate_up_shape and gate_up_shape != current_gate_up_shape ) gate_up_is_ambiguous = ( gate_up_shape == legacy_gate_up_shape and gate_up_shape == current_gate_up_shape ) if gate_up_is_legacy_only or ( gate_up_is_ambiguous and is_legacy_layout ): state_dict[gate_up_key] = gate_up_w.transpose( -2, -1 ).contiguous() if down_key in state_dict: down_w = state_dict[down_key] down_shape = tuple(down_w.shape) down_is_legacy_only = ( down_shape == legacy_down_shape and down_shape != current_down_shape ) down_is_ambiguous = ( down_shape == legacy_down_shape and down_shape == current_down_shape ) if down_is_legacy_only or (down_is_ambiguous and is_legacy_layout): state_dict[down_key] = down_w.transpose(-2, -1).contiguous() super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, ) self._apply_freeze_router_state() def _apply_freeze_router_state(self) -> None: self.router.requires_grad_(not self.freeze_router) def reset_routing_stats(self) -> None: self.routing_counts_accum.zero_() self.last_router_distribution = None self.last_router_logits = None self.last_routed_expert_usage.zero_() self.last_routed_active_expert_count.zero_() self.last_routed_max_expert_frac.zero_() self.last_dense_expert_usage.zero_() self.last_dead_expert_margin_to_topk_loss_value.zero_() self.last_dead_expert_margin_to_topk_target.zero_() self.last_dead_expert_margin_to_topk_loss = None self.last_selected_expert_margin_to_unselected.zero_() self.last_selected_expert_margin_to_unselected_loss_value.zero_() self.last_selected_expert_margin_to_unselected_loss = None def accumulate_routing_stats(self, topk_idx: torch.Tensor) -> None: with torch.no_grad(): counts = torch.bincount( topk_idx.reshape(-1), minlength=self.num_fine_experts ) self.routing_counts_accum.add_(counts) def _apply_bias_update_from_counts(self, counts: torch.Tensor) -> None: with torch.no_grad(): if dist.is_available() and dist.is_initialized(): dist.all_reduce(counts, op=dist.ReduceOp.SUM) total = counts.sum() if int(total.item()) == 0: self.last_active_expert_ratio.zero_() self.last_max_expert_frac.zero_() self.last_expert_count_cv.zero_() self.last_min_expert_frac.zero_() self.last_dead_expert_ratio.zero_() return if self.use_dynamic_bias: avg = counts.float().mean() error = avg - counts.float() self.expert_bias.add_( self.bias_update_rate * torch.sign(error) ) total = total.clamp_min(1) active_ratio = (counts > 0).to(torch.float32).mean() max_expert_frac = counts.max().to(torch.float32) / total.to( torch.float32 ) min_expert_frac = counts.min().to(torch.float32) / total.to( torch.float32 ) dead_expert_ratio = (counts == 0).to(torch.float32).mean() counts_f = counts.to(torch.float32) counts_mean = counts_f.mean().clamp_min(1.0e-6) counts_std = counts_f.std(unbiased=False) expert_count_cv = counts_std / counts_mean self.last_active_expert_ratio.copy_(active_ratio) self.last_max_expert_frac.copy_(max_expert_frac) self.last_expert_count_cv.copy_(expert_count_cv) self.last_min_expert_frac.copy_(min_expert_frac) self.last_dead_expert_ratio.copy_(dead_expert_ratio) if self.use_dynamic_bias and self.expert_bias_clip > 0.0: self.expert_bias.clamp_( min=-self.expert_bias_clip, max=self.expert_bias_clip ) def apply_bias_update_from_counts(self) -> None: with torch.no_grad(): counts = self.routing_counts_accum.clone() self.routing_counts_accum.zero_() self._apply_bias_update_from_counts(counts) def _update_routed_expert_stats_and_floor_loss( self, topk_idx: torch.Tensor, dense_distribution: torch.Tensor, choice_scores: torch.Tensor, ) -> torch.Tensor: counts = torch.bincount( topk_idx.reshape(-1), minlength=self.num_fine_experts ).to(torch.float32) total_assignments = max(int(topk_idx.numel()), 1) hard_usage = counts / float(total_assignments) active_count = (counts > 0).to(torch.float32).sum() max_frac = hard_usage.max() if hard_usage.numel() > 0 else counts.sum() with torch.no_grad(): self.last_routed_expert_usage.copy_( hard_usage.to(self.last_routed_expert_usage.dtype) ) self.last_routed_active_expert_count.copy_( active_count.to(self.last_routed_active_expert_count.dtype) ) self.last_routed_max_expert_frac.copy_( max_frac.to(self.last_routed_max_expert_frac.dtype) ) dense_usage = dense_distribution.to(torch.float32).mean(dim=(0, 1)) with torch.no_grad(): self.last_dense_expert_usage.copy_( dense_usage.detach().to(self.last_dense_expert_usage.dtype) ) kth_choice_score = choice_scores.gather(-1, topk_idx)[..., -1:] if self.top_k < self.num_fine_experts: selected_mask = F.one_hot( topk_idx, num_classes=self.num_fine_experts ).any(dim=-2) best_unselected_score = ( choice_scores.masked_fill( selected_mask, torch.finfo(choice_scores.dtype).min ) .max(dim=-1, keepdim=True) .values ) selected_margin_gap = kth_choice_score - best_unselected_score selected_margin = selected_margin_gap.mean() else: selected_margin_gap = choice_scores.new_zeros( choice_scores.shape[:2] + (1,) ) selected_margin = choice_scores.new_zeros(()) if self.selected_expert_margin_to_unselected_enabled: selected_margin_loss = torch.relu( self.selected_expert_margin_to_unselected_target - selected_margin_gap ).mean() else: selected_margin_loss = choice_scores.new_zeros(()) with torch.no_grad(): self.last_selected_expert_margin_to_unselected.copy_( selected_margin.detach().to( self.last_selected_expert_margin_to_unselected.dtype ) ) self.last_selected_expert_margin_to_unselected_loss_value.copy_( selected_margin_loss.detach().to( self.last_selected_expert_margin_to_unselected_loss_value.dtype ) ) self.last_selected_expert_margin_to_unselected_loss = ( selected_margin_loss ) if not self.dead_expert_margin_to_topk_enabled: margin_loss = dense_distribution.new_zeros(()) with torch.no_grad(): self.last_dead_expert_margin_to_topk_loss_value.zero_() self.last_dead_expert_margin_to_topk_target.zero_() self.last_dead_expert_margin_to_topk_loss = margin_loss return margin_loss dead_mask = (counts == 0).to(choice_scores.dtype) margin_gap = torch.relu(kth_choice_score - choice_scores) dead_margin_sum = ( margin_gap * dead_mask.view(1, 1, self.num_fine_experts) ).sum() dead_count = dead_mask.sum() num_tokens = choice_scores.new_ones( choice_scores.shape[:2], dtype=choice_scores.dtype ).sum() normalizer = dead_count.clamp_min(1.0) * num_tokens margin_loss = dead_margin_sum / normalizer with torch.no_grad(): self.last_dead_expert_margin_to_topk_loss_value.copy_( margin_loss.detach().to( self.last_dead_expert_margin_to_topk_loss_value.dtype ) ) self.last_dead_expert_margin_to_topk_target.copy_( kth_choice_score.mean() .detach() .to(self.last_dead_expert_margin_to_topk_target.dtype) ) self.last_dead_expert_margin_to_topk_loss = margin_loss return margin_loss @torch.compiler.disable def _compute_sparse_experts( self, x: torch.Tensor, topk_idx: torch.Tensor, topk_scores: torch.Tensor, ) -> torch.Tensor: B, T, D = x.size() num_top_k = self.top_k is_exporting = torch.onnx.is_in_onnx_export() if is_exporting: # ONNX/runtime path: compute only selected experts (top_k), # avoiding O(num_experts) per-step overhead at bs=1. return self._compute_with_topk_selection(x, topk_idx, topk_scores) x_flat = x.view(-1, D) expert_ids = topk_idx.view(-1) scores = topk_scores.view(-1) raw_token_indices = ( torch.arange(B * T, device=x.device) .unsqueeze(1) .expand(-1, num_top_k) .reshape(-1) ) sorted_expert_ids, perm = torch.sort(expert_ids) sorted_token_indices = raw_token_indices[perm] x_sorted = x_flat[sorted_token_indices] scores_sorted = scores[perm] # Path B: High-Performance Grouped GEMM output_sorted = self._compute_with_grouped_mm( x_sorted, sorted_expert_ids ) output_sorted = output_sorted * scores_sorted.unsqueeze(-1) inv_perm = torch.argsort(perm) output_flat = output_sorted[inv_perm] output_final = output_flat.view(B * T, num_top_k, D).sum(dim=1) return output_final.view(B, T, D) def _compute_with_grouped_mm( self, x_sorted: torch.Tensor, sorted_expert_ids: torch.Tensor ) -> torch.Tensor: """Based on official implementation logic: - offsets must be Cumsum (End-Indices). - offsets length must be exactly Num_Experts (NOT N+1). - dtype must be int32. """ tokens_per_expert = torch.bincount( sorted_expert_ids.long(), minlength=self.num_fine_experts ) counts = tokens_per_expert[: self.num_fine_experts] offsets = torch.cumsum(counts, dim=0, dtype=torch.int32) gate_up_out = _grouped_linear( x_sorted, self.gate_up_proj, offs=offsets ) x1, x2 = gate_up_out.chunk(2, dim=-1) hidden = F.silu(x1) * x2 out = _grouped_linear(hidden, self.down_proj, offs=offsets) return out def _compute_with_topk_selection( self, x: torch.Tensor, topk_idx: torch.Tensor, topk_scores: torch.Tensor, ) -> torch.Tensor: """ONNX-friendly sparse expert compute that scales with top_k, not num_fine_experts. """ B, T, D = x.shape N = B * T K = self.top_k orig_dtype = x.dtype x_tokens = x.reshape(N, D) idx = topk_idx.reshape(N, K) scores = topk_scores.reshape(N, K) x_rep = x_tokens[:, None, :].expand(N, K, D).reshape(N * K, D) idx_flat = idx.reshape(N * K) compute_dtype = self.gate_up_proj.dtype if x_rep.dtype != compute_dtype: x_rep = x_rep.to(compute_dtype) gate_up_w = self.gate_up_proj.index_select(0, idx_flat) gate_up_out = torch.bmm(x_rep.unsqueeze(1), gate_up_w).squeeze(1) x1, x2 = gate_up_out.chunk(2, dim=-1) hidden = F.silu(x1) * x2 down_w = self.down_proj.index_select(0, idx_flat) sparse_flat = torch.bmm(hidden.unsqueeze(1), down_w).squeeze(1) if sparse_flat.dtype != orig_dtype: sparse_flat = sparse_flat.to(orig_dtype) sparse = sparse_flat.view(N, K, D) weighted = sparse * scores.to(sparse.dtype).unsqueeze(-1) out = weighted.sum(dim=1) return out.view(B, T, D) def _compute_with_loop_fallback( self, x_sorted: torch.Tensor, sorted_expert_ids: torch.Tensor ) -> torch.Tensor: """Path A: Loop Fallback (Compatible with F.linear and 3D Weights)""" results = [] for i in range(self.num_fine_experts): mask = sorted_expert_ids == i inp_i = x_sorted[mask] # Gate + Up w_gate_up = self.gate_up_proj[i].transpose(0, 1) gate_up_out = F.linear(inp_i, w_gate_up) x1, x2 = gate_up_out.chunk(2, dim=-1) hidden = F.silu(x1) * x2 # Down w_down = self.down_proj[i].transpose(0, 1) out_i = F.linear(hidden, w_down) results.append(out_i) return torch.cat(results, dim=0) def compute_moe_ffn( self, x: torch.Tensor, router_x: torch.Tensor | None = None, *, return_routing_debug: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, T, D = x.shape should_cache_router_distribution = ( self.collect_routing_stats and self.collect_router_distribution ) or self.capture_router_distribution should_cache_router_logits = self.capture_router_logits # 1. Shared Experts (Dense Path) shared_out = self.shared_experts(x) # 2. Router (Gating) router_input = x if router_x is None else router_x if router_input.shape != x.shape: raise ValueError( "router_x shape must match x shape in compute_moe_ffn: " f"x={tuple(x.shape)}, router_x={tuple(router_input.shape)}" ) if self.freeze_router: with torch.no_grad(): logits = self.router(router_input) else: logits = self.router(router_input) logits_fp32 = logits.to(torch.float32) bias_fp32 = None if self.use_dynamic_bias: bias_fp32 = self.expert_bias.to( device=logits.device, dtype=torch.float32 ) if self.routing_score_fn == "softmax": choice_logits = logits_fp32 if bias_fp32 is not None: # Keep dynamic bias as a selection correction, not a mixture-weight shaper. choice_logits = choice_logits + bias_fp32 choice_scores = choice_logits _, topk_idx = torch.topk(choice_scores, self.top_k, dim=-1) dense_distribution = torch.softmax(logits_fp32, dim=-1) if torch.onnx.is_in_onnx_export(): selected_probs = dense_distribution.gather(-1, topk_idx) else: selected_logits = logits_fp32.gather(-1, topk_idx) log_z = torch.logsumexp(logits_fp32, dim=-1, keepdim=True) selected_probs = torch.exp(selected_logits - log_z) topk_scores = selected_probs / selected_probs.sum( dim=-1, keepdim=True ).clamp_min(1.0e-20) router_distribution = None if should_cache_router_distribution: router_distribution = dense_distribution else: # sigmoid scores = torch.sigmoid(logits_fp32) dense_distribution = scores / scores.sum( dim=-1, keepdim=True ).clamp_min(1.0e-20) scores_for_choice = scores if bias_fp32 is not None: # DeepSeek-style correction bias for expert choice. scores_for_choice = scores_for_choice + bias_fp32 choice_scores = scores_for_choice _, topk_idx = torch.topk(choice_scores, self.top_k, dim=-1) selected_scores = scores.gather(-1, topk_idx) # Match DeepSeek-style routing: bias affects only expert choice, # while the expert mixing weights come from the original sigmoid # affinities normalized over the selected experts. topk_scores = selected_scores / selected_scores.sum( dim=-1, keepdim=True ).clamp_min(1.0e-20) router_distribution = None if should_cache_router_distribution: router_distribution = dense_distribution if self.collect_routing_stats: self.accumulate_routing_stats(topk_idx) if ( should_cache_router_distribution and router_distribution is not None ): self.last_router_distribution = router_distribution else: self.last_router_distribution = None if should_cache_router_logits: self.last_router_logits = logits_fp32 else: self.last_router_logits = None self._update_routed_expert_stats_and_floor_loss( topk_idx=topk_idx, dense_distribution=dense_distribution, choice_scores=choice_scores, ) if self.routing_scale != 1.0: topk_scores = topk_scores * self.routing_scale topk_scores = topk_scores.to(logits.dtype) # 3. Sparse Experts Computation (Grouped MM / ONNX Loop) sparse_out = self._compute_sparse_experts(x, topk_idx, topk_scores) # 4. Combine output = shared_out + sparse_out output = self.mlp_dropout(output) if return_routing_debug: return output, topk_idx, logits_fp32 return output def forward( self, x: torch.Tensor, cos: torch.Tensor = None, sin: torch.Tensor = None, mask: torch.Tensor | None = None, memory: torch.Tensor | None = None, memory_mask: torch.Tensor | None = None, router_x: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass compatible with ONNX and Attention/Norm.""" norm_x = self.norm1(x) attn_out = self.attn(norm_x, cos, sin, mask) x = x + attn_out if self.use_cross_attn and memory is not None: x_cross = self.norm_cross(x) if memory.ndim == 4: b, t, d_model = x_cross.shape _, _, n_fut, _ = memory.shape q = x_cross.reshape(b * t, 1, d_model) mem = memory.reshape(b * t, n_fut, d_model) mem_mask = None if memory_mask is not None: if memory_mask.ndim != 3: raise ValueError( "memory_mask for 4D memory must have shape [B, T, N_fut]" ) mem_mask = memory_mask.reshape(b * t, 1, 1, n_fut) cross = self.cross_attn(q, mem, mem_mask).reshape( b, t, d_model ) else: cross = self.cross_attn(x_cross, memory, memory_mask) x = x + cross h = self.norm2(x) ffn_out = self.compute_moe_ffn(h, router_x=router_x) x = x + ffn_out return x class ModernTransformerBlock(nn.Module): """Modern Transformer block with pre-norm, SwiGLU MLP, and modern attention. Features: - Pre-normalization with RMSNorm. - ModernAttention (GQA, QK-Norm, RealRoPE, Gated Attention). - DeepseekV3MLP (SwiGLU) for feed-forward. """ def __init__( self, d_model: int, n_heads: int, n_kv_heads: int | None = None, ff_mult: int = 4, use_qk_norm: bool = True, use_gated_attn: bool = True, gated_attn_type: str = "headwise", attn_dropout: float = 0.0, mlp_dropout: float = 0.0, use_cross_attn: bool = False, ): super().__init__() self.use_cross_attn = bool(use_cross_attn) self.norm1 = RMSNorm(d_model) self.attn = ModernAttention( d_model=d_model, n_heads=n_heads, n_kv_heads=n_kv_heads, use_qk_norm=use_qk_norm, use_gated_attn=use_gated_attn, gated_attn_type=gated_attn_type, attn_dropout=attn_dropout, ) if self.use_cross_attn: self.norm_cross = RMSNorm(d_model) self.cross_attn = ModernCrossAttention( d_model=d_model, n_heads=n_heads, n_kv_heads=n_kv_heads, use_qk_norm=use_qk_norm, use_gated_attn=use_gated_attn, gated_attn_type=gated_attn_type, attn_dropout=attn_dropout, ) else: self.norm_cross = None self.cross_attn = None self.norm2 = RMSNorm(d_model) self.mlp = DeepseekV3MLP( hidden_size=d_model, intermediate_size=d_model * ff_mult ) self.mlp_dropout = ( nn.Dropout(mlp_dropout) if mlp_dropout > 0.0 else nn.Identity() ) def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: torch.Tensor | None = None, memory: torch.Tensor | None = None, memory_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with pre-norm residual connections. Args: x: Input tensor [B, T, d_model]. freqs_cis: RoPE frequencies [T, head_dim // 2]. mask: Attention mask [T, T] or [B, T, T], True = allowed (can attend). Returns: out: Output tensor [B, T, d_model]. """ x = x + self.attn(self.norm1(x), cos, sin, mask) if self.use_cross_attn and memory is not None: x_cross = self.norm_cross(x) if memory.ndim == 4: b, t, d_model = x_cross.shape _, _, n_fut, _ = memory.shape q = x_cross.reshape(b * t, 1, d_model) mem = memory.reshape(b * t, n_fut, d_model) mem_mask = None if memory_mask is not None: if memory_mask.ndim != 3: raise ValueError( "memory_mask for 4D memory must have shape [B, T, N_fut]" ) mem_mask = memory_mask.reshape(b * t, 1, 1, n_fut) cross = self.cross_attn(q, mem, mem_mask).reshape( b, t, d_model ) else: cross = self.cross_attn(x_cross, memory, memory_mask) x = x + cross x = x + self.mlp_dropout(self.mlp(self.norm2(x))) return x class DeepseekV3MLP(nn.Module): """SwiGLU MLP with fused gate+up projection for efficiency.""" def __init__(self, hidden_size=None, intermediate_size=None): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size # Fused gate and up projection: outputs [gate, up] concatenated self.gate_up_proj = nn.Linear( self.hidden_size, 2 * self.intermediate_size, bias=True, ) self.down_proj = nn.Linear( self.intermediate_size, self.hidden_size, bias=True, ) self.act_fn = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [..., hidden_size] gate_up = self.gate_up_proj(x) # [..., 2 * intermediate_size] gate, up = gate_up.chunk(2, dim=-1) # each [..., intermediate_size] return self.down_proj(self.act_fn(gate) * up) class RMSNorm(nn.Module): """Root Mean Square Layer Normalization (used in Llama, DeepSeek, Qwen).""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: var = x.pow(2).mean(-1, keepdim=True) x_normed = x * torch.rsqrt(var + self.eps) return x_normed * self.weight class ModernAttention(nn.Module): """Modern attention with GQA, QK-Norm, RealRoPE, Flash Attention, and Gated Attention. Features: - GQA: Grouped Query Attention (n_kv_heads < n_heads). - QK-Norm: RMSNorm on queries and keys for stability. - RealRoPE: Real-valued Rotary Positional Embeddings. - Flash Attention: via F.scaled_dot_product_attention. - Gated Attention: Headwise or element-wise sigmoid gating (Qwen3-style). - Fused Projections: Q separate, KV fused for efficiency. Reference: https://github.com/qiuzh20/gated_attention """ def __init__( self, d_model: int, n_heads: int, n_kv_heads: int | None = None, use_qk_norm: bool = True, use_gated_attn: bool = True, gated_attn_type: str = "headwise", attn_dropout: float = 0.0, ): """Initialize ModernAttention. Args: d_model: Model dimension. n_heads: Number of query heads. n_kv_heads: Number of key/value heads (for GQA). Defaults to n_heads. use_qk_norm: Apply RMSNorm to Q and K. use_gated_attn: Enable gated attention. gated_attn_type: "headwise" (Qwen3-style, one gate per head) or "elementwise" (one gate per element). attn_dropout: Dropout probability for attention. """ super().__init__() self.d_model = d_model self.n_heads = n_heads self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads self.head_dim = d_model // n_heads self.n_rep = self.n_heads // self.n_kv_heads self.use_qk_norm = use_qk_norm self.use_gated_attn = use_gated_attn self.gated_attn_type = gated_attn_type self.attn_dropout = attn_dropout # Fused projections: Q separate, KV fused (for GQA efficiency) self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False) self.kv_proj = nn.Linear( d_model, 2 * self.n_kv_heads * self.head_dim, bias=False ) self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False) if self.use_qk_norm: self.q_norm = RMSNorm(self.head_dim) self.k_norm = RMSNorm(self.head_dim) if self.use_gated_attn: if self.gated_attn_type == "headwise": # Qwen3-style: one gate scalar per head [B, T, n_heads] self.gate_proj = nn.Linear(d_model, n_heads, bias=False) else: # Element-wise: gate each element [B, T, d_model] self.gate_proj = nn.Linear(d_model, d_model, bias=False) def forward( self, x: torch.Tensor, cos: torch.Tensor, # [B, T, D] sin: torch.Tensor, # [B, T, D] mask: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass. Args: x: Input tensor [B, T, d_model]. cos: RoPE cosine frequencies [B, T, head_dim]. sin: RoPE sine frequencies [B, T, head_dim]. mask: Attention mask [T, T] or [B, T, T], True = allowed (can attend). Returns: out: Output tensor [B, T, d_model]. """ B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim) kv = self.kv_proj(x).view(B, T, 2, self.n_kv_heads, self.head_dim) k, v = kv[:, :, 0], kv[:, :, 1] # each [B, T, n_kv_heads, head_dim] if self.use_qk_norm: q = self.q_norm(q) k = self.k_norm(k) # Transpose for SDPA: [B, n_heads, T, head_dim] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) q, k = apply_rotary_pos_emb(q, k, cos, sin) # Flash Attention via SDPA (handles GQA internally) dropout_p = self.attn_dropout if self.training else 0.0 is_exporting = torch.onnx.is_in_onnx_export() if is_exporting: k = repeat_kv(k, self.n_rep) v = repeat_kv(v, self.n_rep) enable_gqa = False else: enable_gqa = True attn_out = export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=dropout_p, is_causal=(mask is None), enable_gqa=enable_gqa, ) # attn_out: [B, n_heads, T, head_dim] # Gated Attention (Qwen3-style) if self.use_gated_attn: if self.gated_attn_type == "headwise": # Headwise gating: [B, T, n_heads] -> [B, n_heads, T, 1] g = torch.sigmoid(self.gate_proj(x)) # [B, T, n_heads] g = g.transpose(1, 2)[..., None] # [B, n_heads, T, 1] attn_out = attn_out * g else: # Element-wise gating: apply after reshaping attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1) g = torch.sigmoid(self.gate_proj(x)) # [B, T, d_model] attn_out = attn_out * g return self.o_proj(attn_out) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1) return self.o_proj(attn_out) def forward_single_token( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, new_len: torch.Tensor, insert_pos: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward for single token with per-environment KV cache update. Args: x: Input tensor [B, 1, d_model]. cos: RoPE cosine frequencies [B, 1, head_dim]. sin: RoPE sine frequencies [B, 1, head_dim]. k_cache: K cache for this layer [B, max_ctx_len, n_kv_heads, head_dim]. v_cache: V cache for this layer [B, max_ctx_len, n_kv_heads, head_dim]. new_len: New valid cache length per env AFTER inserting this token [B]. insert_pos: Insert position per env [B]. Returns: attn_out: [B, 1, d_model] k_cache: Updated K cache v_cache: Updated V cache """ B = x.shape[0] q = self.q_proj(x).view(B, 1, self.n_heads, self.head_dim) kv = self.kv_proj(x).view(B, 1, 2, self.n_kv_heads, self.head_dim) k_new, v_new = kv[:, :, 0], kv[:, :, 1] # [B, 1, n_kv_heads, head_dim] if self.use_qk_norm: q = self.q_norm(q) k_new = self.k_norm(k_new) # Apply RoPE with per-environment position q = q.transpose(1, 2) k_new = k_new.transpose(1, 2) q, k_new = apply_rotary_pos_emb(q, k_new, cos, sin) q = q.transpose(1, 2) k_new = k_new.transpose(1, 2) # Scatter K, V into cache at per-env insert positions # insert_pos: [B] -> [B, 1, 1, 1] for scatter idx = ( insert_pos.view(B, 1, 1, 1) .expand(B, 1, self.n_kv_heads, self.head_dim) .to(torch.int64) ) if torch.onnx.is_in_onnx_export(): # === ONNX 模式: Out-of-place (生成新 Tensor) === k_cache = k_cache.scatter(1, idx, k_new.to(k_cache.dtype)) v_cache = v_cache.scatter(1, idx, v_new.to(v_cache.dtype)) else: # === Rollout 模式: In-place (原地修改) === k_cache.scatter_(1, idx, k_new.to(k_cache.dtype)) v_cache.scatter_(1, idx, v_new.to(v_cache.dtype)) # Compute attention over cached keys/values # Mask out positions >= new_len (after insert) max_len = k_cache.shape[1] new_len = new_len.clamp(max=max_len) # [B] # Build per-env mask: [B, max_len] where True = valid (can attend) pos_idx = torch.arange(max_len, device=x.device, dtype=torch.int64) valid_mask = pos_idx[None, :] < new_len[:, None] # [B, max_len] # For SDPA bool mask: True = allowed (can attend) attn_mask = valid_mask[:, None, None, :] # [B, 1, 1, max_len] # GQA: Use native SDPA broadcasting (no repeat_interleave) k_attn = k_cache.to(q.dtype) v_attn = v_cache.to(q.dtype) # Transpose for SDPA: [B, n_heads, T, head_dim] q_t = q.transpose(1, 2) # [B, n_heads, 1, head_dim] k_t = k_attn.transpose(1, 2) # [B, n_kv_heads, max_len, head_dim] v_t = v_attn.transpose(1, 2) dropout_p = self.attn_dropout if self.training else 0.0 is_exporting = torch.onnx.is_in_onnx_export() if is_exporting: k_t = repeat_kv(k_t, self.n_rep) v_t = repeat_kv(v_t, self.n_rep) enable_gqa = False else: enable_gqa = True attn_out = export_safe_scaled_dot_product_attention( q_t, k_t, v_t, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, enable_gqa=enable_gqa, ) # attn_out: [B, n_heads, 1, head_dim] # Gated Attention if self.use_gated_attn: if self.gated_attn_type == "headwise": g = torch.sigmoid(self.gate_proj(x)) # [B, 1, n_heads] g = g.transpose(1, 2)[..., None] # [B, n_heads, 1, 1] attn_out = attn_out * g else: attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, -1) g = torch.sigmoid(self.gate_proj(x)) attn_out = attn_out * g return self.o_proj(attn_out), k_cache, v_cache attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, -1) return self.o_proj(attn_out), k_cache, v_cache class ModernCrossAttention(nn.Module): """Cross-attention with GQA/QK-norm and optional gated attention.""" def __init__( self, d_model: int, n_heads: int, n_kv_heads: int | None = None, use_qk_norm: bool = True, use_gated_attn: bool = True, gated_attn_type: str = "headwise", attn_dropout: float = 0.0, ): super().__init__() self.d_model = int(d_model) self.n_heads = int(n_heads) self.n_kv_heads = ( int(n_kv_heads) if n_kv_heads is not None else int(n_heads) ) self.head_dim = self.d_model // self.n_heads self.n_rep = self.n_heads // self.n_kv_heads self.use_qk_norm = bool(use_qk_norm) self.use_gated_attn = bool(use_gated_attn) self.gated_attn_type = str(gated_attn_type) self.attn_dropout = float(attn_dropout) self.q_proj = nn.Linear( self.d_model, self.n_heads * self.head_dim, bias=False ) self.kv_proj = nn.Linear( self.d_model, 2 * self.n_kv_heads * self.head_dim, bias=False ) self.o_proj = nn.Linear( self.n_heads * self.head_dim, self.d_model, bias=False ) if self.use_qk_norm: self.q_norm = RMSNorm(self.head_dim) self.k_norm = RMSNorm(self.head_dim) if self.use_gated_attn: if self.gated_attn_type == "headwise": self.gate_proj = nn.Linear( self.d_model, self.n_heads, bias=False ) else: self.gate_proj = nn.Linear( self.d_model, self.d_model, bias=False ) def forward( self, x: torch.Tensor, memory: torch.Tensor, mask: torch.Tensor | None = None, ) -> torch.Tensor: if x.ndim != 3: raise ValueError(f"x must be [B, T, D], got {tuple(x.shape)}") if memory.ndim != 3: raise ValueError( f"memory must be [B, N, D], got {tuple(memory.shape)}" ) b, t, _ = x.shape bm, n, _ = memory.shape if bm != b: raise ValueError( f"batch mismatch between x and memory: {b} vs {bm}" ) q = self.q_proj(x).view(b, t, self.n_heads, self.head_dim) kv = self.kv_proj(memory).view(b, n, 2, self.n_kv_heads, self.head_dim) k, v = kv[:, :, 0], kv[:, :, 1] if self.use_qk_norm: q = self.q_norm(q) k = self.k_norm(k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) dropout_p = self.attn_dropout if self.training else 0.0 is_exporting = torch.onnx.is_in_onnx_export() if is_exporting: k = repeat_kv(k, self.n_rep) v = repeat_kv(v, self.n_rep) enable_gqa = False else: enable_gqa = True attn_out = export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=dropout_p, is_causal=False, enable_gqa=enable_gqa, ) if self.use_gated_attn: if self.gated_attn_type == "headwise": g = torch.sigmoid(self.gate_proj(x)) g = g.transpose(1, 2)[..., None] attn_out = attn_out * g else: attn_out = attn_out.transpose(1, 2).contiguous().view(b, t, -1) g = torch.sigmoid(self.gate_proj(x)) attn_out = attn_out * g return self.o_proj(attn_out) attn_out = attn_out.transpose(1, 2).contiguous().view(b, t, -1) return self.o_proj(attn_out) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """Standard LLaMA GQA replication logic, optimized for ONNX export. Equivalent to torch.repeat_interleave(x, dim=1, repeats=n_rep). Input shape: (batch, num_key_value_heads, seqlen, head_dim) Output shape: (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states # 1. Unsqueeze: [batch, n_kv, 1, seq, dim] # 2. Expand: [batch, n_kv, n_rep, seq, dim] hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) # 3. Reshape: [batch, n_kv * n_rep, seq, dim] -> [batch, n_head, seq, dim] return hidden_states.reshape( batch, num_key_value_heads * n_rep, slen, head_dim ) def export_safe_scaled_dot_product_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *, attn_mask: torch.Tensor | None, dropout_p: float, is_causal: bool, enable_gqa: bool = False, ) -> torch.Tensor: if ( not torch.onnx.is_in_onnx_export() or attn_mask is None or attn_mask.dtype != torch.bool ): return F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, enable_gqa=enable_gqa, ) # Use additive float bias during ONNX export so the legacy exporter # does not emit the bool-mask SDPA cleanup path with IsNaN. mask_bias = torch.zeros_like(attn_mask, dtype=q.dtype) mask_bias = mask_bias.masked_fill(~attn_mask, torch.finfo(q.dtype).min) return F.scaled_dot_product_attention( q, k, v, attn_mask=mask_bias, dropout_p=dropout_p, is_causal=is_causal, enable_gqa=enable_gqa, ) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors.""" cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) orig_dtype = q.dtype # 强制转为 fp32 进行计算 q_fp32 = q.to(torch.float32) k_fp32 = k.to(torch.float32) cos_fp32 = cos.to(torch.float32) sin_fp32 = sin.to(torch.float32) q_embed = (q_fp32 * cos_fp32) + (rotate_half(q_fp32) * sin_fp32) k_embed = (k_fp32 * cos_fp32) + (rotate_half(k_fp32) * sin_fp32) return q_embed.to(orig_dtype), k_embed.to(orig_dtype) def _grouped_linear( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None, offs: torch.Tensor | None = None, ) -> torch.Tensor: """input: [Total_Tokens, In_Dim] weight: [Num_Experts, In_Dim, Out_Dim] """ orig_dtype = input.dtype if input.dtype != weight.dtype: input = input.to(weight.dtype) out = torch._grouped_mm(input, weight, offs=offs) if out.dtype != orig_dtype: out = out.to(orig_dtype) if bias is not None: out = out + bias return out ================================================ FILE: holomotion/src/motion_retargeting/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/motion_retargeting/gmr_to_holomotion.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import json, os, sys from pathlib import Path from typing import Dict, Tuple, Optional, List import joblib import numpy as np import torch import hydra from omegaconf import OmegaConf, DictConfig, ListConfig from tqdm import tqdm from loguru import logger from holomotion.src.utils import torch_utils from holomotion.src.motion_retargeting.utils.torch_humanoid_batch import ( HumanoidBatch, ) from holomotion.src.motion_retargeting.utils import ( rotation_conversions as rot_conv, ) from holomotion.src.motion_retargeting.holomotion_preprocess import ( HoloMotionPreprocessor, ProcessedClip, ) import ray import logging def quaternion_to_axis_angle(q: torch.Tensor) -> torch.Tensor: q = q / torch.norm(q, dim=-1, keepdim=True) x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] angle = 2 * torch.acos(torch.clamp(w, -1.0, 1.0)) s = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0)) s = torch.clamp(s, min=1e-6) ax = x / s ay = y / s az = z / s axis_angles = torch.stack([ax * angle, ay * angle, az * angle], dim=-1) return axis_angles def dof_to_pose_aa( dof_pos: np.ndarray, robot_config_path: Optional[str], root_rot: Optional[np.ndarray], ) -> np.ndarray: """Compute pose_aa via FK; if no config is provided, return zero placeholders.""" if not robot_config_path: T = dof_pos.shape[0] return np.zeros((T, 27, 3), dtype=np.float32) robot_cfg = OmegaConf.load(robot_config_path) logger.info(f"Loaded robot config for FK from: {robot_config_path}") fk = HumanoidBatch(robot_cfg.robot) num_aug = len(robot_cfg.robot.extend_config) dof_t = torch.as_tensor(dof_pos, dtype=torch.float32) if dof_t.dim() == 3 and dof_t.shape[-1] == 1: dof_t = dof_t.squeeze(-1) T = dof_t.shape[0] if root_rot is None: root_aa = torch.zeros((T, 3), dtype=torch.float32) else: rr = torch.as_tensor(root_rot, dtype=torch.float32) root_aa = quaternion_to_axis_angle(rr) if rr.shape[-1] == 4 else rr joint_aa = fk.dof_axis * dof_t.unsqueeze(-1) pose_aa = torch.cat( [root_aa.unsqueeze(1), joint_aa, torch.zeros((T, num_aug, 3))], dim=1 ) return pose_aa.numpy().astype(np.float32, copy=False) def load_any_pkl(p: Path): with open(p, "rb") as f: return joblib.load(f) def unwrap_source(obj) -> Dict[str, np.ndarray]: """Accept {top_key: inner} or flat dict (early GMR).""" if isinstance(obj, dict) and len(obj) == 1: inner = next(iter(obj.values())) if isinstance(inner, dict): return inner if isinstance(obj, dict): return obj raise ValueError("Unsupported PKL structure") def make_motion_key(p: Path, src_dir: Path) -> str: rel = p.relative_to(src_dir).with_suffix("") return "/".join(rel.parts) def key_to_filename(key: str) -> str: return key.replace("/", "_") + ".npz" def get_ref_schema( ref_dir: Path, ) -> Tuple[Dict[str, Tuple[Tuple[int, ...], np.dtype]], str]: """ Read schema only from ref_dir/_schema.json. Expected JSON structure: { "schema": { "root_trans_offset": {"shape": [T, 3], "dtype": "float64"}, "pose_aa": {"shape": [T, 27, 3], "dtype": "float32"}, ... }, "sample_top_key": "xxx" } """ ref_dir = Path(ref_dir) cache_path = ref_dir / "_schema.json" if not cache_path.exists(): raise FileNotFoundError(f"Schema JSON not found: {cache_path}") try: with open(cache_path, "r", encoding="utf-8") as f: obj = json.load(f) except Exception as e: raise ValueError(f"Failed to parse _schema.json: {e}") schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]] = {} raw = obj.get("schema", {}) if not isinstance(raw, dict) or not raw: raise ValueError("Schema JSON missing 'schema' object or it's empty.") for k, v in raw.items(): if not isinstance(v, dict) or "shape" not in v or "dtype" not in v: raise ValueError(f"Bad schema entry for key '{k}': {v}") shape = tuple(int(x) for x in v["shape"]) dtype = np.dtype(v["dtype"]) schema[k] = (shape, dtype) sample_top_key = obj.get("sample_top_key", "") return schema, sample_top_key def infer_T(src_inner: Dict[str, np.ndarray]) -> Optional[int]: for key in [ "root_trans_offset", "root_pos", "pose_aa", "dof", "dof_pos", "root_rot", "smpl_joints", ]: v = src_inner.get(key) if isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] > 0: return int(v.shape[0]) T = 0 for v in src_inner.values(): if isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] > T: T = int(v.shape[0]) return T or None def build_inner_from_source( src_inner: Dict[str, np.ndarray], schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]], T_default: int, ) -> Dict[str, object]: alt_map = { "root_trans_offset": [ "root_trans_offset", "root_pos", "trans", "root_trans", ], "pose_aa": ["pose_aa"], "dof": ["dof", "dof_pos"], "root_rot": ["root_rot", "root_orient", "root_quat"], "smpl_joints": ["smpl_joints", "joints", "smpljoints"], "fps": ["fps", "mocap_framerate", "mocap_frame_rate"], } out: Dict[str, object] = {} T = infer_T(src_inner) or T_default for key, (shape, dtype) in schema.items(): if key == "fps": fps = None for cand in alt_map["fps"]: v = src_inner.get(cand) if isinstance(v, (int, np.integer)): fps = int(v) break out["fps"] = int(fps) if fps is not None else 30 continue src_arr = None for cand in alt_map.get(key, []): v = src_inner.get(cand) if isinstance(v, np.ndarray) and v.ndim >= 1: src_arr = v break # Target shape: override leading T; keep source column count for DOF if key == "dof" and isinstance(src_arr, np.ndarray): target_shape = (T, src_arr.shape[1] if src_arr.ndim >= 2 else 1) else: ts = list(shape) if ts: ts[0] = T target_shape = tuple(ts) if src_arr is None: out[key] = np.zeros(target_shape, dtype=dtype) continue arr = src_arr.astype(dtype, copy=False) if key == "dof" and arr.ndim == 1: arr = arr.reshape(-1, 1) if arr.shape[0] > T: arr = arr[:T] elif arr.shape[0] < T: pad = np.repeat(arr[-1:], T - arr.shape[0], axis=0) arr = np.concatenate([arr, pad], axis=0) if ( key != "dof" and len(target_shape) == 2 and arr.shape[1] != target_shape[1] ): if arr.shape[1] > target_shape[1]: arr = arr[:, : target_shape[1]] else: pad = np.zeros( (T, target_shape[1] - arr.shape[1]), dtype=arr.dtype ) arr = np.concatenate([arr, pad], axis=1) if len(target_shape) == 3: d1 = min(arr.shape[1], target_shape[1]) d2 = min(arr.shape[2], target_shape[2]) arr = arr[:, :d1, :d2] if (arr.shape[1], arr.shape[2]) != ( target_shape[1], target_shape[2], ): pad = np.zeros( ( T, target_shape[1] - arr.shape[1], target_shape[2] - arr.shape[2], ), dtype=arr.dtype, ) arr = np.concatenate([arr, pad], axis=1) if arr.shape[2] != target_shape[2]: pad2 = np.zeros( (T, target_shape[1], target_shape[2] - arr.shape[2]), dtype=arr.dtype, ) arr = np.concatenate([arr, pad2], axis=2) out[key] = arr.astype(dtype, copy=False) return out def to_torch(tensor): if torch.is_tensor(tensor): return tensor else: return torch.from_numpy(tensor.copy()) def batch_interpolate_tensor( tensor, orig_times, target_times, use_slerp=False ): """Optimized tensor interpolation with batch processing""" target_num_frames = len(target_times) shape = list(tensor.shape) shape[0] = target_num_frames # Create empty output tensor result = torch.zeros(shape, device=tensor.device, dtype=tensor.dtype) if len(tensor.shape) == 2: # For 2D tensors - process all frames at once # Create masks for the three cases before_mask = target_times <= orig_times[0] after_mask = target_times >= orig_times[-1] valid_mask = ~(before_mask | after_mask) # Handle edge cases if before_mask.any(): result[before_mask] = tensor[0] if after_mask.any(): result[after_mask] = tensor[-1] # Process interpolation for valid times if valid_mask.any(): valid_times = target_times[valid_mask] # Get indices for lower frames indices = torch.searchsorted(orig_times, valid_times) - 1 # Ensure indices are valid indices = torch.clamp(indices, 0, len(orig_times) - 2) next_indices = indices + 1 # Calculate weights alphas = (valid_times - orig_times[indices]) / ( orig_times[next_indices] - orig_times[indices] ) alphas = alphas.unsqueeze(-1) # Add dimension for broadcasting if use_slerp and tensor.shape[1] == 4: # Quaternion data # Process in smaller batches to avoid memory issues batch_size = 1000 # Adjust based on available memory num_valid = valid_mask.sum() for i in range(0, num_valid, batch_size): end_idx = min(i + batch_size, num_valid) batch_indices = torch.where(valid_mask)[0][i:end_idx] batch_alphas = alphas[i:end_idx] batch_lower_indices = indices[i:end_idx] batch_upper_indices = next_indices[i:end_idx] # Get frame data for this batch frames_low = tensor[batch_lower_indices] frames_high = tensor[batch_upper_indices] # Apply SLERP to this batch result[batch_indices] = torch_utils.slerp( frames_low, frames_high, batch_alphas ) else: # Standard linear interpolation - can be done in one batch frames_low = tensor[indices] frames_high = tensor[next_indices] result[valid_mask] = ( frames_low * (1 - alphas) + frames_high * alphas ) elif len(tensor.shape) == 3: # For 3D tensors - process each joint sequence for j in range(tensor.shape[1]): result[:, j] = batch_interpolate_tensor( tensor[:, j], orig_times, target_times, use_slerp ) return result def fast_interpolate_motion(motion_dict, source_fps, target_fps): """Optimized motion interpolation that preserves correctness""" # Early return if no interpolation needed if source_fps == target_fps: return motion_dict # Calculate timestamps orig_dt = 1.0 / source_fps target_dt = 1.0 / target_fps # Find the first tensor to determine number of frames for v in motion_dict.values(): if torch.is_tensor(v): num_frames = v.shape[0] device = v.device break else: return motion_dict # No tensor data to interpolate orig_times = torch.arange(0, num_frames, device=device) * orig_dt wallclock_len = orig_dt * (num_frames - 1) target_num_frames = int(wallclock_len * target_fps) + 1 target_times = ( torch.arange(0, target_num_frames, device=device) * target_dt ) # Create interpolated motion dictionary interp_motion = {} for k, v in motion_dict.items(): if not torch.is_tensor(v): interp_motion[k] = v continue is_quat = "quat" in k interp_motion[k] = batch_interpolate_tensor( v, orig_times, target_times, is_quat ) return interp_motion def process_single_motion( robot_cfg: dict, all_samples, # Can be dict or LazyMotionLoader curr_key: str, target_fps: int = 50, fast_interpolate: bool = True, debug_mode: bool = False, ): logger.debug(f"Starting process_single_motion for key: {curr_key}") humanoid_fk = HumanoidBatch(robot_cfg) motion_sample_dict = all_samples[curr_key] if len(motion_sample_dict) == 1: motion_sample_dict = motion_sample_dict[ list(motion_sample_dict.keys())[0] ] logger.debug("Step 3: Extracting sequence length") if debug_mode: # In debug mode, let exceptions bubble up naturally if "root_trans_offset" not in motion_sample_dict: available_keys = list(motion_sample_dict.keys()) raise KeyError( f"'root_trans_offset' not found in motion data. Available keys: {available_keys}" ) seq_len = motion_sample_dict["root_trans_offset"].shape[0] start, end = 0, seq_len logger.debug(f"Step 3 completed - seq_len: {seq_len}") else: try: if "root_trans_offset" not in motion_sample_dict: available_keys = list(motion_sample_dict.keys()) raise KeyError( f"'root_trans_offset' not found in motion data. Available keys: {available_keys}" ) seq_len = motion_sample_dict["root_trans_offset"].shape[0] start, end = 0, seq_len logger.debug(f"Step 3 completed - seq_len: {seq_len}") except Exception as e: logger.error( f"Step 3 failed - Extracting sequence length: {e}", exc_info=True, ) raise RuntimeError( f"Failed to extract sequence length: {e}" ) from e logger.debug("Step 4: Processing root translation") if debug_mode: # In debug mode, let exceptions bubble up naturally trans = to_torch(motion_sample_dict["root_trans_offset"]).clone()[ start:end ] logger.debug(f"Step 4 completed - trans shape: {trans.shape}") else: try: trans = to_torch(motion_sample_dict["root_trans_offset"]).clone()[ start:end ] logger.debug(f"Step 4 completed - trans shape: {trans.shape}") except Exception as e: logger.error( f"Step 4 failed - Processing root translation: {e}", exc_info=True, ) raise RuntimeError( f"Failed to process root translation: {e}" ) from e logger.debug("Step 5: Processing pose_aa") if debug_mode: # In debug mode, let exceptions bubble up naturally if "pose_aa" not in motion_sample_dict: available_keys = list(motion_sample_dict.keys()) raise KeyError( f"'pose_aa' not found in motion data. Available keys: {available_keys}" ) pose_aa = to_torch(motion_sample_dict["pose_aa"][start:end]).clone() # If available, enforce root rotation from input quaternions (XYZW) if "root_rot" in motion_sample_dict: root_quat_xyzw = to_torch( motion_sample_dict["root_rot"][start:end] ).clone() root_quat_wxyz = rot_conv.xyzw_to_wxyz(root_quat_xyzw) root_axis_angle = rot_conv.quaternion_to_axis_angle(root_quat_wxyz) pose_aa[:, 0, :] = root_axis_angle logger.debug(f"Step 5 completed - pose_aa shape: {pose_aa.shape}") else: try: if "pose_aa" not in motion_sample_dict: available_keys = list(motion_sample_dict.keys()) raise KeyError( f"'pose_aa' not found in motion data. Available keys: {available_keys}" ) pose_aa = to_torch( motion_sample_dict["pose_aa"][start:end] ).clone() # If available, enforce root rotation from input quaternions (XYZW) if "root_rot" in motion_sample_dict: root_quat_xyzw = to_torch( motion_sample_dict["root_rot"][start:end] ).clone() root_quat_wxyz = rot_conv.xyzw_to_wxyz(root_quat_xyzw) root_axis_angle = rot_conv.quaternion_to_axis_angle( root_quat_wxyz ) pose_aa[:, 0, :] = root_axis_angle logger.debug(f"Step 5 completed - pose_aa shape: {pose_aa.shape}") except Exception as e: logger.error( f"Step 5 failed - Processing pose_aa: {e}", exc_info=True ) raise RuntimeError(f"Failed to process pose_aa: {e}") from e logger.debug("Step 6: Calculating dt") if debug_mode: # In debug mode, let exceptions bubble up naturally if "fps" not in motion_sample_dict: available_keys = list(motion_sample_dict.keys()) raise KeyError( f"'fps' not found in motion data. Available keys: {available_keys}" ) fps = motion_sample_dict["fps"] if fps <= 0: raise ValueError(f"Invalid fps value: {fps}") dt = 1 / fps logger.debug(f"Step 6 completed - fps: {fps}, dt: {dt}") else: try: if "fps" not in motion_sample_dict: available_keys = list(motion_sample_dict.keys()) raise KeyError( f"'fps' not found in motion data. Available keys: {available_keys}" ) fps = motion_sample_dict["fps"] if fps <= 0: raise ValueError(f"Invalid fps value: {fps}") dt = 1 / fps logger.debug(f"Step 6 completed - fps: {fps}, dt: {dt}") except Exception as e: logger.error(f"Step 6 failed - Calculating dt: {e}", exc_info=True) raise RuntimeError(f"Failed to calculate dt: {e}") from e logger.debug("Step 8: Running forward kinematics") if debug_mode: # In debug mode, let exceptions bubble up naturally curr_motion = humanoid_fk.fk_batch( pose_aa[None,], trans[None,], return_full=True, dt=dt, ) logger.debug("Step 8 completed") else: try: curr_motion = humanoid_fk.fk_batch( pose_aa[None,], trans[None,], return_full=True, dt=dt, ) logger.debug("Step 8 completed") except Exception as e: logger.error( f"Step 8 failed - Forward kinematics: {e}", exc_info=True ) raise RuntimeError(f"Failed to run forward kinematics: {e}") from e curr_motion = dict( { k: v.squeeze() if torch.is_tensor(v) else v for k, v in curr_motion.items() } ) motion_fps = curr_motion["fps"] motion_dt = 1.0 / motion_fps num_frames = curr_motion["global_rotation"].shape[0] wallclock_len = motion_dt * (num_frames - 1) num_dofs = len(robot_cfg.motion.dof_names) num_bodies = len(robot_cfg.motion.body_names) num_extended_bodies = num_bodies + len( robot_cfg.motion.get("extend_config", []) ) # build a frame_flag array to indicate three status: # start_of_motion: 0, middle_of_motion: 1, end_of_motion: 2 frame_flag = torch.ones(num_frames).int() frame_flag[0] = 0 frame_flag[-1] = 2 curr_motion["frame_flag"] = frame_flag # rename and pop some keys curr_motion["global_rotation_quat"] = curr_motion.pop("global_rotation") curr_motion["local_rotation_quat"] = curr_motion.pop("local_rotation") if "global_translation_extend" in curr_motion: curr_motion["global_rotation_quat_extend"] = curr_motion.pop( "global_rotation_extend" ) curr_motion.pop("fps") curr_motion.pop("global_rotation_mat") if "global_rotation_mat_extend" in curr_motion: curr_motion.pop("global_rotation_mat_extend") # add some keys curr_motion["global_root_translation"] = curr_motion["global_translation"][ :, 0 ] curr_motion["global_root_rotation_quat"] = curr_motion[ "global_rotation_quat" ][:, 0] # Interpolate to target_fps if different from original fps if target_fps != motion_fps: curr_motion = fast_interpolate_motion( curr_motion, motion_fps, target_fps ) motion_fps = target_fps motion_dt = 1.0 / target_fps num_frames = ( next(iter(curr_motion.values())).shape[0] if curr_motion else num_frames ) wallclock_len = motion_dt * (num_frames - 1) sample_dict = { "motion_name": curr_key, "motion_fps": motion_fps, "num_frames": num_frames, "wallclock_len": wallclock_len, "num_dofs": num_dofs, "num_bodies": num_bodies, "num_extended_bodies": num_extended_bodies, } sample_dict.update( { k: curr_motion[k].float().cpu().numpy() for k in sorted(curr_motion.keys()) } ) if debug_mode: for k, v in sample_dict.items(): if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray): logger.debug(f"{k}: {v.shape}") else: logger.debug(f"{k}: {v}") return sample_dict class InMemoryAlignedLoader: """Minimal Loader interface: compatible with process_single_motion sample access.""" def __init__(self, mapping: Dict[str, Dict[str, object]]): self._map = mapping def keys(self) -> List[str]: return list(self._map.keys()) def __len__(self): return len(self._map) def __getitem__(self, k: str): return self._map[k] def load(self, k: str): return self._map[k] def get(self, k: str, default=None): return self._map.get(k, default) def arrays_for_npz( sample: Dict, emit_prefixed: bool = True, emit_legacy: bool = False ) -> Dict[str, np.ndarray]: """ Build NPZ arrays: - Always include frame_flag if present - If emit_prefixed: write ref_* arrays mapped from base keys - If emit_legacy: also include legacy, unprefixed keys for compatibility """ base_to_ref = { "dof_pos": "ref_dof_pos", "dof_vel": "ref_dof_vel", "dof_vels": "ref_dof_vel", "global_translation": "ref_global_translation", "global_rotation_quat": "ref_global_rotation_quat", "global_velocity": "ref_global_velocity", "global_angular_velocity": "ref_global_angular_velocity", } out: Dict[str, np.ndarray] = {} if isinstance(sample.get("frame_flag"), np.ndarray): out["frame_flag"] = sample["frame_flag"] for base, ref_name in base_to_ref.items(): v = sample.get(base, None) if isinstance(v, np.ndarray): if emit_prefixed: out[ref_name] = v if emit_legacy: out[base] = v return out @ray.remote class MotionProcessorActor: """ Persistent Ray actor that loads robot config once and processes PKLs asynchronously. """ def __init__( self, robot_cfg_path: str, schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]], ): cfg = OmegaConf.load(robot_cfg_path) self.robot_cfg = cfg.robot self.schema = schema # Cached FK holder for DOF → axis-angle conversion (uses dof_axis) self._fk_for_dof = HumanoidBatch(self.robot_cfg) def _dof_to_pose_aa_cached( self, dof_pos: np.ndarray, root_rot: Optional[np.ndarray] ) -> np.ndarray: dof_t = torch.as_tensor(dof_pos, dtype=torch.float32) if dof_t.dim() == 3 and dof_t.shape[-1] == 1: dof_t = dof_t.squeeze(-1) T = int(dof_t.shape[0]) if root_rot is None: root_aa = torch.zeros((T, 3), dtype=torch.float32) else: rr = torch.as_tensor(root_rot, dtype=torch.float32) root_aa = quaternion_to_axis_angle(rr) if rr.shape[-1] == 4 else rr num_aug = len(self.robot_cfg.extend_config) joint_aa = self._fk_for_dof.dof_axis * dof_t[:, :, None] pose_aa = torch.cat( [root_aa[:, None, :], joint_aa, torch.zeros((T, num_aug, 3))], dim=1, ) return pose_aa.numpy().astype(np.float32, copy=False) def process_pkl( self, p_str: str, src_dir_str: str, target_fps: int, fast_interpolate: bool, debug_mode: bool, ) -> Tuple[bool, Dict[str, object]]: """ Returns (success, payload). On success, payload contains: { "flat_key": str, "sample": Dict[str, np.ndarray|scalar] } """ p = Path(p_str) src_dir = Path(src_dir_str) motion_key_rel = make_motion_key(p, src_dir) flat_key = motion_key_rel.replace("/", "_") obj = load_any_pkl(p) inner = unwrap_source(obj) T_default = infer_T(inner) or 1 aligned = build_inner_from_source(inner, self.schema, T_default) dof = aligned.get("dof") if isinstance(dof, np.ndarray) and dof.size > 0: root_rot = aligned.get("root_rot", None) aligned["pose_aa"] = self._dof_to_pose_aa_cached(dof, root_rot) loader = InMemoryAlignedLoader({flat_key: aligned}) sample = process_single_motion( self.robot_cfg, loader, flat_key, int(target_fps), bool(fast_interpolate), bool(debug_mode), ) payload: Dict[str, object] = {"flat_key": flat_key, "sample": sample} return True, payload @hydra.main( config_path="../../config", config_name="motion_retargeting/gmr_to_holomotion", version_base=None, ) def main(cfg: DictConfig) -> None: # Setup logging logger.remove() log_level = "DEBUG" if bool(cfg.processing.debug_mode) else "INFO" logger.add(sys.stderr, level=log_level, colorize=True) src_path = Path(str(cfg.io.src_dir)).expanduser().resolve() ref_dir = Path(str(cfg.io.ref_dir)).expanduser().resolve() out_root = Path(str(cfg.io.out_root)).expanduser().resolve() clips_dir = out_root / "clips" clips_dir.mkdir(parents=True, exist_ok=True) # dump resolved config used (out_root).mkdir(parents=True, exist_ok=True) with open(out_root / "config_used.yaml", "w") as f: f.write(OmegaConf.to_yaml(cfg)) # 1) schema from _schema.json schema, _ = get_ref_schema(ref_dir) # 2) gather PKLs if src_path.is_file() and src_path.suffix == ".pkl": src_pkls = [src_path] root_for_keys = src_path.parent else: src_pkls = [] for dirpath, _, filenames in os.walk(src_path, followlinks=True): for filename in filenames: if filename.endswith(".pkl"): p = Path(dirpath) / filename if p.is_file(): src_pkls.append(p) src_pkls = sorted(src_pkls) root_for_keys = src_path # 3) quiet third-party DEBUG logs (e.g., filelock/Ray) logging.getLogger("filelock").setLevel(logging.WARNING) logging.getLogger("ray").setLevel(logging.ERROR) os.environ.setdefault("RAY_BACKEND_LOG_LEVEL", "error") # 4) initialize Ray if str(cfg.ray.ray_address): ray.init( address=str(cfg.ray.ray_address), ignore_reinit_error=True, log_to_driver=False, include_dashboard=False, logging_level=logging.ERROR, ) else: num_cpus = ( None if int(cfg.ray.num_workers) <= 0 else int(cfg.ray.num_workers) ) ray.init( num_cpus=num_cpus, ignore_reinit_error=True, log_to_driver=False, include_dashboard=False, logging_level=logging.ERROR, ) # 5) build work list (skip existing if requested) skip_existing = bool(cfg.processing.skip_existing) work_list: List[Path] = [] for p in src_pkls: motion_key = make_motion_key(p, root_for_keys) out_name = key_to_filename(motion_key) if skip_existing and (clips_dir / out_name).exists(): continue work_list.append(p) if not work_list: logger.info("No tasks to run (all outputs exist or no PKLs found).") ray.shutdown() return # 6) create persistent actors (each loads robot config once) if int(cfg.ray.num_workers) > 0: num_actors = min(len(work_list), int(cfg.ray.num_workers)) else: available_cpus = int(ray.available_resources().get("CPU", 1)) num_actors = min(len(work_list), max(1, available_cpus)) actors = [ MotionProcessorActor.remote(str(cfg.io.robot_config), schema) for _ in range(num_actors) ] # Parse pipeline config pipeline_cfg = cfg.get("preprocess", None) pipeline = None if pipeline_cfg is not None: pipeline_val = pipeline_cfg.get("pipeline", None) if pipeline_val is not None: if isinstance(pipeline_val, (list, tuple, ListConfig)): pipeline = [str(s) for s in pipeline_val] elif isinstance(pipeline_val, str): import ast pipeline = ast.literal_eval(pipeline_val) else: logger.warning( f"Unexpected pipeline type: {type(pipeline_val)}, value: {pipeline_val}" ) pipeline = [] else: pipeline = [] else: pipeline = [] # Separate per-clip stages from dataset-level stages per_clip_pipeline = ( [s for s in pipeline if s != "tagging"] if pipeline else [] ) tagging_enabled = pipeline and "tagging" in pipeline logger.info("=" * 80) logger.info("Preprocessing Configuration:") if pipeline: logger.info(f" Pipeline stages: {pipeline}") logger.info(f" Number of stages: {len(pipeline)}") for i, stage in enumerate(pipeline, 1): logger.info(f" {i}. {stage}") if tagging_enabled: logger.info( " Note: 'tagging' is a dataset-level operation and will run after all clips are processed" ) else: logger.info( " No preprocessing pipeline specified - no processors will be applied" ) logger.info("=" * 80) preprocessor = HoloMotionPreprocessor( slicing_cfg=cfg.slicing, filtering_cfg=cfg.filtering, tagging_cfg=cfg.tagging, padding_cfg=cfg.get("padding", None), pipeline=per_clip_pipeline if per_clip_pipeline else None, ) # 7) asynchronously schedule PKLs to actors (round-robin) pending = {} next_idx = 0 # prime the queue for i in range(min(num_actors, len(work_list))): p = work_list[next_idx] next_idx += 1 ref = actors[i].process_pkl.remote( str(p), str(root_for_keys), int(cfg.processing.target_fps), bool(cfg.processing.fast_interpolate), bool(cfg.processing.debug_mode), ) pending[ref] = i # 8) collect results and keep feeding new tasks (post-process in-memory, then write) total_outputs = 0 with tqdm(total=len(work_list), desc="Ray: PKL→NPZ (Hydra)") as pbar: while pending: done, _ = ray.wait(list(pending.keys()), num_returns=1) ref = done[0] actor_idx = pending.pop(ref) ok, payload = ray.get(ref) if ok: flat_key: str = payload["flat_key"] # type: ignore[assignment] sample: Dict = payload["sample"] # type: ignore[assignment] arrays_ref = arrays_for_npz( sample, emit_prefixed=bool(cfg.naming.emit_prefixed), emit_legacy=bool(cfg.naming.emit_legacy), ) base_meta = { "motion_key": flat_key, "raw_motion_key": flat_key, "motion_fps": float(sample["motion_fps"]), "num_frames": int(sample["num_frames"]), "wallclock_len": float(sample["wallclock_len"]), "num_dofs": int(sample["num_dofs"]), "num_bodies": int(sample["num_bodies"]), "num_extended_bodies": int(sample["num_extended_bodies"]), "slice_start": 0, "slice_end": int(sample["num_frames"]), } base_clip = ProcessedClip( motion_key=flat_key, metadata=base_meta, arrays=arrays_ref, ) clips = preprocessor.process_clip(base_clip) for clip in clips: out_name = f"{clip.motion_key}.npz" out_path = clips_dir / out_name np.savez_compressed( out_path, metadata=json.dumps(clip.metadata), **clip.arrays, ) total_outputs += 1 else: logger.warning(f"Processing failed: {payload}") pbar.update(1) if next_idx < len(work_list): p = work_list[next_idx] next_idx += 1 new_ref = actors[actor_idx].process_pkl.remote( str(p), str(root_for_keys), int(cfg.processing.target_fps), bool(cfg.processing.fast_interpolate), bool(cfg.processing.debug_mode), ) pending[new_ref] = actor_idx # 9) Optional kinematic tagging (write to out_root level) if tagging_enabled: tags_path = ( Path(str(cfg.tagging.output_json_path)).expanduser().resolve() if str(cfg.tagging.output_json_path) else (out_root / "kinematic_tags.json") ) preprocessor.tag_directory(clips_dir, tags_path) logger.info( f"Done. NPZ written to: {clips_dir} (total clips: {total_outputs})" ) ray.shutdown() if __name__ == "__main__": main() ================================================ FILE: holomotion/src/motion_retargeting/holomotion_fk.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from __future__ import annotations import os import xml.etree.ElementTree as ETree from typing import Dict, List, Tuple import torch import pytorch_kinematics as pk from loguru import logger from holomotion.src.utils import torch_utils class MJCFParser: def __init__(self, robot_file_path: str) -> None: self._robot_file_path = robot_file_path @staticmethod def parse_vec( text: str | None, size: int, default: List[float] ) -> List[float]: if text is None: return list(default) values = [float(v) for v in text.strip().split()] if len(values) != size: raise ValueError( f"Expected {size} values, got {len(values)} in '{text}'" ) return values @staticmethod def _find_parent( root: ETree.Element, child: ETree.Element ) -> ETree.Element | None: for parent in root.iter(): for node in list(parent): if node is child: return parent return None @staticmethod def _select_include_children( parent: ETree.Element, inc_root: ETree.Element ) -> List[ETree.Element]: if inc_root.tag == "mujoco": if parent.tag != "mujoco": sub = inc_root.find(parent.tag) if sub is not None: return list(sub) return list(inc_root) if inc_root.tag == parent.tag: return list(inc_root) return list(inc_root) @classmethod def _resolve_includes(cls, root: ETree.Element, base_dir: str) -> None: includes = root.findall(".//include") while includes: for inc in includes: inc_file = inc.attrib.get("file") if inc_file is None: raise ValueError("Include tag missing 'file' attribute") inc_path = os.path.join(base_dir, inc_file) inc_root = ETree.parse(inc_path).getroot() cls._resolve_includes(inc_root, os.path.dirname(inc_path)) parent = cls._find_parent(root, inc) if parent is None: raise ValueError("Failed to resolve include parent") insert_children = cls._select_include_children( parent, inc_root ) insert_index = list(parent).index(inc) for child in list(insert_children): parent.insert(insert_index, child) insert_index += 1 parent.remove(inc) includes = root.findall(".//include") def load_root(self) -> ETree.Element: root = ETree.parse(self._robot_file_path).getroot() self._resolve_includes(root, os.path.dirname(self._robot_file_path)) return root def parse( self, ) -> Tuple[ List[str], torch.Tensor, torch.Tensor, torch.Tensor, List[List[str]], Dict[str, int], Dict[str, List[float]], List[str], torch.Tensor, List[List[int]], ]: root = self.load_root() xml_world = root.find("worldbody") if xml_world is None: raise ValueError("MJCF missing worldbody") xml_body_root = xml_world.find("body") if xml_body_root is None: raise ValueError("MJCF missing root body") body_names: List[str] = [] parents: List[int] = [] local_translation: List[List[float]] = [] local_rotation: List[List[float]] = [] body_joint_order: List[List[str]] = [] joint_body_index: Dict[str, int] = {} joint_axis: Dict[str, List[float]] = {} def _add_body(xml_body: ETree.Element, parent_index: int) -> None: body_idx = len(body_names) body_names.append(xml_body.attrib.get("name", "")) parents.append(parent_index) local_translation.append( self.parse_vec(xml_body.attrib.get("pos"), 3, [0.0, 0.0, 0.0]) ) local_rotation.append( self.parse_vec( xml_body.attrib.get("quat"), 4, [1.0, 0.0, 0.0, 0.0] ) ) joints_in_body: List[str] = [] for joint in xml_body.findall("joint"): joint_name = joint.attrib.get("name") if joint_name is None: raise ValueError("Joint missing name") joint_type = joint.attrib.get("type", "hinge") if joint_type == "free": continue if joint_type != "hinge": raise ValueError(f"Unsupported joint type: {joint_type}") axis = self.parse_vec( joint.attrib.get("axis"), 3, [0.0, 0.0, 1.0] ) joint_body_index[joint_name] = body_idx joint_axis[joint_name] = axis joints_in_body.append(joint_name) body_joint_order.append(joints_in_body) for child in xml_body.findall("body"): _add_body(child, body_idx) _add_body(xml_body_root, -1) if local_translation: local_translation[0] = [0.0, 0.0, 0.0] local_rotation[0] = [1.0, 0.0, 0.0, 0.0] dof_names: List[str] = [] for elem in root.iter(): if elem.tag == "actuator": for child in list(elem): joint_name = child.attrib.get("joint") if joint_name is not None: dof_names.append(joint_name) if len(dof_names) == 0: raise ValueError("No actuated joints found in MJCF") dof_axis: List[List[float]] = [] for joint_name in dof_names: if joint_name not in joint_body_index: raise ValueError(f"Actuator joint not found: {joint_name}") dof_axis.append(joint_axis[joint_name]) dof_name_to_index = {name: idx for idx, name in enumerate(dof_names)} body_dof_indices: List[List[int]] = [] for joints in body_joint_order: indices: List[int] = [] for name in joints: if name in dof_name_to_index: indices.append(dof_name_to_index[name]) body_dof_indices.append(indices) return ( body_names, torch.tensor(parents, dtype=torch.long), torch.tensor(local_translation, dtype=torch.float32), torch.tensor(local_rotation, dtype=torch.float32), body_joint_order, joint_body_index, joint_axis, dof_names, torch.tensor(dof_axis, dtype=torch.float32), body_dof_indices, ) class URDFParser: def __init__(self, urdf_path: str) -> None: self._urdf_path = urdf_path @staticmethod def _as_tf( tf: torch.Tensor | None, identity: torch.Tensor ) -> torch.Tensor: if tf is None: return identity if tf.ndim == 3: return tf[0] return tf def _load_chain(self) -> pk.Chain: with open(self._urdf_path, mode="r", encoding="utf-8") as f: urdf_text = f.read() return pk.build_chain_from_urdf(urdf_text) def parse( self, ) -> Tuple[ List[str], torch.Tensor, torch.Tensor, torch.Tensor, List[List[str]], Dict[str, int], Dict[str, List[float]], List[str], torch.Tensor, List[List[int]], ]: pk_chain = self._load_chain() dof_names = pk_chain.get_joint_parameter_names() if len(dof_names) == 0: raise ValueError("No actuated joints found in URDF") dof_axis = pk_chain.axes.to(dtype=torch.float32) root_name = pk_chain._root.name moving_frames = pk_chain.get_frame_names(exclude_fixed=True) body_names = [root_name] + [ name for name in moving_frames if name != root_name ] body_name_to_index = {name: idx for idx, name in enumerate(body_names)} num_frames = len(pk_chain.idx_to_frame) frame_name_to_index = { name: idx for idx, name in pk_chain.idx_to_frame.items() } full_parent_indices: List[int] = [] for i in range(num_frames): chain_indices = pk_chain.parents_indices[i] if chain_indices.numel() <= 1: full_parent_indices.append(-1) else: full_parent_indices.append(int(chain_indices[-2].item())) identity = torch.eye(4, dtype=torch.float32) frame_transforms: List[torch.Tensor] = [identity] * num_frames for i in range(num_frames): link_offset = self._as_tf(pk_chain.link_offsets[i], identity) joint_offset = self._as_tf(pk_chain.joint_offsets[i], identity) if i == 0: link_offset = identity joint_offset = identity parent = full_parent_indices[i] if parent < 0: frame_tf = identity else: frame_tf = ( frame_transforms[parent] @ link_offset @ joint_offset ) frame_transforms[i] = frame_tf parents: List[int] = [] local_translation: List[List[float]] = [] local_rotation_mat: List[torch.Tensor] = [] body_joint_order: List[List[str]] = [] joint_body_index: Dict[str, int] = {} joint_axis: Dict[str, List[float]] = {} for body_name in body_names: frame_idx = frame_name_to_index[body_name] parent_frame_idx = full_parent_indices[frame_idx] parent_body_idx = -1 while parent_frame_idx >= 0: parent_name = pk_chain.idx_to_frame[parent_frame_idx] if parent_name in body_name_to_index: parent_body_idx = body_name_to_index[parent_name] break parent_frame_idx = full_parent_indices[parent_frame_idx] parents.append(parent_body_idx) if parent_body_idx < 0: local_tf = identity else: local_tf = ( torch.linalg.inv(frame_transforms[parent_frame_idx]) @ frame_transforms[frame_idx] ) local_translation.append(local_tf[:3, 3].tolist()) local_rotation_mat.append(local_tf[:3, :3]) joints_in_body: List[str] = [] joint_index = int(pk_chain.joint_indices[frame_idx].item()) if joint_index >= 0: joint_type = int(pk_chain.joint_type_indices[frame_idx].item()) if joint_type != 1: raise ValueError( f"Unsupported joint type index: {joint_type}" ) joint_name = dof_names[joint_index] joints_in_body.append(joint_name) joint_body_index[joint_name] = body_name_to_index[body_name] joint_axis[joint_name] = dof_axis[joint_index].tolist() body_joint_order.append(joints_in_body) local_rotation = torch_utils.quat_from_matrix( torch.stack(local_rotation_mat, dim=0) ) dof_name_to_index = {name: idx for idx, name in enumerate(dof_names)} body_dof_indices: List[List[int]] = [] for joints in body_joint_order: indices: List[int] = [] for name in joints: if name in dof_name_to_index: indices.append(dof_name_to_index[name]) body_dof_indices.append(indices) return ( body_names, torch.tensor(parents, dtype=torch.long), torch.tensor(local_translation, dtype=torch.float32), local_rotation.to(dtype=torch.float32), body_joint_order, joint_body_index, joint_axis, dof_names, dof_axis, body_dof_indices, ) # @torch.compile(dynamic=True) class HoloMotionFK(torch.nn.Module): def __init__( self, robot_file_path: str, device: torch.device | str = "cpu", dtype: torch.dtype = torch.float32, ) -> None: super().__init__() self.robot_file_path = robot_file_path _, ext = os.path.splitext(robot_file_path) ext = ext.lower() if ext == ".urdf": parser = URDFParser(robot_file_path) elif ext in [".xml", ".mjcf"]: parser = MJCFParser(robot_file_path) else: raise ValueError(f"Unsupported file extension: {ext}") logger.info( f"Parsing robot file for online forward kinematics: {robot_file_path}..." ) ( body_names, parents, local_translation, local_rotation, body_joint_order, joint_body_index, joint_axis, dof_names, dof_axis, body_dof_indices, ) = parser.parse() self.body_names = body_names self.dof_names = dof_names self.num_bodies = len(body_names) self.num_dof = len(dof_names) parents = parents.to(device=device) local_translation = local_translation.to(device=device, dtype=dtype) local_rotation = local_rotation.to(device=device, dtype=dtype) local_rotation_mat = torch_utils.matrix_from_quat(local_rotation) dof_axis = dof_axis.to(device=device, dtype=dtype) max_body_dofs = max( (len(indices) for indices in body_dof_indices), default=0 ) body_dof_index_tensor = torch.full( (self.num_bodies, max_body_dofs), -1, dtype=torch.long, ) body_dof_mask = torch.zeros( (self.num_bodies, max_body_dofs), dtype=torch.bool ) for body_idx, indices in enumerate(body_dof_indices): if not indices: continue body_dof_index_tensor[body_idx, : len(indices)] = torch.tensor( indices, dtype=torch.long ) body_dof_mask[body_idx, : len(indices)] = True self.register_buffer("_parents", parents) self.register_buffer("_local_translation", local_translation) self.register_buffer("_local_rotation_mat", local_rotation_mat) self.register_buffer("_dof_axis", dof_axis) self.register_buffer("_body_dof_index_tensor", body_dof_index_tensor) self.register_buffer("_body_dof_mask", body_dof_mask) self._body_joint_order = body_joint_order self._joint_body_index = joint_body_index self._joint_axis = joint_axis self._body_dof_indices = body_dof_indices @torch.no_grad() def forward( self, root_pos: torch.Tensor, root_quat: torch.Tensor, dof_pos: torch.Tensor, fps: float, quat_format: str = "xyzw", sub_batch_size: int = 64, vel_smoothing_sigma: float = 2.0, ) -> Dict[str, torch.Tensor]: """Forward kinematics and smoothed velocities. Args: root_pos: (B, T, 3) root_quat: (B, T, 4), XYZW by default dof_pos: (B, T, ndof) fps: frames per second sub_batch_size: split batch into chunks to reduce peak memory vel_smoothing_sigma: Gaussian sigma for smoothing velocity signals along the time axis (set <= 0 to disable). Returns: Dict with global_translation/global_rotation_quat/global_velocity/ global_angular_velocity/dof_pos/dof_vel. """ if fps <= 0.0: raise ValueError(f"Invalid fps: {fps}") if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3: raise ValueError("Inputs must be (B, T, ...)") if ( root_pos.shape[:2] != root_quat.shape[:2] or root_pos.shape[:2] != dof_pos.shape[:2] ): raise ValueError("Mismatched batch/time shapes among inputs") if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4: raise ValueError( "root_pos must be (B,T,3) and root_quat must be (B,T,4)" ) if dof_pos.shape[-1] != self.num_dof: raise ValueError( f"dof_pos last dim {dof_pos.shape[-1]} does not match " f"{self.num_dof}" ) device = self._local_translation.device dtype = self._local_translation.dtype root_pos = root_pos.to(device=device, dtype=dtype) root_quat = root_quat.to(device=device, dtype=dtype) dof_pos = dof_pos.to(device=device, dtype=dtype) batch_size, seq_len = root_pos.shape[:2] if ( sub_batch_size is None or sub_batch_size <= 0 or sub_batch_size >= batch_size ): return self._forward_impl( root_pos=root_pos, root_quat=root_quat, dof_pos=dof_pos, fps=fps, quat_format=quat_format, vel_smoothing_sigma=float(vel_smoothing_sigma), ) global_translation = torch.empty( (batch_size, seq_len, self.num_bodies, 3), device=device, dtype=dtype, ) global_rotation_quat = torch.empty( (batch_size, seq_len, self.num_bodies, 4), device=device, dtype=dtype, ) global_velocity = torch.empty_like(global_translation) global_angular_velocity = torch.empty_like(global_translation) dof_pos_out = torch.empty_like(dof_pos) dof_vel = torch.empty_like(dof_pos) for start in range(0, batch_size, sub_batch_size): end = min(start + sub_batch_size, batch_size) out = self._forward_impl( root_pos=root_pos[start:end], root_quat=root_quat[start:end], dof_pos=dof_pos[start:end], fps=fps, quat_format=quat_format, vel_smoothing_sigma=float(vel_smoothing_sigma), ) global_translation[start:end] = out["global_translation"] global_rotation_quat[start:end] = out["global_rotation_quat"] global_velocity[start:end] = out["global_velocity"] global_angular_velocity[start:end] = out["global_angular_velocity"] dof_pos_out[start:end] = out["dof_pos"] dof_vel[start:end] = out["dof_vel"] return { "global_translation": global_translation, "global_rotation_quat": global_rotation_quat, "global_velocity": global_velocity, "global_angular_velocity": global_angular_velocity, "dof_pos": dof_pos_out, "dof_vel": dof_vel, } def _forward_impl( self, root_pos: torch.Tensor, root_quat: torch.Tensor, dof_pos: torch.Tensor, fps: float, quat_format: str, vel_smoothing_sigma: float, ) -> Dict[str, torch.Tensor]: device = self._local_translation.device dtype = self._local_translation.dtype if quat_format == "xyzw": root_quat_wxyz = torch_utils.xyzw_to_wxyz(root_quat) elif quat_format == "wxyz": root_quat_wxyz = root_quat else: raise ValueError(f"Unsupported quat_format: {quat_format}") root_rotmat = torch_utils.matrix_from_quat(root_quat_wxyz) dof_rotmats = torch_utils.axis_angle_to_matrix(dof_pos, self._dof_axis) positions_world = torch.empty( (dof_pos.shape[0], dof_pos.shape[1], self.num_bodies, 3), device=device, dtype=dtype, ) rotations_world = torch.empty( (dof_pos.shape[0], dof_pos.shape[1], self.num_bodies, 3, 3), device=device, dtype=dtype, ) for i in range(self.num_bodies): parent = int(self._parents[i].item()) if parent < 0: positions_world[:, :, i] = root_pos rotations_world[:, :, i] = root_rotmat continue parent_pos = positions_world[:, :, parent] parent_rot = rotations_world[:, :, parent] offset = self._local_translation[i] pos = parent_pos + torch.einsum("btij,j->bti", parent_rot, offset) rot = torch.matmul(parent_rot, self._local_rotation_mat[i]) body_dof_indices = self._body_dof_indices[i] for dof_idx in body_dof_indices: rot = torch.matmul(rot, dof_rotmats[:, :, dof_idx]) positions_world[:, :, i] = pos rotations_world[:, :, i] = rot global_translation = positions_world global_rotation_mat = rotations_world global_quat_wxyz = torch_utils.quat_from_matrix(global_rotation_mat) global_quat_xyzw = torch_utils.wxyz_to_xyzw(global_quat_wxyz) dt = 1.0 / fps if dof_pos.shape[1] < 2: dof_vel = torch.zeros_like(dof_pos) else: diff = (dof_pos[:, 1:] - dof_pos[:, :-1]) / dt pad = diff[:, -2:-1] if diff.shape[1] >= 2 else diff[:, -1:] dof_vel = torch.cat([diff, pad], dim=1) dof_vel = torch_utils.smooth_time_series( dof_vel, sigma=float(vel_smoothing_sigma), dim=1 ) global_velocity = torch_utils.grad_t(global_translation, dt) global_velocity = torch_utils.smooth_time_series( global_velocity, sigma=float(vel_smoothing_sigma), dim=1 ) if global_quat_xyzw.shape[1] < 2: global_angular_velocity = torch.zeros_like(global_translation) else: q1 = torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, 1:]) q0_inv = torch_utils.quat_conjugate( torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, :-1]) ) q_rel = torch_utils.quat_mul(q1, q0_inv) q_rel = q_rel / torch.linalg.norm(q_rel, dim=-1, keepdim=True) q_rel = torch_utils.standardize_quaternion(q_rel) identity = torch.tensor( [1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype )[None, None, None] q_rel_full = identity.expand( global_quat_xyzw.shape[0], global_quat_xyzw.shape[1], global_quat_xyzw.shape[2], 4, ).clone() q_rel_full[:, :-1] = q_rel global_angular_velocity = ( torch_utils.axis_angle_from_quat(q_rel_full, w_last=False) / dt ) global_angular_velocity = torch_utils.smooth_time_series( global_angular_velocity, sigma=float(vel_smoothing_sigma), dim=1, ) return { "global_translation": global_translation, "global_rotation_quat": global_quat_xyzw, "global_velocity": global_velocity, "global_angular_velocity": global_angular_velocity, "dof_pos": dof_pos, "dof_vel": dof_vel, } # class HoloMotionFK_V2(torch.nn.Module): # """ # Use pytorch_kinematics to compute FK. # """ # def __init__( # self, # robot_file_path: str, # device: torch.device | str = "cpu", # dtype: torch.dtype = torch.float32, # ) -> None: # super().__init__() # self.robot_file_path = robot_file_path # urdf_path = os.path.splitext(robot_file_path)[0] + ".urdf" # if not os.path.isfile(urdf_path): # raise FileNotFoundError(f"URDF not found: {urdf_path}") # with open(urdf_path, mode="r", encoding="utf-8") as f: # urdf_text = f.read() # pk_chain = pk.build_chain_from_urdf(urdf_text) # pk_chain = pk_chain.to(dtype=dtype, device=device) # self.dof_names = pk_chain.get_joint_parameter_names() # self.num_dof = len(self.dof_names) # root_name = pk_chain._root.name # moving_frames = pk_chain.get_frame_names(exclude_fixed=True) # self.body_names = [root_name] + [ # name for name in moving_frames if name != root_name # ] # self.num_bodies = len(self.body_names) # body_frame_indices = pk_chain.get_frame_indices(*self.body_names) # self.register_buffer("_body_frame_indices", body_frame_indices) # num_frames = len(pk_chain.idx_to_frame) # identity = torch.eye(4, device=device, dtype=dtype) # link_offsets = [] # joint_offsets = [] # for i in range(num_frames): # link_offset = pk_chain.link_offsets[i] # joint_offset = pk_chain.joint_offsets[i] # if link_offset is None: # link_offset = identity # if joint_offset is None: # joint_offset = identity # if link_offset.ndim == 3: # link_offset = link_offset[0] # if joint_offset.ndim == 3: # joint_offset = joint_offset[0] # link_offsets.append(link_offset) # joint_offsets.append(joint_offset) # if num_frames > 0: # link_offsets[0] = identity # joint_offsets[0] = identity # parent_indices: List[int] = [] # for i in range(num_frames): # chain_indices = pk_chain.parents_indices[i] # if chain_indices.numel() <= 1: # parent_indices.append(-1) # else: # parent_indices.append(int(chain_indices[-2].item())) # self.register_buffer("_pk_axes", pk_chain.axes) # self.register_buffer( # "_pk_joint_type_indices", pk_chain.joint_type_indices # ) # self.register_buffer("_pk_joint_indices", pk_chain.joint_indices) # self.register_buffer( # "_pk_link_offsets", torch.stack(link_offsets, dim=0) # ) # self.register_buffer( # "_pk_joint_offsets", torch.stack(joint_offsets, dim=0) # ) # self.register_buffer( # "_pk_parent_indices", # torch.tensor(parent_indices, dtype=torch.long, device=device), # ) # self._num_frames = num_frames # def forward( # self, # root_pos: torch.Tensor, # root_quat: torch.Tensor, # dof_pos: torch.Tensor, # fps: float, # quat_format: str = "xyzw", # ) -> Dict[str, torch.Tensor]: # """ # Args: # root_pos: (B, T, 3) # root_quat: (B, T, 4), XYZW by default # dof_pos: (B, T, ndof) # fps: frames per second # Returns: # Dict with global_translation/global_rotation_quat/global_velocity/ # global_angular_velocity/dof_pos/dof_vel. # """ # if fps <= 0.0: # raise ValueError(f"Invalid fps: {fps}") # if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3: # raise ValueError("Inputs must be (B, T, ...)") # if ( # root_pos.shape[:2] != root_quat.shape[:2] # or root_pos.shape[:2] != dof_pos.shape[:2] # ): # raise ValueError("Mismatched batch/time shapes among inputs") # if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4: # raise ValueError( # "root_pos must be (B,T,3) and root_quat must be (B,T,4)" # ) # if dof_pos.shape[-1] != self.num_dof: # raise ValueError( # f"dof_pos last dim {dof_pos.shape[-1]} does not match {self.num_dof}" # ) # device = self._pk_axes.device # dtype = self._pk_axes.dtype # root_pos = root_pos.to(device=device, dtype=dtype) # root_quat = root_quat.to(device=device, dtype=dtype) # dof_pos = dof_pos.to(device=device, dtype=dtype) # if quat_format == "xyzw": # root_quat_wxyz = torch_utils.xyzw_to_wxyz(root_quat) # elif quat_format == "wxyz": # root_quat_wxyz = root_quat # else: # raise ValueError(f"Unsupported quat_format: {quat_format}") # batch_size, seq_len = root_pos.shape[:2] # flat_size = batch_size * seq_len # root_pos_flat = root_pos.reshape(flat_size, 3) # root_quat_flat = root_quat_wxyz.reshape(flat_size, 4) # dof_pos_flat = dof_pos.reshape(flat_size, self.num_dof) # axes_expanded = self._pk_axes[None].expand(flat_size, -1, -1) # revolute_tf = axis_and_angle_to_matrix_44(axes_expanded, dof_pos_flat) # prismatic_tf = axis_and_d_to_pris_matrix(axes_expanded, dof_pos_flat) # frame_transforms = torch.empty( # (flat_size, self._num_frames, 4, 4), device=device, dtype=dtype # ) # identity = torch.eye(4, device=device, dtype=dtype).repeat( # flat_size, 1, 1 # ) # for i in range(self._num_frames): # parent = int(self._pk_parent_indices[i].item()) # if parent < 0: # frame_tf = identity # else: # frame_tf = frame_transforms[:, parent] # frame_tf = frame_tf @ self._pk_link_offsets[i] # frame_tf = frame_tf @ self._pk_joint_offsets[i] # joint_type = int(self._pk_joint_type_indices[i].item()) # if joint_type == 1: # joint_index = int(self._pk_joint_indices[i].item()) # frame_tf = frame_tf @ revolute_tf[:, joint_index] # elif joint_type == 2: # joint_index = int(self._pk_joint_indices[i].item()) # frame_tf = frame_tf @ prismatic_tf[:, joint_index] # frame_transforms[:, i] = frame_tf # chain_tf = torch.index_select( # frame_transforms, 1, self._body_frame_indices # ) # root_rotmat = torch_utils.matrix_from_quat(root_quat_flat) # root_tf = torch.eye(4, device=device, dtype=dtype).repeat( # flat_size, 1, 1 # ) # root_tf[:, :3, :3] = root_rotmat # root_tf[:, :3, 3] = root_pos_flat # world_tf = root_tf[:, None] @ chain_tf # world_tf = world_tf.reshape(batch_size, seq_len, self.num_bodies, 4, 4) # global_translation = world_tf[:, :, :, :3, 3] # global_rotation_mat = world_tf[:, :, :, :3, :3] # global_quat_wxyz = torch_utils.quat_from_matrix(global_rotation_mat) # global_quat_xyzw = torch_utils.wxyz_to_xyzw(global_quat_wxyz) # dt = 1.0 / fps # if dof_pos.shape[1] < 2: # dof_vel = torch.zeros_like(dof_pos) # else: # diff = (dof_pos[:, 1:] - dof_pos[:, :-1]) / dt # pad = diff[:, -2:-1] if diff.shape[1] >= 2 else diff[:, -1:] # dof_vel = torch.cat([diff, pad], dim=1) # global_velocity = torch_utils.grad_t(global_translation, dt) # global_velocity = torch_utils.gaussian_filter1d( # global_velocity, sigma=2.0, dim=1 # ) # if global_quat_xyzw.shape[1] < 2: # global_angular_velocity = torch.zeros_like(global_translation) # else: # q1 = torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, 1:]) # q0_inv = torch_utils.quat_conjugate( # torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, :-1]) # ) # q_rel = torch_utils.quat_mul(q1, q0_inv) # q_rel = q_rel / torch.linalg.norm(q_rel, dim=-1, keepdim=True) # q_rel = torch_utils.standardize_quaternion(q_rel) # identity = torch.tensor( # [1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype # )[None, None, None] # q_rel_full = identity.expand( # global_quat_xyzw.shape[0], # global_quat_xyzw.shape[1], # global_quat_xyzw.shape[2], # 4, # ).clone() # q_rel_full[:, :-1] = q_rel # global_angular_velocity = ( # torch_utils.axis_angle_from_quat(q_rel_full, w_last=False) / dt # ) # global_angular_velocity = torch_utils.gaussian_filter1d( # global_angular_velocity, # sigma=2.0, # dim=1, # ) # return { # "global_translation": global_translation, # "global_rotation_quat": global_quat_xyzw, # "global_velocity": global_velocity, # "global_angular_velocity": global_angular_velocity, # "dof_pos": dof_pos, # "dof_vel": dof_vel, # } ================================================ FILE: holomotion/src/motion_retargeting/holomotion_preprocess.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import json import logging import os import sys from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import hydra import numpy as np import ray import torch from loguru import logger from omegaconf import DictConfig, ListConfig, OmegaConf from scipy.spatial.transform import Rotation as sRot from scipy.spatial.transform import Slerp from tqdm import tqdm from holomotion.src.motion_retargeting.utils.torch_humanoid_batch import ( HumanoidBatch, ) from holomotion.src.motion_retargeting.utils import ( rotation_conversions as rot_conv, ) from holomotion.src.motion_retargeting.reference_filtering import ( butterworth_filter_ref_arrays as shared_butterworth_filter_ref_arrays, ) def compute_slices( sequence_len: int, window_size: int, overlap: int ) -> List[Tuple[int, int]]: step = window_size - overlap if step <= 0: raise ValueError("window_size must be > overlap") slices: List[Tuple[int, int]] = [] start = 0 length = int(sequence_len) while start < length: end = min(start + window_size, length) slices.append((start, end)) if end == length: break start += step return slices def _reshape_time_flat(a: np.ndarray) -> Tuple[np.ndarray, Tuple[int, ...]]: shape = a.shape t = shape[0] return a.reshape(t, -1), shape def _butterworth_lowpass_smooth_time( a: np.ndarray, fps: float, cutoff_hz: float, order: int ) -> np.ndarray: from scipy.signal import butter, filtfilt t = a.shape[0] if t < 3: return a.astype(np.float32, copy=True) if fps <= 0.0 or cutoff_hz <= 0.0: return a.astype(np.float32, copy=True) nyquist = 0.5 * float(fps) wn = float(cutoff_hz) / nyquist if wn >= 1.0: wn = 0.999 if wn <= 0.0: return a.astype(np.float32, copy=True) flat, shape = _reshape_time_flat(a.astype(np.float64, copy=False)) b, a_coefs = butter(int(order), wn, btype="low", analog=False) maxlen = max(len(b), len(a_coefs)) padlen_required = max(3 * (maxlen - 1), 3 * maxlen) if t <= padlen_required: return a.astype(np.float32, copy=True) filtered = filtfilt(b, a_coefs, flat, axis=0, method="pad") return filtered.reshape(shape).astype(np.float32, copy=False) def _quat_normalize(q: np.ndarray) -> np.ndarray: norm = np.linalg.norm(q, axis=-1, keepdims=True) norm = np.where(norm == 0.0, 1.0, norm) return (q / norm).astype(np.float32, copy=False) def _quat_hemisphere_align(q: np.ndarray) -> np.ndarray: if q.shape[0] == 0: return q aligned = q.copy() prev = aligned[0] for t in range(1, aligned.shape[0]): dots = np.sum(prev * aligned[t], axis=-1) mask = dots < 0.0 if np.any(mask): aligned[t, mask] = -aligned[t, mask] prev = aligned[t] return aligned def _quat_conjugate(q: np.ndarray) -> np.ndarray: conj = q.copy() conj[..., :3] = -conj[..., :3] return conj def _quat_multiply(a: np.ndarray, b: np.ndarray) -> np.ndarray: av = a[..., :3] aw = a[..., 3:4] bv = b[..., :3] bw = b[..., 3:4] cross = np.cross(av, bv) vec = aw * bv + bw * av + cross scalar = aw * bw - np.sum(av * bv, axis=-1, keepdims=True) return np.concatenate([vec, scalar], axis=-1) def _finite_difference_time(a: np.ndarray, dt: float) -> np.ndarray: t = a.shape[0] if t < 2 or dt <= 0.0: return np.zeros_like(a, dtype=np.float32) deriv = np.gradient( a.astype(np.float64, copy=False), dt, axis=0, edge_order=2 if t >= 3 else 1, ) return deriv.astype(np.float32, copy=False) def _angular_velocity_from_quat( q: np.ndarray, q_dot: np.ndarray ) -> np.ndarray: q_conj = _quat_conjugate(q) prod = _quat_multiply(q_conj, q_dot) omega = 2.0 * prod[..., :3] return omega.astype(np.float32, copy=False) def butterworth_filter_ref_arrays( arrays: Dict[str, np.ndarray], fps: float, cutoff_hz: float, order: int ) -> Dict[str, np.ndarray]: return shared_butterworth_filter_ref_arrays( arrays=arrays, fps=fps, cutoff_hz=cutoff_hz, order=order, ) def _summary(arr: np.ndarray) -> Dict[str, float]: if arr.size == 0: return { "mean": 0.0, "std": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "q25": 0.0, "q75": 0.0, } return { "mean": float(arr.mean()), "std": float(arr.std()), "median": float(np.median(arr)), "min": float(arr.min()), "max": float(arr.max()), "q25": float(np.quantile(arr, 0.25)), "q75": float(np.quantile(arr, 0.75)), } def _ds_summary(arr: np.ndarray) -> Dict[str, float]: if arr.size == 0: return { "DS_mean": 0.0, "DS_std": 0.0, "DS_median": 0.0, "DS_min": 0.0, "DS_max": 0.0, "DS_q25": 0.0, "DS_q75": 0.0, } return { "DS_mean": float(arr.mean()), "DS_std": float(arr.std()), "DS_median": float(np.median(arr)), "DS_min": float(arr.min()), "DS_max": float(arr.max()), "DS_q25": float(np.quantile(arr, 0.25)), "DS_q75": float(np.quantile(arr, 0.75)), } def _interpolate_linear( start: np.ndarray, end: np.ndarray, num_frames: int ) -> np.ndarray: """Linear interpolation between start and end over num_frames. Returns array where result[0] == start and result[-1] == end. """ start = np.asarray(start, dtype=np.float32) end = np.asarray(end, dtype=np.float32) if num_frames <= 1: return start[None, ...] t = np.linspace(0.0, 1.0, num_frames, dtype=np.float32) for _ in range(start.ndim): t = t[..., None] result = ((1.0 - t) * start + t * end).astype(np.float32) result[0] = start result[-1] = end return result def _interpolate_quaternions_slerp( start_quat: np.ndarray, end_quat: np.ndarray, num_frames: int ) -> np.ndarray: """SLERP interpolation between two quaternions (XYZW format) over num_frames. Args: start_quat: shape [4] in XYZW format end_quat: shape [4] in XYZW format num_frames: number of interpolation frames Returns: shape [num_frames, 4] in XYZW format, with result[0] == start_quat and result[-1] == end_quat. """ start_quat = np.asarray(start_quat, dtype=np.float32) end_quat = np.asarray(end_quat, dtype=np.float32) if num_frames <= 1: return start_quat[None, ...] rotations = sRot.from_quat([start_quat, end_quat]) slerp = Slerp([0.0, 1.0], rotations) t = np.linspace(0.0, 1.0, num_frames) result = slerp(t).as_quat().astype(np.float32) result[0] = start_quat result[-1] = end_quat return result def _extract_yaw_only_quat(quat: np.ndarray) -> np.ndarray: """Extract yaw-only quaternion (XYZW format) from a full quaternion. Args: quat: shape [4] in XYZW format Returns: shape [4] in XYZW format with only yaw rotation """ rot = sRot.from_quat(quat) euler = rot.as_euler("xyz", degrees=False) yaw_only_euler = np.array([0.0, 0.0, euler[2]]) yaw_only_rot = sRot.from_euler("xyz", yaw_only_euler, degrees=False) return yaw_only_rot.as_quat().astype(np.float32) def _dof_to_pose_aa( dof_pos: np.ndarray, root_rot_xyzw: np.ndarray, humanoid_fk: "HumanoidBatch", num_augment_joints: int, ) -> np.ndarray: """Convert DOF positions and root rotation to pose axis-angle. Args: dof_pos: shape [T, num_dofs] root_rot_xyzw: shape [T, 4] in XYZW format humanoid_fk: HumanoidBatch instance num_augment_joints: number of augmented joints Returns: pose_aa: shape [T, num_bodies + num_augment_joints, 3] """ dof_t = torch.as_tensor(dof_pos, dtype=torch.float32) T = dof_t.shape[0] root_quat_wxyz = rot_conv.xyzw_to_wxyz( torch.as_tensor(root_rot_xyzw, dtype=torch.float32) ) root_aa = rot_conv.quaternion_to_axis_angle(root_quat_wxyz) joint_aa = humanoid_fk.dof_axis * dof_t[:, :, None] pose_aa = torch.cat( [ root_aa[:, None, :], joint_aa, torch.zeros((T, num_augment_joints, 3), dtype=torch.float32), ], dim=1, ) return pose_aa.numpy().astype(np.float32) def _compute_fk_motion( dof_pos: np.ndarray, root_pos: np.ndarray, root_rot_xyzw: np.ndarray, humanoid_fk: "HumanoidBatch", num_augment_joints: int, fps: float, ) -> Dict[str, np.ndarray]: """Compute all motion arrays from dof_pos, root_pos, and root_rot via FK. Args: dof_pos: shape [T, num_dofs] root_pos: shape [T, 3] root_rot_xyzw: shape [T, 4] in XYZW format humanoid_fk: HumanoidBatch instance num_augment_joints: number of augmented joints fps: frames per second Returns: Dict with ref_dof_pos, ref_dof_vel, ref_global_translation, ref_global_rotation_quat, ref_global_velocity, ref_global_angular_velocity, frame_flag """ T = dof_pos.shape[0] dt = 1.0 / fps pose_aa = _dof_to_pose_aa( dof_pos, root_rot_xyzw, humanoid_fk, num_augment_joints ) pose_aa_t = torch.as_tensor(pose_aa, dtype=torch.float32) root_pos_t = torch.as_tensor(root_pos, dtype=torch.float32) fk_result = humanoid_fk.fk_batch( pose_aa_t[None, ...], root_pos_t[None, ...], return_full=True, dt=dt, ) frame_flag = np.ones(T, dtype=np.int32) frame_flag[0] = 0 frame_flag[-1] = 2 return { "ref_dof_pos": fk_result.dof_pos.squeeze(0).numpy().astype(np.float32), "ref_dof_vel": fk_result.dof_vels.squeeze(0) .numpy() .astype(np.float32), "ref_global_translation": fk_result.global_translation.squeeze(0) .numpy() .astype(np.float32), "ref_global_rotation_quat": fk_result.global_rotation.squeeze(0) .numpy() .astype(np.float32), "ref_global_velocity": fk_result.global_velocity.squeeze(0) .numpy() .astype(np.float32), "ref_global_angular_velocity": fk_result.global_angular_velocity.squeeze( 0 ) .numpy() .astype(np.float32), "frame_flag": frame_flag, } @dataclass class ProcessedClip: motion_key: str metadata: Dict[str, Any] arrays: Dict[str, np.ndarray] class HoloMotionPreprocessor: """ Composable preprocessing pipeline operating on standardized HoloMotion NPZ clips. Supports per-clip stages like slicing and Butterworth filtering, plus dataset-level kinematic tagging. """ def __init__( self, slicing_cfg: Optional[DictConfig] = None, filtering_cfg: Optional[DictConfig] = None, tagging_cfg: Optional[DictConfig] = None, padding_cfg: Optional[DictConfig] = None, pipeline: Optional[List[str]] = None, ) -> None: self.slicing_cfg = slicing_cfg self.filtering_cfg = filtering_cfg self.tagging_cfg = tagging_cfg self.padding_cfg = padding_cfg self.pipeline = self._resolve_pipeline(pipeline) self._humanoid_fk: Optional[HumanoidBatch] = None self._robot_cfg: Optional[DictConfig] = None def _resolve_pipeline(self, pipeline: Optional[List[str]]) -> List[str]: if pipeline is not None: return list(pipeline) return [] def process_clip(self, clip: ProcessedClip) -> List[ProcessedClip]: clips = [clip] logger.debug( f"Processing clip '{clip.motion_key}' with pipeline: {self.pipeline}" ) for stage in self.pipeline: logger.debug(f"Applying stage: {stage}") if stage in ("slicing", "slice"): next_clips: List[ProcessedClip] = [] for c in clips: next_clips.extend(self._apply_slicing(c)) clips = next_clips logger.debug(f"After slicing: {len(clips)} clips") elif stage in ( "apply_butterworth_filter", "filtering", "butterworth_filter", ): clips = [self._apply_filtering(c) for c in clips] logger.debug( f"After apply_butterworth_filter: {len(clips)} clips" ) elif stage == "filename_as_motionkey": clips = [self._apply_filename_as_motionkey(c) for c in clips] logger.debug( f"After filename_as_motionkey: {len(clips)} clips" ) elif stage == "legacy_to_ref_keys": clips = [self._apply_legacy_to_ref_keys(c) for c in clips] logger.debug(f"After legacy_to_ref_keys: {len(clips)} clips") elif stage == "add_legacy_keys": clips = [self._apply_add_legacy_keys(c) for c in clips] logger.debug(f"After add_legacy_keys: {len(clips)} clips") elif stage == "add_padding": clips = [self._apply_add_padding(c) for c in clips] logger.debug(f"After add_padding: {len(clips)} clips") else: logger.warning( f"Unknown preprocessing stage '{stage}' ignored." ) return clips def _apply_slicing(self, clip: ProcessedClip) -> List[ProcessedClip]: cfg = self.slicing_cfg if cfg is None: logger.warning( "Slicing requested but slicing_cfg is None - skipping slicing" ) return [clip] window_size = int(getattr(cfg, "window_size", 0)) overlap = int(getattr(cfg, "overlap", 0)) seq_len = int(clip.metadata.get("num_frames", 0)) if seq_len <= 0: return [clip] slice_specs = compute_slices(seq_len, window_size, overlap) if not slice_specs: return [clip] fps = float(clip.metadata.get("motion_fps", 0.0)) raw_motion_key = str( clip.metadata.get( "raw_motion_key", clip.metadata.get("motion_key", "") ) ) base_motion_key = str(clip.metadata.get("motion_key", raw_motion_key)) arrays = clip.arrays out_clips: List[ProcessedClip] = [] for s, e in slice_specs: arrays_window: Dict[str, np.ndarray] = {} for k, v in arrays.items(): if ( isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] == seq_len ): arrays_window[k] = v[s:e] else: arrays_window[k] = v num_frames = int(e - s) if num_frames <= 0: continue wallclock_len = float(num_frames - 1) / fps if fps > 0.0 else 0.0 if s == 0 and e == seq_len: motion_key = base_motion_key else: motion_key = f"{base_motion_key}_s{s}_e{e}" meta = dict(clip.metadata) meta["motion_key"] = motion_key meta["raw_motion_key"] = raw_motion_key meta["num_frames"] = num_frames meta["wallclock_len"] = wallclock_len meta["slice_start"] = int(s) meta["slice_end"] = int(e) out_clips.append( ProcessedClip( motion_key=motion_key, metadata=meta, arrays=arrays_window, ) ) return out_clips def _apply_filtering(self, clip: ProcessedClip) -> ProcessedClip: cfg = self.filtering_cfg if cfg is None: logger.warning( "Filtering requested but filtering_cfg is None - skipping filtering" ) return clip fps = float(clip.metadata.get("motion_fps", 0.0)) cutoff = float(getattr(cfg, "butter_cutoff_hz", 0.0)) order = int(getattr(cfg, "butter_order", 4)) ft = butterworth_filter_ref_arrays( clip.arrays, fps=fps, cutoff_hz=cutoff, order=order ) arrays = dict(clip.arrays) arrays.update(ft) return ProcessedClip( motion_key=clip.motion_key, metadata=clip.metadata, arrays=arrays, ) def _apply_filename_as_motionkey( self, clip: ProcessedClip ) -> ProcessedClip: filename = clip.metadata.get("source_filename", None) if filename is None: logger.warning( "filename_as_motionkey requested but source_filename not found in metadata - skipping" ) return clip new_motion_key = str(filename) meta = dict(clip.metadata) meta["motion_key"] = new_motion_key if "raw_motion_key" not in meta: meta["raw_motion_key"] = clip.motion_key return ProcessedClip( motion_key=new_motion_key, metadata=meta, arrays=clip.arrays, ) def _apply_add_legacy_keys(self, clip: ProcessedClip) -> ProcessedClip: """Add deprecated legacy keys for backward compatibility. Maps ref_* keys to legacy unprefixed keys according to spec: - ref_dof_pos -> dof_pos - ref_dof_vel -> dof_vels - ref_global_translation -> global_translation - ref_global_rotation_quat -> global_rotation_quat - ref_global_velocity -> global_velocity - ref_global_angular_velocity -> global_angular_velocity """ ref_to_legacy = { "ref_dof_pos": "dof_pos", "ref_dof_vel": "dof_vels", "ref_global_translation": "global_translation", "ref_global_rotation_quat": "global_rotation_quat", "ref_global_velocity": "global_velocity", "ref_global_angular_velocity": "global_angular_velocity", } arrays = dict(clip.arrays) for ref_key, legacy_key in ref_to_legacy.items(): if ref_key in arrays: if legacy_key not in arrays: arrays[legacy_key] = arrays[ref_key].copy() logger.debug( f"Added legacy key '{legacy_key}' from '{ref_key}'" ) else: logger.debug( f"Legacy key '{legacy_key}' already exists, skipping" ) return ProcessedClip( motion_key=clip.motion_key, metadata=clip.metadata, arrays=arrays, ) def _apply_legacy_to_ref_keys(self, clip: ProcessedClip) -> ProcessedClip: """Add new ref_* keys from legacy unprefixed keys. Maps legacy keys to ref_* keys according to spec while keeping the original legacy arrays: - dof_pos -> ref_dof_pos - dof_vels -> ref_dof_vel - global_translation -> ref_global_translation - global_rotation_quat -> ref_global_rotation_quat - global_velocity -> ref_global_velocity - global_angular_velocity -> ref_global_angular_velocity """ legacy_to_ref = { "dof_pos": "ref_dof_pos", "dof_vels": "ref_dof_vel", "global_translation": "ref_global_translation", "global_rotation_quat": "ref_global_rotation_quat", "global_velocity": "ref_global_velocity", "global_angular_velocity": "ref_global_angular_velocity", } arrays = dict(clip.arrays) for legacy_key, ref_key in legacy_to_ref.items(): if legacy_key in arrays: if ref_key not in arrays: arrays[ref_key] = arrays[legacy_key].copy() logger.debug( f"Added ref key '{ref_key}' from legacy key '{legacy_key}'" ) else: logger.debug( f"Ref key '{ref_key}' already exists, skipping" ) return ProcessedClip( motion_key=clip.motion_key, metadata=clip.metadata, arrays=arrays, ) def _get_humanoid_fk(self) -> HumanoidBatch: """Lazy-load and cache HumanoidBatch for FK computation.""" if self._humanoid_fk is not None: return self._humanoid_fk cfg = self.padding_cfg robot_config_path = str(getattr(cfg, "robot_config_path", "")) self._robot_cfg = OmegaConf.load(robot_config_path) self._humanoid_fk = HumanoidBatch(self._robot_cfg.robot) return self._humanoid_fk def _get_default_dof_pos(self) -> np.ndarray: """Get default DOF positions from robot config.""" robot_cfg = self._robot_cfg.robot dof_names = list(robot_cfg.dof_names) init_state = robot_cfg.get("init_state", {}) default_angles = init_state.get("default_joint_angles", {}) default_dof = np.zeros(len(dof_names), dtype=np.float32) for i, name in enumerate(dof_names): default_dof[i] = float(default_angles.get(name, 0.0)) return default_dof def _apply_add_padding(self, clip: ProcessedClip) -> ProcessedClip: """Add transition and static padding to the motion clip. Adds stand-still padding at default pose before and after the motion, with smooth transitions between default pose and the motion's first/last frames. Recalculates all states from root pos, rot and dof pos via FK. """ cfg = self.padding_cfg if cfg is None: logger.warning( "Padding requested but padding_cfg is None - skipping padding" ) return clip fps = float(clip.metadata.get("motion_fps", 50.0)) stand_still_time = float(getattr(cfg, "stand_still_time", 1.0)) transition_time = float(getattr(cfg, "transition_time", 1.5)) robot_config_path = str(getattr(cfg, "robot_config_path", "")) if not robot_config_path: raise ValueError( "robot_config_path must be specified in padding_cfg" ) stand_still_frames = max(1, int(stand_still_time * fps)) transition_frames = max(1, int(transition_time * fps)) humanoid_fk = self._get_humanoid_fk() default_dof = self._get_default_dof_pos() extend_config = self._robot_cfg.robot.get("extend_config", []) num_augment = len(extend_config) if extend_config else 0 # Get root offset from HumanoidBatch (usually from MJCF root body pos) # self._offsets is [1, num_bodies, 3] # root_offset = humanoid_fk._offsets[0, 0].cpu().numpy() arrays = clip.arrays dof_pos = arrays.get("ref_dof_pos", arrays.get("dof_pos")) global_trans = arrays.get( "ref_global_translation", arrays.get("global_translation") ) global_rot = arrays.get( "ref_global_rotation_quat", arrays.get("global_rotation_quat") ) if dof_pos is None or global_trans is None or global_rot is None: raise ValueError( "Missing required arrays for padding: ref_dof_pos, " "ref_global_translation, or ref_global_rotation_quat" ) T_orig = dof_pos.shape[0] dof_pos = dof_pos.astype(np.float32, copy=True) root_pos = global_trans[:, 0, :].astype(np.float32, copy=True) root_rot = global_rot[:, 0, :].astype(np.float32, copy=True) first_dof = dof_pos[0].copy() last_dof = dof_pos[-1].copy() first_root_pos = root_pos[0].copy() last_root_pos = root_pos[-1].copy() first_root_rot = root_rot[0].copy() last_root_rot = root_rot[-1].copy() logger.debug( f"Padding: T_orig={T_orig}, first_root_pos={first_root_pos}, " f"last_root_pos={last_root_pos}" ) logger.debug( f"Padding: first_dof[:3]={first_dof[:3]}, last_dof[:3]={last_dof[:3]}" ) logger.debug( f"Padding: original dof_pos[-1][:3]={dof_pos[-1][:3]}, " f"original root_pos[-1]={root_pos[-1]}, original root_rot[-1]={root_rot[-1]}" ) first_yaw_quat = _extract_yaw_only_quat(first_root_rot) last_yaw_quat = _extract_yaw_only_quat(last_root_rot) start_stand_dof = np.tile(default_dof, (stand_still_frames, 1)) start_trans_dof = _interpolate_linear( default_dof, first_dof, transition_frames ) end_trans_dof = _interpolate_linear( last_dof, default_dof, transition_frames ) end_stand_dof = np.tile(default_dof, (stand_still_frames, 1)) start_stand_root_pos = np.tile(first_root_pos, (stand_still_frames, 1)) start_trans_root_pos = _interpolate_linear( first_root_pos, first_root_pos, transition_frames ) end_trans_root_pos = _interpolate_linear( last_root_pos, last_root_pos, transition_frames ) end_stand_root_pos = np.tile(last_root_pos, (stand_still_frames, 1)) start_stand_root_rot = np.tile(first_yaw_quat, (stand_still_frames, 1)) start_trans_root_rot = _interpolate_quaternions_slerp( first_yaw_quat, first_root_rot, transition_frames ) end_trans_root_rot = _interpolate_quaternions_slerp( last_root_rot, last_yaw_quat, transition_frames ) end_stand_root_rot = np.tile(last_yaw_quat, (stand_still_frames, 1)) # Construct full sequence of inputs full_dof = np.concatenate( [ start_stand_dof, start_trans_dof, dof_pos, end_trans_dof, end_stand_dof, ], axis=0, ) full_root_pos = np.concatenate( [ start_stand_root_pos, start_trans_root_pos, root_pos, end_trans_root_pos, end_stand_root_pos, ], axis=0, ) full_root_rot = np.concatenate( [ start_stand_root_rot, start_trans_root_rot, root_rot, end_trans_root_rot, end_stand_root_rot, ], axis=0, ) # Compute FK for the entire sequence to ensure continuity new_arrays = _compute_fk_motion( full_dof, full_root_pos, full_root_rot, humanoid_fk, num_augment, fps, ) T_new = full_dof.shape[0] wallclock_len = float(T_new - 1) / fps if fps > 0.0 else 0.0 meta = dict(clip.metadata) meta["num_frames"] = T_new meta["wallclock_len"] = wallclock_len meta["padding_stand_still_frames"] = stand_still_frames meta["padding_transition_frames"] = transition_frames meta["original_num_frames"] = T_orig return ProcessedClip( motion_key=clip.motion_key, metadata=meta, arrays=new_arrays, ) def process_npz_file(self, npz_path: Path) -> List[ProcessedClip]: with np.load(npz_path, allow_pickle=False) as data: if "metadata" not in data: raise KeyError(f"'metadata' missing in NPZ: {npz_path}") meta_text = str(data["metadata"]) metadata = json.loads(meta_text) motion_key = str(metadata["motion_key"]) arrays: Dict[str, np.ndarray] = {} for k in data.files: if k == "metadata": continue arrays[k] = np.array(data[k], copy=False) filename_without_ext = npz_path.stem metadata["source_filename"] = filename_without_ext base_clip = ProcessedClip( motion_key=motion_key, metadata=metadata, arrays=arrays, ) return self.process_clip(base_clip) def run_on_directory( self, src_root: Path, out_root: Path, use_ray: bool = False, num_workers: int = 0, ) -> None: if src_root.is_dir(): if (src_root / "clips").is_dir(): clips_src = src_root / "clips" else: clips_src = src_root else: raise ValueError(f"Source root is not a directory: {src_root}") clips_dst = out_root / "clips" clips_dst.mkdir(parents=True, exist_ok=True) files = sorted([p for p in clips_src.rglob("*.npz") if p.is_file()]) if not files: logger.info("No NPZ files found to process.") return if use_ray: if num_workers <= 0: available_cpus = int(ray.available_resources().get("CPU", 1)) effective_workers = max(1, available_cpus) else: effective_workers = num_workers self._run_on_directory_ray(files, clips_dst, effective_workers) else: self._run_on_directory_sequential(files, clips_dst) def _run_on_directory_sequential( self, files: List[Path], clips_dst: Path ) -> None: logger.info(f"Processing {len(files)} NPZ files sequentially") logger.info(f"Pipeline stages to apply: {self.pipeline}") total_input_clips = 0 total_output_clips = 0 for p in tqdm(files, desc="HoloMotion preprocess NPZ", unit="file"): clips = self.process_npz_file(p) total_input_clips += 1 for clip in clips: total_output_clips += 1 out_name = f"{clip.motion_key}.npz" out_path = clips_dst / out_name metadata_json = json.dumps(clip.metadata) np.savez_compressed( out_path, metadata=metadata_json, **clip.arrays ) logger.info( f"Processed {total_input_clips} input files into {total_output_clips} output clips" ) def _run_on_directory_ray( self, files: List[Path], clips_dst: Path, num_workers: int ) -> None: if num_workers <= 0: available_cpus = int(ray.available_resources().get("CPU", 1)) num_actors = min(len(files), max(1, available_cpus)) else: num_actors = min(len(files), num_workers) actors = [ PreprocessorActor.remote( slicing_cfg=self.slicing_cfg, filtering_cfg=self.filtering_cfg, tagging_cfg=self.tagging_cfg, padding_cfg=self.padding_cfg, pipeline=self.pipeline, ) for _ in range(num_actors) ] pending = {} next_idx = 0 for i in range(min(num_actors, len(files))): p = files[next_idx] next_idx += 1 ref = actors[i].process_npz_file.remote(str(p)) pending[ref] = i total_outputs = 0 with tqdm( total=len(files), desc="Ray: HoloMotion preprocess NPZ" ) as pbar: while pending: done, _ = ray.wait(list(pending.keys()), num_returns=1) ref = done[0] actor_idx = pending.pop(ref) clips = ray.get(ref) for clip in clips: out_name = f"{clip.motion_key}.npz" out_path = clips_dst / out_name metadata_json = json.dumps(clip.metadata) np.savez_compressed( out_path, metadata=metadata_json, **clip.arrays ) total_outputs += 1 pbar.update(1) if next_idx < len(files): p = files[next_idx] next_idx += 1 new_ref = actors[actor_idx].process_npz_file.remote(str(p)) pending[new_ref] = actor_idx logger.info(f"Processed {total_outputs} clips total.") def tag_directory(self, clips_dir: Path, tags_path: Path) -> None: files = sorted([p for p in clips_dir.rglob("*.npz") if p.is_file()]) clip_info: Dict[str, Dict[str, Dict[str, float]]] = {} all_speed: List[np.ndarray] = [] all_wnorm: List[np.ndarray] = [] all_zrel: List[np.ndarray] = [] all_jerk: List[np.ndarray] = [] for f in tqdm(files, desc="Tagging kinematics", unit="file"): with np.load(f, allow_pickle=True) as data: meta_text = str(data["metadata"]) meta = json.loads(meta_text) key = str(meta["motion_key"]) fps = float(meta["motion_fps"]) def pick(name: str) -> np.ndarray: if f"ft_ref_{name}" in data: return np.array(data[f"ft_ref_{name}"], copy=False) if f"ref_{name}" in data: return np.array(data[f"ref_{name}"], copy=False) return np.array([], dtype=np.float32) gv = pick("global_velocity") ga = pick("global_angular_velocity") gt = pick("global_translation") if gv.size > 0: root_vel = gv[:, 0, :] speed = np.linalg.norm(root_vel, axis=1) else: speed = np.array([], dtype=float) if ga.size > 0: root_w = ga[:, 0, :] wnorm = np.linalg.norm(root_w, axis=1) else: wnorm = np.array([], dtype=float) if gt.size > 0: root_pos_z = gt[:, 0, 2] z_rel = np.abs(root_pos_z - float(root_pos_z[0])) else: z_rel = np.array([], dtype=float) if gv.shape[0] >= 3: dt = 1.0 / fps if fps > 0.0 else 0.0 a = ( np.diff(gv, axis=0) / dt if dt > 0.0 else np.zeros_like(gv) ) j = ( np.diff(a, axis=0) / dt if dt > 0.0 else np.zeros_like(a) ) jn = np.linalg.norm(j, axis=2) else: jn = np.array([], dtype=float) clip_info[key] = { "root_linear_speed": _summary(speed), "root_angular_speed": _summary(wnorm), "root_delta_z": _summary(z_rel), "jerk": _summary(jn), } if speed.size > 0: all_speed.append(speed.astype(float)) if wnorm.size > 0: all_wnorm.append(wnorm.astype(float)) if z_rel.size > 0: all_zrel.append(z_rel.astype(float)) if jn.size > 0: all_jerk.append(jn.astype(float)) speed_cat = ( np.concatenate([a for a in all_speed if a.size > 0], axis=0) if len(all_speed) > 0 else np.array([], dtype=float) ) wnorm_cat = ( np.concatenate([a for a in all_wnorm if a.size > 0], axis=0) if len(all_wnorm) > 0 else np.array([], dtype=float) ) zrel_cat = ( np.concatenate([a for a in all_zrel if a.size > 0], axis=0) if len(all_zrel) > 0 else np.array([], dtype=float) ) jerk_cat = ( np.concatenate([a for a in all_jerk if a.size > 0], axis=0) if len(all_jerk) > 0 else np.array([], dtype=float) ) result = { "dataset_stats": { "root_linear_speed": _ds_summary(speed_cat), "root_angular_speed": _ds_summary(wnorm_cat), "root_delta_z": _ds_summary(zrel_cat), "jerk": _ds_summary(jerk_cat), }, "clip_info": clip_info, } with open(tags_path, "w") as f: json.dump(result, f, indent=2, sort_keys=True) logger.info(f"Wrote kinematic tags JSON to: {tags_path}") @ray.remote class PreprocessorActor: """Ray actor that holds a HoloMotionPreprocessor instance for parallel processing.""" def __init__( self, slicing_cfg: Optional[DictConfig] = None, filtering_cfg: Optional[DictConfig] = None, tagging_cfg: Optional[DictConfig] = None, padding_cfg: Optional[DictConfig] = None, pipeline: Optional[List[str]] = None, ) -> None: self.preprocessor = HoloMotionPreprocessor( slicing_cfg=slicing_cfg, filtering_cfg=filtering_cfg, tagging_cfg=tagging_cfg, padding_cfg=padding_cfg, pipeline=pipeline, ) logger.debug( f"PreprocessorActor initialized with pipeline: {self.preprocessor.pipeline}" ) def process_npz_file(self, npz_path_str: str) -> List[ProcessedClip]: npz_path = Path(npz_path_str) return self.preprocessor.process_npz_file(npz_path) @hydra.main( config_path="../../config", config_name="motion_retargeting/holomotion_preprocess", version_base=None, ) def main(cfg: DictConfig) -> None: logger.remove() logger.add(sys.stderr, level="INFO", colorize=True) src_root = Path(str(cfg.io.src_root)).expanduser().resolve() out_root = Path(str(cfg.io.out_root)).expanduser().resolve() out_root.mkdir(parents=True, exist_ok=True) # Dump resolved config used with open(out_root / "config_used.yaml", "w") as f: f.write(OmegaConf.to_yaml(cfg)) # Parse pipeline pipeline_cfg = cfg.get("preprocess", None) logger.debug(f"Raw preprocess config: {pipeline_cfg}") pipeline = None if pipeline_cfg is not None: pipeline_val = pipeline_cfg.get("pipeline", None) logger.debug( f"Raw pipeline value: {pipeline_val} (type: {type(pipeline_val)})" ) if pipeline_val is not None: if isinstance(pipeline_val, (list, tuple, ListConfig)): pipeline = [str(s) for s in pipeline_val] elif isinstance(pipeline_val, str): import ast pipeline = ast.literal_eval(pipeline_val) else: logger.warning( f"Unexpected pipeline type: {type(pipeline_val)}, value: {pipeline_val}" ) pipeline = [] else: logger.debug("pipeline_val is None") else: logger.debug("preprocess config is None") # Separate per-clip stages from dataset-level stages per_clip_pipeline = ( [s for s in pipeline if s != "tagging"] if pipeline else [] ) tagging_enabled = pipeline and "tagging" in pipeline logger.info("=" * 80) logger.info("Preprocessing Configuration:") logger.info(f" Source directory: {src_root}") logger.info(f" Output directory: {out_root}") if pipeline: logger.info(f" Pipeline stages: {pipeline}") logger.info(f" Number of stages: {len(pipeline)}") for i, stage in enumerate(pipeline, 1): logger.info(f" {i}. {stage}") if tagging_enabled: logger.info( " Note: 'tagging' is a dataset-level operation and will run after all clips are processed" ) else: logger.warning( " No preprocessing pipeline specified - no processors will be applied!" ) logger.info("=" * 80) use_ray = bool(cfg.get("ray", {}).get("enabled", False)) num_workers = int(cfg.get("ray", {}).get("num_workers", 0)) ray_address = str(cfg.get("ray", {}).get("ray_address", "")) if use_ray: logging.getLogger("filelock").setLevel(logging.WARNING) logging.getLogger("ray").setLevel(logging.ERROR) os.environ.setdefault("RAY_BACKEND_LOG_LEVEL", "error") if ray_address: ray.init( address=ray_address, ignore_reinit_error=True, log_to_driver=False, include_dashboard=False, logging_level=logging.ERROR, ) if num_workers <= 0: num_workers = int(ray.available_resources().get("CPU", 1)) else: num_cpus = None if num_workers <= 0 else num_workers ray.init( num_cpus=num_cpus, ignore_reinit_error=True, log_to_driver=False, include_dashboard=False, logging_level=logging.ERROR, ) if num_workers <= 0: num_workers = int(ray.available_resources().get("CPU", 1)) preprocessor = HoloMotionPreprocessor( slicing_cfg=cfg.slicing, filtering_cfg=cfg.filtering, tagging_cfg=cfg.tagging, padding_cfg=cfg.get("padding", None), pipeline=per_clip_pipeline if per_clip_pipeline else None, ) logger.info( f"Preprocessor initialized with pipeline: {preprocessor.pipeline}" ) logger.info( f" Slicing config present: {preprocessor.slicing_cfg is not None}" ) logger.info( f" Filtering config present: {preprocessor.filtering_cfg is not None}" ) logger.info( f" Tagging config present: {preprocessor.tagging_cfg is not None}" ) preprocessor.run_on_directory( src_root, out_root, use_ray=use_ray, num_workers=num_workers ) if use_ray: ray.shutdown() if tagging_enabled: if str(cfg.tagging.output_json_path): tags_path = ( Path(str(cfg.tagging.output_json_path)).expanduser().resolve() ) else: tags_path = out_root / "kinematic_tags.json" clips_dir = out_root / "clips" preprocessor.tag_directory(clips_dir, tags_path) if __name__ == "__main__": main() ================================================ FILE: holomotion/src/motion_retargeting/kinematic_filter.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import json import sys from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple import hydra import yaml from loguru import logger from omegaconf import DictConfig, OmegaConf from tqdm import tqdm def _eval_rule(val: float, op: str, thr: float) -> bool: if op == ">": return val > thr if op == ">=": return val >= thr if op == "<": return val < thr if op == "<=": return val <= thr if op == "==": return val == thr if op == "!=": return val != thr raise ValueError(f"Unsupported op: {op}") def _deep_get(container: Dict[str, Any], parts: List[str]) -> Optional[float]: cur: Any = container for p in parts: if not isinstance(cur, dict) or p not in cur: return None cur = cur[p] if isinstance(cur, (int, float)): return float(cur) return None def _resolve_value( tags_root: Dict[str, Any], clip_group: Dict[str, Any], path: str, ) -> Optional[float]: """Resolve a threshold path to a numeric value. - dataset_stats.. reads from tags_root - kinematic_features.. reads from clip_group - . (no prefix) also reads from clip_group for convenience. """ parts = str(path).split(".") if len(parts) == 0: return None if parts[0] == "dataset_stats": return _deep_get(tags_root, parts) if parts[0] == "kinematic_features": return _deep_get(clip_group, parts[1:]) return _deep_get(clip_group, parts) def filter_with_schema( tags: Dict[str, Any], schema: Dict[str, Any], ) -> Tuple[Set[str], Dict[str, int], Dict[str, int]]: thresholds: Dict[str, Dict[str, Any]] = schema.get("thresholds", {}) or {} across_mode = str(schema.get("across", "union")) out: Set[str] = set() path_counts: Dict[str, int] = {} group_counts: Dict[str, int] = {} clips: Dict[str, Dict[str, Any]] = tags.get("clip_info", {}) or {} for motion_key, groups in tqdm( clips.items(), desc="Evaluating schema", unit="clip" ): hits: List[bool] = [] hits_by_path: Dict[str, bool] = {} group_hit_any: Dict[str, bool] = {} for path, spec in thresholds.items(): parts = str(path).split(".") if len(parts) == 0: continue val = _resolve_value(tags, groups, path) if val is None: continue op = str(spec.get("op", ">")) thr = float(spec["value"]) hit = _eval_rule(val, op, thr) hits.append(hit) hits_by_path[path] = hit grp = parts[0] if hit: group_hit_any[grp] = True if len(hits) == 0: continue if across_mode == "union": excluded = any(hits) elif across_mode == "intersection": excluded = all(hits) else: raise ValueError(f"Invalid across mode: {across_mode}") if not excluded: continue out.add(motion_key) # accumulate counts for excluded clips for pth, hit in hits_by_path.items(): if hit: path_counts[pth] = path_counts.get(pth, 0) + 1 for grp, any_hit in group_hit_any.items(): if any_hit: group_counts[grp] = group_counts.get(grp, 0) + 1 return out, path_counts, group_counts def _default_schema_path() -> Path: # holomotion/src/motion_retargeting/kinematic_filter.py # -> holomotion/config/motion_retargeting/kinematic_filtering_schema.yaml this_file = Path(__file__).resolve() holomotion_dir = this_file.parents[2] return ( holomotion_dir / "config" / "motion_retargeting" / "kinematic_filtering_schema.yaml" ) def run( dataset_root: str, schema_yaml_path: Optional[str] = None, output_yaml_path: Optional[str] = None, schema_obj: Optional[Dict[str, Any]] = None, ) -> Set[str]: """Execute kinematic filtering using tags and a schema. - dataset_root: directory containing 'kinematic_tags.json' - schema_yaml_path: external YAML with 'across' and 'thresholds' (optional) - schema_obj: inline dict with 'across' and 'thresholds' (optional) - output_yaml_path: where to write the excluded list YAML (optional) """ root = Path(dataset_root).expanduser().resolve() tags_path = root / "kinematic_tags.json" if not tags_path.is_file(): raise FileNotFoundError(f"Missing kinematic tags JSON: {tags_path}") schema: Dict[str, Any] if schema_obj is not None: schema = dict(schema_obj) else: schema_path = ( Path(schema_yaml_path).expanduser().resolve() if schema_yaml_path else _default_schema_path() ) if not schema_path.is_file(): raise FileNotFoundError(f"Missing schema YAML: {schema_path}") schema = yaml.safe_load(open(schema_path, "r", encoding="utf-8")) out_yaml = ( Path(output_yaml_path).expanduser().resolve() if output_yaml_path else (root / "excluded_kinematic_motion_names.yaml") ) logger.info(f"Dataset root: {root}") logger.info(f"Reading tags from: {tags_path}") logger.info( "Using schema from: inline config" if schema_obj is not None else "Using schema from YAML file" ) # Pretty-print resolved schema to console try: logger.info( "Resolved schema:\n" + yaml.safe_dump(schema, sort_keys=True, default_flow_style=False) ) except Exception: pass tags = json.load(open(tags_path, "r", encoding="utf-8")) # Dump the used filter config into dataset root try: used_cfg = { "dataset_root": str(root), "output_yaml": str(out_yaml), "schema": schema, } with open( root / "kinematic_filter_config_used.yaml", "w", encoding="utf-8" ) as f: yaml.safe_dump( used_cfg, f, sort_keys=True, default_flow_style=False ) except Exception: pass excluded_keys, path_counts, group_counts = filter_with_schema(tags, schema) with open(out_yaml, "w", encoding="utf-8") as f: f.write("# @package _global_\n\n") f.write("excluded_motion_names:\n") for k in sorted(excluded_keys): f.write(f"- {k}\n") logger.info(f"Excluded by config: {len(excluded_keys)}") if len(group_counts) > 0: logger.info("Excluded counts by category:") for grp, cnt in sorted( group_counts.items(), key=lambda kv: kv[1], reverse=True ): logger.info(f"- {grp}: {cnt}") if len(path_counts) > 0: logger.info("Excluded counts by threshold path:") for pth, cnt in sorted( path_counts.items(), key=lambda kv: kv[1], reverse=True ): logger.info(f"- {pth}: {cnt}") logger.info(f"Wrote excluded list to: {out_yaml}") return excluded_keys @hydra.main( config_path="../../config", config_name="motion_retargeting/kinematic_filter", version_base=None, ) def main(cfg: DictConfig) -> None: logger.remove() logger.add(sys.stderr, level="INFO", colorize=True) dataset_root = str(cfg.io.dataset_root) # Optional fields (external schema YAML override and output path) schema_val = "" out_val = "" schema_obj = None if "schema" in cfg: # Inline schema object schema_obj = OmegaConf.to_object(cfg.schema) if "filtering" in cfg and hasattr(cfg.filtering, "schema_yaml"): schema_val = str(cfg.filtering.get("schema_yaml", "") or "") out_val = str(cfg.filtering.get("output_yaml", "") or "") elif "filtering" in cfg: out_val = str(cfg.filtering.get("output_yaml", "") or "") schema_yaml_path = schema_val if len(schema_val) > 0 else None output_yaml_path = out_val if len(out_val) > 0 else None run( dataset_root=dataset_root, schema_yaml_path=schema_yaml_path, schema_obj=schema_obj, output_yaml_path=output_yaml_path, ) if __name__ == "__main__": main() ================================================ FILE: holomotion/src/motion_retargeting/pack_hdf5_v2.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import json import os from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import h5py import hydra import numpy as np from loguru import logger from omegaconf import ListConfig, OmegaConf from tqdm import tqdm def _ensure_dir(path: str) -> None: os.makedirs(path, exist_ok=True) @dataclass class ArraySpec: name: str shape_tail: Tuple[int, ...] # shape excluding time dim dtype: np.dtype @dataclass class ClipEntry: clip_id: int name: str path: str class Hdf5ShardWriter: def __init__( self, h5_path: str, array_specs: List[ArraySpec], chunks_t: int, compression: str, ) -> None: self.h5_path = h5_path self.array_specs = array_specs self.chunks_t = int(chunks_t) self.compression = compression _ensure_dir(os.path.dirname(self.h5_path)) self.h5 = h5py.File(self.h5_path, "w") self.datasets: Dict[str, h5py.Dataset] = {} for spec in self.array_specs: chunks = (self.chunks_t, *spec.shape_tail) maxshape = (None, *spec.shape_tail) ds = self.h5.create_dataset( spec.name, shape=(0, *spec.shape_tail), maxshape=maxshape, chunks=chunks, compression=( self.compression if self.compression != "none" else None ), dtype=spec.dtype, shuffle=True if self.compression != "none" else False, ) self.datasets[spec.name] = ds self._clip_starts: List[int] = [] self._clip_lengths: List[int] = [] self._clip_motion_ids: List[int] = [] self._clip_metadata: List[str] = [] self.t_cursor = 0 def append_motion( self, motion_id: int, np_arrays: Dict[str, np.ndarray], metadata_json: str, ) -> Tuple[int, int]: if "ref_dof_pos" not in np_arrays: raise KeyError("ref_dof_pos missing for HDF5 v2 packing") t_len = int(np_arrays["ref_dof_pos"].shape[0]) start = self.t_cursor end = start + t_len for spec in self.array_specs: if spec.name not in np_arrays: raise KeyError( f"Missing array '{spec.name}' for HDF5 v2 packing" ) ds = self.datasets[spec.name] ds.resize((end, *spec.shape_tail)) ds[start:end, ...] = np_arrays[spec.name] self._clip_starts.append(start) self._clip_lengths.append(t_len) self._clip_motion_ids.append(motion_id) self._clip_metadata.append(metadata_json) self.t_cursor = end return start, t_len def finalize(self) -> Dict[str, Any]: g = self.h5.create_group("clips") g.create_dataset( "start", data=np.asarray(self._clip_starts, dtype=np.int64) ) g.create_dataset( "length", data=np.asarray(self._clip_lengths, dtype=np.int64) ) g.create_dataset( "motion_key_id", data=np.asarray(self._clip_motion_ids, dtype=np.int64), ) vlen_str = h5py.string_dtype(encoding="utf-8") g.create_dataset( "metadata_json", data=np.asarray(self._clip_metadata, dtype=vlen_str), ) summary = { "file": self.h5_path, "num_clips": len(self._clip_starts), "num_frames": int(self.t_cursor), } self.h5.flush() self.h5.close() return summary def _normalize_root_list(value: Any) -> List[str]: if value is None: return [] if isinstance(value, (str, os.PathLike)): return [str(value)] if isinstance(value, (list, tuple, ListConfig)): return [str(v) for v in list(value)] return [str(value)] def _discover_motion_entries(roots: List[str]) -> List[ClipEntry]: motion_key_to_path: Dict[str, str] = {} for root in roots: root_path = Path(root).expanduser().resolve() parent_dir_name = root_path.name clips_dir = root_path / "clips" base_dir = clips_dir if clips_dir.is_dir() else root_path if not base_dir.is_dir(): raise FileNotFoundError(f"NPZ directory not found: {base_dir}") for dirpath, _, filenames in os.walk(str(base_dir)): for fname in filenames: if not fname.endswith(".npz"): continue base_key = os.path.splitext(fname)[0] motion_key = f"{parent_dir_name}_{base_key}" npz_path = os.path.join(dirpath, fname) if motion_key in motion_key_to_path: raise ValueError(f"Duplicate motion key: {motion_key}") motion_key_to_path[motion_key] = npz_path entries = [ ClipEntry(clip_id=i, name=key, path=motion_key_to_path[key]) for i, key in enumerate(sorted(motion_key_to_path.keys())) ] if len(entries) == 0: raise ValueError("No NPZ files found in input directories.") return entries def _load_metadata_json(npz_path: Path) -> Tuple[str, Dict[str, Any]]: with np.load(npz_path, allow_pickle=False) as data: if "metadata" not in data: raise KeyError(f"'metadata' missing in NPZ: {npz_path}") metadata_json = str(data["metadata"]) num_frames_from_dof = data["ref_dof_pos"].shape[0] metadata = json.loads(metadata_json) num_frames_from_metadata = metadata["num_frames"] assert num_frames_from_dof == num_frames_from_metadata, ( f"num_frames_from_dof {num_frames_from_dof} != num_frames_from_metadata {num_frames_from_metadata} in {npz_path}" ) if not isinstance(metadata, dict): raise ValueError(f"metadata must be a JSON object in {npz_path}") return metadata_json, metadata def _cast_array(array: np.ndarray, name: str, npz_path: Path) -> np.ndarray: if array.dtype == np.float32: return array if array.dtype.kind == "O": raise ValueError(f"Array '{name}' in {npz_path} has object dtype.") if np.issubdtype(array.dtype, np.integer): logger.warning( "Casting array '{}' in {} from {} to float32.", name, npz_path, array.dtype, ) return array.astype(np.float32, copy=False) raise ValueError( f"Array '{name}' in {npz_path} has dtype {array.dtype}, " "expected float32 or integer." ) def _discover_array_specs(first_npz: Path) -> List[ArraySpec]: with np.load(first_npz, allow_pickle=False) as data: if "ref_dof_pos" not in data: raise KeyError(f"'ref_dof_pos' missing in NPZ: {first_npz}") if "ref_global_translation" not in data: raise KeyError( f"'ref_global_translation' missing in NPZ: {first_npz}" ) if "ref_global_rotation_quat" not in data: raise KeyError( f"'ref_global_rotation_quat' missing in NPZ: {first_npz}" ) dof_pos = data["ref_dof_pos"] global_trans = data["ref_global_translation"] global_rot = data["ref_global_rotation_quat"] if dof_pos.ndim < 2: raise ValueError(f"'ref_dof_pos' must be (T, ndof) in {first_npz}") if global_trans.ndim < 2 or global_trans.shape[-1] != 3: raise ValueError( f"'ref_global_translation' must end with 3 in {first_npz}" ) if global_rot.ndim < 2 or global_rot.shape[-1] != 4: raise ValueError( f"'ref_global_rotation_quat' must end with 4 in {first_npz}" ) dof_tail = tuple(dof_pos.shape[1:]) return [ ArraySpec(name="ref_dof_pos", shape_tail=dof_tail, dtype=np.float32), ArraySpec(name="ref_root_pos", shape_tail=(3,), dtype=np.float32), ArraySpec(name="ref_root_rot", shape_tail=(4,), dtype=np.float32), ] def _load_npz_arrays( npz_path: Path, num_frames: int, dof_tail: Tuple[int, ...], ) -> Dict[str, np.ndarray]: with np.load(npz_path, allow_pickle=False) as data: dof_pos = _cast_array(data["ref_dof_pos"], "ref_dof_pos", npz_path) global_trans = _cast_array( data["ref_global_translation"], "ref_global_translation", npz_path ) global_rot = _cast_array( data["ref_global_rotation_quat"], "ref_global_rotation_quat", npz_path, ) if global_trans.ndim == 2: root_pos = global_trans elif global_trans.ndim >= 3: root_pos = global_trans[:, 0, :] else: raise ValueError( f"ref_global_translation must be (T,3) or (T,B,3) in {npz_path}" ) if global_rot.ndim == 2: root_rot = global_rot elif global_rot.ndim >= 3: root_rot = global_rot[:, 0, :] else: raise ValueError( f"ref_global_rotation_quat must be (T,4) or (T,B,4) in {npz_path}" ) expected_dof_shape = (num_frames, *dof_tail) if dof_pos.shape != expected_dof_shape: raise ValueError( f"ref_dof_pos shape {dof_pos.shape} does not match {expected_dof_shape} " f"in {npz_path}" ) if root_pos.shape != (num_frames, 3): raise ValueError( f"ref_root_pos shape {root_pos.shape} does not match {(num_frames, 3)} " f"in {npz_path}" ) if root_rot.shape != (num_frames, 4): raise ValueError( f"ref_root_rot shape {root_rot.shape} does not match {(num_frames, 4)} " f"in {npz_path}" ) return { "ref_dof_pos": dof_pos, "ref_root_pos": root_pos, "ref_root_rot": root_rot, } def _relative_npz_path(npz_path: Path, roots: List[str]) -> str: npz_path = npz_path.expanduser().resolve() for root in roots: root_path = Path(root).expanduser().resolve() try: rel = npz_path.relative_to(root_path) except ValueError: continue return str(Path(root_path.name) / rel) return str(npz_path) def _nan_array_names(arrays: Dict[str, np.ndarray]) -> List[str]: nan_names: List[str] = [] for name, array in arrays.items(): if not np.issubdtype(array.dtype, np.floating): continue if np.isnan(array).any(): nan_names.append(name) return nan_names return [] def _estimate_bytes_for_motion(npz_path: Path, mode: str) -> int: """Estimate per-clip byte contribution for shard sizing. Note: - ``uncompressed_nbytes`` matches the in-memory float32 payload size and does *not* correspond to on-disk shard size when compression is enabled. - ``npz_filesize`` uses the compressed input file size as a cheap proxy for on-disk shard size. - ``h5_filesize`` does not use this estimator (it measures actual shard size after writes). """ mode_norm = str(mode).lower().strip() if mode_norm in ("uncompressed_nbytes", "nbytes", "uncompressed"): with np.load(npz_path, allow_pickle=False) as data: total = 0 for key in ( "ref_dof_pos", "ref_global_translation", "ref_global_rotation_quat", ): if key in data: total += int(data[key].nbytes) return total if mode_norm in ("npz_filesize", "npz_size", "npz_bytes"): return int(npz_path.stat().st_size) raise ValueError( f"Unsupported shard_target_mode '{mode}'. Expected one of: " "uncompressed_nbytes | npz_filesize | h5_filesize" ) @hydra.main( config_path="../../config", config_name="motion_retargeting/pack_hdf5_v2", version_base=None, ) def main(cfg: OmegaConf) -> None: roots = _normalize_root_list(cfg.get("holomotion_npz_root", None)) if len(roots) == 0: roots = _normalize_root_list( cfg.get("holomotion_retargeted_dirs", None) ) if len(roots) == 0: legacy_root = cfg.get("precomputed_npz_root", None) roots = _normalize_root_list(legacy_root) if len(roots) == 0: raise ValueError("holomotion_npz_root must be provided.") hdf5_root = cfg.get( "hdf5_root", os.path.join(os.getcwd(), "holomotion_hdf5_v2") ) chunks_t = int(cfg.get("chunks_t", 1024)) compression = str(cfg.get("compression", "lzf")).lower() shard_target_gb = float(cfg.get("shard_target_gb", 2.0)) shard_target_bytes = int( cfg.get("shard_target_bytes", shard_target_gb * (1 << 30)) ) shard_target_mode = str( cfg.get("shard_target_mode", "uncompressed_nbytes") ) for root in roots: if not os.path.isdir(root): raise FileNotFoundError(f"NPZ clips directory not found: {root}") entries = _discover_motion_entries(roots) motion_keys: List[str] = [] motion_key2id: Dict[str, int] = {} nan_npz_paths: List[str] = [] first_npz = Path(entries[0].path) array_specs = _discover_array_specs(first_npz) array_names_created = [s.name for s in array_specs] dof_tail = next( spec.shape_tail for spec in array_specs if spec.name == "ref_dof_pos" ) logger.info( "HDF5 v2 datasets: {} (dof_tail={})", array_names_created, dof_tail, ) dof_names: List[str] = [] body_names: List[str] = [] extended_body_names: List[str] = [] robot_cfg = cfg.get("robot", None) if robot_cfg is not None and "motion" in robot_cfg: motion_cfg = robot_cfg["motion"] dof_names = list(motion_cfg.get("dof_names", [])) body_names = list(motion_cfg.get("body_names", [])) extended_body_names = list( list(motion_cfg.get("body_names", [])) + [ i.get("joint_name") for i in motion_cfg.get("extend_config", []) ] ) shard_dir = os.path.join(str(hdf5_root), "shards") _ensure_dir(shard_dir) hdf5_shards: List[Dict[str, Any]] = [] clips_manifest: Dict[str, Dict[str, Any]] = {} curr_shard_idx = 0 curr_shard_bytes = 0 writer: Optional[Hdf5ShardWriter] = None pbar = tqdm(total=len(entries), desc="Packing HDF5 v2 shards") for entry in entries: npz_path = Path(entry.path) metadata_json, metadata = _load_metadata_json(npz_path) if "num_frames" not in metadata: raise KeyError(f"'num_frames' missing in metadata: {npz_path}") num_frames = int(metadata["num_frames"]) if num_frames <= 0: raise ValueError(f"Invalid num_frames {num_frames} in {npz_path}") arrays_np = _load_npz_arrays( npz_path=npz_path, num_frames=num_frames, dof_tail=dof_tail ) nan_arrays = _nan_array_names(arrays_np) if len(nan_arrays) > 0: rel_npz_path = _relative_npz_path(npz_path, roots) nan_npz_paths.append(rel_npz_path) logger.warning( "NaN detected in NPZ (arrays: {}), skipping: {}", nan_arrays, npz_path, ) pbar.update(1) continue shard_mode_norm = shard_target_mode.lower().strip() if shard_mode_norm in ( "h5_filesize", "h5_size", "output_filesize", "disk", ): if writer is None: shard_name = f"holomotion_{curr_shard_idx:03d}.h5" shard_path = os.path.join(shard_dir, shard_name) writer = Hdf5ShardWriter( shard_path, array_specs, chunks_t=chunks_t, compression=compression, ) else: motion_bytes = _estimate_bytes_for_motion( npz_path, shard_target_mode ) if ( writer is None or (curr_shard_bytes + motion_bytes) > shard_target_bytes ): if writer is not None: shard_summary = writer.finalize() hdf5_shards.append( { "file": os.path.relpath( shard_summary["file"], str(hdf5_root) ), "num_clips": shard_summary["num_clips"], "num_frames": shard_summary["num_frames"], } ) curr_shard_idx += 1 curr_shard_bytes = 0 shard_name = f"holomotion_{curr_shard_idx:03d}.h5" shard_path = os.path.join(shard_dir, shard_name) writer = Hdf5ShardWriter( shard_path, array_specs, chunks_t=chunks_t, compression=compression, ) motion_id = motion_key2id.get(entry.name) if motion_id is None: motion_id = len(motion_keys) motion_key2id[entry.name] = motion_id motion_keys.append(entry.name) start, length = writer.append_motion( motion_id=motion_id, np_arrays=arrays_np, metadata_json=metadata_json, ) clips_manifest[entry.name] = { "motion_key": entry.name, "shard": curr_shard_idx, "clip_idx": len(writer._clip_starts) - 1, "start": int(start), "length": int(length), "available_arrays": list(array_names_created), "metadata": metadata, } if shard_mode_norm in ( "h5_filesize", "h5_size", "output_filesize", "disk", ): writer.h5.flush() curr_shard_bytes = int(os.path.getsize(writer.h5_path)) else: curr_shard_bytes += motion_bytes pbar.update(1) if ( shard_mode_norm in ("h5_filesize", "h5_size", "output_filesize", "disk") and curr_shard_bytes >= shard_target_bytes and writer is not None ): shard_summary = writer.finalize() hdf5_shards.append( { "file": os.path.relpath( shard_summary["file"], str(hdf5_root) ), "num_clips": shard_summary["num_clips"], "num_frames": shard_summary["num_frames"], } ) curr_shard_idx += 1 curr_shard_bytes = 0 writer = None pbar.close() if writer is not None: shard_summary = writer.finalize() hdf5_shards.append( { "file": os.path.relpath(shard_summary["file"], str(hdf5_root)), "num_clips": shard_summary["num_clips"], "num_frames": shard_summary["num_frames"], } ) manifest = { "version": 1, "root": str(hdf5_root), "hdf5_shards": hdf5_shards, "clips": clips_manifest, "motion_keys": motion_keys, "dof_names": dof_names, "body_names": body_names, "extended_body_names": extended_body_names, "array_names": array_names_created, "chunks_t": int(chunks_t), "compression": compression, "shard_target_mode": str(shard_target_mode), "shard_target_bytes": int(shard_target_bytes), } _ensure_dir(str(hdf5_root)) nan_paths_path = os.path.join(str(hdf5_root), "nan_npz_paths.json") with open(nan_paths_path, "w") as f: json.dump(nan_npz_paths, f, indent=2) if len(nan_npz_paths) > 0: logger.warning( "Skipped {} NPZ files due to NaNs. List: {}", len(nan_npz_paths), nan_paths_path, ) else: logger.info("No NaN detected in NPZ inputs.") with open(os.path.join(str(hdf5_root), "manifest.json"), "w") as f: json.dump(manifest, f, indent=2) logger.info( "HDF5 v2 packing complete. Shards: {}. Root: {}", len(hdf5_shards), hdf5_root, ) if __name__ == "__main__": main() ================================================ FILE: holomotion/src/motion_retargeting/reference_filtering.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from typing import Dict, Mapping, Tuple import numpy as np # This module keeps the offline preprocess filtering path and the online # root/DoF-before-FK path aligned while still exposing helpers tailored to # each tensor family. def _reshape_time_flat(a: np.ndarray) -> Tuple[np.ndarray, Tuple[int, ...]]: shape = a.shape t = shape[0] return a.reshape(t, -1), shape def _butterworth_lowpass_smooth_time( a: np.ndarray, fps: float, cutoff_hz: float, order: int ) -> np.ndarray: from scipy.signal import butter, filtfilt t = a.shape[0] if t < 3: return a.astype(np.float32, copy=True) if fps <= 0.0 or cutoff_hz <= 0.0: return a.astype(np.float32, copy=True) nyquist = 0.5 * float(fps) wn = float(cutoff_hz) / nyquist if wn >= 1.0: wn = 0.999 if wn <= 0.0: return a.astype(np.float32, copy=True) flat, shape = _reshape_time_flat(a.astype(np.float64, copy=False)) b, a_coefs = butter(int(order), wn, btype="low", analog=False) maxlen = max(len(b), len(a_coefs)) padlen_required = max(3 * (maxlen - 1), 3 * maxlen) if t <= padlen_required: return a.astype(np.float32, copy=True) filtered = filtfilt(b, a_coefs, flat, axis=0, method="pad") return filtered.reshape(shape).astype(np.float32, copy=False) def _quat_normalize(q: np.ndarray) -> np.ndarray: norm = np.linalg.norm(q, axis=-1, keepdims=True) norm = np.where(norm == 0.0, 1.0, norm) return (q / norm).astype(np.float32, copy=False) def _quat_hemisphere_align(q: np.ndarray) -> np.ndarray: if q.shape[0] == 0: return q aligned = q.copy() prev = aligned[0] for t in range(1, aligned.shape[0]): dots = np.sum(prev * aligned[t], axis=-1) mask = dots < 0.0 if np.any(mask): aligned[t, mask] = -aligned[t, mask] prev = aligned[t] return aligned def _quat_conjugate(q: np.ndarray) -> np.ndarray: conj = q.copy() conj[..., :3] = -conj[..., :3] return conj def _quat_multiply(a: np.ndarray, b: np.ndarray) -> np.ndarray: av = a[..., :3] aw = a[..., 3:4] bv = b[..., :3] bw = b[..., 3:4] cross = np.cross(av, bv) vec = aw * bv + bw * av + cross scalar = aw * bw - np.sum(av * bv, axis=-1, keepdims=True) return np.concatenate([vec, scalar], axis=-1) def _finite_difference_time(a: np.ndarray, dt: float) -> np.ndarray: t = a.shape[0] if t < 2 or dt <= 0.0: return np.zeros_like(a, dtype=np.float32) deriv = np.gradient( a.astype(np.float64, copy=False), dt, axis=0, edge_order=2 if t >= 3 else 1, ) return deriv.astype(np.float32, copy=False) def _angular_velocity_from_quat( q: np.ndarray, q_dot: np.ndarray ) -> np.ndarray: q_conj = _quat_conjugate(q) prod = _quat_multiply(q_conj, q_dot) omega = 2.0 * prod[..., :3] return omega.astype(np.float32, copy=False) def butterworth_filter_ref_arrays( arrays: Mapping[str, np.ndarray], fps: float, cutoff_hz: float, order: int, ) -> Dict[str, np.ndarray]: out: Dict[str, np.ndarray] = {} dt = 1.0 / float(fps) if float(fps) > 0.0 else 0.0 if "ref_dof_pos" in arrays: dof_pos = arrays["ref_dof_pos"].astype(np.float32, copy=True) smooth_dof_pos = _butterworth_lowpass_smooth_time( dof_pos, fps, cutoff_hz, order ) dof_vel = _finite_difference_time(smooth_dof_pos, dt) out["ft_ref_dof_pos"] = smooth_dof_pos out["ft_ref_dof_vel"] = dof_vel if "ref_global_translation" in arrays: body_pos = arrays["ref_global_translation"].astype( np.float32, copy=True ) smooth_body_pos = _butterworth_lowpass_smooth_time( body_pos, fps, cutoff_hz, order ) body_vel = _finite_difference_time(smooth_body_pos, dt) out["ft_ref_global_translation"] = smooth_body_pos out["ft_ref_global_velocity"] = body_vel if "ref_global_rotation_quat" in arrays: body_rot = arrays["ref_global_rotation_quat"].astype( np.float32, copy=True ) body_rot = _quat_normalize(body_rot) body_rot = _quat_hemisphere_align(body_rot) smooth_body_rot = _butterworth_lowpass_smooth_time( body_rot, fps, cutoff_hz, order ) smooth_body_rot = _quat_normalize(smooth_body_rot) body_rot_dot = _finite_difference_time(smooth_body_rot, dt) out["ft_ref_global_rotation_quat"] = _quat_normalize(smooth_body_rot) out["ft_ref_global_angular_velocity"] = _angular_velocity_from_quat( smooth_body_rot, body_rot_dot ) return out def butterworth_filter_root_dof_arrays( arrays: Mapping[str, np.ndarray], fps: float, cutoff_hz: float, order: int, ) -> Dict[str, np.ndarray]: out: Dict[str, np.ndarray] = {} if "ref_root_pos" in arrays: root_pos = arrays["ref_root_pos"].astype(np.float32, copy=True) out["ft_ref_root_pos"] = _butterworth_lowpass_smooth_time( root_pos, fps, cutoff_hz, order ) if "ref_root_rot" in arrays: root_rot = arrays["ref_root_rot"].astype(np.float32, copy=True) root_rot = _quat_normalize(root_rot) root_rot = _quat_hemisphere_align(root_rot) smooth_root_rot = _butterworth_lowpass_smooth_time( root_rot, fps, cutoff_hz, order ) out["ft_ref_root_rot"] = _quat_normalize(smooth_root_rot) if "ref_dof_pos" in arrays: dof_pos = arrays["ref_dof_pos"].astype(np.float32, copy=True) out["ft_ref_dof_pos"] = _butterworth_lowpass_smooth_time( dof_pos, fps, cutoff_hz, order ) return out ================================================ FILE: holomotion/src/motion_retargeting/utils/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/motion_retargeting/utils/_schema.json ================================================ { "schema": { "root_trans_offset": { "shape": [ 682, 3 ], "dtype": "float64" }, "pose_aa": { "shape": [ 682, 27, 3 ], "dtype": "float32" }, "dof": { "shape": [ 682, 23 ], "dtype": "float32" }, "root_rot": { "shape": [ 682, 4 ], "dtype": "float64" }, "smpl_joints": { "shape": [ 682, 24, 3 ], "dtype": "float32" }, "fps": { "shape": [], "dtype": "int64" } }, "sample_top_key": "2024-12-28 16.03.15-视频-2025年主播必学120支热门小舞蹈-帅帅的《电话卡点舞》 #电话卡点舞 #...舞 #抖音热歌 #舞蹈教学 #网红必学+p02_1_btws_pad" } ================================================ FILE: holomotion/src/motion_retargeting/utils/rotation_conversions.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Optional, Union import torch import torch.nn.functional as F def wxyz_to_xyzw(quat): return quat[..., [1, 2, 3, 0]] def xyzw_to_wxyz(quat): return quat[..., [3, 0, 1, 2]] Device = Union[str, torch.device] """ The transformation matrices returned from the functions in this file assume the points on which the transformation will be applied are column vectors. i.e. the R matrix is structured as R = [ [Rxx, Rxy, Rxz], [Ryx, Ryy, Ryz], [Rzx, Rzy, Rzz], ] # (3, 3) This matrix can be applied to column vectors by post multiplication by the points e.g. points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point transformed_points = R * points To apply the same matrix to points which are row vectors, the R matrix can be transposed and pre multiplied by the points: e.g. points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point transformed_points = points * R.transpose(1, 0) """ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: """Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ r, i, j, k = torch.unbind(quaternions, -1) two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Return a tensor of absolute value. Return a tensor where each element has the absolute value taken from the corresponding element of a, with sign taken from the corresponding element of b. This is like the standard copysign floating-point operation, but is not careful about negative 0 and NaN. Args: a: source tensor. b: tensor whose signs will be used, of the same shape as a. Returns: Tensor of the same shape as a with the signs of b. """ signs_differ = (a < 0) != (b < 0) return torch.where(signs_differ, -a, a) def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """Returns torch.sqrt(torch.max(0, x)). but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: """W x y z. Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part first, as tensor of shape (..., 4). """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ torch.stack( [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1 ), torch.stack( [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1 ), torch.stack( [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1 ), torch.stack( [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1 ), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is # small, the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same # (up to a sign), forall i; we pick the best-conditioned one # (with the largest denominator) return quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :, # pyre-ignore[16] ].reshape(batch_dim + (4,)) def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: """Return the rotation matrices for one of the rotations about an axis. of which Euler angles describe, for each value of the angle given. Args: axis: Axis label "X" or "Y or "Z". angle: any shape tensor of Euler angles in radians Returns: Rotation matrices as tensor of shape (..., 3, 3). """ cos = torch.cos(angle) sin = torch.sin(angle) one = torch.ones_like(angle) zero = torch.zeros_like(angle) if axis == "X": r_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) elif axis == "Y": r_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) elif axis == "Z": r_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) else: raise ValueError("letter must be either X, Y or Z.") return torch.stack(r_flat, -1).reshape(angle.shape + (3, 3)) def euler_angles_to_matrix( euler_angles: torch.Tensor, convention: str ) -> torch.Tensor: """Convert rotations given as Euler angles in radians to rotation matrices. Args: euler_angles: Euler angles in radians as tensor of shape (..., 3). convention: Convention string of three uppercase letters from {"X", "Y", and "Z"}. Returns: Rotation matrices as tensor of shape (..., 3, 3). """ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: raise ValueError("Invalid input euler angles.") if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") matrices = [ _axis_angle_rotation(c, e) for c, e in zip( convention, torch.unbind(euler_angles, -1), strict=False ) ] # return functools.reduce(torch.matmul, matrices) return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) def _angle_from_tan( axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool ) -> torch.Tensor: """Extract the first or third Euler angle from the two members of. the matrix which are positive constant times its sine and cosine. Args: axis: Axis label "X" or "Y or "Z" for the angle we are finding. other_axis: Axis label "X" or "Y or "Z" for the middle axis in the convention. data: Rotation matrices as tensor of shape (..., 3, 3). horizontal: Whether we are looking for the angle for the third axis, which means the relevant entries are in the same row of the rotation matrix. If not, they are in the same column. tait_bryan: Whether the first and third axes in the convention differ. Returns: Euler Angles in radians for each matrix in data as a tensor of shape (...). """ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] if horizontal: i2, i1 = i1, i2 even = (axis + other_axis) in ["XY", "YZ", "ZX"] if horizontal == even: return torch.atan2(data[..., i1], data[..., i2]) if tait_bryan: return torch.atan2(-data[..., i2], data[..., i1]) return torch.atan2(data[..., i2], -data[..., i1]) def _index_from_letter(letter: str) -> int: if letter == "X": return 0 if letter == "Y": return 1 if letter == "Z": return 2 raise ValueError("letter must be either X, Y or Z.") def matrix_to_euler_angles( matrix: torch.Tensor, convention: str ) -> torch.Tensor: """Convert rotations given as rotation matrices to Euler angles in radians. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). convention: Convention string of three uppercase letters. Returns: Euler angles in radians as tensor of shape (..., 3). """ if len(convention) != 3: raise ValueError("Convention must have 3 letters.") if convention[1] in (convention[0], convention[2]): raise ValueError(f"Invalid convention {convention}.") for letter in convention: if letter not in ("X", "Y", "Z"): raise ValueError(f"Invalid letter {letter} in convention string.") if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") i0 = _index_from_letter(convention[0]) i2 = _index_from_letter(convention[2]) tait_bryan = i0 != i2 if tait_bryan: central_angle = torch.asin( matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) ) else: central_angle = torch.acos(matrix[..., i0, i0]) o = ( _angle_from_tan( convention[0], convention[1], matrix[..., i2], False, tait_bryan ), central_angle, _angle_from_tan( convention[2], convention[1], matrix[..., i0, :], True, tait_bryan ), ) return torch.stack(o, -1) def random_quaternions( n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None, ) -> torch.Tensor: """Generate random quaternions representing rotations. i.e. versors with nonnegative real part. Args: n: Number of quaternions in a batch to return. dtype: Type to return. device: Desired device of returned tensor. Default: uses the current device for the default tensor type. Returns: Quaternions as tensor of shape (N, 4). """ if isinstance(device, str): device = torch.device(device) o = torch.randn((n, 4), dtype=dtype, device=device) s = (o * o).sum(1) o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] return o def random_rotations( n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None, ) -> torch.Tensor: """Generate random rotations as 3x3 rotation matrices. Args: n: Number of rotation matrices in a batch to return. dtype: Type to return. device: Device of returned tensor. Default: if None, uses the current device for the default tensor type. Returns: Rotation matrices as tensor of shape (n, 3, 3). """ quaternions = random_quaternions(n, dtype=dtype, device=device) return quaternion_to_matrix(quaternions) def random_rotation( dtype: Optional[torch.dtype] = None, device: Optional[Device] = None ) -> torch.Tensor: """Generate a single random 3x3 rotation matrix. Args: dtype: Type to return device: Device of returned tensor. Default: if None, uses the current device for the default tensor type Returns: Rotation matrix as tensor of shape (3, 3). """ return random_rotations(1, dtype, device)[0] def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """Convert a unit quaternion to a standard form: one in which the real. part is non negative. Args: quaternions: Quaternions with real part first, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Multiply two quaternions. Usual torch rules for broadcasting apply. Args: a: Quaternions as tensor of shape (..., 4), real part first. b: Quaternions as tensor of shape (..., 4), real part first. Returns: The product of a and b, a tensor of quaternions shape (..., 4). """ aw, ax, ay, az = torch.unbind(a, -1) bw, bx, by, bz = torch.unbind(b, -1) ow = aw * bw - ax * bx - ay * by - az * bz ox = aw * bx + ax * bw + ay * bz - az * by oy = aw * by - ax * bz + ay * bw + az * bx oz = aw * bz + ax * by - ay * bx + az * bw return torch.stack((ow, ox, oy, oz), -1) def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Multiply two quaternions representing rotations. Returning the quaternion representing their composition, i.e. the versor with nonnegative real part. Usual torch rules for broadcasting apply. Args: a: Quaternions as tensor of shape (..., 4), real part first. b: Quaternions as tensor of shape (..., 4), real part first. Returns: The product of a and b, a tensor of quaternions of shape (..., 4). """ ab = quaternion_raw_multiply(a, b) return standardize_quaternion(ab) def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor: """Get the quaternion representingquaternion representing rotation. Args: quaternion: Quaternions as tensor of shape (..., 4), with real part first, which must be versors (unit quaternions). Returns: The inverse, a tensor of quaternions of shape (..., 4). """ scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device) return quaternion * scaling def quaternion_apply( quaternion: torch.Tensor, point: torch.Tensor ) -> torch.Tensor: """Apply the rotation given by a quaternion to a 3D point. Usual torch rules for broadcasting apply. Args: quaternion: Tensor of quaternions, real part first, of shape (..., 4). point: Tensor of 3D points of shape (..., 3). Returns: Tensor of rotated points of shape (..., 3). """ if point.size(-1) != 3: raise ValueError(f"Points are not in 3D, {point.shape}.") real_parts = point.new_zeros(point.shape[:-1] + (1,)) point_as_quaternion = torch.cat((real_parts, point), -1) out = quaternion_raw_multiply( quaternion_raw_multiply(quaternion, point_as_quaternion), quaternion_invert(quaternion), ) return out[..., 1:] def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: """Convert rotations given as axis/angle to rotation matrices. Args: axis_angle: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. Returns: Rotation matrices as tensor of shape (..., 3, 3). """ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor: """Convert rotations given as rotation matrices to axis/angle. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. """ return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: """Convert rotations given as axis/angle to quaternions. Args: axis_angle: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. Returns: quaternions with real part first, as tensor of shape (..., 4). """ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) half_angles = angles * 0.5 eps = 1e-6 small_angles = angles.abs() < eps sin_half_angles_over_angles = torch.empty_like(angles) sin_half_angles_over_angles[~small_angles] = ( torch.sin(half_angles[~small_angles]) / angles[~small_angles] ) # for x small, sin(x/2) is about x/2 - (x/2)^3/6 # so sin(x/2)/x is about 1/2 - (x*x)/48 sin_half_angles_over_angles[small_angles] = ( 0.5 - (angles[small_angles] * angles[small_angles]) / 48 ) quaternions = torch.cat( [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1, ) return quaternions def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor: """Convert rotations given as quaternions to axis/angle. Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. """ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) half_angles = torch.atan2(norms, quaternions[..., :1]) angles = 2 * half_angles eps = 1e-6 small_angles = angles.abs() < eps sin_half_angles_over_angles = torch.empty_like(angles) sin_half_angles_over_angles[~small_angles] = ( torch.sin(half_angles[~small_angles]) / angles[~small_angles] ) # for x small, sin(x/2) is about x/2 - (x/2)^3/6 # so sin(x/2)/x is about 1/2 - (x*x)/48 sin_half_angles_over_angles[small_angles] = ( 0.5 - (angles[small_angles] * angles[small_angles]) / 48 ) return quaternions[..., 1:] / sin_half_angles_over_angles def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: """Converts 6D rotation to rotation matrix. Using Gram--Schmidt orthogonalization per Section B of [1]. Representation by Zhou et al. [1] Args: d6: 6D rotation representation, of size (*, 6) Returns: batch of rotation matrices of size (*, 3, 3) [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ a1, a2 = d6[..., :3], d6[..., 3:] b1 = F.normalize(a1, dim=-1) b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 b2 = F.normalize(b2, dim=-1) b3 = torch.cross(b1, b2, dim=-1) return torch.stack((b1, b2, b3), dim=-2) def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: """Converts rotation matrices to 6D rotation representation by Zhou et al. by dropping the last row. Note that 6D representation is not unique. Args: matrix: batch of rotation matrices of size (*, 3, 3) Returns: 6D rotation representation, of size (*, 6) [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. On the Continuity of Rotation Representations in Neural Networks. IEEE Conference on Computer Vision and Pattern Recognition, 2019. Retrieved from http://arxiv.org/abs/1812.07035 """ batch_dim = matrix.size()[:-2] return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) ================================================ FILE: holomotion/src/motion_retargeting/utils/torch_humanoid_batch.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # This file was originally copied from the [PHC] repository: # https://github.com/ZhengyiLuo/PHC # Modifications have been made to fit the needs of this project. # import os import os.path as osp import sys sys.path.append(os.getcwd()) import copy import logging import xml.etree.ElementTree as ETree from collections import OrderedDict, defaultdict from io import BytesIO import numpy as np import open3d as o3d import scipy.ndimage.filters as filters import smpl_sim.poselib.core.rotation3d as poselib_rotation3d import smpl_sim.utils.rotation_conversions as torch_rotation_conversions import torch from easydict import EasyDict from lxml.etree import XMLParser, parse from omegaconf import DictConfig from scipy.spatial.transform import Rotation as sRot from tqdm import tqdm # from loguru import logger # Configure logging logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s" ) class HumanoidBatch: def __init__(self, cfg, device=None): if device is None: device = torch.device("cpu") self.cfg = cfg self.mjcf_file = cfg.asset.assetFileName parser = XMLParser(remove_blank_text=True) tree = parse( BytesIO(open(self.mjcf_file, "rb").read()), parser=parser, ) self.dof_axis = [] joints = sorted( [ j.attrib["name"] for j in tree.getroot().find("worldbody").findall(".//joint") ] ) motors = sorted( [ m.attrib["name"] for m in tree.getroot().find("actuator").getchildren() ] ) assert len(motors) > 0, "No motors found in the mjcf file" self.num_dof = len(motors) self.num_extend_dof = self.num_dof self.mjcf_data = mjcf_data = self.from_mjcf(self.mjcf_file) self.body_names = copy.deepcopy(mjcf_data["node_names"]) # logger.info(f"Body names from {self.mjcf_file}: {self.body_names}") self._parents = mjcf_data["parent_indices"] self.body_names_augment = copy.deepcopy(mjcf_data["node_names"]) self._proper_kinematic_structure = copy.deepcopy( mjcf_data["node_names"] ) self._offsets = mjcf_data["local_translation"][None,].to(device) self._local_rotation = mjcf_data["local_rotation"][None,].to(device) self.actuated_joints_idx = np.array( [ self.body_names.index(k) for k, v in mjcf_data["body_to_joint"].items() ] ) for m in motors: if m not in joints: print(m) if ( "type" in tree.getroot().find("worldbody").findall(".//joint")[0].attrib and tree.getroot() .find("worldbody") .findall(".//joint")[0] .attrib["type"] == "free" ): for j in tree.getroot().find("worldbody").findall(".//joint")[1:]: self.dof_axis.append( [int(i) for i in j.attrib["axis"].split(" ")] ) self.has_freejoint = True elif ( "type" not in tree.getroot() .find("worldbody") .findall(".//joint")[0] .attrib ): for j in tree.getroot().find("worldbody").findall(".//joint"): self.dof_axis.append( [int(i) for i in j.attrib["axis"].split(" ")] ) self.has_freejoint = True else: for j in tree.getroot().find("worldbody").findall(".//joint")[6:]: self.dof_axis.append( [int(i) for i in j.attrib["axis"].split(" ")] ) self.has_freejoint = False axis_list = [] for _i, axis in enumerate(self.dof_axis): if axis == [1, 0, 0]: axis_list.append("x") elif axis == [0, 1, 0]: axis_list.append("y") elif axis == [0, 0, 1]: axis_list.append("z") else: raise ValueError(f"Invalid axis: {axis}") # print("Axis list for this humanoid: ", axis_list) self.dof_axis = torch.tensor(self.dof_axis) for extend_config in cfg.extend_config: self.body_names_augment += [extend_config.joint_name] self._parents = torch.cat( [ self._parents, torch.tensor( [self.body_names.index(extend_config.parent_name)] ).to(device), ], dim=0, ) self._offsets = torch.cat( [ self._offsets, torch.tensor([[extend_config.pos]]).to(device), ], dim=1, ) self._local_rotation = torch.cat( [ self._local_rotation, torch.tensor([[extend_config.rot]]).to(device), ], dim=1, ) self.num_extend_dof += 1 parent_id = self._proper_kinematic_structure.index( extend_config.parent_name ) self._proper_kinematic_structure.insert( parent_id + 1, extend_config.joint_name ) self.num_bodies = len(self.body_names) self.num_bodies_augment = len(self.body_names_augment) self.joints_range = mjcf_data["joints_range"].to(device) self._local_rotation_mat = ( torch_rotation_conversions.quaternion_to_matrix( self._local_rotation ).float() ) # w, x, y ,z self.load_mesh() self.extend_to_proper_mapping = [] for _i, name in enumerate(self._proper_kinematic_structure): self.extend_to_proper_mapping.append( self.body_names_augment.index(name) ) self.proper_to_extend_mapping = [] for _i, name in enumerate(self.body_names_augment): self.proper_to_extend_mapping.append( self._proper_kinematic_structure.index(name) ) def from_mjcf(self, path): # function from Poselib: tree = ETree.parse(path) xml_doc_root = tree.getroot() xml_world_body = xml_doc_root.find("worldbody") if xml_world_body is None: raise ValueError("MJCF parsed incorrectly please verify it.") # assume this is the root xml_body_root = xml_world_body.find("body") if xml_body_root is None: raise ValueError("MJCF parsed incorrectly please verify it.") # xml_joint_root = xml_body_root.find("joint") # Unused variable node_names = [] parent_indices = [] local_translation = [] local_rotation = [] joints_range = [] body_to_joint = OrderedDict() # recursively adding all nodes into the skel_tree def _add_xml_node(xml_node, parent_index, node_index): node_name = xml_node.attrib.get("name") # parse the local translation into float list pos = np.fromstring( xml_node.attrib.get("pos", "0 0 0"), dtype=float, sep=" " ) quat = np.fromstring( xml_node.attrib.get("quat", "1 0 0 0"), dtype=float, sep=" " ) node_names.append(node_name) parent_indices.append(parent_index) local_translation.append(pos) local_rotation.append(quat) curr_index = node_index node_index += 1 all_joints = xml_node.findall( "joint" ) # joints need to remove the first 6 joints if len(all_joints) == 6: all_joints = all_joints[6:] for joint in all_joints: if joint.attrib.get("range") is not None: joints_range.append( np.fromstring( joint.attrib.get("range"), dtype=float, sep=" " ) ) else: if not joint.attrib.get("type") == "free": joints_range.append([-np.pi, np.pi]) for joint_node in xml_node.findall("joint"): body_to_joint[node_name] = joint_node.attrib.get("name") for next_node in xml_node.findall("body"): node_index = _add_xml_node(next_node, curr_index, node_index) return node_index _add_xml_node(xml_body_root, -1, 0) assert len(joints_range) == self.num_dof return { "node_names": node_names, "parent_indices": torch.from_numpy( np.array(parent_indices, dtype=np.int32) ), "local_translation": torch.from_numpy( np.array(local_translation, dtype=np.float32) ), "local_rotation": torch.from_numpy( np.array(local_rotation, dtype=np.float32) ), "joints_range": torch.from_numpy(np.array(joints_range)), "body_to_joint": body_to_joint, } def fk_batch( self, pose, trans, convert_to_mat=True, return_full=False, dt=1 / 30 ): # device, dtype = pose.device, pose.dtype # Unused variables # pose_input = pose.clone() # Unused variable b, seq_len = pose.shape[:2] pose = pose[ ..., : len(self._parents), : ] # H1 fitted joints might have extra joints if convert_to_mat: pose_quat = torch_rotation_conversions.axis_angle_to_quaternion( pose.clone() ) pose_mat = torch_rotation_conversions.quaternion_to_matrix( pose_quat ) else: pose_mat = pose if pose_mat.shape != 5: pose_mat = pose_mat.reshape(b, seq_len, -1, 3, 3) # j = pose_mat.shape[2] - 1 # Exclude root - unused variable wbody_pos, wbody_mat = self.forward_kinematics_batch( pose_mat[:, :, 1:], pose_mat[:, :, 0:1], trans ) return_dict = EasyDict() wbody_rot = torch_rotation_conversions.wxyz_to_xyzw( torch_rotation_conversions.matrix_to_quaternion(wbody_mat) ) if len(self.cfg.extend_config) > 0: if return_full: return_dict.global_velocity_extend = self._compute_velocity( wbody_pos, dt ) return_dict.global_angular_velocity_extend = ( self._compute_angular_velocity(wbody_rot, dt) ) return_dict.global_translation_extend = wbody_pos.clone() return_dict.global_rotation_mat_extend = wbody_mat.clone() return_dict.global_rotation_extend = wbody_rot wbody_pos = wbody_pos[..., : self.num_bodies, :] wbody_mat = wbody_mat[..., : self.num_bodies, :, :] wbody_rot = wbody_rot[..., : self.num_bodies, :] return_dict.global_translation = wbody_pos return_dict.global_rotation_mat = wbody_mat return_dict.global_rotation = wbody_rot if return_full: rigidbody_linear_velocity = self._compute_velocity(wbody_pos, dt) # Isaac gym is [x, y, z, w]. All the previous functions are # [w, x, y, z] rigidbody_angular_velocity = self._compute_angular_velocity( wbody_rot, dt ) return_dict.local_rotation = ( torch_rotation_conversions.wxyz_to_xyzw(pose_quat) ) return_dict.global_root_velocity = rigidbody_linear_velocity[ ..., 0, : ] return_dict.global_root_angular_velocity = ( rigidbody_angular_velocity[..., 0, :] ) return_dict.global_angular_velocity = rigidbody_angular_velocity return_dict.global_velocity = rigidbody_linear_velocity if len(self.cfg.extend_config) > 0: return_dict.dof_pos = pose.sum(dim=-1)[ ..., 1 : self.num_bodies ] # you can sum it up since unitree's each joint has 1 dof. # Last two are for hands. doesn't really matter. else: if not len(self.actuated_joints_idx) == len(self.body_names): return_dict.dof_pos = pose.sum(dim=-1)[ ..., self.actuated_joints_idx ] else: return_dict.dof_pos = pose.sum(dim=-1)[..., 1:] dof_vel = ( return_dict.dof_pos[:, 1:] - return_dict.dof_pos[:, :-1] ) / dt return_dict.dof_vels = torch.cat( [dof_vel, dof_vel[:, -2:-1]], dim=1 ) return_dict.fps = int(1 / dt) return return_dict def convert_to_proper_kinematic(self, return_dict): if len(self.cfg.extend_config) > 0: return_dict.global_translation_extend = ( return_dict.global_translation_extend[ ..., self.extend_to_proper_mapping, : ] ) return_dict.global_rotation_mat_extend = ( return_dict.global_rotation_mat_extend[ ..., self.extend_to_proper_mapping, :, : ] ) return_dict.global_rotation_extend = ( return_dict.global_rotation_extend[ ..., self.extend_to_proper_mapping, : ] ) return_dict.global_velocity_extend = ( return_dict.global_velocity_extend[ ..., self.extend_to_proper_mapping, : ] ) return_dict.global_angular_velocity_extend = ( return_dict.global_angular_velocity_extend[ ..., self.extend_to_proper_mapping, : ] ) else: return_dict.global_translation = return_dict.global_translation[ ..., self.extend_to_proper_mapping, : ] return_dict.global_rotation_mat = return_dict.global_rotation_mat[ ..., self.extend_to_proper_mapping, :, : ] return_dict.global_rotation = return_dict.global_rotation[ ..., self.extend_to_proper_mapping, : ] return_dict.global_velocity = return_dict.global_velocity[ ..., self.extend_to_proper_mapping, : ] return_dict.global_angular_velocity = ( return_dict.global_angular_velocity[ ..., self.extend_to_proper_mapping, : ] ) return return_dict def forward_kinematics_batch( self, rotations, root_rotations, root_positions ): """Perform forward kinematics using the trajectory and rotations. Arguments (where B = batch size, J = number of joints): -- rotations: (B, J, 4) tensor of unit quaternions describing the local rotations of each joint. -- root_positions: (B, 3) tensor describing the root joint positions. Output: joint positions (B, J, 3) Reference: https://github.com/ZhengyiLuo/PHC/blob/master/phc/utils/ torch_humanoid_batch.py """ device, dtype = root_rotations.device, root_rotations.dtype b, seq_len = rotations.size()[0:2] j = self._offsets.shape[1] positions_world = [] rotations_world = [] expanded_offsets = ( self._offsets[:, None] .expand(b, seq_len, j, 3) .to(device) .type(dtype) ) # print(expanded_offsets.shape, j) for i in range(j): if self._parents[i] == -1: positions_world.append(root_positions) rotations_world.append(root_rotations) else: jpos = ( torch.matmul( rotations_world[self._parents[i]][:, :, 0], expanded_offsets[:, :, i, :, None], ).squeeze(-1) + positions_world[self._parents[i]] ) rot_mat = torch.matmul( rotations_world[self._parents[i]], torch.matmul( self._local_rotation_mat[:, (i) : (i + 1)], rotations[:, :, (i - 1) : i, :], ), ) # rot_mat = torch.matmul(rotations_world[self._parents[i]], # rotations[:, :, (i - 1):i, :]) # print(rotations[:, :, (i - 1):i, :].shape, # self._local_rotation_mat.shape) positions_world.append(jpos) rotations_world.append(rot_mat) positions_world = torch.stack(positions_world, dim=2) rotations_world = torch.cat(rotations_world, dim=2) return positions_world, rotations_world @staticmethod def _compute_velocity(p, time_delta, guassian_filter=True): velocity = np.gradient(p.numpy(), axis=-3) / time_delta if guassian_filter: velocity = torch.from_numpy( filters.gaussian_filter1d(velocity, 2, axis=-3, mode="nearest") ).to(p) else: velocity = torch.from_numpy(velocity).to(p) return velocity @staticmethod def _compute_angular_velocity(r, time_delta: float, guassian_filter=True): # assume the second last dimension is the time axis diff_quat_data = poselib_rotation3d.quat_identity_like(r).to(r) diff_quat_data[..., :-1, :, :] = poselib_rotation3d.quat_mul_norm( r[..., 1:, :, :], poselib_rotation3d.quat_inverse(r[..., :-1, :, :]), ) diff_angle, diff_axis = poselib_rotation3d.quat_angle_axis( diff_quat_data ) angular_velocity = diff_axis * diff_angle.unsqueeze(-1) / time_delta if guassian_filter: angular_velocity = torch.from_numpy( filters.gaussian_filter1d( angular_velocity.numpy(), 2, axis=-3, mode="nearest" ), ) return angular_velocity def load_mesh(self): xml_base = os.path.dirname(self.mjcf_file) # Read the compiler tag from the g1.xml file to find if there is a # meshdir defined tree = ETree.parse(self.mjcf_file) xml_doc_root = tree.getroot() compiler_tag = xml_doc_root.find("compiler") if compiler_tag is not None and "meshdir" in compiler_tag.attrib: mesh_base = os.path.join(xml_base, compiler_tag.attrib["meshdir"]) else: mesh_base = xml_base self.tree = tree = ETree.parse(self.mjcf_file) xml_doc_root = tree.getroot() xml_world_body = xml_doc_root.find("worldbody") xml_assets = xml_doc_root.find("asset") all_mesh = xml_assets.findall(".//mesh") geoms = xml_world_body.findall(".//geom") # all_joints = xml_world_body.findall(".//joint") # Unused variable # all_motors = tree.findall(".//motor") # Unused variable # all_bodies = xml_world_body.findall(".//body") # Unused variable def find_parent(root, child): for parent in root.iter(): for elem in parent: if elem == child: return parent return None mesh_dict = {} # mesh_parent_dict = {} # Unused variable for mesh_file_node in all_mesh: mesh_name = mesh_file_node.attrib["name"] mesh_file = mesh_file_node.attrib["file"] mesh_full_file = osp.join(mesh_base, mesh_file) mesh_obj = o3d.io.read_triangle_mesh(mesh_full_file) mesh_dict[mesh_name] = mesh_obj geom_transform = {} body_to_mesh = defaultdict(set) mesh_to_body = {} for geom_node in geoms: if "mesh" in geom_node.attrib: parent = find_parent(xml_doc_root, geom_node) body_to_mesh[parent.attrib["name"]].add( geom_node.attrib["mesh"] ) mesh_to_body[geom_node] = parent if "pos" in geom_node.attrib or "quat" in geom_node.attrib: geom_transform[parent.attrib["name"]] = {} geom_transform[parent.attrib["name"]]["pos"] = np.array( [0.0, 0.0, 0.0] ) geom_transform[parent.attrib["name"]]["quat"] = np.array( [1.0, 0.0, 0.0, 0.0] ) if "pos" in geom_node.attrib: geom_transform[parent.attrib["name"]]["pos"] = ( np.array( [ float(f) for f in geom_node.attrib["pos"].split(" ") ] ) ) if "quat" in geom_node.attrib: geom_transform[parent.attrib["name"]]["quat"] = ( np.array( [ float(f) for f in geom_node.attrib["quat"].split( " " ) ] ) ) else: pass self.geom_transform = geom_transform self.mesh_dict = mesh_dict self.body_to_mesh = body_to_mesh self.mesh_to_body = mesh_to_body def mesh_fk(self, pose=None, trans=None): """Load the mesh from the XML file and merge into the humanoid. Reference: https://github.com/ZhengyiLuo/PHC/blob/master/phc/utils/ torch_humanoid_batch.py """ if pose is None: fk_res = self.fk_batch( torch.zeros(1, 1, len(self.body_names_augment), 3), torch.zeros(1, 1, 3), ) else: fk_res = self.fk_batch(pose, trans) g_trans = fk_res.global_translation.squeeze() g_rot = fk_res.global_rotation_mat.squeeze() geoms = self.tree.find("worldbody").findall(".//geom") joined_mesh_obj = [] for geom in geoms: if "mesh" not in geom.attrib: continue # parent_name = geom.attrib["mesh"] k = self.mesh_to_body[geom].attrib["name"] mesh_names = self.body_to_mesh[k] body_idx = self.body_names.index(k) body_trans = g_trans[body_idx].numpy().copy() body_rot = g_rot[body_idx].numpy().copy() for mesh_name in mesh_names: mesh_obj = copy.deepcopy(self.mesh_dict[mesh_name]) if k in self.geom_transform: pos = self.geom_transform[k]["pos"] quat = self.geom_transform[k]["quat"] body_trans = body_trans + body_rot @ pos global_rot = ( body_rot @ sRot.from_quat(quat[[1, 2, 3, 0]]).as_matrix() ).T else: global_rot = body_rot.T mesh_obj.rotate(global_rot.T, center=(0, 0, 0)) mesh_obj.translate(body_trans) joined_mesh_obj.append(mesh_obj) # Merge all meshes into a single mesh merged_mesh = joined_mesh_obj[0] for mesh in joined_mesh_obj[1:]: merged_mesh += mesh # Save the merged mesh to a file # merged_mesh.compute_vertex_normals() # o3d.io.write_triangle_mesh(f"data/{self.cfg.humanoid_type}/ # combined_{self.cfg.humanoid_type}.stl", merged_mesh) return merged_mesh # @hydra.main(version_base=None, config_path="../../phc/data/cfg", # config_name="config") def main(cfg: DictConfig): device = torch.device("cpu") humanoid_fk = HumanoidBatch(cfg.robot, device) humanoid_fk.mesh_fk() if __name__ == "__main__": main() ================================================ FILE: holomotion/src/motion_retargeting/utils/visualize_with_mujoco.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # # This file was originally copied from the [PHC] repository: # https://github.com/ZhengyiLuo/PHC # Modifications have been made to fit the needs of this project. import glob import os from typing import Any, Dict, List, Tuple import cv2 import hydra import mujoco import numpy as np import ray from omegaconf import DictConfig from tqdm.auto import tqdm class OffscreenRenderer: """Offscreen renderer (no SMPL markers or joint spheres).""" def __init__(self, model, height, width): self.model = model self.height = height self.width = width # Create OpenGL context self.ctx = mujoco.GLContext(width, height) self.ctx.make_current() # Scene and camera setup self.scene = mujoco.MjvScene(model, maxgeom=1000) self.cam = mujoco.MjvCamera() self.opt = mujoco.MjvOption() self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE self.cam.distance = 4.0 self.cam.azimuth = 60.0 self.cam.elevation = -20 self.cam.lookat = np.array([0.0, 0.0, 1.0]) # Rendering context self.con = mujoco.MjrContext( model, mujoco.mjtFontScale.mjFONTSCALE_100 ) # Buffers self.rgb_buffer = np.zeros((height, width, 3), dtype=np.uint8) self.viewport = mujoco.MjrRect(0, 0, width, height) def render( self, data, ref_body_positions: np.ndarray | None = None, ref_marker_radius: float = 0.03, ref_marker_rgba: np.ndarray | None = None, ): mujoco.mjv_updateScene( self.model, data, self.opt, None, self.cam, mujoco.mjtCatBit.mjCAT_ALL.value, self.scene, ) _draw_body_spheres_to_scene( scene=self.scene, body_positions=ref_body_positions, radius=ref_marker_radius, rgba=ref_marker_rgba, ) mujoco.mjr_render(self.viewport, self.scene, self.con) mujoco.mjr_readPixels(self.rgb_buffer, None, self.viewport, self.con) return np.flipud(self.rgb_buffer) def close(self): self.ctx.free() def _get_key_prefix_order(cfg: DictConfig) -> List[str]: """ Determine the key prefix order used to extract arrays from NPZ files. Priority: 1) cfg.key_prefix_order (list or single value) 2) cfg.key_prefix (single value) 3) default ["ref_", "", "robot_"] """ configured = cfg.get("key_prefix_order", None) if configured is not None: order_list = ( [str(p) for p in configured] if isinstance(configured, (list, tuple)) else [str(configured)] ) else: single = cfg.get("key_prefix", None) if single is not None: order_list = [str(single)] else: order_list = ["ref_", "", "robot_"] print(f"Using key_prefix_order: {order_list}") return order_list def _get_ref_key_prefix_order(cfg: DictConfig) -> List[str]: """Determine the prefix order used to read reference overlay arrays.""" configured = cfg.get("ref_key_prefix_order", None) if configured is not None: order_list = ( [str(p) for p in configured] if isinstance(configured, (list, tuple)) else [str(configured)] ) else: single = cfg.get("ref_key_prefix", None) if single is not None: order_list = [str(single)] else: order_list = ["ref_"] print(f"Using ref_key_prefix_order: {order_list}") return order_list def _pick_with_prefixes( arrays: Dict[str, np.ndarray], base_name: str, prefixes: List[str], ) -> np.ndarray | None: """ Return arrays[prefix + base_name] for the first matching prefix in order. For non-empty prefixes, also attempts "_". """ for prefix in prefixes: if prefix == "": candidate = base_name if candidate in arrays: return arrays[candidate] else: cand1 = f"{prefix}{base_name}" if cand1 in arrays: return arrays[cand1] cand2 = f"{prefix.rstrip('_')}_{base_name}" if cand2 in arrays: return arrays[cand2] return None def _resolve_visualization_arrays( arrays: Dict[str, np.ndarray], key_prefix_order: List[str], draw_ref_body_spheres: bool = False, ref_key_prefix_order: List[str] | None = None, ) -> Dict[str, np.ndarray | None]: """Resolve playback arrays and optional reference overlay arrays.""" dof_pos = _pick_with_prefixes(arrays, "dof_pos", key_prefix_order) global_translation = _pick_with_prefixes( arrays, "global_translation", key_prefix_order ) global_rotation_quat = _pick_with_prefixes( arrays, "global_rotation_quat", key_prefix_order ) ref_body_positions = None if draw_ref_body_spheres: ref_prefixes = ( ref_key_prefix_order if ref_key_prefix_order is not None else ["ref_"] ) ref_body_positions = _pick_with_prefixes( arrays, "global_translation", ref_prefixes ) return { "dof_pos": dof_pos, "global_translation": global_translation, "global_rotation_quat": global_rotation_quat, "ref_body_positions": ref_body_positions, } def _draw_body_spheres_to_scene( scene, body_positions: np.ndarray | None, radius: float, rgba: np.ndarray | None, ) -> None: """Append sphere markers for body positions to the current MuJoCo scene.""" if body_positions is None: return sphere_rgba = ( np.array([0.8, 0.0, 0.0, 1.0], dtype=np.float32) if rgba is None else np.asarray(rgba, dtype=np.float32) ) size = np.array([radius, 0.0, 0.0], dtype=np.float32) mat = np.eye(3, dtype=np.float32).reshape(-1) start = int(scene.ngeom) idx = 0 for pos in body_positions: geom_id = start + idx if geom_id >= scene.maxgeom: break mujoco.mjv_initGeom( scene.geoms[geom_id], mujoco.mjtGeom.mjGEOM_SPHERE, size, pos.astype(np.float32), mat, sphere_rgba, ) idx += 1 scene.ngeom = start + idx def _load_npz_as_motion( npz_path: str, ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], str]: """ Load a single .npz file and return (arrays_dict, metadata_dict, motion_name) - metadata: parsed from JSON - motion_name: file name without extension """ with np.load(npz_path) as z: arrays = {k: z[k] for k in z.files if k != "metadata"} meta_raw = z.get("metadata", None) if meta_raw is None: metadata = {} else: metadata = {} try: metadata = dict(np.atleast_1d(meta_raw).tolist()) except Exception: try: metadata = {**(dict()), **(eval(str(meta_raw)))} except Exception: pass # Parse metadata as JSON string try: import json metadata = json.loads(str(np.atleast_1d(meta_raw)[0])) except Exception: pass motion_name = os.path.splitext(os.path.basename(npz_path))[0] return arrays, metadata, motion_name def _collect_all_npz( npz_root: str, motion_name: str ) -> List[Tuple[Dict[str, np.ndarray], Dict[str, Any], str]]: """Collect all NPZ files to process based on configuration.""" print("Collecting NPZ files...", npz_root, motion_name) base = ( os.path.join(npz_root, "clips") if os.path.isdir(os.path.join(npz_root, "clips")) else npz_root ) if motion_name == "all": npz_files = [ p for p in glob.glob( os.path.join(base, "**", "*.npz"), recursive=True ) ] else: # try both base and base/clips candidate = os.path.join(base, f"{motion_name}.npz") npz_files = [candidate] motions = [] for f in tqdm(npz_files, desc="Loading npz files"): try: arrays, metadata, name = _load_npz_as_motion(f) motions.append((arrays, metadata, name)) except Exception as e: print(f"Failed to load {f}: {e}") return motions def _infer_fps_from_meta( metadata: Dict[str, Any], default_fps: float = 50.0 ) -> float: """Infer FPS value from metadata.""" try: return float(metadata.get("motion_fps", default_fps)) except Exception: return float(default_fps) def _time_length(*arrays) -> int: """Return the smallest time dimension length among given arrays, ignoring None.""" T = None for a in arrays: if isinstance(a, np.ndarray) and a.ndim >= 1: t = a.shape[0] T = t if T is None else min(T, t) return T if T is not None else 0 @ray.remote def process_single_motion_remote_npz( arrays: Dict[str, np.ndarray], metadata: Dict[str, Any], motion_name: str, cfg_dict: dict, ) -> str: try: cfg = DictConfig(cfg_dict) # MuJoCo model mj_model = mujoco.MjModel.from_xml_path(cfg.robot.asset.assetFileName) mj_data = mujoco.MjData(mj_model) # Renderer width, height = 1280, 720 renderer = OffscreenRenderer(mj_model, height, width) # FPS src_fps = _infer_fps_from_meta(metadata, default_fps=50.0) skip_frames = getattr(cfg, "skip_frames", 1) actual_fps = src_fps / max(1, int(skip_frames)) # Video writer fourcc = cv2.VideoWriter_fourcc(*"mp4v") out_path = os.path.join(cfg.video_dir, f"{motion_name}.mp4") os.makedirs(os.path.dirname(out_path), exist_ok=True) out = cv2.VideoWriter(out_path, fourcc, actual_fps, (width, height)) try: prefix_order = _get_key_prefix_order(cfg) draw_ref_body_spheres = bool( getattr(cfg, "draw_ref_body_spheres", False) ) ref_prefix_order = _get_ref_key_prefix_order(cfg) resolved = _resolve_visualization_arrays( arrays=arrays, key_prefix_order=prefix_order, draw_ref_body_spheres=draw_ref_body_spheres, ref_key_prefix_order=ref_prefix_order, ) dof_pos = resolved["dof_pos"] gpos = resolved["global_translation"] grot = resolved["global_rotation_quat"] ref_body_positions = resolved["ref_body_positions"] if ( not isinstance(dof_pos, np.ndarray) or not isinstance(gpos, np.ndarray) or not isinstance(grot, np.ndarray) ): raise ValueError( "Missing required NPZ keys: dof_pos / global_translation / global_rotation_quat" ) # Time dimension alignment T = _time_length(dof_pos, gpos, grot, ref_body_positions) if T == 0: raise ValueError("No valid frames found.") for t in range(0, T, max(1, int(skip_frames))): # Root position and quaternion: take from body 0 root_pos = gpos[t, 0] root_quat_xyzw = grot[t, 0] root_quat_wxyz = root_quat_xyzw[[3, 0, 1, 2]] mj_data.qpos[:3] = root_pos mj_data.qpos[3:7] = root_quat_wxyz mj_data.qpos[7:] = dof_pos[t] mujoco.mj_forward(mj_model, mj_data) safe_lookat = np.array( renderer.cam.lookat ) # 当前相机中心,先取出来 safe_lookat[0] = root_pos[0] safe_lookat[1] = root_pos[1] min_height = 1.0 safe_lookat[2] = max(root_pos[2], min_height) renderer.cam.lookat[:] = safe_lookat frame_ref_body_positions = ( ref_body_positions[t] if isinstance(ref_body_positions, np.ndarray) else None ) frame = renderer.render( mj_data, ref_body_positions=frame_ref_body_positions, ) # Convert RGB (MuJoCo) -> BGR (OpenCV) before writing frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame_bgr) finally: out.release() renderer.close() return motion_name except Exception as e: return f"ERROR_{motion_name}: {str(e)}" class MotionRendererNPZ: def process_single_motion( self, arrays: Dict[str, np.ndarray], metadata: Dict[str, Any], motion_name: str, cfg: DictConfig, ): mj_model = mujoco.MjModel.from_xml_path(cfg.robot.asset.assetFileName) mj_data = mujoco.MjData(mj_model) width, height = 1280, 720 renderer = OffscreenRenderer(mj_model, height, width) src_fps = _infer_fps_from_meta(metadata, default_fps=50.0) skip_frames = getattr(cfg, "skip_frames", 1) actual_fps = src_fps / max(1, int(skip_frames)) fourcc = cv2.VideoWriter_fourcc(*"mp4v") out_path = os.path.join(cfg.video_dir, f"{motion_name}.mp4") os.makedirs(os.path.dirname(out_path), exist_ok=True) out = cv2.VideoWriter(out_path, fourcc, actual_fps, (width, height)) try: prefix_order = _get_key_prefix_order(cfg) draw_ref_body_spheres = bool( getattr(cfg, "draw_ref_body_spheres", False) ) ref_prefix_order = _get_ref_key_prefix_order(cfg) resolved = _resolve_visualization_arrays( arrays=arrays, key_prefix_order=prefix_order, draw_ref_body_spheres=draw_ref_body_spheres, ref_key_prefix_order=ref_prefix_order, ) dof_pos = resolved["dof_pos"] gpos = resolved["global_translation"] grot = resolved["global_rotation_quat"] ref_body_positions = resolved["ref_body_positions"] if ( not isinstance(dof_pos, np.ndarray) or not isinstance(gpos, np.ndarray) or not isinstance(grot, np.ndarray) ): raise ValueError( "Missing required NPZ keys: dof_pos / global_translation / global_rotation_quat" ) T = _time_length(dof_pos, gpos, grot, ref_body_positions) if T == 0: raise ValueError("No valid frames found.") for t in tqdm( range(0, T, max(1, int(skip_frames))), desc=f"Rendering {motion_name}", ): root_pos = gpos[t, 0] root_quat_xyzw = grot[t, 0] root_quat_wxyz = root_quat_xyzw[[3, 0, 1, 2]] mj_data.qpos[:3] = root_pos mj_data.qpos[3:7] = root_quat_wxyz mj_data.qpos[7:] = dof_pos[t] mujoco.mj_forward(mj_model, mj_data) safe_lookat = np.array( renderer.cam.lookat ) # 当前相机中心,先取出来 safe_lookat[0] = root_pos[0] safe_lookat[1] = root_pos[1] min_height = 1.0 safe_lookat[2] = max(root_pos[2], min_height) renderer.cam.lookat[:] = safe_lookat frame_ref_body_positions = ( ref_body_positions[t] if isinstance(ref_body_positions, np.ndarray) else None ) frame = renderer.render( mj_data, ref_body_positions=frame_ref_body_positions, ) frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame_bgr) finally: out.release() renderer.close() return motion_name @hydra.main( version_base=None, config_path="../../../config/motion_retargeting", config_name="unitree_G1_29dof_retargeting", ) def main(cfg: DictConfig) -> None: """ Required config fields: - cfg.robot.asset.assetFileName : Path to the MuJoCo XML file - cfg.video_dir : Output video directory - cfg.motion_npz_root : Directory containing NPZ files - cfg.motion_name : "all" or a specific clip name (without extension) - cfg.skip_frames : Frame step size (>=1) Optional: - cfg.key_prefix_order : List[str] or str for key prefix matching order - cfg.key_prefix : Single prefix to use (overridden by key_prefix_order) """ try: # NPZ input motions = _collect_all_npz(cfg.motion_npz_root, cfg.motion_name) if not motions: print("No NPZ motions found.") return # Ray parallel or single-thread mode if cfg.motion_name == "all": if not ray.is_initialized(): num_cpus = min(os.cpu_count(), cfg.get("max_workers", 8)) ray.init(num_cpus=num_cpus) print(f"Initialized Ray with {num_cpus} workers") cfg_dict = dict(cfg) tasks = [ process_single_motion_remote_npz.remote( arr, meta, name, cfg_dict ) for (arr, meta, name) in motions ] completed, failed = [], [] with tqdm(total=len(tasks), desc="Processing Motions") as pbar: remaining = list(tasks) while remaining: ready, remaining = ray.wait( remaining, num_returns=1, timeout=1.0 ) for t in ready: try: res = ray.get(t) if isinstance(res, str) and res.startswith( "ERROR_" ): failed.append(res) print(f"Failed: {res}") else: completed.append(res) print(f"Completed: {res}") except Exception as e: failed.append(f"Task exception: {e}") pbar.update(1) print("\nProcessing complete!") print(f"Success: {len(completed)}; Failed: {len(failed)}") if failed: for f in failed: print(" -", f) ray.shutdown() else: renderer = MotionRendererNPZ() for arr, meta, name in motions: res = renderer.process_single_motion(arr, meta, name, cfg) print(f"Processed: {res}") except Exception as e: print(f"Error during processing: {e}") if ray.is_initialized(): ray.shutdown() if __name__ == "__main__": main() ================================================ FILE: holomotion/src/training/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/training/h5_dataloader.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Simplified HDF5 motion cache backed by a PyTorch ``DataLoader``. This module provides two core utilities: * ``Hdf5MotionDataset`` – loads contiguous motion windows directly from HDF5 shards using metadata stored in ``manifest.json``. * ``MotionClipBatchCache`` – maintains a double-buffered cache of motion clips with deterministic swapping semantics suitable for high-throughput reinforcement learning. Compared to the legacy slot-based prefetcher, this implementation keeps the pipeline intentionally simple: * A dataset-worker keeps shard handles open locally; no Ray dependency. * Each cached batch has a fixed shape ``[max_num_clips, max_frame_length, feature_dims]``. * Swapping a batch is handled via an O(1) pointer flip once the next batch is staged on the desired device (CPU or GPU). The cache exposes helper methods that mirror the data access patterns required by ``RefMotionCommand``: * ``sample_env_assignments`` for initial clip/frame sampling. * ``gather_tensor`` to fetch exactly one tensor field for ``1 + n_future`` frames per environment. All tensors returned by this module are ``torch.float32`` unless stated otherwise; tensor shapes are noted explicitly in type annotations. """ from __future__ import annotations import json import math import os import re import time from concurrent.futures import ThreadPoolExecutor from collections import OrderedDict from dataclasses import dataclass from functools import partial from pathlib import Path from typing import ( Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, ) import h5py import numpy as np import torch import torch.multiprocessing as mp from torch.utils.data import DataLoader, Dataset, DistributedSampler, Sampler from loguru import logger from tabulate import tabulate from tqdm import tqdm from holomotion.src.motion_retargeting.reference_filtering import ( butterworth_filter_root_dof_arrays, ) from holomotion.src.utils import torch_utils from holomotion.src.motion_retargeting.holomotion_fk import HoloMotionFK Tensor = torch.Tensor def _cpu_only_dataloader_worker_init_fn(worker_id: int) -> None: """Keep cache workers lightweight without mutating CUDA visibility.""" del worker_id torch.set_num_threads(1) def _allocate_batch_counts( raw_counts: List[float], target_total: int ) -> List[int]: """Allocate integer counts that sum exactly to target_total.""" total = int(max(0, target_total)) if len(raw_counts) == 0: return [] base_counts = [max(0, int(c)) for c in raw_counts] residuals = [float(c) - float(int(c)) for c in raw_counts] remaining = total - int(sum(base_counts)) if remaining > 0: order = sorted( range(len(residuals)), key=lambda i: residuals[i], reverse=True, ) idx_pos = 0 while remaining > 0: j = order[idx_pos % len(order)] base_counts[j] += 1 remaining -= 1 idx_pos += 1 elif remaining < 0: order = sorted(range(len(residuals)), key=lambda i: residuals[i]) idx_pos = 0 while remaining < 0: j = order[idx_pos % len(order)] if base_counts[j] > 0: base_counts[j] -= 1 remaining += 1 idx_pos += 1 if sum(base_counts) != total: raise RuntimeError( "Internal error: integer batch-count allocation did not preserve total." ) return [max(0, int(c)) for c in base_counts] def _configure_weighted_bins( keys: List[str], cfg: Mapping[str, Any], batch_size_for_log: int, ) -> Tuple[List[List[int]], List[float], List[Dict[str, Any]]]: """Common helper to parse config, assign bins, and compute batch fractions.""" if batch_size_for_log <= 0: batch_size_for_log = 1 cfg_local: Dict[str, Any] = dict(cfg or {}) patterns_cfg = cfg_local.get("bin_regex_patterns") if patterns_cfg is None: patterns_cfg = cfg_local.get("bin_regrex_patterns") if not patterns_cfg: raise ValueError( "weighted_bin configuration requires 'bin_regex_patterns' " "(list of {regex, ratio}) to be configured" ) compiled_patterns: List[Dict[str, Any]] = [] ratios: List[float] = [] for idx, entry in enumerate(patterns_cfg): if not isinstance(entry, Mapping): raise ValueError( f"Entry {idx} in bin_regex_patterns must be a mapping, " f"got {type(entry)}" ) regex_str = entry.get("regex", entry.get("regrex", None)) if not isinstance(regex_str, str) or not regex_str: raise ValueError( f"Entry {idx} in bin_regex_patterns is missing a non-empty " f"'regex' field" ) ratio_val = entry.get("ratio", None) if ratio_val is None: raise ValueError( f"Entry {idx} in bin_regex_patterns is missing 'ratio'" ) ratio_f = float(ratio_val) if ratio_f < 0.0 or ratio_f > 1.0: raise ValueError( f"Entry {idx} in bin_regex_patterns has invalid ratio " f"{ratio_f:.6f}; expected in [0.0, 1.0]" ) compiled_patterns.append( { "name": str(entry.get("name", f"bin_{idx}")), "regex": regex_str, "compiled": re.compile(regex_str), } ) ratios.append(ratio_f) sum_explicit = float(sum(ratios)) if sum_explicit > 1.0 + 1.0e-6: raise ValueError( f"Sum of weighted-bin ratios is {sum_explicit:.6f} (> 1.0). " "Please reduce the ratios so that their sum is <= 1.0." ) if sum_explicit > 1.0: sum_explicit = 1.0 others_ratio = max(0.0, 1.0 - sum_explicit) if len(keys) == 0: raise ValueError( "weighted_bin configuration received an empty key set" ) num_items_total = float(len(keys)) num_explicit = len(compiled_patterns) bin_indices: List[List[int]] = [[] for _ in range(num_explicit + 1)] for idx, motion_key in enumerate(keys): assigned = False for b_idx, pat in enumerate(compiled_patterns): if pat["compiled"].search(motion_key): bin_indices[b_idx].append(idx) assigned = True break if not assigned: bin_indices[-1].append(idx) # Combine explicit ratios with implicit "others" ratio all_ratios: List[float] = list(ratios) all_ratios.append(others_ratio) # If all motion keys are covered by explicit regex bins, but the specified # ratios sum to less than 1.0, linearly reweight explicit ratios so that # they sum to 1.0 and disable the implicit "others" bin. others_count = len(bin_indices[-1]) if others_count == 0 and others_ratio > 0.0 and sum_explicit > 0.0: scale = 1.0 / sum_explicit ratios = [r * scale for r in ratios] others_ratio = 0.0 all_ratios = list(ratios) all_ratios.append(others_ratio) logger.info( "Weighted-bin: all regex bins cover the dataset; " "linearly reweighted explicit ratios to sum to 1.0 and disabled " "the implicit 'others' bin." ) # Validate non-empty bins for any positive ratio (including others) for b_idx, r in enumerate(all_ratios): if r > 0.0 and len(bin_indices[b_idx]) == 0: if b_idx < num_explicit: name = compiled_patterns[b_idx]["name"] regex_s = compiled_patterns[b_idx]["regex"] raise ValueError( f"Weighted-bin '{name}' (regex='{regex_s}') has ratio " f"{r:.6f} but matched no motion keys" ) raise ValueError( f"Weighted-bin 'others' has ratio {r:.6f} but matched no motion keys" ) # Prepare logging summary using the configured cache batch size raw_counts_log = [ratio * batch_size_for_log for ratio in all_ratios] base_counts_log = _allocate_batch_counts( raw_counts=raw_counts_log, target_total=batch_size_for_log, ) batch_fractions_log = [ float(c) / float(batch_size_for_log) for c in base_counts_log ] # Build specs using the final, actually used batch fractions specs: List[Dict[str, Any]] = [] total_items = float(max(1, num_items_total)) for b_idx in range(num_explicit): name = compiled_patterns[b_idx]["name"] regex_s = compiled_patterns[b_idx]["regex"] n = len(bin_indices[b_idx]) ds_frac = float(n) / total_items bf = batch_fractions_log[b_idx] specs.append( { "name": name, "regex": regex_s, "ratio": bf, "count": n, "dataset_fraction": ds_frac, "batch_fraction": bf, } ) # Others bin others_name = "others" others_regex = "" n_o = len(bin_indices[-1]) ds_frac_o = float(n_o) / total_items bf_o = batch_fractions_log[-1] specs.append( { "name": others_name, "regex": others_regex, "ratio": bf_o, "count": n_o, "dataset_fraction": ds_frac_o, "batch_fraction": bf_o, } ) return bin_indices, all_ratios, specs def _collect_manifest_keys( manifest_path: str | Sequence[str], ) -> Tuple[List[str], Dict[str, str], List[str]]: if isinstance(manifest_path, (str, os.PathLike)): manifest_paths: List[str] = [str(manifest_path)] else: manifest_paths = [str(p) for p in manifest_path] if len(manifest_paths) == 0: raise ValueError("Expected at least one manifest path") key_source: Dict[str, str] = {} for mp in manifest_paths: if not os.path.exists(mp): raise FileNotFoundError( f"HDF5 manifest not found at {mp}. " "Please set robot.motion.hdf5_root/train_hdf5_roots " "to the correct path." ) with open(mp, "r", encoding="utf-8") as handle: manifest = json.load(handle) clips = manifest.get("clips", {}) if not clips: raise ValueError( f"Manifest at {mp} contains no clips; cannot preview sampling." ) for key in clips.keys(): if key in key_source: raise ValueError( f"Duplicate motion clip key '{key}' found in multiple " "manifests; clip keys must be globally unique." ) key_source[key] = mp return list(key_source.keys()), key_source, manifest_paths def _normalize_online_filter_cfg( cfg: Optional[Mapping[str, Any]], *, default_vel_smoothing_sigma: float = 2.0, ) -> Dict[str, Any]: cfg_local = dict(cfg or {}) enabled = bool(cfg_local.get("enabled", False)) cutoff_pool_cfg = cfg_local.get("butter_cutoff_hz_pool", []) cutoff_pool = tuple(float(v) for v in cutoff_pool_cfg) ref_vel_smoothing_sigma = float( cfg_local.get("ref_vel_smoothing_sigma", default_vel_smoothing_sigma) ) ft_ref_vel_smoothing_sigma = float( cfg_local.get( "ft_ref_vel_smoothing_sigma", default_vel_smoothing_sigma ) ) if enabled and len(cutoff_pool) == 0: raise ValueError( "online_filter.enabled=True requires butter_cutoff_hz_pool to " "contain at least one cutoff value" ) butter_order = int(cfg_local.get("butter_order", 4)) if butter_order <= 0: raise ValueError("online_filter.butter_order must be positive") return { "enabled": enabled, "butter_order": butter_order, "butter_cutoff_hz_pool": cutoff_pool, "ref_vel_smoothing_sigma": ref_vel_smoothing_sigma, "ft_ref_vel_smoothing_sigma": ft_ref_vel_smoothing_sigma, } def preview_weighted_bin_from_manifest( manifest_path: str | Sequence[str], batch_size: int, cfg: Mapping[str, Any], ) -> None: """Lightweight preview of weighted-bin sampling using manifest.json only. This helper is intended to be called at configuration time before any MotionClipBatchCache/DataLoader is constructed, so that invalid regex or ratio settings can fail fast without incurring the cost of cache setup. """ if batch_size <= 0: batch_size = 1 keys, _, _ = _collect_manifest_keys(manifest_path=manifest_path) _, _, specs = _configure_weighted_bins( keys=keys, cfg=cfg, batch_size_for_log=batch_size, ) table_rows = [] for item in specs: table_rows.append( [ item["name"], item["regex"], f"{item['ratio']:.4f}", int(item["count"]), f"{item['dataset_fraction']:.4f}", f"{item['batch_fraction']:.4f}", ] ) headers = [ "bin", "regex", "final_ratio", "num_clips", "clip_fraction", "batch_fraction", ] logger.info( "Weighted-bin config preview (manifest-level):\n" + tabulate(table_rows, headers=headers, tablefmt="simple_outline") ) def preview_uniform_from_manifest( manifest_path: str | Sequence[str], batch_size: int, *, max_frame_length: int, min_window_length: int, handpicked_motion_names: Optional[Sequence[str]] = None, excluded_motion_names: Optional[Sequence[str]] = None, ) -> None: """Manifest-level preview table for uniform/curriculum sampling.""" if batch_size <= 0: batch_size = 1 if max_frame_length <= 0: raise ValueError("max_frame_length must be positive") if min_window_length <= 0: raise ValueError("min_window_length must be positive") _, _, manifest_paths = _collect_manifest_keys(manifest_path=manifest_path) handpicked_set = ( set(handpicked_motion_names) if handpicked_motion_names is not None else None ) excluded_set = ( set(excluded_motion_names) if excluded_motion_names is not None else None ) def _normalize_key(value: Any) -> Optional[str]: if value is None: return None if isinstance(value, bytes): value = value.decode("utf-8") key = value if isinstance(value, str) else str(value) if not key: return None return key def _build_aliases(motion_key: str, meta: Mapping[str, Any]) -> List[str]: aliases: List[str] = [] def _add(value: Any) -> None: key = _normalize_key(value) if key is None or key in aliases: return aliases.append(key) _add(motion_key) if isinstance(meta, Mapping): _add(meta.get("motion_key")) metadata = meta.get("metadata") if isinstance(metadata, Mapping): _add(metadata.get("motion_key")) _add(metadata.get("raw_motion_key")) return aliases def _count_windows(clip_length: int) -> Tuple[int, int]: remaining = clip_length offset = 0 num_windows = 0 num_frames = 0 while remaining > 0: window_length = min(max_frame_length, remaining) if window_length >= min_window_length: num_windows += 1 num_frames += int(window_length) offset += int(window_length) remaining = max(0, clip_length - offset) return num_windows, num_frames stats_by_manifest: Dict[str, Dict[str, float]] = {} for mp in manifest_paths: with open(mp, "r", encoding="utf-8") as handle: manifest = json.load(handle) clips = manifest.get("clips", {}) if not clips: raise ValueError( f"Manifest at {mp} contains no clips; cannot preview sampling." ) num_windows = 0 num_frames = 0 duration_s = 0.0 for key, meta in clips.items(): if isinstance(meta, Mapping): aliases = _build_aliases(key, meta) else: aliases = [key] if handpicked_set is not None and not any( alias in handpicked_set for alias in aliases ): continue if excluded_set is not None and any( alias in excluded_set for alias in aliases ): continue length = ( int(meta.get("length", 0)) if isinstance(meta, Mapping) else 0 ) if length <= 0: continue metadata = ( meta.get("metadata") if isinstance(meta, Mapping) else None ) motion_fps_val = None if isinstance(metadata, Mapping): motion_fps_val = metadata.get("motion_fps") if motion_fps_val is None and isinstance(meta, Mapping): motion_fps_val = meta.get("motion_fps") if motion_fps_val is None: raise ValueError( f"motion_fps missing for clip {key} in manifest {mp}" ) motion_fps = float(motion_fps_val) if motion_fps <= 0.0: raise ValueError( f"Invalid motion_fps {motion_fps} for clip {key} in {mp}" ) clip_windows, clip_frames = _count_windows(length) num_windows += int(clip_windows) num_frames += int(clip_frames) duration_s += float(clip_frames) / float(motion_fps) stats_by_manifest[mp] = { "num_windows": float(num_windows), "num_frames": float(num_frames), "duration_s": float(duration_s), } total_windows = int( sum(stats["num_windows"] for stats in stats_by_manifest.values()) ) if total_windows == 0: raise ValueError( "No motion windows satisfy the requested frame length constraints" ) table_rows = [] denom = float(max(1, total_windows)) for mp in manifest_paths: stats = stats_by_manifest.get(mp, {}) count = int(stats.get("num_windows", 0)) frames = int(stats.get("num_frames", 0)) duration_h = float(stats.get("duration_s", 0.0)) / 3600.0 frac = float(count) / denom table_rows.append( [ os.path.dirname(mp), count, f"{frac:.4f}", frames, f"{duration_h:.2f}", f"{frac:.4f}", ] ) headers = [ "dataset_root", "num_windows", "window_fraction", "num_frames", "duration_h", "batch_fraction", ] logger.info( "Uniform sampling preview (manifest-level):\n" + tabulate(table_rows, headers=headers, tablefmt="simple_outline") ) def preview_sampling_from_cfg(motion_cfg: Mapping[str, Any]) -> None: """Preview manifest-level sampling table for uniform/weighted-bin.""" sampling_strategy_cfg = motion_cfg.get("sampling_strategy", None) if sampling_strategy_cfg is None: sampling_strategy = "uniform" else: sampling_strategy = str(sampling_strategy_cfg).lower() if sampling_strategy not in ("uniform", "weighted_bin", "curriculum"): return backend = str(motion_cfg.get("backend", "hdf5")).lower() if backend not in ("hdf5", "hdf5_simple", "hdf5_v2"): return train_roots = _normalize_root_list( motion_cfg.get("train_hdf5_roots", None) ) if len(train_roots) == 0: hdf5_root = motion_cfg.get("hdf5_root", None) if not hdf5_root: return train_roots = [str(hdf5_root)] manifest_paths = [ os.path.join(str(root), "manifest.json") for root in train_roots ] cache_cfg = motion_cfg.get("cache", {}) batch_size = int(cache_cfg.get("max_num_clips", 1)) if sampling_strategy == "weighted_bin": weighted_bin_cfg = dict(motion_cfg.get("weighted_bin", {})) preview_weighted_bin_from_manifest( manifest_path=manifest_paths if len(manifest_paths) > 1 else manifest_paths[0], batch_size=batch_size, cfg=weighted_bin_cfg, ) return max_frame_length = int(motion_cfg.get("max_frame_length", 1)) min_window_length = int(motion_cfg.get("min_frame_length", 1)) handpicked_motion_names = motion_cfg.get("handpicked_motion_names", None) excluded_motion_names = motion_cfg.get("excluded_motion_names", None) preview_uniform_from_manifest( manifest_path=manifest_paths if len(manifest_paths) > 1 else manifest_paths[0], batch_size=batch_size, max_frame_length=max_frame_length, min_window_length=min_window_length, handpicked_motion_names=handpicked_motion_names, excluded_motion_names=excluded_motion_names, ) MANDATORY_DATASETS = { "dof_pos": "dof_pos", "dof_vel": "dof_vel", "rg_pos": "global_translation", "rb_rot": "global_rotation_quat", "body_vel": "global_velocity", "body_ang_vel": "global_angular_velocity", } class _WorldFrameNormalizeTransform: """Normalize motion tensors into a canonical z-up world frame in-place.""" @staticmethod def _apply_prefix( arrays: Dict[str, Tensor], prefix: str, *, offset_xy: Tensor, q_flat_wxyz: Tensor, ref_rg_pos_shape: torch.Size, ref_rb_rot_shape: torch.Size, ) -> None: pos_key = f"{prefix}rg_pos" rot_key = f"{prefix}rb_rot" vel_key = f"{prefix}body_vel" ang_key = f"{prefix}body_ang_vel" if ( pos_key not in arrays or rot_key not in arrays or vel_key not in arrays or ang_key not in arrays ): return pos = arrays[pos_key] rot = arrays[rot_key] vel = arrays[vel_key] ang = arrays[ang_key] if pos.shape != ref_rg_pos_shape or rot.shape != ref_rb_rot_shape: return # Center XY using canonical offset. pos[..., 0] -= offset_xy[0] pos[..., 1] -= offset_xy[1] # Rotate vectors using shared quaternion utilities (WXYZ convention). pos_flat = pos.reshape(-1, 3) vel_flat = vel.reshape(-1, 3) ang_flat = ang.reshape(-1, 3) pos[:] = torch_utils.quat_apply(q_flat_wxyz, pos_flat).reshape_as(pos) vel[:] = torch_utils.quat_apply(q_flat_wxyz, vel_flat).reshape_as(vel) ang[:] = torch_utils.quat_apply(q_flat_wxyz, ang_flat).reshape_as(ang) # Rotate orientations: q' = q_heading_inv * q. rot_flat_xyzw = rot.reshape(-1, 4) rot_flat_wxyz = torch_utils.xyzw_to_wxyz(rot_flat_xyzw) rot_out_wxyz = torch_utils.quat_mul(q_flat_wxyz, rot_flat_wxyz) rot[:] = torch_utils.wxyz_to_xyzw(rot_out_wxyz).reshape_as(rot) def __call__(self, arrays: Dict[str, Tensor]) -> None: if "ref_rg_pos" not in arrays or "ref_rb_rot" not in arrays: raise ValueError("ref_rg_pos and ref_rb_rot are required") if "ref_body_vel" not in arrays or "ref_body_ang_vel" not in arrays: raise ValueError("ref_body_vel and ref_body_ang_vel are required") rg_pos = arrays["ref_rg_pos"] rb_rot = arrays["ref_rb_rot"] # Root pose at frame 0, body 0 (XYZW quaternion, z-up). p_root0 = rg_pos[0, 0] # [3] q_root0 = rb_rot[0, 0] # [4] # Compute XY offset from root at frame 0 (will be applied in _apply_to_set). offset_xy = p_root0.clone() offset_xy[2] = 0.0 # Extract yaw from q_root0 (XYZW) using z-up convention. x = q_root0[0] y = q_root0[1] z = q_root0[2] w = q_root0[3] siny_cosp = 2.0 * (w * z + x * y) cosy_cosp = w * w + x * x - y * y - z * z yaw0 = torch.atan2(siny_cosp, cosy_cosp) # Quaternion for rotation around +Z by -yaw0 (remove initial heading). half = -0.5 * yaw0 sin_half = torch.sin(half) cos_half = torch.cos(half) q_heading_inv = torch.stack( [ torch.zeros_like(sin_half), torch.zeros_like(sin_half), sin_half, cos_half, ], dim=-1, ) # [4], XYZW t, b, _ = rg_pos.shape q_flat = q_heading_inv.view(1, 1, 4).expand(t, b, 4).reshape(-1, 4) q_flat_wxyz = torch_utils.xyzw_to_wxyz(q_flat) for pfx in ("ref_", "ft_ref_"): self._apply_prefix( arrays, pfx, offset_xy=offset_xy, q_flat_wxyz=q_flat_wxyz, ref_rg_pos_shape=rg_pos.shape, ref_rb_rot_shape=rb_rot.shape, ) class _CpuFKTransform: """Compute FK on CPU and write ref_* tensors in-place.""" def __init__(self, robot_file_path: str) -> None: self._fk = HoloMotionFK( robot_file_path=str(robot_file_path), device=torch.device("cpu") ) self._fk = self._fk.to(torch.device("cpu")) def __call__( self, arrays: Dict[str, Tensor], fps: float, prefix: str = "ref_", vel_smoothing_sigma: float = 2.0, ) -> None: root_pos_key = f"{prefix}root_pos" root_rot_key = f"{prefix}root_rot" dof_pos_key = f"{prefix}dof_pos" if ( root_pos_key not in arrays or root_rot_key not in arrays or dof_pos_key not in arrays ): raise KeyError(f"Missing {prefix}root_* or {prefix}dof_pos for FK") with torch.no_grad(): fk_out = self._fk( root_pos=arrays[root_pos_key][None, ...], root_quat=arrays[root_rot_key][None, ...], dof_pos=arrays[dof_pos_key][None, ...], fps=float(fps), vel_smoothing_sigma=float(vel_smoothing_sigma), quat_format="xyzw", ) arrays[f"{prefix}rg_pos"] = fk_out["global_translation"][0] arrays[f"{prefix}rb_rot"] = fk_out["global_rotation_quat"][0] arrays[f"{prefix}body_vel"] = fk_out["global_velocity"][0] arrays[f"{prefix}body_ang_vel"] = fk_out["global_angular_velocity"][0] arrays[f"{prefix}dof_vel"] = fk_out["dof_vel"][0] @dataclass class MotionWindow: """Metadata describing a contiguous motion window within an HDF5 shard.""" motion_key: str # unique per window shard_index: int start: int length: int raw_motion_key: str # original clip key window_index: int @dataclass class MotionClipSample: """In-memory representation of a motion window. Attributes: motion_key: Unique window identifier (includes slice info). raw_motion_key: Original clip identifier from manifest. tensors: Mapping from tensor name to data tensor of shape ``[window_length, ...]`` (float32 unless specified otherwise). length: Number of valid frames contained in the sample (``<=`` ``max_frame_length``). """ motion_key: str raw_motion_key: str window_index: int tensors: Dict[str, Tensor] length: int @dataclass class ClipBatch: """Batch of motion clips ready for consumption by the environment. Attributes: tensors: Mapping from tensor name to tensor with shape ``[batch_size, max_frame_length, ...]`` placed on the staging device. lengths: Valid frame counts per clip ``[batch_size]``. motion_keys: List of motion keys corresponding to each clip. max_frame_length: Fixed length configured for the cache. """ tensors: Dict[str, Tensor] lengths: Tensor motion_keys: List[str] raw_motion_keys: List[str] window_indices: Tensor max_frame_length: int @staticmethod def collate_fn(samples: List[MotionClipSample]) -> "ClipBatch": if len(samples) == 0: raise ValueError( "ClipBatch collate_fn received an empty sample list" ) max_frame_length = max( sample.tensors["ref_dof_pos"].shape[0] for sample in samples ) max_frame_length = int(max_frame_length) batched_tensors: Dict[str, Tensor] = {} lengths = torch.zeros(len(samples), dtype=torch.long) motion_keys = [] raw_motion_keys = [] window_indices = torch.zeros(len(samples), dtype=torch.long) for batch_idx, sample in enumerate(samples): lengths[batch_idx] = sample.length motion_keys.append(sample.motion_key) raw_motion_keys.append(sample.raw_motion_key) window_indices[batch_idx] = int(sample.window_index) for name, tensor in sample.tensors.items(): if name not in batched_tensors: pad_shape = ( len(samples), max_frame_length, ) + tensor.shape[1:] batched_tensors[name] = torch.zeros( pad_shape, dtype=tensor.dtype, device=tensor.device, ) target = batched_tensors[name] valid_frames = sample.length target[batch_idx, :valid_frames] = tensor if valid_frames < max_frame_length and valid_frames > 0: target[batch_idx, valid_frames:] = tensor[valid_frames - 1] return ClipBatch( tensors=batched_tensors, lengths=lengths, motion_keys=motion_keys, raw_motion_keys=raw_motion_keys, window_indices=window_indices, max_frame_length=max_frame_length, ) class Hdf5RootDofDataset(Dataset[MotionClipSample]): """HDF5 dataset reading ref_root_* + ref_dof_pos only.""" def __init__( self, manifest_path: str | Sequence[str], max_frame_length: int, min_window_length: int = 1, handpicked_motion_names: Optional[List[str]] = None, excluded_motion_names: Optional[List[str]] = None, fk_robot_file_path: Optional[str] = None, fk_vel_smoothing_sigma: float = 2.0, fk_world_frame_normalization: bool = True, online_filter_cfg: Optional[Mapping[str, Any]] = None, allowed_prefixes: Optional[Sequence[str]] = None, ) -> None: super().__init__() if max_frame_length <= 0: raise ValueError("max_frame_length must be positive") if min_window_length <= 0: raise ValueError("min_window_length must be positive") self.max_frame_length = int(max_frame_length) self.min_window_length = int(min_window_length) self.handpicked_motion_names = ( set(handpicked_motion_names) if handpicked_motion_names is not None else None ) self.excluded_motion_names = ( set(excluded_motion_names) if excluded_motion_names is not None else None ) self._fk_robot_file_path = ( str(fk_robot_file_path) if fk_robot_file_path is not None else "" ) if not self._fk_robot_file_path: raise ValueError("fk_robot_file_path is required for hdf5_v2 FK") self._fk_world_frame_normalization = bool(fk_world_frame_normalization) self._fk_transform = _CpuFKTransform(self._fk_robot_file_path) self._world_frame_transform = ( _WorldFrameNormalizeTransform() if self._fk_world_frame_normalization else None ) self._fk_vel_smoothing_sigma = float(fk_vel_smoothing_sigma) self._online_filter_cfg = _normalize_online_filter_cfg( online_filter_cfg, default_vel_smoothing_sigma=self._fk_vel_smoothing_sigma, ) self._online_filter_enabled = bool(self._online_filter_cfg["enabled"]) self._online_filter_butter_order = int( self._online_filter_cfg["butter_order"] ) self._online_filter_cutoff_hz_pool = tuple( float(v) for v in self._online_filter_cfg["butter_cutoff_hz_pool"] ) self._ref_vel_smoothing_sigma = float( self._online_filter_cfg["ref_vel_smoothing_sigma"] ) self._ft_ref_vel_smoothing_sigma = float( self._online_filter_cfg["ft_ref_vel_smoothing_sigma"] ) if allowed_prefixes is None: self._allowed_prefixes = ("ref_", "ft_ref_") else: self._allowed_prefixes = tuple(str(v) for v in allowed_prefixes) if "ref_" not in self._allowed_prefixes: raise ValueError( "Hdf5RootDofDataset requires 'ref_' in allowed_prefixes" ) if isinstance(manifest_path, (str, os.PathLike)): manifest_paths: List[str] = [str(manifest_path)] else: manifest_paths = [str(p) for p in manifest_path] if len(manifest_paths) == 0: raise ValueError("At least one manifest_path must be provided") self.hdf5_root = os.path.dirname(manifest_paths[0]) self._manifest_paths: List[str] = manifest_paths self._shard_paths: List[str] = [] self.shards: List[Dict[str, Any]] = [] self.clips: Dict[str, Dict[str, Any]] = {} for mp in manifest_paths: if not os.path.exists(mp): raise FileNotFoundError( f"HDF5 manifest not found at {mp}. " "Please set robot.motion.hdf5_root/train_hdf5_roots " "to the correct path." ) with open(mp, "r", encoding="utf-8") as handle: manifest = json.load(handle) root = os.path.dirname(mp) shards_local = list(manifest.get("hdf5_shards", [])) clips_local = manifest.get("clips", {}) shard_offset = len(self.shards) for shard_meta in shards_local: self.shards.append(shard_meta) rel = shard_meta.get("file", None) if not isinstance(rel, str) or not rel: raise ValueError( f"Shard entry in manifest {mp} is missing a valid 'file' field" ) self._shard_paths.append(os.path.join(root, rel)) for key, meta in clips_local.items(): if key in self.clips: raise ValueError( f"Duplicate motion clip key '{key}' found in multiple " "manifests; clip keys must be globally unique." ) meta_global = dict(meta) meta_global["shard"] = ( int(meta_global.get("shard", 0)) + shard_offset ) self.clips[key] = meta_global if len(self.shards) == 0: raise ValueError( f"No HDF5 shards listed in manifests: {', '.join(manifest_paths)}" ) self.windows: List[MotionWindow] = self._enumerate_windows() if len(self.windows) == 0: raise ValueError( "No motion windows satisfy the requested frame length constraints" ) # Setting up hdf5 file handles management for bounded host-memory usage self._file_handles: "OrderedDict[int, h5py.File]" = OrderedDict() max_open_env = os.getenv("HOLOMOTION_HDF5_MAX_OPEN_SHARDS") if max_open_env is None: self._h5_max_open_files = 16 else: self._h5_max_open_files = max(1, int(max_open_env)) self._h5_access_counter = 0 self._h5_cleanup_interval = int( 1.0e6 ) # clean h5 handles every 1 million samples def set_progress_counter(self, counter: Optional[mp.Value]) -> None: self._progress_counter = counter @staticmethod def _normalize_motion_key(value: Any) -> Optional[str]: if value is None: return None if isinstance(value, bytes): value = value.decode("utf-8") if isinstance(value, str): key = value else: key = str(value) if not key: return None return key def _build_motion_key_aliases( self, motion_key: str, meta: Mapping[str, Any] ) -> Tuple[str, ...]: aliases: List[str] = [] def _add(value: Any) -> None: key = self._normalize_motion_key(value) if key is None: return if key in aliases: return aliases.append(key) _add(motion_key) if isinstance(meta, Mapping): _add(meta.get("motion_key")) metadata = meta.get("metadata") if isinstance(metadata, Mapping): _add(metadata.get("motion_key")) _add(metadata.get("raw_motion_key")) return tuple(aliases) def _enumerate_windows(self) -> List[MotionWindow]: windows: List[MotionWindow] = [] for motion_key, meta in self.clips.items(): aliases = self._build_motion_key_aliases(motion_key, meta) if self.handpicked_motion_names is not None and not any( alias in self.handpicked_motion_names for alias in aliases ): continue if self.excluded_motion_names is not None and any( alias in self.excluded_motion_names for alias in aliases ): continue shard_index = int(meta.get("shard", 0)) start = int(meta.get("start", 0)) length = int(meta.get("length", 0)) if length <= 0: continue remaining = length offset = 0 window_index = 0 while remaining > 0: window_length = min(self.max_frame_length, remaining) if window_length >= self.min_window_length: win_start = start + offset unique_key = ( f"{motion_key}__start_{win_start}_len_{window_length}" ) windows.append( MotionWindow( motion_key=unique_key, shard_index=shard_index, start=win_start, length=window_length, raw_motion_key=motion_key, window_index=window_index, ) ) window_index += 1 offset += window_length remaining = max(0, length - offset) return windows def __len__(self) -> int: return len(self.windows) @staticmethod def _cast_motion_np(np_array: np.ndarray, name: str) -> Tensor: if np_array.dtype == np.float32: pass elif np_array.dtype.kind == "O": raise ValueError(f"{name} has object dtype") elif np.issubdtype(np_array.dtype, np.integer): logger.warning( "Casting {} from {} to float32.", name, np_array.dtype ) np_array = np_array.astype(np.float32, copy=False) else: raise ValueError( f"{name} has dtype {np_array.dtype}, expected float32 or integer." ) return torch.from_numpy(np_array).to(torch.float32) @staticmethod def _make_scalar_metadata_tensor(value: float, length: int) -> Tensor: return torch.full((int(length), 1), float(value), dtype=torch.float32) def _sample_online_filter_cutoff_hz(self) -> float: if not self._online_filter_enabled: return 0.0 cutoff_pool = self._online_filter_cutoff_hz_pool if len(cutoff_pool) == 0: raise ValueError( "Online filter is enabled but butter_cutoff_hz_pool is empty" ) if len(cutoff_pool) == 1: return cutoff_pool[0] sample_idx = int(torch.randint(len(cutoff_pool), size=(1,)).item()) return cutoff_pool[sample_idx] def _add_online_filtered_reference_tensors( self, arrays: Dict[str, Tensor], fps: float, cutoff_hz: float, ) -> None: filtered_inputs_np = butterworth_filter_root_dof_arrays( arrays={ "ref_root_pos": arrays["ref_root_pos"].cpu().numpy(), "ref_root_rot": arrays["ref_root_rot"].cpu().numpy(), "ref_dof_pos": arrays["ref_dof_pos"].cpu().numpy(), }, fps=float(fps), cutoff_hz=float(cutoff_hz), order=self._online_filter_butter_order, ) for tensor_name, np_array in filtered_inputs_np.items(): arrays[tensor_name] = torch.from_numpy(np_array).to(torch.float32) self._fk_transform( arrays, fps, prefix="ft_ref_", vel_smoothing_sigma=self._ft_ref_vel_smoothing_sigma, ) @staticmethod def _derive_root_state_tensors( arrays: Dict[str, Tensor], prefix: str = "ref_", ) -> None: rg_pos_key = f"{prefix}rg_pos" rb_rot_key = f"{prefix}rb_rot" body_vel_key = f"{prefix}body_vel" body_ang_vel_key = f"{prefix}body_ang_vel" if ( rg_pos_key not in arrays or rb_rot_key not in arrays or body_vel_key not in arrays or body_ang_vel_key not in arrays ): return # Keep root-level tensors consistent with the FK-derived body tensors. arrays[f"{prefix}root_pos"] = arrays[rg_pos_key][:, 0, :] arrays[f"{prefix}root_rot"] = arrays[rb_rot_key][:, 0, :] arrays[f"{prefix}root_vel"] = arrays[body_vel_key][:, 0, :] arrays[f"{prefix}root_ang_vel"] = arrays[body_ang_vel_key][:, 0, :] def __getitem__(self, index: int) -> MotionClipSample: window = self.windows[index] shard_handle = self._get_shard_handle(window.shard_index) start, end = window.start, window.start + window.length arrays: Dict[str, Tensor] = {} for dataset_name in ("ref_root_pos", "ref_root_rot", "ref_dof_pos"): if dataset_name not in shard_handle: raise KeyError( f"Missing mandatory dataset '{dataset_name}' in shard index " f"{window.shard_index}" ) np_array = np.asarray(shard_handle[dataset_name][start:end, ...]) arrays[dataset_name] = self._cast_motion_np(np_array, dataset_name) if "frame_flag" in shard_handle: frame_flag_np = shard_handle["frame_flag"][start:end] if frame_flag_np.dtype.kind == "O": raise ValueError("frame_flag has object dtype") frame_flag = torch.from_numpy(frame_flag_np).to(torch.long) else: frame_flag = torch.ones(window.length, dtype=torch.long) if window.length > 1: frame_flag[0] = 0 frame_flag[-1] = 2 elif window.length == 1: frame_flag[0] = 2 arrays["frame_flag"] = frame_flag clip_meta = self.clips.get(window.raw_motion_key, {}) metadata = clip_meta.get("metadata", {}) motion_fps_val = metadata.get( "motion_fps", clip_meta.get("motion_fps") ) if motion_fps_val is None: raise ValueError( f"motion_fps missing for clip {window.raw_motion_key}" ) motion_fps = float(motion_fps_val) if motion_fps <= 0.0: raise ValueError( f"Invalid motion_fps {motion_fps} for clip {window.raw_motion_key}" ) arrays["motion_fps"] = self._make_scalar_metadata_tensor( motion_fps, window.length ) cutoff_hz = self._sample_online_filter_cutoff_hz() arrays["filter_cutoff_hz"] = self._make_scalar_metadata_tensor( cutoff_hz, window.length ) self._fk_transform( arrays, motion_fps, vel_smoothing_sigma=self._ref_vel_smoothing_sigma, ) if self._online_filter_enabled and "ft_ref_" in self._allowed_prefixes: self._add_online_filtered_reference_tensors( arrays, motion_fps, cutoff_hz, ) if self._world_frame_transform is not None: self._world_frame_transform(arrays) self._derive_root_state_tensors(arrays, prefix="ref_") self._derive_root_state_tensors(arrays, prefix="ft_ref_") if self._progress_counter is not None: with self._progress_counter.get_lock(): self._progress_counter.value += 1 return MotionClipSample( motion_key=window.motion_key, raw_motion_key=window.raw_motion_key, window_index=int(index), tensors=arrays, length=window.length, ) def _get_shard_handle(self, shard_index: int) -> h5py.File: # periodically clean up the file handles self._h5_access_counter += 1 if self._h5_access_counter >= self._h5_cleanup_interval: self.close() self._h5_access_counter = 0 if shard_index in self._file_handles: handle = self._file_handles.pop(shard_index) if handle.id: self._file_handles[shard_index] = handle return handle if shard_index < 0 or shard_index >= len(self._shard_paths): raise IndexError( f"Shard index {shard_index} out of range for " f"{len(self._shard_paths)} available shards" ) shard_path = self._shard_paths[shard_index] rdcc_nbytes_env = os.getenv("HOLOMOTION_HDF5_RDCC_NBYTES") if rdcc_nbytes_env is None: rdcc_nbytes = 4 * 1024 * 1024 else: rdcc_nbytes = int(rdcc_nbytes_env) handle = h5py.File( shard_path, "r", libver="latest", swmr=True, rdcc_nbytes=rdcc_nbytes, rdcc_w0=0.75, ) if ( self._h5_max_open_files is not None and len(self._file_handles) >= self._h5_max_open_files ): old_index, old_handle = self._file_handles.popitem(last=False) old_handle.close() self._file_handles[shard_index] = handle return handle def close(self) -> None: logger.info("Clearing HDF5 file handles ...") for handle in self._file_handles.values(): if handle.id: handle.close() self._file_handles.clear() def __del__(self) -> None: self.close() def _normalize_root_list(value: Any) -> List[str]: if value is None: return [] if isinstance(value, (str, os.PathLike)): return [str(value)] return [str(v) for v in value] def build_motion_datasets_from_cfg( motion_cfg: Mapping[str, Any], *, max_frame_length: int, min_window_length: int, world_frame_normalization: bool = True, handpicked_motion_names: Optional[List[str]] = None, excluded_motion_names: Optional[List[str]] = None, allowed_prefixes: Optional[Sequence[str]] = None, ) -> Tuple[ Dataset[MotionClipSample], Optional[Dataset[MotionClipSample]], Dict[str, Any], ]: preview_sampling_from_cfg(motion_cfg=motion_cfg) backend = str(motion_cfg.get("backend", "hdf5")).lower() if backend in ("hdf5", "hdf5_simple"): train_roots = _normalize_root_list( motion_cfg.get("train_hdf5_roots", None) ) if len(train_roots) == 0: hdf5_root = motion_cfg.get("hdf5_root", None) if not hdf5_root: raise ValueError( "HDF5 backend requires train_hdf5_roots or hdf5_root" ) train_roots = [str(hdf5_root)] manifest_paths = [ os.path.join(str(root), "manifest.json") for root in train_roots ] train_dataset = Hdf5MotionDataset( manifest_path=manifest_paths if len(manifest_paths) > 1 else manifest_paths[0], max_frame_length=max_frame_length, min_window_length=min_window_length, handpicked_motion_names=handpicked_motion_names, excluded_motion_names=excluded_motion_names, world_frame_normalization=world_frame_normalization, allowed_prefixes=allowed_prefixes, ) val_roots = _normalize_root_list( motion_cfg.get("val_hdf5_roots", motion_cfg.get("val_hdf5_root")) ) val_dataset = None if len(val_roots) > 0: val_manifest_paths = [ os.path.join(str(root), "manifest.json") for root in val_roots ] val_dataset = Hdf5MotionDataset( manifest_path=val_manifest_paths if len(val_manifest_paths) > 1 else val_manifest_paths[0], max_frame_length=max_frame_length, min_window_length=min_window_length, handpicked_motion_names=handpicked_motion_names, excluded_motion_names=excluded_motion_names, world_frame_normalization=world_frame_normalization, allowed_prefixes=allowed_prefixes, ) return train_dataset, val_dataset, {} if backend == "hdf5_v2": fk_robot_file_path = motion_cfg.get("fk_robot_file_path") fk_vel_smoothing_sigma = float( motion_cfg.get("fk_vel_smoothing_sigma", 2.0) ) fk_world_frame_normalization = bool( motion_cfg.get("online_fk_world_frame_normalization", True) ) cache_cfg = motion_cfg.get("cache", {}) allowed_prefixes = cache_cfg.get( "allowed_prefixes", ["ref_", "ft_ref_"], ) online_filter_cfg = motion_cfg.get("online_filter", {}) train_roots = _normalize_root_list( motion_cfg.get("train_hdf5_roots", None) ) if len(train_roots) == 0: hdf5_root = motion_cfg.get("hdf5_root", None) if not hdf5_root: raise ValueError( "HDF5 v2 backend requires train_hdf5_roots or hdf5_root" ) train_roots = [str(hdf5_root)] train_manifest_paths = [ os.path.join(str(root), "manifest.json") for root in train_roots ] train_dataset = Hdf5RootDofDataset( manifest_path=train_manifest_paths if len(train_manifest_paths) > 1 else train_manifest_paths[0], max_frame_length=max_frame_length, min_window_length=min_window_length, handpicked_motion_names=handpicked_motion_names, excluded_motion_names=excluded_motion_names, fk_robot_file_path=fk_robot_file_path, fk_vel_smoothing_sigma=fk_vel_smoothing_sigma, fk_world_frame_normalization=fk_world_frame_normalization, online_filter_cfg=online_filter_cfg, allowed_prefixes=allowed_prefixes, ) val_roots = _normalize_root_list( motion_cfg.get("val_hdf5_roots", motion_cfg.get("val_hdf5_root")) ) val_dataset = None if len(val_roots) > 0: val_manifest_paths = [ os.path.join(str(root), "manifest.json") for root in val_roots ] val_dataset = Hdf5RootDofDataset( manifest_path=val_manifest_paths if len(val_manifest_paths) > 1 else val_manifest_paths[0], max_frame_length=max_frame_length, min_window_length=min_window_length, handpicked_motion_names=handpicked_motion_names, excluded_motion_names=excluded_motion_names, fk_robot_file_path=fk_robot_file_path, fk_vel_smoothing_sigma=fk_vel_smoothing_sigma, fk_world_frame_normalization=fk_world_frame_normalization, online_filter_cfg=online_filter_cfg, allowed_prefixes=allowed_prefixes, ) cache_kwargs = { "stage_on_swap_only": bool( motion_cfg.get("stage_on_swap_only", True) ) } return train_dataset, val_dataset, cache_kwargs raise ValueError(f"Unsupported motion backend: {backend}") def _cache_collate_fn( samples: List[MotionClipSample], mode: str, batch_size: int, ) -> ClipBatch: """Collate function for motion cache DataLoader (supports validation padding).""" if mode == "val" and batch_size > len(samples) and len(samples) > 0: extra = batch_size - len(samples) gen = torch.Generator() idx = torch.randint(0, len(samples), size=(extra,), generator=gen) padded = list(samples) for i in idx.tolist(): padded.append(samples[i]) return ClipBatch.collate_fn(padded) return ClipBatch.collate_fn(samples) class InfiniteDistributedSampler(DistributedSampler): """Distributed sampler that yields an infinite stream by cycling epochs.""" def __iter__(self): # Infinite stream by cycling epochs while True: self.set_epoch(getattr(self, "_epoch", 0)) for idx in super().__iter__(): yield idx self._epoch = getattr(self, "_epoch", 0) + 1 class InfiniteRandomSampler(Sampler[int]): """Random sampler that yields infinite reshuffled passes over the dataset.""" def __init__(self, data_source: Dataset, seed: int = 0) -> None: self.data_source = data_source self.seed = int(seed) self.epoch = 0 def __iter__(self): # Yield infinite permutations of indices while True: g = torch.Generator() g.manual_seed(self.seed + self.epoch) perm = torch.randperm(len(self.data_source), generator=g) for idx in perm.tolist(): yield int(idx) self.epoch += 1 def __len__(self) -> int: # Large sentinel to satisfy components that query length return 2**31 - 1 class WeightedBinInfiniteSampler(Sampler[int]): """Infinite sampler that respects regex-based weighted bins over indices.""" def __init__( self, dataset_len: int, bin_indices: List[List[int]], ratios: List[float], batch_size: int, seed: int, ) -> None: self._ds_len = int(max(0, dataset_len)) self._bins = [torch.tensor(b, dtype=torch.long) for b in bin_indices] self._ratios = list(ratios) self._batch_size = int(max(1, batch_size)) self._seed = int(seed) self._epoch = 0 raw_counts = [r * float(self._batch_size) for r in self._ratios] self._counts = _allocate_batch_counts( raw_counts=raw_counts, target_total=self._batch_size, ) def __iter__(self): while True: g = torch.Generator() g.manual_seed(self._seed + self._epoch) batch: List[int] = [] for bin_idx, count in zip(self._bins, self._counts): if count <= 0 or bin_idx.numel() == 0: continue choice = torch.randint( 0, int(bin_idx.numel()), size=(count,), generator=g, ) selected = bin_idx[choice].tolist() batch.extend(int(x) for x in selected) if not batch: # Fallback: uniform over dataset indices if self._ds_len == 0: raise ValueError( "WeightedBinInfiniteSampler cannot sample from an empty dataset" ) all_idx = torch.randint( 0, self._ds_len, size=(self._batch_size,), generator=g, ) batch = [int(x) for x in all_idx.tolist()] if len(batch) > self._batch_size: batch = batch[: self._batch_size] elif len(batch) < self._batch_size: pad = self._batch_size - len(batch) if pad > 0: batch.extend(batch[:pad]) perm = torch.randperm(len(batch), generator=g) for idx in perm.tolist(): yield int(batch[idx]) self._epoch += 1 def __len__(self) -> int: return 2**31 - 1 class PrioritizedInfiniteSampler(Sampler[int]): """Infinite sampler with persistent prioritized and fresh uniform pools.""" def __init__( self, dataset_len: int, batch_size: int, seed: int, *, p_a_ratio: float = 0.2, ema_alpha_signal: float = 0.2, ema_alpha_rel_improve: float = 0.2, relative_eps: float = 1.0e-6, ) -> None: self._ds_len = int(max(0, dataset_len)) self._batch_size = int(max(1, batch_size)) self._seed = int(seed) self._epoch = 0 self._p_a_ratio = float(min(1.0, max(0.0, p_a_ratio))) self._ema_alpha_signal = float(min(1.0, max(0.0, ema_alpha_signal))) self._ema_alpha_rel_improve = float( min(1.0, max(0.0, ema_alpha_rel_improve)) ) self._relative_eps = float(max(1.0e-12, relative_eps)) if self._ds_len <= 0: self._ema_completion_rate = torch.zeros(0, dtype=torch.float32) self._ema_completion_rate_sq = torch.zeros(0, dtype=torch.float32) self._ema_completion_rel_improve = torch.zeros( 0, dtype=torch.float32 ) self._selection_counts = torch.zeros(0, dtype=torch.long) self._seen_mask = torch.zeros(0, dtype=torch.bool) self._prioritized_pool_indices = torch.zeros(0, dtype=torch.long) self._prioritized_pool_mask = torch.zeros(0, dtype=torch.bool) else: self._ema_completion_rate = torch.zeros( self._ds_len, dtype=torch.float32 ) self._ema_completion_rate_sq = torch.zeros( self._ds_len, dtype=torch.float32 ) self._ema_completion_rel_improve = torch.zeros( self._ds_len, dtype=torch.float32 ) self._selection_counts = torch.zeros( self._ds_len, dtype=torch.long ) self._seen_mask = torch.zeros(self._ds_len, dtype=torch.bool) self._prioritized_pool_indices = torch.zeros(0, dtype=torch.long) self._prioritized_pool_mask = torch.zeros( self._ds_len, dtype=torch.bool ) self._state_version = 0 self._last_updated_swap = -1 self._last_prioritized_pool_mean_score = 0.0 self._last_uniform_pool_mean_score = 0.0 self._last_entered_prioritized_pool_count = 0 self._last_exited_prioritized_pool_count = 0 self._uniform_cycle_start = 0 self._uniform_cycle_step = 1 self._uniform_cycle_offset = self._ds_len self._uniform_cycle_epoch = 0 @property def state_version(self) -> int: return int(self._state_version) def get_pool_statistics(self) -> Optional[Dict[str, float]]: if self._ds_len <= 0: return None return self._pool_metric_stats() @staticmethod def _aggregate_by_index( window_indices: Tensor, values: Tensor, counts: Tensor, ) -> Tuple[Tensor, Tensor, Tensor]: if window_indices.numel() == 0: return ( torch.zeros(0, dtype=torch.long), torch.zeros(0, dtype=torch.float32), torch.zeros(0, dtype=torch.float32), ) unique_indices, inverse = torch.unique( window_indices.to(dtype=torch.long), sorted=False, return_inverse=True, ) out_weighted_sum = torch.zeros( unique_indices.numel(), dtype=torch.float32 ) out_count = torch.zeros(unique_indices.numel(), dtype=torch.float32) out_weighted_sum.scatter_add_(0, inverse, values * counts) out_count.scatter_add_(0, inverse, counts) return unique_indices, out_weighted_sum, out_count def _pool_batch_sizes(self) -> Tuple[int, int]: if self._ds_len <= 0: return 0, 0 uniform_count = int(round(self._p_a_ratio * float(self._batch_size))) uniform_count = max(0, min(self._batch_size, uniform_count)) prioritized_count = max(0, self._batch_size - uniform_count) return uniform_count, prioritized_count def _priority_scores_for_indices(self, indices: Tensor) -> Tensor: if indices.numel() == 0 or self._ds_len <= 0: return torch.zeros(0, dtype=torch.float32) idx = indices.to(dtype=torch.long) progress = torch.clamp( self._ema_completion_rel_improve.index_select(0, idx), min=0.0, max=1.0, ) remaining_difficulty = torch.clamp( 1.0 - self._ema_completion_rate.index_select(0, idx), min=0.0, max=1.0, ) seen = self._seen_mask.index_select(0, idx).to(dtype=torch.float32) return progress * remaining_difficulty * seen def _pool_metric_stats(self) -> Dict[str, float]: prioritized_pool_size = int(self._prioritized_pool_indices.numel()) return { "prioritized_pool_size": float(prioritized_pool_size), "prioritized_pool_mean_score": float( self._last_prioritized_pool_mean_score ), "uniform_pool_mean_score": float( self._last_uniform_pool_mean_score ), "entered_prioritized_pool_count": float( self._last_entered_prioritized_pool_count ), "exited_prioritized_pool_count": float( self._last_exited_prioritized_pool_count ), } def get_window_state_for_indices( self, window_indices: Tensor ) -> Dict[str, Tensor]: if self._ds_len <= 0: empty_bool = torch.zeros(0, dtype=torch.bool) empty_float = torch.zeros(0, dtype=torch.float32) return { "ema_completion_rate": empty_float, "completion_rate_rel_improve": empty_float, "selection_count": torch.zeros(0, dtype=torch.long), "seen": empty_bool, "in_prioritized_pool": empty_bool, } idx = window_indices.detach().to(dtype=torch.long).reshape(-1).cpu() if idx.numel() == 0: empty_bool = torch.zeros(0, dtype=torch.bool) empty_float = torch.zeros(0, dtype=torch.float32) return { "ema_completion_rate": empty_float, "completion_rate_rel_improve": empty_float, "selection_count": torch.zeros(0, dtype=torch.long), "seen": empty_bool, "in_prioritized_pool": empty_bool, } return { "ema_completion_rate": self._ema_completion_rate.index_select( 0, idx ).to(dtype=torch.float32), "completion_rate_rel_improve": ( self._ema_completion_rel_improve.index_select(0, idx).to( dtype=torch.float32 ) ), "selection_count": self._selection_counts.index_select(0, idx), "seen": self._seen_mask.index_select(0, idx), "in_prioritized_pool": self._prioritized_pool_mask.index_select( 0, idx ), } def _rebuild_prioritized_pool(self, candidate_indices: Tensor) -> None: if self._ds_len <= 0: return _, prioritized_count = self._pool_batch_sizes() previous_indices = self._prioritized_pool_indices selected = torch.zeros(0, dtype=torch.long) if prioritized_count > 0: candidates = torch.cat( [ previous_indices.to(dtype=torch.long), candidate_indices.to(dtype=torch.long).reshape(-1), ] ) candidates = torch.unique(candidates, sorted=False) scores = self._priority_scores_for_indices(candidates) positive = scores > 0.0 if bool(positive.any().item()): candidates = candidates[positive] scores = scores[positive] order = torch.argsort(scores, descending=True) selected = candidates.index_select( 0, order[: min(prioritized_count, candidates.numel())] ) scores = scores.index_select( 0, order[: min(prioritized_count, scores.numel())] ) self._last_prioritized_pool_mean_score = float( scores.mean().item() ) else: self._last_prioritized_pool_mean_score = 0.0 if candidates.numel() > selected.numel(): selected_mask = torch.zeros( candidates.numel(), dtype=torch.bool ) if selected.numel() > 0: matches = candidates[:, None] == selected[None, :] selected_mask = matches.any(dim=1) nonselected_scores = self._priority_scores_for_indices( candidates[~selected_mask] ) self._last_uniform_pool_mean_score = ( float(nonselected_scores.mean().item()) if nonselected_scores.numel() > 0 else 0.0 ) else: self._last_uniform_pool_mean_score = 0.0 else: self._last_prioritized_pool_mean_score = 0.0 self._last_uniform_pool_mean_score = 0.0 if previous_indices.numel() > 0: self._prioritized_pool_mask[previous_indices] = False if selected.numel() > 0: self._prioritized_pool_mask[selected] = True previous_set = set(previous_indices.tolist()) selected_set = set(selected.tolist()) self._last_entered_prioritized_pool_count = len( selected_set - previous_set ) self._last_exited_prioritized_pool_count = len( previous_set - selected_set ) self._prioritized_pool_indices = selected def maybe_update_from_observations( self, *, window_indices: Tensor, mpkpe_signal_means: Tensor, completion_rate_means: Tensor, counts: Tensor, swap_index: int, ) -> bool: if self._ds_len <= 0: return False swap_idx = int(swap_index) if swap_idx <= 0: return False if self._last_updated_swap == swap_idx: return False indices = ( window_indices.detach().to(dtype=torch.long).reshape(-1).cpu() ) # Keep validating the MPKPE tensor shape so the command-side # curriculum aggregation stays aligned with completion-rate updates. mpkpe_signal_numel = int(mpkpe_signal_means.numel()) completion_rate = ( completion_rate_means.detach() .to(dtype=torch.float32) .reshape(-1) .cpu() ) cnt = counts.detach().to(dtype=torch.float32).reshape(-1).cpu() if not ( indices.numel() == mpkpe_signal_numel and mpkpe_signal_numel == completion_rate.numel() and completion_rate.numel() == cnt.numel() ): raise ValueError( "Prioritized sampler update tensors must have matching shape." ) valid_dataset_idx = (indices >= 0) & (indices < self._ds_len) valid = ( valid_dataset_idx & torch.isfinite(completion_rate) & (cnt > 0.0) ) current_batch_indices = torch.unique( indices[valid_dataset_idx], sorted=False ) if not bool(valid.any().item()): self._last_entered_prioritized_pool_count = 0 self._last_exited_prioritized_pool_count = 0 self._last_updated_swap = swap_idx return False idx_valid = indices[valid] completion_rate_valid = completion_rate[valid] cnt_valid = cnt[valid] touched_idx, completion_rate_sum, completion_rate_count_sum = ( self._aggregate_by_index( idx_valid, completion_rate_valid, cnt_valid, ) ) if touched_idx.numel() == 0: self._last_entered_prioritized_pool_count = 0 self._last_exited_prioritized_pool_count = 0 self._last_updated_swap = swap_idx return False completion_rate_obs = ( completion_rate_sum / completion_rate_count_sum.clamp_min(1.0e-12) ) completion_rate_obs = torch.clamp( completion_rate_obs, min=0.0, max=1.0 ) prev_seen = self._seen_mask[touched_idx] prev_completion_rate = self._ema_completion_rate[touched_idx] prev_completion_rate_sq = self._ema_completion_rate_sq[touched_idx] prev_completion_rate_var = torch.clamp( prev_completion_rate_sq - prev_completion_rate * prev_completion_rate, min=1.0e-6, ) prev_completion_rate_std = torch.sqrt(prev_completion_rate_var) next_completion_rate = torch.where( prev_seen, (1.0 - self._ema_alpha_signal) * prev_completion_rate + self._ema_alpha_signal * completion_rate_obs, completion_rate_obs, ) next_completion_rate_sq = torch.where( prev_seen, (1.0 - self._ema_alpha_signal) * prev_completion_rate_sq + self._ema_alpha_signal * (completion_rate_obs * completion_rate_obs), completion_rate_obs * completion_rate_obs, ) completion_rel_improve_obs = torch.zeros_like(next_completion_rate) completion_rel_improve_obs[prev_seen] = torch.tanh( (completion_rate_obs[prev_seen] - prev_completion_rate[prev_seen]) / (prev_completion_rate_std[prev_seen] + self._relative_eps) ) prev_completion_rel = self._ema_completion_rel_improve[touched_idx] next_completion_rel = torch.where( prev_seen, (1.0 - self._ema_alpha_rel_improve) * prev_completion_rel + self._ema_alpha_rel_improve * completion_rel_improve_obs, completion_rel_improve_obs, ) self._ema_completion_rate[touched_idx] = next_completion_rate self._ema_completion_rate_sq[touched_idx] = next_completion_rate_sq self._ema_completion_rel_improve[touched_idx] = next_completion_rel self._seen_mask[touched_idx] = True self._rebuild_prioritized_pool(touched_idx) self._state_version += 1 self._last_updated_swap = swap_idx return True def _reset_uniform_cycle(self) -> None: if self._ds_len <= 0: self._uniform_cycle_start = 0 self._uniform_cycle_step = 1 self._uniform_cycle_offset = 0 return generator = torch.Generator() generator.manual_seed(self._seed + self._uniform_cycle_epoch * 1000003) self._uniform_cycle_epoch += 1 self._uniform_cycle_start = int( torch.randint( low=0, high=self._ds_len, size=(1,), generator=generator, ).item() ) if self._ds_len <= 1: self._uniform_cycle_step = 1 else: step = int( torch.randint( low=1, high=self._ds_len, size=(1,), generator=generator, ).item() ) while math.gcd(step, self._ds_len) != 1: step += 1 if step >= self._ds_len: step = 1 self._uniform_cycle_step = step self._uniform_cycle_offset = 0 def _next_uniform_index(self) -> int: if self._uniform_cycle_offset >= self._ds_len: self._reset_uniform_cycle() next_index = ( self._uniform_cycle_start + self._uniform_cycle_offset * self._uniform_cycle_step ) % self._ds_len self._uniform_cycle_offset += 1 return int(next_index) def _sample_uniform_indices( self, generator: torch.Generator, count: int, *, exclude: Optional[Tensor] = None, ) -> Tensor: del generator if count <= 0 or self._ds_len <= 0: return torch.zeros(0, dtype=torch.long) blocked = set() if exclude is not None and exclude.numel() > 0: blocked.update( exclude.detach().to(dtype=torch.long).reshape(-1).tolist() ) take = min(int(count), max(0, self._ds_len - len(blocked))) if take <= 0: return torch.zeros(0, dtype=torch.long) selected: List[int] = [] stagnant_steps = 0 while len(selected) < take and stagnant_steps < self._ds_len: next_index = self._next_uniform_index() if next_index in blocked: stagnant_steps += 1 continue selected.append(next_index) blocked.add(next_index) stagnant_steps = 0 return torch.tensor(selected, dtype=torch.long) def _sample_prioritized_indices( self, generator: torch.Generator, count: int ) -> Tensor: if count <= 0 or self._prioritized_pool_indices.numel() == 0: return torch.zeros(0, dtype=torch.long) perm = torch.randperm( self._prioritized_pool_indices.numel(), generator=generator ) take = min(count, int(self._prioritized_pool_indices.numel())) return self._prioritized_pool_indices.index_select(0, perm[:take]) def _sample_batch_indices(self, generator: torch.Generator) -> Tensor: uniform_count, prioritized_count = self._pool_batch_sizes() prioritized_indices = self._sample_prioritized_indices( generator, prioritized_count ) uniform_indices = self._sample_uniform_indices( generator, uniform_count, exclude=prioritized_indices, ) sampled_indices = torch.cat( [uniform_indices, prioritized_indices], dim=0 ) if sampled_indices.numel() < self._batch_size: extra_indices = self._sample_uniform_indices( generator, self._batch_size - int(sampled_indices.numel()), exclude=sampled_indices, ) sampled_indices = torch.cat( [sampled_indices, extra_indices], dim=0 ) if sampled_indices.numel() != self._batch_size: raise ValueError( "Prioritized sampler failed to assemble a full cache batch." ) if sampled_indices.numel() > 0: self._selection_counts[sampled_indices] += 1 return sampled_indices def get_scores_for_indices(self, window_indices: Tensor) -> Tensor: if self._ds_len <= 0: return torch.zeros_like(window_indices, dtype=torch.float32) idx = window_indices.detach().to(dtype=torch.long).reshape(-1).cpu() if idx.numel() == 0: return torch.zeros(0, dtype=torch.float32) scores = self._priority_scores_for_indices(idx) return scores.to(dtype=torch.float32) def __iter__(self): while True: if self._ds_len <= 0: raise ValueError( "PrioritizedInfiniteSampler cannot sample from " "an empty dataset." ) g = torch.Generator() g.manual_seed(self._seed + self._epoch) sampled_indices = self._sample_batch_indices(generator=g) perm = torch.randperm(sampled_indices.numel(), generator=g) yielded_indices = sampled_indices.index_select(0, perm) for idx in yielded_indices.tolist(): yield int(idx) self._epoch += 1 def __len__(self) -> int: return 2**31 - 1 class Hdf5MotionDataset(Dataset[MotionClipSample]): """Dataset that materializes fixed-length motion windows from HDF5 shards.""" def __init__( self, manifest_path: str | Sequence[str], max_frame_length: int, min_window_length: int = 1, handpicked_motion_names: Optional[List[str]] = None, excluded_motion_names: Optional[List[str]] = None, world_frame_normalization: bool = True, allowed_prefixes: Optional[Sequence[str]] = None, ) -> None: super().__init__() if max_frame_length <= 0: raise ValueError("max_frame_length must be positive") self.max_frame_length = int(max_frame_length) self.min_window_length = int(min_window_length) self.handpicked_motion_names = ( set(handpicked_motion_names) if handpicked_motion_names is not None else None ) self.excluded_motion_names = ( set(excluded_motion_names) if excluded_motion_names is not None else None ) self._world_frame_transform = ( _WorldFrameNormalizeTransform() if bool(world_frame_normalization) else None ) self._allowed_prefixes: Tuple[str, ...] = ("ref_", "ft_ref_") self._progress_counter: Optional[mp.Value] = None # Normalize manifest path(s) to a list for aggregation. if isinstance(manifest_path, (str, os.PathLike)): manifest_paths: List[str] = [str(manifest_path)] else: manifest_paths = [str(p) for p in manifest_path] if len(manifest_paths) == 0: raise ValueError("At least one manifest_path must be provided") # Aggregate shards and clips across one or many manifests into a single # logical dataset. Clip keys must be globally unique. self.hdf5_root = os.path.dirname(manifest_paths[0]) self._manifest_paths: List[str] = manifest_paths self._shard_paths: List[str] = [] self.shards: List[Dict[str, Any]] = [] self.clips: Dict[str, Dict[str, Any]] = {} for mp in manifest_paths: if not os.path.exists(mp): raise FileNotFoundError( f"HDF5 manifest not found at {mp}. " "Please set robot.motion.hdf5_root/train_hdf5_roots " "to the correct path." ) with open(mp, "r", encoding="utf-8") as handle: manifest = json.load(handle) root = os.path.dirname(mp) shards_local = list(manifest.get("hdf5_shards", [])) clips_local = manifest.get("clips", {}) shard_offset = len(self.shards) for shard_meta in shards_local: self.shards.append(shard_meta) rel = shard_meta.get("file", None) if not isinstance(rel, str) or not rel: raise ValueError( f"Shard entry in manifest {mp} is missing a valid 'file' field" ) self._shard_paths.append(os.path.join(root, rel)) for key, meta in clips_local.items(): if key in self.clips: raise ValueError( f"Duplicate motion clip key '{key}' found in multiple " "manifests; clip keys must be globally unique." ) meta_global = dict(meta) meta_global["shard"] = ( int(meta_global.get("shard", 0)) + shard_offset ) self.clips[key] = meta_global if len(self.shards) == 0: raise ValueError( f"No HDF5 shards listed in manifests: {', '.join(manifest_paths)}" ) self.windows: List[MotionWindow] = self._enumerate_windows() if len(self.windows) == 0: raise ValueError( "No motion windows satisfy the requested frame length constraints" ) # LRU cache of open HDF5 shard handles; size is bounded to avoid # unbounded host-memory usage from per-file raw chunk caches. self._file_handles: "OrderedDict[int, h5py.File]" = OrderedDict() max_open_env = os.getenv("HOLOMOTION_HDF5_MAX_OPEN_SHARDS") if max_open_env is None: self._max_open_files = 64 else: self._max_open_files = max(1, int(max_open_env)) def set_progress_counter(self, counter: Optional[mp.Value]) -> None: self._progress_counter = counter def _enumerate_windows(self) -> List[MotionWindow]: windows: List[MotionWindow] = [] for motion_key, meta in self.clips.items(): if ( self.handpicked_motion_names is not None and motion_key not in self.handpicked_motion_names ): continue if ( self.excluded_motion_names is not None and motion_key in self.excluded_motion_names ): continue shard_index = int(meta.get("shard", 0)) start = int(meta.get("start", 0)) length = int(meta.get("length", 0)) if length <= 0: continue remaining = length offset = 0 window_index = 0 while remaining > 0: window_length = min(self.max_frame_length, remaining) if window_length >= self.min_window_length: win_start = start + offset unique_key = ( f"{motion_key}__start_{win_start}_len_{window_length}" ) windows.append( MotionWindow( motion_key=unique_key, shard_index=shard_index, start=win_start, length=window_length, raw_motion_key=motion_key, window_index=window_index, ) ) window_index += 1 offset += window_length remaining = max(0, length - offset) return windows def __len__(self) -> int: return len(self.windows) def __getitem__(self, index: int) -> MotionClipSample: window = self.windows[index] shard_handle = self._get_shard_handle(window.shard_index) start, end = window.start, window.start + window.length arrays: Dict[str, Tensor] = {} # Mandatory reference source: ref_* for logical_name, dataset_name in MANDATORY_DATASETS.items(): dname = f"ref_{dataset_name}" if dname not in shard_handle: raise KeyError( f"Missing mandatory dataset '{dname}' in shard index {window.shard_index}" ) np_array = shard_handle[dname][start:end] arrays[f"ref_{logical_name}"] = torch.from_numpy(np_array).to( torch.float32 ) # Optional filtered reference source: ft_ref_* for logical_name, dataset_name in MANDATORY_DATASETS.items(): dname = f"ft_ref_{dataset_name}" if dname in shard_handle: np_array = shard_handle[dname][start:end] arrays[f"ft_ref_{logical_name}"] = torch.from_numpy( np_array ).to(torch.float32) if "frame_flag" in shard_handle: frame_flag_np = shard_handle["frame_flag"][start:end] frame_flag = torch.from_numpy(frame_flag_np).to(torch.long) else: frame_flag = torch.ones(window.length, dtype=torch.long) if window.length > 1: frame_flag[0] = 0 frame_flag[-1] = 2 elif window.length == 1: # Single-frame window: mark as both start and end (use 2 for end) frame_flag[0] = 2 arrays["frame_flag"] = frame_flag if self._world_frame_transform is not None: self._world_frame_transform(arrays) # Derived root_* for ref_* (after normalization) arrays["ref_root_pos"] = arrays["ref_rg_pos"][:, 0, :] arrays["ref_root_rot"] = arrays["ref_rb_rot"][:, 0, :] arrays["ref_root_vel"] = arrays["ref_body_vel"][:, 0, :] arrays["ref_root_ang_vel"] = arrays["ref_body_ang_vel"][:, 0, :] # Derived root_* for optional ft_ref_* (after normalization) if ( "ft_ref_rg_pos" in arrays and "ft_ref_rb_rot" in arrays and "ft_ref_body_vel" in arrays and "ft_ref_body_ang_vel" in arrays ): arrays["ft_ref_root_pos"] = arrays["ft_ref_rg_pos"][:, 0, :] arrays["ft_ref_root_rot"] = arrays["ft_ref_rb_rot"][:, 0, :] arrays["ft_ref_root_vel"] = arrays["ft_ref_body_vel"][:, 0, :] arrays["ft_ref_root_ang_vel"] = arrays["ft_ref_body_ang_vel"][ :, 0, : ] if self._progress_counter is not None: with self._progress_counter.get_lock(): self._progress_counter.value += 1 return MotionClipSample( motion_key=window.motion_key, raw_motion_key=window.raw_motion_key, window_index=int(index), tensors=arrays, length=window.length, ) def _get_shard_handle(self, shard_index: int) -> h5py.File: if shard_index in self._file_handles: handle = self._file_handles.pop(shard_index) if handle.id: # Mark as most recently used. self._file_handles[shard_index] = handle return handle if shard_index < 0 or shard_index >= len(self._shard_paths): raise IndexError( f"Shard index {shard_index} out of range for " f"{len(self._shard_paths)} available shards" ) shard_path = self._shard_paths[shard_index] # Open with SWMR and a configurable raw chunk cache to speed up repeated reads. # The default cache size (in bytes) can be overridden via the # HOLOMOTION_HDF5_RDCC_NBYTES environment variable. rdcc_nbytes_env = os.getenv("HOLOMOTION_HDF5_RDCC_NBYTES") if rdcc_nbytes_env is None: rdcc_nbytes = 256 * 1024 * 1024 # 256MB default else: rdcc_nbytes = int(rdcc_nbytes_env) handle = h5py.File( shard_path, "r", libver="latest", swmr=True, rdcc_nbytes=rdcc_nbytes, rdcc_w0=0.75, ) # Enforce LRU limit on the number of simultaneously open shard files. if ( self._max_open_files is not None and len(self._file_handles) >= self._max_open_files ): old_index, old_handle = self._file_handles.popitem(last=False) old_handle.close() self._file_handles[shard_index] = handle return handle def close(self) -> None: """Close all open HDF5 shard handles for this dataset.""" for handle in self._file_handles.values(): if handle.id: handle.close() self._file_handles.clear() class MotionClipBatchCache: """Double-buffered motion cache for RL training and evaluation.""" @staticmethod def _infer_cuda_device_index() -> int: device_count = int(torch.cuda.device_count()) local_rank_env = os.environ.get("LOCAL_RANK") if local_rank_env is not None: local_rank = int(local_rank_env) if 0 <= local_rank < device_count: return local_rank return int(torch.cuda.current_device()) @classmethod def _normalize_stage_device( cls, stage_device: Optional[object] ) -> Optional[torch.device]: if stage_device is None: return None if isinstance(stage_device, torch.device): if stage_device.type == "cpu": return None if stage_device.type != "cuda": raise ValueError( f"Unsupported stage_device type: {stage_device.type}" ) if not torch.cuda.is_available(): raise RuntimeError( "stage_device requested CUDA but CUDA is not available" ) if stage_device.index is not None: return stage_device return torch.device("cuda", cls._infer_cuda_device_index()) if isinstance(stage_device, str): stage_device_str = stage_device.strip().lower() if stage_device_str in ("none", "cpu"): return None if stage_device_str == "cuda": if not torch.cuda.is_available(): raise RuntimeError( "stage_device requested CUDA but CUDA is not available" ) return torch.device("cuda", cls._infer_cuda_device_index()) if stage_device_str.startswith("cuda:"): if not torch.cuda.is_available(): raise RuntimeError( "stage_device requested CUDA but CUDA is not available" ) return torch.device(stage_device_str) raise ValueError( f"Unsupported stage_device string: {stage_device}" ) raise TypeError( f"Unsupported stage_device value type: {type(stage_device)}" ) def __init__( self, train_dataset: Dataset[MotionClipSample], *, val_dataset: Optional[Dataset[MotionClipSample]] = None, batch_size: int, stage_device: Optional[torch.device] = None, num_workers: int = 4, prefetch_factor: int = 2, pin_memory: bool = True, persistent_workers: bool = True, sampler_rank: int = 0, sampler_world_size: int = 1, allowed_prefixes: Optional[Sequence[str]] = None, swap_interval_steps: Optional[int] = None, force_timeout_on_swap: bool = True, stage_on_swap_only: bool = False, batch_progress_bar: bool = False, seed: Optional[int] = None, loader_timeout: float = 0.0, ) -> None: if batch_size <= 0: raise ValueError("batch_size must be positive") if float(loader_timeout) < 0.0: raise ValueError("loader_timeout must be >= 0") self._datasets = { "train": train_dataset, "val": val_dataset if val_dataset is not None else train_dataset, } self._mode = "train" self._seed = ( int(seed) if seed is not None else int(time.time_ns() & 0x7FFFFFFF) ) self._stage_device = self._normalize_stage_device(stage_device) self._sampler_rank = int(sampler_rank) self._sampler_world_size = int(max(1, sampler_world_size)) self._batch_size = int(batch_size) self._allowed_prefixes: Optional[Tuple[str, ...]] = ( tuple(allowed_prefixes) if allowed_prefixes is not None else None ) # If enabled, keep the prefetched batch on CPU (FK on CPU) and stage to GPU # only during cache swapping (advance). self._stage_on_swap_only = bool(stage_on_swap_only) self._batch_progress_bar = bool(batch_progress_bar) self._loader_timeout = float(loader_timeout) self.force_timeout_on_swap = bool(force_timeout_on_swap) self._batch_progress_counter: Optional[mp.Value] = None if self._should_use_batch_progress(): ctx = mp.get_context("spawn") self._batch_progress_counter = ctx.Value("i", 0) self.swap_interval_steps = ( swap_interval_steps if swap_interval_steps is not None else train_dataset.max_frame_length ) self._num_workers = int(max(0, num_workers)) self._prefetch_factor = ( prefetch_factor if prefetch_factor is not None else None ) self._pin_memory = bool(pin_memory) self._persistent_workers = bool(persistent_workers and num_workers > 0) self._dataloader: Optional[DataLoader] = None self._sampler: Optional[Sampler[int]] = None self._iterator: Optional[Iterator[ClipBatch]] = None self._current_batch: Optional[ClipBatch] = None self._next_batch: Optional[ClipBatch] = None self._swap_index = 0 self._effective_batch_size: Optional[int] = None self._num_batches: Optional[int] = None # Weighted-bin sampling state self._weighted_bin_enabled: bool = False self._weighted_bin_bins: Optional[List[List[int]]] = None self._weighted_bin_ratios: Optional[List[float]] = None self._weighted_bin_specs: Optional[List[Dict[str, Any]]] = None self._cache_curriculum_enabled: bool = False self._cache_curriculum_cfg: Dict[str, Any] = {} self._cache_curriculum_sampler: Optional[ PrioritizedInfiniteSampler ] = None self._cache_curriculum_dump_enabled: bool = False self._cache_curriculum_dump_every_swaps: int = 10 self._cache_curriculum_dump_chunk_size: int = 4096 self._cache_curriculum_dump_dir: Path = Path( "cache_curriculum_window_scores" ) self._cache_curriculum_last_dump_swap: int = -1 # Async GPU staging helpers self._copy_stream = None self._pending_ready_event = None self._current_ready_event = None self._next_ready_event = None self._build_dataloader() if ( self._stage_device is not None and self._stage_device.type == "cuda" ): self._copy_stream = torch.cuda.Stream(device=self._stage_device) self._prime_buffers() @property def current_batch(self) -> ClipBatch: assert self._current_batch is not None return self._current_batch @property def max_frame_length(self) -> int: return self.current_batch.max_frame_length @property def clip_count(self) -> int: return self.current_batch.lengths.shape[0] @property def mode(self) -> str: return self._mode @property def swap_index(self) -> int: return self._swap_index @property def num_batches(self) -> int: if self._num_batches is None: raise RuntimeError("DataLoader is not initialised") return int(self._num_batches) def set_mode(self, mode: str) -> None: if mode == self._mode: return if mode not in self._datasets: raise ValueError(f"Unknown cache mode: {mode}") self._mode = mode self._build_dataloader() self._prime_buffers() def set_seed(self, seed: int, *, reinitialize: bool = True) -> None: self._seed = int(seed) if reinitialize: self._build_dataloader() self._prime_buffers() def advance(self) -> None: if self._stage_on_swap_only: if self._next_batch is None: self._next_batch = self._fetch_next_batch() # Stage the prefetched CPU batch to GPU only at swap time. staged = self._stage_batch_blocking(self._next_batch) self._current_batch = staged self._next_batch = self._fetch_next_batch() self._swap_index += 1 return if self._next_batch is None: self._next_batch = self._fetch_next_batch() # Ensure asynchronous staging finished before swapping in next batch if ( self._next_ready_event is not None and self._stage_device is not None and self._stage_device.type == "cuda" ): torch.cuda.current_stream(self._stage_device).wait_event( self._next_ready_event ) self._current_batch = self._next_batch self._next_batch = self._fetch_next_batch() self._swap_index += 1 # ------------------------- # Weighted-bin configuration # ------------------------- def enable_weighted_bin_sampling( self, cfg: Optional[Dict[str, Any]] = None ) -> None: """Enable regex-based weighted-bin sampling over manifest motion keys. The configuration must provide a list under ``bin_regex_patterns`` (or the legacy name ``bin_regrex_patterns``), where each element is a mapping with: - ``regex`` (or ``regrex``): Python regular expression applied to the manifest clip key (e.g., ``AMASS_.*``, ``VR_pico_.*``). - ``ratio``: Target sampling ratio in [0, 1]. The sum of explicit bin ratios must be <= 1.0. Any remaining mass is assigned to an implicit ``others`` bin that collects all clips not matched by any regex. """ cfg_local: Dict[str, Any] = dict(cfg or {}) if self._cache_curriculum_enabled: raise ValueError( "weighted-bin and cache curriculum sampling cannot be enabled together." ) dataset = self._datasets.get("train") if dataset is None: raise ValueError( "Weighted-bin sampling requires a training dataset" ) # Collect manifest-level motion keys for all windows in order window_keys: List[str] = [] for window in dataset.windows: motion_key = getattr(window, "raw_motion_key", None) if motion_key is None: full_key = getattr(window, "motion_key", "") if "__start_" in full_key: motion_key = full_key.split("__start_", 1)[0] else: motion_key = full_key window_keys.append(motion_key) bin_indices, all_ratios, specs = _configure_weighted_bins( keys=window_keys, cfg=cfg_local, batch_size_for_log=int(self._batch_size), ) # Log summary in terms of windows table_rows = [] for item in specs: table_rows.append( [ item["name"], item["regex"], f"{item['ratio']:.4f}", int(item["count"]), f"{item['dataset_fraction']:.4f}", f"{item['batch_fraction']:.4f}", ] ) headers = [ "bin", "regex", "final_ratio", "num_windows", "dataset_fraction", "batch_fraction", ] logger.info( "Motion cache weighted-bin sampling configured:\n" + tabulate(table_rows, headers=headers, tablefmt="simple_outline") ) # Activate weighted-bin sampling and rebuild dataloader/cache self._weighted_bin_enabled = True self._weighted_bin_bins = bin_indices self._weighted_bin_ratios = all_ratios self._weighted_bin_specs = specs self._build_dataloader() self._prime_buffers() def enable_cache_curriculum_sampling( self, cfg: Optional[Dict[str, Any]] = None ) -> None: if self._weighted_bin_enabled: raise ValueError( "cache curriculum and weighted-bin sampling cannot be enabled together." ) self._cache_curriculum_enabled = True self._cache_curriculum_cfg = dict(cfg or {}) self._cache_curriculum_dump_enabled = bool( self._cache_curriculum_cfg.get( "dump_whole_window_scores_json", True ) ) self._cache_curriculum_dump_every_swaps = max( 1, int( self._cache_curriculum_cfg.get( "dump_whole_window_scores_every_swaps", 10 ) ), ) self._cache_curriculum_dump_chunk_size = max( 1, int( self._cache_curriculum_cfg.get( "dump_whole_window_scores_chunk_size", 4096 ) ), ) self._cache_curriculum_dump_dir = Path( str( self._cache_curriculum_cfg.get( "dump_whole_window_scores_dir", "cache_curriculum_window_scores", ) ) ) self._cache_curriculum_last_dump_swap = -1 self._prepare_cache_curriculum_dump_dir( self._cache_curriculum_dump_dir, reason="enabled", ) self._build_dataloader() self._prime_buffers() def _prepare_cache_curriculum_dump_dir( self, dump_dir: Path, *, reason: str ) -> None: self._cache_curriculum_dump_dir = Path(str(dump_dir)) if not self._cache_curriculum_dump_enabled: return self._cache_curriculum_dump_dir.mkdir(parents=True, exist_ok=True) logger.info( "Cache curriculum whole-window score dump " f"{reason}: dir={self._cache_curriculum_dump_dir}, " f"every_swaps={self._cache_curriculum_dump_every_swaps}, " f"rank={self._sampler_rank}" ) def set_cache_curriculum_dump_dir(self, dump_dir: str) -> None: self._prepare_cache_curriculum_dump_dir( Path(str(dump_dir)), reason="directory set", ) def update_cache_curriculum( self, *, window_indices: Tensor, mpkpe_signal_means: Tensor, completion_rate_means: Tensor, counts: Tensor, swap_index: int, ) -> bool: if self._cache_curriculum_sampler is None: return False updated = ( self._cache_curriculum_sampler.maybe_update_from_observations( window_indices=window_indices, mpkpe_signal_means=mpkpe_signal_means, completion_rate_means=completion_rate_means, counts=counts, swap_index=swap_index, ) ) if updated: self._refresh_prefetched_batch() self._maybe_dump_cache_curriculum_scores_json(swap_index=swap_index) return updated def _refresh_prefetched_batch(self) -> None: if self._next_batch is None: return self._next_batch = self._fetch_next_batch() def _maybe_dump_cache_curriculum_scores_json( self, *, swap_index: int ) -> None: if not self._cache_curriculum_dump_enabled: return if self._cache_curriculum_sampler is None: return swap_idx = int(swap_index) if swap_idx <= 0: return if swap_idx % self._cache_curriculum_dump_every_swaps != 0: return if self._cache_curriculum_last_dump_swap == swap_idx: return dataset = self._datasets["train"] ds_len = int(len(dataset)) if ds_len <= 0: return self._cache_curriculum_dump_dir.mkdir(parents=True, exist_ok=True) output_path = self._cache_curriculum_dump_dir / ( "whole_window_scores_" f"rank_{self._sampler_rank:04d}_swap_{swap_idx:06d}.json" ) sampler_version = int(self._cache_curriculum_sampler.state_version) windows = dataset.windows score_values: List[float] = [] completion_values: List[float] = [] rel_improve_values: List[float] = [] selection_count_values: List[int] = [] seen_values: List[bool] = [] in_pool_values: List[bool] = [] chunk_size = max(1, int(self._cache_curriculum_dump_chunk_size)) for chunk_start in range(0, ds_len, chunk_size): chunk_end = min(ds_len, chunk_start + chunk_size) chunk_indices = torch.arange( chunk_start, chunk_end, dtype=torch.long ) chunk_scores = ( self._cache_curriculum_sampler.get_scores_for_indices( chunk_indices ) ) chunk_state = ( self._cache_curriculum_sampler.get_window_state_for_indices( chunk_indices ) ) if chunk_scores.numel() != chunk_indices.numel(): raise ValueError( "Whole-window score dump shape mismatch for " "cache curriculum sampler." ) score_values.extend(chunk_scores.tolist()) completion_values.extend( chunk_state["ema_completion_rate"].tolist() ) rel_improve_values.extend( chunk_state["completion_rate_rel_improve"].tolist() ) selection_count_values.extend( chunk_state["selection_count"].tolist() ) seen_values.extend(chunk_state["seen"].tolist()) in_pool_values.extend(chunk_state["in_prioritized_pool"].tolist()) rows: List[Dict[str, Any]] = [] for window_index in range(ds_len): window = windows[window_index] rows.append( { "swap_index": int(swap_idx), "rank": int(self._sampler_rank), "sampler_state_version": sampler_version, "window_index": int(window_index), "raw_motion_key": str(window.raw_motion_key), "motion_key": str(window.motion_key), "start": int(window.start), "length": int(window.length), "score": float(score_values[window_index]), "selection_count": int( selection_count_values[window_index] ), "ema_completion_rate": float( completion_values[window_index] ), "completion_rate_rel_improve": float( rel_improve_values[window_index] ), "seen": bool(seen_values[window_index]), "in_prioritized_pool": bool(in_pool_values[window_index]), } ) payload: Dict[str, Any] = { "swap_index": int(swap_idx), "rank": int(self._sampler_rank), "sampler_state_version": sampler_version, "num_windows": int(ds_len), "pool_metrics": self._cache_curriculum_sampler.get_pool_statistics() or {}, "rows": rows, } with output_path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2) handle.write("\n") self._cache_curriculum_last_dump_swap = swap_idx def cache_curriculum_scores_for_window_indices( self, window_indices: Tensor ) -> Optional[Tuple[Tensor, Dict[str, Tensor], int]]: if self._cache_curriculum_sampler is None: return None scores = self._cache_curriculum_sampler.get_scores_for_indices( window_indices ) state = self._cache_curriculum_sampler.get_window_state_for_indices( window_indices ) version = self._cache_curriculum_sampler.state_version return scores, state, version def cache_curriculum_pool_statistics( self, ) -> Optional[Dict[str, float]]: if self._cache_curriculum_sampler is None: return None return self._cache_curriculum_sampler.get_pool_statistics() def sample_env_assignments( self, num_envs: int, n_future_frames: int, device: torch.device, *, deterministic_start: bool = False, ) -> Tuple[Tensor, Tensor]: batch = self.current_batch lengths = batch.lengths.to(device) if num_envs <= 0: raise ValueError("num_envs must be positive") total = int(lengths.shape[0]) if total == 0: raise ValueError( "Cannot sample from an empty batch. Ensure the cache contains " "at least one motion clip before calling sample_env_assignments." ) clip_indices = torch.randint( low=0, high=total, size=(num_envs,), device=device ) max_start = torch.clamp( lengths[clip_indices] - 1 - n_future_frames, min=0 ) if deterministic_start: frame_starts = torch.zeros_like(max_start) else: rand = torch.rand_like(max_start, dtype=torch.float32) frame_starts = torch.floor(rand * (max_start + 1).float()).to( torch.long ) return clip_indices, frame_starts def _prepare_gather_indices( self, *, clip_indices: Tensor, frame_indices: Tensor, n_future_frames: int, ) -> Tuple[Tensor, Tensor]: batch = self.current_batch staged_device = batch.lengths.device selected_clips = clip_indices.to( staged_device, dtype=torch.long ).clone() frame_indices = frame_indices.to( staged_device, dtype=torch.long ).clone() temporal_span = 1 + int(n_future_frames) time_offsets = torch.arange( temporal_span, device=staged_device, dtype=torch.long ) gather_timesteps = frame_indices[:, None] + time_offsets[None, :] lengths = batch.lengths max_valid = torch.clamp( lengths.index_select(0, selected_clips) - 1, min=0 ) gather_timesteps = torch.minimum( gather_timesteps, max_valid[:, None] ).clone() return selected_clips, gather_timesteps def gather_tensor( self, tensor_name: str, *, clip_indices: Tensor, frame_indices: Tensor, n_future_frames: int, ) -> Tensor: batch = self.current_batch if tensor_name not in batch.tensors: raise KeyError( f"Tensor '{tensor_name}' is not present in current_batch" ) selected_clips, gather_timesteps = self._prepare_gather_indices( clip_indices=clip_indices, frame_indices=frame_indices, n_future_frames=n_future_frames, ) tensor = batch.tensors[tensor_name] return tensor[selected_clips[:, None], gather_timesteps, ...] def lengths_for_indices(self, clip_indices: Tensor) -> Tensor: lengths = self.current_batch.lengths.to(clip_indices.device) return lengths.index_select(0, clip_indices.long()) def motion_keys_for_indices(self, clip_indices: Tensor) -> List[str]: result = [] base_keys = self.current_batch.motion_keys for idx in clip_indices.tolist(): result.append(base_keys[int(idx)]) return result def window_indices_for_indices(self, clip_indices: Tensor) -> Tensor: base_indices = self.current_batch.window_indices.to( clip_indices.device ) return base_indices.index_select(0, clip_indices.long()) def _prime_buffers(self) -> None: if self._stage_on_swap_only: # Prefetch on CPU; stage to GPU only for current batch. cpu_current = self._fetch_next_batch() self._current_batch = self._stage_batch_blocking(cpu_current) self._next_batch = self._fetch_next_batch() self._pending_ready_event = None self._current_ready_event = None self._next_ready_event = None return self._current_batch = self._fetch_next_batch() # Ensure first staged batch is ready before consumption if ( self._current_ready_event is not None and self._stage_device is not None and self._stage_device.type == "cuda" ): t0 = time.time() torch.cuda.current_stream(self._stage_device).wait_event( self._current_ready_event ) t1 = time.time() logger.info( f"Perf/Cache/cuda_wait_event_ms={((t1 - t0) * 1e3):.2f} (first)" ) self._next_batch = self._fetch_next_batch() def _fetch_next_batch(self) -> ClipBatch: batch = self._load_next_batch() if self._stage_on_swap_only: # Prefetch raw batch on CPU. return batch staged = self._stage_batch(batch, record_event=True) # Move pending event into current/next slot if self._current_batch is None: self._current_ready_event = self._pending_ready_event else: self._next_ready_event = self._pending_ready_event self._pending_ready_event = None return staged def _load_next_batch(self) -> ClipBatch: if self._should_use_batch_progress(): return self._load_next_batch_with_progress() return self._load_next_batch_raw() def _load_next_batch_raw(self) -> ClipBatch: if self._iterator is None: self._iterator = self._build_iterator() try: batch = next(self._iterator) except StopIteration: self._iterator = self._build_iterator(reset_epoch=True) batch = next(self._iterator) return batch def _load_next_batch_with_progress(self) -> ClipBatch: if self._iterator is None: self._iterator = self._build_iterator() expected = int(self._effective_batch_size or self._batch_size) counter = self._batch_progress_counter if counter is None: return self._load_next_batch_raw() with counter.get_lock(): counter.value = 0 pbar = tqdm( total=expected, desc="Collecting motion batch", leave=False, dynamic_ncols=True, ) last = 0 with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(self._load_next_batch_raw) while not future.done(): with counter.get_lock(): value = counter.value if value > last: step = min(value, expected) - last if step > 0: pbar.update(step) last += step time.sleep(0.05) batch = future.result(timeout=self._result_timeout()) with counter.get_lock(): value = counter.value if value > last: step = min(value, expected) - last if step > 0: pbar.update(step) pbar.close() return batch def _stage_batch_blocking(self, batch: ClipBatch) -> ClipBatch: """Stage a CPU batch to the configured device on the current stream. This path is used when `stage_on_swap_only=True` so that only the current cache batch resides on GPU. """ if self._stage_device is None: return batch non_blocking = bool( self._pin_memory and self._stage_device.type == "cuda" ) tensors = { name: tensor.to(self._stage_device, non_blocking=non_blocking) for name, tensor in batch.tensors.items() } lengths = batch.lengths.to( self._stage_device, non_blocking=non_blocking ) window_indices = batch.window_indices.to( self._stage_device, non_blocking=non_blocking ) staged = ClipBatch( tensors=tensors, lengths=lengths, motion_keys=batch.motion_keys, raw_motion_keys=getattr( batch, "raw_motion_keys", batch.motion_keys ), window_indices=window_indices, max_frame_length=batch.max_frame_length, ) return staged def _stage_batch( self, batch: ClipBatch, record_event: bool = False, ) -> ClipBatch: if self._stage_device is None: return batch # If CUDA, copy on a dedicated stream and record readiness if self._copy_stream is None and ( self._stage_device is not None and self._stage_device.type == "cuda" ): self._copy_stream = torch.cuda.Stream(device=self._stage_device) logger.info( f"Perf/Cache: created CUDA copy stream lazily on {self._stage_device}" ) if self._copy_stream is not None: # estimate payload size for logging try: total_bytes = 0 for tensor in batch.tensors.values(): total_bytes += int(tensor.element_size() * tensor.numel()) total_bytes += int( batch.lengths.element_size() * batch.lengths.numel() ) total_bytes += int( batch.window_indices.element_size() * batch.window_indices.numel() ) except Exception: total_bytes = -1 with torch.cuda.stream(self._copy_stream): tensors = { name: tensor.to(self._stage_device, non_blocking=True) for name, tensor in batch.tensors.items() } lengths = batch.lengths.to( self._stage_device, non_blocking=True ) window_indices = batch.window_indices.to( self._stage_device, non_blocking=True ) if record_event: ev = torch.cuda.Event() ev.record(self._copy_stream) self._pending_ready_event = ev else: tensors = { name: tensor.to(self._stage_device, non_blocking=True) for name, tensor in batch.tensors.items() } lengths = batch.lengths.to(self._stage_device, non_blocking=True) window_indices = batch.window_indices.to( self._stage_device, non_blocking=True ) return ClipBatch( tensors=tensors, lengths=lengths, motion_keys=batch.motion_keys, raw_motion_keys=getattr( batch, "raw_motion_keys", batch.motion_keys ), window_indices=window_indices, max_frame_length=batch.max_frame_length, ) def _build_iterator( self, *, reset_epoch: bool = False ) -> Iterator[ClipBatch]: if self._dataloader is None: raise RuntimeError("DataLoader is not initialised") if isinstance(self._sampler, DistributedSampler) and reset_epoch: self._sampler.set_epoch(self._swap_index + 1) return iter(self._dataloader) def _build_dataloader(self) -> None: dataset = self._datasets[self._mode] dataset.set_progress_counter(self._batch_progress_counter) # Clamp batch size to dataset length to avoid empty iterator when drop_last is disabled effective_batch_size = self._batch_size ds_len = len(dataset) if isinstance(ds_len, int) and ds_len > 0: effective_batch_size = max(1, min(self._batch_size, ds_len)) # Sampler selection: validation uses standard distributed/sequential samplers; # training can optionally use weighted-bin sampling. if self._mode == "val": if self._sampler_world_size > 1: self._sampler = DistributedSampler( dataset, num_replicas=self._sampler_world_size, rank=self._sampler_rank, shuffle=False, drop_last=False, ) else: self._sampler = None self._cache_curriculum_sampler = None else: if self._cache_curriculum_enabled: seed = self._seed + self._sampler_rank * 100003 cfg = dict(self._cache_curriculum_cfg) self._cache_curriculum_sampler = PrioritizedInfiniteSampler( dataset_len=ds_len, batch_size=effective_batch_size, seed=seed, p_a_ratio=float(cfg.get("p_a_ratio", 0.2)), ema_alpha_signal=float(cfg.get("ema_alpha_signal", 0.2)), ema_alpha_rel_improve=float( cfg.get("ema_alpha_rel_improve", 0.2) ), relative_eps=float(cfg.get("relative_eps", 1.0e-6)), ) self._cache_curriculum_last_dump_swap = -1 self._sampler = self._cache_curriculum_sampler elif ( self._weighted_bin_enabled and self._weighted_bin_bins is not None and self._weighted_bin_ratios is not None ): seed = self._seed + self._sampler_rank * 100003 self._sampler = WeightedBinInfiniteSampler( dataset_len=ds_len, bin_indices=self._weighted_bin_bins, ratios=self._weighted_bin_ratios, batch_size=effective_batch_size, seed=seed, ) self._cache_curriculum_sampler = None else: if self._sampler_world_size > 1: # Infinite sampler for training: no epoch boundaries self._sampler = InfiniteDistributedSampler( dataset, num_replicas=self._sampler_world_size, rank=self._sampler_rank, shuffle=True, drop_last=False, ) else: # Infinite sampler for single-process training self._sampler = InfiniteRandomSampler(dataset) self._cache_curriculum_sampler = None # Only pass prefetch_factor when using workers pf = ( self._prefetch_factor if (self._num_workers and self._num_workers > 0) else None ) pw = ( self._persistent_workers if (self._num_workers and self._num_workers > 0) else False ) # Collate wrapper: in validation, pad the batch up to cache size by # uniformly repeating samples when dataset is smaller than batch size. collate = partial( _cache_collate_fn, mode=self._mode, batch_size=self._batch_size, ) mp_ctx = None if self._num_workers and self._num_workers > 0: mp_ctx = mp.get_context("spawn") worker_init_fn = None if ( self._num_workers > 0 and self._stage_device is not None and self._stage_device.type == "cuda" ): worker_init_fn = _cpu_only_dataloader_worker_init_fn self._dataloader = DataLoader( dataset, batch_size=effective_batch_size, sampler=self._sampler, shuffle=(self._sampler is None and self._mode != "val"), num_workers=self._num_workers, prefetch_factor=pf, pin_memory=self._pin_memory, timeout=self._loader_timeout_seconds(), persistent_workers=pw, collate_fn=collate, drop_last=False, multiprocessing_context=mp_ctx, worker_init_fn=worker_init_fn, ) self._iterator = None self._current_batch = None self._next_batch = None self._swap_index = 0 # Compute number of batches only for validation; training is infinite local_len = ds_len if self._mode == "val": if self._sampler is not None: local_len = ( ds_len + self._sampler_world_size - 1 ) // self._sampler_world_size self._effective_batch_size = int(effective_batch_size) self._num_batches = ( local_len + self._effective_batch_size - 1 ) // self._effective_batch_size else: self._effective_batch_size = int(effective_batch_size) self._num_batches = 2**31 # effectively infinite for logging def close(self) -> None: """Release DataLoader workers and close underlying HDF5 datasets.""" datasets = self.__dict__.get("_datasets") if datasets is None: return self._iterator = None self._current_batch = None self._next_batch = None self._dataloader = None self._copy_stream = None self._pending_ready_event = None self._current_ready_event = None self._next_ready_event = None for ds in datasets.values(): if ds is not None: ds.close() def __del__(self) -> None: self.close() def _loader_timeout_seconds(self) -> float: if not self.force_timeout_on_swap: return 0.0 return self._loader_timeout def _result_timeout(self) -> Optional[float]: timeout_s = self._loader_timeout_seconds() if timeout_s <= 0.0: return None return timeout_s + 1.0 def _should_use_batch_progress(self) -> bool: if not self._batch_progress_bar: return False if self._sampler_world_size > 1: return False if self._loader_timeout_seconds() > 0.0: return False return True ================================================ FILE: holomotion/src/training/reference_filter_export.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import json import tempfile from pathlib import Path from typing import Mapping, Sequence import matplotlib import numpy as np from omegaconf import DictConfig, ListConfig from scipy.spatial.transform import Rotation as Rotation3D from holomotion.src.training.h5_dataloader import ( MotionClipSample, build_motion_datasets_from_cfg, ) matplotlib.use("Agg") import matplotlib.pyplot as plt def _to_numpy(array_like) -> np.ndarray: if hasattr(array_like, "detach"): array_like = array_like.detach().cpu().numpy() return np.asarray(array_like, dtype=np.float32) def _require_tensor( tensors: Mapping[str, object], tensor_name: str, error_message: str ) -> np.ndarray: if tensor_name not in tensors: raise ValueError(error_message) return _to_numpy(tensors[tensor_name]) def _quat_xyzw_to_rpy(quat_xyzw: np.ndarray) -> np.ndarray: quat_xyzw = np.asarray(quat_xyzw, dtype=np.float32) flat = quat_xyzw.reshape(-1, 4) euler = Rotation3D.from_quat(flat).as_euler("xyz", degrees=False) return euler.reshape(*quat_xyzw.shape[:-1], 3).astype( np.float32, copy=False ) def _write_npz(output_path: Path, payload: Mapping[str, np.ndarray]) -> None: np.savez(str(output_path), **payload) def _plot_series_groups( output_path: Path, title: str, groups: Sequence[tuple[str, np.ndarray, np.ndarray]], axis_labels: Sequence[str] = ("x", "y", "z"), ) -> None: nrows = len(groups) ncols = len(axis_labels) fig, axes = plt.subplots( nrows=nrows, ncols=ncols, figsize=(4.0 * ncols, 2.8 * max(1, nrows)), squeeze=False, ) plot_steps = np.arange(groups[0][1].shape[0], dtype=np.int32) for row_idx, (group_name, ref_values, ft_values) in enumerate(groups): for col_idx, axis_name in enumerate(axis_labels): ax = axes[row_idx, col_idx] ax.plot( plot_steps, ref_values[:, col_idx], label="raw", linewidth=1.4, ) ax.plot( plot_steps, ft_values[:, col_idx], label="filtered", linewidth=1.2, ) ax.set_title(f"{group_name} {axis_name}") ax.grid(True, alpha=0.3) if row_idx == 0 and col_idx == 0: ax.legend(loc="best") fig.suptitle(title) fig.tight_layout() fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) def _plot_dof_matrix( output_path: Path, title: str, dof_names: Sequence[str], ref_values: np.ndarray, ft_values: np.ndarray, ) -> None: num_dofs = len(dof_names) fig, axes = plt.subplots( nrows=num_dofs, ncols=1, figsize=(14.0, max(2.8 * num_dofs, 3.5)), squeeze=False, ) plot_steps = np.arange(ref_values.shape[0], dtype=np.int32) for idx, dof_name in enumerate(dof_names): ax = axes[idx, 0] ax.plot(plot_steps, ref_values[:, idx], label="raw", linewidth=1.4) ax.plot(plot_steps, ft_values[:, idx], label="filtered", linewidth=1.2) ax.set_title(dof_name) ax.grid(True, alpha=0.3) if idx == 0: ax.legend(loc="best") fig.suptitle(title) fig.tight_layout() fig.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig) def export_reference_filter_debug_artifacts( *, sample: MotionClipSample, output_dir: str | Path, body_names: Sequence[str], dof_names: Sequence[str], selected_body_links: Sequence[str], ) -> Path: tensors = sample.tensors if "ft_ref_rg_pos" not in tensors or "ft_ref_dof_pos" not in tensors: raise ValueError( "Filtered reference tensors are unavailable. Ensure online filtering " "is enabled and ft_ref_* tensors are materialized." ) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) ref_root_pos = _require_tensor( tensors, "ref_root_pos", "Missing ref_root_pos tensor in sampled clip.", ) ft_root_pos = _require_tensor( tensors, "ft_ref_root_pos", "Missing ft_ref_root_pos tensor in sampled clip.", ) ref_root_rot = _require_tensor( tensors, "ref_root_rot", "Missing ref_root_rot tensor in sampled clip.", ) ft_root_rot = _require_tensor( tensors, "ft_ref_root_rot", "Missing ft_ref_root_rot tensor in sampled clip.", ) ref_root_vel = _require_tensor( tensors, "ref_root_vel", "Missing ref_root_vel tensor in sampled clip.", ) ft_root_vel = _require_tensor( tensors, "ft_ref_root_vel", "Missing ft_ref_root_vel tensor in sampled clip.", ) ref_root_ang_vel = _require_tensor( tensors, "ref_root_ang_vel", "Missing ref_root_ang_vel tensor in sampled clip.", ) ft_root_ang_vel = _require_tensor( tensors, "ft_ref_root_ang_vel", "Missing ft_ref_root_ang_vel tensor in sampled clip.", ) ref_rg_pos = _require_tensor( tensors, "ref_rg_pos", "Missing ref_rg_pos tensor in sampled clip.", ) ft_ref_rg_pos = _require_tensor( tensors, "ft_ref_rg_pos", "Missing ft_ref_rg_pos tensor in sampled clip.", ) ref_body_vel = _require_tensor( tensors, "ref_body_vel", "Missing ref_body_vel tensor in sampled clip.", ) ft_ref_body_vel = _require_tensor( tensors, "ft_ref_body_vel", "Missing ft_ref_body_vel tensor in sampled clip.", ) ref_body_ang_vel = _require_tensor( tensors, "ref_body_ang_vel", "Missing ref_body_ang_vel tensor in sampled clip.", ) ft_ref_body_ang_vel = _require_tensor( tensors, "ft_ref_body_ang_vel", "Missing ft_ref_body_ang_vel tensor in sampled clip.", ) ref_dof_pos = _require_tensor( tensors, "ref_dof_pos", "Missing ref_dof_pos tensor in sampled clip.", ) ft_ref_dof_pos = _require_tensor( tensors, "ft_ref_dof_pos", "Missing ft_ref_dof_pos tensor in sampled clip.", ) ref_dof_vel = _require_tensor( tensors, "ref_dof_vel", "Missing ref_dof_vel tensor in sampled clip.", ) ft_ref_dof_vel = _require_tensor( tensors, "ft_ref_dof_vel", "Missing ft_ref_dof_vel tensor in sampled clip.", ) body_name_to_idx = {name: idx for idx, name in enumerate(body_names)} missing_links = [ link_name for link_name in selected_body_links if link_name not in body_name_to_idx ] if missing_links: raise ValueError( f"Requested body links are missing from robot.body_names: {missing_links}" ) ref_root_rpy = _quat_xyzw_to_rpy(ref_root_rot) ft_root_rpy = _quat_xyzw_to_rpy(ft_root_rot) root_payload = { "ref_global_pos": ref_root_pos, "ft_ref_global_pos": ft_root_pos, "ref_rpy": ref_root_rpy, "ft_ref_rpy": ft_root_rpy, "ref_lin_vel": ref_root_vel, "ft_ref_lin_vel": ft_root_vel, "ref_ang_vel": ref_root_ang_vel, "ft_ref_ang_vel": ft_root_ang_vel, } _write_npz(output_dir / "root_signals.npz", root_payload) body_payload: dict[str, np.ndarray] = {} for link_name in selected_body_links: body_idx = body_name_to_idx[link_name] body_payload[f"{link_name}__ref_global_pos"] = ref_rg_pos[ :, body_idx, : ] body_payload[f"{link_name}__ft_ref_global_pos"] = ft_ref_rg_pos[ :, body_idx, : ] body_payload[f"{link_name}__ref_lin_vel"] = ref_body_vel[ :, body_idx, : ] body_payload[f"{link_name}__ft_ref_lin_vel"] = ft_ref_body_vel[ :, body_idx, : ] body_payload[f"{link_name}__ref_ang_vel"] = ref_body_ang_vel[ :, body_idx, : ] body_payload[f"{link_name}__ft_ref_ang_vel"] = ft_ref_body_ang_vel[ :, body_idx, : ] _write_npz(output_dir / "bodylink_signals.npz", body_payload) dof_payload = { "ref_dof_pos": ref_dof_pos, "ft_ref_dof_pos": ft_ref_dof_pos, "ref_dof_vel": ref_dof_vel, "ft_ref_dof_vel": ft_ref_dof_vel, } _write_npz(output_dir / "dof_signals.npz", dof_payload) filter_cutoff_tensor = tensors.get("filter_cutoff_hz") filter_cutoff_hz = None if filter_cutoff_tensor is not None: cutoff_values = _to_numpy(filter_cutoff_tensor).reshape(-1) if cutoff_values.size > 0: filter_cutoff_hz = float(cutoff_values[0]) metadata = { "motion_key": sample.motion_key, "raw_motion_key": sample.raw_motion_key, "window_index": int(sample.window_index), "length": int(sample.length), "filter_cutoff_hz": filter_cutoff_hz, "selected_body_links": list(selected_body_links), "body_names": list(body_names), "dof_names": list(dof_names), } (output_dir / "metadata.json").write_text( json.dumps(metadata, indent=2, sort_keys=True), encoding="utf-8", ) _plot_series_groups( output_dir / "root_comparison.png", title="Root Raw vs Filtered Reference Signals", groups=[ ("global_pos", ref_root_pos, ft_root_pos), ("rpy", ref_root_rpy, ft_root_rpy), ("lin_vel", ref_root_vel, ft_root_vel), ("ang_vel", ref_root_ang_vel, ft_root_ang_vel), ], ) for link_name in selected_body_links: _plot_series_groups( output_dir / f"{link_name}_comparison.png", title=f"{link_name} Raw vs Filtered Reference Signals", groups=[ ( "global_pos", body_payload[f"{link_name}__ref_global_pos"], body_payload[f"{link_name}__ft_ref_global_pos"], ), ( "lin_vel", body_payload[f"{link_name}__ref_lin_vel"], body_payload[f"{link_name}__ft_ref_lin_vel"], ), ( "ang_vel", body_payload[f"{link_name}__ref_ang_vel"], body_payload[f"{link_name}__ft_ref_ang_vel"], ), ], ) _plot_dof_matrix( output_dir / "dof_pos_comparison.png", title="DOF Position Raw vs Filtered", dof_names=dof_names, ref_values=ref_dof_pos, ft_values=ft_ref_dof_pos, ) _plot_dof_matrix( output_dir / "dof_vel_comparison.png", title="DOF Velocity Raw vs Filtered", dof_names=dof_names, ref_values=ref_dof_vel, ft_values=ft_ref_dof_vel, ) return output_dir def _to_plain_sequence(values) -> list[str]: if values is None: return [] if isinstance(values, (ListConfig, tuple, list)): return [str(v) for v in values] return [str(values)] def export_reference_filter_artifacts_from_config(config) -> Path: debug_cfg = getattr(config, "debug_reference_filter_export", None) if debug_cfg is None or not bool(debug_cfg.get("enabled", False)): raise ValueError("debug_reference_filter_export.enabled must be true.") motion_cfg = config.robot.motion online_filter_cfg = motion_cfg.get("online_filter", {}) if not bool(online_filter_cfg.get("enabled", False)): raise ValueError( "Reference filter debug export requires robot.motion.online_filter.enabled=true." ) output_dir = debug_cfg.get("output_dir", None) if output_dir in (None, ""): output_dir = tempfile.mkdtemp(prefix="motrack-ref-filter-") train_dataset, _, _ = build_motion_datasets_from_cfg( motion_cfg=motion_cfg, max_frame_length=int(motion_cfg.max_frame_length), min_window_length=int(motion_cfg.min_frame_length), world_frame_normalization=bool( motion_cfg.get("world_frame_normalization", True) ), ) sample = train_dataset[0] return export_reference_filter_debug_artifacts( sample=sample, output_dir=Path(str(output_dir)), body_names=_to_plain_sequence(config.robot.body_names), dof_names=_to_plain_sequence(config.robot.dof_names), selected_body_links=_to_plain_sequence( debug_cfg.get("selected_body_links", []) ), ) ================================================ FILE: holomotion/src/training/train.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import os from pathlib import Path import sys import hydra from hydra.utils import get_class from omegaconf import ListConfig, OmegaConf from accelerate import Accelerator from accelerate.utils import ProjectConfiguration from loguru import logger from holomotion.src.training.reference_filter_export import ( export_reference_filter_artifacts_from_config, ) from holomotion.src.utils.config import compile_config def _resolve_mujoco_eval_onnx_names( exported_dir: Path, ckpt_onnx_names ) -> list[str]: if not exported_dir.is_dir(): raise FileNotFoundError( f"Exported ONNX directory not found: {exported_dir}" ) existing = sorted([p.name for p in exported_dir.glob("*.onnx")]) if len(existing) == 0: raise FileNotFoundError(f"No .onnx files found under {exported_dir}") existing_set = set(existing) if ckpt_onnx_names is None: return existing if isinstance(ckpt_onnx_names, ListConfig): requested = list(ckpt_onnx_names) elif isinstance(ckpt_onnx_names, (list, tuple)): requested = list(ckpt_onnx_names) else: raise TypeError( "mujoco_eval.ckpt_onnx_names must be a list/tuple, " f"got {type(ckpt_onnx_names)}" ) requested_norm = [] for name in requested: name_str = str(name).strip() if name_str == "": continue requested_norm.append(Path(name_str).name) if len(requested_norm) == 0: return existing selected = [name for name in requested_norm if name in existing_set] if len(selected) == 0: raise ValueError( "No requested ONNX checkpoints exist under exported directory. " f"exported_dir={exported_dir}, requested={requested_norm}, " f"existing={existing}" ) return selected def _exec_mujoco_eval(eval_override_dict: dict) -> None: cli_args = [] for key in sorted(eval_override_dict.keys()): value = eval_override_dict[key] if value is None: continue if isinstance(value, bool): cli_args.append(f"{key}={'true' if value else 'false'}") elif isinstance(value, (int, float)): cli_args.append(f"{key}={value}") elif isinstance(value, str): cli_args.append(f"{key}={value}") elif isinstance(value, (list, tuple)): inner = ",".join([str(v) for v in value]) cli_args.append(f"{key}=[{inner}]") else: cli_args.append(f"{key}={value}") argv = [ sys.executable, "-m", "holomotion.src.evaluation.eval_mujoco_sim2sim", ] + cli_args os.execv(sys.executable, argv) def _maybe_export_reference_filter_artifacts(config: OmegaConf) -> None: debug_cfg = getattr(config, "debug_reference_filter_export", None) if debug_cfg is None or not bool(debug_cfg.get("enabled", False)): return if not bool(getattr(config, "main_process", True)): return export_dir = export_reference_filter_artifacts_from_config(config) logger.info(f"Exported reference filter debug artifacts to: {export_dir}") @hydra.main( config_path="../../config", config_name="training/train_base", version_base=None, ) def main(config: OmegaConf): """Train the motion tracking model. Args: config: OmegaConf object containing the configuration. """ config = compile_config(config, accelerator=None) dist = None # In distributed runs, Hydra resolves ${now:...} per process so experiment_save_dir # can differ by rank (e.g. staggered startup). Use Accelerator to init the process # group, then broadcast rank 0's path so all ranks write to the same directory. if getattr(config, "num_processes", 1) > 1: project_config = ProjectConfiguration( project_dir=config.experiment_save_dir, logging_dir=config.experiment_save_dir, ) _accelerator = Accelerator(project_config=project_config) import torch.distributed as dist path_list = ( [config.experiment_save_dir] if _accelerator.is_main_process else [None] ) dist.broadcast_object_list(path_list, src=0) config.experiment_save_dir = path_list[0] _maybe_export_reference_filter_artifacts(config) if dist is not None: dist.barrier() log_dir = config.experiment_save_dir headless = config.headless algo_class = get_class(config.algo._target_) algo = algo_class( env_config=config.env, config=config.algo.config, log_dir=log_dir, headless=headless, ) algo.load(config.checkpoint) algo.learn() if not bool(config.mujoco_eval.get("enabled", False)): return if not bool(config.algo.config.get("export_policy", False)): msg = ( "mujoco_eval.enabled=true requires " "algo.config.export_policy=true to export ONNX " "before post-training evaluation." ) raise ValueError(msg) if not bool(algo.is_main_process): os._exit(0) exported_dir = Path(log_dir) / "exported" selected_onnx_names = _resolve_mujoco_eval_onnx_names( exported_dir, config.mujoco_eval.get("ckpt_onnx_names", None) ) eval_override_dict = OmegaConf.to_container( config.mujoco_eval, resolve=True ) eval_override_dict.pop("enabled", None) eval_override_dict["ckpt_onnx_root_dir"] = str(exported_dir) eval_override_dict["ckpt_onnx_names"] = selected_onnx_names _exec_mujoco_eval(eval_override_dict) if __name__ == "__main__": main() ================================================ FILE: holomotion/src/utils/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/utils/config.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import copy import math import os from pathlib import Path import torch from accelerate import Accelerator from loguru import logger from omegaconf import OmegaConf def setup_hydra_resolvers(): """Set up custom resolvers for OmegaConf. This function registers a set of custom resolvers with OmegaConf to allow for more dynamic and flexible configurations within Hydra. These resolvers enable performing calculations, conditional logic, and other operations directly in the YAML configuration files. For example, you can use `${sqrt:4}` to get `2.0`. The registered resolvers include: - `eval`: Evaluates a Python expression. - `if`: Conditional logic (if-else). - `eq`: Case-insensitive string comparison. - `sqrt`: Calculates the square root. - `sum`: Sums a list of numbers. - `ceil`: Computes the ceiling of a number. - `int`: Casts a value to an integer. - `len`: Returns the length of a list or string. - `sum_list`: Sums a list of numbers. """ try: OmegaConf.register_new_resolver("eval", eval) OmegaConf.register_new_resolver( "if", lambda pred, a, b: a if pred else b ) OmegaConf.register_new_resolver( "eq", lambda x, y: x.lower() == y.lower() ) OmegaConf.register_new_resolver("sqrt", lambda x: math.sqrt(float(x))) OmegaConf.register_new_resolver("sum", lambda x: sum(x)) OmegaConf.register_new_resolver("ceil", lambda x: math.ceil(x)) OmegaConf.register_new_resolver("int", lambda x: int(x)) OmegaConf.register_new_resolver("len", lambda x: len(x)) OmegaConf.register_new_resolver("sum_list", lambda lst: sum(lst)) except Exception as e: logger.warning(f"Warning: Some resolvers already registered: {e}") def compile_config( config: OmegaConf, accelerator: Accelerator = None, eval: bool = False, ) -> None: """Compile the configuration. Args: config: Unresolved configuration. accelerator: Accelerator instance. Returns: Compiled configuration. """ setup_hydra_resolvers() config = copy.deepcopy(config) config = compile_config_hf_accelerate(config, accelerator) config = compile_config_directories(config, eval) config = compile_config_devices(config) return config def compile_config_hf_accelerate( config, accelerator: Accelerator = None, ) -> None: """Compile the configuration for HF Accelerate. Args: config: Configuration. accelerator: Accelerator instance. Returns: Compiled configuration. """ if accelerator is not None: device = accelerator.device is_main_process = accelerator.is_main_process process_idx = accelerator.process_index total_processes = accelerator.num_processes else: # Best-effort distributed metadata when running under torchrun / Accelerate launch, # even if an Accelerator instance is not provided yet. process_idx = int( os.environ.get( "RANK", os.environ.get("ACCELERATE_PROCESS_INDEX", "0") ) ) total_processes = int( os.environ.get( "WORLD_SIZE", os.environ.get("ACCELERATE_NUM_PROCESSES", "1") ) ) local_rank = int( os.environ.get( "LOCAL_RANK", os.environ.get("ACCELERATE_LOCAL_PROCESS_INDEX", "0"), ) ) is_main_process = process_idx == 0 if torch.cuda.is_available(): device = torch.device("cuda", local_rank) else: device = torch.device("cpu") config.process_id = process_idx config.num_processes = total_processes config.main_process = is_main_process if hasattr(config, "device"): config.device = str(device) logger.info(f"Using device: {device}") if is_main_process: logger.info(f"Process {process_idx} on device: {device}") return config def compile_config_devices(config): """Propagate device and process metadata into the environment configuration.""" config = copy.deepcopy(config) if hasattr(config, "device"): device_str = str(config.device) else: device_str = str( torch.device("cuda" if torch.cuda.is_available() else "cpu") ) world_size = getattr(config, "num_processes", 1) process_rank = getattr(config, "process_id", 0) is_main_process = getattr(config, "main_process", True) if hasattr(config, "env") and hasattr(config.env, "config"): env_cfg = config.env.config env_cfg_struct = OmegaConf.is_struct(env_cfg) OmegaConf.set_struct(env_cfg, False) env_cfg.num_processes = world_size env_cfg.process_id = process_rank env_cfg.main_process = is_main_process env_cfg.simulation_device = device_str for key in [ "sim_device", "rl_device", "compute_device", "physx_device", ]: setattr(env_cfg, key, device_str) if hasattr(env_cfg, "simulation"): for sim_key in ["device", "compute_device", "rl_device"]: sim_cfg = env_cfg.simulation sim_struct = OmegaConf.is_struct(sim_cfg) OmegaConf.set_struct(sim_cfg, False) setattr(sim_cfg, sim_key, device_str) OmegaConf.set_struct(sim_cfg, sim_struct) OmegaConf.set_struct(env_cfg, env_cfg_struct) return config def compile_config_directories(config, eval: bool = False) -> None: """Compile the configuration for folders. Args: config: Configuration. Returns: Compiled configuration. """ if eval: return config config = copy.deepcopy(config) experiment_save_dir = Path(config.experiment_dir) experiment_save_dir.mkdir(exist_ok=True, parents=True) config.experiment_save_dir = str(experiment_save_dir) if hasattr(config, "env"): config.env.config.save_rendering_dir = str( Path(config.experiment_dir) / "renderings_training" ) unresolved_conf = OmegaConf.to_container(config, resolve=False) if config.main_process: logger.info(f"Saving config file to {experiment_save_dir}") with open(experiment_save_dir / "config.yaml", "w") as file: OmegaConf.save(unresolved_conf, file) return config ================================================ FILE: holomotion/src/utils/frame_utils.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import isaaclab.utils.math as isaaclab_math import torch def positions_world_to_env_frame( positions_w: torch.Tensor, env_origins: torch.Tensor, ) -> torch.Tensor: """Convert simulator-world positions to IsaacLab environment frame. IsaacLab's MDP root position helpers return positions in the environment frame, i.e. simulation-world coordinates with per-environment `env_origins` subtracted. This helper applies the same translation removal to arbitrary position tensors so position arithmetic stays frame-consistent. """ if positions_w.ndim < 2 or positions_w.shape[-1] != 3: raise ValueError( "positions_w must have shape [B, ..., 3], " f"got {tuple(positions_w.shape)}." ) if env_origins.ndim != 2 or env_origins.shape[-1] != 3: raise ValueError( "env_origins must have shape [B, 3], " f"got {tuple(env_origins.shape)}." ) if positions_w.shape[0] != env_origins.shape[0]: raise ValueError( "Batch size mismatch between positions_w and env_origins: " f"{positions_w.shape[0]} vs {env_origins.shape[0]}." ) origin_view = env_origins.view( env_origins.shape[0], *([1] * (positions_w.ndim - 2)), 3, ) return positions_w - origin_view def root_relative_positions_from_env_frame( body_pos_env: torch.Tensor, root_pos_env: torch.Tensor, root_quat_w: torch.Tensor, ) -> torch.Tensor: """Convert environment-frame body positions into the root frame. The input positions must already be in IsaacLab's environment frame rather than raw simulator-world coordinates. Orientation is unaffected by `env_origins`, so the articulation root quaternion is reused directly. """ if body_pos_env.ndim < 3 or body_pos_env.shape[-1] != 3: raise ValueError( "body_pos_env must have shape [B, ..., 3], " f"got {tuple(body_pos_env.shape)}." ) if root_pos_env.ndim != 2 or root_pos_env.shape[-1] != 3: raise ValueError( "root_pos_env must have shape [B, 3], " f"got {tuple(root_pos_env.shape)}." ) if root_quat_w.ndim != 2 or root_quat_w.shape[-1] != 4: raise ValueError( "root_quat_w must have shape [B, 4], " f"got {tuple(root_quat_w.shape)}." ) if body_pos_env.shape[0] != root_pos_env.shape[0]: raise ValueError( "Batch size mismatch between body_pos_env and root_pos_env: " f"{body_pos_env.shape[0]} vs {root_pos_env.shape[0]}." ) if body_pos_env.shape[0] != root_quat_w.shape[0]: raise ValueError( "Batch size mismatch between body_pos_env and root_quat_w: " f"{body_pos_env.shape[0]} vs {root_quat_w.shape[0]}." ) root_pos_view = root_pos_env.view( root_pos_env.shape[0], *([1] * (body_pos_env.ndim - 2)), 3, ) root_quat_view = root_quat_w.view( root_quat_w.shape[0], *([1] * (body_pos_env.ndim - 2)), 4, ).expand(*body_pos_env.shape[:-1], 4) rel_pos_env = body_pos_env - root_pos_view return isaaclab_math.quat_apply_inverse(root_quat_view, rel_pos_env) def root_relative_positions_from_mixed_position_frames( body_pos_w: torch.Tensor, root_pos_env: torch.Tensor, root_quat_w: torch.Tensor, env_origins: torch.Tensor, ) -> torch.Tensor: """Build root-relative positions from world-frame bodies. This is the safe adapter for common IsaacLab code paths where body poses are read from `robot.data.body_pos_w` in simulator world coordinates while `isaaclab_mdp.root_pos_w(env)` is already expressed in the environment frame. """ body_pos_env = positions_world_to_env_frame(body_pos_w, env_origins) return root_relative_positions_from_env_frame( body_pos_env=body_pos_env, root_pos_env=root_pos_env, root_quat_w=root_quat_w, ) ================================================ FILE: holomotion/src/utils/isaac_utils/__init__.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: holomotion/src/utils/isaac_utils/maths.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # This file was originally copied from the [ASAP] repository: # https://github.com/LeCAR-Lab/ASAP # Modifications have been made to fit the needs of this project. import os import random import numpy as np import torch @torch.jit.script def normalize(x, eps: float = 1e-9): return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) @torch.jit.script def torch_rand_float(lower, upper, shape, device): # type: (float, float, Tuple[int, int], str) -> Tensor return (upper - lower) * torch.rand(*shape, device=device) + lower @torch.jit.script def copysign(a, b): # type: (float, Tensor) -> Tensor a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0]) return torch.abs(a) * torch.sign(b) def set_seed(seed, torch_deterministic=False): """Set seed across modules""" if seed == -1 and torch_deterministic: seed = 42 elif seed == -1: seed = np.random.randint(0, 10000) print(f"Setting seed: {seed}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if torch_deterministic: # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True) else: torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False return seed ================================================ FILE: holomotion/src/utils/isaac_utils/rotations.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # This file was originally copied from the [ASAP] repository: # https://github.com/LeCAR-Lab/ASAP # Modifications have been made to fit the needs of this project. from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import Tensor from holomotion.src.utils.isaac_utils.maths import ( copysign, normalize, ) @torch.jit.script def quat_unit(a): return normalize(a) @torch.jit.script def quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor: shape = b.shape a = a.reshape(-1, 4) b = b.reshape(-1, 3) if w_last: xyz = a[:, :3] w = a[:, 3:] else: xyz = a[:, 1:] w = a[:, :1] t = xyz.cross(b, dim=-1) * 2 return (b + w * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_apply_yaw(quat: Tensor, vec: Tensor, w_last: bool) -> Tensor: quat_yaw = quat.clone().view(-1, 4) quat_yaw[:, :2] = 0.0 quat_yaw = normalize(quat_yaw) return quat_apply(quat_yaw, vec, w_last) @torch.jit.script def wrap_to_pi(angles): angles %= 2 * np.pi angles -= 2 * np.pi * (angles > np.pi) return angles @torch.jit.script def quat_conjugate(a: Tensor, w_last: bool) -> Tensor: shape = a.shape a = a.reshape(-1, 4) if w_last: return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape) else: return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape) @torch.jit.script def quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor: shape = b.shape a = a.reshape(-1, 4) b = b.reshape(-1, 3) if w_last: xyz = a[:, :3] w = a[:, 3:] else: xyz = a[:, 1:] w = a[:, :1] t = xyz.cross(b, dim=-1) * 2 return (b + w * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_rotate(q: Tensor, v: Tensor, w_last: bool) -> Tensor: shape = q.shape if w_last: q_w = q[:, -1] q_vec = q[:, :3] else: q_w = q[:, 0] q_vec = q[:, 1:] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a + b + c @torch.jit.script def quat_rotate_inverse(q: Tensor, v: Tensor, w_last: bool) -> Tensor: shape = q.shape if w_last: q_w = q[:, -1] q_vec = q[:, :3] else: q_w = q[:, 0] q_vec = q[:, 1:] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a - b + c @torch.jit.script def quat_angle_axis(x: Tensor, w_last: bool) -> Tuple[Tensor, Tensor]: """The (angle, axis) representation of the rotation. The axis is normalized to unit length. The angle is guaranteed to be between [0, pi]. """ if w_last: w = x[..., -1] axis = x[..., :3] else: w = x[..., 0] axis = x[..., 1:] s = 2 * (w**2) - 1 angle = s.clamp(-1, 1).arccos() # just to be safe axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9) return angle, axis @torch.jit.script def quat_from_angle_axis(angle: Tensor, axis: Tensor, w_last: bool) -> Tensor: theta = (angle / 2).unsqueeze(-1) xyz = normalize(axis) * theta.sin() w = theta.cos() if w_last: return quat_unit(torch.cat([xyz, w], dim=-1)) else: return quat_unit(torch.cat([w, xyz], dim=-1)) @torch.jit.script def vec_to_heading(h_vec): h_theta = torch.atan2(h_vec[..., 1], h_vec[..., 0]) return h_theta @torch.jit.script def heading_to_quat(h_theta, w_last: bool): axis = torch.zeros( h_theta.shape + [ 3, ], device=h_theta.device, ) axis[..., 2] = 1 heading_q = quat_from_angle_axis(h_theta, axis, w_last=w_last) return heading_q @torch.jit.script def quat_axis(q: Tensor, axis: int, w_last: bool) -> Tensor: basis_vec = torch.zeros(q.shape[0], 3, device=q.device) basis_vec[:, axis] = 1 return quat_rotate(q, basis_vec, w_last) @torch.jit.script def normalize_angle(x): return torch.atan2(torch.sin(x), torch.cos(x)) @torch.jit.script def get_basis_vector(q: Tensor, v: Tensor, w_last: bool) -> Tensor: return quat_rotate(q, v, w_last) @torch.jit.script def quat_to_angle_axis(q): # type: (Tensor) -> Tuple[Tensor, Tensor] # computes axis-angle representation from quaternion q # q must be normalized # ZL: could have issues. min_theta = 1e-5 qx, qy, qz, qw = 0, 1, 2, 3 sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) angle = 2 * torch.acos(q[..., qw]) angle = normalize_angle(angle) sin_theta_expand = sin_theta.unsqueeze(-1) axis = q[..., qx:qw] / sin_theta_expand mask = torch.abs(sin_theta) > min_theta default_axis = torch.zeros_like(axis) default_axis[..., -1] = 1 angle = torch.where(mask, angle, torch.zeros_like(angle)) mask_expand = mask.unsqueeze(-1) axis = torch.where(mask_expand, axis, default_axis) return angle, axis @torch.jit.script def slerp(q0, q1, t): # type: (Tensor, Tensor, Tensor) -> Tensor cos_half_theta = torch.sum(q0 * q1, dim=-1) neg_mask = cos_half_theta < 0 q1 = q1.clone() q1[neg_mask] = -q1[neg_mask] cos_half_theta = torch.abs(cos_half_theta) cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) half_theta = torch.acos(cos_half_theta) sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta ratioB = torch.sin(t * half_theta) / sin_half_theta new_q = ratioA * q0 + ratioB * q1 new_q = torch.where( torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q ) new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) return new_q @torch.jit.script def angle_axis_to_exp_map(angle, axis): # type: (Tensor, Tensor) -> Tensor # compute exponential map from axis-angle angle_expand = angle.unsqueeze(-1) exp_map = angle_expand * axis return exp_map @torch.jit.script def my_quat_rotate(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a + b + c @torch.jit.script def calc_heading(q): # type: (Tensor) -> Tensor # calculate heading direction from quaternion # the heading is the direction on the xy plane # q must be normalized # this is the x axis heading ref_dir = torch.zeros_like(q[..., 0:3]) ref_dir[..., 0] = 1 rot_dir = my_quat_rotate(q, ref_dir) heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) return heading @torch.jit.script def quat_to_exp_map(q): # type: (Tensor) -> Tensor # compute exponential map from quaternion # q must be normalized angle, axis = quat_to_angle_axis(q) exp_map = angle_axis_to_exp_map(angle, axis) return exp_map @torch.jit.script def calc_heading_quat(q, w_last): # type: (Tensor, bool) -> Tensor # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(heading, axis, w_last=w_last) return heading_q @torch.jit.script def calc_heading_quat_inv(q, w_last): # type: (Tensor, bool) -> Tensor # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(-heading, axis, w_last=w_last) return heading_q @torch.jit.script def quat_inverse(x, w_last): # type: (Tensor, bool) -> Tensor """The inverse of the rotation""" return quat_conjugate(x, w_last=w_last) @torch.jit.script def get_euler_xyz(q: Tensor, w_last: bool) -> Tuple[Tensor, Tensor, Tensor]: if w_last: qx, qy, qz, qw = 0, 1, 2, 3 else: qw, qx, qy, qz = 0, 1, 2, 3 # roll (x-axis rotation) sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) cosr_cosp = ( q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz] ) roll = torch.atan2(sinr_cosp, cosr_cosp) # pitch (y-axis rotation) sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) pitch = torch.where( torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp) ) # yaw (z-axis rotation) siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) cosy_cosp = ( q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] ) yaw = torch.atan2(siny_cosp, cosy_cosp) return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) # @torch.jit.script def get_euler_xyz_in_tensor(q): qx, qy, qz, qw = 0, 1, 2, 3 # roll (x-axis rotation) sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) cosr_cosp = ( q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz] ) roll = torch.atan2(sinr_cosp, cosr_cosp) # pitch (y-axis rotation) sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) pitch = torch.where( torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp) ) # yaw (z-axis rotation) siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) cosy_cosp = ( q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] ) yaw = torch.atan2(siny_cosp, cosy_cosp) return torch.stack((roll, pitch, yaw), dim=-1) @torch.jit.script def quat_pos(x): """Make all the real part of the quaternion positive""" q = x z = (q[..., 3:] < 0).float() q = (1 - 2 * z) * q return q @torch.jit.script def is_valid_quat(q): x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] return (w * w + x * x + y * y + z * z).allclose(torch.ones_like(w)) @torch.jit.script def quat_normalize(q): """Construct 3D rotation from quaternion (the quaternion needs not to be normalized).""" q = quat_unit(quat_pos(q)) # normalized to positive and unit quaternion return q @torch.jit.script def quat_mul(a, b, w_last: bool): assert a.shape == b.shape shape = a.shape a = a.reshape(-1, 4) b = b.reshape(-1, 4) if w_last: x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] else: w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) xx = ww + yy + zz qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) w = qq - ww + (z1 - y1) * (y2 - z2) x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) if w_last: quat = torch.stack([x, y, z, w], dim=-1).view(shape) else: quat = torch.stack([w, x, y, z], dim=-1).view(shape) return quat @torch.jit.script def quat_mul_norm(x, y, w_last): # type: (Tensor, Tensor, bool) -> Tensor r"""Combine two set of 3D rotations together using \**\* operator. The shape needs to be broadcastable """ return quat_normalize(quat_mul(x, y, w_last)) @torch.jit.script def quat_mul_norm(x, y, w_last): # type: (Tensor, Tensor, bool) -> Tensor r"""Combine two set of 3D rotations together using \**\* operator. The shape needs to be broadcastable """ return quat_unit(quat_mul(x, y, w_last)) @torch.jit.script def quat_identity(shape: List[int]): """Construct 3D identity rotation given shape""" w = torch.ones(shape + [1]) xyz = torch.zeros(shape + [3]) q = torch.cat([xyz, w], dim=-1) return quat_normalize(q) @torch.jit.script def quat_identity_like(x): """Construct identity 3D rotation with the same shape""" return quat_identity(x.shape[:-1]) @torch.jit.script def transform_from_rotation_translation( r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None ): """Construct a transform from a quaternion and 3D translation. Only one of them can be None.""" assert r is not None or t is not None, ( "rotation and translation can't be all None" ) if r is None: assert t is not None r = quat_identity(list(t.shape)) if t is None: t = torch.zeros(list(r.shape) + [3]) return torch.cat([r, t], dim=-1) @torch.jit.script def transform_rotation(x): """Get rotation from transform""" return x[..., :4] @torch.jit.script def transform_translation(x): """Get translation from transform""" return x[..., 4:] @torch.jit.script def transform_mul(x, y): """Combine two transformation together""" z = transform_from_rotation_translation( r=quat_mul_norm( transform_rotation(x), transform_rotation(y), w_last=True ), t=quat_rotate( transform_rotation(x), transform_translation(y), w_last=True ) + transform_translation(x), ) return z @torch.compile def quaternion_to_matrix( quaternions: torch.Tensor, w_last: bool = True, ) -> torch.Tensor: """Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions as tensor of shape (..., 4). If w_last=True (default): real part last (x, y, z, w) If w_last=False: real part first (w, x, y, z) w_last: If True, quaternion format is (x, y, z, w). If False, quaternion format is (w, x, y, z). Default: True. Returns: Rotation matrices as tensor of shape (..., 3, 3). """ if w_last: i, j, k, r = torch.unbind(quaternions, -1) else: r, i, j, k = torch.unbind(quaternions, -1) two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) @torch.jit.script def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: """Convert rotations given as axis/angle to quaternions. Args: axis_angle: Rotations given as a vector in axis angle form, as a tensor of shape (..., 3), where the magnitude is the angle turned anticlockwise in radians around the vector's direction. Returns: quaternions with real part first, as tensor of shape (..., 4). """ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) half_angles = angles * 0.5 eps = 1e-6 small_angles = angles.abs() < eps sin_half_angles_over_angles = torch.empty_like(angles) sin_half_angles_over_angles[~small_angles] = ( torch.sin(half_angles[~small_angles]) / angles[~small_angles] ) # for x small, sin(x/2) is about x/2 - (x/2)^3/6 # so sin(x/2)/x is about 1/2 - (x*x)/48 sin_half_angles_over_angles[small_angles] = ( 0.5 - (angles[small_angles] * angles[small_angles]) / 48 ) quaternions = torch.cat( [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1, ) return quaternions # @torch.jit.script def wxyz_to_xyzw(quat): return quat[..., [1, 2, 3, 0]] # @torch.jit.script def xyzw_to_wxyz(quat): return quat[..., [3, 0, 1, 2]] def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: """W x y z Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part first, as tensor of shape (..., 4). """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_rijk = torch.stack( [ torch.stack( [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1 ), torch.stack( [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1 ), torch.stack( [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1 ), torch.stack( [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1 ), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) return quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :, # pyre-ignore[16] ].reshape(batch_dim + (4,)) def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret def quat_w_first(rot): rot = torch.cat([rot[..., [-1]], rot[..., :-1]], -1) return rot @torch.jit.script def quat_from_euler_xyz(roll, pitch, yaw): cy = torch.cos(yaw * 0.5) sy = torch.sin(yaw * 0.5) cr = torch.cos(roll * 0.5) sr = torch.sin(roll * 0.5) cp = torch.cos(pitch * 0.5) sp = torch.sin(pitch * 0.5) qw = cy * cr * cp + sy * sr * sp qx = cy * sr * cp - sy * cr * sp qy = cy * cr * sp + sy * sr * cp qz = sy * cr * cp - cy * sr * sp return torch.stack([qx, qy, qz, qw], dim=-1) @torch.compile def remove_yaw_component( quat_raw: Tensor, quat_init: Tensor, w_last: bool = True, ) -> Tensor: """Remove yaw component from quaternion while keeping roll and pitch. This function extracts the yaw component from the initial quaternion and uses it to normalize the raw quaternion, effectively removing the initial heading offset while preserving roll and pitch components. Args: quat_raw: Current quaternion from IMU, shape (..., 4) quat_init: Initial quaternion (contains the yaw to be removed), shape (..., 4) w_last: If True, quaternion format is (x, y, z, w). If False, quaternion format is (w, x, y, z). Default: True. Returns: Quaternion with initial yaw component removed, same shape as input. The resulting quaternion represents roll and pitch relative to the heading-aligned coordinate frame. Example: >>> # Initial robot orientation (roll=0°, pitch=0°, yaw=45°) >>> quat_init = quat_from_euler_xyz( ... torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.7854) ... ) >>> # Current IMU reading (roll=10°, pitch=20°, yaw=60°) >>> quat_raw = quat_from_euler_xyz( ... torch.tensor(0.1745), ... torch.tensor(0.3491), ... torch.tensor(1.0472), ... ) >>> quat_norm = remove_yaw_component(quat_raw, quat_init) >>> # quat_norm contains roll=10°, pitch=20°, with initial yaw offset removed """ # Extract quaternion components based on format if w_last: q_w = quat_init[..., -1] q_vec = quat_init[..., :3] else: q_w = quat_init[..., 0] q_vec = quat_init[..., 1:] # Calculate heading by rotating x-axis with quaternion # ref_dir = [1, 0, 0] (x-axis) ref_dir = torch.zeros_like(q_vec) ref_dir[..., 0] = 1.0 # Quaternion rotation: v' = v + 2 * w * (q_vec × v) + 2 * q_vec × (q_vec × v) cross1 = torch.cross(q_vec, ref_dir, dim=-1) cross2 = torch.cross(q_vec, cross1, dim=-1) rot_dir = ref_dir + 2.0 * q_w.unsqueeze(-1) * cross1 + 2.0 * cross2 # Extract heading angle from rotated x-axis heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) # Create inverse heading quaternion (rotation about negative z-axis) half_heading = (-heading) * 0.5 heading_q_inv = torch.zeros_like(quat_init) if w_last: heading_q_inv[..., 0] = 0.0 # x heading_q_inv[..., 1] = 0.0 # y heading_q_inv[..., 2] = torch.sin(half_heading) # z heading_q_inv[..., 3] = torch.cos(half_heading) # w else: heading_q_inv[..., 0] = torch.cos(half_heading) # w heading_q_inv[..., 1] = 0.0 # x heading_q_inv[..., 2] = 0.0 # y heading_q_inv[..., 3] = torch.sin(half_heading) # z # Quaternion multiplication: heading_q_inv * quat_raw shape = quat_raw.shape a = heading_q_inv.reshape(-1, 4) b = quat_raw.reshape(-1, 4) if w_last: x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] else: w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] # Quaternion multiplication formula ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) xx = ww + yy + zz qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) w = qq - ww + (z1 - y1) * (y2 - z2) x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) if w_last: quat_result = torch.stack([x, y, z, w], dim=-1).view(shape) else: quat_result = torch.stack([w, x, y, z], dim=-1).view(shape) # Normalize the result quaternion norm = torch.norm(quat_result, p=2, dim=-1, keepdim=True) quat_norm = quat_result / norm.clamp(min=1e-8) return quat_norm ================================================ FILE: holomotion/src/utils/isaac_utils/setup.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. # This file was originally copied from the [ASAP] repository: # https://github.com/LeCAR-Lab/ASAP # Modifications have been made to fit the needs of this project. from setuptools import setup setup( name="isaac_utils", packages=["isaac_utils"], version="0.0.1", description="Unified torch env_utils for IsaacGym and IsaacSim", author="", classifiers=[], ) ================================================ FILE: holomotion/src/utils/onnx_export.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import inspect import re from pathlib import Path from loguru import logger def _list_to_csv_str(arr, *, decimals: int = 3, delimiter: str = ",") -> str: fmt = f"{{:.{decimals}f}}" return delimiter.join( fmt.format(x) if isinstance(x, (int, float)) else str(x) for x in arr ) def attach_onnx_metadata_holomotion(env, onnx_path: str) -> None: import onnx metadata = { "joint_names": env.scene["robot"].data.joint_names, "joint_stiffness": env.scene["robot"] .data.default_joint_stiffness[0] .cpu() .tolist(), "joint_damping": env.scene["robot"] .data.default_joint_damping[0] .cpu() .tolist(), "default_joint_pos": env.scene["robot"] .data.default_joint_pos[0] .cpu() .tolist(), "action_scale": env.action_manager.get_term("dof_pos") ._scale[0] .cpu() .tolist(), } model = onnx.load(onnx_path) for key, value in metadata.items(): entry = onnx.StringStringEntryProto() entry.key = key entry.value = ( _list_to_csv_str(value) if isinstance(value, list) else str(value) ) model.metadata_props.append(entry) onnx.save(model, onnx_path) def export_policy_to_onnx( algo, checkpoint_path: str, *, onnx_name_suffix: str | None = None, use_kv_cache: bool = True, ) -> str: checkpoint = Path(checkpoint_path) export_dir = checkpoint.parent / "exported" export_dir.mkdir(exist_ok=True) onnx_name = checkpoint.name.replace(".pt", ".onnx") if onnx_name_suffix is not None: suffix = re.sub(r"[\s+]", "_", str(onnx_name_suffix)) onnx_name = onnx_name.replace(".onnx", f"_{suffix}.onnx") onnx_path = export_dir / onnx_name logger.info("Starting ONNX minimal policy export (actions-only)...") actor_was_training = getattr(algo.actor, "training", None) critic_was_training = getattr(algo.critic, "training", None) algo.actor.eval() algo.critic.eval() try: actor_for_export = algo.accelerator.unwrap_model(algo.actor) orig_mod = getattr(actor_for_export, "_orig_mod", None) if orig_mod is not None: actor_for_export = orig_mod export_signature = inspect.signature(actor_for_export.export_onnx) export_kwargs = {"onnx_path": onnx_path, "opset_version": 17} if "use_kv_cache" in export_signature.parameters: export_kwargs["use_kv_cache"] = bool(use_kv_cache) onnx_path_str = actor_for_export.export_onnx(**export_kwargs) attach_onnx_metadata_holomotion(algo.env._env, onnx_path=onnx_path_str) logger.info( f"Successfully exported minimal policy to: {onnx_path_str}" ) return onnx_path_str finally: if actor_was_training is not None: algo.actor.train(actor_was_training) if critic_was_training is not None: algo.critic.train(critic_was_training) ================================================ FILE: holomotion/src/utils/reference_prefix.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from typing import Mapping def resolve_reference_tensor_key( batch_tensors: Mapping[str, object], base_key: str, prefix: str = "ref_", ) -> str: tensor_key = base_key if prefix: prefixed_key = f"{prefix}{base_key}" if prefixed_key in batch_tensors: tensor_key = prefixed_key elif prefix == "ft_ref_": raise KeyError( f"Filtered tensor '{prefixed_key}' is not present in the " "current motion cache batch. Ensure online filtering is " "enabled and 'ft_ref_' is materialized in allowed_prefixes." ) elif base_key not in batch_tensors: raise KeyError( f"Neither '{prefixed_key}' nor '{base_key}' is present in " "the current motion cache batch." ) elif base_key not in batch_tensors: raise KeyError( f"Tensor '{base_key}' is not present in the current motion cache batch." ) return tensor_key ================================================ FILE: holomotion/src/utils/torch_utils.py ================================================ """Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. NVIDIA CORPORATION and its licensors retain all intellectual property and proprietary rights in and to this software, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this software and related documentation without an express license agreement from NVIDIA CORPORATION is strictly prohibited. """ import numpy as np import torch import torch.nn.functional as F def to_torch(x, dtype=torch.float, device="cpu", requires_grad=False): return torch.tensor( x, dtype=dtype, device=device, requires_grad=requires_grad ) @torch.jit.script def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: """Multiply two quaternions together. Args: q1: The first quaternion in (w, x, y, z). Shape is (..., 4). q2: The second quaternion in (w, x, y, z). Shape is (..., 4). Returns: The product of the two quaternions in (w, x, y, z). Shape is (..., 4). Raises: ValueError: Input shapes of ``q1`` and ``q2`` are not matching. """ # check input is correct if q1.shape != q2.shape: msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}." raise ValueError(msg) # reshape to (N, 4) for multiplication shape = q1.shape q1 = q1.reshape(-1, 4) q2 = q2.reshape(-1, 4) # extract components from quaternions w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] # perform multiplication ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) xx = ww + yy + zz qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) w = qq - ww + (z1 - y1) * (y2 - z2) x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) return torch.stack([w, x, y, z], dim=-1).view(shape) @torch.jit.script def normalize(x, eps: float = 1e-9): return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) @torch.jit.script def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """Apply a quaternion rotation to a vector. Args: quat: The quaternion in (w, x, y, z). Shape is (..., 4). vec: The vector in (x, y, z). Shape is (..., 3). Returns: The rotated vector in (x, y, z). Shape is (..., 3). """ # store shape shape = vec.shape # reshape to (N, 3) for multiplication quat = quat.reshape(-1, 4) vec = vec.reshape(-1, 3) # extract components from quaternions xyz = quat[:, 1:] t = xyz.cross(vec, dim=-1) * 2 return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_apply_inverse(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """Apply an inverse quaternion rotation to a vector. Args: quat: The quaternion in (w, x, y, z). Shape is (..., 4). vec: The vector in (x, y, z). Shape is (..., 3). Returns: The rotated vector in (x, y, z). Shape is (..., 3). """ # store shape shape = vec.shape # reshape to (N, 3) for multiplication quat = quat.reshape(-1, 4) vec = vec.reshape(-1, 3) # extract components from quaternions xyz = quat[:, 1:] t = xyz.cross(vec, dim=-1) * 2 return (vec - quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape) @torch.jit.script def quat_rotate(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a + b + c # @torch.jit.script def quat_rotate_inverse(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a - b + c @torch.jit.script def quat_conjugate(a): shape = a.shape a = a.reshape(-1, 4) return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape) # return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape) @torch.jit.script def quat_unit(a): return normalize(a) @torch.jit.script def quat_from_angle_axis(angle, axis): theta = (angle / 2).unsqueeze(-1) xyz = normalize(axis) * theta.sin() w = theta.cos() return quat_unit(torch.cat([xyz, w], dim=-1)) @torch.jit.script def normalize_angle(x): return torch.atan2(torch.sin(x), torch.cos(x)) @torch.jit.script def tf_inverse(q, t): q_inv = quat_conjugate(q) return q_inv, -quat_apply(q_inv, t) @torch.jit.script def tf_apply(q, t, v): return quat_apply(q, v) + t @torch.jit.script def tf_vector(q, v): return quat_apply(q, v) @torch.jit.script def tf_combine(q1, t1, q2, t2): return quat_mul(q1, q2), quat_apply(q1, t2) + t1 @torch.jit.script def get_basis_vector(q, v): return quat_rotate(q, v) def get_axis_params(value, axis_idx, x_value=0.0, dtype=np.float64, n_dims=3): """Construct arguments to `Vec` according to axis index.""" zs = np.zeros((n_dims,)) assert axis_idx < n_dims, ( "the axis dim should be within the vector dimensions" ) zs[axis_idx] = 1.0 params = np.where(zs == 1.0, value, zs) params[0] = x_value return list(params.astype(dtype)) @torch.jit.script def copysign(a, b): a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0]) return torch.abs(a) * torch.sign(b) @torch.jit.script def get_euler_xyz(q: torch.Tensor) -> tuple: qx, qy, qz, qw = 0, 1, 2, 3 # roll (x-axis rotation) sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) cosr_cosp = ( q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz] ) roll = torch.atan2(sinr_cosp, cosr_cosp) # pitch (y-axis rotation) sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) pitch = torch.where( torch.abs(sinp) >= 1, copysign(torch.tensor(np.pi / 2.0, device=sinp.device), sinp), torch.asin(sinp), ) # yaw (z-axis rotation) siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) cosy_cosp = ( q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] ) yaw = torch.atan2(siny_cosp, cosy_cosp) return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi) @torch.jit.script def quat_from_euler_xyz(roll, pitch, yaw): cy = torch.cos(yaw * 0.5) sy = torch.sin(yaw * 0.5) cr = torch.cos(roll * 0.5) sr = torch.sin(roll * 0.5) cp = torch.cos(pitch * 0.5) sp = torch.sin(pitch * 0.5) qw = cy * cr * cp + sy * sr * sp qx = cy * sr * cp - sy * cr * sp qy = cy * cr * sp + sy * sr * cp qz = sy * cr * cp - cy * sr * sp return torch.stack([qx, qy, qz, qw], dim=-1) def torch_rand_float(lower, upper, shape, device): return (upper - lower) * torch.rand(*shape, device=device) + lower # @torch.jit.script @torch.compile def torch_random_dir_2(shape, device): angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1) return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1) @torch.jit.script def tensor_clamp(t, min_t, max_t): return torch.max(torch.min(t, max_t), min_t) @torch.jit.script def scale(x, lower, upper): return 0.5 * (x + 1.0) * (upper - lower) + lower @torch.jit.script def unscale(x, lower, upper): return (2.0 * x - upper - lower) / (upper - lower) def unscale_np(x, lower, upper): return (2.0 * x - upper - lower) / (upper - lower) @torch.jit.script def quat_to_angle_axis(q): # computes axis-angle representation from quaternion q # q must be normalized min_theta = 1e-5 qx, _, _, qw = 0, 1, 2, 3 sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw]) angle = 2 * torch.acos(q[..., qw]) angle = normalize_angle(angle) sin_theta_expand = sin_theta.unsqueeze(-1) axis = q[..., qx:qw] / sin_theta_expand mask = torch.abs(sin_theta) > min_theta default_axis = torch.zeros_like(axis) default_axis[..., -1] = 1 angle = torch.where(mask, angle, torch.zeros_like(angle)) mask_expand = mask.unsqueeze(-1) axis = torch.where(mask_expand, axis, default_axis) return angle, axis @torch.jit.script def angle_axis_to_exp_map(angle, axis): # compute exponential map from axis-angle angle_expand = angle.unsqueeze(-1) exp_map = angle_expand * axis return exp_map @torch.jit.script def quat_to_exp_map(q): # compute exponential map from quaternion # q must be normalized angle, axis = quat_to_angle_axis(q) exp_map = angle_axis_to_exp_map(angle, axis) return exp_map @torch.jit.script def slerp(q0, q1, t): cos_half_theta = torch.sum(q0 * q1, dim=-1) neg_mask = cos_half_theta < 0 q1 = q1.clone() q1[neg_mask] = -q1[neg_mask] cos_half_theta = torch.abs(cos_half_theta) cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1) half_theta = torch.acos(cos_half_theta) sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta) ratio_a = torch.sin((1 - t) * half_theta) / sin_half_theta ratio_b = torch.sin(t * half_theta) / sin_half_theta new_q = ratio_a * q0 + ratio_b * q1 new_q = torch.where( torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q ) new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q) return new_q @torch.jit.script def my_quat_rotate(q, v): shape = q.shape q_w = q[:, -1] q_vec = q[:, :3] a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1) b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 c = ( q_vec * torch.bmm( q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1) ).squeeze(-1) * 2.0 ) return a + b + c @torch.jit.script def calc_heading(q): # calculate heading direction from quaternion # the heading is the direction on the xy plane # q must be normalized # this is the x axis heading ref_dir = torch.zeros_like(q[..., 0:3]) ref_dir[..., 0] = 1 rot_dir = my_quat_rotate(q, ref_dir) heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0]) return heading @torch.jit.script def calc_heading_quat(q): # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(heading, axis) return heading_q @torch.jit.script def calc_heading_quat_inv(q): # calculate heading rotation from quaternion # the heading is the direction on the xy plane # q must be normalized heading = calc_heading(q) axis = torch.zeros_like(q[..., 0:3]) axis[..., 2] = 1 heading_q = quat_from_angle_axis(-heading, axis) return heading_q @torch.compiler.disable def axis_angle_from_quat( quat: torch.Tensor, w_last: bool = True, ) -> torch.Tensor: """Compute axis-angle (log map) vector from a quaternion. Args: quat (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z]. w_last (bool): Whether the scalar part w is the last element. Returns: torch.Tensor: (..., 3) axis-angle vector (axis * angle), with angle in radians in [0, pi]. Notes: - The quaternion is sign-adjusted to ensure w >= 0 and normalized to unit length for numerical stability. - Uses a stable small-angle handling to avoid NaNs and gradient issues. """ # Handle different quaternion formats if w_last: # Quaternion is [q_x, q_y, q_z, q_w] quat_w_orig = quat[..., -1:] else: # Quaternion is [q_w, q_x, q_y, q_z] quat_w_orig = quat[..., 0:1] # Normalize quaternion to have w > 0 quat = quat * (1.0 - 2.0 * (quat_w_orig < 0.0)) # Ensure unit quaternion for stability quat = quat / torch.linalg.norm(quat, dim=-1, keepdim=True).clamp_min( 1.0e-9 ) # Recompute quat_xyz and quat_w after potential sign flip if w_last: quat_w = quat[..., -1:] quat_xyz = quat[..., :3] else: quat_w = quat[..., 0:1] quat_xyz = quat[..., 1:4] mag = torch.linalg.norm(quat_xyz, dim=-1) half_angle = torch.atan2(mag, quat_w.squeeze(-1)) angle = 2.0 * half_angle # check whether to apply Taylor approximation use_taylor = angle.abs() <= 1.0e-6 # To prevent NaN gradients with torch.where, we compute both branches and blend # based on the condition. # See: https://pytorch.org/docs/1.9.0/generated/torch.where.html#torch-where # "However, if you need the gradients to flow through the branches, please use torch.lerp" # Although we are not using lerp, the principle of avoiding sharp branches is the same. sin_half_angles_over_angles_approx = 0.5 - angle * angle / 48 # Clamp angle to avoid division by zero in the non-taylor branch when angle is exactly 0. angle_safe = torch.where(use_taylor, torch.ones_like(angle), angle) sin_half_angles_over_angles_exact = torch.sin(half_angle) / angle_safe sin_half_angles_over_angles = torch.where( use_taylor, sin_half_angles_over_angles_approx, sin_half_angles_over_angles_exact, ) return quat_xyz / sin_half_angles_over_angles[..., None] @torch.compile def quat_box_minus( q1: torch.Tensor, q2: torch.Tensor, w_last: bool = True, ) -> torch.Tensor: """Right-invariant quaternion difference mapped to so(3) via log map. Computes log(q1 * q2^{-1}) using the shortest rotation convention. Args: q1 (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z]. q2 (torch.Tensor): (..., 4) quaternion with the same format as `q1`. w_last (bool): Whether the scalar part w is the last element. Returns: torch.Tensor: (..., 3) axis-angle error vector. """ if w_last: q1_xyzw = q1 q2_xyzw = q2 else: # Convert from (w, x, y, z) to (x, y, z, w) q1_xyzw = torch.cat([q1[..., 1:4], q1[..., 0:1]], dim=-1) q2_xyzw = torch.cat([q2[..., 1:4], q2[..., 0:1]], dim=-1) quat_diff = quat_mul( q1_xyzw, quat_conjugate(q2_xyzw), w_last=True, ) # q1 * q2^-1 return axis_angle_from_quat(quat_diff, w_last=True) # log(qd) @torch.compile def quat_error_magnitude( q1: torch.Tensor, q2: torch.Tensor, w_last: bool = True, ) -> torch.Tensor: """Geodesic angle between two orientations given as quaternions. Args: q1 (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z]. q2 (torch.Tensor): (..., 4) quaternion with the same format as `q1`. w_last (bool): Whether the scalar part w is the last element. Returns: torch.Tensor: (...,) rotation angle in radians in [0, pi]. """ axis_angle_error = quat_box_minus(q1, q2, w_last=w_last) return torch.norm(axis_angle_error, dim=-1) @torch.jit.script def quat_inv(q: torch.Tensor, eps: float = 1e-9) -> torch.Tensor: """Computes the inverse of a quaternion. Args: q: The quaternion orientation in (w, x, y, z). Shape is (N, 4). eps: A small value to avoid division by zero. Defaults to 1e-9. Returns: The inverse quaternion in (w, x, y, z). Shape is (N, 4). """ return quat_conjugate(q) / q.pow(2).sum(dim=-1, keepdim=True).clamp( min=eps ) # --------------------- WXYZ helpers (torch) --------------------- def xyzw_to_wxyz(q: torch.Tensor) -> torch.Tensor: """ Convert quaternion from XYZW to WXYZ. Args: q (torch.Tensor): (..., 4) quaternion in XYZW. Returns: torch.Tensor: (..., 4) quaternion in WXYZ. """ return torch.cat([q[..., 3:4], q[..., 0:3]], dim=-1) def wxyz_to_xyzw(q: torch.Tensor) -> torch.Tensor: """ Convert quaternion from WXYZ to XYZW. Args: q (torch.Tensor): (..., 4) quaternion in WXYZ. Returns: torch.Tensor: (..., 4) quaternion in XYZW. """ return torch.cat([q[..., 1:4], q[..., 0:1]], dim=-1) @torch.compile def quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: """ Hamilton product in WXYZ layout using fused implementation. Args: q1 (torch.Tensor): (..., 4) WXYZ. q2 (torch.Tensor): (..., 4) WXYZ. Returns: torch.Tensor: (..., 4) WXYZ. """ return quat_mul(q1, q2, w_last=False) def rotate_vec_wxyz(q_wxyz: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """ Rotate vector v by quaternion q (WXYZ). Args: q_wxyz (torch.Tensor): (..., 4) WXYZ. v (torch.Tensor): (..., 3). Returns: torch.Tensor: (..., 3) rotated vector. """ # Support single-vector inputs by promoting to batch single = q_wxyz.ndim == 1 if single: q_in = q_wxyz[None, :] v_in = v[None, :] else: q_in = q_wxyz v_in = v q_xyzw = wxyz_to_xyzw(q_in) out = quat_apply(q_xyzw, v_in) if single: return out[0] return out def rotate_vec_inv_wxyz(q_wxyz: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """ Rotate vector v by inverse of quaternion q (WXYZ). Args: q_wxyz (torch.Tensor): (..., 4) WXYZ. v (torch.Tensor): (..., 3). Returns: torch.Tensor: (..., 3) rotated vector in inverse rotation. """ single = q_wxyz.ndim == 1 if single: q_in = q_wxyz[None, :] v_in = v[None, :] else: q_in = q_wxyz v_in = v q_xyzw = wxyz_to_xyzw(q_in) q_inv_xyzw = quat_conjugate(q_xyzw) out = quat_apply(q_inv_xyzw, v_in) if single: return out[0] return out def subtract_frame_transforms( t01: torch.Tensor, q01: torch.Tensor, t02: torch.Tensor | None = None, q02: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: r"""Subtract transformations between two reference frames into a stationary frame. It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \times T_{02}`, where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B. Args: t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3). q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3). Defaults to None, in which case the position is assumed to be zero. q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). Defaults to None, in which case the orientation is assumed to be identity. Returns: A tuple containing the position and orientation of frame 2 w.r.t. frame 1. Shape of the tensors are (N, 3) and (N, 4) respectively. """ # compute orientation q10 = quat_inv(q01) if q02 is not None: q12 = quat_mul(q10, q02) else: q12 = q10 # compute translation if t02 is not None: t12 = quat_apply(q10, t02 - t01) else: t12 = quat_apply(q10, -t01) return t12, q12 @torch.compile def quat_normalize_wxyz(q_wxyz: torch.Tensor) -> torch.Tensor: """ Normalize quaternion in WXYZ layout. Args: q_wxyz (torch.Tensor): (..., 4) WXYZ. Returns: torch.Tensor: (..., 4) normalized WXYZ. """ return q_wxyz / torch.linalg.norm(q_wxyz, dim=-1, keepdim=True).clamp_min( 1.0e-9 ) # @torch.compile @torch.jit.script def matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor: """Convert rotations given as quaternions to rotation matrices. Args: quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4). Returns: Rotation matrices. The shape is (..., 3, 3). Reference: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70 """ r, i, j, k = torch.unbind(quaternions, -1) # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) @torch.jit.script def rot6d_from_quat(quaternions: torch.Tensor) -> torch.Tensor: """ Convert rotations given as quaternions to 6D rotation representation. Uses the continuous 6D rotation representation from Zhou et al. (CVPR 2019). Args: quaternions: (..., 4) quaternion in (w, x, y, z). Returns: (..., 6) 6D rotation representation (first two columns of rotation matrix, flattened). """ mat = matrix_from_quat(quaternions) # (..., 3, 3) batch_shape = mat.shape[:-2] return mat[..., :, :2].reshape(batch_shape + (6,)) @torch.jit.script def matrix_from_rot6d(rot6d: torch.Tensor) -> torch.Tensor: """ Convert 6D rotation representation to rotation matrix. Uses Gram-Schmidt orthogonalization to reconstruct the rotation matrix from the first two columns. Args: rot6d: (..., 6) 6D rotation representation (first two columns of rotation matrix, flattened). Returns: (..., 3, 3) rotation matrix. """ # Extract first two columns a1 = rot6d[..., :3] # first column a2 = rot6d[..., 3:] # second column # Gram-Schmidt orthogonalization b1 = torch.nn.functional.normalize(a1, dim=-1) b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 b2 = torch.nn.functional.normalize(b2, dim=-1) b3 = torch.cross(b1, b2, dim=-1) # Stack columns to form rotation matrix mat = torch.stack((b1, b2, b3), dim=-1) # (..., 3, 3) return mat @torch.jit.script def quat_from_matrix(mat: torch.Tensor) -> torch.Tensor: """ Convert rotation matrix to quaternion. Args: mat: (..., 3, 3) rotation matrix. Returns: (..., 4) quaternion in (w, x, y, z). """ batch_dim = mat.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( mat.reshape(batch_dim + (9,)), dim=-1 ) # Compute q_abs = sqrt(max(0, trace_terms)) q_abs = torch.sqrt( torch.clamp( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ), min=0.0, ) ) # Compute quaternion candidates for each branch quat_by_rijk = torch.stack( [ torch.stack( [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1 ), torch.stack( [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1 ), torch.stack( [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1 ), torch.stack( [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1 ), ], dim=-2, ) # Normalize candidates (floor at 0.1 for numerical stability) flr = torch.tensor(0.1, dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].clamp(min=flr)) # Pick the best-conditioned candidate (largest denominator) return quat_candidates[ torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :, ].reshape(batch_dim + (4,)) @torch.jit.script def quat_from_rot6d(rot6d: torch.Tensor) -> torch.Tensor: """ Convert 6D rotation representation to quaternions. Args: rot6d: (..., 6) 6D rotation representation (first two columns of rotation matrix, flattened). Returns: (..., 4) quaternion in (w, x, y, z). """ mat = matrix_from_rot6d(rot6d) return quat_from_matrix(mat) @torch.jit.script def euler_xyz_from_quat( quat: torch.Tensor, wrap_to_2pi: bool = False ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Convert rotations given as quaternions to Euler angles in radians. Note: The euler angles are assumed in XYZ extrinsic convention. Args: quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4). wrap_to_2pi (bool): Whether to wrap output Euler angles into [0, 2π). If False, angles are returned in the default range (−π, π]. Defaults to False. Returns: A tuple containing roll-pitch-yaw. Each element is a tensor of shape (N,). Reference: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles """ q_w, q_x, q_y, q_z = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] # roll (x-axis rotation) sin_roll = 2.0 * (q_w * q_x + q_y * q_z) cos_roll = 1 - 2 * (q_x * q_x + q_y * q_y) roll = torch.atan2(sin_roll, cos_roll) # pitch (y-axis rotation) sin_pitch = 2.0 * (q_w * q_y - q_z * q_x) pitch = torch.where( torch.abs(sin_pitch) >= 1, torch.copysign( torch.tensor(torch.pi / 2.0, device=quat.device, dtype=quat.dtype), sin_pitch, ), torch.asin(sin_pitch), ) # yaw (z-axis rotation) sin_yaw = 2.0 * (q_w * q_z + q_x * q_y) cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z) yaw = torch.atan2(sin_yaw, cos_yaw) if wrap_to_2pi: return ( roll % (2 * torch.pi), pitch % (2 * torch.pi), yaw % (2 * torch.pi), ) return roll, pitch, yaw @torch.jit.script def yaw_quat(quat: torch.Tensor) -> torch.Tensor: """Extract the yaw component of a quaternion. Args: quat: The orientation in (w, x, y, z). Shape is (..., 4) Returns: A quaternion with only yaw component. """ shape = quat.shape quat_yaw = quat.view(-1, 4) qw = quat_yaw[:, 0] qx = quat_yaw[:, 1] qy = quat_yaw[:, 2] qz = quat_yaw[:, 3] yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz)) quat_yaw = torch.zeros_like(quat_yaw) quat_yaw[:, 3] = torch.sin(yaw / 2) quat_yaw[:, 0] = torch.cos(yaw / 2) quat_yaw = normalize(quat_yaw) return quat_yaw.view(shape) def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """ Convert a unit quaternion to a standard form: one in which the real part is non negative. Args: quaternions: Quaternions with real part first, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) @torch.compiler.disable def gaussian_kernel1d( sigma: float, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: if sigma <= 0.0: raise ValueError(f"Invalid sigma: {sigma}") radius = int(4.0 * sigma + 0.5) x = torch.arange(-radius, radius + 1, device=device, dtype=dtype) kernel = torch.exp(-0.5 * (x / sigma).square()) return kernel / kernel.sum() @torch.compiler.disable def gaussian_filter1d(x: torch.Tensor, sigma: float, dim: int) -> torch.Tensor: if x.shape[dim] < 2: return x kernel = gaussian_kernel1d(sigma, device=x.device, dtype=x.dtype).reshape( 1, 1, -1 ) x_perm = x.movedim(dim, -1) x_flat = x_perm.reshape(-1, 1, x_perm.shape[-1]) pad = kernel.shape[-1] // 2 x_flat = F.pad(x_flat, (pad, pad), mode="replicate") y = F.conv1d(x_flat, kernel) y = y.reshape(x_perm.shape) return y.movedim(-1, dim) def smooth_time_series( x: torch.Tensor, sigma: float, dim: int ) -> torch.Tensor: """Gaussian smooth along a time dimension. This is a thin wrapper around :func:`gaussian_filter1d` that treats non-positive sigma as "no-op" for easy ablations. """ if sigma <= 0.0: return x return gaussian_filter1d(x, sigma=float(sigma), dim=int(dim)) @torch.compiler.disable def grad_t(x: torch.Tensor, dt: float) -> torch.Tensor: if dt <= 0.0: raise ValueError(f"Invalid dt: {dt}") if x.shape[1] < 2: return torch.zeros_like(x) grad = torch.empty_like(x) inv_dt = 1.0 / dt grad[:, 0] = (x[:, 1] - x[:, 0]) * inv_dt grad[:, -1] = (x[:, -1] - x[:, -2]) * inv_dt if x.shape[1] > 2: grad[:, 1:-1] = (x[:, 2:] - x[:, :-2]) * (0.5 * inv_dt) return grad def axis_angle_to_matrix( angles: torch.Tensor, axes: torch.Tensor ) -> torch.Tensor: if axes.shape[-1] != 3: raise ValueError("Axes must have shape (N, 3)") axis_norm = torch.linalg.norm(axes, dim=-1) if torch.any(axis_norm <= 0): raise ValueError("Axis vector has zero norm") axis = axes / axis_norm[:, None] aat = torch.einsum("ni,nj->nij", axis, axis) skew = torch.zeros( (axis.shape[0], 3, 3), device=axes.device, dtype=axes.dtype ) ax, ay, az = axis[:, 0], axis[:, 1], axis[:, 2] skew[:, 0, 1] = -az skew[:, 0, 2] = ay skew[:, 1, 0] = az skew[:, 1, 2] = -ax skew[:, 2, 0] = -ay skew[:, 2, 1] = ax sin_t = torch.sin(angles) cos_t = torch.cos(angles) eye = torch.eye(3, device=axes.device, dtype=axes.dtype)[None, None, None] return ( cos_t[..., None, None] * eye + (1.0 - cos_t)[..., None, None] * aat[None, None] + sin_t[..., None, None] * skew[None, None] ) ================================================ FILE: holomotion/src/utils/unitree_g1_actuator_calculator.py ================================================ # Project HoloMotion # # Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from __future__ import annotations import math from dataclasses import dataclass from typing import Any @dataclass(frozen=True) class MotorFamily: name: str armature: float x1: float x2: float y1: float y2: float fs: float fd: float va: float = 0.01 @dataclass(frozen=True) class JointSpec: joint_expr: str motor: MotorFamily effort_limit: float velocity_limit: float servo_scale: float = 1.0 envelope_scale: float = 1.0 friction_scale: float = 1.0 # ----------------------------------------------------------------------------- # Base actuator families # ----------------------------------------------------------------------------- N5020_16 = MotorFamily( name="N5020_16", armature=0.003609725, x1=30.86, x2=40.13, y1=24.8, y2=31.9, fs=0.6, fd=0.06, ) N7520_14P3 = MotorFamily( name="N7520_14P3", armature=0.010177520, x1=22.63, x2=35.52, y1=71.0, y2=83.3, fs=1.6, fd=0.16, ) N7520_22P5 = MotorFamily( name="N7520_22P5", armature=0.025101925, x1=14.5, x2=22.7, y1=111.0, y2=131.0, fs=2.4, fd=0.24, ) W4010_25 = MotorFamily( name="W4010_25", armature=0.00425, x1=15.3, x2=24.76, y1=4.8, y2=8.6, fs=0.6, fd=0.06, ) # ----------------------------------------------------------------------------- # Design constants # ----------------------------------------------------------------------------- NATURAL_FREQ_HZ = 10.0 DAMPING_RATIO = 2.0 # Set this to your actual physics dt before running the generator. PHYSICS_DT = 1.0 / 200.0 # Desired action delay budget: at most 2 * (1 / 50) = 0.04 s. MIN_DELAY_SECONDS = 0.0 MAX_DELAY_SECONDS = 2.0 / 50.0 def seconds_to_delay_steps(delay_seconds: float, physics_dt: float) -> int: return int(math.floor(delay_seconds / physics_dt + 1e-12)) MIN_DELAY = seconds_to_delay_steps(MIN_DELAY_SECONDS, PHYSICS_DT) MAX_DELAY = seconds_to_delay_steps(MAX_DELAY_SECONDS, PHYSICS_DT) # ----------------------------------------------------------------------------- # Single-group mapping # # ankle / waist: # - servo-side armature/gains are doubled # - torque envelope is NOT doubled # - friction is NOT doubled # ----------------------------------------------------------------------------- ALL_JOINT_SPECS: list[JointSpec] = [ # legs JointSpec( ".*_hip_yaw_joint", N7520_14P3, effort_limit=88.0, velocity_limit=32.0 ), JointSpec( ".*_hip_roll_joint", N7520_22P5, effort_limit=139.0, velocity_limit=20.0, ), JointSpec( ".*_hip_pitch_joint", N7520_14P3, effort_limit=88.0, velocity_limit=32.0, ), JointSpec( ".*_knee_joint", N7520_22P5, effort_limit=139.0, velocity_limit=20.0 ), # feet JointSpec( ".*_ankle_pitch_joint", N5020_16, effort_limit=50.0, velocity_limit=37.0, servo_scale=2.0, ), JointSpec( ".*_ankle_roll_joint", N5020_16, effort_limit=50.0, velocity_limit=37.0, servo_scale=2.0, ), # waist JointSpec( "waist_roll_joint", N5020_16, effort_limit=50.0, velocity_limit=37.0, servo_scale=2.0, ), JointSpec( "waist_pitch_joint", N5020_16, effort_limit=50.0, velocity_limit=37.0, servo_scale=2.0, ), JointSpec( "waist_yaw_joint", N7520_14P3, effort_limit=88.0, velocity_limit=32.0 ), # arms JointSpec( ".*_shoulder_pitch_joint", N5020_16, effort_limit=25.0, velocity_limit=37.0, ), JointSpec( ".*_shoulder_roll_joint", N5020_16, effort_limit=25.0, velocity_limit=37.0, ), JointSpec( ".*_shoulder_yaw_joint", N5020_16, effort_limit=25.0, velocity_limit=37.0, ), JointSpec( ".*_elbow_joint", N5020_16, effort_limit=25.0, velocity_limit=37.0 ), JointSpec( ".*_wrist_roll_joint", N5020_16, effort_limit=25.0, velocity_limit=37.0 ), JointSpec( ".*_wrist_pitch_joint", W4010_25, effort_limit=5.0, velocity_limit=22.0 ), JointSpec( ".*_wrist_yaw_joint", W4010_25, effort_limit=5.0, velocity_limit=22.0 ), ] def compute_pd_gains( armature: float, natural_freq_hz: float, damping_ratio: float ) -> tuple[float, float]: wn = natural_freq_hz * 2.0 * math.pi kp = armature * wn * wn kd = 2.0 * damping_ratio * armature * wn return kp, kd def fmt_float(x: float) -> str: return format(float(x), ".12g") def fmt_value(value: Any, indent: int = 0) -> str: sp = " " * indent if isinstance(value, dict): if not value: return "{}" lines = ["{"] for k, v in value.items(): lines.append(f"{sp} {k!r}: {fmt_value(v, indent + 4)},") lines.append(f"{sp}}}") return "\n".join(lines) if isinstance(value, list): if not value: return "[]" lines = ["["] for item in value: lines.append(f"{sp} {fmt_value(item, indent + 4)},") lines.append(f"{sp}]") return "\n".join(lines) if isinstance(value, float): return fmt_float(value) return repr(value) def build_single_group_cfg( specs: list[JointSpec], natural_freq_hz: float = NATURAL_FREQ_HZ, damping_ratio: float = DAMPING_RATIO, min_delay: int = MIN_DELAY, max_delay: int = MAX_DELAY, ) -> dict[str, Any]: joint_names_expr = [spec.joint_expr for spec in specs] effort_limit: dict[str, float] = {} velocity_limit: dict[str, float] = {} stiffness: dict[str, float] = {} damping: dict[str, float] = {} armature: dict[str, float] = {} x1: dict[str, float] = {} x2: dict[str, float] = {} y1: dict[str, float] = {} y2: dict[str, float] = {} fs: dict[str, float] = {} fd: dict[str, float] = {} va: dict[str, float] = {} action_scale: dict[str, float] = {} for spec in specs: name = spec.joint_expr servo_armature = spec.motor.armature * spec.servo_scale kp, kd = compute_pd_gains( servo_armature, natural_freq_hz, damping_ratio ) effort_limit[name] = spec.effort_limit velocity_limit[name] = spec.velocity_limit stiffness[name] = kp damping[name] = kd armature[name] = servo_armature x1[name] = spec.motor.x1 x2[name] = spec.motor.x2 y1[name] = spec.motor.y1 * spec.envelope_scale y2[name] = spec.motor.y2 * spec.envelope_scale fs[name] = spec.motor.fs * spec.friction_scale fd[name] = spec.motor.fd * spec.friction_scale va[name] = spec.motor.va action_scale[name] = 0.25 * spec.effort_limit / kp return { "joint_names_expr": joint_names_expr, "min_delay": min_delay, "max_delay": max_delay, "effort_limit": effort_limit, "velocity_limit": velocity_limit, "stiffness": stiffness, "damping": damping, "armature": armature, "friction": 0.0, "dynamic_friction": 0.0, "viscous_friction": 0.0, "X1": x1, "X2": x2, "Y1": y1, "Y2": y2, "Fs": fs, "Fd": fd, "Va": va, "action_scale": action_scale, } def render_single_group_cfg( cfg: dict[str, Any], group_name: str = "all_joints" ) -> str: ordered_keys = [ "joint_names_expr", "min_delay", "max_delay", "effort_limit", "velocity_limit", "stiffness", "damping", "armature", "friction", "dynamic_friction", "viscous_friction", "X1", "X2", "Y1", "Y2", "Fs", "Fd", "Va", ] lines = [ "from unitree_actuators import UnitreeActuatorCfg", "", "G1_HIFI_ACTUATORS = {", f" {group_name!r}: UnitreeActuatorCfg(", ] for key in ordered_keys: rendered = fmt_value(cfg[key], indent=8) lines.append(f" {key}={rendered},") lines.append(" )") lines.append("}") lines.append("") lines.append("G1_HIFI_ACTION_SCALE = {") for joint_expr in cfg["joint_names_expr"]: lines.append( f" {joint_expr!r}: {fmt_float(cfg['action_scale'][joint_expr])}," ) lines.append("}") return "\n".join(lines) def print_summary(cfg: dict[str, Any]) -> None: print("# === SUMMARY ===") print(f"# physics_dt = {fmt_float(PHYSICS_DT)}") print(f"# min_delay = {cfg['min_delay']}") print(f"# max_delay = {cfg['max_delay']}") print( "# joint_expr | effort_limit | velocity_limit | armature | kp | kd | " "X1 | X2 | Y1 | Y2 | Fs | Fd | action_scale" ) for joint_expr in cfg["joint_names_expr"]: print( f"# {joint_expr} | " f"{fmt_float(cfg['effort_limit'][joint_expr])} | " f"{fmt_float(cfg['velocity_limit'][joint_expr])} | " f"{fmt_float(cfg['armature'][joint_expr])} | " f"{fmt_float(cfg['stiffness'][joint_expr])} | " f"{fmt_float(cfg['damping'][joint_expr])} | " f"{fmt_float(cfg['X1'][joint_expr])} | " f"{fmt_float(cfg['X2'][joint_expr])} | " f"{fmt_float(cfg['Y1'][joint_expr])} | " f"{fmt_float(cfg['Y2'][joint_expr])} | " f"{fmt_float(cfg['Fs'][joint_expr])} | " f"{fmt_float(cfg['Fd'][joint_expr])} | " f"{fmt_float(cfg['action_scale'][joint_expr])}" ) print() def main() -> None: cfg = build_single_group_cfg(ALL_JOINT_SPECS) print_summary(cfg) print(render_single_group_cfg(cfg, group_name="all_joints")) if __name__ == "__main__": main() ================================================ FILE: holomotion/tests/__init__.py ================================================ ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=64.0", "wheel"] build-backend = "setuptools.build_meta" [project] name = "holomotion" version = "1.2.0" description = "HoloMotion" authors = [ {name = "Horizon Robotics"}, ] readme = "README.md" requires-python = ">=3.10" dependencies = [] [project.urls] Homepage = "https://horizonrobotics.github.io/robot_lab/holomotion/ " Repository = "https://github.com/" [tool.setuptools.packages.find] where = ["."] include = ["holomotion*"] [tool.ruff] exclude = [ # common ".bzr", ".direnv", ".eggs", ".git", ".git-rewrite", ".hg", ".ipynb_checkpoints", ".mypy_cache", ".nox", ".pants.d", ".pyenv", ".pytest_cache", ".pytype", ".ruff_cache", ".svn", ".tox", ".venv", ".vscode", "__pypackages__", "_build", "buck-out", "build", "dist", "node_modules", "site-packages", "venv", # project "3rdparty/*", "dummy/*", "*.pyi", "*_pb2.py", ] # Same as Black. line-length = 79 indent-width = 4 # required python 3.11 target-version = "py311" [tool.ruff.lint] select = [ "E", # flake8-errors "F", # pyflake "I", # isort "B", # flake8-bugber "TID", # flake8-tidy-imports "D", # pydocstyle "Q", # flake8-quotes "W", # flake8-warnings "N", # pep8-naming ] ignore = [ "D104", "D107", "D202", "D105", "D100", "D102", "D103", "D101", "D301", "F403", "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling "B028", # No explicit `stacklevel` keyword argument found "D417", # requires documentation for every function parameter. ] [tool.ruff.lint.isort] known-third-party = [] no-lines-before = ["future", "standard-library"] combine-as-imports = true force-wrap-aliases = true [tool.ruff.lint.pydocstyle] convention = "google" [tool.ruff.lint.flake8-tidy-imports] # Disallow all relative imports. ban-relative-imports = "all" [tool.ruff.lint.flake8-quotes] avoid-escape = false [tool.ruff.lint.mccabe] max-complexity = 18 [tool.ruff.lint.per-file-ignores] "__init__.py" = ["TID252", "F401"] [tool.ruff.lint.pep8-naming] classmethod-decorators = [ # Allow Pydantic's `@validator` decorator to trigger class method treatment. "pydantic.validator", # Allow SQLAlchemy's dynamic decorators, like `@field.expression`, to trigger class method treatment. "declared_attr", "expression", "comparator", ] ignore-names = [ # ruff default (https://docs.astral.sh/ruff/settings/#lintpep8-naming) "setUp", "tearDown", "setUpClass", "tearDownClass", "setUpModule", "tearDownModule", "asyncSetUp", "asyncTearDown", "setUpTestData", "failureException", "longMessage", "maxDiff", # project "PROJECT_ROOT", # project test environment fixture "ROBO_ORCHARD_TEST_WORKSPACE", # project test fixture "F", # import torch.nn.functional as F ] [tool.ruff.format] # Like Black, use double quotes for strings. quote-style = "double" # Like Black, indent with spaces, rather than tabs. indent-style = "space" # Like Black, respect magic trailing commas. skip-magic-trailing-comma = false # Like Black, automatically detect the appropriate line ending. line-ending = "auto" docstring-code-format = true ================================================ FILE: tests/benchmark_legacy_onnx_attention.py ================================================ import sys import tempfile import time from pathlib import Path import numpy as np import onnx import onnxruntime import torch import torch.nn as nn import torch.nn.functional as F sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.modules.network_modules import ( export_safe_scaled_dot_product_attention, ) class _RawAttentionModule(nn.Module): def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: return F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False ) class _SafeAttentionModule(nn.Module): def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: return export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, ) def _export_model( module: nn.Module, export_path: Path, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor, ) -> None: torch.onnx.export( module.eval(), (q, k, v, mask), str(export_path), opset_version=17, input_names=["q", "k", "v", "mask"], output_names=["out"], dynamo=False, verbose=False, ) def _benchmark_session( model_path: Path, provider, feed: dict[str, np.ndarray], *, warmup_iters: int = 50, measure_iters: int = 300, ) -> float: providers = ( ["CPUExecutionProvider"] if provider == "CPUExecutionProvider" else [provider, "CPUExecutionProvider"] ) session = onnxruntime.InferenceSession( str(model_path), providers=providers, ) for _ in range(warmup_iters): session.run(["out"], feed) start = time.perf_counter() for _ in range(measure_iters): session.run(["out"], feed) elapsed_s = time.perf_counter() - start return (elapsed_s * 1000.0) / measure_iters def main() -> None: torch.manual_seed(0) q = torch.randn(4, 8, 1, 64) k = torch.randn(4, 8, 32, 64) v = torch.randn(4, 8, 32, 64) valid_lengths = torch.tensor([32, 24, 16, 8], dtype=torch.int64) mask = ( torch.arange(32, dtype=torch.int64)[None, :] < valid_lengths[:, None] ) mask = mask[:, None, None, :] feed = { "q": q.numpy(), "k": k.numpy(), "v": v.numpy(), "mask": mask.numpy(), } with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) raw_path = tmp_path / "raw_attention.onnx" safe_path = tmp_path / "safe_attention.onnx" _export_model(_RawAttentionModule(), raw_path, q, k, v, mask) _export_model(_SafeAttentionModule(), safe_path, q, k, v, mask) raw_model = onnx.load(str(raw_path)) safe_model = onnx.load(str(safe_path)) raw_ops = [node.op_type for node in raw_model.graph.node] safe_ops = [node.op_type for node in safe_model.graph.node] print( "Graph ops: " f"raw_has_isnan={'IsNaN' in raw_ops}, " f"safe_has_isnan={'IsNaN' in safe_ops}" ) cpu_raw = _benchmark_session(raw_path, "CPUExecutionProvider", feed) cpu_safe = _benchmark_session(safe_path, "CPUExecutionProvider", feed) print( f"CPUExecutionProvider: raw={cpu_raw:.4f} ms, " f"safe={cpu_safe:.4f} ms, " f"delta={(cpu_safe - cpu_raw) / cpu_raw * 100.0:.2f}%" ) if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): cuda_raw = _benchmark_session( raw_path, "CUDAExecutionProvider", feed ) cuda_safe = _benchmark_session( safe_path, "CUDAExecutionProvider", feed ) print( f"CUDAExecutionProvider: raw={cuda_raw:.4f} ms, " f"safe={cuda_safe:.4f} ms, " f"delta={(cuda_safe - cuda_raw) / cuda_raw * 100.0:.2f}%" ) if __name__ == "__main__": main() ================================================ FILE: tests/benchmark_moe_router_export.py ================================================ import re import time from pathlib import Path import torch def _extract_int_setting(config_path: Path, key: str) -> int: pattern = re.compile(rf"^\s*{re.escape(key)}:\s*([0-9]+)\s*$") for line in config_path.read_text().splitlines(): match = pattern.match(line) if match: return int(match.group(1)) raise ValueError( f"Could not find integer setting {key!r} in {config_path}" ) def _load_b0310_shape_config() -> dict[str, int]: repo_root = Path(__file__).resolve().parents[1] module_cfg = ( repo_root / "holomotion" / "config" / "modules" / "motion_tracking" / "tf_motrack_v3.yaml" ) return { "num_fine_experts": _extract_int_setting( module_cfg, "num_fine_experts" ), "top_k": _extract_int_setting(module_cfg, "top_k"), "max_ctx_len": _extract_int_setting(module_cfg, "max_ctx_len"), } def _router_scores_training( logits_fp32: torch.Tensor, *, top_k: int, bias_fp32: torch.Tensor | None = None, ) -> torch.Tensor: choice_logits = ( logits_fp32 if bias_fp32 is None else logits_fp32 + bias_fp32 ) _, topk_idx = torch.topk(choice_logits, top_k, dim=-1) selected_logits = logits_fp32.gather(-1, topk_idx) log_z = torch.logsumexp(logits_fp32, dim=-1, keepdim=True) selected_probs = torch.exp(selected_logits - log_z) return selected_probs / selected_probs.sum(dim=-1, keepdim=True).clamp_min( 1.0e-20 ) def _router_scores_export_safe( logits_fp32: torch.Tensor, *, top_k: int, bias_fp32: torch.Tensor | None = None, ) -> torch.Tensor: choice_logits = ( logits_fp32 if bias_fp32 is None else logits_fp32 + bias_fp32 ) _, topk_idx = torch.topk(choice_logits, top_k, dim=-1) selected_probs = torch.softmax(logits_fp32, dim=-1).gather(-1, topk_idx) return selected_probs / selected_probs.sum(dim=-1, keepdim=True).clamp_min( 1.0e-20 ) def _benchmark( fn, logits_fp32: torch.Tensor, *, top_k: int, bias_fp32: torch.Tensor | None = None, warmup_iters: int = 200, measure_iters: int = 2000, ) -> float: is_cuda = logits_fp32.is_cuda with torch.inference_mode(): for _ in range(warmup_iters): fn(logits_fp32, top_k=top_k, bias_fp32=bias_fp32) if is_cuda: torch.cuda.synchronize(logits_fp32.device) start = time.perf_counter() for _ in range(measure_iters): fn(logits_fp32, top_k=top_k, bias_fp32=bias_fp32) if is_cuda: torch.cuda.synchronize(logits_fp32.device) elapsed_s = time.perf_counter() - start return (elapsed_s * 1000.0) / measure_iters def _run_case( device: torch.device, *, case_name: str, batch_size: int, seq_len: int, num_fine_experts: int, top_k: int, ) -> None: seed = 0 generator = torch.Generator(device="cpu") generator.manual_seed(seed) logits_fp32 = torch.randn( batch_size, seq_len, num_fine_experts, generator=generator, dtype=torch.float32, ).to(device) eager_scores = _router_scores_training(logits_fp32, top_k=top_k) export_scores = _router_scores_export_safe(logits_fp32, top_k=top_k) max_abs_diff = (eager_scores - export_scores).abs().max().item() eager_ms = _benchmark( _router_scores_training, logits_fp32, top_k=top_k, ) export_ms = _benchmark( _router_scores_export_safe, logits_fp32, top_k=top_k, ) delta_pct = ((export_ms - eager_ms) / eager_ms) * 100.0 print( f"{device.type}:{case_name}: " f"shape=({batch_size}, {seq_len}, {num_fine_experts}), " f"top_k={top_k}, " f"training={eager_ms:.6f} ms, " f"export_safe={export_ms:.6f} ms, " f"delta={delta_pct:.2f}%, " f"max_abs_diff={max_abs_diff:.3e}" ) def main() -> None: shape_cfg = _load_b0310_shape_config() num_fine_experts = shape_cfg["num_fine_experts"] top_k = shape_cfg["top_k"] max_ctx_len = shape_cfg["max_ctx_len"] cases = [ ("single_step_export", 1, 1), ("training_like_sequence", 16, max_ctx_len), ] devices = [torch.device("cpu")] if torch.cuda.is_available(): devices.append(torch.device("cuda")) print( "Benchmarking MoE router formulas with " f"num_fine_experts={num_fine_experts}, top_k={top_k}, " f"max_ctx_len={max_ctx_len}" ) for device in devices: for case_name, batch_size, seq_len in cases: _run_case( device, case_name=case_name, batch_size=batch_size, seq_len=seq_len, num_fine_experts=num_fine_experts, top_k=top_k, ) if __name__ == "__main__": main() ================================================ FILE: tests/test_actor_export_config.py ================================================ import importlib import sys import unittest from pathlib import Path from types import SimpleNamespace sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.modules.agent_modules import ( PPOTFActor, PPOTFRefRouterActor, PPOTFRefRouterSeqActor, PPOTFRefRouterV3Actor, _clone_module_for_cpu_export, ) from holomotion.src.modules.network_modules import ( GroupedMoEBlock, GroupedMoETransformerPolicy, ReferenceRoutedGroupedMoETransformerPolicy, ReferenceRoutedGroupedMoETransformerPolicyV2, ReferenceRoutedGroupedMoETransformerPolicyV3, export_safe_scaled_dot_product_attention, ) from holomotion.src.utils.onnx_export import export_policy_to_onnx from tensordict import TensorDict try: onnx = importlib.import_module("onnx") torch = importlib.import_module("torch") nn = importlib.import_module("torch.nn") except ModuleNotFoundError as exc: raise unittest.SkipTest( f"Optional ONNX test dependency missing: {exc.name}" ) from exc class _DummyTFModule(nn.Module): def __init__(self): super().__init__() self.n_layers = 1 self.max_ctx_len = 4 self.n_kv_heads = 1 self.head_dim = 2 def forward( self, obs: torch.Tensor, past_key_values: torch.Tensor, current_pos: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: return obs[:, :2], past_key_values class _DummyAttentionTFModule(nn.Module): def __init__(self): super().__init__() self.n_layers = 1 self.max_ctx_len = 4 self.n_kv_heads = 1 self.head_dim = 2 def forward( self, obs: torch.Tensor, past_key_values: torch.Tensor, current_pos: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: batch_size = obs.shape[0] max_len = past_key_values.shape[3] valid_len = (current_pos + 1).clamp(max=max_len) pos_idx = torch.arange(max_len, device=obs.device, dtype=torch.int64) mask = (pos_idx[None, :] < valid_len[:, None])[:, None, None, :] q = obs[:, :2].reshape(batch_size, 1, 1, 2) k = torch.zeros(batch_size, 1, max_len, 2, device=obs.device) v = torch.ones(batch_size, 1, max_len, 2, device=obs.device) attn_out = export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, ) actions = attn_out.reshape(batch_size, 2) return actions, past_key_values class _RecordingDeviceModule(nn.Module): def __init__(self): super().__init__() self.current_device = "cuda:0" self.to_calls = [] def to(self, device): self.to_calls.append(str(device)) self.current_device = str(device) return self def _make_minimal_real_transformer_actor( *, n_layers: int = 1, routing_score_fn: str = "softmax", num_fine_experts: int = 1, top_k: int = 1, use_dynamic_bias: bool = False, dense_layer_at_last: bool = False, selected_expert_margin_to_unselected_enabled: bool = False, selected_expert_margin_to_unselected_target: float = 0.0, ) -> PPOTFActor: actor = PPOTFActor.__new__(PPOTFActor) nn.Module.__init__(actor) actor.actor_module = GroupedMoETransformerPolicy( input_dim=6, output_dim=2, module_config_dict={ "type": "GroupedMoETransformerPolicy", "num_fine_experts": num_fine_experts, "num_shared_experts": 0, "top_k": top_k, "obs_embed_mlp_hidden": 8, "d_model": 8, "n_layers": n_layers, "n_heads": 2, "n_kv_heads": 1, "ff_mult": 1.0, "ff_mult_dense": 1, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "dense_layer_at_last": dense_layer_at_last, "use_gated_attn": False, "use_qk_norm": True, "routing_score_fn": routing_score_fn, "use_dynamic_bias": use_dynamic_bias, "selected_expert_margin_to_unselected": { "enabled": selected_expert_margin_to_unselected_enabled, "target": selected_expert_margin_to_unselected_target, }, }, ) actor.obs_norm_enabled = False actor.obs_normalizer = nn.Identity() actor.obs_norm_clip = 0.0 actor.assembler = SimpleNamespace(output_dim=6) return actor def _capture_moe_router_outputs( monkeypatch, *, export_mode: bool, top_k: int, use_dynamic_bias: bool, x: torch.Tensor, router_weight: torch.Tensor, router_x: torch.Tensor | None = None, expert_bias: torch.Tensor | None = None, ): block = GroupedMoEBlock( d_model=x.shape[-1], n_heads=2, n_kv_heads=1, num_fine_experts=router_weight.shape[0], num_shared_experts=1, top_k=top_k, ff_mult=1.0, use_qk_norm=True, use_gated_attn=False, attn_dropout=0.0, mlp_dropout=0.0, use_dynamic_bias=use_dynamic_bias, routing_score_fn="softmax", ) block.eval() with torch.no_grad(): block.router.weight.copy_(router_weight) if expert_bias is not None: block.expert_bias.copy_(expert_bias) captured = {} def _fake_sparse_experts( x_input: torch.Tensor, topk_idx: torch.Tensor, topk_scores: torch.Tensor, ) -> torch.Tensor: captured["topk_idx"] = topk_idx.detach().clone() captured["topk_scores"] = topk_scores.detach().clone() return torch.zeros_like(x_input) monkeypatch.setattr(torch.onnx, "is_in_onnx_export", lambda: export_mode) monkeypatch.setattr(block, "_compute_sparse_experts", _fake_sparse_experts) captured["output"] = block.compute_moe_ffn(x, router_x=router_x) return captured def _make_minimal_ref_router_actor() -> PPOTFRefRouterActor: actor = PPOTFRefRouterActor.__new__(PPOTFRefRouterActor) nn.Module.__init__(actor) actor.actor_module = ReferenceRoutedGroupedMoETransformerPolicy( input_dim=8, output_dim=2, module_config_dict={ "type": "ReferenceRoutedGroupedMoETransformerPolicy", "num_fine_experts": 4, "num_shared_experts": 0, "top_k": 2, "obs_embed_mlp_hidden": 8, "router_embed_mlp_hidden": 8, "router_input_dim": 4, "router_feature_indices": [0, 1, 4, 5], "d_model": 8, "n_layers": 2, "n_heads": 2, "n_kv_heads": 1, "ff_mult": 1.0, "ff_mult_dense": 1, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "use_gated_attn": False, "use_qk_norm": True, "routing_score_fn": "softmax", "use_dynamic_bias": False, }, ) actor.obs_norm_enabled = False actor.obs_normalizer = nn.Identity() actor.obs_norm_clip = 0.0 actor.assembler = SimpleNamespace(output_dim=8) return actor def _make_ref_router_v2_obs_schema() -> dict: return { "flattened_obs": { "seq_len": 1, "terms": [ "unified/actor_ref_gravity_projection_cur", "unified/actor_ref_base_linvel_cur", "unified/actor_ref_base_angvel_cur", "unified/actor_ref_dof_pos_cur", "unified/actor_projected_gravity", "unified/actor_rel_robot_root_ang_vel", "unified/actor_dof_vel", "unified/actor_dof_pos", "unified/actor_ref_root_height_cur", "unified/actor_last_action", ], }, "flattened_obs_fut": { "seq_len": 5, "terms": [ "unified/actor_ref_gravity_projection_fut", "unified/actor_ref_base_linvel_fut", "unified/actor_ref_base_angvel_fut", "unified/actor_ref_dof_pos_fut", "unified/actor_ref_root_height_fut", ], }, } def _make_ref_router_v2_obs(batch_size: list[int]) -> TensorDict: shape = list(batch_size) actor_fut_shape = shape + [5] unified = TensorDict( { "actor_ref_gravity_projection_cur": torch.randn(*shape, 3), "actor_ref_base_linvel_cur": torch.randn(*shape, 3), "actor_ref_base_angvel_cur": torch.randn(*shape, 3), "actor_ref_dof_pos_cur": torch.randn(*shape, 2), "actor_projected_gravity": torch.randn(*shape, 3), "actor_rel_robot_root_ang_vel": torch.randn(*shape, 3), "actor_dof_vel": torch.randn(*shape, 3), "actor_dof_pos": torch.randn(*shape, 3), "actor_ref_root_height_cur": torch.randn(*shape, 1), "actor_last_action": torch.randn(*shape, 2), "actor_ref_gravity_projection_fut": torch.randn( *actor_fut_shape, 3 ), "actor_ref_base_linvel_fut": torch.randn(*actor_fut_shape, 3), "actor_ref_base_angvel_fut": torch.randn(*actor_fut_shape, 3), "actor_ref_dof_pos_fut": torch.randn(*actor_fut_shape, 2), "actor_ref_root_height_fut": torch.randn(*actor_fut_shape, 1), }, batch_size=shape, ) return TensorDict({"unified": unified}, batch_size=shape) def _make_minimal_ref_router_v2_actor() -> PPOTFRefRouterSeqActor: obs_schema = _make_ref_router_v2_obs_schema() obs_example = _make_ref_router_v2_obs([2]) return PPOTFRefRouterSeqActor( obs_schema=obs_schema, module_config_dict={ "type": "ReferenceRoutedGroupedMoETransformerPolicyV2", "num_fine_experts": 4, "num_shared_experts": 0, "top_k": 2, "obs_embed_mlp_hidden": 8, "d_model": 8, "n_layers": 2, "n_heads": 2, "n_kv_heads": 1, "ff_mult": 1.0, "ff_mult_dense": 1, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "use_gated_attn": False, "use_qk_norm": True, "routing_score_fn": "softmax", "use_dynamic_bias": False, "ref_hist_n_layers": 1, "ref_future_conv_channels": 8, "ref_future_conv_layers": 2, "ref_future_conv_kernel_size": 3, "ref_future_conv_stride": 2, "obs_norm": {"enabled": False}, "output_dim": 2, }, num_actions=2, init_noise_std=0.2, obs_example=obs_example, ) def _make_minimal_ref_router_v3_actor() -> PPOTFRefRouterV3Actor: obs_schema = _make_ref_router_v2_obs_schema() obs_example = _make_ref_router_v2_obs([2]) return PPOTFRefRouterV3Actor( obs_schema=obs_schema, module_config_dict={ "type": "ReferenceRoutedGroupedMoETransformerPolicyV3", "num_fine_experts": 4, "num_shared_experts": 0, "top_k": 2, "obs_embed_mlp_hidden": 8, "d_model": 8, "n_layers": 2, "n_heads": 2, "n_kv_heads": 1, "ff_mult": 1.0, "ff_mult_dense": 1, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "use_gated_attn": False, "use_qk_norm": True, "routing_score_fn": "softmax", "use_dynamic_bias": False, "ref_hist_n_layers": 1, "router_future_hidden_dim": 12, "router_layer_proj_hidden_dim": 10, "obs_norm": {"enabled": False}, "output_dim": 2, }, num_actions=2, init_noise_std=0.2, obs_example=obs_example, ) def test_export_policy_to_onnx_uses_opset_17(monkeypatch, tmp_path): captured = {} class _FakeActor: def eval(self): return self def export_onnx( self, *, onnx_path, opset_version, use_kv_cache=True, ): captured["onnx_path"] = onnx_path captured["opset_version"] = opset_version captured["use_kv_cache"] = use_kv_cache return str(onnx_path) actor = _FakeActor() algo = SimpleNamespace( actor=actor, critic=SimpleNamespace(eval=lambda: None), accelerator=SimpleNamespace(unwrap_model=lambda model: model), env=SimpleNamespace(_env=object()), ) monkeypatch.setattr( "holomotion.src.utils.onnx_export.attach_onnx_metadata_holomotion", lambda env, onnx_path: None, ) checkpoint_path = tmp_path / "model.pt" checkpoint_path.write_bytes(b"") export_policy_to_onnx(algo, str(checkpoint_path), use_kv_cache=False) assert captured["opset_version"] == 17 assert captured["use_kv_cache"] is False def test_export_policy_to_onnx_restores_training_mode(monkeypatch, tmp_path): class _FakeActor: def __init__(self): self.training = True def eval(self): self.training = False return self def train(self, mode: bool = True): self.training = bool(mode) return self def export_onnx( self, *, onnx_path, opset_version, use_kv_cache=True, ): return str(onnx_path) class _FakeCritic: def __init__(self): self.training = True def eval(self): self.training = False return self def train(self, mode: bool = True): self.training = bool(mode) return self actor = _FakeActor() critic = _FakeCritic() algo = SimpleNamespace( actor=actor, critic=critic, accelerator=SimpleNamespace(unwrap_model=lambda model: model), env=SimpleNamespace(_env=object()), ) monkeypatch.setattr( "holomotion.src.utils.onnx_export.attach_onnx_metadata_holomotion", lambda env, onnx_path: None, ) checkpoint_path = tmp_path / "model.pt" checkpoint_path.write_bytes(b"") export_policy_to_onnx(algo, str(checkpoint_path), use_kv_cache=False) assert actor.training is True assert critic.training is True def test_clone_module_for_cpu_export_does_not_move_live_module(monkeypatch): module = _RecordingDeviceModule() monkeypatch.setattr( "holomotion.src.modules.agent_modules._module_device", lambda _: torch.device("cuda:0"), ) cloned = _clone_module_for_cpu_export(module) assert module.to_calls == [] assert module.current_device == "cuda:0" assert isinstance(cloned, _RecordingDeviceModule) assert cloned is not module def test_ppotf_actor_export_uses_legacy_torchscript(monkeypatch, tmp_path): export_calls = [] def _fake_export(*args, **kwargs): export_calls.append(kwargs) monkeypatch.setattr(torch.onnx, "export", _fake_export) actor = PPOTFActor.__new__(PPOTFActor) nn.Module.__init__(actor) actor.actor_module = _DummyTFModule() actor.obs_norm_enabled = False actor.obs_normalizer = nn.Identity() actor.obs_norm_clip = 0.0 actor.assembler = SimpleNamespace(output_dim=6) out_path = tmp_path / "policy.onnx" PPOTFActor.export_onnx( actor, out_path, opset_version=17, use_kv_cache=True, ) assert len(export_calls) == 1 assert export_calls[0]["opset_version"] == 17 assert export_calls[0]["dynamo"] is False def test_ppotf_actor_export_onnx_avoids_isnan(tmp_path): actor = PPOTFActor.__new__(PPOTFActor) nn.Module.__init__(actor) actor.actor_module = _DummyAttentionTFModule() actor.obs_norm_enabled = False actor.obs_normalizer = nn.Identity() actor.obs_norm_clip = 0.0 actor.assembler = SimpleNamespace(output_dim=2) out_path = tmp_path / "policy.onnx" PPOTFActor.export_onnx( actor, out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) op_types = [node.op_type for node in model.graph.node] assert "IsNaN" not in op_types def test_ppotf_real_transformer_export_onnx_avoids_isnan(tmp_path): actor = _make_minimal_real_transformer_actor() out_path = tmp_path / "policy_real_tf.onnx" PPOTFActor.export_onnx( actor, out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) op_types = [node.op_type for node in model.graph.node] assert "IsNaN" not in op_types def test_ppotf_real_moe_transformer_export_reaches_router_ops(tmp_path): actor = _make_minimal_real_transformer_actor( n_layers=2, num_fine_experts=4, top_k=2, routing_score_fn="softmax", ) out_path = tmp_path / "policy_real_moe_tf.onnx" PPOTFActor.export_onnx( actor, out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) op_types = [node.op_type for node in model.graph.node] assert "TopK" in op_types def test_ppotf_real_moe_transformer_export_exposes_routing_outputs(tmp_path): actor = _make_minimal_real_transformer_actor( n_layers=3, num_fine_experts=4, top_k=2, routing_score_fn="softmax", ) out_path = tmp_path / "policy_real_moe_tf_outputs.onnx" PPOTFActor.export_onnx( actor, out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) output_names = [value.name for value in model.graph.output] assert output_names == [ "actions", "present_key_values", "moe_layer_1_expert_indices", "moe_layer_1_expert_logits", "moe_layer_2_expert_indices", "moe_layer_2_expert_logits", ] def test_ppotf_real_moe_transformer_export_dense_last_uses_actual_moe_indices( tmp_path, ): actor = _make_minimal_real_transformer_actor( n_layers=4, num_fine_experts=4, top_k=2, routing_score_fn="softmax", dense_layer_at_last=True, ) out_path = tmp_path / "policy_real_moe_tf_dense_last_outputs.onnx" PPOTFActor.export_onnx( actor, out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) output_names = [value.name for value in model.graph.output] assert output_names == [ "actions", "present_key_values", "moe_layer_1_expert_indices", "moe_layer_1_expert_logits", "moe_layer_2_expert_indices", "moe_layer_2_expert_logits", ] def test_ppotf_real_moe_transformer_export_avoids_reduce_log_sum_exp( tmp_path, ): actor = _make_minimal_real_transformer_actor( n_layers=2, num_fine_experts=4, top_k=2, routing_score_fn="softmax", ) out_path = tmp_path / "policy_real_moe_tf_no_rlse.onnx" PPOTFActor.export_onnx( actor, out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) op_types = [node.op_type for node in model.graph.node] assert "ReduceLogSumExp" not in op_types def test_export_safe_moe_router_matches_training_scores_for_topk1(monkeypatch): x = torch.tensor([[[1.0, -0.5, 0.25, 2.0]]], dtype=torch.float32) router_weight = torch.tensor( [ [0.1, 0.3, -0.2, 0.5], [0.2, -0.4, 0.1, 0.7], [-0.3, 0.6, 0.2, -0.1], [0.4, 0.1, -0.5, 0.2], ], dtype=torch.float32, ) eager = _capture_moe_router_outputs( monkeypatch, export_mode=False, top_k=1, use_dynamic_bias=False, x=x, router_weight=router_weight, ) export = _capture_moe_router_outputs( monkeypatch, export_mode=True, top_k=1, use_dynamic_bias=False, x=x, router_weight=router_weight, ) assert torch.equal(export["topk_idx"], eager["topk_idx"]) torch.testing.assert_close( export["topk_scores"], eager["topk_scores"], atol=1.0e-6, rtol=1.0e-5, ) def test_export_safe_moe_router_matches_training_scores_with_dynamic_bias( monkeypatch, ): x = torch.tensor( [ [[0.2, -1.0, 0.5, 1.1], [0.4, 0.3, -0.7, 0.9]], [[-0.6, 0.8, 1.0, -0.2], [0.1, -0.4, 0.6, 0.7]], ], dtype=torch.float32, ) router_weight = torch.tensor( [ [0.2, -0.1, 0.5, 0.3], [-0.4, 0.7, 0.2, 0.1], [0.6, 0.2, -0.3, 0.4], [0.1, 0.5, 0.4, -0.6], ], dtype=torch.float32, ) expert_bias = torch.tensor([0.0, 0.4, -0.3, 0.2], dtype=torch.float32) eager = _capture_moe_router_outputs( monkeypatch, export_mode=False, top_k=2, use_dynamic_bias=True, x=x, router_weight=router_weight, expert_bias=expert_bias, ) export = _capture_moe_router_outputs( monkeypatch, export_mode=True, top_k=2, use_dynamic_bias=True, x=x, router_weight=router_weight, expert_bias=expert_bias, ) assert torch.equal(export["topk_idx"], eager["topk_idx"]) torch.testing.assert_close( export["topk_scores"], eager["topk_scores"], atol=1.0e-6, rtol=1.0e-5, ) def test_grouped_moe_router_x_keeps_topk_when_main_input_changes(monkeypatch): router_weight = torch.tensor( [ [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], ], dtype=torch.float32, ) router_x = torch.tensor([[[4.0, 1.0, 0.0, 0.0]]], dtype=torch.float32) x_a = torch.tensor([[[0.0, 0.5, 1.0, 1.5]]], dtype=torch.float32) x_b = torch.tensor([[[3.0, -2.0, -1.0, 6.0]]], dtype=torch.float32) out_a = _capture_moe_router_outputs( monkeypatch, export_mode=False, top_k=1, use_dynamic_bias=False, x=x_a, router_x=router_x, router_weight=router_weight, ) out_b = _capture_moe_router_outputs( monkeypatch, export_mode=False, top_k=1, use_dynamic_bias=False, x=x_b, router_x=router_x, router_weight=router_weight, ) assert torch.equal(out_a["topk_idx"], out_b["topk_idx"]) assert not torch.allclose(out_a["output"], out_b["output"]) def test_grouped_moe_router_x_changes_topk_when_router_input_changes( monkeypatch, ): router_weight = torch.tensor( [ [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], ], dtype=torch.float32, ) x = torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32) router_x_a = torch.tensor([[[3.0, 0.0, 0.0, 0.0]]], dtype=torch.float32) router_x_b = torch.tensor([[[0.0, 5.0, 0.0, 0.0]]], dtype=torch.float32) out_a = _capture_moe_router_outputs( monkeypatch, export_mode=False, top_k=1, use_dynamic_bias=False, x=x, router_x=router_x_a, router_weight=router_weight, ) out_b = _capture_moe_router_outputs( monkeypatch, export_mode=False, top_k=1, use_dynamic_bias=False, x=x, router_x=router_x_b, router_weight=router_weight, ) assert not torch.equal(out_a["topk_idx"], out_b["topk_idx"]) def test_ref_router_actor_export_keeps_single_obs_input_and_reaches_moe( tmp_path, ): actor = _make_minimal_ref_router_actor() out_path = tmp_path / "policy_ref_router.onnx" PPOTFActor.export_onnx( actor, out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) input_names = [value.name for value in model.graph.input] op_types = [node.op_type for node in model.graph.node] assert input_names == ["obs", "past_key_values", "step_idx"] assert "TopK" in op_types def test_ref_router_v2_actor_export_keeps_single_obs_input_and_reaches_moe( tmp_path, ): actor = _make_minimal_ref_router_v2_actor() assert actor.onnx_past_key_values_shape(batch_size=1) == ( 3, 2, 1, 4, 1, 4, ) out_path = tmp_path / "policy_ref_router_v2.onnx" actor.export_onnx( out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) input_names = [value.name for value in model.graph.input] op_types = [node.op_type for node in model.graph.node] assert input_names == ["obs", "past_key_values", "step_idx"] assert "TopK" in op_types def test_ref_router_v3_actor_export_keeps_single_obs_input_and_reaches_moe( tmp_path, ): actor = _make_minimal_ref_router_v3_actor() assert actor.onnx_past_key_values_shape(batch_size=1) == ( 3, 2, 1, 4, 1, 4, ) out_path = tmp_path / "policy_ref_router_v3.onnx" actor.export_onnx( out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) input_names = [value.name for value in model.graph.input] op_types = [node.op_type for node in model.graph.node] assert input_names == ["obs", "past_key_values", "step_idx"] assert "TopK" in op_types def test_real_transformer_actor_export_supports_selected_expert_margin( tmp_path, ): actor = _make_minimal_real_transformer_actor( n_layers=2, num_fine_experts=4, top_k=2, selected_expert_margin_to_unselected_enabled=True, selected_expert_margin_to_unselected_target=0.4, ) out_path = tmp_path / "policy_selected_expert_margin.onnx" actor.export_onnx( out_path, opset_version=17, use_kv_cache=True, ) model = onnx.load(str(out_path)) input_names = [value.name for value in model.graph.input] op_types = [node.op_type for node in model.graph.node] assert input_names == ["obs", "past_key_values", "step_idx"] assert "TopK" in op_types ================================================ FILE: tests/test_algo_base_iteration_logging.py ================================================ from pathlib import Path import sys from unittest import mock sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.algo.algo_base import BaseOnpolicyRL def test_log_iteration_uses_checkpoint_start_for_total_iterations(): algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL) algo.log_dir = "/tmp/holomotion-test" algo.gpu_world_size = 1 algo.num_steps_per_env = 8 algo.num_envs = 16 algo.num_learning_iterations = 10 algo.current_learning_iteration = 123 algo.rewbuffer = [] algo.lenbuffer = [] algo._aggregate_episode_log_metrics = lambda: {} algo._get_additional_log_metrics = lambda: {} algo.algo_logger = mock.Mock() BaseOnpolicyRL._log_iteration( algo, it=123, loss_dict={"policy": 1.5}, collection_time=2.0, learn_time=2.0, ) algo.algo_logger.log_iteration.assert_called_once() _, kwargs = algo.algo_logger.log_iteration.call_args assert kwargs["step"] == 123 assert kwargs["total_learning_iterations"] == 133 assert kwargs["metrics"]["0-Train/iteration"] == 123 assert kwargs["metrics"]["0-Train/iterations_total"] == 133 ================================================ FILE: tests/test_build_quantization_dataset.py ================================================ import importlib.util from pathlib import Path def _load_module(): module_path = ( Path(__file__).resolve().parents[1] / "not_for_commit" / "build_quantization_dataset.py" ) spec = importlib.util.spec_from_file_location( "build_quantization_dataset", module_path ) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) return module def test_allocate_sample_counts_normalizes_weights_and_matches_total(): module = _load_module() counts = module.allocate_sample_counts( { "AMASS": 0.2, "lafan1": 0.2, "MotionMillion-ft": 0.4, "pico_train": 0.05, }, 17, ) assert counts == { "AMASS": 4, "lafan1": 4, "MotionMillion-ft": 8, "pico_train": 1, } assert sum(counts.values()) == 17 def test_build_quantization_dataset_creates_symlinks(tmp_path): module = _load_module() npz_root = tmp_path / "retargeted" npz_root.mkdir() for dataset_name, clip_count in {"AMASS": 3, "lafan1": 2}.items(): dataset_dir = npz_root / dataset_name dataset_dir.mkdir() for clip_idx in range(clip_count): (dataset_dir / f"clip_{clip_idx}.npz").write_text( f"{dataset_name}-{clip_idx}", encoding="utf-8" ) output_dir = module.build_quantization_dataset( npz_root=npz_root, dataset_ratios={"AMASS": 2.0, "lafan1": 1.0}, num_clips=3, seed=0, current_date="20260324", ) assert output_dir == npz_root / "20260324_quant_dataset" created_links = sorted(output_dir.iterdir()) assert len(created_links) == 3 assert all(link.is_symlink() for link in created_links) assert {link.name.split("__", 1)[0] for link in created_links} == { "AMASS", "lafan1", } for link in created_links: assert link.resolve().is_file() assert link.suffix == ".npz" def test_build_quantization_dataset_caps_each_dataset_at_available_clips( tmp_path, ): module = _load_module() npz_root = tmp_path / "retargeted" npz_root.mkdir() for dataset_name, clip_count in {"AMASS": 1, "lafan1": 5}.items(): dataset_dir = npz_root / dataset_name dataset_dir.mkdir() for clip_idx in range(clip_count): (dataset_dir / f"clip_{clip_idx}.npz").write_text( f"{dataset_name}-{clip_idx}", encoding="utf-8" ) output_dir = module.build_quantization_dataset( npz_root=npz_root, dataset_ratios={"AMASS": 2.0, "lafan1": 1.0}, num_clips=6, seed=0, current_date="20260324", ) created_links = sorted(output_dir.iterdir()) assert len(created_links) == 3 assert sum(link.name.startswith("AMASS__") for link in created_links) == 1 assert sum(link.name.startswith("lafan1__") for link in created_links) == 2 ================================================ FILE: tests/test_cache_curriculum_sampler.py ================================================ import json from types import SimpleNamespace import torch import holomotion.src.training.h5_dataloader as h5_dataloader from holomotion.src.algo.ppo import PPO from holomotion.src.training.h5_dataloader import ( MotionClipBatchCache, ClipBatch, PrioritizedInfiniteSampler, ) def _update_sampler( sampler: PrioritizedInfiniteSampler, completion_rates: list[float], *, swap_index: int, ) -> bool: num_windows = len(completion_rates) window_indices = torch.arange(num_windows, dtype=torch.long) completion_rate_means = torch.tensor(completion_rates, dtype=torch.float32) mpkpe_signal_means = torch.zeros(num_windows, dtype=torch.float32) counts = torch.ones(num_windows, dtype=torch.float32) return sampler.maybe_update_from_observations( window_indices=window_indices, mpkpe_signal_means=mpkpe_signal_means, completion_rate_means=completion_rate_means, counts=counts, swap_index=swap_index, ) def _update_sampler_subset( sampler: PrioritizedInfiniteSampler, *, window_indices: list[int], completion_rates: list[float], swap_index: int, counts: list[float] | None = None, ) -> bool: if counts is None: counts = [1.0] * len(window_indices) window_index_tensor = torch.tensor(window_indices, dtype=torch.long) completion_rate_tensor = torch.tensor( completion_rates, dtype=torch.float32 ) count_tensor = torch.tensor(counts, dtype=torch.float32) mpkpe_signal_means = torch.zeros(len(window_indices), dtype=torch.float32) return sampler.maybe_update_from_observations( window_indices=window_index_tensor, mpkpe_signal_means=mpkpe_signal_means, completion_rate_means=completion_rate_tensor, counts=count_tensor, swap_index=swap_index, ) class _ChunkLimitedSampler: def __init__(self, *, max_query_size: int) -> None: self.max_query_size = int(max_query_size) self.state_version = 7 self.query_sizes: list[int] = [] def _checked_indices(self, window_indices: torch.Tensor) -> torch.Tensor: indices = window_indices.to(dtype=torch.long).reshape(-1) size = int(indices.numel()) self.query_sizes.append(size) if size > self.max_query_size: raise AssertionError( f"expected chunked access <= {self.max_query_size}, got {size}" ) return indices def get_scores_for_indices( self, window_indices: torch.Tensor ) -> torch.Tensor: indices = self._checked_indices(window_indices) return indices.to(dtype=torch.float32) def get_window_state_for_indices( self, window_indices: torch.Tensor ) -> dict[str, torch.Tensor]: indices = self._checked_indices(window_indices) count = int(indices.numel()) return { "ema_completion_rate": torch.zeros(count, dtype=torch.float32), "completion_rate_rel_improve": torch.zeros( count, dtype=torch.float32 ), "selection_count": indices + 10, "seen": torch.zeros(count, dtype=torch.bool), "in_prioritized_pool": torch.zeros(count, dtype=torch.bool), } def get_pool_statistics(self) -> dict[str, float]: return { "prioritized_pool_size": 0.0, "prioritized_pool_mean_score": 0.0, "uniform_pool_mean_score": 0.0, "entered_prioritized_pool_count": 0.0, "exited_prioritized_pool_count": 0.0, } class _FakeTrainDataset: def __init__(self, windows: list[SimpleNamespace]) -> None: self.windows = windows def __len__(self) -> int: return len(self.windows) def close(self) -> None: return def test_sampler_uses_configured_uniform_ratio_immediately(): sampler = PrioritizedInfiniteSampler( dataset_len=8, batch_size=10, seed=0, p_a_ratio=0.3, ) assert sampler._pool_batch_sizes() == (3, 7) def test_completion_rate_relative_improvement_alone_drives_scores(): sampler = PrioritizedInfiniteSampler( dataset_len=3, batch_size=2, seed=0, p_a_ratio=0.5, ema_alpha_signal=0.5, ema_alpha_rel_improve=1.0, ) assert _update_sampler(sampler, [0.2, 0.2, 0.2], swap_index=1) assert _update_sampler(sampler, [0.8, 0.2, 0.2], swap_index=2) scores = sampler.get_scores_for_indices(torch.arange(3, dtype=torch.long)) assert scores[0].item() > 0.0 assert scores[1].item() == 0.0 assert scores[2].item() == 0.0 def test_sampler_weights_progress_by_remaining_difficulty(): sampler = PrioritizedInfiniteSampler( dataset_len=2, batch_size=2, seed=0, p_a_ratio=0.5, ema_alpha_signal=1.0, ema_alpha_rel_improve=1.0, ) assert _update_sampler(sampler, [0.1, 0.8], swap_index=1) assert _update_sampler(sampler, [0.2, 0.9], swap_index=2) scores = sampler.get_scores_for_indices(torch.arange(2, dtype=torch.long)) assert scores[0].item() > scores[1].item() def test_sampler_tracks_cumulative_selection_counts(): sampler = PrioritizedInfiniteSampler( dataset_len=4, batch_size=2, seed=0, p_a_ratio=0.5, ) iterator = iter(sampler) selected_indices = [next(iterator) for _ in range(4)] state = sampler.get_window_state_for_indices(torch.arange(4)) selection_count = state["selection_count"] expected_count = torch.bincount( torch.tensor(selected_indices, dtype=torch.long), minlength=4 ) assert int(selection_count.sum().item()) == 4 assert torch.equal(selection_count, expected_count) def test_low_completion_plateau_drops_from_prioritized_replay(): sampler = PrioritizedInfiniteSampler( dataset_len=3, batch_size=3, seed=0, p_a_ratio=1.0 / 3.0, ema_alpha_signal=1.0, ema_alpha_rel_improve=1.0, ) assert _update_sampler(sampler, [0.1, 0.2, 0.2], swap_index=1) assert _update_sampler(sampler, [0.4, 0.2, 0.2], swap_index=2) assert _update_sampler(sampler, [0.4, 0.8, 0.8], swap_index=3) state = sampler.get_window_state_for_indices( torch.arange(3, dtype=torch.long) ) assert not bool(state["in_prioritized_pool"][0].item()) generator = torch.Generator().manual_seed(0) uniform_pick = sampler._sample_uniform_indices(generator, 3) assert 0 in uniform_pick.tolist() def test_prioritized_windows_persist_beyond_immediate_batch(): sampler = PrioritizedInfiniteSampler( dataset_len=6, batch_size=4, seed=0, p_a_ratio=0.5, ema_alpha_signal=1.0, ema_alpha_rel_improve=1.0, ) assert _update_sampler_subset( sampler, window_indices=[0, 1], completion_rates=[0.2, 0.2], swap_index=1, ) assert _update_sampler_subset( sampler, window_indices=[0, 1], completion_rates=[0.8, 0.7], swap_index=2, ) assert _update_sampler_subset( sampler, window_indices=[2, 3], completion_rates=[0.3, 0.3], swap_index=3, ) state = sampler.get_window_state_for_indices(torch.tensor([0, 1])) assert torch.equal( state["in_prioritized_pool"], torch.tensor([True, True], dtype=torch.bool), ) def test_sampler_reports_pool_means_and_membership_churn(): sampler = PrioritizedInfiniteSampler( dataset_len=4, batch_size=4, seed=0, p_a_ratio=0.5, ema_alpha_signal=0.5, ema_alpha_rel_improve=1.0, ) assert _update_sampler(sampler, [0.2, 0.2, 0.2, 0.2], swap_index=1) assert _update_sampler(sampler, [0.9, 0.8, 0.2, 0.2], swap_index=2) assert _update_sampler(sampler, [0.1, 0.2, 0.9, 0.8], swap_index=3) next(iter(sampler)) stats = sampler.get_pool_statistics() assert stats is not None assert set(stats) == { "prioritized_pool_size", "prioritized_pool_mean_score", "uniform_pool_mean_score", "entered_prioritized_pool_count", "exited_prioritized_pool_count", } assert stats["prioritized_pool_size"] == 2.0 assert stats["entered_prioritized_pool_count"] == 2.0 assert stats["exited_prioritized_pool_count"] == 2.0 assert ( stats["prioritized_pool_mean_score"] > stats["uniform_pool_mean_score"] ) def test_sampler_hot_path_avoids_full_dataset_temporaries(monkeypatch): sampler = PrioritizedInfiniteSampler( dataset_len=1_000_000, batch_size=8, seed=0, p_a_ratio=0.5, ema_alpha_signal=1.0, ema_alpha_rel_improve=1.0, ) orig_zeros = h5_dataloader.torch.zeros orig_arange = h5_dataloader.torch.arange orig_randperm = h5_dataloader.torch.randperm def _guard_size(arg) -> int | None: if isinstance(arg, int): return arg if ( isinstance(arg, tuple) and len(arg) == 1 and isinstance(arg[0], int) ): return arg[0] return None def guarded_zeros(*args, **kwargs): size = _guard_size(args[0]) if args else None if size == sampler._ds_len: raise AssertionError("full-dataset zeros in hot path") return orig_zeros(*args, **kwargs) def guarded_arange(*args, **kwargs): if args and args[0] == sampler._ds_len: raise AssertionError("full-dataset arange in hot path") return orig_arange(*args, **kwargs) def guarded_randperm(*args, **kwargs): if args and args[0] == sampler._ds_len: raise AssertionError("full-dataset randperm in hot path") return orig_randperm(*args, **kwargs) monkeypatch.setattr(h5_dataloader.torch, "zeros", guarded_zeros) monkeypatch.setattr(h5_dataloader.torch, "arange", guarded_arange) monkeypatch.setattr(h5_dataloader.torch, "randperm", guarded_randperm) assert _update_sampler_subset( sampler, window_indices=[5, 25, 125, 625], completion_rates=[0.1, 0.2, 0.3, 0.4], swap_index=1, ) batch_indices = sampler._sample_batch_indices( torch.Generator().manual_seed(0) ) assert int(batch_indices.numel()) == 8 def test_ppo_logs_only_core_curriculum_metrics(): algo = PPO.__new__(PPO) algo.actor_learning_rate = 1.0e-4 algo.critic_learning_rate = 2.0e-4 algo._last_update_metrics = {} algo.command_name = "ref_motion" algo._get_mean_policy_std = lambda: torch.tensor(0.0) cache = SimpleNamespace( swap_index=12, cache_curriculum_pool_statistics=lambda: { "prioritized_pool_size": 2.0, "prioritized_pool_mean_score": 0.8, "uniform_pool_mean_score": 0.1, "entered_prioritized_pool_count": 1.0, "exited_prioritized_pool_count": 1.0, }, ) motion_cmd = SimpleNamespace(_motion_cache=cache) algo.env = SimpleNamespace( _env=SimpleNamespace( command_manager=SimpleNamespace( get_term=lambda name: motion_cmd, ) ) ) metrics = algo._get_additional_log_metrics() assert metrics["1-Perf/Cache/swap_index"] == 12.0 assert metrics["1-Perf/Cache/prioritized_pool_size"] == 2.0 assert metrics["1-Perf/Cache/prioritized_pool_mean_score"] == 0.8 assert metrics["1-Perf/Cache/uniform_pool_mean_score"] == 0.1 assert metrics["1-Perf/Cache/entered_prioritized_pool_count"] == 1.0 assert metrics["1-Perf/Cache/exited_prioritized_pool_count"] == 1.0 assert ( "1-Perf/Cache/curriculum_probability_coefficient_of_variation" not in metrics ) assert ( "1-Perf/Cache/curriculum_max_probability_over_uniform" not in metrics ) assert "1-Perf/Cache/uniform_floor_ratio" not in metrics def test_cache_curriculum_dumps_on_scheduled_swap_even_without_state_update(): cache = MotionClipBatchCache.__new__(MotionClipBatchCache) cache._datasets = {} cache._cache_curriculum_sampler = SimpleNamespace( maybe_update_from_observations=lambda **kwargs: False, ) dumped_swaps = [] cache._maybe_dump_cache_curriculum_scores_json = ( lambda *, swap_index: dumped_swaps.append(swap_index) ) updated = cache.update_cache_curriculum( window_indices=torch.tensor([0], dtype=torch.long), mpkpe_signal_means=torch.tensor([0.0], dtype=torch.float32), completion_rate_means=torch.tensor([0.0], dtype=torch.float32), counts=torch.tensor([1.0], dtype=torch.float32), swap_index=5, ) assert updated is False assert dumped_swaps == [5] def test_update_cache_curriculum_refreshes_prefetched_batch_when_state_changes(): cache = MotionClipBatchCache.__new__(MotionClipBatchCache) cache._datasets = {} cache._cache_curriculum_sampler = SimpleNamespace( maybe_update_from_observations=lambda **kwargs: True, ) cache._cache_curriculum_dump_enabled = False cache._next_batch = ClipBatch( tensors={}, lengths=torch.tensor([1], dtype=torch.long), motion_keys=["stale"], raw_motion_keys=["stale"], window_indices=torch.tensor([0], dtype=torch.long), max_frame_length=1, ) refreshed_batch = ClipBatch( tensors={}, lengths=torch.tensor([1], dtype=torch.long), motion_keys=["fresh"], raw_motion_keys=["fresh"], window_indices=torch.tensor([1], dtype=torch.long), max_frame_length=1, ) cache._fetch_next_batch = lambda: refreshed_batch updated = cache.update_cache_curriculum( window_indices=torch.tensor([0], dtype=torch.long), mpkpe_signal_means=torch.tensor([0.0], dtype=torch.float32), completion_rate_means=torch.tensor([0.0], dtype=torch.float32), counts=torch.tensor([1.0], dtype=torch.float32), swap_index=5, ) assert updated is True assert cache._next_batch.motion_keys == ["fresh"] def test_cache_curriculum_whole_window_dump_streams_rows_in_chunks( tmp_path, ): cache = MotionClipBatchCache.__new__(MotionClipBatchCache) cache._datasets = { "train": _FakeTrainDataset( [ SimpleNamespace( raw_motion_key=f"raw_{idx}", motion_key=f"motion_{idx}", start=idx, length=idx + 1, ) for idx in range(5) ] ) } sampler = _ChunkLimitedSampler(max_query_size=2) cache._cache_curriculum_sampler = sampler cache._cache_curriculum_dump_enabled = True cache._cache_curriculum_dump_every_swaps = 1 cache._cache_curriculum_dump_chunk_size = 2 cache._cache_curriculum_last_dump_swap = -1 cache._cache_curriculum_dump_dir = tmp_path cache._sampler_rank = 3 cache._maybe_dump_cache_curriculum_scores_json(swap_index=1) output_path = tmp_path / "whole_window_scores_rank_0003_swap_000001.json" payload = json.loads(output_path.read_text()) assert output_path.exists() assert payload["swap_index"] == 1 assert payload["rank"] == 3 assert payload["sampler_state_version"] == 7 assert payload["num_windows"] == 5 assert len(payload["rows"]) == 5 assert payload["rows"][0]["window_index"] == 0 assert payload["rows"][0]["selection_count"] == 10 assert payload["rows"][-1]["window_index"] == 4 assert payload["rows"][-1]["selection_count"] == 14 assert "probability" not in payload["rows"][0] assert max(sampler.query_sizes) == 2 ================================================ FILE: tests/test_domain_rand_config_builder.py ================================================ import importlib.util import sys from pathlib import Path from types import ModuleType MODULE_PATH = ( Path(__file__).resolve().parents[1] / "holomotion" / "src" / "env" / "isaaclab_components" / "isaaclab_domain_rand.py" ) class _DummyEventTermCfg: def __init__(self, **kwargs): self.__dict__.update(kwargs) class _DummySceneEntityCfg: def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs def resolve(self, _scene): return None def _load_domain_rand_module(monkeypatch): isaaclab = ModuleType("isaaclab") isaaclab_utils = ModuleType("isaaclab.utils") isaaclab_utils.configclass = lambda cls: cls isaaclab_utils_math = ModuleType("isaaclab.utils.math") isaaclab_assets = ModuleType("isaaclab.assets") isaaclab_assets.Articulation = object isaaclab_envs = ModuleType("isaaclab.envs") isaaclab_envs.ManagerBasedEnv = object isaaclab_envs_mdp = ModuleType("isaaclab.envs.mdp") isaaclab_envs_mdp.events = ModuleType("isaaclab.envs.mdp.events") isaaclab_envs_mdp.events._randomize_prop_by_op = ( lambda *args, **kwargs: None ) isaaclab_managers = ModuleType("isaaclab.managers") isaaclab_managers.SceneEntityCfg = _DummySceneEntityCfg isaaclab_managers.EventTermCfg = _DummyEventTermCfg isaaclab.envs = isaaclab_envs isaaclab.assets = isaaclab_assets isaaclab.utils = isaaclab_utils isaaclab_envs.mdp = isaaclab_envs_mdp isaaclab_utils.math = isaaclab_utils_math for name, module in { "isaaclab": isaaclab, "isaaclab.utils": isaaclab_utils, "isaaclab.utils.math": isaaclab_utils_math, "isaaclab.assets": isaaclab_assets, "isaaclab.envs": isaaclab_envs, "isaaclab.envs.mdp": isaaclab_envs_mdp, "isaaclab.envs.mdp.events": isaaclab_envs_mdp.events, "isaaclab.managers": isaaclab_managers, }.items(): monkeypatch.setitem(sys.modules, name, module) module_name = "_test_domain_rand_builder" spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH) module = importlib.util.module_from_spec(spec) assert spec is not None assert spec.loader is not None sys.modules[module_name] = module spec.loader.exec_module(module) return module def test_build_domain_rand_config_skips_non_event_metadata(monkeypatch): module = _load_domain_rand_module(monkeypatch) events_cfg = module.build_domain_rand_config( { "erfi": { "enabled": True, "rfi_probability": 0.5, "rfi_lim": 0.1, "randomize_rfi_lim": True, "rfi_lim_range": [0.5, 1.5], "rao_lim": 0.1, }, "action_delay": { "enabled": True, "min_delay": 1, "max_delay": 3, }, "motion_init_perturb": { "root_pose_perturb_range": {"x": [-0.1, 0.1]} }, "obs_noise": {"actor_dof_pos": {"n_min": -0.01, "n_max": 0.01}}, "default_dof_pos_bias": { "mode": "startup", "params": { "joint_names": [".*"], "pos_distribution_params": [-0.01, 0.01], "operation": "add", "distribution": "uniform", }, }, } ) assert hasattr(events_cfg, "default_dof_pos_bias") assert events_cfg.default_dof_pos_bias.mode == "startup" assert not hasattr(events_cfg, "erfi") assert not hasattr(events_cfg, "action_delay") assert not hasattr(events_cfg, "motion_init_perturb") assert not hasattr(events_cfg, "obs_noise") ================================================ FILE: tests/test_eval_mujoco_action_delay.py ================================================ from collections import deque import numpy as np from omegaconf import OmegaConf import holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim def test_action_delay_cfg_defaults_to_disabled_episode(): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.config = OmegaConf.create({}) max_delay_step, delay_type = evaluator._get_action_delay_cfg() assert max_delay_step == 0 assert delay_type == "episode" def test_action_delay_cfg_rejects_invalid_delay_type(): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.config = OmegaConf.create( { "policy_action_delay_step": 2, "action_delay_type": "frame", } ) try: evaluator._get_action_delay_cfg() except ValueError as exc: assert "action_delay_type" in str(exc) else: raise AssertionError("Expected ValueError for invalid delay type.") def test_apply_action_delay_passthrough_when_disabled(): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.policy_action_delay_step = 0 evaluator.action_delay_type = "episode" evaluator._policy_action_delay_buffer = deque(maxlen=1) evaluator._current_policy_action_delay_step = 0 delayed = evaluator._apply_action_delay( np.array([1.0, -1.0], dtype=np.float32) ) np.testing.assert_allclose( delayed, np.array([1.0, -1.0], dtype=np.float32) ) def test_apply_action_delay_episode_reuses_single_sample(monkeypatch): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.policy_action_delay_step = 1 evaluator.action_delay_type = "episode" calls = [] def fake_randint(low, high): calls.append((low, high)) return 1 monkeypatch.setattr(eval_mujoco_sim2sim.np.random, "randint", fake_randint) evaluator._reset_action_delay_randomization() first = evaluator._apply_action_delay(np.array([1.0], dtype=np.float32)) second = evaluator._apply_action_delay(np.array([2.0], dtype=np.float32)) third = evaluator._apply_action_delay(np.array([3.0], dtype=np.float32)) assert calls == [(0, 2)] assert evaluator._current_policy_action_delay_step == 1 np.testing.assert_allclose(first, np.array([1.0], dtype=np.float32)) np.testing.assert_allclose(second, np.array([1.0], dtype=np.float32)) np.testing.assert_allclose(third, np.array([2.0], dtype=np.float32)) def test_apply_action_delay_step_resamples_each_policy_step(monkeypatch): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.policy_action_delay_step = 2 evaluator.action_delay_type = "step" evaluator._policy_action_delay_buffer = deque(maxlen=3) evaluator._current_policy_action_delay_step = 0 sampled_delays = iter([2, 0, 1]) calls = [] def fake_randint(low, high): calls.append((low, high)) return next(sampled_delays) monkeypatch.setattr(eval_mujoco_sim2sim.np.random, "randint", fake_randint) first = evaluator._apply_action_delay(np.array([1.0], dtype=np.float32)) second = evaluator._apply_action_delay(np.array([2.0], dtype=np.float32)) third = evaluator._apply_action_delay(np.array([3.0], dtype=np.float32)) assert calls == [(0, 3), (0, 3), (0, 3)] assert evaluator._current_policy_action_delay_step == 1 np.testing.assert_allclose(first, np.array([1.0], dtype=np.float32)) np.testing.assert_allclose(second, np.array([2.0], dtype=np.float32)) np.testing.assert_allclose(third, np.array([2.0], dtype=np.float32)) ================================================ FILE: tests/test_eval_mujoco_action_ema.py ================================================ import numpy as np from omegaconf import OmegaConf import holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim def test_action_ema_filter_cfg_reads_erfi_settings(): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.config = OmegaConf.create( { "robot": { "actuators": { "actuator_type": "unitree_erfi", "ema_filter_enabled": True, "ema_filter_alpha": 0.37, } }, } ) enabled, alpha = evaluator._get_action_ema_filter_cfg() assert enabled is True assert alpha == 0.37 def test_action_ema_filter_defaults_to_disabled_for_non_erfi(): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.config = OmegaConf.create( { "robot": { "actuators": { "actuator_type": "unitree", "ema_filter_enabled": True, "ema_filter_alpha": 0.37, } }, } ) enabled, alpha = evaluator._get_action_ema_filter_cfg() assert enabled is False assert alpha == 1.0 def test_apply_action_ema_filter_uses_previous_filtered_action(): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.action_ema_filter_enabled = True evaluator.action_ema_filter_alpha = 0.25 evaluator._filtered_actions_onnx = None first = evaluator._apply_action_ema_filter( np.array([1.0, -1.0], dtype=np.float32) ) second = evaluator._apply_action_ema_filter( np.array([3.0, 1.0], dtype=np.float32) ) np.testing.assert_allclose(first, np.array([1.0, -1.0], dtype=np.float32)) np.testing.assert_allclose(second, np.array([1.5, -0.5], dtype=np.float32)) def test_reset_action_ema_filter_clears_state(): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator._filtered_actions_onnx = np.array([1.0], dtype=np.float32) evaluator._reset_action_ema_filter() assert evaluator._filtered_actions_onnx is None ================================================ FILE: tests/test_eval_mujoco_contact_export.py ================================================ import json from pathlib import Path import numpy as np from omegaconf import OmegaConf import holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim def _build_export_evaluator(tmp_path: Path): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.simulation_dt = 0.005 evaluator._get_stacked_moe_routing_tensors = lambda: (None, None) evaluator._robot_dof_pos_seq = [ np.array([0.0, 1.0], dtype=np.float32), np.array([0.5, 1.5], dtype=np.float32), ] evaluator._robot_dof_vel_seq = [ np.array([0.1, 0.2], dtype=np.float32), np.array([0.3, 0.4], dtype=np.float32), ] evaluator._robot_dof_acc_seq = [ np.array([1.0, 2.0], dtype=np.float32), np.array([3.0, 4.0], dtype=np.float32), ] evaluator._robot_dof_torque_seq = [ np.array([5.0, 6.0], dtype=np.float32), np.array([7.0, 8.0], dtype=np.float32), ] evaluator._robot_low_level_dof_torque_seq = [ np.array([1.0, 2.0], dtype=np.float32), np.array([3.0, 4.0], dtype=np.float32), np.array([5.0, 6.0], dtype=np.float32), np.array([7.0, 8.0], dtype=np.float32), ] evaluator._robot_low_level_foot_contact_seq = [ np.array([1.0, 0.0], dtype=np.float32), np.array([1.0, 1.0], dtype=np.float32), np.array([0.0, 1.0], dtype=np.float32), np.array([0.0, 0.0], dtype=np.float32), ] evaluator._robot_low_level_foot_normal_force_seq = [ np.array([50.0, 0.0], dtype=np.float32), np.array([60.0, 55.0], dtype=np.float32), np.array([0.0, 45.0], dtype=np.float32), np.array([0.0, 0.0], dtype=np.float32), ] evaluator._robot_low_level_foot_tangent_speed_seq = [ np.array([0.02, 0.0], dtype=np.float32), np.array([0.03, 0.04], dtype=np.float32), np.array([0.0, 0.05], dtype=np.float32), np.array([0.0, 0.0], dtype=np.float32), ] evaluator._robot_action_rate_seq = [ np.float32(0.0), np.float32(1.0), ] evaluator._robot_actions_seq = [ np.array([0.11, 0.22], dtype=np.float32), np.array([0.33, 0.44], dtype=np.float32), ] evaluator._robot_global_translation_seq = [ np.zeros((2, 3), dtype=np.float32), np.ones((2, 3), dtype=np.float32), ] evaluator._robot_global_rotation_quat_seq = [ np.tile(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 1)), np.tile(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 1)), ] evaluator._robot_global_velocity_seq = [ np.zeros((2, 3), dtype=np.float32), np.ones((2, 3), dtype=np.float32), ] evaluator._robot_global_angular_velocity_seq = [ np.zeros((2, 3), dtype=np.float32), np.ones((2, 3), dtype=np.float32), ] evaluator.ref_dof_pos = np.zeros((2, 2), dtype=np.float32) evaluator.ref_dof_vel = np.zeros((2, 2), dtype=np.float32) evaluator.ref_global_translation = np.zeros((2, 2, 3), dtype=np.float32) evaluator.ref_global_rotation_quat_xyzw = np.tile( np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 2, 1) ) evaluator.ref_global_velocity = np.zeros((2, 2, 3), dtype=np.float32) evaluator.ref_global_angular_velocity = np.zeros( (2, 2, 3), dtype=np.float32 ) motion_npz_path = tmp_path / "motion.npz" np.savez_compressed( motion_npz_path, metadata=np.array(json.dumps({"clip_length": 2}), dtype=np.str_), ) evaluator.config = OmegaConf.create( { "motion_npz_path": str(motion_npz_path), "ckpt_onnx_path": str(tmp_path / "model.onnx"), } ) return evaluator def test_save_batch_result_exports_low_level_contact_traces(tmp_path: Path): evaluator = _build_export_evaluator(tmp_path) output_path = tmp_path / "batch_result.npz" evaluator.save_batch_result(str(output_path), {"clip_length": 2}) with np.load(output_path, allow_pickle=True) as data: assert "robot_actions" in data.files assert "robot_low_level_foot_contact" in data.files assert "robot_low_level_foot_normal_force" in data.files assert "robot_low_level_foot_tangent_speed" in data.files assert "robot_low_level_contact_dt" in data.files np.testing.assert_allclose( data["robot_actions"], np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32), ) assert data["robot_low_level_foot_contact"].shape == (4, 2) np.testing.assert_allclose( data["robot_low_level_contact_dt"], np.array(0.005, np.float32) ) def test_dump_robot_augmented_npz_exports_low_level_contact_traces( tmp_path: Path, ): evaluator = _build_export_evaluator(tmp_path) evaluator._dump_robot_augmented_npz() output_path = ( tmp_path / "mujoco_output_model" / f"{Path(evaluator.config.motion_npz_path).stem}_robot.npz" ) with np.load(output_path, allow_pickle=True) as data: assert "robot_actions" in data.files assert "robot_low_level_foot_contact" in data.files assert "robot_low_level_foot_normal_force" in data.files assert "robot_low_level_foot_tangent_speed" in data.files assert "robot_low_level_contact_dt" in data.files np.testing.assert_allclose( data["robot_actions"], np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32), ) assert data["robot_low_level_foot_normal_force"].shape == (4, 2) def test_init_low_level_foot_contact_logging_falls_back_to_ankle_roll_bodies( monkeypatch, ): evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.config = OmegaConf.create({"robot": {}}) evaluator.m = type( "FakeModel", (), { "geom_bodyid": np.array([5, 6, 6, 9, 10], dtype=np.int32), "geom_contype": np.array([0, 1, 1, 0, 1], dtype=np.int32), "geom_conaffinity": np.array([0, 1, 1, 0, 1], dtype=np.int32), }, )() def fake_name2id(model, obj_type, name): if obj_type == eval_mujoco_sim2sim.mujoco.mjtObj.mjOBJ_GEOM: return -1 if obj_type == eval_mujoco_sim2sim.mujoco.mjtObj.mjOBJ_BODY: return { "left_ankle_roll_link": 6, "right_ankle_roll_link": 10, }.get(name, -1) return -1 monkeypatch.setattr(eval_mujoco_sim2sim.mujoco, "mj_name2id", fake_name2id) evaluator._init_low_level_foot_contact_logging() assert evaluator._foot_contact_logging_enabled is True assert evaluator._foot_geom_id_groups == [[1, 2], [4]] assert evaluator._foot_geom_id_to_side == {1: 0, 2: 0, 4: 1} ================================================ FILE: tests/test_eval_mujoco_s100_horizon_ptq.py ================================================ import sys from pathlib import Path from types import SimpleNamespace import numpy as np import pytest from omegaconf import OmegaConf sys.path.insert(0, str(Path(__file__).resolve().parents[1])) import holomotion.src.evaluation.eval_mujoco_sim2sim_s100 as eval_mujoco_sim2sim_s100 class _FakeIoNode: def __init__(self, name, shape): self.name = name self.shape = shape def _make_value_info(name, shape): dims = [SimpleNamespace(dim_value=dim) for dim in shape] tensor_shape = SimpleNamespace(dim=dims) tensor_type = SimpleNamespace(shape=tensor_shape) return SimpleNamespace( name=name, type=SimpleNamespace(tensor_type=tensor_type) ) def _make_fake_onnx_model(): return SimpleNamespace( graph=SimpleNamespace( input=[ _make_value_info("obs", [1, 16]), _make_value_info("past_key_values", [1, 2, 3, 4]), _make_value_info("step_idx", [1]), ], output=[ _make_value_info("action", [1, 12]), _make_value_info("present_key_values", [1, 2, 3, 4]), ], ) ) def _make_evaluator(model_path: Path, bc_path: Path | None = None): config_dict = { "ckpt_onnx_path": str(model_path), "use_gpu": False, "gpu_id": 0, } if bc_path is not None: config_dict["bc_path"] = str(bc_path) evaluator = eval_mujoco_sim2sim_s100.MujocoEvaluator.__new__( eval_mujoco_sim2sim_s100.MujocoEvaluator ) evaluator.config = OmegaConf.create(config_dict) evaluator.max_context_len = 0 evaluator._discover_policy_moe_outputs = lambda: None return evaluator def test_load_policy_falls_back_to_horizon_quantized_bc_for_ptq_onnx( monkeypatch, tmp_path ): model_path = tmp_path / "demo_ptq_model.onnx" model_path.write_bytes(b"onnx") quantized_path = tmp_path / "demo_quantized_model.bc" quantized_path.write_bytes(b"bc") captured = {} class _FakeHBRuntime: def __init__(self, model_path): captured["hb_model_path"] = model_path self.input_names = ["obs", "past_key_values", "step_idx"] self.output_names = ["action", "present_key_values"] def run(self, output_names, input_feed): raise AssertionError("run should not be called in this test") def _raise_hz_calibration(*args, **kwargs): raise RuntimeError("Failed to load custom op HzCalibration") monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "get_available_providers", lambda: ["CPUExecutionProvider"], ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "InferenceSession", _raise_hz_calibration, ) monkeypatch.setattr( eval_mujoco_sim2sim_s100, "HBRuntime", _FakeHBRuntime, raising=False, ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnx, "load", lambda _: _make_fake_onnx_model(), ) evaluator = _make_evaluator(model_path) evaluator.load_policy() assert captured["hb_model_path"] == str(quantized_path) assert evaluator.policy_input_name == "obs" assert evaluator.policy_kv_input_name == "past_key_values" assert evaluator.policy_step_input_name == "step_idx" assert evaluator.policy_output_name == "action" assert evaluator.policy_kv_output_name == "present_key_values" assert evaluator.policy_model_context_len == 4 @pytest.mark.parametrize( "runtime_name", [ "demo_model_16000_ptq_model.bc", "demo_model_16000_ptq_model.hbm", "demo_model_16000_quantized_model.hbm", ], ) def test_load_policy_resolves_common_horizon_runtime_artifact_names( monkeypatch, tmp_path, runtime_name ): model_path = tmp_path / "demo_model_16000_ptq_model.onnx" model_path.write_bytes(b"onnx") runtime_path = tmp_path / runtime_name runtime_path.write_bytes(b"runtime") captured = {} class _FakeHBRuntime: def __init__(self, model_path): captured["hb_model_path"] = model_path self.input_names = ["obs", "past_key_values", "step_idx"] self.output_names = ["action", "present_key_values"] def run(self, output_names, input_feed): raise AssertionError("run should not be called in this test") def _raise_hz_calibration(*args, **kwargs): raise RuntimeError("Failed to load custom op HzCalibration") monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "get_available_providers", lambda: ["CPUExecutionProvider"], ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "InferenceSession", _raise_hz_calibration, ) monkeypatch.setattr( eval_mujoco_sim2sim_s100, "HBRuntime", _FakeHBRuntime, raising=False, ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnx, "load", lambda _: _make_fake_onnx_model(), ) evaluator = _make_evaluator(model_path) evaluator.load_policy() assert captured["hb_model_path"] == str(runtime_path) assert evaluator.policy_model_context_len == 4 def test_load_policy_raises_original_error_when_ptq_fallback_bc_missing( monkeypatch, tmp_path ): model_path = tmp_path / "demo_ptq_model.onnx" model_path.write_bytes(b"onnx") def _raise_hz_calibration(*args, **kwargs): raise RuntimeError("Failed to load custom op HzCalibration") monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "get_available_providers", lambda: ["CPUExecutionProvider"], ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "InferenceSession", _raise_hz_calibration, ) evaluator = _make_evaluator(model_path) with pytest.raises(RuntimeError, match="HzCalibration"): evaluator.load_policy() def test_load_policy_keeps_standard_onnxruntime_path_for_regular_onnx( monkeypatch, tmp_path ): model_path = tmp_path / "demo_model.onnx" model_path.write_bytes(b"onnx") captured = {} class _FakeInferenceSession: def __init__(self, model_path, sess_options, providers): captured["model_path"] = model_path captured["providers"] = providers def get_providers(self): return ["CPUExecutionProvider"] def get_inputs(self): return [ _FakeIoNode("obs", [1, 16]), _FakeIoNode("past_key_values", [1, 2, 3, 4]), _FakeIoNode("step_idx", [1]), ] def get_outputs(self): return [ _FakeIoNode("action", [1, 12]), _FakeIoNode("present_key_values", [1, 2, 3, 4]), ] monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "get_available_providers", lambda: ["CPUExecutionProvider"], ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "InferenceSession", _FakeInferenceSession, ) evaluator = _make_evaluator(model_path) evaluator.load_policy() assert captured["model_path"] == str(model_path) assert captured["providers"] == ["CPUExecutionProvider"] assert evaluator.policy_model_context_len == 4 def test_load_policy_prefers_explicit_bc_path_for_inference_and_onnx_for_metadata( monkeypatch, tmp_path ): model_path = tmp_path / "demo_model.onnx" model_path.write_bytes(b"onnx") runtime_path = tmp_path / "demo_quantized_model.bc" runtime_path.write_bytes(b"bc") captured = {} class _FakeHBRuntime: def __init__(self, model_path): captured["hb_model_path"] = model_path self.input_names = ["obs", "past_key_values", "step_idx"] self.output_names = ["action", "present_key_values"] def run(self, output_names, input_feed): raise AssertionError("run should not be called in this test") def _unexpected_ort_session(*args, **kwargs): raise AssertionError( "InferenceSession should not be created when bc_path is set" ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "get_available_providers", lambda: ["CPUExecutionProvider"], ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnxruntime, "InferenceSession", _unexpected_ort_session, ) monkeypatch.setattr( eval_mujoco_sim2sim_s100, "HBRuntime", _FakeHBRuntime, raising=False, ) monkeypatch.setattr( eval_mujoco_sim2sim_s100.onnx, "load", lambda _: _make_fake_onnx_model(), ) evaluator = _make_evaluator(model_path, bc_path=runtime_path) evaluator.load_policy() assert captured["hb_model_path"] == str(runtime_path) assert evaluator.policy_input_name == "obs" assert evaluator.policy_kv_input_name == "past_key_values" assert evaluator.policy_step_input_name == "step_idx" assert evaluator.policy_output_name == "action" assert evaluator.policy_kv_output_name == "present_key_values" assert evaluator.policy_model_context_len == 4 def test_bc_runtime_run_normalizes_inputs_for_hbruntime(monkeypatch, tmp_path): runtime_path = tmp_path / "demo_quantized_model.bc" runtime_path.write_bytes(b"bc") captured = {} class _FakeHBRuntime: def __init__(self, model_path): captured["model_path"] = model_path self.input_names = ["obs", "past_key_values", "step_idx"] self.output_names = ["action", "present_key_values"] def run(self, output_names, input_feed): captured["output_names"] = list(output_names) captured["input_feed"] = input_feed return ["ok"] monkeypatch.setattr( eval_mujoco_sim2sim_s100, "HBRuntime", _FakeHBRuntime, raising=False, ) wrapper = eval_mujoco_sim2sim_s100._HbSessionWrapper(runtime_path) obs = np.arange(6, dtype=np.float64).reshape(2, 3).T past_key_values = np.arange(24, dtype=np.float64).reshape(2, 3, 4) step_idx = np.array([7], dtype=np.int32) outputs = wrapper.run( ["action"], { "obs": obs, "past_key_values": past_key_values, "step_idx": step_idx, }, ) assert outputs == ["ok"] assert captured["model_path"] == str(runtime_path) assert captured["output_names"] == ["action"] assert captured["input_feed"]["obs"].dtype == np.float32 assert captured["input_feed"]["obs"].flags["C_CONTIGUOUS"] assert captured["input_feed"]["past_key_values"].dtype == np.float32 assert captured["input_feed"]["past_key_values"].flags["C_CONTIGUOUS"] assert captured["input_feed"]["step_idx"].dtype == np.int64 assert captured["input_feed"]["step_idx"].flags["C_CONTIGUOUS"] def test_update_policy_raises_clear_error_before_runtime_on_obs_dim_mismatch(): evaluator = eval_mujoco_sim2sim_s100.MujocoEvaluator.__new__( eval_mujoco_sim2sim_s100.MujocoEvaluator ) evaluator._record_robot_states = lambda: None evaluator.obs_builder = SimpleNamespace( build_policy_obs=lambda: np.zeros(425, dtype=np.float32) ) evaluator.policy_input_name = "obs" evaluator.policy_output_name = "action" evaluator.policy_obs_expected_dim = 786 evaluator.use_kv_cache = False evaluator.policy_step_input_name = None evaluator.policy_kv_output_name = None evaluator.policy_moe_layer_output_names = [] evaluator.dump_onnx_io_npy = False evaluator.counter = 0 evaluator.command_mode = "velocity_tracking" evaluator.config = OmegaConf.create( {"motion_npz_dir": "", "motion_npz_path": ""} ) evaluator.policy_session = SimpleNamespace( run=lambda *args, **kwargs: (_ for _ in ()).throw( AssertionError("runtime should not be called on shape mismatch") ) ) with pytest.raises( ValueError, match="expects 786 features but evaluator built 425" ): evaluator._update_policy() ================================================ FILE: tests/test_eval_mujoco_use_gpu.py ================================================ import sys from pathlib import Path from omegaconf import OmegaConf sys.path.insert(0, str(Path(__file__).resolve().parents[1])) import holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim class _FakeIoNode: def __init__(self, name, shape): self.name = name self.shape = shape def test_load_policy_treats_false_string_use_gpu_as_cpu(monkeypatch): captured = {} class _FakeInferenceSession: def __init__(self, model_path, sess_options, providers): captured["model_path"] = model_path captured["providers"] = providers def get_providers(self): return ["CPUExecutionProvider"] def get_inputs(self): return [_FakeIoNode("obs", [1, 16])] def get_outputs(self): return [_FakeIoNode("action", [1, 12])] monkeypatch.setattr( eval_mujoco_sim2sim.onnxruntime, "get_available_providers", lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"], ) monkeypatch.setattr( eval_mujoco_sim2sim.onnxruntime, "InferenceSession", _FakeInferenceSession, ) evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__( eval_mujoco_sim2sim.MujocoEvaluator ) evaluator.config = OmegaConf.create( { "ckpt_onnx_path": "model.onnx", "use_gpu": "false", "gpu_id": 3, } ) evaluator.max_context_len = 0 evaluator.load_policy() assert captured["model_path"] == "model.onnx" assert captured["providers"] == ["CPUExecutionProvider"] def test_create_ray_evaluator_preserves_use_gpu_false(monkeypatch): captured = {} class _FakeEvaluator: def __init__(self, config): captured["use_gpu"] = config.use_gpu captured["gpu_id"] = config.gpu_id monkeypatch.setattr(eval_mujoco_sim2sim, "MujocoEvaluator", _FakeEvaluator) eval_mujoco_sim2sim._create_ray_evaluator( {"use_gpu": False, "gpu_id": 5}, "holomotion" ) assert captured["use_gpu"] is False assert captured["gpu_id"] == 5 def test_run_mujoco_sim2sim_eval_preserves_use_gpu_false( monkeypatch, tmp_path ): captured = {} class _FakeEvaluator: def __init__(self, config): captured["use_gpu"] = config.use_gpu def setup(self): captured["setup"] = True def run_simulation(self): captured["run_simulation"] = True monkeypatch.setattr( eval_mujoco_sim2sim.hydra.utils, "get_original_cwd", lambda: str(tmp_path), ) monkeypatch.setattr( eval_mujoco_sim2sim, "process_config", lambda _: OmegaConf.create( { "use_gpu": False, "model_type": "holomotion", } ), ) monkeypatch.setattr(eval_mujoco_sim2sim, "MujocoEvaluator", _FakeEvaluator) eval_mujoco_sim2sim.run_mujoco_sim2sim_eval(OmegaConf.create({})) assert captured["use_gpu"] is False assert captured["setup"] is True assert captured["run_simulation"] is True ================================================ FILE: tests/test_eval_onnx_io_dump.py ================================================ import json import sys import types from pathlib import Path from types import SimpleNamespace import numpy as np sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.evaluation.eval_mujoco_sim2sim import ( MujocoEvaluator, write_onnx_io_dump_readme, ) from holomotion.src.evaluation.ray_evaluator_actor import RayEvaluatorActor class _Config(SimpleNamespace): def get(self, key, default=None): return getattr(self, key, default) def test_save_onnx_io_dump_stacks_per_frame_inputs_and_outputs(tmp_path): evaluator = MujocoEvaluator.__new__(MujocoEvaluator) evaluator._reset_onnx_io_dump_buffers() evaluator._record_onnx_io_frame( input_feed={ "obs": np.array([[1.0, 2.0]], dtype=np.float32), "step": np.array([0], dtype=np.int64), }, output_names=["action", "kv_cache"], onnx_output=[ np.array([[0.1, 0.2]], dtype=np.float32), np.array([[[3.0, 4.0]]], dtype=np.float32), ], ) evaluator._record_onnx_io_frame( input_feed={ "obs": np.array([[5.0, 6.0]], dtype=np.float32), "step": np.array([1], dtype=np.int64), }, output_names=["action", "kv_cache"], onnx_output=[ np.array([[0.3, 0.4]], dtype=np.float32), np.array([[[7.0, 8.0]]], dtype=np.float32), ], ) output_path = tmp_path / "clip_onnx_io.npy" evaluator.save_onnx_io_dump( output_path, { "source_npz": "clip.npz", "onnx_model": "model.onnx", }, ) payload = np.load(output_path, allow_pickle=True).item() assert payload["input_names"] == ["obs", "step"] assert payload["output_names"] == ["action", "kv_cache"] np.testing.assert_allclose( payload["inputs"]["obs"], np.array([[[1.0, 2.0]], [[5.0, 6.0]]], dtype=np.float32), ) np.testing.assert_array_equal( payload["inputs"]["step"], np.array([[0], [1]], dtype=np.int64), ) np.testing.assert_allclose( payload["outputs"]["action"], np.array([[[0.1, 0.2]], [[0.3, 0.4]]], dtype=np.float32), ) np.testing.assert_allclose( payload["outputs"]["kv_cache"], np.array([[[[3.0, 4.0]]], [[[7.0, 8.0]]]], dtype=np.float32), ) assert payload["source_npz"] == "clip.npz" assert payload["onnx_model"] == "model.onnx" def test_write_onnx_io_dump_readme_creates_chinese_loading_instructions( tmp_path, ): readme_path = write_onnx_io_dump_readme(tmp_path) assert readme_path == tmp_path / "README.md" content = readme_path.read_text(encoding="utf-8") assert "每个动作片段会生成一个 `.npy` 文件" in content assert "allow_pickle=True" in content assert "np.load(npy_path, allow_pickle=True).item()" in content def test_save_batch_result_persists_low_level_torque_dump_and_dt(tmp_path): evaluator = MujocoEvaluator.__new__(MujocoEvaluator) evaluator.policy_dt = 0.02 evaluator.simulation_dt = 0.005 evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)] evaluator._robot_low_level_dof_torque_seq = [ np.array([1.0, -1.0], dtype=np.float32), np.array([-1.0, 1.0], dtype=np.float32), ] evaluator._robot_action_rate_seq = [np.float32(0.0)] evaluator._robot_global_translation_seq = [ np.zeros((1, 3), dtype=np.float32) ] evaluator._robot_global_rotation_quat_seq = [ np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32) ] evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)] evaluator._robot_global_angular_velocity_seq = [ np.zeros((1, 3), dtype=np.float32) ] evaluator.ref_dof_pos = np.zeros((1, 2), dtype=np.float32) evaluator.ref_dof_vel = np.zeros((1, 2), dtype=np.float32) evaluator.ref_global_translation = np.zeros((1, 1, 3), dtype=np.float32) evaluator.ref_global_rotation_quat_xyzw = np.array( [[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32 ) evaluator.ref_global_velocity = np.zeros((1, 1, 3), dtype=np.float32) evaluator.ref_global_angular_velocity = np.zeros( (1, 1, 3), dtype=np.float32 ) output_path = tmp_path / "demo_eval.npz" evaluator.save_batch_result( output_path, {"source_file": "clip.npz", "clip_length": 1} ) with np.load(output_path, allow_pickle=True) as payload: metadata = json.loads(payload["metadata"].item()) np.testing.assert_allclose( payload["robot_low_level_dof_torque"], np.array([[1.0, -1.0], [-1.0, 1.0]], dtype=np.float32), ) assert metadata["source_file"] == "clip.npz" assert metadata["clip_length"] == 1 assert metadata["robot_low_level_torque_dt"] == 0.005 def test_save_batch_result_persists_moe_routing_tensors(tmp_path): evaluator = MujocoEvaluator.__new__(MujocoEvaluator) evaluator.policy_dt = 0.02 evaluator.simulation_dt = 0.005 evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)] evaluator._robot_low_level_dof_torque_seq = [ np.array([0.5, -0.5], dtype=np.float32) ] evaluator._robot_action_rate_seq = [np.float32(0.0)] evaluator._robot_global_translation_seq = [ np.zeros((1, 3), dtype=np.float32) ] evaluator._robot_global_rotation_quat_seq = [ np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32) ] evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)] evaluator._robot_global_angular_velocity_seq = [ np.zeros((1, 3), dtype=np.float32) ] evaluator._robot_moe_expert_indices_seq = [ np.array([[1, 3], [0, 2]], dtype=np.int64) ] evaluator._robot_moe_expert_logits_seq = [ np.array( [[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]], dtype=np.float32, ) ] evaluator.ref_dof_pos = np.zeros((1, 2), dtype=np.float32) evaluator.ref_dof_vel = np.zeros((1, 2), dtype=np.float32) evaluator.ref_global_translation = np.zeros((1, 1, 3), dtype=np.float32) evaluator.ref_global_rotation_quat_xyzw = np.array( [[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32 ) evaluator.ref_global_velocity = np.zeros((1, 1, 3), dtype=np.float32) evaluator.ref_global_angular_velocity = np.zeros( (1, 1, 3), dtype=np.float32 ) output_path = tmp_path / "demo_eval_moe.npz" evaluator.save_batch_result(output_path, {"source_file": "clip.npz"}) with np.load(output_path, allow_pickle=True) as payload: np.testing.assert_array_equal( payload["robot_moe_expert_indices"], np.array([[[1, 3], [0, 2]]], dtype=np.int64), ) np.testing.assert_allclose( payload["robot_moe_expert_logits"], np.array( [[[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]]], dtype=np.float32, ), ) def test_dump_robot_augmented_npz_persists_moe_routing_tensors(tmp_path): source_npz = tmp_path / "clip.npz" np.savez(source_npz, ref=np.array([1], dtype=np.int32)) onnx_path = tmp_path / "model.onnx" onnx_path.write_bytes(b"") evaluator = MujocoEvaluator.__new__(MujocoEvaluator) evaluator.simulation_dt = 0.005 evaluator.config = _Config( motion_npz_path=str(source_npz), ckpt_onnx_path=str(onnx_path), ) evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)] evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)] evaluator._robot_low_level_dof_torque_seq = [ np.array([0.5, -0.5], dtype=np.float32) ] evaluator._robot_action_rate_seq = [np.float32(0.0)] evaluator._robot_global_translation_seq = [ np.zeros((1, 3), dtype=np.float32) ] evaluator._robot_global_rotation_quat_seq = [ np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32) ] evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)] evaluator._robot_global_angular_velocity_seq = [ np.zeros((1, 3), dtype=np.float32) ] evaluator._robot_moe_expert_indices_seq = [ np.array([[1, 3], [0, 2]], dtype=np.int64) ] evaluator._robot_moe_expert_logits_seq = [ np.array( [[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]], dtype=np.float32, ) ] evaluator._dump_robot_augmented_npz() out_path = tmp_path / "mujoco_output_model" / "clip_robot.npz" with np.load(out_path, allow_pickle=True) as payload: np.testing.assert_array_equal( payload["robot_moe_expert_indices"], np.array([[[1, 3], [0, 2]]], dtype=np.int64), ) np.testing.assert_allclose( payload["robot_moe_expert_logits"], np.array( [[[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]]], dtype=np.float32, ), ) def test_ray_actor_run_clip_overwrites_existing_outputs_and_sidecar(tmp_path): class _FakeEvaluator: def __init__(self): self.n_motion_frames = 2 self.calls = [] self.counter = 0 def load_specific_motion(self, file_path): self.calls.append(("load", file_path)) def reset_state_teleport(self): self.calls.append(("reset",)) def _update_policy(self): self.calls.append(("update",)) def _apply_control(self, sleep=False): self.calls.append(("apply", sleep)) def save_batch_result(self, output_path, meta_info): self.calls.append(("save_batch", output_path, meta_info)) Path(output_path).write_text("fresh-npz", encoding="utf-8") def save_onnx_io_dump(self, output_path, meta_info): self.calls.append(("save_onnx", output_path, meta_info)) np.save( output_path, {"source_npz": meta_info["source_file"]}, allow_pickle=True, ) actor = RayEvaluatorActor.__new__(RayEvaluatorActor) actor.output_dir = str(tmp_path) actor.config_dict = { "ckpt_onnx_path": "model.onnx", "dump_onnx_io_npy": True, } actor.evaluator = _FakeEvaluator() clip_path = tmp_path / "demo_clip.npz" np.savez(clip_path, dummy=np.array([1], dtype=np.int32)) existing_npz = tmp_path / "demo_clip_eval.npz" existing_npz.write_text("stale", encoding="utf-8") onnx_dir = tmp_path / "onnx_io_npy" onnx_dir.mkdir() status = actor.run_clip(str(clip_path)) assert status == "success" assert existing_npz.read_text(encoding="utf-8") == "fresh-npz" onnx_dump_path = onnx_dir / "demo_clip_onnx_io.npy" assert onnx_dump_path.is_file() payload = np.load(onnx_dump_path, allow_pickle=True).item() assert payload["source_npz"] == "demo_clip.npz" assert ("load", str(clip_path)) in actor.evaluator.calls assert ("reset",) in actor.evaluator.calls assert actor.evaluator.calls.count(("update",)) == 2 def test_ray_actor_skips_sidecar_for_non_default_model_type(tmp_path): class _FakeEvaluator: def __init__(self): self.n_motion_frames = 1 self.calls = [] self.counter = 0 def load_specific_motion(self, file_path): self.calls.append(("load", file_path)) def reset_state_teleport(self): self.calls.append(("reset",)) def _update_policy(self): self.calls.append(("update",)) def _apply_control(self, sleep=False): self.calls.append(("apply", sleep)) def save_batch_result(self, output_path, meta_info): self.calls.append(("save_batch", output_path, meta_info)) Path(output_path).write_text("fresh-npz", encoding="utf-8") def save_onnx_io_dump(self, output_path, meta_info): self.calls.append(("save_onnx", output_path, meta_info)) actor = RayEvaluatorActor.__new__(RayEvaluatorActor) actor.output_dir = str(tmp_path) actor.config_dict = { "ckpt_onnx_path": "model.onnx", "dump_onnx_io_npy": True, "model_type": "gmt", } actor.evaluator = _FakeEvaluator() clip_path = tmp_path / "demo_clip.npz" np.savez(clip_path, dummy=np.array([1], dtype=np.int32)) status = actor.run_clip(str(clip_path)) assert status == "success" assert not any(call[0] == "save_onnx" for call in actor.evaluator.calls) def test_ray_actor_treats_empty_model_type_as_default_holomotion(tmp_path): class _FakeEvaluator: def __init__(self): self.n_motion_frames = 1 self.calls = [] self.counter = 0 def load_specific_motion(self, file_path): self.calls.append(("load", file_path)) def reset_state_teleport(self): self.calls.append(("reset",)) def _update_policy(self): self.calls.append(("update",)) def _apply_control(self, sleep=False): self.calls.append(("apply", sleep)) def save_batch_result(self, output_path, meta_info): self.calls.append(("save_batch", output_path, meta_info)) Path(output_path).write_text("fresh-npz", encoding="utf-8") def save_onnx_io_dump(self, output_path, meta_info): self.calls.append(("save_onnx", output_path, meta_info)) np.save(output_path, {"source_npz": meta_info["source_file"]}) actor = RayEvaluatorActor.__new__(RayEvaluatorActor) actor.output_dir = str(tmp_path) actor.config_dict = { "ckpt_onnx_path": "model.onnx", "dump_onnx_io_npy": True, "model_type": "", } actor.evaluator = _FakeEvaluator() clip_path = tmp_path / "demo_clip.npz" np.savez(clip_path, dummy=np.array([1], dtype=np.int32)) status = actor.run_clip(str(clip_path)) assert status == "success" assert any(call[0] == "save_onnx" for call in actor.evaluator.calls) def test_ray_actor_init_uses_configured_evaluator_module( monkeypatch, tmp_path ): class _FakeEvaluator: def __init__(self): self.setup_called = False def setup(self): self.setup_called = True captured = {} fake_evaluator = _FakeEvaluator() def _unexpected_default_factory(*args, **kwargs): raise AssertionError("default evaluator factory should not be used") def _fake_override_factory(config_dict, model_type): captured["config_dict"] = config_dict captured["model_type"] = model_type return fake_evaluator monkeypatch.setattr( "holomotion.src.evaluation.eval_mujoco_sim2sim._create_ray_evaluator", _unexpected_default_factory, ) sys.modules["holomotion.src.evaluation.fake_eval_module"] = ( types.SimpleNamespace(_create_ray_evaluator=_fake_override_factory) ) actor = RayEvaluatorActor( { "ckpt_onnx_path": "model.onnx", "model_type": "holomotion", "ray_evaluator_module": "holomotion.src.evaluation.fake_eval_module", }, str(tmp_path), ) assert actor.evaluator is fake_evaluator assert fake_evaluator.setup_called is True assert captured["model_type"] == "holomotion" assert ( captured["config_dict"]["ray_evaluator_module"] == "holomotion.src.evaluation.fake_eval_module" ) ================================================ FILE: tests/test_evaluation_metrics.py ================================================ import csv import json from pathlib import Path import numpy as np from holomotion.src.evaluation.metrics import ( _compute_clip_stability_summary, _per_frame_metrics_from_npz, offline_evaluate_dumped_npzs, ) def _make_eval_data( robot_dof_torque: np.ndarray, *, robot_dof_vel: np.ndarray | None = None, robot_dof_acc: np.ndarray | None = None, robot_action_rate: np.ndarray | None = None, robot_low_level_dof_torque: np.ndarray | None = None, robot_global_angular_velocity: np.ndarray | None = None, robot_low_level_foot_contact: np.ndarray | None = None, robot_low_level_foot_normal_force: np.ndarray | None = None, robot_low_level_foot_tangent_speed: np.ndarray | None = None, robot_moe_expert_logits: np.ndarray | None = None, ): num_frames = int(robot_dof_torque.shape[0]) num_dofs = int(robot_dof_torque.shape[1]) root = np.zeros((num_frames, 1, 3), dtype=np.float32) child = np.tile( np.array([[[0.0, 0.0, 1.0]]], dtype=np.float32), (num_frames, 1, 1) ) global_translation = np.concatenate([root, child], axis=1) global_rotation = np.tile( np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (num_frames, 2, 1), ) zeros_dof = np.zeros((num_frames, num_dofs), dtype=np.float32) payload = { "ref_dof_pos": zeros_dof.copy(), "robot_dof_pos": zeros_dof.copy(), "ref_dof_vel": zeros_dof.copy(), "ref_global_translation": global_translation.copy(), "robot_global_translation": global_translation.copy(), "ref_global_rotation_quat": global_rotation.copy(), "robot_global_rotation_quat": global_rotation.copy(), "ref_global_velocity": np.zeros((num_frames, 2, 3), dtype=np.float32), "ref_global_angular_velocity": np.zeros( (num_frames, 2, 3), dtype=np.float32 ), "robot_global_velocity": np.zeros( (num_frames, 2, 3), dtype=np.float32 ), "robot_global_angular_velocity": ( np.zeros((num_frames, 2, 3), dtype=np.float32) if robot_global_angular_velocity is None else robot_global_angular_velocity.astype(np.float32) ), "robot_dof_vel": ( zeros_dof.copy() if robot_dof_vel is None else robot_dof_vel.astype(np.float32) ), "robot_dof_acc": ( zeros_dof.copy() if robot_dof_acc is None else robot_dof_acc.astype(np.float32) ), "robot_dof_torque": robot_dof_torque.astype(np.float32), "robot_action_rate": ( np.zeros((num_frames,), dtype=np.float32) if robot_action_rate is None else robot_action_rate.astype(np.float32) ), } if robot_low_level_dof_torque is not None: payload["robot_low_level_dof_torque"] = ( robot_low_level_dof_torque.astype(np.float32) ) if robot_low_level_foot_contact is not None: payload["robot_low_level_foot_contact"] = ( robot_low_level_foot_contact.astype(np.float32) ) if robot_low_level_foot_normal_force is not None: payload["robot_low_level_foot_normal_force"] = ( robot_low_level_foot_normal_force.astype(np.float32) ) if robot_low_level_foot_tangent_speed is not None: payload["robot_low_level_foot_tangent_speed"] = ( robot_low_level_foot_tangent_speed.astype(np.float32) ) if robot_moe_expert_logits is not None: payload["robot_moe_expert_logits"] = robot_moe_expert_logits.astype( np.float32 ) return payload def test_per_frame_metrics_include_torque_jump_diagnostics(): constant_torque = np.ones((4, 2), dtype=np.float32) constant_df = _per_frame_metrics_from_npz( motion_key="constant", data=_make_eval_data(constant_torque), robot_control_dt=0.5, ) assert "mean_torque_jump_norm" in constant_df.columns assert "mean_torque_jump_ratio" in constant_df.columns assert np.isnan(constant_df["mean_torque_jump_norm"].iloc[0]) assert np.isnan(constant_df["mean_torque_jump_ratio"].iloc[0]) np.testing.assert_allclose( np.nan_to_num(constant_df["mean_torque_jump_norm"].to_numpy()), np.zeros(4, dtype=np.float64), ) np.testing.assert_allclose( np.nan_to_num(constant_df["mean_torque_jump_ratio"].to_numpy()), np.zeros(4, dtype=np.float64), ) jump_torque = np.array( [ [1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [-1.0, 0.0], ], dtype=np.float32, ) jump_df = _per_frame_metrics_from_npz( motion_key="jump", data=_make_eval_data(jump_torque), robot_control_dt=0.5, ) assert jump_df["mean_torque_jump_norm"].iloc[2] > 3.9 assert jump_df["mean_torque_jump_ratio"].iloc[2] > 1.9 def test_offline_evaluate_dumped_npzs_exports_torque_jump_summary_metrics( tmp_path: Path, ): eval_dir = tmp_path / "eval" eval_dir.mkdir() jump_torque_50hz = np.tile( np.array([[1.0, 0.0]], dtype=np.float32), (4, 1) ) jump_torque_low_level = np.array( [ [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], ], dtype=np.float32, ) payload = _make_eval_data( jump_torque_50hz, robot_low_level_dof_torque=jump_torque_low_level, ) payload["metadata"] = np.array( json.dumps({"clip_length": 4, "robot_low_level_torque_dt": 0.005}), dtype=np.str_, ) np.savez_compressed(eval_dir / "demo_clip.npz", **payload) output_json_path = eval_dir / "summary.json" result = offline_evaluate_dumped_npzs( npz_dir=str(eval_dir), output_json_path=str(output_json_path), ) per_clip = result["per_clip"][0] for key in ( "mean_torque_jump_norm", "p95_torque_jump_norm", "mean_torque_jump_ratio", "p95_torque_jump_ratio", ): assert key in per_clip assert key in result["dataset"]["mean"] assert per_clip["mean_dof_torque"] == 1.0 assert per_clip["p95_torque_jump_norm"] > 300.0 assert per_clip["p95_torque_jump_ratio"] > 1.0 with output_json_path.open("r", encoding="utf-8") as handle: written = json.load(handle) assert "p95_torque_jump_ratio" in written["dataset"]["mean"] csv_path = eval_dir / "per_clip_metrics.csv" with csv_path.open("r", encoding="utf-8", newline="") as handle: reader = csv.DictReader(handle) row = next(reader) assert "p95_torque_jump_ratio" in row assert "mean_torque_jump_norm" in row def test_compute_clip_stability_summary_detects_chatter_and_support_events(): num_frames = 50 num_low_level = 200 policy_dt = 0.02 low_level_dt = 0.005 t_policy = np.arange(num_frames, dtype=np.float32) * policy_dt t_low = np.arange(num_low_level, dtype=np.float32) * low_level_dt smooth_ang_vel = np.zeros((num_frames, 2, 3), dtype=np.float32) smooth_ang_vel[:, 0, 0] = 0.2 * np.sin(2.0 * np.pi * 1.0 * t_policy) unstable_ang_vel = smooth_ang_vel.copy() unstable_ang_vel[:, 0, 0] += 0.7 * np.sin( 2.0 * np.pi * 8.0 * t_policy ).astype(np.float32) unstable_ang_vel[:, 0, 1] += 0.4 * np.sin( 2.0 * np.pi * 6.0 * t_policy ).astype(np.float32) smooth_low_level_torque = np.zeros((num_low_level, 2), dtype=np.float32) smooth_low_level_torque[:, 0] = np.sin(2.0 * np.pi * 1.0 * t_low) unstable_low_level_torque = smooth_low_level_torque.copy() unstable_low_level_torque[:, 0] += 0.8 * np.sin( 2.0 * np.pi * 15.0 * t_low ).astype(np.float32) unstable_low_level_torque[80:85, 0] += 2.5 unstable_low_level_torque[120:123, 0] -= 2.5 stable_contact = np.zeros((num_low_level, 2), dtype=np.float32) stable_contact[:100, 0] = 1.0 stable_contact[100:, 1] = 1.0 stable_normal_force = stable_contact * np.array( [[80.0, 75.0]], dtype=np.float32 ) stable_tangent_speed = stable_contact * 0.01 unstable_contact = np.zeros((num_low_level, 2), dtype=np.float32) for start in range(0, num_low_level, 10): unstable_contact[start : start + 5, 0] = 1.0 unstable_contact[start + 5 : start + 10, 1] = 1.0 unstable_normal_force = unstable_contact * 60.0 touchdown_mask = unstable_contact.copy() touchdown_mask[1:] = np.clip( unstable_contact[1:] - unstable_contact[:-1], a_min=0.0, a_max=None ) unstable_normal_force += touchdown_mask * 120.0 unstable_tangent_speed = unstable_contact * 0.25 smooth_metrics = _compute_clip_stability_summary( data=_make_eval_data( np.zeros((num_frames, 2), dtype=np.float32), robot_low_level_dof_torque=smooth_low_level_torque, robot_global_angular_velocity=smooth_ang_vel, robot_low_level_foot_contact=stable_contact, robot_low_level_foot_normal_force=stable_normal_force, robot_low_level_foot_tangent_speed=stable_tangent_speed, ), robot_control_dt=policy_dt, low_level_contact_dt=low_level_dt, ) unstable_metrics = _compute_clip_stability_summary( data=_make_eval_data( np.zeros((num_frames, 2), dtype=np.float32), robot_low_level_dof_torque=unstable_low_level_torque, robot_global_angular_velocity=unstable_ang_vel, robot_low_level_foot_contact=unstable_contact, robot_low_level_foot_normal_force=unstable_normal_force, robot_low_level_foot_tangent_speed=unstable_tangent_speed, ), robot_control_dt=policy_dt, low_level_contact_dt=low_level_dt, ) assert ( unstable_metrics["torque_chatter_hf_ratio"] > smooth_metrics["torque_chatter_hf_ratio"] ) assert ( unstable_metrics["torque_jump_burst_max"] > smooth_metrics["torque_jump_burst_max"] ) assert ( unstable_metrics["torso_rp_hf_ratio"] > smooth_metrics["torso_rp_hf_ratio"] ) assert ( unstable_metrics["torso_rp_angacc_p95"] > smooth_metrics["torso_rp_angacc_p95"] ) assert ( unstable_metrics["foot_contact_toggle_rate"] > smooth_metrics["foot_contact_toggle_rate"] ) assert ( unstable_metrics["foot_impact_force_p95"] > smooth_metrics["foot_impact_force_p95"] ) assert ( unstable_metrics["stance_slip_speed_p95"] > smooth_metrics["stance_slip_speed_p95"] ) def test_compute_clip_stability_summary_reports_expert_switching_js_div(): num_frames = 8 stable_logits = np.tile( np.array( [ [8.0, -4.0, -4.0], [-4.0, 8.0, -4.0], ], dtype=np.float32, )[None, :, :], (num_frames, 1, 1), ) switching_logits = stable_logits.copy() switching_logits[1::2, 0, :] = np.array( [-4.0, 8.0, -4.0], dtype=np.float32 ) switching_logits[1::2, 1, :] = np.array( [-4.0, -4.0, 8.0], dtype=np.float32 ) stable_metrics = _compute_clip_stability_summary( data=_make_eval_data( np.zeros((num_frames, 2), dtype=np.float32), robot_moe_expert_logits=stable_logits, ), robot_control_dt=0.02, low_level_contact_dt=0.02, ) switching_metrics = _compute_clip_stability_summary( data=_make_eval_data( np.zeros((num_frames, 2), dtype=np.float32), robot_moe_expert_logits=switching_logits, ), robot_control_dt=0.02, low_level_contact_dt=0.02, ) assert stable_metrics["expert_switching_js_div"] < 1e-6 assert ( switching_metrics["expert_switching_js_div"] > stable_metrics["expert_switching_js_div"] ) def test_offline_evaluate_dumped_npzs_reports_nan_contact_metrics_for_legacy_npz( tmp_path: Path, ): eval_dir = tmp_path / "legacy_eval" eval_dir.mkdir() payload = _make_eval_data(np.ones((8, 2), dtype=np.float32)) payload["metadata"] = np.array( json.dumps({"clip_length": 8, "robot_low_level_torque_dt": 0.005}), dtype=np.str_, ) np.savez_compressed(eval_dir / "legacy_clip.npz", **payload) output_json_path = eval_dir / "summary.json" result = offline_evaluate_dumped_npzs( npz_dir=str(eval_dir), output_json_path=str(output_json_path), ) per_clip = result["per_clip"][0] for key in ( "torque_chatter_hf_ratio", "torque_jump_burst_max", "torso_rp_hf_ratio", "torso_rp_angacc_p95", "foot_contact_toggle_rate", "foot_impact_force_p95", "stance_slip_speed_p95", "expert_switching_js_div", ): assert key in per_clip assert key in result["dataset"]["mean"] assert np.isnan(per_clip["foot_contact_toggle_rate"]) assert np.isnan(per_clip["foot_impact_force_p95"]) assert np.isnan(per_clip["stance_slip_speed_p95"]) assert np.isnan(per_clip["expert_switching_js_div"]) ================================================ FILE: tests/test_isaaclab_termination.py ================================================ import importlib.util import sys import types from pathlib import Path from types import SimpleNamespace import pytest import torch MODULE_PATH = ( Path(__file__).resolve().parents[1] / "holomotion" / "src" / "env" / "isaaclab_components" / "isaaclab_termination.py" ) MOTION_COMMAND_MODULE_NAME = ( "holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command" ) ISAACLAB_UTILS_MODULE_NAME = ( "holomotion.src.env.isaaclab_components.isaaclab_utils" ) class _Scene(SimpleNamespace): def __getitem__(self, key): return getattr(self, key) def _load_isaaclab_termination_module(module_name: str): isaaclab_module = types.ModuleType("isaaclab") isaaclab_envs = types.ModuleType("isaaclab.envs") isaaclab_envs.ManagerBasedRLEnv = object isaaclab_terminations = types.SimpleNamespace( time_out=lambda env: torch.zeros(1, dtype=torch.bool), bad_orientation=lambda env, limit_angle: torch.zeros( 1, dtype=torch.bool ), root_height_below_minimum=lambda env, minimum_height: torch.zeros( 1, dtype=torch.bool ), native_only_term=lambda env, margin: torch.zeros(1, dtype=torch.bool), ) isaaclab_envs_mdp = types.ModuleType("isaaclab.envs.mdp") isaaclab_envs_mdp.terminations = isaaclab_terminations isaaclab_managers = types.ModuleType("isaaclab.managers") class _TerminationTermCfg: def __init__(self, func, params=None, time_out=False): self.func = func self.params = {} if params is None else params self.time_out = time_out isaaclab_managers.TerminationTermCfg = _TerminationTermCfg isaaclab_managers.SceneEntityCfg = object isaaclab_utils = types.ModuleType("isaaclab.utils") isaaclab_utils.configclass = lambda cls: cls isaaclab_utils_math = types.ModuleType("isaaclab.utils.math") isaaclab_utils_math.quat_apply_inverse = ( lambda quat, vec: torch.zeros_like(vec) ) isaaclab_utils.math = isaaclab_utils_math isaaclab_assets = types.ModuleType("isaaclab.assets") isaaclab_assets.Articulation = object isaaclab_components_package = types.ModuleType( "holomotion.src.env.isaaclab_components" ) motion_command_module = types.ModuleType(MOTION_COMMAND_MODULE_NAME) motion_command_module.RefMotionCommand = object isaaclab_utils_module = types.ModuleType(ISAACLAB_UTILS_MODULE_NAME) isaaclab_utils_module._get_body_indices = lambda robot, keybody_names: None isaaclab_utils_module.resolve_holo_config = lambda cfg: cfg isaaclab_components_package.isaaclab_motion_tracking_command = ( motion_command_module ) isaaclab_components_package.isaaclab_utils = isaaclab_utils_module fake_modules = { "isaaclab": isaaclab_module, "isaaclab.envs": isaaclab_envs, "isaaclab.envs.mdp": isaaclab_envs_mdp, "isaaclab.managers": isaaclab_managers, "isaaclab.utils": isaaclab_utils, "isaaclab.utils.math": isaaclab_utils_math, "isaaclab.assets": isaaclab_assets, "holomotion.src.env.isaaclab_components": isaaclab_components_package, MOTION_COMMAND_MODULE_NAME: motion_command_module, ISAACLAB_UTILS_MODULE_NAME: isaaclab_utils_module, } original_modules = {name: sys.modules.get(name) for name in fake_modules} sys.modules.update(fake_modules) try: spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) return module finally: for name, original in original_modules.items(): if original is None: sys.modules.pop(name, None) else: sys.modules[name] = original def test_wholebody_mpjpe_far_flags_envs_above_mean_error_threshold(): termination_module = _load_isaaclab_termination_module( "isaaclab_termination_under_test" ) current_dof_pos = torch.tensor( [ [0.0, 0.2, 0.6], [0.0, 0.1, 0.2], ] ) ref_dof_pos = torch.zeros_like(current_dof_pos) command = SimpleNamespace( robot=SimpleNamespace(data=SimpleNamespace(joint_pos=current_dof_pos)), get_ref_motion_dof_pos_cur=lambda prefix="ref_": ref_dof_pos, get_ref_motion_dof_pos_immediate_next=lambda prefix="ref_": ref_dof_pos, ) env = SimpleNamespace( command_manager=SimpleNamespace(get_term=lambda name: command) ) result = termination_module.wholebody_mpjpe_far(env, threshold=0.2) assert result.dtype == torch.bool assert torch.equal(result, torch.tensor([True, False])) def test_wholebody_mpjpe_far_uses_immediate_next_reference(): termination_module = _load_isaaclab_termination_module( "isaaclab_termination_under_test_next_dof" ) current_dof_pos = torch.tensor([[0.0, 0.1, 0.2]]) command = SimpleNamespace( robot=SimpleNamespace(data=SimpleNamespace(joint_pos=current_dof_pos)), get_ref_motion_dof_pos_cur=lambda prefix="ref_": (_ for _ in ()).throw( AssertionError("current reference should not be used") ), get_ref_motion_dof_pos_immediate_next=lambda prefix="ref_": current_dof_pos, ) env = SimpleNamespace( command_manager=SimpleNamespace(get_term=lambda name: command) ) result = termination_module.wholebody_mpjpe_far(env, threshold=0.05) assert torch.equal(result, torch.tensor([False])) def test_keybody_ref_pos_far_uses_immediate_next_reference(): termination_module = _load_isaaclab_termination_module( "isaaclab_termination_under_test_next_keybody" ) body_pos = torch.tensor( [[[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]]], dtype=torch.float32 ) robot = SimpleNamespace( body_names=["anchor", "target"], data=SimpleNamespace(body_pos_w=body_pos), ) command = SimpleNamespace( robot=robot, get_ref_motion_bodylink_global_pos_cur=( lambda prefix="ref_": (_ for _ in ()).throw( AssertionError("current reference should not be used") ) ), get_ref_motion_bodylink_global_pos_immediate_next=( lambda prefix="ref_": body_pos ), ) env = SimpleNamespace( command_manager=SimpleNamespace(get_term=lambda name: command) ) result = termination_module.keybody_ref_pos_far( env, threshold=0.1, keybody_names=["target"], ) assert torch.equal(result, torch.tensor([False])) def test_ref_gravity_projection_far_uses_immediate_next_reference(): termination_module = _load_isaaclab_termination_module( "isaaclab_termination_under_test_next_gravity" ) gravity = torch.tensor([[0.0, 0.0, -1.0]], dtype=torch.float32) anchor_quat = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32) robot = SimpleNamespace( data=SimpleNamespace( GRAVITY_VEC_W=gravity, body_quat_w=anchor_quat[:, None, :], ) ) command = SimpleNamespace( robot=robot, anchor_bodylink_idx=0, get_ref_motion_anchor_bodylink_global_rot_wxyz_cur=( lambda prefix="ref_": (_ for _ in ()).throw( AssertionError("current reference should not be used") ) ), get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next=( lambda prefix="ref_": anchor_quat ), ) env = SimpleNamespace( scene=_Scene(robot=robot), command_manager=SimpleNamespace(get_term=lambda name: command), ) result = termination_module.ref_gravity_projection_far( env, threshold=0.1, ) assert torch.equal(result, torch.tensor([False])) def test_build_terminations_config_registers_wholebody_mpjpe_far(): termination_module = _load_isaaclab_termination_module( "isaaclab_termination_under_test_for_cfg" ) config = termination_module.build_terminations_config( { "wholebody_mpjpe_far": { "params": {"threshold": 0.3}, } } ) assert ( config.wholebody_mpjpe_far.func is termination_module.wholebody_mpjpe_far ) assert config.wholebody_mpjpe_far.params == {"threshold": 0.3} assert config.wholebody_mpjpe_far.time_out is False def test_build_terminations_config_resolves_native_isaaclab_termination(): termination_module = _load_isaaclab_termination_module( "isaaclab_termination_under_test_for_native_cfg" ) config = termination_module.build_terminations_config( { "native_only_term": { "params": {"margin": 0.3}, } } ) assert ( config.native_only_term.func is termination_module.isaaclab_mdp.terminations.native_only_term ) assert config.native_only_term.params == {"margin": 0.3} assert config.native_only_term.time_out is False def test_build_terminations_config_raises_on_unknown_termination(): termination_module = _load_isaaclab_termination_module( "isaaclab_termination_under_test_for_unknown_cfg" ) with pytest.raises(ValueError, match="Unknown termination function"): termination_module.build_terminations_config({"missing_term": {}}) ================================================ FILE: tests/test_mean_process_5metrics.py ================================================ import json import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.scripts.evaluation import mean_process_5metrics LEGACY_METRICS = [ "mpjpe_g", "mpjpe_l", "whole_body_joints_dist", "root_vel_error", "root_r_error", "root_p_error", "root_y_error", "root_height_error", "mean_dof_vel", "mean_dof_acc", "mean_dof_torque", "mean_action_rate", "success", ] TORQUE_JUMP_METRICS = [ "mean_torque_jump_norm", "p95_torque_jump_norm", "mean_torque_jump_ratio", "p95_torque_jump_ratio", ] def test_macro_report_appends_torque_jump_columns_to_legacy_tables( tmp_path, monkeypatch ): json_path = tmp_path / "model_a.json" payload = { "per_clip": [ { "motion_key": "clips_AMASS_demo", "mpjpe_g": 1.0, "mpjpe_l": 2.0, "whole_body_joints_dist": 3.0, "root_vel_error": 4.0, "root_r_error": 5.0, "root_p_error": 6.0, "root_y_error": 7.0, "root_height_error": 8.0, "mean_dof_vel": 9.0, "mean_dof_acc": 10.0, "mean_dof_torque": 11.0, "mean_action_rate": 12.0, "success": 1.0, "mean_torque_jump_norm": 13.0, "p95_torque_jump_norm": 14.0, "mean_torque_jump_ratio": 15.0, "p95_torque_jump_ratio": 16.0, } ] } json_path.write_text(json.dumps(payload), encoding="utf-8") mean_df, _ = mean_process_5metrics.process_data(str(tmp_path)) assert mean_df.columns.tolist() == [ "Method", "Dataset", *LEGACY_METRICS, *TORQUE_JUMP_METRICS, ] captured_headers = {} def _fake_tabulate(_rows, headers, **_kwargs): captured_headers["headers"] = headers return "fake-table" monkeypatch.setattr(mean_process_5metrics, "tabulate", _fake_tabulate) report_path = ( mean_process_5metrics.generate_macro_mean_report_from_json_dir( str(tmp_path) ) ) tsv_path = tmp_path / "sub_dataset_macro_mean_metrics.tsv" header = tsv_path.read_text(encoding="utf-8").splitlines()[0].split("\t") assert header == [ "Dataset", "Global Bodylink Pos Err", "Local Bodylink Pos Err", "Dof Position Err", "Root Vel Err", "Root Roll Err", "Root Pitch Err", "Root Yaw Err", "Root Height Err", "Mean Dof Vel", "Mean Dof Acc", "Mean Dof Torque", "Mean Action Rate", "Success Rate", "Mean Torque Jump Norm", "P95 Torque Jump Norm", "Mean Torque Jump Ratio", "P95 Torque Jump Ratio", ] assert captured_headers["headers"] == header report_text = Path(report_path).read_text(encoding="utf-8") legacy_index = report_text.index("Success Rate") torque_index = report_text.index("Mean Torque Jump Norm") assert torque_index > legacy_index ================================================ FILE: tests/test_motion_cache_gather_state.py ================================================ import sys import unittest from unittest import mock from pathlib import Path import torch sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.training.h5_dataloader import ( ClipBatch, Hdf5RootDofDataset, MotionClipBatchCache, MotionWindow, _CpuFKTransform, _normalize_online_filter_cfg, build_motion_datasets_from_cfg, ) def _expected_field( tensor: torch.Tensor, clip_indices: torch.Tensor, frame_indices: torch.Tensor, n_future_frames: int, lengths: torch.Tensor, ) -> torch.Tensor: temporal_span = 1 + int(n_future_frames) time_offsets = torch.arange(temporal_span, dtype=torch.long) gather_timesteps = frame_indices[:, None] + time_offsets[None, :] max_valid = torch.clamp(lengths.index_select(0, clip_indices) - 1, min=0) gather_timesteps = torch.minimum(gather_timesteps, max_valid[:, None]) return tensor[clip_indices[:, None], gather_timesteps] class MotionCacheGatherStateTests(unittest.TestCase): def test_normalize_online_filter_cfg_includes_velocity_smoothing_sigmas( self, ): default_cfg = _normalize_online_filter_cfg({}) self.assertEqual(default_cfg["ref_vel_smoothing_sigma"], 2.0) self.assertEqual(default_cfg["ft_ref_vel_smoothing_sigma"], 2.0) explicit_cfg = _normalize_online_filter_cfg( { "enabled": True, "butter_cutoff_hz_pool": [3.0], "ref_vel_smoothing_sigma": 0.0, "ft_ref_vel_smoothing_sigma": 2.0, }, default_vel_smoothing_sigma=0.5, ) self.assertEqual(explicit_cfg["ref_vel_smoothing_sigma"], 0.0) self.assertEqual(explicit_cfg["ft_ref_vel_smoothing_sigma"], 2.0) def test_normalize_online_filter_cfg_uses_fk_sigma_fallback_defaults(self): cfg = _normalize_online_filter_cfg( {}, default_vel_smoothing_sigma=0.5, ) self.assertEqual(cfg["ref_vel_smoothing_sigma"], 0.5) self.assertEqual(cfg["ft_ref_vel_smoothing_sigma"], 0.5) def test_build_motion_datasets_from_cfg_passes_fk_sigma_fallback(self): with ( mock.patch( "holomotion.src.training.h5_dataloader.preview_sampling_from_cfg" ), mock.patch( "holomotion.src.training.h5_dataloader.Hdf5RootDofDataset" ) as dataset_cls, ): build_motion_datasets_from_cfg( { "backend": "hdf5_v2", "hdf5_root": "/tmp/train", "fk_robot_file_path": "robot.xml", "fk_vel_smoothing_sigma": 0.5, "cache": {"allowed_prefixes": ["ref_", "ft_ref_"]}, "online_filter": {"enabled": False}, }, max_frame_length=16, min_window_length=4, ) self.assertEqual(dataset_cls.call_count, 1) self.assertEqual( dataset_cls.call_args.kwargs["fk_vel_smoothing_sigma"], 0.5, ) def test_build_motion_datasets_from_cfg_defaults_fk_sigma_fallback(self): with ( mock.patch( "holomotion.src.training.h5_dataloader.preview_sampling_from_cfg" ), mock.patch( "holomotion.src.training.h5_dataloader.Hdf5RootDofDataset" ) as dataset_cls, ): build_motion_datasets_from_cfg( { "backend": "hdf5_v2", "hdf5_root": "/tmp/train", "fk_robot_file_path": "robot.xml", "cache": {"allowed_prefixes": ["ref_", "ft_ref_"]}, "online_filter": {"enabled": False}, }, max_frame_length=16, min_window_length=4, ) self.assertEqual(dataset_cls.call_count, 1) self.assertEqual( dataset_cls.call_args.kwargs["fk_vel_smoothing_sigma"], 2.0, ) def test_gather_tensor_returns_expected_values(self): cache = MotionClipBatchCache.__new__(MotionClipBatchCache) ref_dof_pos = torch.arange(2 * 6 * 3, dtype=torch.float32).reshape( 2, 6, 3 ) ref_rg_pos = torch.arange(2 * 6 * 2 * 3, dtype=torch.float32).reshape( 2, 6, 2, 3 ) lengths = torch.tensor([6, 4], dtype=torch.long) window_indices = torch.tensor([10, 11], dtype=torch.long) cache._current_batch = ClipBatch( tensors={ "ref_dof_pos": ref_dof_pos, "ref_rg_pos": ref_rg_pos, }, lengths=lengths, motion_keys=["clip-a", "clip-b"], raw_motion_keys=["clip-a", "clip-b"], window_indices=window_indices, max_frame_length=6, ) clip_indices = torch.tensor([1, 0, 1, 1], dtype=torch.long) frame_indices = torch.tensor([0, 2, 3, 1], dtype=torch.long) gathered_dof_pos = cache.gather_tensor( "ref_dof_pos", clip_indices=clip_indices, frame_indices=frame_indices, n_future_frames=2, ) gathered_rg_pos = cache.gather_tensor( "ref_rg_pos", clip_indices=clip_indices, frame_indices=frame_indices, n_future_frames=2, ) expected_dof_pos = _expected_field( ref_dof_pos, clip_indices, frame_indices, n_future_frames=2, lengths=lengths, ) expected_rg_pos = _expected_field( ref_rg_pos, clip_indices, frame_indices, n_future_frames=2, lengths=lengths, ) torch.testing.assert_close(gathered_dof_pos, expected_dof_pos) torch.testing.assert_close(gathered_rg_pos, expected_rg_pos) self.assertEqual(tuple(gathered_dof_pos.shape), (4, 3, 3)) self.assertEqual(tuple(gathered_rg_pos.shape), (4, 3, 2, 3)) def test_gather_tensor_reflects_updated_indices_without_cached_state(self): cache = MotionClipBatchCache.__new__(MotionClipBatchCache) ref_dof_pos = torch.arange(3 * 6 * 3, dtype=torch.float32).reshape( 3, 6, 3 ) lengths = torch.tensor([6, 5, 4], dtype=torch.long) window_indices = torch.tensor([10, 11, 12], dtype=torch.long) cache._current_batch = ClipBatch( tensors={"ref_dof_pos": ref_dof_pos}, lengths=lengths, motion_keys=["clip-a", "clip-b", "clip-c"], raw_motion_keys=["clip-a", "clip-b", "clip-c"], window_indices=window_indices, max_frame_length=6, ) initial_clip_indices = torch.tensor([0, 1, 2, 1], dtype=torch.long) initial_frame_indices = torch.tensor([0, 1, 2, 0], dtype=torch.long) updated_clip_indices = torch.tensor([0, 2, 1, 0], dtype=torch.long) updated_frame_indices = torch.tensor([1, 0, 3, 2], dtype=torch.long) initial_gathered = cache.gather_tensor( "ref_dof_pos", clip_indices=initial_clip_indices, frame_indices=initial_frame_indices, n_future_frames=2, ) updated_gathered = cache.gather_tensor( "ref_dof_pos", clip_indices=updated_clip_indices, frame_indices=updated_frame_indices, n_future_frames=2, ) expected_initial = _expected_field( ref_dof_pos, initial_clip_indices, initial_frame_indices, n_future_frames=2, lengths=lengths, ) expected_updated = _expected_field( ref_dof_pos, updated_clip_indices, updated_frame_indices, n_future_frames=2, lengths=lengths, ) torch.testing.assert_close(initial_gathered, expected_initial) torch.testing.assert_close(updated_gathered, expected_updated) def test_cpu_fk_transform_forwards_explicit_vel_smoothing_sigma(self): transform = _CpuFKTransform.__new__(_CpuFKTransform) transform._fk = mock.Mock( return_value={ "global_translation": torch.zeros(1, 4, 2, 3), "global_rotation_quat": torch.zeros(1, 4, 2, 4), "global_velocity": torch.zeros(1, 4, 2, 3), "global_angular_velocity": torch.zeros(1, 4, 2, 3), "dof_vel": torch.zeros(1, 4, 2), } ) arrays = { "ref_root_pos": torch.zeros(4, 3), "ref_root_rot": torch.zeros(4, 4), "ref_dof_pos": torch.zeros(4, 2), } transform( arrays, fps=60.0, prefix="ref_", vel_smoothing_sigma=0.0, ) self.assertEqual( transform._fk.call_args.kwargs["vel_smoothing_sigma"], 0.0, ) def test_cpu_fk_transform_defaults_vel_smoothing_sigma_to_two(self): transform = _CpuFKTransform.__new__(_CpuFKTransform) transform._fk = mock.Mock( return_value={ "global_translation": torch.zeros(1, 4, 2, 3), "global_rotation_quat": torch.zeros(1, 4, 2, 4), "global_velocity": torch.zeros(1, 4, 2, 3), "global_angular_velocity": torch.zeros(1, 4, 2, 3), "dof_vel": torch.zeros(1, 4, 2), } ) arrays = { "ref_root_pos": torch.zeros(4, 3), "ref_root_rot": torch.zeros(4, 4), "ref_dof_pos": torch.zeros(4, 2), } transform(arrays, fps=60.0) self.assertEqual( transform._fk.call_args.kwargs["vel_smoothing_sigma"], 2.0, ) def test_hdf5_v2_sample_exposes_zero_cutoff_metadata_when_disabled(self): dataset = self._make_stub_root_dof_dataset() sample = dataset[0] self.assertIn("filter_cutoff_hz", sample.tensors) torch.testing.assert_close( sample.tensors["filter_cutoff_hz"], torch.zeros(4, 1, dtype=torch.float32), ) def test_hdf5_v2_sample_exposes_sampled_cutoff_metadata(self): dataset = self._make_stub_root_dof_dataset( cutoff_pool=(3.0,), online_filter_enabled=True, ) sample = dataset[0] self.assertIn("filter_cutoff_hz", sample.tensors) torch.testing.assert_close( sample.tensors["filter_cutoff_hz"], torch.full((4, 1), 3.0, dtype=torch.float32), ) def test_hdf5_v2_sample_generates_filtered_reference_family(self): dataset = self._make_stub_root_dof_dataset( cutoff_pool=(3.0,), online_filter_enabled=True, ) sample = dataset[0] for tensor_name in ( "ft_ref_root_pos", "ft_ref_root_rot", "ft_ref_dof_pos", "ft_ref_rg_pos", "ft_ref_rb_rot", "ft_ref_body_vel", "ft_ref_body_ang_vel", "ft_ref_dof_vel", "ft_ref_root_vel", "ft_ref_root_ang_vel", ): self.assertIn(tensor_name, sample.tensors) def test_hdf5_v2_sample_uses_split_fk_smoothing_sigmas(self): dataset = self._make_stub_root_dof_dataset( cutoff_pool=(3.0,), online_filter_enabled=True, ref_vel_smoothing_sigma=0.0, ft_ref_vel_smoothing_sigma=2.0, ) sample = dataset[0] self.assertIn("ref_root_vel", sample.tensors) self.assertIn("ft_ref_root_vel", sample.tensors) self.assertEqual( dataset._fk_calls, [("ref_", 0.0), ("ft_ref_", 2.0)], ) def test_hdf5_v2_sample_skips_filtered_reference_family_when_disabled( self, ): dataset = self._make_stub_root_dof_dataset( cutoff_pool=(3.0,), online_filter_enabled=True, allowed_prefixes=("ref_",), ) sample = dataset[0] self.assertNotIn("ft_ref_root_pos", sample.tensors) self.assertNotIn("ft_ref_rg_pos", sample.tensors) @staticmethod def _make_stub_root_dof_dataset( *, cutoff_pool=(0.0,), online_filter_enabled=False, allowed_prefixes=("ref_", "ft_ref_"), ref_vel_smoothing_sigma=2.0, ft_ref_vel_smoothing_sigma=2.0, ): dataset = Hdf5RootDofDataset.__new__(Hdf5RootDofDataset) dataset.windows = [ MotionWindow( motion_key="clip-a__start_0_len_4", shard_index=0, start=0, length=4, raw_motion_key="clip-a", window_index=0, ) ] dataset.clips = { "clip-a": { "metadata": { "motion_fps": 60.0, } } } dataset._progress_counter = None dataset._world_frame_transform = None dataset._file_handles = {} dataset._h5_access_counter = 0 dataset._h5_cleanup_interval = int(1e6) dataset._online_filter_enabled = bool(online_filter_enabled) dataset._online_filter_cutoff_hz_pool = tuple(cutoff_pool) dataset._allowed_prefixes = tuple(allowed_prefixes) dataset._ref_vel_smoothing_sigma = float(ref_vel_smoothing_sigma) dataset._ft_ref_vel_smoothing_sigma = float(ft_ref_vel_smoothing_sigma) dataset._fk_calls = [] shard_handle = { "ref_root_pos": torch.arange(12, dtype=torch.float32) .reshape(4, 3) .numpy(), "ref_root_rot": torch.tensor( [[0.0, 0.0, 0.0, 1.0]] * 4, dtype=torch.float32 ).numpy(), "ref_dof_pos": torch.arange(8, dtype=torch.float32) .reshape(4, 2) .numpy(), } dataset._online_filter_butter_order = 4 def fake_fk_transform( arrays, fps, prefix="ref_", vel_smoothing_sigma=2.0, ): del fps dataset._fk_calls.append((prefix, float(vel_smoothing_sigma))) root_pos = arrays[f"{prefix}root_pos"] root_rot = arrays[f"{prefix}root_rot"] arrays[f"{prefix}rg_pos"] = torch.stack( [root_pos, root_pos], dim=1 ) arrays[f"{prefix}rb_rot"] = torch.stack( [root_rot, root_rot], dim=1 ) arrays[f"{prefix}body_vel"] = torch.zeros( 4, 2, 3, dtype=torch.float32 ) arrays[f"{prefix}body_ang_vel"] = torch.zeros( 4, 2, 3, dtype=torch.float32 ) arrays[f"{prefix}dof_vel"] = torch.zeros(4, 2, dtype=torch.float32) dataset._fk_transform = fake_fk_transform dataset._get_shard_handle = lambda shard_index: shard_handle return dataset ================================================ FILE: tests/test_motion_cache_startup.py ================================================ from pathlib import Path import sys from types import SimpleNamespace import unittest from unittest import mock sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.algo.algo_base import BaseOnpolicyRL import holomotion.src.training.h5_dataloader as h5_dataloader_module from holomotion.src.training.h5_dataloader import MotionClipBatchCache class _FakeDataset: def __init__(self, length: int = 8) -> None: self._length = int(length) self.max_frame_length = 16 self.progress_counter = None def __len__(self) -> int: return self._length def __getitem__(self, index: int): raise AssertionError("__getitem__ should not be called in these tests") def set_progress_counter(self, counter) -> None: self.progress_counter = counter def close(self) -> None: return class MotionCacheStartupTests(unittest.TestCase): def test_motion_cache_uses_explicit_constructor_seed(self): with ( mock.patch.object( MotionClipBatchCache, "_build_dataloader", lambda self: None ), mock.patch.object( MotionClipBatchCache, "_prime_buffers", lambda self: None ), ): cache = MotionClipBatchCache( train_dataset=_FakeDataset(), batch_size=2, num_workers=0, pin_memory=False, persistent_workers=False, seed=1234, ) self.assertEqual(cache._seed, 1234) def test_setup_seeding_does_not_reinitialize_motion_cache(self): algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL) algo.config = {"seed": 100} algo.process_rank = 2 algo.command_name = "ref_motion" env_seed_calls = [] motion_cache_seed_calls = [] algo.env = SimpleNamespace( seed=lambda seed: env_seed_calls.append(seed) ) algo.command_term = SimpleNamespace( cfg=SimpleNamespace(seed=102), set_motion_cache_seed=lambda seed, reinitialize: motion_cache_seed_calls.append((seed, reinitialize)), ) BaseOnpolicyRL._setup_seeding(algo) self.assertEqual(algo.base_seed, 100) self.assertEqual(algo.seed, 102) self.assertEqual(env_seed_calls, [102]) self.assertEqual(motion_cache_seed_calls, [(102, False)]) def test_motion_cache_passes_loader_timeout_to_dataloader(self): captured_kwargs = {} class _FakeLoader: def __init__(self, *args, **kwargs) -> None: del args captured_kwargs.update(kwargs) with ( mock.patch.object(h5_dataloader_module, "DataLoader", _FakeLoader), mock.patch.object( MotionClipBatchCache, "_prime_buffers", lambda self: None ), ): MotionClipBatchCache( train_dataset=_FakeDataset(), batch_size=2, num_workers=0, pin_memory=False, persistent_workers=False, loader_timeout=17, ) self.assertEqual(captured_kwargs["timeout"], 17) def test_motion_cache_disables_progress_bar_in_distributed_runs(self): with ( mock.patch.object( MotionClipBatchCache, "_build_dataloader", lambda self: None ), mock.patch.object( MotionClipBatchCache, "_prime_buffers", lambda self: None ), ): cache = MotionClipBatchCache( train_dataset=_FakeDataset(), batch_size=2, num_workers=0, pin_memory=False, persistent_workers=False, sampler_world_size=8, batch_progress_bar=True, ) self.assertIs(cache._should_use_batch_progress(), False) self.assertIsNone(cache._batch_progress_counter) def test_motion_cache_keeps_progress_bar_for_local_runs(self): with ( mock.patch.object( MotionClipBatchCache, "_build_dataloader", lambda self: None ), mock.patch.object( MotionClipBatchCache, "_prime_buffers", lambda self: None ), ): cache = MotionClipBatchCache( train_dataset=_FakeDataset(), batch_size=2, num_workers=0, pin_memory=False, persistent_workers=False, sampler_world_size=1, batch_progress_bar=True, ) self.assertIs(cache._should_use_batch_progress(), True) self.assertIsNotNone(cache._batch_progress_counter) def test_motion_cache_requires_positive_loader_timeout(self): with ( mock.patch.object( MotionClipBatchCache, "_build_dataloader", lambda self: None ), mock.patch.object( MotionClipBatchCache, "_prime_buffers", lambda self: None ), ): with self.assertRaisesRegex( ValueError, "loader_timeout must be >= 0" ): MotionClipBatchCache( train_dataset=_FakeDataset(), batch_size=2, num_workers=0, pin_memory=False, persistent_workers=False, loader_timeout=-1, ) ================================================ FILE: tests/test_motion_tracking_command_reference_prefix.py ================================================ import sys import unittest from pathlib import Path from types import SimpleNamespace sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.utils.reference_prefix import resolve_reference_tensor_key class MotionTrackingCommandReferencePrefixTests(unittest.TestCase): def test_ft_ref_prefix_uses_filtered_tensor_when_present(self): resolved = resolve_reference_tensor_key( batch_tensors={"ft_ref_root_pos": SimpleNamespace()}, base_key="root_pos", prefix="ft_ref_", ) self.assertEqual(resolved, "ft_ref_root_pos") def test_ft_ref_prefix_requires_filtered_tensor(self): with self.assertRaises(KeyError): resolve_reference_tensor_key( batch_tensors={"root_pos": SimpleNamespace()}, base_key="root_pos", prefix="ft_ref_", ) def test_ref_prefix_falls_back_to_unprefixed_tensor(self): resolved = resolve_reference_tensor_key( batch_tensors={"root_pos": SimpleNamespace()}, base_key="root_pos", prefix="ref_", ) self.assertEqual(resolved, "root_pos") def test_ref_prefix_prefers_prefixed_tensor_when_present(self): resolved = resolve_reference_tensor_key( batch_tensors={ "root_pos": SimpleNamespace(), "ref_root_pos": SimpleNamespace(), }, base_key="root_pos", prefix="ref_", ) self.assertEqual(resolved, "ref_root_pos") ================================================ FILE: tests/test_motion_tracking_timing.py ================================================ import importlib.util import sys from pathlib import Path from types import ModuleType, SimpleNamespace import pytest import torch MODULE_PATH = ( Path(__file__).resolve().parents[1] / "holomotion" / "src" / "env" / "isaaclab_components" / "isaaclab_motion_tracking_command.py" ) class _DummyConfig: def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs def _install_fake_motion_command_deps(monkeypatch): isaaclab_mdp = ModuleType("isaaclab.envs.mdp") isaaclab_sim = ModuleType("isaaclab.sim") isaaclab_sim.PreviewSurfaceCfg = _DummyConfig isaaclab_sim.PhysxCfg = _DummyConfig isaaclab_sim.SimulationCfg = _DummyConfig isaaclab_math = ModuleType("isaaclab.utils.math") isaaclab_math.quat_apply_inverse = lambda quat, vec: vec isaaclab_math.quat_apply = lambda quat, vec: vec isaaclab_math.yaw_quat = lambda quat: quat isaaclab_math.quat_inv = lambda quat: quat isaaclab_math.quat_mul = lambda lhs, rhs: lhs isaaclab_math.sample_uniform = ( lambda low, high, shape, device=None: torch.zeros( *shape, device=device ) ) isaaclab_actuators = ModuleType("isaaclab.actuators") isaaclab_actuators.ImplicitActuatorCfg = _DummyConfig isaaclab_assets = ModuleType("isaaclab.assets") isaaclab_assets.Articulation = object isaaclab_assets.ArticulationCfg = _DummyConfig isaaclab_assets.AssetBaseCfg = _DummyConfig isaaclab_envs = ModuleType("isaaclab.envs") isaaclab_envs.ManagerBasedRLEnv = object isaaclab_envs.ManagerBasedRLEnvCfg = _DummyConfig isaaclab_envs.ViewerCfg = _DummyConfig isaaclab_envs_mdp_actions = ModuleType("isaaclab.envs.mdp.actions") isaaclab_envs_mdp_actions.JointEffortActionCfg = _DummyConfig isaaclab_managers = ModuleType("isaaclab.managers") isaaclab_managers.ActionTermCfg = _DummyConfig isaaclab_managers.CommandTerm = object isaaclab_managers.CommandTermCfg = _DummyConfig isaaclab_managers.EventTermCfg = _DummyConfig isaaclab_managers.ObservationGroupCfg = _DummyConfig isaaclab_managers.ObservationTermCfg = _DummyConfig isaaclab_managers.RewardTermCfg = _DummyConfig isaaclab_managers.TerminationTermCfg = _DummyConfig isaaclab_markers = ModuleType("isaaclab.markers") isaaclab_markers.VisualizationMarkers = _DummyConfig isaaclab_markers.VisualizationMarkersCfg = _DummyConfig isaaclab_markers_config = ModuleType("isaaclab.markers.config") isaaclab_markers_config.SPHERE_MARKER_CFG = SimpleNamespace( replace=lambda **kwargs: SimpleNamespace( markers={"sphere": SimpleNamespace(radius=None)}, **kwargs, ) ) isaaclab_scene = ModuleType("isaaclab.scene") isaaclab_scene.InteractiveSceneCfg = _DummyConfig isaaclab_sensors = ModuleType("isaaclab.sensors") isaaclab_sensors.ContactSensorCfg = _DummyConfig isaaclab_sensors.RayCasterCfg = _DummyConfig isaaclab_sensors.patterns = _DummyConfig isaaclab_terrains = ModuleType("isaaclab.terrains") isaaclab_terrains.TerrainImporterCfg = _DummyConfig isaaclab_utils = ModuleType("isaaclab.utils") isaaclab_utils.configclass = lambda cls: cls isaaclab_noise = ModuleType("isaaclab.utils.noise") isaaclab_noise.AdditiveUniformNoiseCfg = _DummyConfig h5_dataloader = ModuleType("holomotion.src.training.h5_dataloader") h5_dataloader.Hdf5MotionDataset = object h5_dataloader.Hdf5RootDofDataset = object h5_dataloader.MotionClipBatchCache = object h5_dataloader.build_motion_datasets_from_cfg = lambda *args, **kwargs: None rotations = ModuleType("holomotion.src.utils.isaac_utils.rotations") rotations.calc_heading_quat_inv = lambda *args, **kwargs: None rotations.get_euler_xyz = lambda *args, **kwargs: None rotations.my_quat_rotate = lambda *args, **kwargs: None rotations.quat_inverse = lambda *args, **kwargs: None rotations.quat_mul = lambda *args, **kwargs: None rotations.quat_rotate = lambda *args, **kwargs: None rotations.quat_rotate_inverse = lambda *args, **kwargs: None rotations.quaternion_to_matrix = lambda *args, **kwargs: None rotations.wrap_to_pi = lambda *args, **kwargs: None rotations.wxyz_to_xyzw = lambda x: x rotations.xyzw_to_wxyz = lambda x: x reference_prefix = ModuleType("holomotion.src.utils.reference_prefix") reference_prefix.resolve_reference_tensor_key = ( lambda batch_tensors, base_key, prefix="ref_": f"{prefix}{base_key}" ) omegaconf = ModuleType("omegaconf") omegaconf.OmegaConf = SimpleNamespace( to_container=lambda value, resolve=True: value ) loguru = ModuleType("loguru") loguru.logger = SimpleNamespace(info=lambda *args, **kwargs: None) tqdm = ModuleType("tqdm") tqdm.tqdm = lambda iterable, *args, **kwargs: iterable scipy = ModuleType("scipy") scipy_spatial = ModuleType("scipy.spatial") scipy_transform = ModuleType("scipy.spatial.transform") scipy_transform.Rotation = object for name, module in { "isaaclab.envs.mdp": isaaclab_mdp, "isaaclab.sim": isaaclab_sim, "isaaclab.utils.math": isaaclab_math, "isaaclab.actuators": isaaclab_actuators, "isaaclab.assets": isaaclab_assets, "isaaclab.envs": isaaclab_envs, "isaaclab.envs.mdp.actions": isaaclab_envs_mdp_actions, "isaaclab.managers": isaaclab_managers, "isaaclab.markers": isaaclab_markers, "isaaclab.markers.config": isaaclab_markers_config, "isaaclab.scene": isaaclab_scene, "isaaclab.sensors": isaaclab_sensors, "isaaclab.terrains": isaaclab_terrains, "isaaclab.utils": isaaclab_utils, "isaaclab.utils.noise": isaaclab_noise, "holomotion.src.training.h5_dataloader": h5_dataloader, "holomotion.src.utils.isaac_utils.rotations": rotations, "holomotion.src.utils.reference_prefix": reference_prefix, "omegaconf": omegaconf, "loguru": loguru, "tqdm": tqdm, "scipy": scipy, "scipy.spatial": scipy_spatial, "scipy.spatial.transform": scipy_transform, }.items(): monkeypatch.setitem(sys.modules, name, module) def _load_motion_command_module(monkeypatch): _install_fake_motion_command_deps(monkeypatch) module_name = "_test_motion_tracking_timing" spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH) module = importlib.util.module_from_spec(spec) assert spec is not None assert spec.loader is not None sys.modules[module_name] = module spec.loader.exec_module(module) return module def test_immediate_next_reference_getters_use_slot_one(monkeypatch): module = _load_motion_command_module(monkeypatch) command = module.RefMotionCommand.__new__(module.RefMotionCommand) command.urdf2sim_dof_idx = torch.tensor([1, 0], dtype=torch.long) command.urdf2sim_body_idx = torch.tensor([1, 0], dtype=torch.long) command._env_origins = torch.tensor( [[10.0, 20.0, 30.0]], dtype=torch.float32 ) base_tensors = { "ref_dof_pos": torch.tensor( [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]], dtype=torch.float32 ), "ref_root_pos": torch.tensor( [[[0.0, 1.0, 2.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]], dtype=torch.float32, ), "ref_body_vel": torch.tensor( [ [ [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], [[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], [[4.0, 4.0, 4.0], [5.0, 5.0, 5.0]], ] ], dtype=torch.float32, ), } command._get_ref_state_array = ( lambda base_key, prefix="ref_": base_tensors[f"{prefix}{base_key}"] ) dof_pos = command.get_ref_motion_dof_pos_immediate_next() root_pos = command.get_ref_motion_root_global_pos_immediate_next() body_lin_vel = ( command.get_ref_motion_bodylink_global_lin_vel_immediate_next() ) assert torch.allclose(dof_pos, torch.tensor([[4.0, 3.0]])) assert torch.allclose(root_pos, torch.tensor([[17.0, 28.0, 39.0]])) assert torch.allclose( body_lin_vel, torch.tensor([[[3.0, 3.0, 3.0], [2.0, 2.0, 2.0]]]), ) def test_update_command_skips_just_reset_envs(monkeypatch): module = _load_motion_command_module(monkeypatch) command = module.RefMotionCommand.__new__(module.RefMotionCommand) command.device = torch.device("cpu") command.num_envs = 3 command._frame_indices = torch.tensor([10, 20, 30], dtype=torch.long) command._swap_step_counter = 0 command._swap_pending = False command._motion_cache = SimpleNamespace(swap_interval_steps=100) command._env = SimpleNamespace( episode_length_buf=torch.tensor([5, 0, 2], dtype=torch.long) ) command._filter_env_ids_for_motion_task = lambda env_ids: env_ids command._resample_when_motion_end_cache = lambda: None command._update_ref_motion_state_from_cache = lambda env_ids=None: None command._update_command() assert torch.equal(command._frame_indices, torch.tensor([11, 20, 31])) assert command._swap_step_counter == 1 def test_update_command_resumes_advancing_after_reset_step(monkeypatch): module = _load_motion_command_module(monkeypatch) command = module.RefMotionCommand.__new__(module.RefMotionCommand) command.device = torch.device("cpu") command.num_envs = 1 command._frame_indices = torch.tensor([20], dtype=torch.long) command._swap_step_counter = 0 command._swap_pending = False command._motion_cache = SimpleNamespace(swap_interval_steps=100) command._env = SimpleNamespace(episode_length_buf=torch.tensor([0])) command._filter_env_ids_for_motion_task = lambda env_ids: env_ids command._resample_when_motion_end_cache = lambda: None command._update_ref_motion_state_from_cache = lambda env_ids=None: None command._update_command() assert torch.equal(command._frame_indices, torch.tensor([20])) command._env.episode_length_buf = torch.tensor([1]) command._update_command() assert torch.equal(command._frame_indices, torch.tensor([21])) def test_mpjpe_metrics_use_immediate_next_reference(monkeypatch): module = _load_motion_command_module(monkeypatch) command = module.RefMotionCommand.__new__(module.RefMotionCommand) command.device = torch.device("cpu") command.num_envs = 1 command.metrics = {} command.arm_dof_indices = [0] command.torso_dof_indices = [1] command.leg_dof_indices = [2] command.robot = SimpleNamespace( data=SimpleNamespace( joint_pos=torch.tensor([[0.1, 0.2, 0.3]], dtype=torch.float32) ) ) command.get_ref_motion_dof_pos_cur = lambda prefix="ref_": ( _ for _ in () ).throw(AssertionError("current reference should not be used")) command.get_ref_motion_dof_pos_immediate_next = ( lambda prefix="ref_": torch.tensor( [[0.1, 0.2, 0.3]], dtype=torch.float32 ) ) command._update_mpjpe_metrics() assert torch.allclose( command.metrics["Task/MPJPE_WholeBody"], torch.zeros(1) ) def test_mpkpe_metrics_use_immediate_next_reference(monkeypatch): module = _load_motion_command_module(monkeypatch) command = module.RefMotionCommand.__new__(module.RefMotionCommand) command.device = torch.device("cpu") command.num_envs = 1 command.metrics = {} command.arm_body_indices = [0] command.torso_body_indices = [1] command.leg_body_indices = [2] command.robot = SimpleNamespace( data=SimpleNamespace( body_pos_w=torch.tensor( [[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]], dtype=torch.float32, ) ) ) command.get_ref_motion_bodylink_global_pos_cur = lambda prefix="ref_": ( _ for _ in () ).throw(AssertionError("current reference should not be used")) command.get_ref_motion_bodylink_global_pos_immediate_next = ( lambda prefix="ref_": torch.tensor( [[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]], dtype=torch.float32, ) ) command._update_mpkpe_metrics() assert torch.allclose( command.metrics["Task/MPKPE_WholeBody"], torch.zeros(1) ) ================================================ FILE: tests/test_mujoco_filtered_ref_compat.py ================================================ import tempfile from pathlib import Path import numpy as np from omegaconf import OmegaConf from holomotion.src.evaluation.eval_mujoco_sim2sim import MujocoEvaluator from holomotion.src.evaluation.obs.obs_builder import PolicyObsBuilder PROJECT_ROOT = Path(__file__).resolve().parents[1] OBS_CONFIG_PATH = ( PROJECT_ROOT / "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3_with_freq.yaml" ) MODULE_CONFIG_PATH = ( PROJECT_ROOT / "holomotion/config/modules/motion_tracking/tf_motrack_v3_with_ft.yaml" ) OBS_CONFIG_PATH_V2 = ( PROJECT_ROOT / "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3_sonic_router_v2.yaml" ) MODULE_CONFIG_PATH_V2 = ( PROJECT_ROOT / "holomotion/config/modules/motion_tracking/tf_motrack_v3_wo_eepos_ref_route_v2.yaml" ) DOMAIN_RAND_CONFIG_PATH = ( PROJECT_ROOT / "holomotion/config/env/domain_randomization/domain_rand_strong.yaml" ) def _make_minimal_motion_npz(path: Path, *, include_cutoff: bool) -> None: payload = { "ref_global_translation": np.zeros((2, 1, 3), dtype=np.float32), "ref_global_rotation_quat": np.tile( np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 1, 1), ), "ref_global_velocity": np.zeros((2, 1, 3), dtype=np.float32), "ref_global_angular_velocity": np.zeros((2, 1, 3), dtype=np.float32), "ref_dof_pos": np.zeros((2, 2), dtype=np.float32), "ref_dof_vel": np.zeros((2, 2), dtype=np.float32), } if include_cutoff: payload["filter_cutoff_hz"] = np.array( [[2.0], [3.0]], dtype=np.float32 ) np.savez(path, **payload) def test_policy_obs_list_accepts_shared_cutoff_term(): config = OmegaConf.merge( OmegaConf.load(OBS_CONFIG_PATH), OmegaConf.load(MODULE_CONFIG_PATH), ) evaluator = MujocoEvaluator.__new__(MujocoEvaluator) evaluator.config = config atomic_obs_list = evaluator._get_policy_atomic_obs_list() term_names = [str(list(item.keys())[0]) for item in atomic_obs_list] assert term_names[0] == "ref_motion_filter_cutoff_hz" assert "actor_ref_gravity_projection_cur" in term_names def test_cutoff_obs_getters_use_current_frame_and_default_zero(): evaluator = MujocoEvaluator.__new__(MujocoEvaluator) evaluator.motion_frame_idx = 1 evaluator.filter_cutoff_hz = np.array([[2.0], [3.0]], dtype=np.float32) assert evaluator._get_obs_ref_motion_filter_cutoff_hz() == np.float32(3.0) assert ( evaluator._get_obs_actor_ref_motion_filter_cutoff_hz() == np.float32(3.0) ) missing = MujocoEvaluator.__new__(MujocoEvaluator) missing.motion_frame_idx = 0 assert missing._get_obs_ref_motion_filter_cutoff_hz() == 0.0 def test_policy_obs_list_v2_uses_only_actor_schema_terms(): config = OmegaConf.merge( OmegaConf.load(OBS_CONFIG_PATH_V2), OmegaConf.load(MODULE_CONFIG_PATH_V2), OmegaConf.load(DOMAIN_RAND_CONFIG_PATH), ) evaluator = MujocoEvaluator.__new__(MujocoEvaluator) evaluator.config = config atomic_obs_list = evaluator._get_policy_atomic_obs_list() term_names = [str(list(item.keys())[0]) for item in atomic_obs_list] assert not any(name.startswith("actor_moe_router_") for name in term_names) assert "actor_ref_gravity_projection_cur" in term_names assert "actor_ref_base_linvel_fut" in term_names def test_load_specific_motion_loads_cutoff_metadata_with_zero_fallback(): with tempfile.TemporaryDirectory() as tmp_dir: with_cutoff = Path(tmp_dir) / "with_cutoff.npz" without_cutoff = Path(tmp_dir) / "without_cutoff.npz" _make_minimal_motion_npz(with_cutoff, include_cutoff=True) _make_minimal_motion_npz(without_cutoff, include_cutoff=False) evaluator = MujocoEvaluator.__new__(MujocoEvaluator) evaluator.load_specific_motion(with_cutoff) np.testing.assert_allclose( evaluator.filter_cutoff_hz, np.array([[2.0], [3.0]], dtype=np.float32), ) evaluator.load_specific_motion(without_cutoff) np.testing.assert_allclose( evaluator.filter_cutoff_hz, np.zeros((2, 1), dtype=np.float32), ) ================================================ FILE: tests/test_obs_norm_compile.py ================================================ import torch import torch.nn as nn from holomotion.src.modules.agent_modules import PPOTFActor from holomotion.src.modules.network_modules import EmpiricalNormalization def _make_actor_with_obs_norm(obs_dim: int = 16) -> PPOTFActor: actor = PPOTFActor.__new__(PPOTFActor) nn.Module.__init__(actor) actor.obs_norm_enabled = True actor.obs_norm_clip = 10.0 actor.obs_normalizer = EmpiricalNormalization(shape=(obs_dim,)) return actor def test_obs_norm_update_is_not_captured_by_dynamo(): actor = _make_actor_with_obs_norm() obs = torch.randn(8, 16) def normalize_with_update(x: torch.Tensor) -> torch.Tensor: return actor._normalize_actor_obs(x, True) explanation = torch._dynamo.explain(normalize_with_update)(obs) graph_code = "\n".join(graph.code for graph in explanation.graphs) assert "torch.var" not in graph_code assert "torch.mean" not in graph_code count_before_compile = actor.obs_normalizer.count.item() compiled = torch.compile(normalize_with_update, backend="eager") normalized = compiled(obs) assert normalized.shape == obs.shape assert ( actor.obs_normalizer.count.item() - count_before_compile == obs.shape[0] ) ================================================ FILE: tests/test_observation_frames.py ================================================ import importlib.util import sys from pathlib import Path from types import ModuleType, SimpleNamespace import pytest import torch OBSERVATION_PATH = ( Path(__file__).resolve().parents[1] / "holomotion" / "src" / "env" / "isaaclab_components" / "isaaclab_observation.py" ) class _DummyConfig: def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs class _Scene(SimpleNamespace): def __getitem__(self, key): return getattr(self, key) def _identity_quat(*shape: int) -> torch.Tensor: quat = torch.zeros(*shape, 4, dtype=torch.float32) quat[..., 0] = 1.0 return quat def _load_observation_module(monkeypatch): isaaclab = ModuleType("isaaclab") isaaclab_mdp = ModuleType("isaaclab.envs.mdp") isaaclab_math = ModuleType("isaaclab.utils.math") isaaclab_math.quat_apply = lambda quat, vec: vec isaaclab_math.quat_apply_inverse = lambda quat, vec: vec isaaclab_math.quat_inv = lambda quat: quat isaaclab_math.matrix_from_quat = lambda quat: torch.zeros( *quat.shape[:-1], 3, 3, dtype=quat.dtype, device=quat.device ) isaaclab_math.subtract_frame_transforms = lambda t01, q01, t02, q02: ( t02 - t01, q02, ) isaaclab_math.__getattr__ = lambda name: (lambda *args, **kwargs: None) isaaclab_noise = ModuleType("isaaclab.utils.noise") isaaclab_noise.__getattr__ = lambda name: _DummyConfig isaaclab_envs = ModuleType("isaaclab.envs") isaaclab_envs.ManagerBasedRLEnv = object isaaclab_envs.ManagerBasedRLEnvCfg = _DummyConfig isaaclab_envs.ViewerCfg = _DummyConfig isaaclab_sim = ModuleType("isaaclab.sim") isaaclab_sim.__getattr__ = lambda name: _DummyConfig isaaclab_actuators = ModuleType("isaaclab.actuators") isaaclab_actuators.ImplicitActuatorCfg = _DummyConfig isaaclab_assets = ModuleType("isaaclab.assets") isaaclab_assets.Articulation = object isaaclab_assets.ArticulationCfg = _DummyConfig isaaclab_assets.AssetBaseCfg = _DummyConfig isaaclab_managers = ModuleType("isaaclab.managers") isaaclab_managers.__getattr__ = lambda name: _DummyConfig isaaclab_markers = ModuleType("isaaclab.markers") isaaclab_markers.VisualizationMarkers = _DummyConfig isaaclab_markers.VisualizationMarkersCfg = _DummyConfig isaaclab_markers_config = ModuleType("isaaclab.markers.config") isaaclab_markers_config.FRAME_MARKER_CFG = _DummyConfig isaaclab_scene = ModuleType("isaaclab.scene") isaaclab_scene.InteractiveSceneCfg = _DummyConfig isaaclab_sensors = ModuleType("isaaclab.sensors") isaaclab_sensors.ContactSensorCfg = _DummyConfig isaaclab_sensors.RayCasterCfg = _DummyConfig isaaclab_sensors.patterns = _DummyConfig isaaclab_terrains = ModuleType("isaaclab.terrains") isaaclab_terrains.TerrainImporterCfg = _DummyConfig isaaclab_utils = ModuleType("isaaclab.utils") isaaclab_utils.configclass = lambda cls: cls omegaconf = ModuleType("omegaconf") omegaconf.DictConfig = dict omegaconf.ListConfig = list omegaconf.OmegaConf = SimpleNamespace( to_container=lambda value, resolve=True: value ) fake_utils_module = ModuleType( "holomotion.src.env.isaaclab_components.isaaclab_utils" ) fake_utils_module.resolve_holo_config = lambda value: value isaaclab.envs = isaaclab_envs isaaclab.sim = isaaclab_sim isaaclab.actuators = isaaclab_actuators isaaclab.assets = isaaclab_assets isaaclab.managers = isaaclab_managers isaaclab.markers = isaaclab_markers isaaclab.scene = isaaclab_scene isaaclab.sensors = isaaclab_sensors isaaclab.terrains = isaaclab_terrains isaaclab.utils = isaaclab_utils isaaclab_envs.mdp = isaaclab_mdp isaaclab_utils.math = isaaclab_math isaaclab_utils.noise = isaaclab_noise for name, module in { "isaaclab": isaaclab, "isaaclab.envs.mdp": isaaclab_mdp, "isaaclab.utils.math": isaaclab_math, "isaaclab.utils.noise": isaaclab_noise, "isaaclab.envs": isaaclab_envs, "isaaclab.sim": isaaclab_sim, "isaaclab.actuators": isaaclab_actuators, "isaaclab.assets": isaaclab_assets, "isaaclab.managers": isaaclab_managers, "isaaclab.markers": isaaclab_markers, "isaaclab.markers.config": isaaclab_markers_config, "isaaclab.scene": isaaclab_scene, "isaaclab.sensors": isaaclab_sensors, "isaaclab.terrains": isaaclab_terrains, "isaaclab.utils": isaaclab_utils, "omegaconf": omegaconf, ( "holomotion.src.env.isaaclab_components.isaaclab_utils" ): fake_utils_module, }.items(): monkeypatch.setitem(sys.modules, name, module) module_name = "_test_observation_frames" spec = importlib.util.spec_from_file_location( module_name, OBSERVATION_PATH ) module = importlib.util.module_from_spec(spec) assert spec is not None assert spec.loader is not None sys.modules[module_name] = module spec.loader.exec_module(module) return module def test_ref_future_observations_can_limit_num_frames(monkeypatch): observation = _load_observation_module(monkeypatch) class _Command: def get_ref_motion_dof_pos_fut(self, prefix="ref_"): return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3) def get_ref_motion_dof_vel_fut(self, prefix="ref_"): return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3) def get_ref_motion_gravity_projection_fut(self, prefix="ref_"): return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3) def get_ref_motion_base_linvel_fut(self, prefix="ref_"): return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3) def get_ref_motion_base_angvel_fut(self, prefix="ref_"): return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3) def get_ref_motion_root_global_pos_fut(self, prefix="ref_"): pos = torch.arange(24, dtype=torch.float32).reshape(2, 4, 3) pos[..., 2] = torch.tensor( [[1.0, 2.0, 3.0, 4.0], [0.5, 1.5, 2.5, 3.5]], dtype=torch.float32, ) return pos class _CommandManager: def get_term(self, name): return _Command() env = SimpleNamespace( command_manager=_CommandManager(), scene=SimpleNamespace(env_origins=torch.zeros(2, 3)), ) dof_pos = observation.ObservationFunctions._get_obs_ref_dof_pos_fut( env, num_frames=2 ) dof_vel = observation.ObservationFunctions._get_obs_ref_dof_vel_fut( env, num_frames=2 ) gravity = ( observation.ObservationFunctions._get_obs_ref_gravity_projection_fut( env, num_frames=2 ) ) base_linvel = ( observation.ObservationFunctions._get_obs_ref_base_linvel_fut( env, num_frames=2 ) ) base_angvel = ( observation.ObservationFunctions._get_obs_ref_base_angvel_fut( env, num_frames=2 ) ) root_height = ( observation.ObservationFunctions._get_obs_ref_root_height_fut( env, num_frames=2 ) ) assert dof_pos.shape == (2, 2, 3) assert dof_vel.shape == (2, 6) assert gravity.shape == (2, 2, 3) assert base_linvel.shape == (2, 2, 3) assert base_angvel.shape == (2, 2, 3) assert root_height.shape == (2, 2, 1) torch.testing.assert_close( dof_pos, torch.tensor( [ [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], ], dtype=torch.float32, ), ) torch.testing.assert_close( dof_vel, torch.tensor( [ [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], [12.0, 13.0, 14.0, 15.0, 16.0, 17.0], ], dtype=torch.float32, ), ) torch.testing.assert_close( root_height[..., 0], torch.tensor([[1.0, 2.0], [0.5, 1.5]]) ) def _make_env(): env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32) robot_data = SimpleNamespace( body_pos_w=torch.tensor( [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32 ), body_quat_w=_identity_quat(1, 2), ) robot = SimpleNamespace(body_names=["anchor", "target"], data=robot_data) command = SimpleNamespace( anchor_bodylink_name="anchor", get_ref_motion_anchor_bodylink_global_pos_cur=( lambda prefix="ref_": torch.tensor([[10.0, 0.0, 0.0]]) ), get_ref_motion_anchor_bodylink_global_rot_wxyz_cur=( lambda prefix="ref_": _identity_quat(1) ), ) return SimpleNamespace( num_envs=1, scene=_Scene(env_origins=env_origins, robot=robot), command_manager=SimpleNamespace(get_term=lambda name: command), ) def test_global_robot_bodylink_pos_is_in_environment_frame(monkeypatch): observation = _load_observation_module(monkeypatch) env = _make_env() pos = observation.ObservationFunctions._get_obs_global_robot_bodylink_pos( env, keybody_names=["target"], ) assert torch.allclose(pos, torch.tensor([[[1.0, 0.0, 0.0]]])) def test_root_rel_robot_bodylink_pos_uses_consistent_env_frame(monkeypatch): observation = _load_observation_module(monkeypatch) env = _make_env() observation.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros( 1, 3, dtype=torch.float32 ) observation.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1) pos = ( observation.ObservationFunctions._get_obs_root_rel_robot_bodylink_pos( env, keybody_names=["target"], ) ) assert torch.allclose(pos, torch.tensor([[[1.0, 0.0, 0.0]]])) def test_global_anchor_pos_diff_uses_environment_frame_consistently( monkeypatch, ): observation = _load_observation_module(monkeypatch) env = _make_env() pos_diff = ( observation.ObservationFunctions._get_obs_global_anchor_pos_diff(env) ) assert torch.allclose(pos_diff, torch.zeros(1, 3)) def test_build_additive_uniform_noise_cfg_supports_optional_z_override( monkeypatch, ): observation = _load_observation_module(monkeypatch) noise = observation._build_noise_cfg( { "type": "AdditiveUniformNoiseCfg", "params": { "n_min": -0.1, "n_max": 0.1, "n_min_z": -0.02, "n_max_z": 0.03, }, } ) assert torch.equal( noise.kwargs["n_min"], torch.tensor([-0.1, -0.1, -0.02]) ) assert torch.equal(noise.kwargs["n_max"], torch.tensor([0.1, 0.1, 0.03])) def test_build_additive_uniform_noise_cfg_keeps_scalar_bounds_without_z_override( monkeypatch, ): observation = _load_observation_module(monkeypatch) noise = observation._build_noise_cfg( { "type": "AdditiveUniformNoiseCfg", "params": { "n_min": -0.1, "n_max": 0.1, }, } ) assert noise.kwargs["n_min"] == pytest.approx(-0.1) assert noise.kwargs["n_max"] == pytest.approx(0.1) ================================================ FILE: tests/test_onnx_attention_export.py ================================================ import sys import tempfile from pathlib import Path import numpy as np import onnx import onnxruntime import torch import torch.nn as nn import torch.nn.functional as F sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.modules.network_modules import ( export_safe_scaled_dot_product_attention, ) class _ExportAttentionModule(nn.Module): def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: return export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, ) class _ExportCausalAttentionModule(nn.Module): def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, ) -> torch.Tensor: return export_safe_scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, ) def _export_model( export_path: Path, module: nn.Module, inputs: tuple[torch.Tensor, ...], input_names: list[str], ) -> None: torch.onnx.export( module.eval(), inputs, str(export_path), opset_version=17, input_names=input_names, output_names=["out"], dynamo=False, verbose=False, ) def _export_op_types( module: nn.Module, *inputs: torch.Tensor, input_names: list[str], ) -> list[str]: with tempfile.TemporaryDirectory() as tmp_dir: export_path = Path(tmp_dir) / "attention.onnx" _export_model(export_path, module, inputs, input_names) model = onnx.load(str(export_path)) return [node.op_type for node in model.graph.node] def _run_onnx( module: nn.Module, *inputs: torch.Tensor, input_names: list[str], ) -> np.ndarray: with tempfile.TemporaryDirectory() as tmp_dir: export_path = Path(tmp_dir) / "attention.onnx" _export_model(export_path, module, inputs, input_names) session = onnxruntime.InferenceSession( str(export_path), providers=["CPUExecutionProvider"], ) feed = { name: tensor.detach().cpu().numpy() for name, tensor in zip(input_names, inputs, strict=True) } outputs = session.run(["out"], feed) return outputs[0] def test_export_safe_attention_uses_native_bool_mask_outside_export( monkeypatch, ): captured = {} original_sdpa = F.scaled_dot_product_attention def _spy_sdpa( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, enable_gqa: bool = False, ) -> torch.Tensor: captured["mask_dtype"] = None if attn_mask is None else attn_mask.dtype return original_sdpa( q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, enable_gqa=enable_gqa, ) monkeypatch.setattr(torch.onnx, "is_in_onnx_export", lambda: False) monkeypatch.setattr(F, "scaled_dot_product_attention", _spy_sdpa) q = torch.randn(1, 2, 3, 4) k = torch.randn(1, 2, 5, 4) v = torch.randn(1, 2, 5, 4) mask = torch.ones(1, 1, 3, 5, dtype=torch.bool) export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, ) assert captured["mask_dtype"] == torch.bool def test_export_safe_attention_matches_sdpa_for_valid_masks(): torch.manual_seed(0) q = torch.randn(2, 4, 3, 8) k = torch.randn(2, 4, 5, 8) v = torch.randn(2, 4, 5, 8) mask = torch.tensor( [ [[[True, True, False, False, False]]], [[[True, False, False, False, False]]], ], dtype=torch.bool, ).expand(2, 1, 3, 5) expected = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, ) actual = export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, ) torch.testing.assert_close(actual, expected, atol=1.0e-6, rtol=1.0e-5) def test_legacy_attention_export_avoids_isnan(): torch.manual_seed(2) q = torch.randn(1, 4, 1, 8) k = torch.randn(1, 4, 16, 8) v = torch.randn(1, 4, 16, 8) mask = torch.ones(1, 1, 1, 16, dtype=torch.bool) op_types = _export_op_types( _ExportAttentionModule(), q, k, v, mask, input_names=["q", "k", "v", "mask"], ) assert "IsNaN" not in op_types def test_legacy_attention_export_ort_matches_pytorch_for_future_mask(): torch.manual_seed(3) q = torch.randn(2, 4, 3, 8) k = torch.randn(2, 4, 5, 8) v = torch.randn(2, 4, 5, 8) mask = torch.tensor( [ [[[True, True, True, False, False]]], [[[True, False, False, False, False]]], ], dtype=torch.bool, ).expand(2, 1, 3, 5) expected = export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, ) actual = _run_onnx( _ExportAttentionModule(), q, k, v, mask, input_names=["q", "k", "v", "mask"], ) np.testing.assert_allclose( actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5 ) def test_legacy_attention_export_ort_matches_pytorch_for_causal_path(): torch.manual_seed(4) q = torch.randn(2, 4, 6, 8) k = torch.randn(2, 4, 6, 8) v = torch.randn(2, 4, 6, 8) expected = export_safe_scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, ) actual = _run_onnx( _ExportCausalAttentionModule(), q, k, v, input_names=["q", "k", "v"], ) np.testing.assert_allclose( actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5 ) def test_legacy_attention_export_ort_matches_pytorch_for_kv_mask(): torch.manual_seed(5) q = torch.randn(2, 4, 1, 8) k = torch.randn(2, 4, 16, 8) v = torch.randn(2, 4, 16, 8) valid_lengths = torch.tensor([16, 5], dtype=torch.int64) mask = ( torch.arange(16, dtype=torch.int64)[None, :] < valid_lengths[:, None] ) mask = mask[:, None, None, :] expected = export_safe_scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, ) actual = _run_onnx( _ExportAttentionModule(), q, k, v, mask, input_names=["q", "k", "v", "mask"], ) np.testing.assert_allclose( actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5 ) ================================================ FILE: tests/test_onnx_export.py ================================================ import sys from types import SimpleNamespace from holomotion.src.utils.onnx_export import attach_onnx_metadata_holomotion class _FakeEntry: def __init__(self): self.key = "" self.value = "" class _FakeTensor: def __init__(self, values): self._values = values def __getitem__(self, index): return _FakeTensor(self._values[index]) def cpu(self): return self def tolist(self): return self._values def test_attach_onnx_metadata_uses_default_joint_gains(monkeypatch): model = SimpleNamespace(metadata_props=[]) fake_onnx = SimpleNamespace( load=lambda path: model, save=lambda loaded_model, path: None, StringStringEntryProto=_FakeEntry, ) monkeypatch.setitem(sys.modules, "onnx", fake_onnx) robot_data = SimpleNamespace( joint_names=["joint_a", "joint_b"], joint_stiffness=_FakeTensor([[0.0, 0.0]]), joint_damping=_FakeTensor([[0.0, 0.0]]), default_joint_stiffness=_FakeTensor([[10.0, 20.0]]), default_joint_damping=_FakeTensor([[1.0, 2.0]]), default_joint_pos=_FakeTensor([[0.1, -0.2]]), ) action_term = SimpleNamespace(_scale=_FakeTensor([[0.5, 0.25]])) env = SimpleNamespace( scene={"robot": SimpleNamespace(data=robot_data)}, action_manager=SimpleNamespace( get_term=lambda name: action_term, ), ) attach_onnx_metadata_holomotion(env, "dummy.onnx") metadata = {entry.key: entry.value for entry in model.metadata_props} assert metadata["joint_stiffness"] == "10.000,20.000" assert metadata["joint_damping"] == "1.000,2.000" ================================================ FILE: tests/test_plot_moe_expert_heatmap.py ================================================ import importlib.util from pathlib import Path from unittest.mock import MagicMock import numpy as np SCRIPT_PATH = ( Path(__file__).resolve().parents[1] / "not_for_commit" / "plot_moe_expert_heatmap.py" ) def _load_plot_moe_expert_heatmap_module(): spec = importlib.util.spec_from_file_location( "plot_moe_expert_heatmap", SCRIPT_PATH ) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) return module def _write_eval_npz(path: Path) -> None: np.savez( path, robot_moe_expert_logits=np.array( [ [[0.0, 1.0, 2.0, 3.0], [1.5, 0.5, -0.5, -1.5]], [[0.1, 1.1, 2.1, 3.1], [1.0, 0.0, -1.0, -2.0]], [[0.2, 1.2, 2.2, 3.2], [0.5, -0.5, -1.5, -2.5]], [[0.3, 1.3, 2.3, 3.3], [0.0, -1.0, -2.0, -3.0]], [[0.4, 1.4, 2.4, 3.4], [-0.5, -1.5, -2.5, -3.5]], ], dtype=np.float32, ), robot_moe_expert_indices=np.array( [ [[3, 2], [0, 1]], [[3, 2], [0, 1]], [[3, 2], [0, 1]], [[3, 2], [0, 1]], [[3, 2], [0, 1]], ], dtype=np.int64, ), robot_dof_torque=np.linspace(-1.0, 1.0, 15, dtype=np.float32).reshape( 5, 3 ), robot_actions=np.linspace(-0.5, 0.5, 15, dtype=np.float32).reshape( 5, 3 ), robot_low_level_dof_torque=np.zeros((20, 3), dtype=np.float32), robot_low_level_torque_dt=np.array(0.01, dtype=np.float32), ) def test_plot_dump_exports_moe_heatmap_pdf(tmp_path): module = _load_plot_moe_expert_heatmap_module() npz_path = tmp_path / "demo_eval.npz" _write_eval_npz(npz_path) output_path = module.plot_dump(npz_path) assert output_path == ( tmp_path / "demo_eval_moe_expert_probability_heatmap.pdf" ) assert output_path.is_file() assert (tmp_path / "demo_eval_robot_dof_torque_line_plot.pdf").is_file() assert (tmp_path / "demo_eval_robot_actions_line_plot.pdf").is_file() def test_selected_expert_weights_are_renormalized_within_selected_ids(): module = _load_plot_moe_expert_heatmap_module() probabilities = np.array( [ [[0.1, 0.2, 0.3, 0.4], [0.7, 0.1, 0.1, 0.1]], [[0.25, 0.25, 0.25, 0.25], [0.05, 0.15, 0.3, 0.5]], ], dtype=np.float32, ) expert_indices = np.array( [ [[1, 3], [0, 2]], [[0, 2], [1, 3]], ], dtype=np.int64, ) selected_weights = module.compute_selected_expert_weights( probabilities, expert_indices ) np.testing.assert_allclose( selected_weights, np.array( [ [[1.0 / 3.0, 2.0 / 3.0], [0.875, 0.125]], [[0.5, 0.5], [0.23076923, 0.7692308]], ], dtype=np.float32, ), ) def test_selected_expert_heatmap_only_colors_activated_experts(): module = _load_plot_moe_expert_heatmap_module() probabilities = np.array( [ [[0.1, 0.2, 0.3, 0.4], [0.7, 0.1, 0.1, 0.1]], [[0.25, 0.25, 0.25, 0.25], [0.05, 0.15, 0.3, 0.5]], ], dtype=np.float32, ) expert_indices = np.array( [ [[1, 3], [0, 2]], [[0, 2], [1, 3]], ], dtype=np.int64, ) selected_heatmap = module.build_selected_expert_heatmap( probabilities, expert_indices ) np.testing.assert_allclose( selected_heatmap, np.array( [ [ [0.0, 1.0 / 3.0, 0.0, 2.0 / 3.0], [0.875, 0.0, 0.125, 0.0], ], [ [0.5, 0.0, 0.5, 0.0], [0.0, 0.23076923, 0.0, 0.7692308], ], ], dtype=np.float32, ), ) def test_collect_npz_paths_recursively_sorts_directory_entries(tmp_path): module = _load_plot_moe_expert_heatmap_module() input_dir = tmp_path / "evals" first_npz = input_dir / "z_branch" / "clip_z.npz" second_npz = input_dir / "a_branch" / "nested" / "clip_a.npz" second_npz.parent.mkdir(parents=True) first_npz.parent.mkdir(parents=True) _write_eval_npz(first_npz) _write_eval_npz(second_npz) (input_dir / "ignore.txt").write_text("ignore", encoding="utf-8") assert module.collect_npz_paths(input_dir) == [second_npz, first_npz] def test_plot_input_path_directory_generates_all_heatmaps_with_tqdm( tmp_path, ): module = _load_plot_moe_expert_heatmap_module() input_dir = tmp_path / "evals" npz_paths = [ input_dir / "z_branch" / "clip_z.npz", input_dir / "a_branch" / "nested" / "clip_a.npz", ] for npz_path in npz_paths: npz_path.parent.mkdir(parents=True, exist_ok=True) _write_eval_npz(npz_path) expected_output_paths = [ input_dir / "a_branch" / "nested" / "clip_a_moe_expert_probability_heatmap.pdf", input_dir / "z_branch" / "clip_z_moe_expert_probability_heatmap.pdf", ] fake_tqdm = MagicMock(side_effect=lambda iterable, **_: iterable) original_tqdm = module.tqdm module.tqdm = fake_tqdm try: output_paths = module.plot_input_path(input_dir) finally: module.tqdm = original_tqdm assert output_paths == expected_output_paths assert all(path.is_file() for path in expected_output_paths) expected_torque_paths = [ input_dir / "a_branch" / "nested" / "clip_a_robot_dof_torque_line_plot.pdf", input_dir / "z_branch" / "clip_z_robot_dof_torque_line_plot.pdf", ] assert all(path.is_file() for path in expected_torque_paths) expected_action_paths = [ input_dir / "a_branch" / "nested" / "clip_a_robot_actions_line_plot.pdf", input_dir / "z_branch" / "clip_z_robot_actions_line_plot.pdf", ] assert all(path.is_file() for path in expected_action_paths) assert list(fake_tqdm.call_args.args[0]) == sorted(npz_paths) assert fake_tqdm.call_args.kwargs == { "desc": "Generating plot PDFs", "unit": "file", "dynamic_ncols": True, } def test_plot_dump_requires_2d_robot_dof_torque(tmp_path): module = _load_plot_moe_expert_heatmap_module() npz_path = tmp_path / "bad_eval.npz" np.savez( npz_path, robot_moe_expert_logits=np.zeros((2, 1, 3), dtype=np.float32), robot_dof_torque=np.zeros((2,), dtype=np.float32), ) try: module.plot_dump(npz_path) except ValueError as exc: assert "robot_dof_torque must have shape [frames, dofs]" in str(exc) else: raise AssertionError( "Expected plot_dump to reject 1-D robot_dof_torque" ) ================================================ FILE: tests/test_plot_state_series.py ================================================ import importlib.util from pathlib import Path import numpy as np SCRIPT_PATH = ( Path(__file__).resolve().parents[1] / "not_for_commit" / "plot_state_series.py" ) def _load_plot_state_series_module(): spec = importlib.util.spec_from_file_location( "plot_state_series", SCRIPT_PATH ) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) return module def test_plot_dump_exports_time_matched_scalar_series(tmp_path): module = _load_plot_state_series_module() robot_config_path = tmp_path / "robot.yaml" robot_config_path.write_text( "robot:\n dof_names:\n - joint_a\n - joint_b\n", encoding="utf-8", ) npz_path = tmp_path / "demo_eval.npz" np.savez( npz_path, robot_dof_torque=np.arange(10, dtype=np.float32).reshape(5, 2), robot_dof_acc=np.arange(10, 20, dtype=np.float32).reshape(5, 2), robot_action_rate=np.linspace(0.0, 1.0, 5, dtype=np.float32), reward=np.linspace(1.0, 2.0, 5, dtype=np.float32), bad_scalar=np.array([1.0, 2.0], dtype=np.float32), metadata=np.array("demo", dtype=" events.index("actor_optimizer") assert events.index("override_sigma") > events.index("critic_optimizer") algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict) def test_ppo_load_skips_optimizer_restore_during_offline_eval(): algo = PPO.__new__(PPO) algo.is_main_process = False algo.is_offline_eval = True algo.device = torch.device("cpu") algo.actor = nn.Linear(1, 1) algo.critic = nn.Linear(1, 1) algo.accelerator = SimpleNamespace(unwrap_model=lambda model: model) algo.config = {} algo._load_extra_checkpoint_state = mock.Mock() algo._resolve_model_file_path = ( lambda ckpt_path, model_name: f"{ckpt_path}:{model_name}" ) algo._load_accelerate_model = mock.Mock() algo._maybe_override_loaded_actor_sigma = mock.Mock() algo.actor_optimizer = mock.Mock() algo.critic_optimizer = mock.Mock() loaded_dict = { "actor_optimizer_state_dict": {"state": {"stale": {}}}, "critic_optimizer_state_dict": {"state": {"stale": {}}}, "iter": 321, "infos": {"source": "offline-eval"}, } with mock.patch( "holomotion.src.algo.ppo.torch.load", return_value=loaded_dict ): infos = algo.load("checkpoint.pt") assert infos == {"source": "offline-eval"} assert algo.current_learning_iteration == 321 algo.actor_optimizer.load_state_dict.assert_not_called() algo.critic_optimizer.load_state_dict.assert_not_called() algo._maybe_override_loaded_actor_sigma.assert_called_once_with() algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict) def test_ppo_load_skips_incompatible_optimizer_state_restore(): algo = PPO.__new__(PPO) algo.is_main_process = False algo.is_offline_eval = False algo.device = torch.device("cpu") algo.actor = nn.Linear(1, 1) algo.critic = nn.Linear(1, 1) algo.accelerator = SimpleNamespace(unwrap_model=lambda model: model) algo.config = {} algo._load_extra_checkpoint_state = mock.Mock() algo._resolve_model_file_path = ( lambda ckpt_path, model_name: f"{ckpt_path}:{model_name}" ) algo._load_accelerate_model = mock.Mock() algo._maybe_override_loaded_actor_sigma = mock.Mock() algo.actor_optimizer = mock.Mock() algo.actor_optimizer.state_dict.return_value = { "state": {}, "param_groups": [{"params": [0]}], } algo.actor_optimizer.load_state_dict.side_effect = AssertionError( "incompatible actor optimizer state should be skipped" ) algo.critic_optimizer = mock.Mock() algo.critic_optimizer.state_dict.return_value = { "state": {}, "param_groups": [{"params": [0]}], } loaded_dict = { "actor_optimizer_state_dict": { "state": {0: {"step": torch.tensor(1)}}, "param_groups": [{"params": [0, 1]}], }, "critic_optimizer_state_dict": { "state": {0: {"step": torch.tensor(2)}}, "param_groups": [{"params": [0]}], }, "iter": 77, "infos": {"source": "resume-training"}, } with mock.patch( "holomotion.src.algo.ppo.torch.load", return_value=loaded_dict ): infos = algo.load("checkpoint.pt") assert infos == {"source": "resume-training"} assert algo.current_learning_iteration == 77 algo.actor_optimizer.load_state_dict.assert_not_called() algo.critic_optimizer.load_state_dict.assert_called_once_with( loaded_dict["critic_optimizer_state_dict"] ) algo._maybe_override_loaded_actor_sigma.assert_called_once_with() algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict) def test_checkpoint_state_to_cpu_moves_nested_tensors(): source = { "state": { 0: { "exp_avg": torch.tensor([1.0, 2.0], requires_grad=True), "exp_avg_sq": torch.tensor([3.0, 4.0]), } }, "param_groups": [{"lr": 1.0e-3}], "step_tensor": torch.tensor([5]), } converted = _checkpoint_state_to_cpu(source) assert converted is not source assert converted["state"] is not source["state"] assert converted["state"][0]["exp_avg"].device.type == "cpu" assert converted["state"][0]["exp_avg_sq"].device.type == "cpu" assert converted["step_tensor"].device.type == "cpu" assert converted["state"][0]["exp_avg"].requires_grad is False torch.testing.assert_close( converted["state"][0]["exp_avg"], source["state"][0]["exp_avg"].detach(), ) ================================================ FILE: tests/test_ppo_entropy_annealing.py ================================================ from pathlib import Path import sys from types import SimpleNamespace import pytest sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.algo.algo_base import BaseOnpolicyRL from holomotion.src.algo.ppo import PPO def _build_entropy_algo( *, initial_entropy_coef: float, anneal_entropy: bool, zero_entropy_point: float, current_learning_iteration: int, total_learning_iterations: int, num_learning_iterations: int = 0, ): algo = PPO.__new__(PPO) algo.initial_entropy_coef = float(initial_entropy_coef) algo.anneal_entropy = bool(anneal_entropy) algo.zero_entropy_point = float(zero_entropy_point) algo.current_learning_iteration = int(current_learning_iteration) algo.total_learning_iterations = int(total_learning_iterations) algo.num_learning_iterations = int(num_learning_iterations) return algo def test_entropy_coef_is_constant_when_annealing_disabled(): algo = _build_entropy_algo( initial_entropy_coef=5.0e-3, anneal_entropy=False, zero_entropy_point=1.0, current_learning_iteration=50, total_learning_iterations=100, ) assert algo._get_effective_entropy_coef() == pytest.approx(5.0e-3) def test_entropy_coef_decays_and_respects_resumed_total_iterations(): algo = _build_entropy_algo( initial_entropy_coef=5.0e-3, anneal_entropy=True, zero_entropy_point=1.0, current_learning_iteration=123, total_learning_iterations=133, num_learning_iterations=10, ) expected = 5.0e-3 * max(0.0, 1.0 - 123.0 / 133.0) assert algo._get_effective_entropy_coef() == pytest.approx(expected) def test_entropy_coef_clamps_to_zero_at_and_after_zero_point(): algo = _build_entropy_algo( initial_entropy_coef=5.0e-3, anneal_entropy=True, zero_entropy_point=0.75, current_learning_iteration=75, total_learning_iterations=100, ) assert algo._get_effective_entropy_coef() == pytest.approx(0.0) algo.current_learning_iteration = 90 assert algo._get_effective_entropy_coef() == pytest.approx(0.0) @pytest.mark.parametrize( ("initial_entropy_coef", "anneal_entropy", "zero_entropy_point"), [ (-1.0, False, 1.0), (1.0, True, 0.0), (1.0, True, -0.1), (1.0, True, 1.1), ], ) def test_validate_entropy_schedule_config_rejects_invalid_values( initial_entropy_coef: float, anneal_entropy: bool, zero_entropy_point: float, ): with pytest.raises(ValueError): PPO._validate_entropy_schedule_config( initial_entropy_coef=initial_entropy_coef, anneal_entropy=anneal_entropy, zero_entropy_point=zero_entropy_point, ) def test_learn_sets_current_iteration_before_each_update(): algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL) algo.env = SimpleNamespace(reset_all=lambda: ({},)) algo._wrap_obs_dict = lambda obs_dict: obs_dict algo._ensure_storage = lambda obs_td: None algo.train_mode = lambda: None algo.rollout_policy = lambda obs_td: obs_td algo.log_dir = "/tmp/holomotion-test" algo.num_learning_iterations = 3 algo.current_learning_iteration = 5 algo.total_learning_iterations = 0 algo.log_interval = 100 algo.save_interval = 100 algo.is_main_process = False algo.ep_infos = [] algo._post_iteration_hook = lambda it: None algo._post_training_hook = lambda: None algo._release_cuda_cache = lambda: None algo.save = lambda *args, **kwargs: None algo.accelerator = SimpleNamespace( wait_for_everyone=lambda: None, end_training=lambda: None, ) observed_iterations = [] observed_totals = [] def _update(): observed_iterations.append(algo.current_learning_iteration) observed_totals.append(algo.total_learning_iterations) return {} algo.update = _update BaseOnpolicyRL.learn(algo) assert observed_iterations == [5, 6, 7] assert observed_totals == [8, 8, 8] ================================================ FILE: tests/test_ppo_symmetry_loss.py ================================================ from contextlib import nullcontext from types import ModuleType, SimpleNamespace import sys import pytest import torch import torch.nn as nn from omegaconf import OmegaConf from tensordict import TensorDict from holomotion.src.algo.ppo import PPO class _DummyAccelerator: def autocast(self): return nullcontext() def backward(self, loss): loss.backward() def clip_grad_norm_(self, parameters, max_norm): torch.nn.utils.clip_grad_norm_(list(parameters), max_norm) def reduce(self, tensor, reduction="mean"): return tensor class _DummyActor(nn.Module): def __init__(self, num_actions: int, mirror_offset: float): super().__init__() self.mu_param = nn.Parameter(torch.full((num_actions,), 0.25)) self.log_std = nn.Parameter(torch.zeros(num_actions)) self.mirror_offset = float(mirror_offset) def forward( self, obs_td: TensorDict, actions: torch.Tensor | None = None, mode: str = "sampling", *, update_obs_norm: bool = True, ) -> TensorDict: del obs_td, update_obs_norm batch_size = int(actions.shape[0]) if actions is not None else 2 mu = self.mu_param.unsqueeze(0).expand(batch_size, -1) sigma = torch.exp(self.log_std).unsqueeze(0).expand(batch_size, -1) out = TensorDict({}, batch_size=[batch_size]) out.set("mu", mu) out.set("sigma", sigma) if mode == "inference": out.set("actions", mu + self.mirror_offset) return out if actions is None: actions = mu out.set("actions", actions) zero_with_grad = mu.sum(dim=-1) * 0.0 out.set("actions_log_prob", zero_with_grad) out.set("entropy", zero_with_grad) return out class _DummyCritic(nn.Module): def __init__(self): super().__init__() self.value = nn.Parameter(torch.tensor([0.1], dtype=torch.float32)) def forward(self, obs_td: TensorDict, *, update_obs_norm: bool = True): del obs_td, update_obs_norm batch_size = 2 out = TensorDict({}, batch_size=[batch_size]) out.set("values", self.value.view(1, 1).expand(batch_size, 1)) return out class _SingleBatchStorage: def __init__(self, batch): self._batch = batch self.data = { "returns": torch.zeros(2, 1, 1, dtype=torch.float32), "values": torch.zeros(2, 1, 1, dtype=torch.float32), } self.num_envs = 1 self.num_transitions_per_env = 2 self.cleared = False def iter_minibatches(self, num_mini_batches: int, num_epochs: int): del num_mini_batches, num_epochs yield self._batch def clear(self): self.cleared = True def _install_mirror_stub(): module = ModuleType( "holomotion.src.env.isaaclab_components.isaaclab_observation" ) class MirrorFunctions: @staticmethod def mirror_dof( x: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor ): perm = perm.to(device=x.device, dtype=torch.long) sign = sign.to(device=x.device, dtype=x.dtype) mirrored = torch.index_select(x, dim=x.ndim - 1, index=perm) view_shape = [1] * (mirrored.ndim - 1) + [int(sign.numel())] return mirrored * sign.view(*view_shape) @staticmethod def mirror_action( actions: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor ): return MirrorFunctions.mirror_dof(actions, perm=perm, sign=sign) @staticmethod def mirror_vec3(x: torch.Tensor): sign = torch.tensor( [1.0, -1.0, 1.0], device=x.device, dtype=x.dtype ) view_shape = [1] * (x.ndim - 1) + [3] return x * sign.view(*view_shape) @staticmethod def mirror_axial_vec3(x: torch.Tensor): sign = torch.tensor( [-1.0, 1.0, -1.0], device=x.device, dtype=x.dtype ) view_shape = [1] * (x.ndim - 1) + [3] return x * sign.view(*view_shape) @staticmethod def mirror_velocity_command(x: torch.Tensor): if x.shape[-1] == 3: sign = torch.tensor( [1.0, -1.0, -1.0], device=x.device, dtype=x.dtype ) else: sign = torch.tensor( [1.0, 1.0, -1.0, -1.0], device=x.device, dtype=x.dtype ) view_shape = [1] * (x.ndim - 1) + [int(sign.numel())] return x * sign.view(*view_shape) module.MirrorFunctions = MirrorFunctions sys.modules[module.__name__] = module def test_setup_symmetry_builds_expected_dof_permutation_and_signs(): _install_mirror_stub() algo = PPO.__new__(PPO) algo.device = torch.device("cpu") algo.num_actions = 5 algo.command_name = "base_velocity" algo.symmetry_loss_enabled = True algo.is_main_process = False algo.config = OmegaConf.create( { "module_dict": { "actor": { "obs_schema": { "flattened_obs": { "seq_len": 2, "terms": ["unified/actor_dof_pos"], } } } }, "symmetry_loss": { "enabled": True, "coef": 0.1, "dof_sign_by_name": { "left_hip_pitch_joint": 1.0, "right_hip_pitch_joint": 1.0, "waist_yaw_joint": -1.0, "left_knee_joint": 1.0, "right_knee_joint": 1.0, }, }, } ) algo.env = SimpleNamespace( _env=SimpleNamespace( scene={ "robot": SimpleNamespace( joint_names=[ "left_hip_pitch_joint", "right_hip_pitch_joint", "waist_yaw_joint", "left_knee_joint", "right_knee_joint", ] ) } ) ) algo.env_config = OmegaConf.create( { "config": { "robot": { "dof_sign_by_name": { "left_hip_pitch_joint": 1.0, "right_hip_pitch_joint": 1.0, "waist_yaw_joint": -1.0, "left_knee_joint": 1.0, "right_knee_joint": 1.0, } }, "obs": { "obs_groups": { "unified": { "atomic_obs_list": [ { "actor_dof_pos": { "mirror_func": "mirror_dof", } } ] } } }, } } ) algo._setup_symmetry() assert algo._sym_dof_perm.tolist() == [1, 0, 2, 4, 3] assert algo._sym_dof_sign.tolist() == [1.0, 1.0, -1.0, 1.0, 1.0] def test_mirror_actor_obs_uses_slash_qualified_actor_terms_only(): _install_mirror_stub() algo = PPO.__new__(PPO) algo.command_name = "base_velocity" algo.symmetry_loss_enabled = True algo.symmetry_loss_coef = 0.1 algo._obs_mirror_map = { "unified/actor_velocity_command": lambda x: x * 2.0, "unified/actor_dof_pos": lambda x: x + 1.0, } obs_td = TensorDict.from_dict( { "unified": { "actor_velocity_command": torch.tensor( [[[1.0, 2.0, 3.0]]], dtype=torch.float32 ), "actor_dof_pos": torch.tensor( [[[0.1, 0.2]]], dtype=torch.float32 ), "critic_dof_pos": torch.tensor( [[9.0, 8.0]], dtype=torch.float32 ), } }, batch_size=[1], device="cpu", ) mirrored = algo._mirror_actor_obs(obs_td) torch.testing.assert_close( mirrored["unified", "actor_velocity_command"], torch.tensor([[[2.0, 4.0, 6.0]]], dtype=torch.float32), ) torch.testing.assert_close( mirrored["unified", "actor_dof_pos"], torch.tensor([[[1.1, 1.2]]], dtype=torch.float32), ) torch.testing.assert_close( mirrored["unified", "critic_dof_pos"], obs_td["unified", "critic_dof_pos"], ) def test_update_reports_symmetry_loss_only_for_velocity_tracking(): algo = PPO.__new__(PPO) algo.device = torch.device("cpu") algo.accelerator = _DummyAccelerator() algo.actor = _DummyActor(num_actions=2, mirror_offset=1.0) algo.critic = _DummyCritic() algo.actor_optimizer = torch.optim.SGD(algo.actor.parameters(), lr=0.01) algo.critic_optimizer = torch.optim.SGD(algo.critic.parameters(), lr=0.01) algo.storage = _SingleBatchStorage( SimpleNamespace( obs=TensorDict.from_dict( { "unified": { "actor_dof_pos": torch.zeros(2, 1, 2), "critic_dof_pos": torch.zeros(2, 2), } }, batch_size=[2], device="cpu", ), actions=torch.zeros(2, 2), values=torch.zeros(2, 1), advantages=torch.zeros(2, 1), returns=torch.zeros(2, 1), actions_log_prob=torch.zeros(2, 1), mu=torch.zeros(2, 2), sigma=torch.ones(2, 2), ) ) algo.value_loss_coef = 1.0 algo.clip_param = 0.2 algo.max_grad_norm = 1.0 algo.schedule = "fixed" algo.desired_kl = None algo.distributed_update_mode = "legacy" algo.num_mini_batches = 1 algo.num_learning_epochs = 1 algo.configured_num_mini_batches = 1 algo.requested_num_mini_batches = 1 algo.distributed_lr_scale_factor = 1.0 algo.entropy_coef = 0.0 algo.initial_entropy_coef = 0.0 algo.anneal_entropy = False algo.use_clipped_value_loss = False algo.actor_learning_rate = 1.0e-3 algo.critic_learning_rate = 1.0e-3 algo.global_advantage_norm = True algo.is_distributed = False algo.symmetry_loss_enabled = True algo.symmetry_loss_coef = 0.5 algo._mirror_actor_obs = lambda obs_td: obs_td algo._mirror_env_action = lambda actions: actions algo._post_update_hook = lambda loss_dict: None algo.command_name = "base_velocity" velocity_loss = algo.update() assert velocity_loss["symmetry_loss"] == pytest.approx(1.0) algo.storage = _SingleBatchStorage(algo.storage._batch) algo.command_name = "ref_motion" non_velocity_loss = algo.update() assert "symmetry_loss" not in non_velocity_loss ================================================ FILE: tests/test_ppo_tf_aux_keybody.py ================================================ import copy import sys from types import ModuleType, SimpleNamespace from unittest import mock import pytest import torch import torch.nn as nn import torch.nn.functional as F from holomotion.src.algo.algo_utils import PpoAuxTransition, RolloutStorage from holomotion.src.algo.ppo import PPO from holomotion.src.algo.ppo_tf import PPOTF from holomotion.src.modules.agent_modules import ( PPOTFActor, TensorDictAssembler, ) from holomotion.src.modules.network_modules import GroupedMoETransformerPolicy from holomotion.src.modules.network_modules import GroupedMoEBlock from holomotion.src.modules.network_modules import ModernTransformerBlock from tensordict import TensorDict def _make_aux_policy( *, denoise_root_lin_vel_weight: float = 1.0e-2, denoise_root_ang_vel_weight: float = 1.0e-2, denoise_dof_pos_weight: float = 1.0e-2, ) -> GroupedMoETransformerPolicy: module_config = { "num_fine_experts": 2, "num_shared_experts": 1, "top_k": 1, "routing_score_fn": "softmax", "routing_scale": 1.0, "use_dynamic_bias": False, "bias_update_rate": 0.001, "expert_bias_clip": 0.0, "obs_embed_mlp_hidden": 16, "d_model": 8, "n_heads": 2, "n_kv_heads": 1, "use_gated_attn": False, "n_layers": 1, "ff_mult": 1.0, "ff_mult_dense": 2, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "aux_state_pred": { "enabled": True, "w_denoise_ref_root_lin_vel": denoise_root_lin_vel_weight, "w_denoise_ref_root_ang_vel": denoise_root_ang_vel_weight, "w_denoise_ref_dof_pos": denoise_dof_pos_weight, "keybody_contact_names": [ "left_knee_link", "right_knee_link", ], "keybody_rel_pos_names": [ "left_knee_link", "right_knee_link", ], }, } return GroupedMoETransformerPolicy( input_dim=6, output_dim=4, module_config_dict=module_config, ) def _make_aux_actor() -> PPOTFActor: actor = PPOTFActor.__new__(PPOTFActor) nn.Module.__init__(actor) actor.actor_module = _make_aux_policy() actor.aux_state_pred_enabled = True actor.aux_router_command_recon_enabled = False actor.aux_router_switch_penalty_enabled = False actor.obs_norm_enabled = False actor.obs_normalizer = nn.Identity() actor.obs_norm_clip = 0.0 actor.actor_obs_transforms = [] actor.assembler = TensorDictAssembler( {"flat_obs": {"seq_len": 1, "terms": ["flat_obs"]}}, output_mode="flat", ) actor.min_sigma = 0.01 actor.max_sigma = 1.0 actor.log_std = nn.Parameter(torch.zeros(4, dtype=torch.float32)) return actor def _make_aux_command_policy( *, n_layers: int = 3, dense_layer_at_last: bool = False, enable_aux_router_command_recon: bool = True, freeze_router: bool = False, ) -> GroupedMoETransformerPolicy: module_config = { "num_fine_experts": 2, "num_shared_experts": 1, "top_k": 1, "routing_score_fn": "softmax", "routing_scale": 1.0, "use_dynamic_bias": False, "bias_update_rate": 0.001, "expert_bias_clip": 0.0, "obs_embed_mlp_hidden": 16, "d_model": 8, "n_heads": 2, "n_kv_heads": 1, "use_gated_attn": False, "n_layers": n_layers, "ff_mult": 1.0, "ff_mult_dense": 2, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "dense_layer_at_last": dense_layer_at_last, "freeze_router": freeze_router, "aux_router_command_recon": { "enabled": enable_aux_router_command_recon, "output_dim": 5, "hidden_dim": 7, }, } return GroupedMoETransformerPolicy( input_dim=6, output_dim=4, module_config_dict=module_config, ) def _make_temporal_aux_actor() -> PPOTFActor: actor = PPOTFActor.__new__(PPOTFActor) nn.Module.__init__(actor) actor.actor_module = _make_aux_command_policy() actor.aux_state_pred_enabled = False actor.aux_router_command_recon_enabled = True actor.aux_router_switch_penalty_enabled = True actor.obs_norm_enabled = False actor.obs_normalizer = nn.Identity() actor.obs_norm_clip = 0.0 actor.actor_obs_transforms = [] actor.assembler = TensorDictAssembler( {"flat_obs": {"seq_len": 1, "terms": ["flat_obs"]}}, output_mode="flat", ) actor.min_sigma = 0.01 actor.max_sigma = 1.0 actor.log_std = nn.Parameter(torch.zeros(4, dtype=torch.float32)) return actor def _make_temporal_only_aux_actor() -> PPOTFActor: actor = _make_temporal_aux_actor() actor.aux_router_command_recon_enabled = False return actor def test_rollout_storage_allocates_ref_and_robot_keybody_targets(): original_tokens = dict(PpoAuxTransition.SHAPE_TOKENS) PpoAuxTransition.SHAPE_TOKENS["C"] = 2 PpoAuxTransition.SHAPE_TOKENS["K"] = 8 try: obs_template = TensorDict( {"flat_obs": torch.zeros(2, 5)}, batch_size=[2], ) storage = RolloutStorage( num_envs=2, num_transitions_per_env=3, obs_template=obs_template, actions_shape=(4,), transition_cls=PpoAuxTransition, ) finally: PpoAuxTransition.SHAPE_TOKENS = original_tokens assert storage.data["gt_ref_keybody_rel_pos"].shape == (3, 2, 8, 3) assert storage.data["gt_robot_keybody_rel_pos"].shape == (3, 2, 8, 3) assert storage.data["gt_denoise_ref_root_lin_vel"].shape == (3, 2, 3) assert storage.data["gt_denoise_ref_root_ang_vel"].shape == (3, 2, 3) assert storage.data["gt_denoise_ref_dof_pos"].shape == (3, 2, 4) def test_grouped_moe_policy_returns_keybody_position_predictions(): policy = _make_aux_policy() pre_moe_hidden = torch.randn(2, 3, policy.d_model) outputs = policy.predict_aux_from_pre_moe(pre_moe_hidden) assert outputs["base_lin_vel_loc"].shape == (2, 3, 3) assert outputs["base_lin_vel_log_std"].shape == (2, 3, 3) assert outputs["root_height_loc"].shape == (2, 3, 1) assert outputs["root_height_log_std"].shape == (2, 3, 1) assert outputs["keybody_contact_logits"].shape == (2, 3, 2) assert outputs["ref_keybody_rel_pos"].shape == (2, 3, 2, 3) assert outputs["robot_keybody_rel_pos"].shape == (2, 3, 2, 3) assert outputs["denoise_ref_root_lin_vel_residual"].shape == (2, 3, 3) assert outputs["denoise_ref_root_ang_vel_residual"].shape == (2, 3, 3) assert outputs["denoise_ref_dof_pos_residual"].shape == (2, 3, 4) def test_grouped_moe_policy_omits_denoise_predictions_when_weights_zero(): policy = _make_aux_policy( denoise_root_lin_vel_weight=0.0, denoise_root_ang_vel_weight=0.0, denoise_dof_pos_weight=0.0, ) pre_moe_hidden = torch.randn(2, 3, policy.d_model) outputs = policy.predict_aux_from_pre_moe(pre_moe_hidden) assert "denoise_ref_root_lin_vel_residual" not in outputs assert "denoise_ref_root_ang_vel_residual" not in outputs assert "denoise_ref_dof_pos_residual" not in outputs def test_ppotf_actor_sequence_logp_emits_actor_facing_dof_denoise_keys(): actor = _make_aux_actor() obs_td = TensorDict( {"flat_obs": torch.randn(2, 3, 6, dtype=torch.float32)}, batch_size=[2, 3], ) actions = torch.randn(2, 3, 4, dtype=torch.float32) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) outputs = actor( obs_td, actions=actions, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=False, ) assert outputs["aux_denoise_ref_dof_pos_residual"].shape == (2, 3, 4) assert "aux_denoise_ref_keybody_rel_pos_loc" not in outputs.keys() assert "aux_denoise_ref_keybody_rel_pos_log_std" not in outputs.keys() def test_grouped_moe_policy_default_layout_keeps_dense_first_and_moe_tail(): policy = _make_aux_command_policy( enable_aux_router_command_recon=False, ) assert len(policy.layers) == 3 assert isinstance(policy.layers[0], ModernTransformerBlock) assert all( isinstance(layer, GroupedMoEBlock) for layer in policy.layers[1:] ) assert policy._num_moe_layers == 2 assert policy._last_moe_layer_idx == 2 def test_grouped_moe_policy_dense_layer_at_last_keeps_only_middle_layers_moe(): policy = _make_aux_command_policy( n_layers=4, dense_layer_at_last=True, ) obs_seq = torch.randn(2, 3, 6, dtype=torch.float32) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) _, router_features = policy.sequence_mu( obs_seq, attn_mask=attn_mask, return_router_features=True, ) assert isinstance(policy.layers[0], ModernTransformerBlock) assert isinstance(policy.layers[1], GroupedMoEBlock) assert isinstance(policy.layers[2], GroupedMoEBlock) assert isinstance(policy.layers[3], ModernTransformerBlock) assert policy._num_moe_layers == 2 assert policy._last_moe_layer_idx == 2 assert router_features.shape == (2, 3, 4) def test_grouped_moe_policy_dense_layer_at_last_allows_shallow_fully_dense(): policy = _make_aux_command_policy( n_layers=2, dense_layer_at_last=True, enable_aux_router_command_recon=False, ) assert len(policy.layers) == 2 assert all( isinstance(layer, ModernTransformerBlock) for layer in policy.layers ) assert policy._num_moe_layers == 0 assert policy._last_moe_layer_idx is None def test_grouped_moe_policy_command_recon_uses_live_router_features(): policy = _make_aux_command_policy() obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) _, router_features = policy.sequence_mu( obs_seq, attn_mask=attn_mask, return_router_features=True, ) pred = policy.predict_aux_router_command_from_router_features( router_features ) assert policy._num_moe_layers == 2 assert router_features.shape == (2, 3, 4) assert pred.shape == (2, 3, 5) assert router_features.requires_grad pred.sum().backward() first_moe = next( layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock) ) assert first_moe.last_router_distribution is not None assert first_moe.last_router_distribution.requires_grad assert first_moe.router.weight.grad is not None def test_grouped_moe_policy_freeze_router_detaches_router_features_and_params(): policy = _make_aux_command_policy(freeze_router=True) obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) _, router_features, router_temporal_features = policy.sequence_mu( obs_seq, attn_mask=attn_mask, return_router_features=True, return_router_temporal_features=True, ) pred = policy.predict_aux_router_command_from_router_features( router_features ) first_moe = next( layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock) ) assert first_moe.freeze_router is True assert first_moe.router.weight.requires_grad is False assert router_features.requires_grad is False assert router_temporal_features.requires_grad is False pred.sum().backward() assert first_moe.last_router_distribution is not None assert first_moe.last_router_distribution.requires_grad is False assert first_moe.last_router_logits is not None assert first_moe.last_router_logits.requires_grad is False assert first_moe.router.weight.grad is None def test_grouped_moe_policy_loads_legacy_aux_command_recon_head_keys(): policy = _make_aux_command_policy(enable_aux_router_command_recon=True) state_dict = copy.deepcopy(policy.state_dict()) expected_tensors = {} for key in list(state_dict.keys()): if "aux_router_command_recon_head." not in key: continue legacy_key = key.replace( "aux_router_command_recon_head.", "aux_command_recon_head.", ) legacy_value = torch.randn_like(state_dict[key]) expected_tensors[key] = legacy_value state_dict[legacy_key] = legacy_value del state_dict[key] result = policy.load_state_dict(state_dict, strict=True) assert result.missing_keys == [] assert result.unexpected_keys == [] for key, expected in expected_tensors.items(): actual = policy.state_dict()[key] assert torch.allclose(actual, expected) def test_grouped_moe_policy_ignores_legacy_aux_command_recon_head_keys_when_disabled(): policy = _make_aux_command_policy(enable_aux_router_command_recon=False) legacy_policy = _make_aux_command_policy( enable_aux_router_command_recon=True ) state_dict = copy.deepcopy(policy.state_dict()) for key, value in legacy_policy.state_dict().items(): if "aux_router_command_recon_head." not in key: continue legacy_key = key.replace( "aux_router_command_recon_head.", "aux_command_recon_head.", ) state_dict[legacy_key] = value.clone() result = policy.load_state_dict(state_dict, strict=True) assert result.missing_keys == [] assert result.unexpected_keys == [] assert policy.aux_router_command_recon_head is None def test_grouped_moe_policy_clears_router_cache_before_deepcopy(): policy = _make_aux_command_policy() obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) _, router_features = policy.sequence_mu( obs_seq, attn_mask=attn_mask, return_router_features=True, ) pred = policy.predict_aux_router_command_from_router_features( router_features ) pred.sum().backward() first_moe = next( layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock) ) assert first_moe.last_router_distribution is not None policy.clear_router_distribution_cache() copied = copy.deepcopy(policy) copied_first_moe = next( layer for layer in copied.layers if isinstance(layer, GroupedMoEBlock) ) assert copied_first_moe.last_router_distribution is None def test_grouped_moe_block_tracks_least_utilized_expert_stats(): block = GroupedMoEBlock( d_model=8, n_heads=2, n_kv_heads=1, num_fine_experts=4, num_shared_experts=1, top_k=1, ff_mult=1.0, use_qk_norm=True, use_gated_attn=False, attn_dropout=0.0, mlp_dropout=0.0, use_dynamic_bias=False, routing_score_fn="softmax", ) block._apply_bias_update_from_counts(torch.tensor([5, 3, 0, 2])) assert block.last_active_expert_ratio.item() == pytest.approx(0.75) assert block.last_max_expert_frac.item() == pytest.approx(0.5) assert block.last_min_expert_frac.item() == pytest.approx(0.0) assert block.last_dead_expert_ratio.item() == pytest.approx(0.25) block._apply_bias_update_from_counts(torch.tensor([5, 3, 1, 1])) assert block.last_min_expert_frac.item() == pytest.approx(0.1) assert block.last_dead_expert_ratio.item() == pytest.approx(0.0) def test_grouped_moe_block_tracks_dead_expert_margin_to_topk_loss(): block = GroupedMoEBlock( d_model=4, n_heads=2, n_kv_heads=1, num_fine_experts=3, num_shared_experts=1, top_k=1, ff_mult=1.0, use_qk_norm=True, use_gated_attn=False, attn_dropout=0.0, mlp_dropout=0.0, use_dynamic_bias=False, routing_score_fn="softmax", dead_expert_margin_to_topk_enabled=True, ) topk_idx = torch.tensor([[[0], [0]]], dtype=torch.long) dense_distribution = torch.tensor( [[[0.8, 0.15, 0.05], [0.7, 0.2, 0.1]]], dtype=torch.float32 ) choice_scores = torch.log(dense_distribution) loss = block._update_routed_expert_stats_and_floor_loss( topk_idx=topk_idx, dense_distribution=dense_distribution, choice_scores=choice_scores, ) expected = torch.relu( choice_scores.gather(-1, topk_idx)[..., -1:] - choice_scores ) expected = expected[..., 1:].sum() / 4.0 torch.testing.assert_close(loss, expected) torch.testing.assert_close( block.last_dead_expert_margin_to_topk_loss, expected ) torch.testing.assert_close( block.last_dead_expert_margin_to_topk_loss_value, expected.detach(), ) torch.testing.assert_close( block.last_dead_expert_margin_to_topk_target, choice_scores.gather(-1, topk_idx)[..., -1:].mean(), ) torch.testing.assert_close( block.last_dense_expert_usage, dense_distribution.mean(dim=(0, 1)), ) def test_grouped_moe_block_tracks_selected_expert_margin_to_unselected(): block = GroupedMoEBlock( d_model=4, n_heads=2, n_kv_heads=1, num_fine_experts=4, num_shared_experts=1, top_k=2, ff_mult=1.0, use_qk_norm=True, use_gated_attn=False, attn_dropout=0.0, mlp_dropout=0.0, use_dynamic_bias=False, routing_score_fn="softmax", selected_expert_margin_to_unselected_enabled=True, selected_expert_margin_to_unselected_target=0.4, ) topk_idx = torch.tensor([[[0, 2], [1, 0]]], dtype=torch.long) dense_distribution = torch.tensor( [ [ [0.42, 0.21, 0.28, 0.09], [0.27, 0.36, 0.22, 0.15], ] ], dtype=torch.float32, ) choice_scores = torch.tensor( [[[1.0, 0.3, 0.8, 0.1], [0.9, 1.2, 0.7, 0.4]]], dtype=torch.float32, ) block._update_routed_expert_stats_and_floor_loss( topk_idx=topk_idx, dense_distribution=dense_distribution, choice_scores=choice_scores, ) expected_margin = torch.tensor((0.5 + 0.2) / 2.0) expected_loss = torch.tensor((0.0 + 0.2) / 2.0) torch.testing.assert_close( block.last_selected_expert_margin_to_unselected, expected_margin, ) torch.testing.assert_close( block.last_selected_expert_margin_to_unselected_loss, expected_loss, ) torch.testing.assert_close( block.last_selected_expert_margin_to_unselected_loss_value, expected_loss, ) def test_ppotf_summarize_moe_layer_stats_includes_least_utilized_metrics(): moe_layers = [ SimpleNamespace( last_active_expert_ratio=torch.tensor(0.75), last_max_expert_frac=torch.tensor(0.50), last_min_expert_frac=torch.tensor(0.00), last_dead_expert_ratio=torch.tensor(0.25), last_expert_count_cv=torch.tensor(1.20), last_selected_expert_margin_to_unselected=torch.tensor(0.30), ), SimpleNamespace( last_active_expert_ratio=torch.tensor(0.50), last_max_expert_frac=torch.tensor(0.30), last_min_expert_frac=torch.tensor(0.05), last_dead_expert_ratio=torch.tensor(0.50), last_expert_count_cv=torch.tensor(0.80), last_selected_expert_margin_to_unselected=torch.tensor(0.10), ), ] metrics = PPOTF._summarize_moe_layer_stats(moe_layers) assert metrics["moe_active_expert_ratio"] == pytest.approx(0.625) assert metrics["moe_max_expert_frac"] == pytest.approx(0.40) assert metrics["moe_least_expert_frac"] == pytest.approx(0.025) assert metrics["moe_dead_expert_ratio"] == pytest.approx(0.375) assert metrics["moe_expert_count_cv"] == pytest.approx(1.0) assert metrics[ "moe_selected_expert_margin_to_unselected" ] == pytest.approx(0.20) def test_compute_routed_expert_orthogonal_loss_uses_active_experts_only(): algo = PPOTF.__new__(PPOTF) algo.router_expert_orthogonal_min_active_usage = 0.1 algo.router_expert_orthogonal_eps = 1.0e-8 moe_layer = SimpleNamespace( last_routed_expert_usage=torch.tensor( [0.2, 0.12, 0.05], dtype=torch.float32 ), down_proj=torch.tensor( [ [[1.0, 0.0]], [[1.0, 1.0]], [[0.0, 1.0]], ], dtype=torch.float32, ), ) loss, active_count, mean_offdiag = ( algo._compute_routed_expert_orthogonal_loss( moe_layer, dtype=torch.float32, device=torch.device("cpu"), ) ) active_vecs = F.normalize( torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32), p=2.0, dim=-1, eps=1.0e-8, ) gram = active_vecs @ active_vecs.transpose(0, 1) offdiag = gram.masked_select(~torch.eye(2, dtype=torch.bool)) torch.testing.assert_close(active_count, torch.tensor(2.0)) torch.testing.assert_close(loss, offdiag.square().sum()) torch.testing.assert_close(mean_offdiag, offdiag.abs().mean()) def test_compute_routed_expert_orthogonal_loss_returns_zero_below_two_active(): algo = PPOTF.__new__(PPOTF) algo.router_expert_orthogonal_min_active_usage = 0.1 algo.router_expert_orthogonal_eps = 1.0e-8 moe_layer = SimpleNamespace( last_routed_expert_usage=torch.tensor( [0.2, 0.05, 0.01], dtype=torch.float32 ), down_proj=torch.randn(3, 1, 2, dtype=torch.float32), ) loss, active_count, mean_offdiag = ( algo._compute_routed_expert_orthogonal_loss( moe_layer, dtype=torch.float32, device=torch.device("cpu"), ) ) torch.testing.assert_close(loss, torch.tensor(0.0)) torch.testing.assert_close(active_count, torch.tensor(1.0)) torch.testing.assert_close(mean_offdiag, torch.tensor(0.0)) def test_ppotf_actor_sequence_logp_emits_router_features_for_aux_router_losses(): actor = _make_temporal_aux_actor() obs_td = TensorDict( {"flat_obs": torch.randn(2, 3, 6, dtype=torch.float32)}, batch_size=[2, 3], ) actions = torch.randn(2, 3, 4, dtype=torch.float32) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) outputs = actor( obs_td, actions=actions, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=False, ) assert outputs["router_features"].shape == (2, 3, 4) assert outputs["router_temporal_features"].shape == (2, 3, 4) assert outputs["aux_router_command_recon"].shape == (2, 3, 5) def test_ppotf_actor_sequence_logp_emits_only_router_features_for_temporal_only_aux(): actor = _make_temporal_only_aux_actor() obs_td = TensorDict( {"flat_obs": torch.randn(2, 3, 6, dtype=torch.float32)}, batch_size=[2, 3], ) actions = torch.randn(2, 3, 4, dtype=torch.float32) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) outputs = actor( obs_td, actions=actions, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=False, ) assert outputs["router_features"].shape == (2, 3, 4) assert outputs["router_temporal_features"].shape == (2, 3, 4) assert "aux_router_command_recon" not in outputs.keys() def test_masked_adjacent_router_js_averages_only_valid_adjacent_tokens(): router_features = torch.tensor( [ [ [0.8, 0.2, 0.6, 0.4], [0.6, 0.4, 0.5, 0.5], [0.1, 0.9, 0.4, 0.6], ] ], dtype=torch.float32, ) valid_tok = torch.tensor([[1.0, 1.0, 0.0]], dtype=torch.float32) loss = PPOTF._masked_adjacent_router_js( router_features=router_features, valid_tok=valid_tok, num_moe_layers=2, num_fine_experts=2, ) layer0_prev = torch.tensor([0.8, 0.2], dtype=torch.float32) layer0_curr = torch.tensor([0.6, 0.4], dtype=torch.float32) mix0 = 0.5 * (layer0_prev + layer0_curr) js0 = 0.5 * ( (layer0_prev * (torch.log(layer0_prev) - torch.log(mix0))).sum() + (layer0_curr * (torch.log(layer0_curr) - torch.log(mix0))).sum() ) layer1_prev = torch.tensor([0.6, 0.4], dtype=torch.float32) layer1_curr = torch.tensor([0.5, 0.5], dtype=torch.float32) mix1 = 0.5 * (layer1_prev + layer1_curr) js1 = 0.5 * ( (layer1_prev * (torch.log(layer1_prev) - torch.log(mix1))).sum() + (layer1_curr * (torch.log(layer1_curr) - torch.log(mix1))).sum() ) expected = 0.5 * (js0 + js1) assert torch.isclose(loss, expected) def test_masked_adjacent_router_normed_smooth_l1_averages_only_valid_adjacent_tokens(): router_temporal_features = torch.tensor( [ [ [3.0, 1.0, 0.0], [2.0, 0.0, 2.0], [1.0, 1.0, 1.0], ] ], dtype=torch.float32, ) valid_tok = torch.tensor([[1.0, 1.0, 0.0]], dtype=torch.float32) loss = PPOTF._masked_adjacent_router_normed_smooth_l1( router_temporal_features=router_temporal_features, valid_tok=valid_tok, num_moe_layers=1, num_fine_experts=3, ) prev_logits = router_temporal_features[:, :1].reshape(1, 1, 1, 3) curr_logits = router_temporal_features[:, 1:2].reshape(1, 1, 1, 3) prev_norm = F.normalize( prev_logits - prev_logits.mean(dim=-1, keepdim=True), p=2.0, dim=-1, eps=1.0e-5, ) curr_norm = F.normalize( curr_logits - curr_logits.mean(dim=-1, keepdim=True), p=2.0, dim=-1, eps=1.0e-5, ) expected = F.smooth_l1_loss( curr_norm, prev_norm, reduction="none", beta=1.0, ).mean() assert torch.isclose(loss, expected) def test_masked_aux_keybody_mse_averages_only_valid_tokens(): pred = torch.tensor([[[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]]) target = torch.zeros_like(pred) valid_tok = torch.tensor([[1.0, 0.0]]) loss = PPOTF._masked_aux_keybody_mse(pred, target, valid_tok) expected = torch.tensor((1.0 + 4.0 + 9.0) / 3.0) assert torch.isclose(loss, expected) def test_masked_aux_huber_averages_only_valid_tokens(): pred = torch.zeros(1, 2, 1, 3) target = torch.tensor([[[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]]) valid_tok = torch.tensor([[1.0, 0.0]]) loss = PPOTF._masked_aux_huber( pred=pred, target=target, valid_tok=valid_tok, beta=1.0, ) expected = torch.tensor((0.5 + 1.5 + 2.5) / 3.0) assert torch.isclose(loss, expected) def test_setup_configs_rejects_router_aux_terms_outside_motion_tracking(): algo = PPOTF.__new__(PPOTF) algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": { "enabled": False, }, "aux_router_switch_penalty": {"enabled": True, "weight": 1.0}, } algo.command_name = "velocity" with mock.patch.object(PPO, "_setup_configs", return_value=None): with pytest.raises(ValueError, match="aux_router_switch_penalty"): algo._setup_configs() def test_setup_configs_rejects_unknown_router_switch_penalty_metric(): algo = PPOTF.__new__(PPOTF) algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": { "enabled": False, }, "aux_router_switch_penalty": { "enabled": True, "weight": 1.0, "metric": "not_a_metric", }, } algo.command_name = "ref_motion" with mock.patch.object(PPO, "_setup_configs", return_value=None): with pytest.raises( ValueError, match="aux_router_switch_penalty.metric" ): algo._setup_configs() def test_setup_configs_reads_dead_expert_margin_to_topk_only(): algo = PPOTF.__new__(PPOTF) algo.command_name = "ref_motion" algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, "dead_expert_margin_to_topk": {"enabled": True, "weight": 0.7}, } with mock.patch.object(PPO, "_setup_configs", return_value=None): algo._setup_configs() assert algo.use_dead_expert_margin_to_topk is True assert algo.dead_expert_margin_to_topk_weight == pytest.approx(0.7) algo = PPOTF.__new__(PPOTF) algo.command_name = "ref_motion" algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, } with mock.patch.object(PPO, "_setup_configs", return_value=None): algo._setup_configs() assert algo.use_dead_expert_margin_to_topk is False assert algo.dead_expert_margin_to_topk_weight == pytest.approx(0.0) def test_setup_configs_reads_selected_expert_margin_to_unselected(): algo = PPOTF.__new__(PPOTF) algo.command_name = "ref_motion" algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, "selected_expert_margin_to_unselected": { "enabled": True, "weight": 0.9, "target": 0.3, }, } with mock.patch.object(PPO, "_setup_configs", return_value=None): algo._setup_configs() assert algo.use_selected_expert_margin_to_unselected is True assert algo.selected_expert_margin_to_unselected_weight == pytest.approx( 0.9 ) assert algo.selected_expert_margin_to_unselected_target == pytest.approx( 0.3 ) algo = PPOTF.__new__(PPOTF) algo.command_name = "ref_motion" algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, } with mock.patch.object(PPO, "_setup_configs", return_value=None): algo._setup_configs() assert algo.use_selected_expert_margin_to_unselected is False assert algo.selected_expert_margin_to_unselected_weight == pytest.approx( 0.0 ) assert algo.selected_expert_margin_to_unselected_target == pytest.approx( 0.0 ) def test_setup_configs_reads_aux_router_future_recon(): algo = PPOTF.__new__(PPOTF) algo.command_name = "ref_motion" algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, "aux_router_future_recon": { "enabled": True, "weight": 0.7, "hidden_dim": 13, "huber_beta": 0.3, }, "module_dict": { "actor": { "type": "ReferenceRoutedGroupedMoETransformerPolicyV3", } }, } with mock.patch.object(PPO, "_setup_configs", return_value=None): algo._setup_configs() assert algo.use_aux_router_future_recon is True assert algo.aux_router_future_recon_weight == pytest.approx(0.7) assert algo.aux_router_future_recon_hidden_dim == 13 assert algo.aux_router_future_recon_huber_beta == pytest.approx(0.3) def test_setup_configs_reads_router_expert_orthogonal(): algo = PPOTF.__new__(PPOTF) algo.command_name = "ref_motion" algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, "dead_expert_margin_to_topk": {"enabled": True, "weight": 0.7}, "router_expert_orthogonal": { "enabled": True, "weight": 0.9, "min_active_usage": 0.2, "eps": 1.0e-6, }, } with mock.patch.object(PPO, "_setup_configs", return_value=None): algo._setup_configs() assert algo.use_router_expert_orthogonal is True assert algo.router_expert_orthogonal_weight == pytest.approx(0.9) assert algo.router_expert_orthogonal_min_active_usage == pytest.approx(0.2) assert algo.router_expert_orthogonal_eps == pytest.approx(1.0e-6) def test_setup_configs_rejects_router_expert_orthogonal_without_dead_margin(): algo = PPOTF.__new__(PPOTF) algo.command_name = "ref_motion" algo.config = { "aux_state_pred": {"enabled": False}, "aux_router_command_recon": {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, "router_expert_orthogonal": { "enabled": True, "weight": 0.9, }, } with mock.patch.object(PPO, "_setup_configs", return_value=None): with pytest.raises(ValueError, match="requires.*dead_expert"): algo._setup_configs() def test_build_transition_uses_filtered_residual_targets_for_denoise_outputs(): algo = PPOTF.__new__(PPOTF) algo.use_aux_state_pred = True algo.use_aux_root_height = False algo.use_aux_denoise_ref_root_lin_vel = True algo.use_aux_denoise_ref_root_ang_vel = True algo.use_aux_denoise_ref_dof_pos = True algo.aux_state_pred_num_contact_bodies = 0 algo.aux_state_pred_num_keybody_bodies = 0 algo.command_name = "ref_motion" algo.num_envs = 2 algo.device = torch.device("cpu") algo.transition_cls = PpoAuxTransition world_lin_vel = torch.tensor( [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], dtype=torch.float32 ) world_ang_vel = torch.tensor( [[-1.0, -2.0, -3.0], [-4.0, -5.0, -6.0]], dtype=torch.float32 ) base_lin_vel = torch.tensor( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32 ) base_ang_vel = torch.tensor( [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32 ) dof_pos = torch.tensor( [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32 ) command = SimpleNamespace( get_ref_motion_root_global_lin_vel_cur=lambda prefix="ref_": ( world_lin_vel if prefix == "ft_ref_" else world_lin_vel + 100.0 ), get_ref_motion_root_global_ang_vel_cur=lambda prefix="ref_": ( world_ang_vel if prefix == "ft_ref_" else world_ang_vel + 100.0 ), get_ref_motion_base_linvel_cur=lambda prefix="ref_": ( base_lin_vel if prefix == "ft_ref_" else base_lin_vel + 100.0 ), get_ref_motion_base_angvel_cur=lambda prefix="ref_": ( base_ang_vel if prefix == "ft_ref_" else base_ang_vel + 100.0 ), get_ref_motion_dof_pos_cur=lambda prefix="ref_": ( dof_pos if prefix == "ft_ref_" else dof_pos + 100.0 ), ) algo.env = SimpleNamespace( _env=SimpleNamespace( command_manager=SimpleNamespace(get_term=lambda name: command) ) ) obs_td = TensorDict( {"flat_obs": torch.zeros(2, 5, dtype=torch.float32)}, batch_size=[2], ) actor_out = TensorDict( { "actions": torch.zeros(2, 4, dtype=torch.float32), "actions_log_prob": torch.zeros(2, dtype=torch.float32), "mu": torch.zeros(2, 4, dtype=torch.float32), "sigma": torch.ones(2, 4, dtype=torch.float32), }, batch_size=[2], ) critic_out = TensorDict( {"values": torch.zeros(2, 1, dtype=torch.float32)}, batch_size=[2], ) isaaclab_pkg = ModuleType("isaaclab") isaaclab_envs = ModuleType("isaaclab.envs") isaaclab_mdp = ModuleType("isaaclab.envs.mdp") isaaclab_mdp.base_lin_vel = lambda env: torch.zeros( 2, 3, dtype=torch.float32 ) isaaclab_envs.mdp = isaaclab_mdp isaaclab_pkg.envs = isaaclab_envs with mock.patch.dict( sys.modules, { "isaaclab": isaaclab_pkg, "isaaclab.envs": isaaclab_envs, "isaaclab.envs.mdp": isaaclab_mdp, }, ): transition = algo._build_transition(obs_td, actor_out, critic_out) torch.testing.assert_close( transition.gt_denoise_ref_root_lin_vel, torch.full_like(base_lin_vel, -100.0), ) torch.testing.assert_close( transition.gt_denoise_ref_root_ang_vel, torch.full_like(base_ang_vel, -100.0), ) torch.testing.assert_close( transition.gt_denoise_ref_dof_pos, torch.full_like(dof_pos, -100.0) ) def test_build_transition_rejects_mismatched_denoise_dof_target_shape(): algo = PPOTF.__new__(PPOTF) algo.use_aux_state_pred = True algo.use_aux_root_height = False algo.use_aux_denoise_ref_root_lin_vel = False algo.use_aux_denoise_ref_root_ang_vel = False algo.use_aux_denoise_ref_dof_pos = True algo.aux_state_pred_num_contact_bodies = 0 algo.aux_state_pred_num_keybody_bodies = 0 algo.command_name = "ref_motion" algo.num_envs = 2 algo.device = torch.device("cpu") algo.transition_cls = PpoAuxTransition command = SimpleNamespace( get_ref_motion_dof_pos_cur=lambda prefix="ref_": torch.zeros( 2, 5, dtype=torch.float32 ) ) algo.env = SimpleNamespace( _env=SimpleNamespace( command_manager=SimpleNamespace(get_term=lambda name: command) ) ) obs_td = TensorDict( {"flat_obs": torch.zeros(2, 5, dtype=torch.float32)}, batch_size=[2], ) actor_out = TensorDict( { "actions": torch.zeros(2, 4, dtype=torch.float32), "actions_log_prob": torch.zeros(2, dtype=torch.float32), "mu": torch.zeros(2, 4, dtype=torch.float32), "sigma": torch.ones(2, 4, dtype=torch.float32), }, batch_size=[2], ) critic_out = TensorDict( {"values": torch.zeros(2, 1, dtype=torch.float32)}, batch_size=[2], ) isaaclab_pkg = ModuleType("isaaclab") isaaclab_envs = ModuleType("isaaclab.envs") isaaclab_mdp = ModuleType("isaaclab.envs.mdp") isaaclab_mdp.base_lin_vel = lambda env: torch.zeros( 2, 3, dtype=torch.float32 ) isaaclab_envs.mdp = isaaclab_mdp isaaclab_pkg.envs = isaaclab_envs with mock.patch.dict( sys.modules, { "isaaclab": isaaclab_pkg, "isaaclab.envs": isaaclab_envs, "isaaclab.envs.mdp": isaaclab_mdp, }, ): with pytest.raises(ValueError, match="gt_denoise_ref_dof_pos"): algo._build_transition(obs_td, actor_out, critic_out) def test_compute_aux_router_future_recon_loss_uses_normalized_future_targets(): algo = PPOTF.__new__(PPOTF) algo.aux_router_future_recon_huber_beta = 0.5 obs_schema = { "flattened_obs_fut": { "seq_len": 2, "terms": [ "unified/actor_ref_base_linvel_fut", "unified/actor_ref_dof_pos_fut", ], } } obs_b = TensorDict( { "unified": TensorDict( { "actor_ref_base_linvel_fut": torch.tensor( [ [ [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], ] ], dtype=torch.float32, ), "actor_ref_dof_pos_fut": torch.tensor( [ [ [[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]], [[0.9, 1.0], [1.1, 1.2]], ] ], dtype=torch.float32, ), }, batch_size=[1, 3], ) }, batch_size=[1, 3], ) assembler = TensorDictAssembler(obs_schema, output_mode="flat") class _DummyPolicy(nn.Module): def normalize_aux_router_future_recon_target( self, future_target: torch.Tensor ) -> torch.Tensor: return future_target * 0.25 actor_wrapper = SimpleNamespace( aux_router_future_recon_assembler=assembler, actor_module=_DummyPolicy(), ) raw_target = assembler(obs_b.flatten(0, 1)).reshape(1, 3, -1) normalized_target = raw_target * 0.25 pred = normalized_target + torch.tensor( [ [ [0.0] * raw_target.shape[-1], [0.5] * raw_target.shape[-1], [1.0] * raw_target.shape[-1], ] ], dtype=torch.float32, ) actor_out = TensorDict( {"aux_router_future_recon": pred}, batch_size=[1, 3], ) valid_tok = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32) loss = algo._compute_aux_router_future_recon_loss( actor_wrapper=actor_wrapper, actor_out=actor_out, obs_b=obs_b, valid_tok=valid_tok, ) expected = PPOTF._masked_aux_huber( pred=pred, target=normalized_target, valid_tok=valid_tok, beta=0.5, ) assert torch.isclose(loss, expected) def test_root_relative_body_pos_uses_consistent_environment_frame(): body_pos_w = torch.tensor( [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32 ) root_pos_env = torch.zeros(1, 3, dtype=torch.float32) root_quat_w = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32) env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32) rel = PPOTF._root_relative_body_pos_from_mixed_position_frames( body_pos_w=body_pos_w, root_pos_env=root_pos_env, root_quat_w=root_quat_w, env_origins=env_origins, ) expected = torch.tensor( [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], dtype=torch.float32 ) assert torch.allclose(rel, expected) ================================================ FILE: tests/test_ref_router_actor.py ================================================ from __future__ import annotations import pytest import torch from holomotion.src.algo.ppo_tf import PPOTF from holomotion.src.modules.agent_modules import PPOTFRefRouterActor from tensordict import TensorDict def _make_ref_router_obs_schema() -> dict: return { "flattened_obs": { "seq_len": 1, "terms": [ "unified/actor_ref_dof_pos_cur", "unified/actor_dof_pos", "unified/actor_ref_root_height_cur", "unified/actor_last_action", ], }, "flattened_obs_fut": { "seq_len": 2, "terms": [ "unified/actor_ref_dof_pos_fut", "unified/actor_ref_root_height_fut", ], }, } def _make_ref_router_obs(batch_size: list[int]) -> TensorDict: shape = list(batch_size) fut_shape = shape + [2] unified = TensorDict( { "actor_ref_dof_pos_cur": torch.randn(*shape, 2), "actor_dof_pos": torch.randn(*shape, 3), "actor_ref_root_height_cur": torch.randn(*shape, 1), "actor_last_action": torch.randn(*shape, 2), "actor_ref_dof_pos_fut": torch.randn(*fut_shape, 2), "actor_ref_root_height_fut": torch.randn(*fut_shape, 1), }, batch_size=shape, ) return TensorDict({"unified": unified}, batch_size=shape) def _make_ref_router_actor( *, num_actions: int = 4, freeze_router: bool = False, aux_router_future_recon: dict | None = None, ) -> PPOTFRefRouterActor: obs_schema = _make_ref_router_obs_schema() obs_example = _make_ref_router_obs([2]) module_config = { "type": "ReferenceRoutedGroupedMoETransformerPolicy", "num_fine_experts": 3, "num_shared_experts": 1, "top_k": 1, "routing_score_fn": "softmax", "routing_scale": 1.0, "use_dynamic_bias": False, "bias_update_rate": 0.001, "expert_bias_clip": 0.0, "obs_embed_mlp_hidden": 16, "d_model": 8, "n_layers": 2, "n_heads": 2, "n_kv_heads": 1, "use_gated_attn": False, "use_qk_norm": True, "ff_mult": 1.0, "ff_mult_dense": 2, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "freeze_router": freeze_router, "obs_norm": {"enabled": False}, "output_dim": num_actions, "aux_router_future_recon": aux_router_future_recon or {"enabled": False}, } return PPOTFRefRouterActor( obs_schema=obs_schema, module_config_dict=module_config, num_actions=num_actions, init_noise_std=0.2, obs_example=obs_example, ) def test_ref_router_actor_infers_only_actor_ref_feature_indices(): obs_schema = _make_ref_router_obs_schema() obs_example = _make_ref_router_obs([2]) indices = PPOTFRefRouterActor.infer_router_feature_indices( obs_schema, obs_example ) assert indices == [0, 1, 5, 8, 9, 10, 11, 12, 13] def test_ref_router_actor_single_step_and_sequence_logp_match_contract(): actor = _make_ref_router_actor() obs_td = _make_ref_router_obs([2]) inference_out = actor( obs_td, mode="inference", update_obs_norm=False, ) assert inference_out["actions"].shape == (2, 4) assert inference_out["mu"].shape == (2, 4) assert inference_out["sigma"].shape == (2, 4) cache_shape = actor.onnx_past_key_values_shape(batch_size=2) past_key_values = torch.zeros(*cache_shape, dtype=torch.float32) step_idx = torch.zeros(2, dtype=torch.long) with torch.no_grad(): actions, present = actor( obs_td, past_key_values=past_key_values, current_pos=step_idx, ) assert actions.shape == (2, 4) assert present.shape == cache_shape obs_seq = _make_ref_router_obs([2, 3]) actions_seq = torch.randn(2, 3, 4) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) seq_out = actor( obs_seq, actions=actions_seq, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=False, ) assert seq_out["mu"].shape == (2, 3, 4) assert seq_out["sigma"].shape == (2, 3, 4) assert seq_out["actions_log_prob"].shape == (2, 3, 1) assert seq_out["entropy"].shape == (2, 3, 1) def test_ref_router_actor_rejects_aux_router_future_recon(): with pytest.raises( ValueError, match="does not support aux_router_future_recon", ): _make_ref_router_actor( aux_router_future_recon={"enabled": True, "weight": 1.0} ) def test_ppotf_select_actor_wrapper_rejects_ref_router_cross_attn(): with pytest.raises( ValueError, match="ReferenceRoutedGroupedMoETransformerPolicy", ): PPOTF._select_actor_wrapper_cls( { "type": "ReferenceRoutedGroupedMoETransformerPolicy", "use_future_cross_attn": True, } ) def test_ref_router_actor_freeze_router_freezes_router_obs_embed(): actor = _make_ref_router_actor(freeze_router=True) module = actor.actor_module assert module.freeze_router is True assert module.router_obs_embed[0].weight.requires_grad is False assert module.router_obs_embed[0].bias.requires_grad is False assert module.router_obs_embed[2].weight.requires_grad is False assert module.router_obs_embed[2].bias.requires_grad is False def test_ref_router_actor_freeze_router_reapplies_after_load_state_dict(): actor = _make_ref_router_actor(freeze_router=True) module = actor.actor_module state_dict = module.state_dict() module.router_obs_embed.requires_grad_(True) for layer in module.layers: if hasattr(layer, "router"): layer.router.requires_grad_(True) result = module.load_state_dict(state_dict, strict=True) assert result.missing_keys == [] assert result.unexpected_keys == [] assert module.router_obs_embed[0].weight.requires_grad is False assert module.router_obs_embed[0].bias.requires_grad is False assert module.router_obs_embed[2].weight.requires_grad is False assert module.router_obs_embed[2].bias.requires_grad is False for layer in module.layers: if hasattr(layer, "router"): assert layer.router.weight.requires_grad is False ================================================ FILE: tests/test_ref_router_seq_actor.py ================================================ from __future__ import annotations import pytest import torch from holomotion.src.algo.ppo_tf import PPOTF from holomotion.src.modules.agent_modules import ( PPOTFRefRouterSeqActor, PPOTFRefRouterV3Actor, ) from tensordict import TensorDict REF_CUR_TERM_DIMS = { "actor_ref_gravity_projection_cur": 3, "actor_ref_base_linvel_cur": 3, "actor_ref_base_angvel_cur": 3, "actor_ref_dof_pos_cur": 2, "actor_ref_root_height_cur": 1, } REF_FUT_TERM_DIMS = { "actor_ref_gravity_projection_fut": 3, "actor_ref_base_linvel_fut": 3, "actor_ref_base_angvel_fut": 3, "actor_ref_dof_pos_fut": 2, "actor_ref_root_height_fut": 1, } def _make_ref_router_v2_obs_schema( *, include_ref_cur: bool = True, include_ref_fut: bool = True, ) -> dict: flat_terms = [] if include_ref_cur: flat_terms.extend( [ "unified/actor_ref_gravity_projection_cur", "unified/actor_ref_base_linvel_cur", "unified/actor_ref_base_angvel_cur", "unified/actor_ref_dof_pos_cur", "unified/actor_ref_root_height_cur", ] ) flat_terms.extend( [ "unified/actor_projected_gravity", "unified/actor_rel_robot_root_ang_vel", "unified/actor_dof_pos", "unified/actor_dof_vel", "unified/actor_last_action", ] ) schema = { "flattened_obs": {"seq_len": 1, "terms": flat_terms}, } if include_ref_fut: schema["flattened_obs_fut"] = { "seq_len": 5, "terms": [ "unified/actor_ref_gravity_projection_fut", "unified/actor_ref_base_linvel_fut", "unified/actor_ref_base_angvel_fut", "unified/actor_ref_dof_pos_fut", "unified/actor_ref_root_height_fut", ], } return schema def _make_ref_router_v2_obs(batch_size: list[int]) -> TensorDict: shape = list(batch_size) fut_shape = shape + [5] unified = TensorDict( { "actor_ref_gravity_projection_cur": torch.randn(*shape, 3), "actor_ref_base_linvel_cur": torch.randn(*shape, 3), "actor_ref_base_angvel_cur": torch.randn(*shape, 3), "actor_ref_dof_pos_cur": torch.randn(*shape, 2), "actor_ref_root_height_cur": torch.randn(*shape, 1), "actor_projected_gravity": torch.randn(*shape, 3), "actor_rel_robot_root_ang_vel": torch.randn(*shape, 3), "actor_dof_pos": torch.randn(*shape, 4), "actor_dof_vel": torch.randn(*shape, 4), "actor_last_action": torch.randn(*shape, 2), "actor_ref_gravity_projection_fut": torch.randn(*fut_shape, 3), "actor_ref_base_linvel_fut": torch.randn(*fut_shape, 3), "actor_ref_base_angvel_fut": torch.randn(*fut_shape, 3), "actor_ref_dof_pos_fut": torch.randn(*fut_shape, 2), "actor_ref_root_height_fut": torch.randn(*fut_shape, 1), }, batch_size=shape, ) return TensorDict({"unified": unified}, batch_size=shape) def _make_ref_router_v2_actor( *, obs_schema: dict | None = None, num_actions: int = 4, aux_state_pred: dict | None = None, aux_router_command_recon: dict | None = None, freeze_router: bool = False, ) -> PPOTFRefRouterSeqActor: obs_schema = ( _make_ref_router_v2_obs_schema() if obs_schema is None else obs_schema ) obs_example = _make_ref_router_v2_obs([2]) module_config = { "type": "ReferenceRoutedGroupedMoETransformerPolicyV2", "num_fine_experts": 3, "num_shared_experts": 1, "top_k": 1, "routing_score_fn": "softmax", "routing_scale": 1.0, "use_dynamic_bias": False, "bias_update_rate": 0.001, "expert_bias_clip": 0.0, "obs_embed_mlp_hidden": 16, "d_model": 8, "n_layers": 2, "n_heads": 2, "n_kv_heads": 1, "use_gated_attn": False, "use_qk_norm": True, "ff_mult": 1.0, "ff_mult_dense": 2, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "freeze_router": freeze_router, "ref_hist_n_layers": 1, "ref_future_conv_channels": 8, "ref_future_conv_layers": 2, "ref_future_conv_kernel_size": 3, "ref_future_conv_stride": 2, "obs_norm": {"enabled": False}, "output_dim": num_actions, "aux_state_pred": aux_state_pred or {"enabled": False}, "aux_router_command_recon": aux_router_command_recon or {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, } return PPOTFRefRouterSeqActor( obs_schema=obs_schema, module_config_dict=module_config, num_actions=num_actions, init_noise_std=0.2, obs_example=obs_example, ) def _make_ref_router_v3_actor( *, obs_schema: dict | None = None, num_actions: int = 4, freeze_router: bool = False, aux_router_future_recon: dict | None = None, ) -> PPOTFRefRouterV3Actor: obs_schema = ( _make_ref_router_v2_obs_schema() if obs_schema is None else obs_schema ) obs_example = _make_ref_router_v2_obs([2]) module_config = { "type": "ReferenceRoutedGroupedMoETransformerPolicyV3", "num_fine_experts": 3, "num_shared_experts": 1, "top_k": 1, "routing_score_fn": "softmax", "routing_scale": 1.0, "use_dynamic_bias": False, "bias_update_rate": 0.001, "expert_bias_clip": 0.0, "obs_embed_mlp_hidden": 16, "d_model": 8, "n_layers": 2, "n_heads": 2, "n_kv_heads": 1, "use_gated_attn": False, "use_qk_norm": True, "ff_mult": 1.0, "ff_mult_dense": 2, "attn_dropout": 0.0, "mlp_dropout": 0.0, "max_ctx_len": 4, "freeze_router": freeze_router, "ref_hist_n_layers": 1, "router_future_hidden_dim": 12, "router_layer_proj_hidden_dim": 10, "obs_norm": {"enabled": False}, "output_dim": num_actions, "aux_state_pred": {"enabled": False}, "aux_router_command_recon": {"enabled": False}, "aux_router_future_recon": aux_router_future_recon or {"enabled": False}, "aux_router_switch_penalty": {"enabled": False}, } return PPOTFRefRouterV3Actor( obs_schema=obs_schema, module_config_dict=module_config, num_actions=num_actions, init_noise_std=0.2, obs_example=obs_example, ) def test_ppotf_select_actor_wrapper_uses_ref_router_seq_actor(): actor_cls = PPOTF._select_actor_wrapper_cls( {"type": "ReferenceRoutedGroupedMoETransformerPolicyV2"} ) assert actor_cls is PPOTFRefRouterSeqActor def test_ppotf_select_actor_wrapper_uses_ref_router_v3_actor(): actor_cls = PPOTF._select_actor_wrapper_cls( {"type": "ReferenceRoutedGroupedMoETransformerPolicyV3"} ) assert actor_cls is PPOTFRefRouterV3Actor def test_ref_router_seq_actor_infers_shared_ref_partitions_without_router_schemas(): actor = _make_ref_router_v2_actor() assert actor.state_obs_input_dim > 0 assert actor.ref_cur_token_dim == sum(REF_CUR_TERM_DIMS.values()) assert actor.ref_fut_token_dim == sum(REF_FUT_TERM_DIMS.values()) assert actor.ref_fut_seq_len == 5 cache_shape = actor.onnx_past_key_values_shape(batch_size=2) assert len(cache_shape) == 6 assert cache_shape[0] == actor.actor_module.onnx_kv_layers assert cache_shape[1] == 2 assert cache_shape[2] == 2 assert cache_shape[-1] == 4 def test_ref_router_v3_actor_keeps_full_obs_backbone_and_layer_router_adapters(): actor = _make_ref_router_v3_actor() module = actor.actor_module assert actor.full_obs_input_dim > actor.state_obs_input_dim assert module.full_obs_input_dim == actor.full_obs_input_dim assert module.obs_embed[0].in_features == actor.full_obs_input_dim assert len(module.router_layer_projections) == sum( isinstance(layer, type(module.layers[1])) for layer in module.layers ) def test_ref_router_v3_history_backbone_consumes_flat_ref_motion(): actor = _make_ref_router_v3_actor() module = actor.actor_module x = torch.randn(2, 3, module.full_obs_input_dim) _, ref_cur_x, ref_fut_x = module._split_actor_ref_inputs(x) assert hasattr(module, "_build_router_ref_motion") ref_motion_x = module._build_router_ref_motion(ref_cur_x, ref_fut_x) assert ref_motion_x.shape == ( 2, 3, actor.ref_cur_token_dim + actor.ref_fut_seq_len * actor.ref_fut_token_dim, ) assert module.ref_frame_embed[0].in_features == ref_motion_x.shape[-1] assert not hasattr(module, "router_future_obs_embed") assert not hasattr(module, "router_future_pool") assert not hasattr(module, "router_summary_fusion") assert not hasattr(module, "router_summary_norm") def test_ref_router_v3_actor_sequence_logp_emits_aux_router_future_recon(): actor = _make_ref_router_v3_actor( aux_router_future_recon={ "enabled": True, "hidden_dim": 9, "weight": 1.0, } ) obs_td = _make_ref_router_v2_obs([2, 3]) actions_seq = torch.randn(2, 3, 4) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) seq_out = actor( obs_td, actions=actions_seq, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=False, ) assert seq_out["aux_router_future_recon"].shape == ( 2, 3, actor.ref_fut_seq_len * actor.ref_fut_token_dim, ) def test_ref_router_v3_actor_updates_future_recon_empirical_normalizer(): actor = _make_ref_router_v3_actor( aux_router_future_recon={ "enabled": True, "hidden_dim": 9, "weight": 1.0, } ) obs_td = _make_ref_router_v2_obs([2, 3]) actions_seq = torch.randn(2, 3, 4) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) assert actor.aux_router_future_recon_assembler is not None assert ( int(actor.actor_module.aux_router_future_recon_normalizer.count) == 0 ) actor( obs_td, actions=actions_seq, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=True, ) assert ( int(actor.actor_module.aux_router_future_recon_normalizer.count) == 6 ) def test_ref_router_v2_freeze_router_freezes_reference_router_path(): actor = _make_ref_router_v2_actor(freeze_router=True) module = actor.actor_module x = torch.randn(2, 3, module.full_obs_input_dim, requires_grad=True) mu, router_h, router_temporal_features = module.sequence_mu( x, return_ref_aux_hidden=True, return_router_temporal_features=True, ) mu.sum().backward() first_moe = next( layer for layer in module.layers if hasattr(layer, "router") ) assert module.freeze_router is True assert module.ref_frame_embed[0].weight.requires_grad is False assert module.ref_hist_attn.q_proj.weight.requires_grad is False assert module.ref_future_conv[0].weight.requires_grad is False assert module.router_ref_pool.q_proj.weight.requires_grad is False assert module.router_query.requires_grad is False assert first_moe.router.weight.requires_grad is False assert router_h.requires_grad is False assert router_temporal_features.requires_grad is False assert module.ref_frame_embed[0].weight.grad is None assert module.ref_hist_attn.q_proj.weight.grad is None assert module.ref_future_conv[0].weight.grad is None assert module.router_ref_pool.q_proj.weight.grad is None assert module.router_query.grad is None assert first_moe.router.weight.grad is None def test_ref_router_v3_freeze_router_reapplies_after_load_state_dict(): actor = _make_ref_router_v3_actor(freeze_router=True) module = actor.actor_module state_dict = module.state_dict() module.ref_frame_embed.requires_grad_(True) module.ref_hist_attn.requires_grad_(True) module.router_layer_projections.requires_grad_(True) for layer in module.layers: if hasattr(layer, "router"): layer.router.requires_grad_(True) result = module.load_state_dict(state_dict, strict=True) assert result.missing_keys == [] assert result.unexpected_keys == [] assert module.ref_frame_embed[0].weight.requires_grad is False assert module.ref_hist_attn.q_proj.weight.requires_grad is False assert module.router_layer_projections[0][1].weight.requires_grad is False assert not hasattr(module, "router_future_obs_embed") assert not hasattr(module, "router_future_pool") assert not hasattr(module, "router_summary_fusion") assert not hasattr(module, "router_summary_norm") for layer in module.layers: if hasattr(layer, "router"): assert layer.router.weight.requires_grad is False def test_ref_router_v2_film_head_starts_near_identity(): actor = _make_ref_router_v2_actor() module = actor.actor_module assert hasattr(module, "_actor_film_gain") gains = module._actor_film_gain().detach() assert gains.shape == (module.d_model,) assert torch.allclose(gains, torch.full_like(gains, 0.05), atol=1.0e-5) last_linear = module.actor_ref_film[-1] assert torch.count_nonzero(last_linear.weight.detach()) == 0 assert torch.count_nonzero(last_linear.bias.detach()) == 0 hidden = torch.randn(2, 1, module.d_model) actor_ref_ctx = torch.randn(2, 1, module.d_model) conditioned = module._apply_actor_ref_film(hidden, actor_ref_ctx) assert torch.allclose(conditioned, hidden) def test_ref_router_v2_pre_moe_hidden_precedes_film_modulation(): actor = _make_ref_router_v2_actor() module = actor.actor_module with torch.no_grad(): module.actor_film_gain_raw.fill_(100.0) module.actor_ref_film[-1].weight.zero_() module.actor_ref_film[-1].bias.fill_(2.0) x = torch.randn(2, module.full_obs_input_dim) x_seq = x[:, None, :] state_x, ref_cur_x, ref_fut_x = module._split_actor_ref_inputs(x_seq) state_h = module.obs_embed(state_x) ref_cur_h = module.ref_frame_embed(ref_cur_x) ref_hist_attn = module.ref_hist_attn( module.ref_hist_norm(ref_cur_h), *module.get_cos_sin(ref_cur_h, torch.zeros(2, 1, dtype=torch.long)), mask=None, ) ref_hist_h = module.ref_hist_out_norm(ref_cur_h + ref_hist_attn) ref_fut_tokens = module._encode_future_tokens(ref_fut_x) shared_ref_tokens = torch.cat( [ref_hist_h.unsqueeze(2), ref_fut_tokens], dim=2 ) router_h = module._pool_router_context(shared_ref_tokens) cos, sin = module.get_cos_sin(state_h, torch.zeros(2, 1, dtype=torch.long)) block0_hidden = module._forward_layers_range( state_h, cos=cos, sin=sin, mask=None, router_h=router_h, start_layer=0, end_layer=1, ) _, pre_moe_hidden = module.sequence_mu( x_seq, return_pre_moe_hidden=True, ) assert torch.allclose(pre_moe_hidden, block0_hidden) def test_ref_router_v2_film_gain_is_bounded_per_channel(): actor = _make_ref_router_v2_actor() module = actor.actor_module assert hasattr(module, "actor_film_gain_raw") with torch.no_grad(): module.actor_film_gain_raw.copy_( torch.linspace(-100.0, 100.0, module.d_model) ) gains = module._actor_film_gain() assert gains.shape == (module.d_model,) assert torch.all(gains >= 0.0) assert torch.all(gains <= module.actor_film_gain_max + 1.0e-6) assert torch.unique(gains).numel() > 1 def test_ref_router_v2_film_perturbation_rms_stays_bounded(): actor = _make_ref_router_v2_actor() module = actor.actor_module assert hasattr(module, "actor_film_gain_raw") with torch.no_grad(): module.actor_ref_film[-1].weight.zero_() module.actor_ref_film[-1].bias.fill_(100.0) module.actor_film_gain_raw.fill_(100.0) hidden = torch.randn(4, 3, module.d_model) actor_ref_ctx = torch.randn(4, 3, module.d_model) conditioned = module._apply_actor_ref_film(hidden, actor_ref_ctx) delta = conditioned - hidden delta_rms = delta.pow(2).mean(dim=-1).sqrt() assert torch.all(delta_rms <= module.actor_film_gain_max + 1.0e-5) def test_ref_router_v2_aux_prediction_stays_bound_to_returned_pre_moe_hidden(): actor = _make_ref_router_v2_actor( aux_state_pred={ "enabled": True, "w_base_lin_vel": 1.0, "w_keybody_contact": 1.0, "w_ref_keybody_rel_pos": 1.0, "w_robot_keybody_rel_pos": 1.0, "keybody_contact_names": ["knee"], "keybody_rel_pos_names": ["knee"], } ) module = actor.actor_module module.eval() x_a = torch.randn(1, 1, module.full_obs_input_dim) x_b = x_a + 0.5 with torch.no_grad(): _, pre_a, ref_aux_a = module.sequence_mu( x_a, return_pre_moe_hidden=True, return_ref_aux_hidden=True, ) aux_a = module.predict_aux_from_pre_moe( pre_a, ref_aux_hidden=ref_aux_a ) _, pre_b, ref_aux_b = module.sequence_mu( x_b, return_pre_moe_hidden=True, return_ref_aux_hidden=True, ) aux_a_late = module.predict_aux_from_pre_moe( pre_a, ref_aux_hidden=ref_aux_a ) aux_b = module.predict_aux_from_pre_moe( pre_b, ref_aux_hidden=ref_aux_b ) assert torch.allclose( aux_a_late["ref_keybody_rel_pos"], aux_a["ref_keybody_rel_pos"] ) assert torch.allclose( aux_a_late["base_lin_vel_loc"], aux_a["base_lin_vel_loc"] ) assert not torch.allclose( aux_a["ref_keybody_rel_pos"], aux_b["ref_keybody_rel_pos"] ) assert not hasattr(pre_a, "_ref_aux_hidden") def test_ref_router_v2_sequence_single_step_and_cached_onnx_agree(): actor = _make_ref_router_v2_actor() module = actor.actor_module module.eval() x_seq = torch.randn(1, 2, module.full_obs_input_dim) attn_mask = torch.tril(torch.ones(2, 2, dtype=torch.bool)).unsqueeze(0) with torch.no_grad(): mu_seq = module.sequence_mu(x_seq, attn_mask=attn_mask) module.reset_kv_cache(num_envs=1, device=x_seq.device) mu_step_0 = module.single_step_mu(x_seq[:, 0, :]) mu_step_1 = module.single_step_mu(x_seq[:, 1, :]) mu_single_step = torch.stack([mu_step_0, mu_step_1], dim=1) cache_shape = actor.onnx_past_key_values_shape(batch_size=1) past_key_values = torch.zeros(*cache_shape, dtype=x_seq.dtype) step_0 = torch.zeros(1, dtype=torch.long) step_1 = torch.ones(1, dtype=torch.long) mu_onnx_0, present_0 = module.forward( x_seq[:, 0, :], past_key_values=past_key_values, current_pos=step_0, ) mu_onnx_1, present_1 = module.forward( x_seq[:, 1, :], past_key_values=present_0, current_pos=step_1, ) mu_onnx = torch.stack([mu_onnx_0, mu_onnx_1], dim=1) assert torch.allclose(mu_single_step, mu_seq, atol=1.0e-5, rtol=1.0e-4) assert torch.allclose(mu_onnx, mu_seq, atol=1.0e-5, rtol=1.0e-4) assert present_0.shape == cache_shape assert present_1.shape == cache_shape def test_ref_router_v3_sequence_single_step_and_cached_onnx_agree(): actor = _make_ref_router_v3_actor() module = actor.actor_module module.eval() x_seq = torch.randn(1, 2, module.full_obs_input_dim) attn_mask = torch.tril(torch.ones(2, 2, dtype=torch.bool)).unsqueeze(0) with torch.no_grad(): mu_seq = module.sequence_mu(x_seq, attn_mask=attn_mask) module.reset_kv_cache(num_envs=1, device=x_seq.device) mu_step_0 = module.single_step_mu(x_seq[:, 0, :]) mu_step_1 = module.single_step_mu(x_seq[:, 1, :]) mu_single_step = torch.stack([mu_step_0, mu_step_1], dim=1) cache_shape = actor.onnx_past_key_values_shape(batch_size=1) past_key_values = torch.zeros(*cache_shape, dtype=x_seq.dtype) step_0 = torch.zeros(1, dtype=torch.long) step_1 = torch.ones(1, dtype=torch.long) mu_onnx_0, present_0 = module.forward( x_seq[:, 0, :], past_key_values=past_key_values, current_pos=step_0, ) mu_onnx_1, present_1 = module.forward( x_seq[:, 1, :], past_key_values=present_0, current_pos=step_1, ) mu_onnx = torch.stack([mu_onnx_0, mu_onnx_1], dim=1) assert torch.allclose(mu_single_step, mu_seq, atol=1.0e-5, rtol=1.0e-4) assert torch.allclose(mu_onnx, mu_seq, atol=1.0e-5, rtol=1.0e-4) assert present_0.shape == cache_shape assert present_1.shape == cache_shape def test_ref_router_seq_actor_single_step_and_sequence_logp_match_contract(): actor = _make_ref_router_v2_actor() obs_td = _make_ref_router_v2_obs([2]) inference_out = actor( obs_td, mode="inference", update_obs_norm=False, ) assert inference_out["actions"].shape == (2, 4) assert inference_out["mu"].shape == (2, 4) assert inference_out["sigma"].shape == (2, 4) cache_shape = actor.onnx_past_key_values_shape(batch_size=2) past_key_values = torch.zeros(*cache_shape, dtype=torch.float32) step_idx = torch.zeros(2, dtype=torch.long) with torch.no_grad(): actions, present = actor( obs_td, past_key_values=past_key_values, current_pos=step_idx, ) assert actions.shape == (2, 4) assert present.shape == cache_shape obs_seq = _make_ref_router_v2_obs([2, 3]) actions_seq = torch.randn(2, 3, 4) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) seq_out = actor( obs_seq, actions=actions_seq, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=False, ) assert seq_out["mu"].shape == (2, 3, 4) assert seq_out["sigma"].shape == (2, 3, 4) assert seq_out["actions_log_prob"].shape == (2, 3, 1) assert seq_out["entropy"].shape == (2, 3, 1) def test_ref_router_v3_actor_single_step_and_sequence_logp_match_contract(): actor = _make_ref_router_v3_actor() obs_td = _make_ref_router_v2_obs([2]) inference_out = actor( obs_td, mode="inference", update_obs_norm=False, ) assert inference_out["actions"].shape == (2, 4) assert inference_out["mu"].shape == (2, 4) assert inference_out["sigma"].shape == (2, 4) cache_shape = actor.onnx_past_key_values_shape(batch_size=2) past_key_values = torch.zeros(*cache_shape, dtype=torch.float32) step_idx = torch.zeros(2, dtype=torch.long) with torch.no_grad(): actions, present = actor( obs_td, past_key_values=past_key_values, current_pos=step_idx, ) assert actions.shape == (2, 4) assert present.shape == cache_shape obs_seq = _make_ref_router_v2_obs([2, 3]) actions_seq = torch.randn(2, 3, 4) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) seq_out = actor( obs_seq, actions=actions_seq, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=False, ) assert seq_out["mu"].shape == (2, 3, 4) assert seq_out["sigma"].shape == (2, 3, 4) assert seq_out["actions_log_prob"].shape == (2, 3, 1) assert seq_out["entropy"].shape == (2, 3, 1) def test_ref_router_seq_actor_sequence_logp_emits_aux_preds_without_metadata(): actor = _make_ref_router_v2_actor( aux_state_pred={ "enabled": True, "w_base_lin_vel": 1.0, "w_keybody_contact": 1.0, "w_ref_keybody_rel_pos": 1.0, "w_robot_keybody_rel_pos": 1.0, "keybody_contact_names": ["knee"], "keybody_rel_pos_names": ["knee"], } ) obs_seq = _make_ref_router_v2_obs([2, 3]) actions_seq = torch.randn(2, 3, 4) attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand( 2, -1, -1 ) seq_out = actor( obs_seq, actions=actions_seq, mode="sequence_logp", attn_mask=attn_mask, update_obs_norm=False, ) assert "aux_ref_keybody_rel_pos" in seq_out.keys() assert "aux_robot_keybody_rel_pos" in seq_out.keys() assert "aux_base_lin_vel_loc" in seq_out.keys() assert seq_out["aux_ref_keybody_rel_pos"].shape == (2, 3, 1, 3) assert seq_out["aux_robot_keybody_rel_pos"].shape == (2, 3, 1, 3) def test_ref_router_seq_actor_requires_all_shared_ref_terms(): obs_schema = _make_ref_router_v2_obs_schema(include_ref_cur=False) with pytest.raises(ValueError, match="missing required current ref term"): _make_ref_router_v2_actor(obs_schema=obs_schema) def test_ref_router_seq_actor_rejects_aux_router_command_recon(): with pytest.raises(ValueError, match="aux_router_command_recon"): _make_ref_router_v2_actor( aux_router_command_recon={"enabled": True, "hidden_dim": 8} ) def test_ref_router_seq_actor_rejects_unsupported_aux_state_pred_weights(): with pytest.raises(ValueError, match="root_height"): _make_ref_router_v2_actor( aux_state_pred={ "enabled": True, "w_base_lin_vel": 0.0, "w_keybody_contact": 0.0, "w_ref_keybody_rel_pos": 0.0, "w_robot_keybody_rel_pos": 0.0, "w_root_height": 1.0, "keybody_contact_names": [], "keybody_rel_pos_names": [], } ) ================================================ FILE: tests/test_reference_filter_export.py ================================================ import json import sys import tempfile import unittest from unittest import mock from pathlib import Path import numpy as np import torch from omegaconf import OmegaConf sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.training.h5_dataloader import MotionClipSample from holomotion.src.training.reference_filter_export import ( export_reference_filter_artifacts_from_config, export_reference_filter_debug_artifacts, ) def _quat_xyzw_from_rpy(roll: float, pitch: float, yaw: float) -> torch.Tensor: cr = np.cos(roll * 0.5) sr = np.sin(roll * 0.5) cp = np.cos(pitch * 0.5) sp = np.sin(pitch * 0.5) cy = np.cos(yaw * 0.5) sy = np.sin(yaw * 0.5) return torch.tensor( [ sr * cp * cy - cr * sp * sy, cr * sp * cy + sr * cp * sy, cr * cp * sy - sr * sp * cy, cr * cp * cy + sr * sp * sy, ], dtype=torch.float32, ) def _make_sample(*, include_filtered: bool = True) -> MotionClipSample: timesteps = 4 num_bodies = 5 num_dofs = 3 ref_rg_pos = torch.arange( timesteps * num_bodies * 3, dtype=torch.float32 ).reshape(timesteps, num_bodies, 3) ref_body_vel = ref_rg_pos + 100.0 ref_body_ang_vel = ref_rg_pos + 200.0 ref_dof_pos = torch.arange( timesteps * num_dofs, dtype=torch.float32 ).reshape(timesteps, num_dofs) ref_dof_vel = ref_dof_pos + 50.0 ref_rb_rot = torch.stack( [ _quat_xyzw_from_rpy(0.0, 0.0, 0.0), _quat_xyzw_from_rpy(0.1, -0.2, 0.3), _quat_xyzw_from_rpy(0.2, -0.1, 0.4), _quat_xyzw_from_rpy(0.3, 0.0, 0.5), ], dim=0, )[:, None, :].repeat(1, num_bodies, 1) tensors = { "ref_rg_pos": ref_rg_pos, "ref_rb_rot": ref_rb_rot, "ref_body_vel": ref_body_vel, "ref_body_ang_vel": ref_body_ang_vel, "ref_root_pos": ref_rg_pos[:, 0, :], "ref_root_rot": ref_rb_rot[:, 0, :], "ref_root_vel": ref_body_vel[:, 0, :], "ref_root_ang_vel": ref_body_ang_vel[:, 0, :], "ref_dof_pos": ref_dof_pos, "ref_dof_vel": ref_dof_vel, "filter_cutoff_hz": torch.full( (timesteps, 1), 2.0, dtype=torch.float32 ), } if include_filtered: tensors.update( { "ft_ref_rg_pos": ref_rg_pos + 0.5, "ft_ref_rb_rot": ref_rb_rot.clone(), "ft_ref_body_vel": ref_body_vel + 0.25, "ft_ref_body_ang_vel": ref_body_ang_vel + 0.25, "ft_ref_root_pos": ref_rg_pos[:, 0, :] + 0.5, "ft_ref_root_rot": ref_rb_rot[:, 0, :].clone(), "ft_ref_root_vel": ref_body_vel[:, 0, :] + 0.25, "ft_ref_root_ang_vel": ref_body_ang_vel[:, 0, :] + 0.25, "ft_ref_dof_pos": ref_dof_pos + 0.75, "ft_ref_dof_vel": ref_dof_vel + 0.75, } ) return MotionClipSample( motion_key="clip-a__start_0_len_4", raw_motion_key="clip-a", window_index=0, tensors=tensors, length=timesteps, ) class ReferenceFilterExportTests(unittest.TestCase): def test_export_reference_filter_artifacts_from_config_builds_dataset( self, ): sample = _make_sample() body_names = [ "root_link", "torso_link", "left_wrist_yaw_link", "right_wrist_yaw_link", "left_ankle_roll_link", ] dof_names = [ "waist_yaw_joint", "left_wrist_yaw_joint", "left_ankle_roll_joint", ] with tempfile.TemporaryDirectory() as tmp_dir: config = OmegaConf.create( { "robot": { "body_names": body_names, "dof_names": dof_names, "motion": { "online_filter": {"enabled": True}, "max_frame_length": 4, "min_frame_length": 1, "world_frame_normalization": True, }, }, "debug_reference_filter_export": { "enabled": True, "output_dir": tmp_dir, "selected_body_links": [ "left_wrist_yaw_link", "left_ankle_roll_link", ], }, } ) with mock.patch( "holomotion.src.training.reference_filter_export." "build_motion_datasets_from_cfg", return_value=([sample], None, {}), ) as build_mock: output_dir = export_reference_filter_artifacts_from_config( config ) self.assertEqual(output_dir, Path(tmp_dir)) build_mock.assert_called_once() self.assertTrue((Path(tmp_dir) / "metadata.json").is_file()) def test_export_reference_filter_debug_artifacts_writes_outputs(self): sample = _make_sample() body_names = [ "root_link", "torso_link", "left_wrist_yaw_link", "right_wrist_yaw_link", "left_ankle_roll_link", ] dof_names = [ "waist_yaw_joint", "left_wrist_yaw_joint", "left_ankle_roll_joint", ] selected_body_links = [ "left_wrist_yaw_link", "right_wrist_yaw_link", "left_ankle_roll_link", ] with tempfile.TemporaryDirectory() as tmp_dir: output_dir = Path(tmp_dir) export_reference_filter_debug_artifacts( sample=sample, output_dir=output_dir, body_names=body_names, dof_names=dof_names, selected_body_links=selected_body_links, ) self.assertTrue((output_dir / "metadata.json").is_file()) self.assertTrue((output_dir / "root_signals.npz").is_file()) self.assertTrue((output_dir / "bodylink_signals.npz").is_file()) self.assertTrue((output_dir / "dof_signals.npz").is_file()) self.assertTrue((output_dir / "root_comparison.png").is_file()) self.assertTrue( (output_dir / "left_wrist_yaw_link_comparison.png").is_file() ) self.assertTrue((output_dir / "dof_pos_comparison.png").is_file()) self.assertTrue((output_dir / "dof_vel_comparison.png").is_file()) metadata = json.loads( (output_dir / "metadata.json").read_text(encoding="utf-8") ) self.assertEqual(metadata["filter_cutoff_hz"], 2.0) self.assertEqual( metadata["selected_body_links"], selected_body_links ) self.assertEqual(metadata["dof_names"], dof_names) root_payload = np.load(output_dir / "root_signals.npz") self.assertIn("ref_global_pos", root_payload.files) self.assertIn("ft_ref_rpy", root_payload.files) self.assertEqual(root_payload["ref_global_pos"].shape, (4, 3)) self.assertEqual(root_payload["ft_ref_rpy"].shape, (4, 3)) dof_payload = np.load(output_dir / "dof_signals.npz") self.assertEqual(dof_payload["ref_dof_pos"].shape, (4, 3)) self.assertEqual(dof_payload["ft_ref_dof_vel"].shape, (4, 3)) def test_export_reference_filter_debug_artifacts_requires_filtered_tensors( self, ): sample = _make_sample(include_filtered=False) with tempfile.TemporaryDirectory() as tmp_dir: with self.assertRaisesRegex( ValueError, "Filtered reference tensors are unavailable" ): export_reference_filter_debug_artifacts( sample=sample, output_dir=Path(tmp_dir), body_names=["root_link", "left_wrist_yaw_link"], dof_names=["waist_yaw_joint"], selected_body_links=["left_wrist_yaw_link"], ) ================================================ FILE: tests/test_reference_motion_config_wiring.py ================================================ import sys import unittest from pathlib import Path from omegaconf import OmegaConf sys.path.insert(0, str(Path(__file__).resolve().parents[1])) PROJECT_ROOT = Path(__file__).resolve().parents[1] OBS_CONFIG_PATHS = [ PROJECT_ROOT / "holomotion/config/env/observations/motion_tracking/obs_motion_tracking.yaml", PROJECT_ROOT / "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3.yaml", PROJECT_ROOT / "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_more_info.yaml", PROJECT_ROOT / "holomotion/config/env/observations/motion_tracking/obs_motrack_mlp_20260210.yaml", PROJECT_ROOT / "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_20260210.yaml", PROJECT_ROOT / "holomotion/config/env/observations/motion_tracking/obs_motrack_teacher.yaml", ] TERMINATION_CONFIG_PATHS = [ PROJECT_ROOT / "holomotion/config/env/terminations/termination_motion_tracking.yaml", PROJECT_ROOT / "holomotion/config/env/terminations/termination_motion_tracking_simple.yaml", PROJECT_ROOT / "holomotion/config/env/terminations/termination_motrack_with_kpe.yaml", PROJECT_ROOT / "holomotion/config/env/terminations/termination_motrack_with_kpe_jpe.yaml", ] ROBOT_TRAINING_CONFIG_PATH = ( PROJECT_ROOT / "holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml" ) REWARD_CONFIG_PATH = ( PROJECT_ROOT / "holomotion/config/env/rewards/motion_tracking/rew_motrack_robust.yaml" ) class ReferenceMotionConfigWiringTests(unittest.TestCase): def test_motion_tracking_observation_configs_expose_cutoff_term(self): for config_path in OBS_CONFIG_PATHS: with self.subTest(config_path=str(config_path)): config = OmegaConf.load(config_path) self.assertTrue( self._config_has_obs_term( config, term_name="ref_motion_filter_cutoff_hz", ) ) def test_motion_tracking_termination_configs_forward_ref_prefix(self): for config_path in TERMINATION_CONFIG_PATHS: with self.subTest(config_path=str(config_path)): config = OmegaConf.load(config_path) for term_name, term_cfg in config.terminations.items(): if term_name == "time_out": continue self.assertIn("params", term_cfg) self.assertIn("ref_prefix", term_cfg.params) def test_robot_training_config_uses_hdf5_v2_backend(self): config = OmegaConf.load(ROBOT_TRAINING_CONFIG_PATH) self.assertEqual(config.robot.motion.backend, "hdf5_v2") def test_termination_ref_prefix_resolves_with_reward_config(self): reward_config = OmegaConf.load(REWARD_CONFIG_PATH) for config_path in TERMINATION_CONFIG_PATHS: with self.subTest(config_path=str(config_path)): termination_config = OmegaConf.load(config_path) merged = OmegaConf.merge(reward_config, termination_config) for term_name, term_cfg in merged.terminations.items(): if term_name == "time_out": continue resolved_ref_prefix = OmegaConf.select( merged, f"terminations.{term_name}.params.ref_prefix", ) self.assertIsInstance(resolved_ref_prefix, str) self.assertTrue(resolved_ref_prefix.endswith("ref_")) @staticmethod def _config_has_obs_term(config, term_name: str) -> bool: for group_cfg in config.obs.obs_groups.values(): for term_cfg in group_cfg.atomic_obs_list: for obs_name, obs_value in term_cfg.items(): if obs_name == term_name: return True if obs_value.get("func") == term_name: return True return False ================================================ FILE: tests/test_root_rel_rewards.py ================================================ import importlib.util import sys from pathlib import Path from types import ModuleType, SimpleNamespace import torch REWARDS_PATH = ( Path(__file__).resolve().parents[1] / "holomotion" / "src" / "env" / "isaaclab_components" / "isaaclab_rewards.py" ) MOTION_TRACKING_PATH = ( Path(__file__).resolve().parents[1] / "holomotion" / "src" / "env" / "motion_tracking.py" ) class _DummyConfig: def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs if args: self.name = args[0] for key, value in kwargs.items(): setattr(self, key, value) if not hasattr(self, "params"): self.params = {} class _DummyManagerTermBase: def __init__(self, cfg, env): self.cfg = cfg self._env = env @property def num_envs(self): return self._env.num_envs @property def device(self): return self._env.device def reset(self, env_ids=None): pass def _identity_quat(*shape: int) -> torch.Tensor: quat = torch.zeros(*shape, 4, dtype=torch.float32) quat[..., 0] = 1.0 return quat def _load_rewards_module(monkeypatch): isaaclab_root = ModuleType("isaaclab") isaaclab_assets = ModuleType("isaaclab.assets") isaaclab_assets.Articulation = object isaaclab_envs = ModuleType("isaaclab.envs") isaaclab_envs.ManagerBasedRLEnv = object isaaclab_mdp = ModuleType("isaaclab.envs.mdp") isaaclab_mdp.__getattr__ = lambda name: (lambda *args, **kwargs: None) isaaclab_managers = ModuleType("isaaclab.managers") isaaclab_managers.ManagerTermBase = _DummyManagerTermBase isaaclab_managers.RewardTermCfg = _DummyConfig isaaclab_managers.SceneEntityCfg = _DummyConfig isaaclab_sensors = ModuleType("isaaclab.sensors") isaaclab_sensors.ContactSensor = object isaaclab_utils = ModuleType("isaaclab.utils") isaaclab_utils.configclass = lambda cls: cls isaaclab_math = ModuleType("isaaclab.utils.math") isaaclab_math.quat_apply = lambda quat, vec: vec isaaclab_math.quat_apply_inverse = lambda quat, vec: vec isaaclab_math.quat_inv = lambda quat: quat isaaclab_math.quat_mul = lambda lhs, rhs: lhs isaaclab_math.yaw_quat = lambda quat: quat isaaclab_math.quat_error_magnitude = lambda lhs, rhs: torch.linalg.norm( lhs - rhs, dim=-1 ) isaaclab_math.__getattr__ = lambda name: (lambda *args, **kwargs: None) hydra_utils = ModuleType("hydra.utils") hydra_utils.instantiate = lambda value, *args, **kwargs: value omegaconf = ModuleType("omegaconf") omegaconf.DictConfig = dict omegaconf.ListConfig = list omegaconf.OmegaConf = SimpleNamespace( to_container=lambda value, resolve=True: value ) loguru = ModuleType("loguru") loguru.logger = SimpleNamespace( info=lambda *args, **kwargs: None, warning=lambda *args, **kwargs: None, ) fake_command_module = ModuleType( "holomotion.src.env.isaaclab_components." "isaaclab_motion_tracking_command" ) fake_command_module.RefMotionCommand = object fake_utils_module = ModuleType( "holomotion.src.env.isaaclab_components.isaaclab_utils" ) fake_utils_module._get_body_indices = lambda robot, keybody_names: [ robot.body_names.index(name) for name in keybody_names ] fake_utils_module._get_dof_indices = lambda robot, key_dofs: [] fake_utils_module.resolve_holo_config = lambda value: value for name, module in { "isaaclab": isaaclab_root, "isaaclab.assets": isaaclab_assets, "isaaclab.envs": isaaclab_envs, "isaaclab.envs.mdp": isaaclab_mdp, "isaaclab.managers": isaaclab_managers, "isaaclab.sensors": isaaclab_sensors, "isaaclab.utils": isaaclab_utils, "isaaclab.utils.math": isaaclab_math, "hydra.utils": hydra_utils, "omegaconf": omegaconf, "loguru": loguru, ( "holomotion.src.env.isaaclab_components." "isaaclab_motion_tracking_command" ): fake_command_module, ( "holomotion.src.env.isaaclab_components.isaaclab_utils" ): fake_utils_module, }.items(): monkeypatch.setitem(sys.modules, name, module) isaaclab_root.assets = isaaclab_assets isaaclab_root.envs = isaaclab_envs isaaclab_root.managers = isaaclab_managers isaaclab_root.sensors = isaaclab_sensors isaaclab_root.utils = isaaclab_utils isaaclab_envs.mdp = isaaclab_mdp isaaclab_utils.math = isaaclab_math module_name = "_test_root_rel_rewards" spec = importlib.util.spec_from_file_location(module_name, REWARDS_PATH) module = importlib.util.module_from_spec(spec) assert spec is not None assert spec.loader is not None sys.modules[module_name] = module spec.loader.exec_module(module) return module def _load_motion_tracking_module(monkeypatch): class _DummyConfigClass: def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs isaaclab_root = ModuleType("isaaclab") isaaclab_actuators = ModuleType("isaaclab.actuators") isaaclab_actuators.ImplicitActuatorCfg = _DummyConfigClass isaaclab_assets = ModuleType("isaaclab.assets") isaaclab_assets.Articulation = object isaaclab_envs = ModuleType("isaaclab.envs") isaaclab_envs.ManagerBasedEnv = object isaaclab_envs.ManagerBasedRLEnv = object isaaclab_envs.ManagerBasedRLEnvCfg = object isaaclab_envs.ViewerCfg = _DummyConfigClass isaaclab_envs_mdp = ModuleType("isaaclab.envs.mdp") isaaclab_envs_mdp.__getattr__ = lambda name: (lambda *args, **kwargs: None) isaaclab_envs_mdp_events = ModuleType("isaaclab.envs.mdp.events") isaaclab_envs_mdp_events._randomize_prop_by_op = ( lambda *args, **kwargs: None ) isaaclab_managers = ModuleType("isaaclab.managers") isaaclab_managers.EventTermCfg = _DummyConfigClass isaaclab_managers.SceneEntityCfg = _DummyConfig isaaclab_sim = ModuleType("isaaclab.sim") isaaclab_sim.PhysxCfg = _DummyConfigClass isaaclab_sim.SimulationCfg = _DummyConfigClass isaaclab_utils = ModuleType("isaaclab.utils") isaaclab_utils.configclass = lambda cls: cls isaaclab_utils_io = ModuleType("isaaclab.utils.io") isaaclab_utils_io.dump_yaml = lambda *args, **kwargs: None isaaclab_utils_math = ModuleType("isaaclab.utils.math") isaaclab_utils_math.__getattr__ = lambda name: ( lambda *args, **kwargs: None ) easydict = ModuleType("easydict") easydict.EasyDict = lambda value=None: value if value is not None else {} omegaconf = ModuleType("omegaconf") omegaconf.OmegaConf = SimpleNamespace( to_container=lambda value, resolve=True: value ) loguru = ModuleType("loguru") loguru.logger = SimpleNamespace( info=lambda *args, **kwargs: None, warning=lambda *args, **kwargs: None, ) isaaclab_components = ModuleType("holomotion.src.env.isaaclab_components") for name in [ "ActionsCfg", "VelTrack_CommandsCfg", "MoTrack_CommandsCfg", "EventsCfg", "MotionTrackingSceneCfg", "ObservationsCfg", "RewardsCfg", "TerminationsCfg", "CurriculumCfg", ]: setattr(isaaclab_components, name, _DummyConfigClass) for name in [ "build_actions_config", "build_motion_tracking_commands_config", "build_velocity_commands_config", "build_domain_rand_config", "build_curriculum_config", "build_observations_config", "build_rewards_config", "build_scene_config", "build_terminations_config", ]: setattr(isaaclab_components, name, lambda *args, **kwargs: None) fake_observation_module = ModuleType( "holomotion.src.env.isaaclab_components.isaaclab_observation" ) fake_observation_module.ObservationFunctions = object fake_utils_module = ModuleType( "holomotion.src.env.isaaclab_components.isaaclab_utils" ) fake_utils_module.resolve_holo_config = lambda value: value for name, module in { "isaaclab": isaaclab_root, "isaaclab.actuators": isaaclab_actuators, "isaaclab.assets": isaaclab_assets, "isaaclab.envs": isaaclab_envs, "isaaclab.envs.mdp": isaaclab_envs_mdp, "isaaclab.envs.mdp.events": isaaclab_envs_mdp_events, "isaaclab.managers": isaaclab_managers, "isaaclab.sim": isaaclab_sim, "isaaclab.utils": isaaclab_utils, "isaaclab.utils.io": isaaclab_utils_io, "isaaclab.utils.math": isaaclab_utils_math, "easydict": easydict, "omegaconf": omegaconf, "loguru": loguru, "holomotion.src.env.isaaclab_components": isaaclab_components, ( "holomotion.src.env.isaaclab_components.isaaclab_observation" ): fake_observation_module, ( "holomotion.src.env.isaaclab_components.isaaclab_utils" ): fake_utils_module, }.items(): monkeypatch.setitem(sys.modules, name, module) isaaclab_root.actuators = isaaclab_actuators isaaclab_root.assets = isaaclab_assets isaaclab_root.envs = isaaclab_envs isaaclab_root.managers = isaaclab_managers isaaclab_root.sim = isaaclab_sim isaaclab_root.utils = isaaclab_utils isaaclab_envs.mdp = isaaclab_envs_mdp isaaclab_utils.io = isaaclab_utils_io isaaclab_utils.math = isaaclab_utils_math module_name = "_test_motion_tracking" spec = importlib.util.spec_from_file_location( module_name, MOTION_TRACKING_PATH ) module = importlib.util.module_from_spec(spec) assert spec is not None assert spec.loader is not None sys.modules[module_name] = module spec.loader.exec_module(module) return module def _make_env(): env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32) robot_data = SimpleNamespace( body_pos_w=torch.tensor( [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32 ), body_quat_w=_identity_quat(1, 2), body_lin_vel_w=torch.tensor( [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32 ), body_ang_vel_w=torch.tensor( [[[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]], dtype=torch.float32 ), ) robot = SimpleNamespace(body_names=["anchor", "target"], data=robot_data) command = SimpleNamespace( robot=robot, anchor_bodylink_idx=0, get_ref_motion_root_global_pos_cur=lambda prefix="ref_": torch.tensor( [[10.0, 0.0, 0.0]], dtype=torch.float32 ), get_ref_motion_root_global_pos_immediate_next=( lambda prefix="ref_": torch.tensor( [[10.0, 0.0, 0.0]], dtype=torch.float32 ) ), get_ref_motion_root_global_rot_quat_wxyz_cur=( lambda prefix="ref_": _identity_quat(1) ), get_ref_motion_root_global_rot_quat_wxyz_immediate_next=( lambda prefix="ref_": _identity_quat(1) ), get_ref_motion_root_global_lin_vel_cur=( lambda prefix="ref_": torch.zeros(1, 3, dtype=torch.float32) ), get_ref_motion_root_global_lin_vel_immediate_next=( lambda prefix="ref_": torch.zeros(1, 3, dtype=torch.float32) ), get_ref_motion_root_global_ang_vel_cur=( lambda prefix="ref_": torch.tensor([[0.0, 0.0, 1.0]]) ), get_ref_motion_root_global_ang_vel_immediate_next=( lambda prefix="ref_": torch.tensor([[0.0, 0.0, 1.0]]) ), get_ref_motion_bodylink_global_pos_cur=( lambda prefix="ref_": torch.tensor( [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32 ) ), get_ref_motion_bodylink_global_pos_immediate_next=( lambda prefix="ref_": torch.tensor( [[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32 ) ), get_ref_motion_bodylink_global_lin_vel_cur=( lambda prefix="ref_": torch.tensor( [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32 ) ), get_ref_motion_bodylink_global_lin_vel_immediate_next=( lambda prefix="ref_": torch.tensor( [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32 ) ), get_ref_motion_bodylink_global_rot_wxyz_immediate_next=( lambda prefix="ref_": _identity_quat(1, 2) ), get_ref_motion_bodylink_global_ang_vel_immediate_next=( lambda prefix="ref_": torch.tensor( [[[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]], dtype=torch.float32 ) ), ) return SimpleNamespace( command_manager=SimpleNamespace(get_term=lambda name: command), scene=SimpleNamespace(env_origins=env_origins), ) def _make_torque_rate_env( applied_torque: torch.Tensor, actuators: dict, joint_vel: torch.Tensor | None = None, joint_vel_limits: torch.Tensor | None = None, ): class _Scene(dict): pass if joint_vel is None: joint_vel = torch.zeros_like(applied_torque) if joint_vel_limits is None: joint_vel_limits = torch.ones_like(applied_torque) asset = SimpleNamespace( data=SimpleNamespace( applied_torque=applied_torque.clone(), joint_vel=joint_vel.clone(), joint_vel_limits=joint_vel_limits.clone(), ), actuators=actuators, ) scene = _Scene(robot=asset) return SimpleNamespace( scene=scene, num_envs=applied_torque.shape[0], device=applied_torque.device, episode_length_buf=torch.zeros( applied_torque.shape[0], dtype=torch.long, device=applied_torque.device, ), ) def _make_action_acc_env(action: torch.Tensor): return SimpleNamespace( action_manager=SimpleNamespace(action=action.clone()), num_envs=action.shape[0], device=action.device, episode_length_buf=torch.zeros( action.shape[0], dtype=torch.long, device=action.device ), ) def test_root_rel_keybody_pos_reward_uses_true_root_frame(monkeypatch): rewards = _load_rewards_module(monkeypatch) env = _make_env() rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros( 1, 3, dtype=torch.float32 ) rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1) reward = rewards.root_rel_keybodylink_pos_tracking_l2_exp( env, std=1.0, keybody_names=["target"], ) assert torch.allclose(reward, torch.ones(1)) def test_root_rel_keybody_pos_bydmmc_reward_uses_true_root_frame( monkeypatch, ): rewards = _load_rewards_module(monkeypatch) env = _make_env() rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros( 1, 3, dtype=torch.float32 ) rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1) reward = rewards.root_rel_keybodylink_pos_tracking_l2_exp_bydmmc_style( env, std=1.0, keybody_names=["target"], ) assert torch.allclose(reward, torch.ones(1)) def test_root_rel_keybody_lin_vel_reward_uses_true_root_frame(monkeypatch): rewards = _load_rewards_module(monkeypatch) env = _make_env() rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros( 1, 3, dtype=torch.float32 ) rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1) rewards.isaaclab_mdp.root_lin_vel_w = lambda _env: torch.zeros( 1, 3, dtype=torch.float32 ) rewards.isaaclab_mdp.root_ang_vel_w = lambda _env: torch.tensor( [[0.0, 0.0, 1.0]], dtype=torch.float32 ) reward = rewards.root_rel_keybodylink_lin_vel_tracking_l2_exp( env, std=1.0, keybody_names=["target"], ) assert torch.allclose(reward, torch.ones(1)) def test_root_pos_xy_tracking_uses_immediate_next_reference(monkeypatch): rewards = _load_rewards_module(monkeypatch) robot_data = SimpleNamespace( root_pos_w=torch.tensor([[1.0, 2.0, 0.0]], dtype=torch.float32) ) robot = SimpleNamespace(data=robot_data) command = SimpleNamespace( robot=robot, get_ref_motion_root_global_pos_cur=( lambda prefix="ref_": (_ for _ in ()).throw( AssertionError("current reference should not be used") ) ), get_ref_motion_root_global_pos_immediate_next=( lambda prefix="ref_": torch.tensor( [[1.0, 2.0, 3.0]], dtype=torch.float32 ) ), ) env = SimpleNamespace( command_manager=SimpleNamespace(get_term=lambda name: command) ) reward = rewards.root_pos_xy_tracking_exp(env, std=1.0) assert torch.allclose(reward, torch.ones(1)) def test_global_keybody_lin_vel_tracking_uses_immediate_next_reference( monkeypatch, ): rewards = _load_rewards_module(monkeypatch) robot_data = SimpleNamespace( body_lin_vel_w=torch.tensor( [[[0.0, 0.0, 0.0], [3.0, 4.0, 0.0]]], dtype=torch.float32 ) ) robot = SimpleNamespace( body_names=["anchor", "target"], data=robot_data, ) command = SimpleNamespace( robot=robot, get_ref_motion_bodylink_global_lin_vel_cur=( lambda prefix="ref_": (_ for _ in ()).throw( AssertionError("current reference should not be used") ) ), get_ref_motion_bodylink_global_lin_vel_immediate_next=( lambda prefix="ref_": torch.tensor( [[[0.0, 0.0, 0.0], [3.0, 4.0, 0.0]]], dtype=torch.float32 ) ), ) env = SimpleNamespace( command_manager=SimpleNamespace(get_term=lambda name: command) ) reward = rewards.global_keybodylink_lin_vel_tracking_l2_exp( env, std=1.0, keybody_names=["target"], ) assert torch.allclose(reward, torch.ones(1)) def test_normed_torque_rate_matches_selected_joint_math(monkeypatch): rewards = _load_rewards_module(monkeypatch) env = _make_torque_rate_env( applied_torque=torch.zeros(2, 3, dtype=torch.float32), actuators={ "all_joints": SimpleNamespace( joint_indices=slice(None), effort_limit=torch.tensor( [[10.0, 20.0, 40.0], [10.0, 20.0, 40.0]], dtype=torch.float32, ), ) }, ) term = rewards.normed_torque_rate(_DummyConfig(params={}), env) asset_cfg = SimpleNamespace( name="robot", joint_ids=torch.tensor([0, 2], dtype=torch.long) ) first = term(env, asset_cfg=asset_cfg) assert torch.allclose(first, torch.zeros(2)) env.episode_length_buf[:] = 1 env.scene["robot"].data.applied_torque = torch.tensor( [[1.0, 9.0, 4.0], [2.0, 7.0, 8.0]], dtype=torch.float32, ) reward = term(env, asset_cfg=asset_cfg) expected = torch.tensor( [ (1.0 / 10.0) ** 2 + (4.0 / 40.0) ** 2, (2.0 / 10.0) ** 2 + (8.0 / 40.0) ** 2, ], dtype=torch.float32, ) assert torch.allclose(reward, expected) def test_normed_torque_rate_assembles_limits_across_actuator_groups( monkeypatch, ): rewards = _load_rewards_module(monkeypatch) env = _make_torque_rate_env( applied_torque=torch.zeros(1, 3, dtype=torch.float32), actuators={ "implicit_group": SimpleNamespace( joint_indices=torch.tensor([0, 2], dtype=torch.long), effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32), ), "unitree_group": SimpleNamespace( joint_indices=torch.tensor([1], dtype=torch.long), effort_limit=torch.tensor([[5.0]], dtype=torch.float32), ), }, ) term = rewards.normed_torque_rate(_DummyConfig(params={}), env) asset_cfg = SimpleNamespace( name="robot", joint_ids=torch.tensor([0, 1, 2], dtype=torch.long) ) _ = term(env, asset_cfg=asset_cfg) env.episode_length_buf[:] = 1 env.scene["robot"].data.applied_torque = torch.tensor( [[1.0, 1.0, 2.0]], dtype=torch.float32 ) reward = term(env, asset_cfg=asset_cfg) expected = torch.tensor( [(1.0 / 10.0) ** 2 + (1.0 / 5.0) ** 2 + (2.0 / 20.0) ** 2], dtype=torch.float32, ) assert torch.allclose(reward, expected) def test_normed_torque_rate_resets_first_step_history(monkeypatch): rewards = _load_rewards_module(monkeypatch) env = _make_torque_rate_env( applied_torque=torch.tensor( [[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32 ), actuators={ "all_joints": SimpleNamespace( joint_indices=slice(None), effort_limit=torch.tensor( [[10.0, 10.0], [10.0, 10.0]], dtype=torch.float32 ), ) }, ) term = rewards.normed_torque_rate(_DummyConfig(params={}), env) asset_cfg = SimpleNamespace( name="robot", joint_ids=torch.tensor([0, 1], dtype=torch.long) ) first = term(env, asset_cfg=asset_cfg) assert torch.allclose(first, torch.zeros(2)) env.episode_length_buf[:] = 1 env.scene["robot"].data.applied_torque = torch.tensor( [[2.0, 4.0], [5.0, 8.0]], dtype=torch.float32 ) second = term(env, asset_cfg=asset_cfg) assert torch.all(second > 0.0) term.reset(env_ids=[0]) env.scene["robot"].data.applied_torque = torch.tensor( [[7.0, 9.0], [6.0, 10.0]], dtype=torch.float32 ) after_reset = term(env, asset_cfg=asset_cfg) assert torch.isclose(after_reset[0], torch.tensor(0.0)) assert after_reset[1] > 0.0 def test_normed_torque_rate_reuses_cached_normalization(monkeypatch): rewards = _load_rewards_module(monkeypatch) actuator = SimpleNamespace( joint_indices=slice(None), effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32), ) env = _make_torque_rate_env( applied_torque=torch.zeros(1, 2, dtype=torch.float32), actuators={"all_joints": actuator}, ) term = rewards.normed_torque_rate(_DummyConfig(params={}), env) asset_cfg = SimpleNamespace( name="robot", joint_ids=torch.tensor([0, 1], dtype=torch.long) ) _ = term(env, asset_cfg=asset_cfg) actuator.effort_limit = torch.tensor( [[1000.0, 1000.0]], dtype=torch.float32 ) env.episode_length_buf[:] = 1 env.scene["robot"].data.applied_torque = torch.tensor( [[2.0, 4.0]], dtype=torch.float32 ) reward = term(env, asset_cfg=asset_cfg) expected = torch.tensor( [(2.0 / 10.0) ** 2 + (4.0 / 20.0) ** 2], dtype=torch.float32 ) assert torch.allclose(reward, expected) def test_normed_positive_work_matches_selected_joint_math(monkeypatch): rewards = _load_rewards_module(monkeypatch) env = _make_torque_rate_env( applied_torque=torch.tensor( [[2.0, 5.0, 8.0], [3.0, -4.0, 6.0]], dtype=torch.float32 ), joint_vel=torch.tensor( [[1.0, -5.0, 2.0], [2.0, 3.0, -2.0]], dtype=torch.float32 ), joint_vel_limits=torch.tensor( [[4.0, 10.0, 8.0], [4.0, 10.0, 8.0]], dtype=torch.float32 ), actuators={ "all_joints": SimpleNamespace( joint_indices=slice(None), effort_limit=torch.tensor( [[4.0, 10.0, 16.0], [4.0, 10.0, 16.0]], dtype=torch.float32, ), ) }, ) term = rewards.normed_positive_work(_DummyConfig(params={}), env) asset_cfg = SimpleNamespace( name="robot", joint_ids=torch.tensor([0, 2], dtype=torch.long) ) reward = term(env, asset_cfg=asset_cfg) expected = torch.tensor( [ (2.0 / 4.0) * (1.0 / 4.0) + (8.0 / 16.0) * (2.0 / 8.0), (3.0 / 4.0) * (2.0 / 4.0), ], dtype=torch.float32, ) assert torch.allclose(reward, expected) def test_normed_positive_work_assembles_effort_limits_across_actuators( monkeypatch, ): rewards = _load_rewards_module(monkeypatch) env = _make_torque_rate_env( applied_torque=torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float32), joint_vel=torch.tensor([[5.0, 2.0, -1.0]], dtype=torch.float32), joint_vel_limits=torch.tensor([[10.0, 4.0, 8.0]], dtype=torch.float32), actuators={ "implicit_group": SimpleNamespace( joint_indices=torch.tensor([0, 2], dtype=torch.long), effort_limit=torch.tensor([[4.0, 20.0]], dtype=torch.float32), ), "unitree_group": SimpleNamespace( joint_indices=torch.tensor([1], dtype=torch.long), effort_limit=torch.tensor([[6.0]], dtype=torch.float32), ), }, ) term = rewards.normed_positive_work(_DummyConfig(params={}), env) asset_cfg = SimpleNamespace( name="robot", joint_ids=torch.tensor([0, 1, 2], dtype=torch.long) ) reward = term(env, asset_cfg=asset_cfg) expected = torch.tensor( [ (2.0 / 4.0) * (5.0 / 10.0) + (3.0 / 6.0) * (2.0 / 4.0), ], dtype=torch.float32, ) assert torch.allclose(reward, expected) def test_normed_positive_work_reuses_cached_effort_limits(monkeypatch): rewards = _load_rewards_module(monkeypatch) actuator = SimpleNamespace( joint_indices=slice(None), effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32), ) env = _make_torque_rate_env( applied_torque=torch.tensor([[2.0, 4.0]], dtype=torch.float32), joint_vel=torch.tensor([[5.0, 10.0]], dtype=torch.float32), joint_vel_limits=torch.tensor([[10.0, 20.0]], dtype=torch.float32), actuators={"all_joints": actuator}, ) term = rewards.normed_positive_work(_DummyConfig(params={}), env) asset_cfg = SimpleNamespace( name="robot", joint_ids=torch.tensor([0, 1], dtype=torch.long) ) first = term(env, asset_cfg=asset_cfg) assert torch.allclose( first, torch.tensor( [(2.0 / 10.0) * (5.0 / 10.0) + (4.0 / 20.0) * (10.0 / 20.0)], dtype=torch.float32, ), ) actuator.effort_limit = torch.tensor( [[1000.0, 1000.0]], dtype=torch.float32 ) reward = term(env, asset_cfg=asset_cfg) expected = torch.tensor( [(2.0 / 10.0) * (5.0 / 10.0) + (4.0 / 20.0) * (10.0 / 20.0)], dtype=torch.float32, ) assert torch.allclose(reward, expected) def test_normed_positive_work_requires_positive_finite_velocity_limits( monkeypatch, ): rewards = _load_rewards_module(monkeypatch) env = _make_torque_rate_env( applied_torque=torch.tensor([[2.0, 4.0]], dtype=torch.float32), joint_vel=torch.tensor([[1.0, 1.0]], dtype=torch.float32), joint_vel_limits=torch.tensor([[0.0, torch.inf]], dtype=torch.float32), actuators={ "all_joints": SimpleNamespace( joint_indices=slice(None), effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32), ) }, ) term = rewards.normed_positive_work(_DummyConfig(params={}), env) try: term( env, asset_cfg=SimpleNamespace( name="robot", joint_ids=torch.tensor([0, 1], dtype=torch.long) ), ) except ValueError as exc: assert ( "normed_positive_work requires finite, strictly positive" in str(exc) ) else: raise AssertionError( "expected normed_positive_work to reject invalid limits" ) def test_action_acc_matches_second_order_action_change(monkeypatch): rewards = _load_rewards_module(monkeypatch) env = _make_action_acc_env(torch.zeros(2, 2, dtype=torch.float32)) term = rewards.action_acc(_DummyConfig(params={}), env) first = term(env) assert torch.allclose(first, torch.zeros(2)) env.episode_length_buf[:] = 1 env.action_manager.action = torch.tensor( [[1.0, 2.0], [2.0, 1.0]], dtype=torch.float32 ) second = term(env) assert torch.allclose(second, torch.zeros(2)) env.episode_length_buf[:] = 2 env.action_manager.action = torch.tensor( [[3.0, 1.0], [5.0, 1.0]], dtype=torch.float32 ) third = term(env) expected = torch.tensor([10.0, 2.0], dtype=torch.float32) assert torch.allclose(third, expected) def test_action_acc_reset_clears_history(monkeypatch): rewards = _load_rewards_module(monkeypatch) env = _make_action_acc_env(torch.zeros(1, 2, dtype=torch.float32)) term = rewards.action_acc(_DummyConfig(params={}), env) assert torch.allclose(term(env), torch.zeros(1)) env.episode_length_buf[:] = 1 env.action_manager.action = torch.tensor([[1.0, 1.0]], dtype=torch.float32) assert torch.allclose(term(env), torch.zeros(1)) env.episode_length_buf[:] = 2 env.action_manager.action = torch.tensor([[2.0, 0.0]], dtype=torch.float32) assert torch.allclose(term(env), torch.tensor([4.0])) term.reset(env_ids=[0]) env.action_manager.action = torch.tensor([[7.0, 7.0]], dtype=torch.float32) assert torch.allclose(term(env), torch.zeros(1)) def test_build_rewards_config_exposes_action_acc_term(monkeypatch): rewards = _load_rewards_module(monkeypatch) rewards_cfg = rewards.build_rewards_config( { "action_acc": { "weight": -2.5, "params": {}, } } ) assert rewards_cfg.action_acc.func is rewards.action_acc assert rewards_cfg.action_acc.weight == -2.5 assert rewards_cfg.action_acc.params == {} def test_motion_tracking_logs_normed_torque_rate_metric(monkeypatch): motion_tracking = _load_motion_tracking_module(monkeypatch) env = motion_tracking.MotionTrackingEnv.__new__( motion_tracking.MotionTrackingEnv ) env.metrics = {} env._robot_prev_joint_vel = None env._robot_prev_applied_torque = None env._robot_torque_rate_inv_effort_limit = None env._robot_torque_rate_needs_reseed = None robot = SimpleNamespace( data=SimpleNamespace( joint_vel=torch.zeros(2, 2, dtype=torch.float32), applied_torque=torch.zeros(2, 2, dtype=torch.float32), ), actuators={ "all_joints": SimpleNamespace( joint_indices=slice(None), effort_limit=torch.tensor( [[10.0, 20.0], [10.0, 20.0]], dtype=torch.float32 ), ) }, ) env._env = SimpleNamespace( step_dt=0.5, action_manager=SimpleNamespace( action=torch.zeros(2, 2, dtype=torch.float32), prev_action=torch.zeros(2, 2, dtype=torch.float32), ), scene={"robot": robot}, episode_length_buf=torch.zeros(2, dtype=torch.long), num_envs=2, device=torch.device("cpu"), ) infos = {"log": {}} env._update_robot_metrics(infos) assert torch.isclose( infos["log"]["Metrics/Robot/Normed_Torque_Rate"], torch.tensor(0.0), ) env._env.episode_length_buf[:] = 1 robot.data.applied_torque = torch.tensor( [[1.0, 4.0], [2.0, 8.0]], dtype=torch.float32 ) env._update_robot_metrics(infos) expected = torch.tensor( [ (1.0 / 10.0) ** 2 + (4.0 / 20.0) ** 2, (2.0 / 10.0) ** 2 + (8.0 / 20.0) ** 2, ], dtype=torch.float32, ).mean() assert torch.allclose( infos["log"]["Metrics/Robot/Normed_Torque_Rate"], expected ) ================================================ FILE: tests/test_unitree_actuators.py ================================================ import importlib.util import json import sys from pathlib import Path from types import ModuleType, SimpleNamespace import pytest import torch ACTUATOR_MODULE_PATH = ( Path(__file__).resolve().parents[1] / "holomotion" / "src" / "env" / "isaaclab_components" / "unitree_actuators.py" ) SCENE_MODULE_PATH = ( Path(__file__).resolve().parents[1] / "holomotion" / "src" / "env" / "isaaclab_components" / "isaaclab_scene.py" ) class _DummyArticulationActions: def __init__( self, joint_positions=None, joint_velocities=None, joint_efforts=None, joint_indices=None, ): self.joint_positions = joint_positions self.joint_velocities = joint_velocities self.joint_efforts = joint_efforts self.joint_indices = joint_indices class _DummyDelayedPDActuatorCfg: min_delay = 0 max_delay = 0 def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) class _DummyDelayedPDActuator: def __init__(self, cfg, *args, **kwargs): self.cfg = cfg self._num_envs = kwargs.get("num_envs", 4) self._device = kwargs.get("device", "cpu") self.num_joints = len( kwargs.get("joint_names", ["joint_a", "joint_b"]) ) self.computed_effort = torch.zeros( self._num_envs, self.num_joints, device=self._device ) self.applied_effort = torch.zeros_like(self.computed_effort) effort_limit = kwargs.get("effort_limit", 100.0) if isinstance(effort_limit, torch.Tensor): self.effort_limit = effort_limit.clone().to(device=self._device) else: self.effort_limit = torch.full_like( self.computed_effort, float(effort_limit) ) self.super_compute_inputs = [] self.super_compute_joint_positions = [] self.reset_calls = [] def _parse_joint_parameter(self, value, default): if value is None: value = default if isinstance(value, torch.Tensor): return value.clone().to(device=self._device) if isinstance(value, dict): values = list(value.values()) tensor = torch.tensor( values, dtype=torch.float32, device=self._device ) return tensor.unsqueeze(0).repeat(self._num_envs, 1) if isinstance(value, (float, int)): return torch.full_like(self.computed_effort, float(value)) raise TypeError(f"Unsupported parameter type: {type(value)}") def reset(self, env_ids): self.reset_calls.append(env_ids) def compute(self, control_action, joint_pos, joint_vel): if control_action.joint_efforts is None: self.super_compute_inputs.append(None) else: self.super_compute_inputs.append( control_action.joint_efforts.clone() ) if control_action.joint_positions is None: self.super_compute_joint_positions.append(None) else: self.super_compute_joint_positions.append( control_action.joint_positions.clone() ) self.computed_effort = control_action.joint_efforts.clone() self.applied_effort = control_action.joint_efforts.clone() return control_action def _configclass(cls): annotations = getattr(cls, "__annotations__", {}) defaults = { name: getattr(cls, name) for name in annotations if hasattr(cls, name) } def __init__(self, **kwargs): for name, value in defaults.items(): setattr(self, name, value) for key, value in kwargs.items(): setattr(self, key, value) cls.__init__ = __init__ return cls def _load_unitree_actuator_module(monkeypatch): isaaclab_root = ModuleType("isaaclab") isaaclab_actuators = ModuleType("isaaclab.actuators") isaaclab_actuators.DelayedPDActuator = _DummyDelayedPDActuator isaaclab_actuators.DelayedPDActuatorCfg = _DummyDelayedPDActuatorCfg isaaclab_utils = ModuleType("isaaclab.utils") isaaclab_utils.configclass = _configclass isaaclab_utils_types = ModuleType("isaaclab.utils.types") isaaclab_utils_types.ArticulationActions = _DummyArticulationActions for name, module in { "isaaclab": isaaclab_root, "isaaclab.actuators": isaaclab_actuators, "isaaclab.utils": isaaclab_utils, "isaaclab.utils.types": isaaclab_utils_types, }.items(): monkeypatch.setitem(sys.modules, name, module) isaaclab_root.actuators = isaaclab_actuators isaaclab_root.utils = isaaclab_utils isaaclab_utils.types = isaaclab_utils_types module_name = "_test_unitree_actuators" spec = importlib.util.spec_from_file_location( module_name, ACTUATOR_MODULE_PATH ) module = importlib.util.module_from_spec(spec) assert spec is not None assert spec.loader is not None sys.modules[module_name] = module spec.loader.exec_module(module) return module def _make_erfi_actuator(module, *, cfg_kwargs=None, num_envs=4, num_joints=3): if cfg_kwargs is None: cfg_kwargs = {} cfg_defaults = { "Y1": 100.0, "Y2": 120.0, "erfi_enabled": True, "ema_filter_enabled": False, "ema_filter_alpha": 1.0, "ema_filter_debug_dump_path": None, "ema_filter_debug_stop_after_dump": False, "rfi_probability": 0.5, "rfi_lim": 0.1, "randomize_rfi_lim": True, "rfi_lim_range": (0.5, 1.5), "rao_lim": 0.1, } cfg_defaults.update(cfg_kwargs) cfg = module.UnitreeErfiActuatorCfg(**cfg_defaults) actuator = module.UnitreeErfiActuator( cfg, joint_names=[f"joint_{idx}" for idx in range(num_joints)], joint_ids=torch.arange(num_joints), num_envs=num_envs, device="cpu", stiffness=0.0, damping=0.0, armature=0.0, friction=0.0, dynamic_friction=0.0, viscous_friction=0.0, effort_limit=100.0, velocity_limit=100.0, ) return actuator def _make_action(actuator): return _DummyArticulationActions( joint_positions=torch.zeros_like(actuator.computed_effort), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) def test_unitree_erfi_reset_samples_all_rfi(monkeypatch): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={"rfi_probability": 1.0}, ) actuator.reset(torch.tensor([0, 1, 2, 3], dtype=torch.long)) assert torch.all(actuator._mode_is_rfi) assert torch.allclose( actuator._rao_scale, torch.zeros_like(actuator._rao_scale) ) def test_unitree_erfi_reset_samples_all_rao(monkeypatch): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={"rfi_probability": 0.0}, ) actuator.reset(torch.tensor([0, 1, 2, 3], dtype=torch.long)) assert not torch.any(actuator._mode_is_rfi) assert torch.any(actuator._rao_scale != 0.0) def test_unitree_erfi_rfi_without_randomized_limit_uses_effort_limit_ratio( monkeypatch, ): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={ "rfi_probability": 1.0, "randomize_rfi_lim": False, "rfi_lim": 0.1, }, num_envs=2, num_joints=2, ) actuator.reset(torch.tensor([0, 1], dtype=torch.long)) torch.manual_seed(0) actuator.compute( _make_action(actuator), joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) injected = actuator.super_compute_inputs[-1] assert torch.all(injected.abs() <= 10.0 + 1.0e-6) def test_unitree_erfi_reset_randomizes_rfi_scale_within_range(monkeypatch): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={ "rfi_probability": 1.0, "rfi_lim_range": (0.5, 1.5), }, num_envs=2, num_joints=2, ) actuator.reset(torch.tensor([0, 1], dtype=torch.long)) assert torch.all(actuator._rfi_lim_scale >= 0.5) assert torch.all(actuator._rfi_lim_scale <= 1.5) def test_unitree_erfi_rao_bias_stays_constant_between_resets(monkeypatch): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={"rfi_probability": 0.0, "rao_lim": 0.1}, num_envs=2, num_joints=2, ) actuator.reset(torch.tensor([0, 1], dtype=torch.long)) action = _make_action(actuator) actuator.compute( action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) first = actuator.super_compute_inputs[-1].clone() actuator.compute( action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) second = actuator.super_compute_inputs[-1].clone() assert torch.allclose(first, second) def test_unitree_erfi_rfi_changes_each_compute(monkeypatch): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={"rfi_probability": 1.0, "randomize_rfi_lim": False}, num_envs=2, num_joints=2, ) actuator.reset(torch.tensor([0, 1], dtype=torch.long)) action = _make_action(actuator) torch.manual_seed(0) actuator.compute( action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) first = actuator.super_compute_inputs[-1].clone() torch.manual_seed(1) actuator.compute( action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) second = actuator.super_compute_inputs[-1].clone() assert not torch.allclose(first, second) def test_unitree_erfi_disabled_matches_plain_unitree(monkeypatch): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={"erfi_enabled": False}, num_envs=2, num_joints=2, ) action = _make_action(actuator) actuator.reset(torch.tensor([0, 1], dtype=torch.long)) actuator.compute( action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) assert torch.allclose( actuator.super_compute_inputs[-1], torch.zeros_like(actuator.super_compute_inputs[-1]), ) def test_unitree_erfi_ema_filters_joint_positions(monkeypatch): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={ "erfi_enabled": False, "ema_filter_enabled": True, "ema_filter_alpha": 0.25, }, num_envs=2, num_joints=2, ) first_action = _DummyArticulationActions( joint_positions=torch.tensor( [[1.0, -1.0], [0.5, -0.5]], dtype=torch.float32 ), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) second_action = _DummyArticulationActions( joint_positions=torch.tensor( [[3.0, 1.0], [1.5, 0.5]], dtype=torch.float32 ), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) actuator.compute( first_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) actuator.compute( second_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) assert torch.allclose( actuator.super_compute_joint_positions[0], first_action.joint_positions, ) expected_second = ( 0.25 * second_action.joint_positions + 0.75 * first_action.joint_positions ) assert torch.allclose( actuator.super_compute_joint_positions[1], expected_second ) def test_unitree_erfi_ema_reset_clears_only_selected_envs(monkeypatch): module = _load_unitree_actuator_module(monkeypatch) actuator = _make_erfi_actuator( module, cfg_kwargs={ "erfi_enabled": False, "ema_filter_enabled": True, "ema_filter_alpha": 0.5, }, num_envs=2, num_joints=1, ) one_action = _DummyArticulationActions( joint_positions=torch.tensor([[1.0], [1.0]], dtype=torch.float32), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) two_action = _DummyArticulationActions( joint_positions=torch.tensor([[2.0], [2.0]], dtype=torch.float32), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) zero_action = _DummyArticulationActions( joint_positions=torch.zeros_like(actuator.computed_effort), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) actuator.compute( one_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) actuator.compute( two_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) actuator.reset(torch.tensor([1], dtype=torch.long)) actuator.compute( zero_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) assert torch.allclose( actuator.super_compute_joint_positions[1], torch.tensor([[1.5], [1.5]], dtype=torch.float32), ) assert torch.allclose( actuator.super_compute_joint_positions[2], torch.tensor([[0.75], [0.0]], dtype=torch.float32), ) def test_unitree_erfi_ema_debug_dump_records_formula(monkeypatch, tmp_path): module = _load_unitree_actuator_module(monkeypatch) dump_path = tmp_path / "ema_verify.json" actuator = _make_erfi_actuator( module, cfg_kwargs={ "erfi_enabled": False, "ema_filter_enabled": True, "ema_filter_alpha": 0.25, "ema_filter_debug_dump_path": str(dump_path), }, num_envs=2, num_joints=2, ) first_action = _DummyArticulationActions( joint_positions=torch.tensor( [[1.0, -1.0], [0.5, -0.5]], dtype=torch.float32 ), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) second_action = _DummyArticulationActions( joint_positions=torch.tensor( [[3.0, 1.0], [1.5, 0.5]], dtype=torch.float32 ), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) actuator.compute( first_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) actuator.compute( second_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) assert dump_path.is_file() payload = json.loads(dump_path.read_text()) expected_second = ( 0.25 * second_action.joint_positions[0] + 0.75 * first_action.joint_positions[0] ) assert payload["alpha"] == 0.25 assert payload["matched"] is True assert payload["env_index"] == 0 assert payload["raw_joint_positions"] == [3.0, 1.0] assert payload["previous_filtered_joint_positions"] == [1.0, -1.0] assert payload["expected_filtered_joint_positions"] == pytest.approx( expected_second.tolist() ) assert payload["actual_filtered_joint_positions"] == pytest.approx( expected_second.tolist() ) def test_unitree_erfi_ema_debug_stop_after_dump(monkeypatch, tmp_path): module = _load_unitree_actuator_module(monkeypatch) dump_path = tmp_path / "ema_verify.json" actuator = _make_erfi_actuator( module, cfg_kwargs={ "erfi_enabled": False, "ema_filter_enabled": True, "ema_filter_alpha": 0.5, "ema_filter_debug_dump_path": str(dump_path), "ema_filter_debug_stop_after_dump": True, }, num_envs=1, num_joints=1, ) first_action = _DummyArticulationActions( joint_positions=torch.tensor([[1.0]], dtype=torch.float32), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) second_action = _DummyArticulationActions( joint_positions=torch.tensor([[3.0]], dtype=torch.float32), joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) actuator.compute( first_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) with pytest.raises(RuntimeError, match="EMA verification dump written"): actuator.compute( second_action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) assert dump_path.is_file() def test_unitree_erfi_ema_debug_dump_records_skip_reason( monkeypatch, tmp_path ): module = _load_unitree_actuator_module(monkeypatch) dump_path = tmp_path / "ema_verify_skip.json" actuator = _make_erfi_actuator( module, cfg_kwargs={ "erfi_enabled": False, "ema_filter_enabled": True, "ema_filter_debug_dump_path": str(dump_path), "ema_filter_debug_stop_after_dump": True, }, num_envs=1, num_joints=1, ) action = _DummyArticulationActions( joint_positions=None, joint_velocities=torch.zeros_like(actuator.computed_effort), joint_efforts=torch.zeros_like(actuator.computed_effort), ) with pytest.raises(RuntimeError, match="EMA verification dump written"): actuator.compute( action, joint_pos=torch.zeros_like(actuator.computed_effort), joint_vel=torch.zeros_like(actuator.computed_effort), ) payload = json.loads(dump_path.read_text()) assert payload["applied"] is False assert payload["reason"] == "joint_positions_none" def _load_scene_module(monkeypatch): actuator_module = _load_unitree_actuator_module(monkeypatch) isaaclab_root = ModuleType("isaaclab") isaaclab_sim = ModuleType("isaaclab.sim") isaaclab_sim.UrdfFileCfg = lambda **kwargs: SimpleNamespace(**kwargs) isaaclab_sim.RigidBodyPropertiesCfg = lambda **kwargs: SimpleNamespace( **kwargs ) isaaclab_sim.ArticulationRootPropertiesCfg = ( lambda **kwargs: SimpleNamespace(**kwargs) ) isaaclab_sim.UrdfConverterCfg = SimpleNamespace( JointDriveCfg=SimpleNamespace( PDGainsCfg=lambda **kwargs: SimpleNamespace(**kwargs) ) ) isaaclab_actuators = ModuleType("isaaclab.actuators") isaaclab_actuators.ImplicitActuatorCfg = lambda **kwargs: SimpleNamespace( **kwargs ) isaaclab_assets = ModuleType("isaaclab.assets") isaaclab_assets.ArticulationCfg = SimpleNamespace( InitialStateCfg=lambda **kwargs: SimpleNamespace(**kwargs) ) isaaclab_assets.ArticulationCfg = lambda **kwargs: SimpleNamespace( **kwargs ) isaaclab_assets.AssetBaseCfg = lambda **kwargs: SimpleNamespace(**kwargs) isaaclab_scene = ModuleType("isaaclab.scene") isaaclab_scene.InteractiveSceneCfg = object isaaclab_sensors = ModuleType("isaaclab.sensors") isaaclab_sensors.ContactSensorCfg = lambda **kwargs: SimpleNamespace( **kwargs ) isaaclab_sensors.RayCasterCfg = SimpleNamespace( OffsetCfg=lambda **kwargs: SimpleNamespace(**kwargs) ) isaaclab_sensors.patterns = SimpleNamespace( GridPatternCfg=lambda **kwargs: SimpleNamespace(**kwargs) ) isaaclab_terrains = ModuleType("isaaclab.terrains") isaaclab_terrains.TerrainImporterCfg = object isaaclab_utils = ModuleType("isaaclab.utils") isaaclab_utils.configclass = _configclass loguru = ModuleType("loguru") loguru.logger = SimpleNamespace(info=lambda *args, **kwargs: None) fake_terrain = ModuleType( "holomotion.src.env.isaaclab_components.isaaclab_terrain" ) fake_terrain.build_terrain_config = lambda *args, **kwargs: None fake_unitree = ModuleType( "holomotion.src.env.isaaclab_components.unitree_actuators" ) fake_unitree.UnitreeActuator = actuator_module.UnitreeActuator fake_unitree.UnitreeActuatorCfg = actuator_module.UnitreeActuatorCfg fake_unitree.UnitreeErfiActuator = actuator_module.UnitreeErfiActuator fake_unitree.UnitreeErfiActuatorCfg = ( actuator_module.UnitreeErfiActuatorCfg ) for name, module in { "isaaclab": isaaclab_root, "isaaclab.sim": isaaclab_sim, "isaaclab.actuators": isaaclab_actuators, "isaaclab.assets": isaaclab_assets, "isaaclab.scene": isaaclab_scene, "isaaclab.sensors": isaaclab_sensors, "isaaclab.terrains": isaaclab_terrains, "isaaclab.utils": isaaclab_utils, "loguru": loguru, ( "holomotion.src.env.isaaclab_components.isaaclab_terrain" ): fake_terrain, ( "holomotion.src.env.isaaclab_components.unitree_actuators" ): fake_unitree, }.items(): monkeypatch.setitem(sys.modules, name, module) module_name = "_test_isaaclab_scene" spec = importlib.util.spec_from_file_location( module_name, SCENE_MODULE_PATH ) module = importlib.util.module_from_spec(spec) assert spec is not None assert spec.loader is not None sys.modules[module_name] = module spec.loader.exec_module(module) return module def test_scene_builder_selects_unitree_erfi_cfg(monkeypatch): module = _load_scene_module(monkeypatch) actuators = module._build_unitree_actuator_cfg( {"actuator_type": "unitree_erfi"}, {"erfi": {"enabled": True, "rfi_lim": 0.2}}, ) assert isinstance(actuators["all_joints"], module.UnitreeErfiActuatorCfg) assert actuators["all_joints"].erfi_enabled is True assert actuators["all_joints"].rfi_lim == 0.2 def test_scene_builder_keeps_plain_unitree_cfg(monkeypatch): module = _load_scene_module(monkeypatch) actuators = module._build_unitree_actuator_cfg( {"actuator_type": "unitree"}, {} ) assert isinstance(actuators["all_joints"], module.UnitreeActuatorCfg) assert not hasattr(actuators["all_joints"], "rfi_lim") def test_scene_builder_disables_erfi_when_domain_rand_missing(monkeypatch): module = _load_scene_module(monkeypatch) actuators = module._build_unitree_actuator_cfg( {"actuator_type": "unitree_erfi"}, {} ) assert isinstance(actuators["all_joints"], module.UnitreeErfiActuatorCfg) assert actuators["all_joints"].erfi_enabled is False def test_scene_builder_applies_domain_rand_action_delay_to_unitree( monkeypatch, ): module = _load_scene_module(monkeypatch) actuators = module._build_unitree_actuator_cfg( {"actuator_type": "unitree"}, {"action_delay": {"enabled": True, "min_delay": 1, "max_delay": 3}}, ) assert isinstance(actuators["all_joints"], module.UnitreeActuatorCfg) assert actuators["all_joints"].min_delay == 1 assert actuators["all_joints"].max_delay == 3 def test_scene_builder_applies_domain_rand_action_delay_to_unitree_erfi( monkeypatch, ): module = _load_scene_module(monkeypatch) actuators = module._build_unitree_actuator_cfg( {"actuator_type": "unitree_erfi"}, { "erfi": {"enabled": True}, "action_delay": { "enabled": True, "min_delay": 2, "max_delay": 4, }, }, ) assert isinstance(actuators["all_joints"], module.UnitreeErfiActuatorCfg) assert actuators["all_joints"].min_delay == 2 assert actuators["all_joints"].max_delay == 4 def test_scene_builder_applies_erfi_ema_filter_config(monkeypatch): module = _load_scene_module(monkeypatch) actuators = module._build_unitree_actuator_cfg( { "actuator_type": "unitree_erfi", "ema_filter_enabled": True, "ema_filter_alpha": 0.37, }, {"erfi": {"enabled": True}}, ) assert isinstance(actuators["all_joints"], module.UnitreeErfiActuatorCfg) assert actuators["all_joints"].class_type.__name__ == "UnitreeErfiActuator" assert actuators["all_joints"].ema_filter_enabled is True assert actuators["all_joints"].ema_filter_alpha == 0.37 def test_scene_builder_disables_action_delay_when_domain_rand_missing( monkeypatch, ): module = _load_scene_module(monkeypatch) actuators = module._build_unitree_actuator_cfg( {"actuator_type": "unitree"}, {} ) assert isinstance(actuators["all_joints"], module.UnitreeActuatorCfg) assert actuators["all_joints"].min_delay == 0 assert actuators["all_joints"].max_delay == 0 ================================================ FILE: tests/test_visualize_with_mujoco.py ================================================ import sys from pathlib import Path import numpy as np sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from holomotion.src.motion_retargeting.utils.visualize_with_mujoco import ( _resolve_visualization_arrays, ) def test_resolve_visualization_arrays_uses_robot_for_pose_and_ref_for_overlay(): arrays = { "robot_dof_pos": np.array([[1.0, 2.0]], dtype=np.float32), "robot_global_translation": np.array( [[[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]], dtype=np.float32 ), "robot_global_rotation_quat": np.array( [[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]], dtype=np.float32, ), "ref_global_translation": np.array( [[[20.0, 21.0, 22.0], [23.0, 24.0, 25.0]]], dtype=np.float32 ), } resolved = _resolve_visualization_arrays( arrays=arrays, key_prefix_order=["robot_"], draw_ref_body_spheres=True, ref_key_prefix_order=["ref_"], ) np.testing.assert_allclose(resolved["dof_pos"], arrays["robot_dof_pos"]) np.testing.assert_allclose( resolved["global_translation"], arrays["robot_global_translation"] ) np.testing.assert_allclose( resolved["global_rotation_quat"], arrays["robot_global_rotation_quat"], ) np.testing.assert_allclose( resolved["ref_body_positions"], arrays["ref_global_translation"], ) ================================================ FILE: train.env ================================================ # This is the environment file for running HoloMotion scripts. export CONDA_BASE=$(conda info --base) export Train_CONDA_PREFIX="$CONDA_BASE/envs/holomotion_train" # export CUDA_HOME=$Train_CONDA_PREFIX export CUDA_HOME=/usr/local/cuda # export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$Train_CONDA_PREFIX/lib/:$Train_CONDA_PREFIX/lib/stubs" # export LIBRARY_PATH="$Train_CONDA_PREFIX/lib/stubs:$Train_CONDA_PREFIX/lib:$LIBRARY_PATH" export HYDRA_FULL_ERROR=1 export OMNI_KIT_ACCEPT_EULA="YES" export ACCEPT_EULA="YES" # export CUDA_LAUNCH_BLOCKING=1 export USE_NVRTC=1 export HDF5_USE_FILE_LOCKING=FALSE export HOLOMOTION_ISAAC_STAGGER_SEC=1 export HOLOMOTION_HDF5_RDCC_NBYTES=$((4 * 1024 * 1024)) # 4MB export HOLOMOTION_HDF5_MAX_OPEN_SHARDS=16 # 16 shards export TORCH_DISTRIBUTED_DEBUG=INFO # export TORCHDYNAMO_VERBOSE=0 echo "--------------------------------" echo "Train_CONDA_PREFIX: $Train_CONDA_PREFIX" echo "CUDA_HOME: $CUDA_HOME" echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" echo "LIBRARY_PATH: $LIBRARY_PATH" echo "HYDRA_FULL_ERROR: $HYDRA_FULL_ERROR" echo "OMNI_KIT_ACCEPT_EULA: $OMNI_KIT_ACCEPT_EULA" echo "HDF5_USE_FILE_LOCKING: $HDF5_USE_FILE_LOCKING" echo "HOLOMOTION_ISAAC_STAGGER_SEC: $HOLOMOTION_ISAAC_STAGGER_SEC" echo "HOLOMOTION_HDF5_RDCC_NBYTES: $HOLOMOTION_HDF5_RDCC_NBYTES" echo "HOLOMOTION_HDF5_MAX_OPEN_SHARDS: $HOLOMOTION_HDF5_MAX_OPEN_SHARDS" echo "--------------------------------" # Graceful shutdown function for training scripts # Note: Scripts must set TRAIN_PID variable and call: trap cleanup SIGINT SIGTERM cleanup() { echo "" echo "🛑 Cleanup triggered - shutting down training process ${TRAIN_PID}..." exec 2>/dev/null # Suppress error messages during cleanup [[ -n "${TRAIN_PID}" ]] && kill -TERM "${TRAIN_PID}" 2>/dev/null && echo " ✓ Sent TERM signal to process ${TRAIN_PID}" sleep 2 [[ -n "${TRAIN_PID}" ]] && pkill -P "${TRAIN_PID}" 2>/dev/null && echo " ✓ Killed child processes" exec 2>&1 echo " ✓ Cleanup complete" }