Repository: HorizonRobotics/HoloMotion
Branch: master
Commit: 75ca12750541
Files: 249
Total size: 2.8 MB
Directory structure:
gitextract_3fyqn965/
├── .gitattributes
├── .githooks/
│ ├── README.md
│ └── pre-commit
├── .gitignore
├── .gitlab-ci.yml
├── .gitmodules
├── LICENSE
├── Makefile
├── NOTICE
├── README.md
├── assets/
│ ├── robots/
│ │ └── unitree/
│ │ └── G1/
│ │ └── 29dof/
│ │ ├── g1_29dof.xml
│ │ ├── g1_29dof_rev_1_0.urdf
│ │ ├── g1_29dof_rev_1_0.xml
│ │ ├── g1_29dof_rev_1_0_s100.urdf
│ │ └── scene_29dof.xml
│ └── test_data/
│ └── motion_retargeting/
│ └── ACCAD/
│ └── Male1Walking_c3d/
│ └── Walk_B10_-_Walk_turn_left_45_stageii.npz
├── deploy.env
├── deployment/
│ ├── deploy_environment.sh
│ ├── holomotion_teleop/
│ │ ├── holomotion_teleop_node.py
│ │ ├── holomotion_teleop_setup.md
│ │ └── setup_holomotion_teleop_x86_ubuntu2204.sh
│ └── unitree_g1_ros2_29dof/
│ ├── launch_holomotion_29dof.sh
│ ├── launch_holomotion_29dof_docker.sh
│ ├── src/
│ │ ├── CMakeLists.txt
│ │ ├── config/
│ │ │ └── g1_29dof_holomotion.yaml
│ │ ├── humanoid_policy/
│ │ │ ├── __init__.py
│ │ │ ├── holomotion_fk_root_only.py
│ │ │ ├── obs_builder/
│ │ │ │ ├── __init__.py
│ │ │ │ └── obs_builder.py
│ │ │ ├── policy_node_29dof.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── command_helper.py
│ │ │ ├── maths.py
│ │ │ ├── motor_crc.py
│ │ │ ├── remote_controller_filter.py
│ │ │ ├── rotation_helper.py
│ │ │ └── rotations.py
│ │ ├── include/
│ │ │ └── common/
│ │ │ ├── motor_crc.h
│ │ │ ├── motor_crc_hg.h
│ │ │ ├── ros2_sport_client.h
│ │ │ └── wireless_controller.h
│ │ ├── launch/
│ │ │ └── holomotion_29dof_launch.py
│ │ ├── models/
│ │ │ └── .gitkeep
│ │ ├── motion_data/
│ │ │ └── .gitkeep
│ │ ├── package.xml
│ │ ├── resource/
│ │ │ └── humanoid_control
│ │ ├── setup.cfg
│ │ ├── setup.py
│ │ └── src/
│ │ ├── common/
│ │ │ ├── motor_crc.cpp
│ │ │ ├── motor_crc_hg.cpp
│ │ │ ├── ros2_sport_client.cpp
│ │ │ └── wireless_controller.cpp
│ │ └── main_node.cpp
│ └── start_container.sh
├── docs/
│ ├── environment_setup.md
│ ├── evaluate_motion_tracking.md
│ ├── holomotion_motion_file_spec.md
│ ├── motion_retargeting.md
│ ├── mujoco_sim2sim.md
│ ├── realworld_deployment.md
│ ├── smpl_data_curation.md
│ └── train_motion_tracking.md
├── environments/
│ ├── environment_deploy.yaml
│ ├── environment_train_isaaclab_cu118.yaml
│ ├── environment_train_isaaclab_cu128.yaml
│ ├── requirements_base.txt
│ ├── requirements_deploy.txt
│ ├── requirements_torch_cu118.txt
│ ├── requirements_torch_cu128.txt
│ └── requirements_torch_cu130.txt
├── holomotion/
│ ├── config/
│ │ ├── algo/
│ │ │ ├── ppo.yaml
│ │ │ └── ppo_tf.yaml
│ │ ├── data_curation/
│ │ │ ├── joints2smpl.yaml
│ │ │ └── smplify_base.yaml
│ │ ├── env/
│ │ │ ├── domain_randomization/
│ │ │ │ ├── NO_domain_rand.yaml
│ │ │ │ ├── domain_rand_medium.yaml
│ │ │ │ ├── domain_rand_small.yaml
│ │ │ │ └── domain_rand_strong.yaml
│ │ │ ├── motion_tracking.yaml
│ │ │ ├── observations/
│ │ │ │ ├── motion_tracking/
│ │ │ │ │ ├── obs_motion_tracking_mlp.yaml
│ │ │ │ │ └── obs_motion_tracking_tf-moe.yaml
│ │ │ │ └── velocity_tracking/
│ │ │ │ └── obs_velocity_tracking.yaml
│ │ │ ├── rewards/
│ │ │ │ ├── motion_tracking/
│ │ │ │ │ └── rew_motion_tracking.yaml
│ │ │ │ └── velocity_tracking/
│ │ │ │ └── rew_velocity_tracking.yaml
│ │ │ ├── terminations/
│ │ │ │ ├── NO_termination.yaml
│ │ │ │ ├── termination_motion_tracking.yaml
│ │ │ │ └── termination_velocity_tracking.yaml
│ │ │ ├── terrain/
│ │ │ │ ├── isaaclab_plane.yaml
│ │ │ │ └── isaaclab_rough.yaml
│ │ │ └── velocity_tracking.yaml
│ │ ├── evaluation/
│ │ │ ├── eval_isaaclab.yaml
│ │ │ ├── eval_mujoco_sim2sim.yaml
│ │ │ └── eval_velocity_tracking.yaml
│ │ ├── modules/
│ │ │ ├── motion_tracking/
│ │ │ │ ├── motion_tracking_mlp.yaml
│ │ │ │ └── motion_tracking_tf-moe.yaml
│ │ │ └── velocity_tracking/
│ │ │ └── velocity_tracking_mlp.yaml
│ │ ├── motion_retargeting/
│ │ │ ├── gmr_to_holomotion.yaml
│ │ │ ├── holomotion_preprocess.yaml
│ │ │ ├── kinematic_filter.yaml
│ │ │ ├── pack_hdf5_database.yaml
│ │ │ ├── pack_hdf5_v2.yaml
│ │ │ └── unitree_G1_29dof_retargeting.yaml
│ │ ├── mujoco_eval/
│ │ │ └── sim2sim.yaml
│ │ ├── robot/
│ │ │ └── unitree/
│ │ │ └── G1/
│ │ │ └── 29dof/
│ │ │ ├── 29dof_training_isaaclab.yaml
│ │ │ └── 29dof_training_isaaclab_s100.yaml
│ │ └── training/
│ │ ├── motion_tracking/
│ │ │ ├── train_g1_29dof_motion_tracking_mlp.yaml
│ │ │ └── train_g1_29dof_motion_tracking_tf-moe.yaml
│ │ ├── train_base.yaml
│ │ └── velocity_tracking/
│ │ └── train_g1_29dof_velocity_tracking_mlp.yaml
│ ├── scripts/
│ │ ├── data_curation/
│ │ │ ├── convert_to_amass.sh
│ │ │ ├── filter_smpl_data.sh
│ │ │ ├── video_to_smpl_gvhmr.sh
│ │ │ └── visualize_smpl_npz.sh
│ │ ├── evaluation/
│ │ │ ├── calc_offline_eval_metrics.sh
│ │ │ ├── eval_motion_tracking.sh
│ │ │ ├── eval_mujoco_sim2sim.sh
│ │ │ ├── eval_velocity_tracking.sh
│ │ │ ├── mean_process_5metrics.py
│ │ │ └── multi_model_metrics_analysis.sh
│ │ ├── motion_retargeting/
│ │ │ ├── apply_gmr_motion_retarget_patch.sh
│ │ │ ├── pack_hdf5_v2.sh
│ │ │ ├── run_holomotion_preprocessing.sh
│ │ │ ├── run_kinematic_filter.sh
│ │ │ ├── run_motion_retargeting_gmr_bvh.sh
│ │ │ ├── run_motion_retargeting_gmr_smplx.sh
│ │ │ ├── run_motion_retargeting_gmr_to_holomotion.sh
│ │ │ └── run_motion_viz_mujoco.sh
│ │ └── training/
│ │ ├── train_motion_tracking.sh
│ │ └── train_velocity_tracking.sh
│ ├── src/
│ │ ├── algo/
│ │ │ ├── __init__.py
│ │ │ ├── algo_base.py
│ │ │ ├── algo_utils.py
│ │ │ ├── ppo.py
│ │ │ └── ppo_tf.py
│ │ ├── data_curation/
│ │ │ ├── .gitignore
│ │ │ ├── __init__.py
│ │ │ ├── data_smplify.py
│ │ │ ├── filter/
│ │ │ │ ├── filter.py
│ │ │ │ └── label_data.py
│ │ │ ├── smpl_npz_to_html.py
│ │ │ ├── smplify/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── smplify_humanact12.py
│ │ │ │ ├── smplify_motionx.py
│ │ │ │ ├── smplify_omomo.py
│ │ │ │ └── smplify_zjumocap.py
│ │ │ ├── templates/
│ │ │ │ └── index_wooden_static.html
│ │ │ ├── video_to_smpl_gvhmr.py
│ │ │ ├── vison_mocap/
│ │ │ │ └── joints2smpl.py
│ │ │ └── visualize_smpl_npz.py
│ │ ├── env/
│ │ │ ├── __init__.py
│ │ │ ├── isaaclab_components/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── isaaclab_actions.py
│ │ │ │ ├── isaaclab_curriculum.py
│ │ │ │ ├── isaaclab_domain_rand.py
│ │ │ │ ├── isaaclab_motion_tracking_command.py
│ │ │ │ ├── isaaclab_observation.py
│ │ │ │ ├── isaaclab_rewards.py
│ │ │ │ ├── isaaclab_scene.py
│ │ │ │ ├── isaaclab_simulator.py
│ │ │ │ ├── isaaclab_termination.py
│ │ │ │ ├── isaaclab_terrain.py
│ │ │ │ ├── isaaclab_utils.py
│ │ │ │ ├── isaaclab_velocity_tracking_command.py
│ │ │ │ └── unitree_actuators.py
│ │ │ ├── motion_tracking.py
│ │ │ └── velocity_tracking.py
│ │ ├── evaluation/
│ │ │ ├── __init__.py
│ │ │ ├── eval_motion_tracking.py
│ │ │ ├── eval_motion_tracking_single.py
│ │ │ ├── eval_mujoco_sim2sim.py
│ │ │ ├── eval_velocity_tracking.py
│ │ │ ├── find_worst_clips.py
│ │ │ ├── metrics.py
│ │ │ ├── multi_model_metrics_report.py
│ │ │ ├── obs/
│ │ │ │ ├── __init__.py
│ │ │ │ └── obs_builder.py
│ │ │ ├── ray_evaluator_actor.py
│ │ │ └── ray_metrics_postprocess.py
│ │ ├── modules/
│ │ │ ├── __init__.py
│ │ │ ├── agent_modules.py
│ │ │ └── network_modules.py
│ │ ├── motion_retargeting/
│ │ │ ├── __init__.py
│ │ │ ├── gmr_to_holomotion.py
│ │ │ ├── holomotion_fk.py
│ │ │ ├── holomotion_preprocess.py
│ │ │ ├── kinematic_filter.py
│ │ │ ├── pack_hdf5_v2.py
│ │ │ ├── reference_filtering.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── _schema.json
│ │ │ ├── rotation_conversions.py
│ │ │ ├── torch_humanoid_batch.py
│ │ │ └── visualize_with_mujoco.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── h5_dataloader.py
│ │ │ ├── reference_filter_export.py
│ │ │ └── train.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── frame_utils.py
│ │ ├── isaac_utils/
│ │ │ ├── __init__.py
│ │ │ ├── maths.py
│ │ │ ├── rotations.py
│ │ │ └── setup.py
│ │ ├── onnx_export.py
│ │ ├── reference_prefix.py
│ │ ├── torch_utils.py
│ │ └── unitree_g1_actuator_calculator.py
│ └── tests/
│ └── __init__.py
├── pyproject.toml
├── tests/
│ ├── benchmark_legacy_onnx_attention.py
│ ├── benchmark_moe_router_export.py
│ ├── test_actor_export_config.py
│ ├── test_algo_base_iteration_logging.py
│ ├── test_build_quantization_dataset.py
│ ├── test_cache_curriculum_sampler.py
│ ├── test_domain_rand_config_builder.py
│ ├── test_eval_mujoco_action_delay.py
│ ├── test_eval_mujoco_action_ema.py
│ ├── test_eval_mujoco_contact_export.py
│ ├── test_eval_mujoco_s100_horizon_ptq.py
│ ├── test_eval_mujoco_use_gpu.py
│ ├── test_eval_onnx_io_dump.py
│ ├── test_evaluation_metrics.py
│ ├── test_isaaclab_termination.py
│ ├── test_mean_process_5metrics.py
│ ├── test_motion_cache_gather_state.py
│ ├── test_motion_cache_startup.py
│ ├── test_motion_tracking_command_reference_prefix.py
│ ├── test_motion_tracking_timing.py
│ ├── test_mujoco_filtered_ref_compat.py
│ ├── test_obs_norm_compile.py
│ ├── test_observation_frames.py
│ ├── test_onnx_attention_export.py
│ ├── test_onnx_export.py
│ ├── test_plot_moe_expert_heatmap.py
│ ├── test_plot_state_series.py
│ ├── test_ppo_checkpoint_sigma_override.py
│ ├── test_ppo_entropy_annealing.py
│ ├── test_ppo_symmetry_loss.py
│ ├── test_ppo_tf_aux_keybody.py
│ ├── test_ref_router_actor.py
│ ├── test_ref_router_seq_actor.py
│ ├── test_reference_filter_export.py
│ ├── test_reference_motion_config_wiring.py
│ ├── test_root_rel_rewards.py
│ ├── test_unitree_actuators.py
│ └── test_visualize_with_mujoco.py
└── train.env
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
*.ipynb filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.obj filter=lfs diff=lfs merge=lfs -text
*.dae filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
assets/smplx filter=lfs diff=lfs merge=lfs -text
assets/smpl filter=lfs diff=lfs merge=lfs -text
assets/test_data filter=lfs diff=lfs merge=lfs -text
assets/media/*.png filter=lfs diff=lfs merge=lfs -text
assets/media/*.jpg filter=lfs diff=lfs merge=lfs -text
assets/media/*.jpeg filter=lfs diff=lfs merge=lfs -text
assets/videos/*.mp4 filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.svg filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
================================================
FILE: .githooks/README.md
================================================
# Git hooks
Pre-commit runs [ruff](https://docs.astral.sh/ruff/) format on staged Python files (using `train.env` for the correct environment).
**Install ruff** in the holomotion train environment if it is absent:
```bash
conda activate holomotion_train
pip install ruff
```
Ruff is also listed in `environments/requirements_base.txt` if you install deps from there.
**Enable hooks** (run once from repo root):
```bash
git config core.hooksPath .githooks
```
Ensure the hook is executable: `chmod +x .githooks/pre-commit`
**Requirement:** Run `git commit` from a shell where conda is available so `train.env` can set `Train_CONDA_PREFIX`.
**Skip hook for one commit:** `git commit --no-verify`
================================================
FILE: .githooks/pre-commit
================================================
#!/bin/bash
# Run ruff format on staged Python files before commit.
# Use from holomotion repo root (standalone clone or submodule).
# Requires: run git commit from a shell where conda is available (so train.env can set Train_CONDA_PREFIX).
set -e
cd "$(git rev-parse --show-toplevel)"
mapfile -t staged_py < <(git diff --cached --name-only --diff-filter=ACM | grep '\.py$' || true)
if [ ${#staged_py[@]} -eq 0 ]; then
exit 0
fi
# train.env sets Train_CONDA_PREFIX; conda must be on PATH when you run git commit
source train.env
"$Train_CONDA_PREFIX/bin/ruff" format --config pyproject.toml "${staged_py[@]}"
git add "${staged_py[@]}"
exit 0
================================================
FILE: .gitignore
================================================
# ignore logs and cache
logs/
logs_eval/
data/
outputs/
.archive/
tmp/
# ignore deployment bag_record, install, log
deployment/unitree_g1_ros2/bag_record/
deployment/unitree_g1_ros2/install/
deployment/unitree_g1_ros2/log/
# ignore data and outputs
data
data/
outputs/
build/
install/
log/
.DS_Store/
**.egg-info/
**.log
**.LOG
# ignore large files
*.log
*.pkl
*.pt
*.onnx
*.npy
*.npz
*.zip
*.tar.gz
*.obj
*.dae
*.STL
# ignore video, image, etc.
*.mp4
*.avi
*.mov
*.png
*.pdf
__pycache__/
*.pyc
*.egg-info
.agents/
.cursor/
.cursorignore
.vscode/
.*_cache/
**/usd/
assets/isaac/
not_for_commit/
thirdparties/smpl_models/
MUJOCO_LOG.TXT
# allow certain files
!deployment/unitree_g1_ros2/src/models/*.onnx
!deployment/unitree_g1_ros2/src/motion_data/*.pkl
!assets/smpl/*.pkl
!assets/smpl/*.npz
!assets/smplx/*.pkl
!assets/smplx/*.npz
!assets/test_data/**
!assets/media/**
# macOS system files
.DS_Store
**/.DS_Store
================================================
FILE: .gitlab-ci.yml
================================================
workflow:
rules:
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
job1:
script:
- echo "This job runs in merge request pipelines"
================================================
FILE: .gitmodules
================================================
[submodule "thirdparties/SMPLSim"]
path = thirdparties/SMPLSim
url = https://github.com/ZhengyiLuo/SMPLSim
branch = master
[submodule "thirdparties/joints2smpl"]
path = thirdparties/joints2smpl
url = https://github.com/wangsen1312/joints2smpl.git
branch = main
[submodule "thirdparties/omomo_release"]
path = thirdparties/omomo_release
url = https://github.com/lijiaman/omomo_release.git
branch = main
[submodule "thirdparties/unitree_ros"]
path = thirdparties/unitree_ros
url = https://github.com/unitreerobotics/unitree_ros
branch = master
[submodule "thirdparties/unitree_ros2"]
path = thirdparties/unitree_ros2
url = https://github.com/unitreerobotics/unitree_ros2
branch = master
[submodule "thirdparties/unitree_sdk2_python"]
path = thirdparties/unitree_sdk2_python
url = https://github.com/unitreerobotics/unitree_sdk2_python.git
[submodule "thirdparties/cyclonedds"]
path = thirdparties/cyclonedds
url = https://github.com/eclipse-cyclonedds/cyclonedds
branch = 0.10.2
[submodule "thirdparties/unitree_sdk2"]
path = thirdparties/unitree_sdk2
url = https://github.com/unitreerobotics/unitree_sdk2.git
[submodule "thirdparties/GMR"]
path = thirdparties/GMR
url = https://github.com/YanjieZe/GMR.git
branch = master
[submodule "thirdparties/HoloMotion_assets"]
path = thirdparties/HoloMotion_assets
url = https://huggingface.co/datasets/HorizonRobotics/HoloMotion_assets
[submodule "thirdparties/smplx"]
path = thirdparties/smplx
url = https://github.com/vchoutas/smplx
[submodule "thirdparties/GVHMR"]
path = thirdparties/GVHMR
url = https://github.com/zju3dv/GVHMR.git
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 maiyue01.chen
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: Makefile
================================================
# Variables
PY_SRC := holomotion/ # Your Python code directory
RUFF := ruff # Assumes Ruff is installed locally
PYTEST := pytest -v
TESTS := holomotion/tests/
COV := --cov=holomotion/src --cov-report=term-missing
# Directory to lint/format - can be overridden with DIR=path
DIR ?= holomotion/src
.PHONY: lint format check lint-dir format-dir
# Run Ruff linter on default directory
lint:
@echo "Linting with Ruff..."
@$(RUFF) check $(PY_SRC)
# Format code in default directory
format:
@echo "Formatting with Ruff..."
@$(RUFF) format $(PY_SRC)
@$(RUFF) check --fix $(PY_SRC) # Auto-fix lint errors
# Run Ruff linter on specific directory (with fallback)
lint-dir:
@echo "Linting directory: $(DIR)"
@$(RUFF) check $(DIR)
# Format code in specific directory (with fallback)
format-dir:
@echo "Formatting directory: $(DIR)"
@$(RUFF) format $(DIR)
@$(RUFF) check --fix $(DIR) # Auto-fix lint errors
# Strict check (for CI)
check:
@$(RUFF) check $(PY_SRC) --exit-non-zero-on-fix
# Run all tests
test:
$(PYTEST) $(TESTS)
================================================
FILE: NOTICE
================================================
=======================================================================
ASAP's MIT License
=======================================================================
Code derived from implementations in ASAP should mention its derivation
and reference the following license:
MIT License
Copyright (c) 2025 ASAP Team
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
=======================================================================
omomo_release's MIT License
=======================================================================
Code derived from implementations in omomo_release should mention its derivation
and reference the following license:
MIT License
Copyright (c) 2023 Jiaman Li
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
=======================================================================
NVIDIA License
=======================================================================
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
================================================
FILE: README.md
================================================
---
[](#)
[](#)
[](./LICENSE)
[](https://horizonrobotics.github.io/robot_lab/holomotion/)
[](https://huggingface.co/collections/HorizonRobotics/holomotion)
[](https://deepwiki.com/HorizonRobotics/HoloMotion)
[](https://horizonrobotics.feishu.cn/docx/Xs3cdEI8bo1EZuxUfzjckTgKn2c)
# HoloMotion: A Foundation Model for Whole-Body Humanoid Control
## NEWS
- [2026.04.04] The v1.2 version of HoloMotion has been released, we provide pre-trained motion tracking and velocity tracking models for the community to deploy directly.
- [2026.01.06] The v1.1 version of HoloMotion has been released, representing a major step forward toward a fully engineered, stable, and reproducible humanoid motion intelligence system.
- [2025.11.05] The v1.0 version of HoloMotion has been released, and the WeChat user group is now open! Please scan the [QR Code](https://horizonrobotics.feishu.cn/docx/Xs3cdEI8bo1EZuxUfzjckTgKn2c) to join.
## Pre-trained Models
- Motion Tracking Model: [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_motion_tracking_model)
- Velocity Tracking Model: [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_velocity_tracking_model)
Please read the doc for [real-world deployment](docs/realworld_deployment.md) for more details on how to use the models.
## Introduction
HoloMotion is a foundation model for humanoid robotics, designed to fullfill robust, real-time, and generalizable whole-body control.
Our framework provides an end-to-end solution, encompassing the entire workflow from data curation and motion retargeting to distributed model training, evaluation, and seamless deployment on physical hardware via ROS2. HoloMotion's modular architecture allows for flexible adaptation and extension, enabling researchers and developers to build and benchmark agents that can imitate, generalize, and master complex whole-body motions.
For those at the forefront of creating the next generation of humanoid robots, HoloMotion serves as a powerful, extensible, and open-source foundation for achieving whole-body control.
---
### 🛠️ Roadmap: Progress Toward Any Humanoid Control
We envision HoloMotion as a general-purpose foundation for humanoid motion and control. Its development is structured around four core generalization goals: Any Pose, Any Command, Any Terrain, and Any Embodiment. Each goal corresponds to a major version milestone.
| Version | Target Capability | Description |
| -------- | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------- |
| **v1.0** | 🔄 Any Pose | Achieve robust tracking and imitation of diverse, whole-body human motions, forming the core of the imitation learning capability. |
| **v2.0** | ⏳ Any Command | Enable language- and task-conditioned motion generation, allowing for goal-directed and interactive behaviors. |
| **v3.0** | ⏳ Any Terrain | Master adaptation to uneven, dynamic, and complex terrains, enhancing real-world operational robustness. |
| **v4.0** | ⏳ Any Embodiment | Generalize control policies across humanoids with varying morphologies and kinematics, achieving true embodiment-level abstraction. |
> Each stage builds on the previous one, moving from motion imitation to instruction following, terrain adaptation, and embodiment-level generalization.
## Pipeline Overview
```mermaid
flowchart LR
A["🔧 1. Environment Setup Dependencies & conda"]
subgraph dataFrame ["DATA"]
B["📊 2. Dataset Preparation Download & curate"]
C["🔄 3. Motion Retargeting Human to robot motion"]
B --> C
end
subgraph modelFrame ["TRAIN & EVAL"]
D["🧠 4. Model Training Train with HoloMotion"]
E["📈 5. Evaluation Test & export"]
D --> E
end
F["🚀 6. Deployment Deploy to robots"]
A --> dataFrame
dataFrame --> modelFrame
modelFrame --> F
classDef subgraphStyle fill:#f9f9f9,stroke:#333,stroke-width:2px,stroke-dasharray:5 5,rx:10,ry:10,font-size:16px,font-weight:bold
classDef nodeStyle fill:#e1f5fe,stroke:#0277bd,stroke-width:2px,rx:10,ry:10
class dataFrame,modelFrame subgraphStyle
class A,B,C,D,E,F nodeStyle
```
## Quick Start
### 🔧 1. Environment Setup [[Doc](docs/environment_setup.md)]
Set up your development and deployment environments using Conda. This initial step ensures all dependencies are correctly configured for both training and real-world execution.
If you only intend to use our pretrained models, you can skip the training environment setup and proceed directly to configure the deployment environment. See the [real-world deployment documentation](docs/realworld_deployment.md) for details.
### 📊 2. Dataset Preparation [[Doc](docs/smpl_data_curation.md)]
Acquire and process large-scale motion datasets. Our tools help you curate high-quality data by converting it to the AMASS-compatible smpl format and filtering out anomalies using kinematic metrics.
### 🔄 3. Motion Retargeting [[Doc](docs/motion_retargeting.md)]
Translate human motion data into robot-specific kinematic data. Our pipeline leverages [GMR](https://github.com/YanjieZe/GMR) to map human movements onto your robot's morphology, producing optimized HDF5 datasets ready for high-speed, distributed training.
### 🧠 4. Model Training [[Doc](docs/train_motion_tracking.md)]
Train your foundation model using our reinforcement learning framework. HoloMotion supports versatile training tasks, including motion tracking and velocity tracking.
### 📈 5. Evaluation [[Doc](docs/evaluate_motion_tracking.md)]
Evaluate your trained policies in IsaacLab. Visualize performance, and export trained models in ONNX format for seamless deployment.
### 🚀 6. Real-world Deployment [[Doc](docs/realworld_deployment.md)]
Our ROS2 package facilitates the deployment of the exported ONNX models, enabling real-time control on hardware like the Unitree G1.
## Join Us
We are hiring full-time engineers, new graduates, and interns who are excited about humanoid robots, motion control, and embodied intelligence.
Send your resume by scanning the **WeChat** QR code below to get in touch with us.
## Citation
```
@software{HoloMotion,
author = {Maiyue Chen, Kaihui Wang, Bo Zhang, Yi Ren, Zihao Zhu, Xihan Ma, Qijun Huang, Zhiyuan Yang, Yucheng Wang, Zhizhong Su},
title = {HoloMotion: A Foundation Model for Whole-Body Humanoid Control},
year = {2026},
month = April,
version = {1.2.0},
url = {https://github.com/HorizonRobotics/HoloMotion},
license = {Apache-2.0}
}
```
## License
This project is released under the **[Apache 2.0](https://img.shields.io/badge/license-Apache--2.0-blue.svg)** license.
## Acknowledgements
This project is built upon and inspired by several outstanding open source projects:
- [GMR](https://github.com/YanjieZe/GMR)
- [BeyondMimic](https://github.com/HybridRobotics/whole_body_tracking/tree/dcecabd8c24c68f59d143fdf8e3a670f420c972d)
- [ASAP](https://github.com/LeCAR-Lab/ASAP)
- [Humanoidverse](https://github.com/LeCAR-Lab/HumanoidVerse)
- [PHC](https://github.com/ZhengyiLuo/PHC?tab=readme-ov-file)
- [ProtoMotion](https://github.com/NVlabs/ProtoMotions/tree/main/protomotions)
- [Mink](https://github.com/kevinzakka/mink)
- [PBHC](https://github.com/TeleHuman/PBHC)
================================================
FILE: assets/robots/unitree/G1/29dof/g1_29dof.xml
================================================
================================================
FILE: assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0.urdf
================================================
================================================
FILE: assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0.xml
================================================
================================================
FILE: assets/robots/unitree/G1/29dof/g1_29dof_rev_1_0_s100.urdf
================================================
================================================
FILE: assets/robots/unitree/G1/29dof/scene_29dof.xml
================================================
================================================
FILE: assets/test_data/motion_retargeting/ACCAD/Male1Walking_c3d/Walk_B10_-_Walk_turn_left_45_stageii.npz
================================================
version https://git-lfs.github.com/spec/v1
oid sha256:738f96eb1767e281d78631ca697079adefb5f171d581acd622a27740f2503b4e
size 5876184
================================================
FILE: deploy.env
================================================
export CONDA_BASE=$(conda info --base)
export Deploy_CONDA_PREFIX="$CONDA_BASE/envs/holomotion_deploy"
export CUDA_HOME=$Deploy_CONDA_PREFIX
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$Deploy_CONDA_PREFIX/lib/
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$Deploy_CONDA_PREFIX/lib/stubs
export LIBRARY_PATH=$Deploy_CONDA_PREFIX/lib:$LIBRARY_PATH
export LIBRARY_PATH=$Deploy_CONDA_PREFIX/lib/stubs:$LIBRARY_PATH
export HYDRA_FULL_ERROR=1
echo "--------------------------------"
echo "Deploy_CONDA_PREFIX: $Deploy_CONDA_PREFIX"
echo "CUDA_HOME: $CUDA_HOME"
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
echo "LIBRARY_PATH: $LIBRARY_PATH"
echo "HYDRA_FULL_ERROR: $HYDRA_FULL_ERROR"
echo "--------------------------------"
================================================
FILE: deployment/deploy_environment.sh
================================================
#!/bin/bash
##############################################################################
# HoloMotion Environment Deployment Script
#
# This script sets up the complete environment for HoloMotion humanoid robot
# system deployment. It handles:
# 1. Conda environment creation with all dependencies (CUDA, PyTorch, etc.)
# 2. Special dependencies (unitree_sdk2_python)
# 3. ROS2 workspace compilation
#
# Prerequisites:
# - Anaconda/Miniconda installed
# - ROS2 Humble installed at /opt/ros/humble/
# - Unitree ROS2 SDK at ~/unitree_ros2/
#
# Usage:
# chmod +x deploy_environment.sh
# ./deploy_environment.sh [environment_name]
#
# Arguments:
# environment_name: Optional. Name for the conda environment (default: holomotion_deploy)
#
# Examples:
# ./deploy_environment.sh # Uses default name 'holomotion_deploy'
# ./deploy_environment.sh my_robot_env # Uses custom name 'my_robot_env'
#
# Author: HoloMotion Team
##############################################################################
set -e # Exit on any error
# Parse command line arguments
ENV_NAME="${1:-holomotion_deploy}"
echo "🚀 Starting HoloMotion Environment Deployment..."
echo "📝 Environment name: $ENV_NAME"
# Get script directory
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
echo "📁 Project root: $PROJECT_ROOT"
echo "📁 Script directory: $SCRIPT_DIR"
# Step 1: Create conda environment with all dependencies
echo ""
echo "📦 Step 1: Creating conda environment with all dependencies..."
if conda env list | grep -q "^$ENV_NAME "; then
echo "⚠️ Environment '$ENV_NAME' already exists. Removing it..."
conda env remove -n "$ENV_NAME" -y
fi
echo "🔧 Creating new environment from environment_deploy.yaml..."
echo " This will install: PyTorch (CUDA), NumPy, SciPy, ONNX Runtime, and all other dependencies..."
cd "$PROJECT_ROOT"
conda env create -f holomotion/environment_deploy.yaml -n "$ENV_NAME"
echo "✅ Conda environment with all dependencies created successfully!"
# Step 2: Install unitree_sdk2_python
echo ""
echo "📦 Step 2: Installing unitree_sdk2_python..."
# Function to run commands in conda environment
run_in_env() {
conda run -n "$ENV_NAME" "$@"
}
echo "🔧 Installing unitree_sdk2_python..."
if [ ! -d "$HOME/unitree_sdk2_python" ]; then
echo "📥 Cloning unitree_sdk2_python repository..."
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git "$HOME/unitree_sdk2_python"
fi
echo "🔧 Installing unitree_sdk2_python in development mode..."
cd "$HOME/unitree_sdk2_python"
run_in_env pip install -e .
echo "✅ unitree_sdk2_python installed successfully!"
# Step 3: Setup ROS2 workspace
echo ""
echo "📦 Step 3: Setting up ROS2 workspace..."
# Ensure conda environment is completely deactivated for ROS2 compilation
echo "🔧 Ensuring conda environment is completely deactivated..."
# Initialize conda for this script
eval "$(conda shell.bash hook)"
# Deactivate any active conda environments
while [[ "$CONDA_DEFAULT_ENV" != "" && "$CONDA_DEFAULT_ENV" != "base" ]]; do
echo " Deactivating conda environment: $CONDA_DEFAULT_ENV"
conda deactivate
done
# If we're still in base environment, deactivate it too
if [[ "$CONDA_DEFAULT_ENV" == "base" ]]; then
echo " Deactivating base conda environment"
conda deactivate
fi
echo " ✅ Conda environment fully deactivated"
# Check ROS2 installation
if [ ! -f "/opt/ros/humble/setup.bash" ]; then
echo "❌ ROS2 Humble not found at /opt/ros/humble/"
echo " Please install ROS2 Humble first: https://docs.ros.org/en/humble/Installation.html"
exit 1
fi
# Check Unitree ROS2 SDK
if [ ! -f "$HOME/unitree_ros2/setup.sh" ]; then
echo "❌ Unitree ROS2 SDK not found at ~/unitree_ros2/"
echo " Please install Unitree ROS2 SDK first"
exit 1
fi
echo "🔧 Compiling ROS2 workspace..."
cd "$PROJECT_ROOT/holomotion/deployment/unitree_g1_ros2_29dof"
# Create necessary directories
echo "📁 Creating required directories..."
mkdir -p src/models
mkdir -p src/motion_data
# Clean previous build
rm -rf build install log
# Source ROS2 and Unitree setup
source /opt/ros/humble/setup.bash
source ~/unitree_ros2/setup.sh
# Build workspace
echo "🏗️ Building workspace with colcon..."
colcon build
echo "✅ ROS2 workspace compiled successfully!"
echo ""
echo "🎉 Deployment completed successfully!"
echo ""
echo "📋 Summary of installed packages:"
echo " ✅ PyTorch 2.3.1 with CUDA 12.1 support"
echo " ✅ ONNX Runtime for neural network inference"
echo " ✅ SMPLX for humanoid motion processing"
echo " ✅ Scientific computing packages (NumPy, SciPy, etc.)"
echo " ✅ Unitree SDK2 Python bindings"
echo " ✅ ROS2 workspace compiled"
echo ""
echo "📋 To run the system:"
echo "1. Activate the conda environment:"
echo " conda activate $ENV_NAME"
echo ""
echo "2. Launch the system:"
echo " cd $PROJECT_ROOT/holomotion/deployment/unitree_g1_ros2_29dof"
echo " bash launch_holomotion.sh"
echo ""
echo "✅ Environment '$ENV_NAME' setup complete!"
echo "🚀 Ready for robot deployment!"
================================================
FILE: deployment/holomotion_teleop/holomotion_teleop_node.py
================================================
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Single-process teleoperation pipeline.
This node reads raw Pico body tracking data from XRoboToolkit, converts it to
SMPL, applies GMR retargeting, and publishes a 65D observation vector to the
robot over ZMQ.
Data flow:
xrobotoolkit_sdk (body_poses 24x7)
-> body_poses_to_smpl_pose_trans
-> SMPL_Parser / humanoid_fk
-> GMR
-> latest_obs(65)
-> ZMQ PUB
Message format:
[topic_bytes][1280-byte JSON header][binary payload]
Default payload fields:
- latest_obs: (65,) float32
- frame_index: (1,) int64
- timestamp_realtime: (1,) float64
- timestamp_monotonic: (1,) float64
- timestamp_ns: (1,) int64
- pico_dt: (1,) float32
- pico_fps: (1,) float32
"""
from __future__ import annotations
import argparse
from collections import defaultdict
from dataclasses import dataclass
import json
import logging
import os
import subprocess
import sys
import threading
import time
import traceback
from types import SimpleNamespace
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import zmq
from scipy.spatial.transform import Rotation as R
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
HOLOMOTION_ROOT_DIR = os.path.abspath(os.path.join(FILE_DIR, "..", ".."))
SMPL_ASSET_DIR = os.path.join(HOLOMOTION_ROOT_DIR, "assets", "smpl")
for extra_path in (
FILE_DIR,
os.path.join(FILE_DIR, "GMR"),
os.path.join(FILE_DIR, "SMPLSim"),
):
if extra_path not in sys.path:
sys.path.insert(0, extra_path)
try:
import xrobotoolkit_sdk as xrt
except ImportError:
xrt = None
from third_party.GMR.general_motion_retargeting.motion_retarget import GeneralMotionRetargeting as GMR
from smpl_sim.smpllib.smpl_parser import SMPL_Parser
MIRROR_POSE = False
MIRROR_AXIS = "x"
HEADER_SIZE = 1280
OUT_TOPIC = b"obs65"
GMR_LR_SWAP_PAIRS = [
("left_hip", "right_hip"),
("left_knee", "right_knee"),
("left_foot", "right_foot"),
("left_shoulder", "right_shoulder"),
("left_elbow", "right_elbow"),
("left_wrist", "right_wrist"),
]
SMPL_PARENTS_24 = np.array(
[-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21],
dtype=np.int32,
)
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotation matrices to quaternions in wxyz format.
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
floor = torch.tensor(0.1, dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(floor))
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,
:,
].reshape(batch_dim + (4,))
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert axis-angle vectors to rotation matrices.
Input shape: (..., 3)
Output shape: (..., 3, 3)
"""
orig_shape = axis_angle.shape[:-1]
aa = axis_angle.reshape(-1, 3)
theta = torch.linalg.norm(aa, dim=-1, keepdim=True)
axis = aa / torch.clamp(theta, min=1e-8)
x = axis[:, 0]
y = axis[:, 1]
z = axis[:, 2]
zeros = torch.zeros_like(x)
K = torch.stack(
[
zeros, -z, y,
z, zeros, -x,
-y, x, zeros,
],
dim=-1,
).reshape(-1, 3, 3)
eye = torch.eye(3, dtype=aa.dtype, device=aa.device).unsqueeze(0).expand(aa.shape[0], -1, -1)
sin_theta = torch.sin(theta).unsqueeze(-1)
cos_theta = torch.cos(theta).unsqueeze(-1)
axis_outer = axis.unsqueeze(-1) @ axis.unsqueeze(-2)
small = (theta.squeeze(-1) < 1e-8).unsqueeze(-1).unsqueeze(-1)
rot = cos_theta * eye + (1.0 - cos_theta) * axis_outer + sin_theta * K
rot = torch.where(small, eye, rot)
return rot.reshape(orig_shape + (3, 3))
class Humanoid_Batch_V2:
"""
Minimal per-frame SMPL kinematics helper used by this script only.
Keeping it local avoids importing the much larger training/visualization module.
"""
def __init__(self, device: torch.device = torch.device("cpu")):
self.device = device
self.smpl_24_parents = [
-1, 0, 0, 0, 1, 2, 3,
4, 5, 6, 7, 8, 9, 9,
9, 12, 13, 14, 16, 17,
18, 19, 20, 21,
]
@staticmethod
def _relative_link_position(joints_world: torch.Tensor, root_pos: torch.Tensor) -> torch.Tensor:
return joints_world - root_pos.unsqueeze(0)
def _relative_link_pose(self, full_pose_aa: torch.Tensor) -> torch.Tensor:
joint_count = full_pose_aa.shape[0]
assert joint_count == len(self.smpl_24_parents), (
f"Joint count mismatch: {joint_count} vs {len(self.smpl_24_parents)}"
)
rotation_local = axis_angle_to_matrix(full_pose_aa)
rotation_global = torch.empty_like(rotation_local)
for joint_idx in range(joint_count):
parent = self.smpl_24_parents[joint_idx]
if parent == -1:
rotation_global[joint_idx] = rotation_local[joint_idx]
else:
rotation_global[joint_idx] = rotation_global[parent] @ rotation_local[joint_idx]
return rotation_global
def step_per_frame(
self,
full_pose_aa: torch.Tensor,
root_pos: torch.Tensor,
joints: torch.Tensor,
) -> SimpleNamespace:
global_joints_position = joints
global_joints2root_pos = self._relative_link_position(joints[1:, :], root_pos)
global_joints_rotation_mat = self._relative_link_pose(full_pose_aa)
return SimpleNamespace(
global_joints2root_pos=global_joints2root_pos,
global_joints_rotation_mat=global_joints_rotation_mat,
global_joints_position=global_joints_position,
)
humanoid_fk = Humanoid_Batch_V2()
@dataclass
class PicoToSmplConfig:
quat_scalar_first: bool = False
apply_global_y_180: bool = True
apply_root_rx90: bool = True
root_align_degrees: float = 90.0
root_align_axis: str = "x"
def body_poses_to_smpl_pose_trans(
body_poses: np.ndarray,
parents: np.ndarray = SMPL_PARENTS_24,
cfg: Optional[PicoToSmplConfig] = None,
) -> Tuple[np.ndarray, np.ndarray]:
if cfg is None:
cfg = PicoToSmplConfig()
body_poses = np.asarray(body_poses, dtype=np.float32)
if body_poses.shape != (24, 7):
raise ValueError(f"body_poses shape must be (24,7), got {body_poses.shape}")
positions = body_poses[:, 0:3].astype(np.float32)
qx, qy, qz, qw = body_poses[:, 3], body_poses[:, 4], body_poses[:, 5], body_poses[:, 6]
global_quats_sfirst = np.stack([qw, qx, qy, qz], axis=1).astype(np.float32)
global_rots = R.from_quat(global_quats_sfirst, scalar_first=True)
if cfg.apply_global_y_180:
global_rots = global_rots * R.from_euler("y", 180.0, degrees=True)
local_rots = []
for i in range(24):
parent = int(parents[i])
if parent == -1:
local_rots.append(global_rots[i])
else:
local_rots.append(global_rots[parent].inv() * global_rots[i])
pose_aa_24x3 = np.stack([rot.as_rotvec() for rot in local_rots], axis=0).astype(np.float32)
trans = positions[0].astype(np.float32)
if cfg.apply_root_rx90:
rot_align = R.from_euler(cfg.root_align_axis, cfg.root_align_degrees, degrees=True).as_matrix().astype(
np.float32
)
root_matrix = R.from_rotvec(pose_aa_24x3[0]).as_matrix().astype(np.float32)
pose_aa_24x3[0] = R.from_matrix(rot_align @ root_matrix).as_rotvec().astype(np.float32)
trans = (rot_align @ trans.reshape(3, 1)).reshape(3).astype(np.float32)
return pose_aa_24x3, trans
def _mirror_matrix(mirror_axis: str) -> np.ndarray:
if mirror_axis == "x":
return np.diag([-1.0, 1.0, 1.0]).astype(np.float32)
if mirror_axis == "y":
return np.diag([1.0, -1.0, 1.0]).astype(np.float32)
if mirror_axis == "z":
return np.diag([1.0, 1.0, -1.0]).astype(np.float32)
raise ValueError(f"mirror_axis must be one of x/y/z, got {mirror_axis}")
def safe_normalize_quat_wxyz(q: np.ndarray, eps: float = 1e-8) -> np.ndarray:
q = np.asarray(q, dtype=np.float32).reshape(4,)
n = float(np.linalg.norm(q))
if n < eps:
return np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
return (q / n).astype(np.float32)
def mirror_pos_and_quat_wxyz(pos: np.ndarray, quat_wxyz: np.ndarray, mirror_axis: str) -> Tuple[np.ndarray, np.ndarray]:
pos = np.asarray(pos, dtype=np.float32).reshape(3,)
q = safe_normalize_quat_wxyz(quat_wxyz)
M = _mirror_matrix(mirror_axis)
pos_m = (M @ pos).astype(np.float32)
q_xyzw = np.array([q[1], q[2], q[3], q[0]], dtype=np.float32)
rot_m = R.from_quat(q_xyzw).as_matrix().astype(np.float32)
rot_m = (M @ rot_m @ M).astype(np.float32)
q_m_xyzw = R.from_matrix(rot_m).as_quat().astype(np.float32)
quat_m_wxyz = np.array([q_m_xyzw[3], q_m_xyzw[0], q_m_xyzw[1], q_m_xyzw[2]], dtype=np.float32)
return pos_m, safe_normalize_quat_wxyz(quat_m_wxyz)
def mirror_and_swap_gmr_input(gmr_input: Dict[str, Any], mirror_axis: str = "x") -> Dict[str, Any]:
mirrored: Dict[str, Any] = {}
for key, (pos, quat) in gmr_input.items():
mirrored[key] = mirror_pos_and_quat_wxyz(pos, quat, mirror_axis)
out = dict(mirrored)
for a, b in GMR_LR_SWAP_PAIRS:
if a in out and b in out:
out[a], out[b] = out[b], out[a]
return out
def pack_numpy_message(payload: dict, topic: bytes = OUT_TOPIC, version: int = 1) -> bytes:
fields = []
binary_data = []
for key, value in payload.items():
if not isinstance(value, np.ndarray):
continue
if value.dtype == np.float32:
dtype_str = "f32"
elif value.dtype == np.float64:
dtype_str = "f64"
elif value.dtype == np.int32:
dtype_str = "i32"
elif value.dtype == np.int64:
dtype_str = "i64"
elif value.dtype == np.uint8:
dtype_str = "u8"
elif value.dtype == bool:
dtype_str = "bool"
else:
dtype_str = "f32"
value = value.astype(np.float32)
if not value.flags["C_CONTIGUOUS"]:
value = np.ascontiguousarray(value)
if value.dtype.byteorder == ">":
value = value.astype(value.dtype.newbyteorder("<"))
fields.append({"name": key, "dtype": dtype_str, "shape": list(value.shape)})
binary_data.append(value.tobytes())
header = {"v": version, "endian": "le", "count": 1, "fields": fields}
header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8")
if len(header_bytes) > HEADER_SIZE:
raise ValueError(f"Header too large: {len(header_bytes)} > {HEADER_SIZE}")
header_bytes = header_bytes.ljust(HEADER_SIZE, b"\x00")
return topic + header_bytes + b"".join(binary_data)
class PicoReader:
def __init__(self):
self._stop = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
self._last_stamp_ns = None
self._fps_ema = 0.0
self._latest = None
self._lock = threading.Lock()
def start(self):
self._thread.start()
def stop(self):
self._stop.set()
self._thread.join(timeout=1.0)
def get_latest(self):
with self._lock:
return self._latest
def _run(self):
last_report = time.time()
while not self._stop.is_set():
if not xrt.is_body_data_available():
time.sleep(0.001)
continue
stamp_ns = xrt.get_time_stamp_ns()
prev_stamp_ns = self._last_stamp_ns
if prev_stamp_ns is not None and stamp_ns == prev_stamp_ns:
time.sleep(0.000001)
continue
device_dt = ((stamp_ns - prev_stamp_ns) * 1e-9) if prev_stamp_ns is not None else 0.0
if device_dt > 0.0:
inst_fps = 1.0 / device_dt
self._fps_ema = inst_fps if self._fps_ema == 0.0 else (0.9 * self._fps_ema + 0.1 * inst_fps)
self._last_stamp_ns = stamp_ns
t_realtime = time.time()
t_monotonic = time.monotonic()
try:
body_poses = xrt.get_body_joints_pose()
body_poses_np = np.asarray(body_poses, dtype=np.float32)
if body_poses_np.shape != (24, 7):
print(f"[PicoReader] WARNING: unexpected body_poses shape: {body_poses_np.shape}")
sample = {
"body_poses_np": body_poses_np,
"timestamp_realtime": t_realtime,
"timestamp_monotonic": t_monotonic,
"timestamp_ns": int(stamp_ns),
"dt": float(device_dt),
"fps": float(self._fps_ema),
}
with self._lock:
self._latest = sample
now = time.time()
if now - last_report >= 5.0:
print(
f"[PicoReader] shape={body_poses_np.shape}, "
f"dt_ts={device_dt * 1000.0:.2f} ms, fps={self._fps_ema:.2f}"
)
last_report = now
except Exception as exc:
print(f"[PicoReader] read error: {exc}")
class ZmqObsSender:
def __init__(self, uri: str, logger, topic: bytes = OUT_TOPIC, mode: str = "bind", conflate: bool = True):
self.logger = logger
self.topic = topic
self._context = zmq.Context()
self._socket = self._context.socket(zmq.PUB)
self._socket.setsockopt(zmq.SNDHWM, 1)
if conflate and hasattr(zmq, "CONFLATE"):
self._socket.setsockopt(zmq.CONFLATE, 1)
if mode == "bind":
self._socket.bind(uri)
elif mode == "connect":
self._socket.connect(uri)
else:
raise ValueError("mode must be 'bind' or 'connect'")
self._last_send_time = None
self._send_freq_log = []
self._frame_count = 0
self.logger.info(f"[ZMQOut] sender ready: mode={mode}, uri={uri}, topic={topic.decode('utf-8')}")
def send(self, latest_obs: np.ndarray, frame_index: int, sample_meta: dict):
payload = {
"latest_obs": np.asarray(latest_obs, dtype=np.float32),
"frame_index": np.array([frame_index], dtype=np.int64),
"timestamp_realtime": np.array([sample_meta["timestamp_realtime"]], dtype=np.float64),
"timestamp_monotonic": np.array([sample_meta["timestamp_monotonic"]], dtype=np.float64),
"timestamp_ns": np.array([sample_meta["timestamp_ns"]], dtype=np.int64),
"pico_dt": np.array([sample_meta["dt"]], dtype=np.float32),
"pico_fps": np.array([sample_meta["fps"]], dtype=np.float32),
}
packet = pack_numpy_message(payload, topic=self.topic)
self._socket.send(packet)
now = time.time()
if self._last_send_time is not None:
dt = now - self._last_send_time
if dt > 0:
self._send_freq_log.append(1.0 / dt)
self._frame_count += 1
if self._frame_count >= 50:
avg_freq = sum(self._send_freq_log) / len(self._send_freq_log)
self.logger.info(f"Average ZMQ send rate: {avg_freq:.2f} Hz")
self._send_freq_log.clear()
self._frame_count = 0
self._last_send_time = now
def stop(self):
self._socket.close(0)
self._context.term()
self.logger.info("🛑 ZMQ obs sender stopped")
class VRNodeXRTPicoGMRZmqOut:
def __init__(
self,
robot_zmq_uri: str,
robot_zmq_mode: str = "bind",
loop_hz: float = 55.0,
timing_log_every: int = 100,
save_obs_path: str = "",
):
self.device = "cpu"
logging.getLogger("websockets").setLevel(logging.WARNING)
self.info(f"✅ VRNodeXRTPicoGMRZmqOut running on device={self.device}")
self.info("starting xrt pico -> gmr -> robot zmq node")
self.gmr = GMR(src_human="smplx", tgt_robot="unitree_g1")
self.smpl_parser = SMPL_Parser(model_path=SMPL_ASSET_DIR, gender="neutral")
if hasattr(self.smpl_parser, "to"):
self.smpl_parser = self.smpl_parser.to(self.device)
self.betas = torch.zeros(1, 10, device=self.device)
self.gmr_input_data: Dict[str, Any] = {}
self.prev_dof_pos = None
self.lasttime = None
self.timing_log_every = max(1, timing_log_every)
self.save_obs_path = save_obs_path
self.mirror_pose = MIRROR_POSE
self.mirror_axis = MIRROR_AXIS
self.tick_count = 0
self.frame_index = 0
self.timing_sums_ms = defaultdict(float)
self.saved_obs = []
self.latest_sample = None
self.reader = PicoReader()
self.reader.start()
self.sender = ZmqObsSender(uri=robot_zmq_uri, logger=self, mode=robot_zmq_mode)
self.start_loop(loop_hz)
def info(self, msg): print(f"[INFO] {msg}")
def error(self, msg): print(f"[ERROR] {msg}")
def warning(self, msg): print(f"[WARN] {msg}")
def debug(self, msg): print(f"[DEBUG] {msg}")
def _accumulate_timing(self, name: str, start_time: float) -> float:
elapsed_ms = (time.perf_counter() - start_time) * 1000.0
self.timing_sums_ms[name] += elapsed_ms
return elapsed_ms
def _maybe_log_timing(self):
if self.tick_count <= 0 or self.tick_count % self.timing_log_every != 0:
return
avg_parts = []
for key in ("body_poses_to_smpl", "smpl_to_joints", "gmr_retarget", "postprocess_send", "tick_total"):
if key in self.timing_sums_ms:
avg_ms = self.timing_sums_ms[key] / self.timing_log_every
avg_parts.append(f"{key}={avg_ms:.2f}ms")
if avg_parts:
self.info("[Timing] " + ", ".join(avg_parts))
self.timing_sums_ms.clear()
def process_smpl_pose_trans_to_gmr_input(self, smpl_pose_aa, smpl_trans) -> Dict[str, Any]:
stage_start = time.perf_counter()
if not isinstance(smpl_pose_aa, torch.Tensor):
smpl_pose_aa = torch.tensor(smpl_pose_aa, dtype=torch.float32)
if not isinstance(smpl_trans, torch.Tensor):
smpl_trans = torch.tensor(smpl_trans, dtype=torch.float32)
pose = smpl_pose_aa.to(self.device, dtype=torch.float32)
trans = smpl_trans.to(self.device, dtype=torch.float32)
if pose.ndim == 2:
pose = pose.unsqueeze(0)
if trans.ndim == 1:
trans = trans.unsqueeze(0)
verts, joints = self.smpl_parser.get_joints_verts(pose, self.betas, trans)
# joints[..., 2] -= verts[0, :, 2].min().item()
pose = pose.squeeze(0)
trans = trans.squeeze(0)
joints = joints.squeeze(0)
motion_state = humanoid_fk.step_per_frame(pose, trans, joints)
global_joints_position = motion_state.global_joints_position
global_joints_rotation_mat = motion_state.global_joints_rotation_mat
global_joints_qua_wxyz = matrix_to_quaternion(global_joints_rotation_mat)
smpl_to_gmr = {
"pelvis": 0,
"spine3": 9,
"left_hip": 1,
"right_hip": 2,
"left_knee": 4,
"right_knee": 5,
"left_foot": 10,
"right_foot": 11,
"left_shoulder": 16,
"right_shoulder": 17,
"left_elbow": 18,
"right_elbow": 19,
"left_wrist": 20,
"right_wrist": 21,
}
gmr_input_data: Dict[str, Any] = {}
for name, idx in smpl_to_gmr.items():
pos = global_joints_position[idx].detach().cpu().numpy()
quat = global_joints_qua_wxyz[idx].detach().cpu().numpy()
gmr_input_data[name] = (pos, quat)
if self.mirror_pose:
gmr_input_data = mirror_and_swap_gmr_input(gmr_input_data, mirror_axis=self.mirror_axis)
self._accumulate_timing("smpl_to_joints", stage_start)
return gmr_input_data
def process_xrt_frame_to_gmr_input(self, sample: dict):
body_poses = np.asarray(sample["body_poses_np"], dtype=np.float32)
if body_poses.shape != (24, 7):
raise ValueError(f"[XRT] body_poses_np must have shape (24,7), got {body_poses.shape}")
stage_start = time.perf_counter()
pose_aa, trans = body_poses_to_smpl_pose_trans(
body_poses,
cfg=PicoToSmplConfig(
apply_global_y_180=True,
apply_root_rx90=True,
root_align_axis="x",
root_align_degrees=90.0,
),
)
self._accumulate_timing("body_poses_to_smpl", stage_start)
self.gmr_input_data = self.process_smpl_pose_trans_to_gmr_input(pose_aa, trans)
def process_gmr_output(self):
stage_start = time.perf_counter()
qpos = self.gmr.retarget(self.gmr_input_data)
self._accumulate_timing("gmr_retarget", stage_start)
stage_start = time.perf_counter()
root_pos = qpos[:3]
root_rot = qpos[3:7]
dof_pos = qpos[7:]
now = time.time()
delta_time = 1 / 50 if self.lasttime is None else (now - self.lasttime)
self.lasttime = now
if self.prev_dof_pos is None:
dof_vel = np.zeros_like(dof_pos, dtype=np.float32)
else:
dof_vel = (dof_pos - self.prev_dof_pos) / max(delta_time, 1e-6)
self.prev_dof_pos = dof_pos.copy()
latest_obs = np.concatenate([dof_pos, dof_vel, root_pos, root_rot], axis=0).astype(np.float32)
self.publish_data(latest_obs)
self.sender.send(latest_obs, self.frame_index, self.latest_sample)
self.saved_obs.append(latest_obs.copy())
self.frame_index += 1
self._accumulate_timing("postprocess_send", stage_start)
return latest_obs
def publish_data(self, motion_state: np.ndarray):
if motion_state.size != 65:
self.error(f"Output dim {motion_state.size} != expected 65")
return
if np.isnan(motion_state).any():
self.error("NaN detected")
return
def save_observations(self):
if not self.save_obs_path:
return
if len(self.saved_obs) == 0:
self.warning(f"[SaveObs] no observations to save: {self.save_obs_path}")
return
obs_array = np.stack(self.saved_obs, axis=0).astype(np.float32)
save_dir = os.path.dirname(self.save_obs_path)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
if self.save_obs_path.endswith(".npy"):
np.save(self.save_obs_path, obs_array)
else:
np.savez_compressed(
self.save_obs_path,
latest_obs=obs_array,
columns=np.array(["dof_pos(29)", "dof_vel(29)", "root_pos(3)", "root_rot_wxyz(4)"], dtype=object),
)
self.info(f"[SaveObs] saved {obs_array.shape[0]} frames to {self.save_obs_path}")
def start_loop(self, hz=50):
self.info(f"Starting main loop at {hz} Hz")
interval = 1.0 / hz
def loop():
next_time = time.time()
while True:
self._tick()
next_time += interval
sleep_time = next_time - time.time()
if sleep_time > 0:
time.sleep(sleep_time)
else:
next_time = time.time()
threading.Thread(target=loop, daemon=True).start()
def _tick(self):
tick_start = time.perf_counter()
sample = self.reader.get_latest()
if sample is not None:
try:
self.latest_sample = sample
self.process_xrt_frame_to_gmr_input(sample)
self.process_gmr_output()
except Exception as exc:
self.error(f"[tick_error] {exc}")
self.error(traceback.format_exc())
return
elif self.prev_dof_pos is not None:
try:
self.process_gmr_output()
except Exception as exc:
self.error(f"[tick_error] {exc}")
self.error(traceback.format_exc())
return
self.tick_count += 1
self._accumulate_timing("tick_total", tick_start)
self._maybe_log_timing()
def stop(self):
self.reader.stop()
self.sender.stop()
self.save_observations()
try:
if xrt is not None and hasattr(xrt, "close"):
xrt.close()
except Exception:
pass
def init_xrt(start_service: bool = True):
if xrt is None:
raise ImportError("XRoboToolkit SDK not available. Install xrobotoolkit_sdk first.")
if start_service:
subprocess.Popen(["bash", "/opt/apps/roboticsservice/runService.sh"])
xrt.init()
print("Waiting for body tracking data...")
while not xrt.is_body_data_available():
print("waiting for body data...")
time.sleep(1)
def main():
parser = argparse.ArgumentParser(description="XRT Pico -> GMR -> robot ZMQ(65D)")
parser.add_argument("--robot-zmq-uri", default="tcp://*:6001", help="Robot-side ZMQ uri for 65D obs output")
parser.add_argument("--robot-zmq-mode", default="bind", choices=["bind", "connect"])
parser.add_argument("--hz", type=float, default=55.0, help="Main loop frequency / publish cap")
parser.add_argument("--timing-log-every", type=int, default=200, help="Print average stage timing every N ticks")
parser.add_argument("--save-obs-path", type=str, default="", help="Optional path to save emitted 65D observations")
parser.add_argument("--skip-start-service", action="store_true", help="Do not auto-run /opt/apps/roboticsservice/runService.sh")
args = parser.parse_args()
init_xrt(start_service=not args.skip_start_service)
node = VRNodeXRTPicoGMRZmqOut(
robot_zmq_uri=args.robot_zmq_uri,
robot_zmq_mode=args.robot_zmq_mode,
loop_hz=args.hz,
timing_log_every=args.timing_log_every,
save_obs_path=args.save_obs_path,
)
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
node.stop()
print("🛑 Program terminated by user.")
if __name__ == "__main__":
main()
================================================
FILE: deployment/holomotion_teleop/holomotion_teleop_setup.md
================================================
# Holomotion Teleop
Single-process pipeline for:
`PICO / XRoboToolkit -> SMPL conversion -> GMR retargeting -> robot ZMQ`
## Prerequisites
Before setting up the Python environment, install XRoboToolkit PC Service manually.
1. On Ubuntu 22.04, download the XRoboToolkit PC Service `.deb` package, or build it from source.
```bash
sudo dpkg -i XRoboToolkit_PC_Service_1.0.0_ubuntu_22.04_amd64.deb
```
## Environment Setup
```bash
cd /path/to/holomotion_teleop
bash setup_holomotion_teleop_x86_ubuntu2204.sh
```
This script will:
- create the Conda environment `holomotion_teleop`
- automatically clone and install `GMR` and `SMPLSim`
- install runtime dependencies such as `numpy==1.23.5`, `torch`, and `pyzmq`
- build and install `xrobotoolkit_sdk` from source
Optional environment variables:
```bash
ENV_NAME=holomotion_teleop
PYTHON_VERSION=3.10
INSTALL_APT_DEPS=auto
THIRD_PARTY_DIR=/path/to/third_party
GMR_SOURCE_DIR=/path/to/GMR
SMPLSIM_SOURCE_DIR=/path/to/SMPLSim
XRT_PYBIND_REPO_DIR=/path/to/XRoboToolkit-PC-Service-Pybind
```
- `INSTALL_APT_DEPS=auto`: only runs apt installation if required build tools are missing
- `INSTALL_APT_DEPS=0`: skip apt installation entirely if your machine already has the tools or apt is unusable
- `INSTALL_APT_DEPS=1`: force the apt installation step
- `THIRD_PARTY_DIR`: default directory used for auto-cloned third-party repositories
- `GMR_SOURCE_DIR` / `SMPLSIM_SOURCE_DIR`: point to external source checkouts; if omitted, the script auto-clones them
## Input and Output
### Input
The script reads raw body tracking data directly from `xrobotoolkit_sdk.get_body_joints_pose()`:
- shape: `(24, 7)`
- row format: `[x, y, z, qx, qy, qz, qw]`
### Output
The robot-side ZMQ payload contains `latest_obs` as `float32[65]`:
1. `dof_pos[29]`
2. `dof_vel[29]`
3. `root_pos[3]`
4. `root_rot_wxyz[4]`
Additional metadata is included in the same payload:
- `frame_index`
- `timestamp_realtime`
- `timestamp_monotonic`
- `timestamp_ns`
- `pico_dt`
- `pico_fps`
## Next Steps
Before running teleoperation on the real robot, make sure the operators are already familiar with the offline `.npz` motion-performance workflow and the robot's basic mode-switching behavior. Teleoperation should not be the first time the team tests motion-mode entry on hardware.
### Real Robot Workflow
Use the following checklist when running the teleoperation stack on the real robot.
#### 1. Hardware and Network
Required hardware:
- PICO 4 / PICO 4 Pro headset
- 2 PICO controllers
- 2 PICO motion trackers attached to the ankles
- One workstation running `holomotion_teleop_node.py`
- One robot computer running the policy / control stack
- A low-latency Wi-Fi network shared by the PICO headset and the workstation
Make sure the robot, the workstation and the PICO headset are on the same Wi-Fi network. Low network latency is important for stable teleoperation. The PICO-side setup steps below follow the XRoboToolkit / PICO workflow described in the [GR00T VR Teleop Setup (PICO)](https://nvlabs.github.io/GR00T-WholeBodyControl/getting_started/vr_teleop_setup.html).
#### 2. Install and Configure PICO
1. Install the XRoboToolkit PICO app on the headset.
- Enable Developer Mode on the headset.
- Open the browser on PICO and download the XRoboToolkit PICO APK.
- Install the APK from the downloads page and confirm it appears in the app library.
2. Pair the two PICO motion trackers.
- Attach one tracker to each ankle.
- Open the motion tracker settings on the headset.
- Unpair any old trackers first, then pair both trackers again.
3. Calibrate the motion trackers on the headset.
- Follow the standing calibration step.
- Then look down at the foot trackers so the headset cameras can detect them.
4. Connect the headset to the workstation.
- Confirm the headset and workstation are on the same Wi-Fi network.
- Open the XRoboToolkit app on the headset.
- Enter the workstation IP address into the PC Service field.
- Verify the status shows a successful connection.
5. In XRoboToolkit, enable the required streaming options.
- Enable `Head` and `Controller` tracking.
- Set `Pico Motion Tracker` to `Full body`.
- Enable the `Send` option for data/control streaming.
#### 3. Configure the Robot-Side Policy
Before starting the robot-side policy, update the robot config file:
`HoloMotion/deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml`
Recommended settings:
- `enable_teleop_reference: true`
- `require_vr_data_for_motion: true`
- `latest_obs_zmq_uri: "tcp://:6001"`
Replace `` with the actual IP address of the workstation that runs `holomotion_teleop_node.py`.
This ensures the robot waits for live VR data before switching into motion mode and connects to the correct ZMQ publisher endpoint.
#### 4. Launch Order
Start the system in the following order:
1. Start the robot control / policy stack on the robot computer.
2. Wait until the control policy is fully initialized, then press `Start` to move the robot into the default pose.
3. Start XRoboToolkit on the PICO headset and confirm that body-tracking data is being streamed.
4. Start the teleoperation node on the workstation:
```bash
conda activate holomotion_teleop
cd /path/to/holomotion_teleop
python holomotion_teleop_node.py
```
If needed, pass explicit ZMQ arguments such as:
```bash
python holomotion_teleop_node.py \
--robot-zmq-uri tcp://*:6001 \
--robot-zmq-mode bind \
--hz 50
```
5. After the robot-side policy is receiving live teleoperation data, perform the runtime mode sequence:
- press `A` to enter walking / velocity mode
- press `B` to enter teleoperation motion mode
- press `Y` whenever you want to leave teleoperation and return to walking mode
## Optional Arguments
- `--robot-zmq-uri`: robot-side ZMQ endpoint for the 65D output
- `--robot-zmq-mode`: `bind` or `connect`
- `--hz`: main loop frequency / processing cap
- `--timing-log-every`: print average stage timing every N ticks
- `--save-obs-path`: save emitted 65D observations on exit as `.npy` or `.npz`
#### 5. Runtime Check
Before enabling motion on the robot:
- confirm XRoboToolkit PC Service is running
- confirm the PICO headset is connected to the workstation
- confirm `holomotion_teleop_node.py` is publishing ZMQ data
- confirm the robot-side policy is using the correct workstation IP in `latest_obs_zmq_uri`
- confirm the robot-side config keeps `enable_teleop_reference: true`
- confirm the robot-side config keeps `require_vr_data_for_motion: true`
- confirm the team has already validated the offline `.npz` motion-performance pipeline before attempting live teleoperation
Once the ZMQ stream is stable, enable the robot policy and switch into motion mode.
================================================
FILE: deployment/holomotion_teleop/setup_holomotion_teleop_x86_ubuntu2204.sh
================================================
#!/usr/bin/env bash
set -euo pipefail
# One-click setup script for the holomotion teleoperation environment.
#
# This script automates the manually verified workflow:
# 1. create/activate conda env
# 2. clone/install GMR
# 3. clone/build/install XRoboToolkit pybind SDK
# 4. clone/install SMPLSim
# 5. install runtime Python dependencies
#
# Usage:
# bash setup_gmr_holomotion_teleop_ubuntu2204.sh
#
# Optional env vars:
# ENV_NAME=holomotion_teleop
# PYTHON_VERSION=3.10
# INSTALL_APT_DEPS=0 # default disabled; set to 1 only if you need apt
# THIRD_PARTY_DIR=/path/to/third_party
# GMR_SOURCE_DIR=/path/to/GMR
# SMPLSIM_SOURCE_DIR=/path/to/SMPLSim
# XRT_PYBIND_REPO_DIR=/path/to/XRoboToolkit-PC-Service-Pybind
ENV_NAME="${ENV_NAME:-holomotion_teleop}"
PYTHON_VERSION="${PYTHON_VERSION:-3.10}"
INSTALL_APT_DEPS="${INSTALL_APT_DEPS:-0}"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="${SCRIPT_DIR}"
THIRD_PARTY_DIR="${THIRD_PARTY_DIR:-$PROJECT_ROOT/third_party}"
GMR_REPO_URL="${GMR_REPO_URL:-https://github.com/YanjieZe/GMR.git}"
SMPLSIM_REPO_URL="${SMPLSIM_REPO_URL:-https://github.com/ZhengyiLuo/SMPLSim.git}"
XRT_PYBIND_REPO_URL="${XRT_PYBIND_REPO_URL:-https://github.com/YanjieZe/XRoboToolkit-PC-Service-Pybind.git}"
XRT_PC_SERVICE_REPO_URL="${XRT_PC_SERVICE_REPO_URL:-https://github.com/XR-Robotics/XRoboToolkit-PC-Service.git}"
GMR_SOURCE_DIR="${GMR_SOURCE_DIR:-$THIRD_PARTY_DIR/GMR}"
SMPLSIM_SOURCE_DIR="${SMPLSIM_SOURCE_DIR:-$THIRD_PARTY_DIR/SMPLSim}"
XRT_PYBIND_REPO_DIR="${XRT_PYBIND_REPO_DIR:-$THIRD_PARTY_DIR/XRoboToolkit-PC-Service-Pybind}"
info() {
echo "[INFO] $*"
}
warn() {
echo "[WARN] $*" >&2
}
error() {
echo "[ERROR] $*" >&2
exit 1
}
require_command() {
local cmd="$1"
local hint="${2:-}"
if ! command -v "$cmd" >/dev/null 2>&1; then
if [[ -n "$hint" ]]; then
error "$cmd not found. $hint"
else
error "$cmd not found."
fi
fi
}
run_conda_relaxed() {
# Some conda activation/deactivation hooks are not compatible with `set -u`
# and may reference unset variables such as SETVARS_CALL.
set +u
"$@"
local status=$?
set -u
return $status
}
show_env_summary() {
info "project root: $PROJECT_ROOT"
info "env name: $ENV_NAME"
info "python version: $PYTHON_VERSION"
info "install apt deps: $INSTALL_APT_DEPS"
info "third party dir: $THIRD_PARTY_DIR"
info "gmr source dir: $GMR_SOURCE_DIR"
info "smplsim source dir: $SMPLSIM_SOURCE_DIR"
info "xrt pybind dir: $XRT_PYBIND_REPO_DIR"
}
check_platform() {
if [[ "$(uname -s)" != "Linux" ]]; then
error "This setup script currently supports Linux only."
fi
if [[ -f /etc/os-release ]]; then
# shellcheck disable=SC1091
source /etc/os-release
info "detected OS: ${PRETTY_NAME:-unknown}"
if [[ "${ID:-}" != "ubuntu" || "${VERSION_ID:-}" != "22.04" ]]; then
warn "This script is primarily tested on Ubuntu 22.04. Continuing anyway."
fi
fi
}
apt_deps_missing() {
local missing=0
command -v gcc >/dev/null 2>&1 || missing=1
command -v g++ >/dev/null 2>&1 || missing=1
command -v make >/dev/null 2>&1 || missing=1
command -v git >/dev/null 2>&1 || missing=1
command -v cmake >/dev/null 2>&1 || missing=1
return "$missing"
}
install_apt_deps_if_needed() {
case "$INSTALL_APT_DEPS" in
0|false|False|FALSE|no|NO)
info "Skipping apt dependency installation because INSTALL_APT_DEPS=$INSTALL_APT_DEPS"
info "This matches the manually verified workflow and avoids unrelated apt source failures"
return
;;
1|true|True|TRUE|yes|YES)
;;
*)
error "Unsupported INSTALL_APT_DEPS value: $INSTALL_APT_DEPS (expected 1 or 0)"
;;
esac
require_command sudo "Install sudo or run the equivalent apt commands manually."
require_command apt-get "This script needs apt-get to install build tools."
info "Installing apt packages needed for build"
if ! sudo apt-get update; then
error "apt-get update failed. Common causes: broken apt sources, third-party repository timeouts, or proxy/network issues."
fi
if ! sudo apt-get install -y build-essential git cmake; then
cat >&2 <<'EOF'
[ERROR] apt package installation failed.
Try one of the following:
1. sudo apt --fix-broken install
2. disable broken third-party apt repositories temporarily
3. rerun with INSTALL_APT_DEPS=0 if gcc/g++/make/git/cmake already exist
EOF
exit 1
fi
}
setup_conda_env() {
require_command conda "Please install Miniconda or Anaconda first."
# shellcheck disable=SC1091
source "$(conda info --base)/etc/profile.d/conda.sh"
if ! conda env list | awk '{print $1}' | grep -Fx "$ENV_NAME" >/dev/null 2>&1; then
info "Creating conda env: $ENV_NAME"
run_conda_relaxed conda create -n "$ENV_NAME" "python=$PYTHON_VERSION" -y
else
info "Conda env already exists: $ENV_NAME"
fi
run_conda_relaxed conda activate "$ENV_NAME"
}
clone_repo_if_missing() {
local repo_dir="$1"
local repo_url="$2"
local repo_name="$3"
if [[ ! -d "$repo_dir/.git" ]]; then
info "Cloning $repo_name"
mkdir -p "$(dirname "$repo_dir")"
git clone "$repo_url" "$repo_dir"
else
info "Using existing $repo_name checkout at $repo_dir"
fi
}
install_gmr() {
clone_repo_if_missing "$GMR_SOURCE_DIR" "$GMR_REPO_URL" "GMR"
info "Installing GMR in editable mode"
python -m pip install -e "$GMR_SOURCE_DIR"
}
build_xrt_python_sdk() {
clone_repo_if_missing "$XRT_PYBIND_REPO_DIR" "$XRT_PYBIND_REPO_URL" "XRoboToolkit pybind repository"
pushd "$XRT_PYBIND_REPO_DIR" >/dev/null
mkdir -p tmp
clone_repo_if_missing "tmp/XRoboToolkit-PC-Service" "$XRT_PC_SERVICE_REPO_URL" "XRoboToolkit PC Service source"
info "Building PXREARobotSDK"
pushd tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK >/dev/null
bash build.sh
popd >/dev/null
mkdir -p lib include
cp tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/PXREARobotSDK.h include/
cp -r tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/nlohmann include/nlohmann/
cp tmp/XRoboToolkit-PC-Service/RoboticsService/PXREARobotSDK/build/libPXREARobotSDK.so lib/
info "Installing pybind11 into conda env"
run_conda_relaxed conda install -y -c conda-forge pybind11
info "Reinstalling xrobotoolkit_sdk"
python -m pip uninstall -y xrobotoolkit_sdk || true
python setup.py install
popd >/dev/null
}
install_smplsim() {
clone_repo_if_missing "$SMPLSIM_SOURCE_DIR" "$SMPLSIM_REPO_URL" "SMPLSim"
info "Installing SMPLSim in editable mode"
python -m pip install -e "$SMPLSIM_SOURCE_DIR"
}
install_runtime_python_deps() {
info "Upgrading pip toolchain"
python -m pip install --upgrade pip setuptools wheel
info "Installing runtime Python packages"
python -m pip install pyzmq
python -m pip install open3d
}
install_compat_python_deps() {
info "Installing compatibility packages"
python -m pip install chumpy
info "Pinning numpy for chumpy compatibility"
python -m pip install --upgrade "numpy==1.23.5"
}
print_next_steps() {
echo
info "Environment setup complete"
echo
info "Manual prerequisite:"
echo " Install XRoboToolkit PC Service manually from the Ubuntu 22.04 .deb package."
echo " Launch xrobotoolkit-pc-service before teleoperation."
echo
info "Activate with:"
echo " conda activate $ENV_NAME"
echo
info "Example command:"
echo " python \"$PROJECT_ROOT/holomotion_teleop_node.py\" \\"
echo " --robot-zmq-uri tcp://*:6001 \\"
echo " --robot-zmq-mode bind \\"
echo " --hz 50 \\"
echo " --timing-log-every 250"
}
main() {
check_platform
show_env_summary
install_apt_deps_if_needed
setup_conda_env
install_gmr
build_xrt_python_sdk
install_runtime_python_deps
install_smplsim
install_compat_python_deps
print_next_steps
}
main "$@"
================================================
FILE: deployment/unitree_g1_ros2_29dof/launch_holomotion_29dof.sh
================================================
#!/bin/bash
##############################################################################
# HoloMotion Deployment Launch Script
#
# This script sets up the complete environment and launches the HoloMotion
# humanoid robot control system for the Unitree G1 robot. It handles:
# 1. ROS2 environment setup and workspace building
# 2. Conda environment configuration for GPU/CUDA support
# 3. Library path configuration for proper linking
# 4. Launch of the complete HoloMotion control pipeline
#
# Prerequisites:
# - Unitree ROS2 SDK properly installed at ~/unitree_ros2/
# - Conda environment 'holomotion_deploy' with required packages
# - Network interface configured for robot communication
# - Proper permissions for robot hardware access
#
# Usage:
# ./launch_holomotion_29dof.sh [--record]
# --record: Enable topic recording (optional, disabled by default)
#
# Author: HoloMotion Team
# License: See project LICENSE file
##############################################################################
# Default values
ENABLE_RECORDING=false
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--record)
ENABLE_RECORDING=true
shift
;;
-h|--help)
echo "Usage: $0 [--record]"
echo " --record: Enable topic recording (optional, disabled by default)"
exit 0
;;
*)
echo "Unknown option $1"
echo "Usage: $0 [--record]"
echo " --record: Enable topic recording (optional, disabled by default)"
exit 1
;;
esac
done
echo "Starting HoloMotion 29DOF..."
echo "Recording enabled: $ENABLE_RECORDING"
rm -rf build/ install/ log/ 2>/dev/null || sudo rm -rf build/ install/ log/
source ~/miniconda3/bin/activate
while [[ ${CONDA_SHLVL:-0} -gt 0 ]]; do
conda deactivate
done
source /opt/ros/humble/setup.sh
source ~/unitree_ros2/setup.sh
colcon build
source install/setup.bash
source ../../deploy.env
# Launch with recording parameter
ros2 launch humanoid_control holomotion_29dof_launch.py enable_recording:=$ENABLE_RECORDING
================================================
FILE: deployment/unitree_g1_ros2_29dof/launch_holomotion_29dof_docker.sh
================================================
#!/bin/bash
##############################################################################
# HoloMotion Deployment Launch Script
#
# This script sets up the complete environment and launches the HoloMotion
# humanoid robot control system for the Unitree G1 robot. It handles:
# 1. ROS2 environment setup and workspace building
# 2. Conda environment configuration for GPU/CUDA support
# 3. Library path configuration for proper linking
# 4. Launch of the complete HoloMotion control pipeline
#
# Prerequisites:
# - Unitree ROS2 SDK properly installed at ~/unitree_ros2/
# - Conda environment 'holomotion_deploy' with required packages
# - Network interface configured for robot communication
# - Proper permissions for robot hardware access
#
# Usage:
# ./launch_holomotion_29dof_docker.sh [--record]
# --record: Enable topic recording (optional, disabled by default)
#
# Author: HoloMotion Team
# License: See project LICENSE file
##############################################################################
# Default values
ENABLE_RECORDING=false
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--record)
ENABLE_RECORDING=true
shift
;;
-h|--help)
echo "Usage: $0 [--record]"
echo " --record: Enable topic recording (optional, disabled by default)"
exit 0
;;
*)
echo "Unknown option $1"
echo "Usage: $0 [--record]"
echo " --record: Enable topic recording (optional, disabled by default)"
exit 1
;;
esac
done
echo "Starting HoloMotion 29DOF Docker..."
echo "Recording enabled: $ENABLE_RECORDING"
source /root/miniconda3/etc/profile.d/conda.sh
while [[ ${CONDA_SHLVL:-0} -gt 0 ]]; do
conda deactivate
done
rm -rf build/ install/ log/
source /opt/ros/humble/setup.sh
source /root/unitree_ros2/setup.sh
colcon build
source install/setup.bash
# Configure conda environment paths for CUDA and library linking
# NOTE: Update this path to match your actual conda environment location
export CYCLONEDDS_HOME=/root/cyclonedds/install
export CMAKE_PREFIX_PATH=$CYCLONEDDS_HOME:$CMAKE_PREFIX_PATH
source ../../deploy.env
export LD_LIBRARY_PATH=/host_gpu:/cuda_base:/usr/lib/aarch64-linux-gnu/tegra:/usr/lib/aarch64-linux-gnu:/usr/local/cuda/lib64:/lib/aarch64-linux-gnu/:$LD_LIBRARY_PATH
# Launch with recording parameter
ros2 launch humanoid_control holomotion_29dof_launch.py enable_recording:=$ENABLE_RECORDING
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/CMakeLists.txt
================================================
cmake_minimum_required(VERSION 3.8)
project(humanoid_control)
# Default to C99
if(NOT CMAKE_C_STANDARD)
set(CMAKE_C_STANDARD 99)
endif()
# Default to C++14
if(NOT CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 17)
endif()
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-Wall -Wextra -Wpedantic)
endif()
include_directories(include include/common include/nlohmann)
link_directories(src)
set(
DEPENDENCY_LIST
unitree_go
unitree_hg
unitree_api
rclcpp
std_msgs
rosbag2_cpp
yaml-cpp
)
# find dependencies
find_package(ament_cmake REQUIRED)
find_package(ament_cmake_python REQUIRED)
find_package(unitree_go REQUIRED)
find_package(unitree_hg REQUIRED)
find_package(unitree_api REQUIRED)
find_package(rclcpp REQUIRED)
find_package(std_msgs REQUIRED)
find_package(rosbag2_cpp REQUIRED)
find_package(yaml-cpp REQUIRED)
# Main control executable
add_executable(
humanoid_control
src/main_node.cpp
src/common/motor_crc_hg.cpp
src/common/wireless_controller.cpp
)
ament_target_dependencies(humanoid_control ${DEPENDENCY_LIST})
# Install Python modules
ament_python_install_package(humanoid_policy)
# Install Python scripts as executables
install(PROGRAMS
humanoid_policy/policy_node_29dof.py
DESTINATION lib/${PROJECT_NAME}
RENAME policy_node_29dof
)
# Install your models directory
install(DIRECTORY
models/
DESTINATION share/${PROJECT_NAME}/models
)
install(TARGETS
humanoid_control
DESTINATION lib/${PROJECT_NAME})
# motion folder
install(DIRECTORY
config/
DESTINATION share/${PROJECT_NAME}/config
)
install(DIRECTORY
motion_data/
DESTINATION share/${PROJECT_NAME}/motion_data
)
# Install launch files
install(
DIRECTORY launch/
DESTINATION share/${PROJECT_NAME}
)
if(BUILD_TESTING)
find_package(ament_lint_auto REQUIRED)
ament_lint_auto_find_test_dependencies()
endif()
ament_package()
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml
================================================
device: "cuda"
policy_freq: 50 # Hz
control_freq: 500 # Hz
lowstate_topic: "/lowstate"
action_topic: "/humanoid/action"
# walking policy
velocity_tracking_model_folder: "velocity_tracking_model"
# motion policy
motion_tracking_model_folder: "motion_tracking_model"
# motion data
motion_clip_dir: "motion_data"
cpu_affinity_main: ""
cpu_affinity_zmq_sub: ""
# VR / ZMQ: the robot acts as a SUB socket and receives latest_obs from the sender.
vr:
enable_teleop_reference: false
latest_obs_zmq_uri: "tcp://192.168.124.29:6001"
latest_obs_zmq_topic: "obs65"
latest_obs_zmq_mode: "connect"
latest_obs_zmq_conflate: true
zmq_jitter_delay_frames: 5
max_data_age: 0.6
require_vr_data_for_motion: false
timing_debug_enabled: false
timing_debug_log_interval_sec: 5.0
timing_debug_log_per_loop: false
complete_dof_order:
- left_hip_pitch_joint
- left_hip_roll_joint
- left_hip_yaw_joint
- left_knee_joint
- left_ankle_pitch_joint
- left_ankle_roll_joint
- right_hip_pitch_joint
- right_hip_roll_joint
- right_hip_yaw_joint
- right_knee_joint
- right_ankle_pitch_joint
- right_ankle_roll_joint
- waist_yaw_joint
- waist_roll_joint
- waist_pitch_joint
- left_shoulder_pitch_joint
- left_shoulder_roll_joint
- left_shoulder_yaw_joint
- left_elbow_joint
- left_wrist_roll_joint
- left_wrist_pitch_joint
- left_wrist_yaw_joint
- right_shoulder_pitch_joint
- right_shoulder_roll_joint
- right_shoulder_yaw_joint
- right_elbow_joint
- right_wrist_roll_joint
- right_wrist_pitch_joint
- right_wrist_yaw_joint
policy_dof_order:
- left_hip_pitch_joint
- left_hip_roll_joint
- left_hip_yaw_joint
- left_knee_joint
- left_ankle_pitch_joint
- left_ankle_roll_joint
- right_hip_pitch_joint
- right_hip_roll_joint
- right_hip_yaw_joint
- right_knee_joint
- right_ankle_pitch_joint
- right_ankle_roll_joint
- waist_yaw_joint
- waist_roll_joint
- waist_pitch_joint
- left_shoulder_pitch_joint
- left_shoulder_roll_joint
- left_shoulder_yaw_joint
- left_elbow_joint
- left_wrist_roll_joint
- left_wrist_pitch_joint
- left_wrist_yaw_joint
- right_shoulder_pitch_joint
- right_shoulder_roll_joint
- right_shoulder_yaw_joint
- right_elbow_joint
- right_wrist_roll_joint
- right_wrist_pitch_joint
- right_wrist_yaw_joint
dof2motor_idx_mapping:
# https://support.unitree.com/home/zh/G1_developer/about_G1
left_hip_pitch_joint: 0
left_hip_roll_joint: 1
left_hip_yaw_joint: 2
left_knee_joint: 3
left_ankle_pitch_joint: 4
left_ankle_roll_joint: 5
right_hip_pitch_joint: 6
right_hip_roll_joint: 7
right_hip_yaw_joint: 8
right_knee_joint: 9
right_ankle_pitch_joint: 10
right_ankle_roll_joint: 11
waist_yaw_joint: 12
waist_roll_joint: 13
waist_pitch_joint: 14
left_shoulder_pitch_joint: 15
left_shoulder_roll_joint: 16
left_shoulder_yaw_joint: 17
left_elbow_joint: 18
left_wrist_roll_joint: 19
left_wrist_pitch_joint: 20
left_wrist_yaw_joint: 21
right_shoulder_pitch_joint: 22
right_shoulder_roll_joint: 23
right_shoulder_yaw_joint: 24
right_elbow_joint: 25
right_wrist_roll_joint: 26
right_wrist_pitch_joint: 27
right_wrist_yaw_joint: 28
default_joint_angles:
# Left leg joints (indices 0-5)
left_hip_pitch_joint: -0.312
left_hip_roll_joint: 0.0
left_hip_yaw_joint: 0.0
left_knee_joint: 0.669
left_ankle_pitch_joint: -0.33
left_ankle_roll_joint: 0.0
# Right leg joints (indices 6-11)
right_hip_pitch_joint: -0.312
right_hip_roll_joint: 0.0
right_hip_yaw_joint: 0.0
right_knee_joint: 0.669
right_ankle_pitch_joint: -0.33
right_ankle_roll_joint: 0.0
# Waist joints (indices 12-14)
waist_yaw_joint: 0.0
waist_roll_joint: 0.0
waist_pitch_joint: 0.2
# Left arm joints (indices 15-21)
left_shoulder_pitch_joint: 0.2
left_shoulder_roll_joint: 0.2
left_shoulder_yaw_joint: 0.0
left_elbow_joint: 0.6
left_wrist_roll_joint: 0.0
left_wrist_pitch_joint: 0.0
left_wrist_yaw_joint: 0.0
# Right arm joints (indices 22-28)
right_shoulder_pitch_joint: 0.2
right_shoulder_roll_joint: -0.2
right_shoulder_yaw_joint: 0.0
right_elbow_joint: 0.6
right_wrist_roll_joint: 0.0
right_wrist_pitch_joint: 0.0
right_wrist_yaw_joint: 0.0
# Joint limits
joint_limits:
position:
# Left leg joints
left_hip_pitch_joint: [-2.5307, 2.8798]
left_hip_roll_joint: [-0.5236, 2.9671]
left_hip_yaw_joint: [-2.7576, 2.7576]
left_knee_joint: [-0.087267, 2.8798]
left_ankle_pitch_joint: [-0.87267, 0.5236]
left_ankle_roll_joint: [-0.2618, 0.2618]
# Right leg joints
right_hip_pitch_joint: [-2.5307, 2.8798]
right_hip_roll_joint: [-2.9671, 0.5236]
right_hip_yaw_joint: [-2.7576, 2.7576]
right_knee_joint: [-0.087267, 2.8798]
right_ankle_pitch_joint: [-0.87267, 0.5236]
right_ankle_roll_joint: [-0.2618, 0.2618]
# Waist joints
waist_yaw_joint: [-2.618, 2.618]
waist_roll_joint: [-0.52, 0.52]
waist_pitch_joint: [-0.52, 0.52]
# Left arm joints
left_shoulder_pitch_joint: [-3.0892, 2.6704]
left_shoulder_roll_joint: [-1.5882, 2.2515]
left_shoulder_yaw_joint: [-2.618, 2.618]
left_elbow_joint: [-1.0472, 2.0944]
left_wrist_roll_joint: [-1.972222054, 1.972222054]
left_wrist_pitch_joint: [-1.614429558, 1.614429558]
left_wrist_yaw_joint: [-1.614429558, 1.614429558]
# Right arm joints
right_shoulder_pitch_joint: [-3.0892, 2.6704]
right_shoulder_roll_joint: [-2.2515, 1.5882]
right_shoulder_yaw_joint: [-2.618, 2.618]
right_elbow_joint: [-1.0472, 2.0944]
right_wrist_roll_joint: [-1.972222054, 1.972222054]
right_wrist_pitch_joint: [-1.614429558, 1.614429558]
right_wrist_yaw_joint: [-1.614429558, 1.614429558]
velocity:
# Left leg joints
left_hip_pitch_joint: 32
left_hip_roll_joint: 20
left_hip_yaw_joint: 32
left_knee_joint: 20
left_ankle_pitch_joint: 30
left_ankle_roll_joint: 30
# Right leg joints
right_hip_pitch_joint: 32
right_hip_roll_joint: 20
right_hip_yaw_joint: 32
right_knee_joint: 20
right_ankle_pitch_joint: 30
right_ankle_roll_joint: 30
# Waist joints
waist_yaw_joint: 32
waist_roll_joint: 30
waist_pitch_joint: 30
# Left arm joints
left_shoulder_pitch_joint: 37
left_shoulder_roll_joint: 37
left_shoulder_yaw_joint: 37
left_elbow_joint: 37
left_wrist_roll_joint: 37
left_wrist_pitch_joint: 22
left_wrist_yaw_joint: 22
# Right arm joints
right_shoulder_pitch_joint: 37
right_shoulder_roll_joint: 37
right_shoulder_yaw_joint: 37
right_elbow_joint: 37
right_wrist_roll_joint: 37
right_wrist_pitch_joint: 22
right_wrist_yaw_joint: 22
effort:
# Left leg joints
left_hip_pitch_joint: 88
left_hip_roll_joint: 139
left_hip_yaw_joint: 88
left_knee_joint: 139
left_ankle_pitch_joint: 35
left_ankle_roll_joint: 35
# Right leg joints
right_hip_pitch_joint: 88
right_hip_roll_joint: 139
right_hip_yaw_joint: 88
right_knee_joint: 139
right_ankle_pitch_joint: 35
right_ankle_roll_joint: 35
# Waist joints
waist_yaw_joint: 88
waist_roll_joint: 35
waist_pitch_joint: 35
# Left arm joints
left_shoulder_pitch_joint: 25
left_shoulder_roll_joint: 25
left_shoulder_yaw_joint: 25
left_elbow_joint: 25
left_wrist_roll_joint: 25
left_wrist_pitch_joint: 5
left_wrist_yaw_joint: 5
# Right arm joints
right_shoulder_pitch_joint: 25
right_shoulder_roll_joint: 25
right_shoulder_yaw_joint: 25
right_elbow_joint: 25
right_wrist_roll_joint: 25
right_wrist_pitch_joint: 5
right_wrist_yaw_joint: 5
limit_scales:
position: 2.0 # Allows 50% more range of motion
velocity: 2.0
effort: 2.0
# move to default position
# joint_names and default_position are auto-generated from complete_dof_order and default_joint_angles
# Only kp and kd arrays need to be specified here
kp:
[ 350.0, 200.0, 200.0, 300.0, 300.0, 150.0,
350.0, 200.0, 200.0, 300.0, 300.0, 150.0,
200.0, 200.0, 200.0,
40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0,
40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0 ]
kd:
[ 5.0, 5.0, 5.0, 10.0, 5.0, 5.0,
5.0, 5.0, 5.0, 10.0, 5.0, 5.0,
5.0, 5.0, 5.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0 ]
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/__init__.py
================================================
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/holomotion_fk_root_only.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
import logging
import time
from typing import Callable, Dict, Sequence
import numpy as np
import torch
def _xyzw_to_wxyz(q: np.ndarray) -> np.ndarray:
return np.concatenate([q[..., 3:4], q[..., 0:3]], axis=-1)
def _wxyz_to_xyzw(q: np.ndarray) -> np.ndarray:
return np.concatenate([q[..., 1:4], q[..., 0:1]], axis=-1)
def _quat_conjugate_wxyz(q: np.ndarray) -> np.ndarray:
out = np.array(q, copy=True)
out[..., 1:4] *= -1.0
return out
def _quat_mul_wxyz(q1: np.ndarray, q2: np.ndarray) -> np.ndarray:
w1 = q1[..., 0]
x1 = q1[..., 1]
y1 = q1[..., 2]
z1 = q1[..., 3]
w2 = q2[..., 0]
x2 = q2[..., 1]
y2 = q2[..., 2]
z2 = q2[..., 3]
return np.stack(
[
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
],
axis=-1,
)
def _standardize_quaternion_wxyz(q: np.ndarray) -> np.ndarray:
return np.where(q[..., 0:1] < 0.0, -q, q)
def _axis_angle_from_wxyz(q: np.ndarray) -> np.ndarray:
q = _standardize_quaternion_wxyz(q)
q = q / np.linalg.norm(q, axis=-1, keepdims=True).clip(min=1.0e-9)
quat_w = q[..., 0]
quat_xyz = q[..., 1:4]
mag = np.linalg.norm(quat_xyz, axis=-1)
half_angle = np.arctan2(mag, quat_w)
angle = 2.0 * half_angle
use_taylor = np.abs(angle) <= 1.0e-6
angle_safe = np.where(use_taylor, 1.0, angle)
sin_half_over_angle = np.where(
use_taylor,
0.5 - angle * angle / 48.0,
np.sin(half_angle) / angle_safe,
)
return quat_xyz / sin_half_over_angle[..., None]
def _grad_t(x: np.ndarray, dt: float) -> np.ndarray:
if dt <= 0.0:
raise ValueError(f"Invalid dt: {dt}")
if x.shape[1] < 2:
return np.zeros_like(x)
grad = np.empty_like(x)
inv_dt = 1.0 / dt
grad[:, 0] = (x[:, 1] - x[:, 0]) * inv_dt
grad[:, -1] = (x[:, -1] - x[:, -2]) * inv_dt
if x.shape[1] > 2:
grad[:, 1:-1] = (x[:, 2:] - x[:, :-2]) * (0.5 * inv_dt)
return grad
class HoloMotionFKRootOnly(torch.nn.Module):
"""Root-only online FK.
This lightweight variant is intended for policy-time VR reference building when
only the root body pose/velocity are consumed by observation terms.
"""
def __init__(
self,
dof_names: Sequence[str],
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
timing_logger_enabled: bool = False,
timing_log_interval_sec: float = 5.0,
timing_log_per_call: bool = False,
timing_name: str = "HoloMotionFKRootOnly",
timing_log_fn: Callable[[str], None] | None = None,
) -> None:
super().__init__()
self.body_names = ["root"]
self.dof_names = list(dof_names)
self.num_bodies = 1
self.num_dof = len(self.dof_names)
if self.num_dof <= 0:
raise ValueError("dof_names must not be empty")
self._device = torch.device(device)
self._dtype = dtype
if self._dtype == torch.float64:
self._np_dtype = np.float64
else:
self._np_dtype = np.float32
self._timing_logger_enabled = bool(timing_logger_enabled)
self._timing_log_interval_sec = float(timing_log_interval_sec)
self._timing_log_per_call = bool(timing_log_per_call)
self._timing_name = str(timing_name)
self._timing_logger = logging.getLogger(__name__)
self._timing_log_fn = timing_log_fn
self._timing_last_log_time = None
self._timing_count = 0
self._timing_sum_ms = {}
self._timing_max_ms = {}
self.last_timing_ms = {}
self._gaussian_kernel_cache: Dict[tuple[float, str], np.ndarray] = {}
def set_timing_logger(
self,
enabled: bool,
interval_sec: float | None = None,
per_call: bool | None = None,
log_fn: Callable[[str], None] | None = None,
) -> None:
self._timing_logger_enabled = bool(enabled)
if interval_sec is not None:
self._timing_log_interval_sec = float(interval_sec)
if per_call is not None:
self._timing_log_per_call = bool(per_call)
if log_fn is not None:
self._timing_log_fn = log_fn
def _timing_ms(self, t0: float) -> float:
return (time.perf_counter() - t0) * 1000.0
def _to_numpy(self, x: torch.Tensor | np.ndarray) -> np.ndarray:
if isinstance(x, np.ndarray):
return np.asarray(x, dtype=self._np_dtype)
if not isinstance(x, torch.Tensor):
return np.asarray(x, dtype=self._np_dtype)
if x.device.type != "cpu" or x.dtype != self._dtype:
x = x.detach().to(device="cpu", dtype=self._dtype)
else:
x = x.detach()
return x.numpy()
def _to_output_tensor(self, x: np.ndarray) -> torch.Tensor:
tensor = torch.from_numpy(np.ascontiguousarray(x))
if self._device.type != "cpu" or tensor.dtype != self._dtype:
tensor = tensor.to(device=self._device, dtype=self._dtype)
return tensor
def _get_gaussian_kernel(self, sigma: float) -> np.ndarray | None:
if sigma <= 0.0:
return None
key = (float(sigma), np.dtype(self._np_dtype).str)
kernel = self._gaussian_kernel_cache.get(key, None)
if kernel is not None:
return kernel
radius = int(4.0 * sigma + 0.5)
kernel_x = np.arange(-radius, radius + 1, dtype=self._np_dtype)
kernel = np.exp(-0.5 * np.square(kernel_x / sigma)).astype(
self._np_dtype, copy=False
)
kernel /= kernel.sum(dtype=self._np_dtype)
self._gaussian_kernel_cache[key] = kernel
return kernel
def _gaussian_filter_time(self, x: np.ndarray, kernel: np.ndarray | None) -> np.ndarray:
if kernel is None or x.shape[1] < 2:
return x
radius = kernel.shape[0] // 2
padded = np.pad(x, ((0, 0), (radius, radius), (0, 0)), mode="edge")
windows = np.lib.stride_tricks.sliding_window_view(
padded, window_shape=kernel.shape[0], axis=1
)
return np.tensordot(windows, kernel, axes=([-1], [0])).astype(
x.dtype, copy=False
)
def _log_timing_message(self, message: str) -> None:
if self._timing_log_fn is not None:
self._timing_log_fn(message)
else:
self._timing_logger.info(message)
def _record_timing(self, sample: Dict[str, float]) -> None:
self.last_timing_ms = dict(sample)
if not self._timing_logger_enabled:
return
self._timing_count += 1
for key, value in sample.items():
v = float(value)
self._timing_sum_ms[key] = self._timing_sum_ms.get(key, 0.0) + v
self._timing_max_ms[key] = max(self._timing_max_ms.get(key, v), v)
if self._timing_log_per_call:
self._log_timing_message(
(
f"[{self._timing_name}][Timing] "
f"total={sample['total_ms']:.3f}ms "
f"input={sample['input_ms']:.3f}ms "
f"quat={sample['quat_ms']:.3f}ms "
f"linvel={sample['linvel_ms']:.3f}ms "
f"angvel={sample['angvel_ms']:.3f}ms "
f"smooth={sample['smooth_ms']:.3f}ms "
f"output={sample['output_ms']:.3f}ms"
)
)
now = time.time()
if self._timing_last_log_time is None:
self._timing_last_log_time = now
return
if now - self._timing_last_log_time < self._timing_log_interval_sec:
return
if self._timing_count == 0:
self._timing_last_log_time = now
return
keys = [
"total_ms",
"input_ms",
"quat_ms",
"linvel_ms",
"angvel_ms",
"smooth_ms",
"output_ms",
]
self._log_timing_message(
f"[{self._timing_name}][Timing-Agg] "
+ " ".join(
f"{key}=mean:{self._timing_sum_ms.get(key, 0.0) / self._timing_count:.3f}ms/"
f"max:{self._timing_max_ms.get(key, 0.0):.3f}ms"
for key in keys
)
+ f" n={self._timing_count}"
)
self._timing_count = 0
self._timing_sum_ms.clear()
self._timing_max_ms.clear()
self._timing_last_log_time = now
@torch.inference_mode()
def forward(
self,
root_pos: torch.Tensor,
root_quat: torch.Tensor,
dof_pos: torch.Tensor,
fps: float,
quat_format: str = "xyzw",
sub_batch_size: int = 64,
vel_smoothing_sigma: float = 2.0,
compute_velocity: bool = True,
) -> Dict[str, torch.Tensor]:
t_total = time.perf_counter()
del sub_batch_size
del compute_velocity # kept for call-site compatibility
if fps <= 0.0:
raise ValueError(f"Invalid fps: {fps}")
if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3:
raise ValueError("Inputs must be (B, T, ...)")
if (
root_pos.shape[:2] != root_quat.shape[:2]
or root_pos.shape[:2] != dof_pos.shape[:2]
):
raise ValueError("Mismatched batch/time shapes among inputs")
if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4:
raise ValueError(
"root_pos must be (B,T,3) and root_quat must be (B,T,4)"
)
if dof_pos.shape[-1] != self.num_dof:
raise ValueError(
f"dof_pos last dim {dof_pos.shape[-1]} does not match {self.num_dof}"
)
t_input = time.perf_counter()
root_pos_np = self._to_numpy(root_pos)
root_quat_np = self._to_numpy(root_quat)
dof_pos_np = self._to_numpy(dof_pos)
input_ms = self._timing_ms(t_input)
t_quat = time.perf_counter()
if quat_format == "xyzw":
root_quat_xyzw_np = root_quat_np
root_quat_wxyz_np = _xyzw_to_wxyz(root_quat_np)
elif quat_format == "wxyz":
root_quat_wxyz_np = root_quat_np
root_quat_xyzw_np = _wxyz_to_xyzw(root_quat_np)
else:
raise ValueError(f"Unsupported quat_format: {quat_format}")
quat_ms = self._timing_ms(t_quat)
dt = 1.0 / fps
kernel = self._get_gaussian_kernel(float(vel_smoothing_sigma))
t_linvel = time.perf_counter()
root_vel_np = _grad_t(root_pos_np, dt)
linvel_ms = self._timing_ms(t_linvel)
t_angvel = time.perf_counter()
root_angvel_np = np.zeros_like(root_pos_np)
if root_quat_wxyz_np.shape[1] >= 2:
q1 = root_quat_wxyz_np[:, 1:]
q0_inv = _quat_conjugate_wxyz(root_quat_wxyz_np[:, :-1])
q_rel = _quat_mul_wxyz(q1, q0_inv)
root_angvel_np[:, :-1] = _axis_angle_from_wxyz(q_rel) / dt
angvel_ms = self._timing_ms(t_angvel)
t_smooth = time.perf_counter()
if kernel is not None and root_pos_np.shape[1] >= 2:
vel_and_ang_np = np.concatenate([root_vel_np, root_angvel_np], axis=-1)
vel_and_ang_np = self._gaussian_filter_time(vel_and_ang_np, kernel)
root_vel_np = vel_and_ang_np[..., :3]
root_angvel_np = vel_and_ang_np[..., 3:6]
smooth_ms = self._timing_ms(t_smooth)
t_output = time.perf_counter()
out = {
"global_translation": self._to_output_tensor(root_pos_np[:, :, None, :]),
"global_rotation_quat": self._to_output_tensor(
root_quat_xyzw_np[:, :, None, :]
),
"global_velocity": self._to_output_tensor(root_vel_np[:, :, None, :]),
"global_angular_velocity": self._to_output_tensor(
root_angvel_np[:, :, None, :]
),
"dof_pos": self._to_output_tensor(dof_pos_np),
"dof_vel": self._to_output_tensor(np.zeros_like(dof_pos_np)),
}
output_ms = self._timing_ms(t_output)
self._record_timing(
{
"total_ms": self._timing_ms(t_total),
"input_ms": input_ms,
"quat_ms": quat_ms,
"linvel_ms": linvel_ms,
"angvel_ms": angvel_ms,
"smooth_ms": smooth_ms,
"output_ms": output_ms,
}
)
return out
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/obs_builder/__init__.py
================================================
from .obs_builder import PolicyObsBuilder, get_gravity_orientation
__all__ = [
"PolicyObsBuilder",
"get_gravity_orientation",
]
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/obs_builder/obs_builder.py
================================================
import numpy as np
import torch
from typing import Dict, List, Sequence, Any, Optional
def get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray:
"""Calculate gravity orientation from quaternion.
Args:
quaternion: Array-like [w, x, y, z]
Returns:
np.ndarray of shape (3,) representing gravity projection.
"""
qw = float(quaternion[0])
qx = float(quaternion[1])
qy = float(quaternion[2])
qz = float(quaternion[3])
gravity_orientation = np.zeros(3, dtype=np.float32)
gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2.0 * (qz * qy + qw * qx)
gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz)
return gravity_orientation
class _CircularBuffer:
"""History buffer for batched tensor data (batch==1 in our eval/deploy).
Stores history in oldest->newest order when accessed via .buffer.
"""
def __init__(self, max_len: int, feat_dim: int):
if max_len < 1:
raise ValueError(f"max_len must be >= 1, got {max_len}")
self._max_len = int(max_len)
self._feat_dim = int(feat_dim)
self._pointer = -1
self._num_pushes = 0
self._buffer: torch.Tensor = torch.zeros(
(self._max_len, 1, self._feat_dim),
dtype=torch.float32,
device="cpu",
)
@property
def buffer(self) -> torch.Tensor:
"""Tensor of shape [1, max_len, feat_dim], oldest->newest along dim=1."""
if self._num_pushes == 0:
raise RuntimeError(
"Attempting to read from an empty history buffer."
)
# roll such that oldest is at index=0 along the history axis
rolled = torch.roll(
self._buffer, shifts=self._max_len - self._pointer - 1, dims=0
)
return torch.transpose(rolled, 0, 1) # [1, max_len, feat]
def append(self, data: torch.Tensor) -> None:
"""Append one step: data shape [1, feat_dim] on the configured device."""
if (
data.ndim != 2
or data.shape[0] != 1
or data.shape[1] != self._feat_dim
):
raise ValueError(
f"Expected data with shape [1, {self._feat_dim}], got {tuple(data.shape)}"
)
self._pointer = (self._pointer + 1) % self._max_len
self._buffer[self._pointer] = data
if self._num_pushes == 0:
# duplicate first push across entire history for warm start
self._buffer[:] = data
self._num_pushes += 1
class PolicyObsBuilder:
"""Builds policy observations from Unitree lowstate with temporal history.
Designed to be shared between MuJoCo sim2sim evaluation and ROS2 deployment.
History management is internal and produces a flattened vector of size
sum_i(context_length * feat_i) across the configured observation items.
Supports two command modes:
- "motion_tracking": uses reference motion states
- "velocity_tracking": uses velocity commands [vx, vy, vyaw]
"""
def __init__(
self,
dof_names_onnx: Sequence[str],
default_angles_onnx: np.ndarray,
evaluator: Optional[Any] = None,
obs_policy_cfg: Optional[Dict[str, Any]] = None,
) -> None:
self.dof_names_onnx: List[str] = list(dof_names_onnx)
self.num_actions: int = len(self.dof_names_onnx)
self.evaluator = evaluator
self.obs_policy_cfg = obs_policy_cfg
if default_angles_onnx.shape[0] != self.num_actions:
raise ValueError(
"default_angles_onnx length must match num actions"
)
self.default_angles_onnx = default_angles_onnx.astype(np.float32)
self.default_angles_dict: Dict[str, float] = {
name: float(self.default_angles_onnx[idx])
for idx, name in enumerate(self.dof_names_onnx)
}
# Build observation schema from config if provided
self.term_specs: List[Dict[str, Any]] = []
for term_dict in self.obs_policy_cfg["atomic_obs_list"]:
for name, cfg in term_dict.items():
term_dict = {**cfg}
term_dict["name"] = name
self.term_specs.append(term_dict)
# Buffers are created lazily after first dimension inference
self._buffers: Dict[str, _CircularBuffer] = {}
def reset(self) -> None:
for buf in self._buffers.values():
buf._pointer = -1
buf._num_pushes = 0
buf._buffer.zero_()
def _compute_term(
self,
name: str,
) -> np.ndarray:
# Prefer evaluator-provided methods; no legacy fallbacks
if self.evaluator is not None:
meth = getattr(self.evaluator, f"_get_obs_{name}", None)
if callable(meth):
out = meth()
return np.asarray(out, dtype=np.float32).reshape(-1)
raise ValueError(
f"Unknown observation term '{name}' or evaluator method missing."
)
def build_policy_obs(self) -> np.ndarray:
"""Append one step using evaluator-provided observation terms and return flattened obs."""
# Compute per-term outputs
values: Dict[str, np.ndarray] = {}
for spec in self.term_specs:
name = spec["name"]
scale = float(spec.get("scale", 1.0))
values[name] = self._compute_term(name) * scale
# Lazily initialize buffers with inferred feature dims
if len(self._buffers) == 0:
for spec in self.term_specs:
name = spec["name"]
hist_len = int(spec.get("history_length", 0))
if hist_len <= 0:
continue
feat_dim = int(values[name].reshape(-1).shape[0])
self._buffers[name] = _CircularBuffer(
hist_len,
feat_dim,
)
# Append current step to buffers (skip terms without history)
for spec in self.term_specs:
name = spec["name"]
if name in self._buffers:
item = torch.as_tensor(
values[name],
dtype=torch.float32,
device="cpu",
)[None, :]
self._buffers[name].append(item)
# Assemble flat list according to term ordering and history flatten rules
flat_list: List[np.ndarray] = []
for spec in self.term_specs:
name = spec["name"]
if name in self._buffers:
buf = self._buffers[name].buffer[0] # [hist, feat]
arr = buf.reshape(-1).detach().cpu().numpy()
flat_list.append(arr.astype(np.float32))
else:
# no history -> use computed value directly
flat_list.append(values[name].reshape(-1).astype(np.float32))
if len(flat_list) == 0:
return np.zeros(0, dtype=np.float32)
return np.concatenate(flat_list, axis=0).astype(np.float32)
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/policy_node_29dof.py
================================================
#! /your_dir/miniconda3/envs/holomotion_deploy/bin/python
"""
HoloMotion Policy Node
This module implements the main policy execution node for the HoloMotion humanoid robot system using ZMQ latest_obs transport.
It handles neural network policy inference, motion sequence management, remote controller input,
and robot state coordination for humanoid behaviors including velocity tracking and motion tracking.
The policy node serves as the high-level decision maker that:
- Processes sensor observations and builds state representations
- Executes trained neural network policies for motion generation (velocity tracking and motion tracking)
- Manages multiple motion sequences (motion clips) loaded from offline files
- Handles remote controller input for motion selection
- Coordinates with the main control node for safe operation
Key Features:
- Dual policy support: velocity tracking and motion tracking
- Offline motion file loading (.npz format)
- Runtime policy switching with button controls
- Separate hyperparameters (kps, kds, action_scale, default_angles) for each model
Author: HoloMotion Team
License: See project LICENSE file
"""
import os
import torch
import time
import json
import threading
from collections import deque
import easydict
import numpy as np
import onnx
import onnxruntime
import rclpy
import zmq
import yaml
from ament_index_python.packages import get_package_share_directory
from omegaconf import OmegaConf
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import Float32MultiArray, String
from unitree_hg.msg import LowState
from humanoid_policy.obs_builder import PolicyObsBuilder
from humanoid_policy.utils.remote_controller_filter import KeyMap, RemoteController
from humanoid_policy.holomotion_fk_root_only import HoloMotionFKRootOnly
def _parse_cpu_affinity_str(s):
"""Parse '0,1' or '2' -> [0,1] or [2]. Empty/invalid -> []."""
s = str(s).strip()
if not s:
return []
out = []
for x in s.split(","):
x = x.strip()
if x.isdigit():
out.append(int(x))
return out
def set_thread_cpu_affinity(cpu_ids):
"""Pin current thread to given CPU core IDs (Linux only).
cpu_ids: list of int, e.g. [0], [0,1]. Returns True if set successfully."""
if not cpu_ids:
return False
try:
import ctypes
libc = ctypes.CDLL("libc.so.6")
CPU_SETSIZE = 1024
ncpubits = 8 * ctypes.sizeof(ctypes.c_ulong)
nlongs = (CPU_SETSIZE + ncpubits - 1) // ncpubits
class CpuSetT(ctypes.Structure):
_fields_ = [("__bits", ctypes.c_ulong * nlongs)]
libc.pthread_self.restype = ctypes.c_ulong
libc.pthread_setaffinity_np.argtypes = [
ctypes.c_ulong, ctypes.c_size_t, ctypes.POINTER(CpuSetT)
]
cs = CpuSetT()
for i in range(nlongs):
cs.__bits[i] = 0
for c in cpu_ids:
if 0 <= c < CPU_SETSIZE:
idx = c // ncpubits
bit = c % ncpubits
cs.__bits[idx] |= 1 << bit
tid = libc.pthread_self()
sz = ctypes.sizeof(CpuSetT)
ret = libc.pthread_setaffinity_np(tid, sz, ctypes.byref(cs))
return ret == 0
except Exception:
return False
HEADER_SIZE = 1280
DEFAULT_ZMQ_TOPIC = b"obs65"
_DTYPE_BY_NAME = {
"f32": np.float32,
"f64": np.float64,
"i32": np.int32,
"i64": np.int64,
"u8": np.uint8,
"bool": np.bool_,
}
def _decode_zmq_topic(topic_value) -> bytes:
if isinstance(topic_value, bytes):
return topic_value
return str(topic_value).encode("utf-8")
def _coerce_config_bool(value, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, (bool, np.bool_)):
return bool(value)
if isinstance(value, str):
value = value.strip().lower()
if value in {"1", "true", "yes", "y", "on"}:
return True
if value in {"0", "false", "no", "n", "off", ""}:
return False
return bool(value)
def _infer_onnx_dim(dim, default: int = 1) -> int:
if isinstance(dim, int) and dim > 0:
return dim
return int(default)
def _infer_numpy_dtype_from_onnx_type(type_str: str):
type_str = str(type_str).lower()
if "float16" in type_str:
return np.float16
if "float64" in type_str or "double" in type_str:
return np.float64
if "int64" in type_str:
return np.int64
if "int32" in type_str:
return np.int32
if "bool" in type_str:
return np.bool_
return np.float32
def unpack_numpy_message(packet: bytes, expected_topic: bytes | None = None) -> dict:
if expected_topic is not None:
if not packet.startswith(expected_topic):
raise ValueError("ZMQ packet topic prefix mismatch")
packet = packet[len(expected_topic) :]
if len(packet) < HEADER_SIZE:
raise ValueError(f"ZMQ packet too short: {len(packet)} < {HEADER_SIZE}")
header_bytes = packet[:HEADER_SIZE].rstrip(b"\x00")
if not header_bytes:
raise ValueError("ZMQ packet has empty header")
header = json.loads(header_bytes.decode("utf-8"))
payload = memoryview(packet)[HEADER_SIZE:]
result = {}
offset = 0
for field in header.get("fields", []):
name = str(field["name"])
dtype_name = str(field["dtype"])
shape = tuple(int(x) for x in field.get("shape", []))
if dtype_name not in _DTYPE_BY_NAME:
raise ValueError(f"Unsupported dtype in ZMQ packet: {dtype_name}")
dtype = np.dtype(_DTYPE_BY_NAME[dtype_name]).newbyteorder("<")
count = int(np.prod(shape, dtype=np.int64)) if len(shape) > 0 else 1
nbytes = count * dtype.itemsize
end = offset + nbytes
if end > len(payload):
raise ValueError(
f"ZMQ packet field '{name}' exceeds payload size: end={end}, payload={len(payload)}"
)
arr = np.frombuffer(payload[offset:end], dtype=dtype, count=count)
if len(shape) > 0:
arr = arr.reshape(shape)
else:
arr = arr.reshape(())
result[name] = np.array(arr, copy=True)
offset = end
return result
class LatestObsBuffer:
"""Thread-safe buffer for delayed latest_obs access."""
def __init__(self, max_queue_size: int = 20):
self._lock = threading.Lock()
self._data = None
self._timestamp = None
self._sender_timestamp = None
self._frame_index = None
self._data_queue = deque(maxlen=max_queue_size)
self._timestamp_queue = deque(maxlen=max_queue_size)
self._sender_timestamp_queue = deque(maxlen=max_queue_size)
self._frame_index_queue = deque(maxlen=max_queue_size)
def set(
self,
arr: np.ndarray,
sender_timestamp: float | None = None,
frame_index: int | None = None,
):
with self._lock:
current_time = time.time()
arr_copy = np.asarray(arr, dtype=np.float32).copy()
self._data = arr_copy
self._timestamp = current_time
self._sender_timestamp = sender_timestamp
self._frame_index = frame_index
self._data_queue.append(arr_copy)
self._timestamp_queue.append(current_time)
self._sender_timestamp_queue.append(sender_timestamp)
self._frame_index_queue.append(frame_index)
def get_with_age_and_delay(self, max_age: float = 0.1, delay_steps: int = 0):
"""Return a delayed frame and report whether it is stale."""
with self._lock:
if len(self._data_queue) == 0:
if self._data is None or self._timestamp is None:
return None, None, True, None, None
current_time = time.time()
age = current_time - self._timestamp
return (
self._data,
self._timestamp,
age > max_age,
self._frame_index,
self._sender_timestamp,
)
if delay_steps < 0:
delay_steps = 0
idx = len(self._data_queue) - 1 - delay_steps
if idx < 0:
idx = 0
data = self._data_queue[idx]
ts = self._timestamp_queue[idx]
frame_index = self._frame_index_queue[idx]
sender_timestamp = self._sender_timestamp_queue[idx]
current_time = time.time()
age = current_time - ts
is_stale = age > max_age
return data, ts, is_stale, frame_index, sender_timestamp
def get_queue_stats(self):
with self._lock:
if len(self._data_queue) < 2:
return {"queue_size": len(self._data_queue), "avg_interval": None}
intervals = []
for i in range(1, len(self._timestamp_queue)):
interval = self._timestamp_queue[i] - self._timestamp_queue[i - 1]
intervals.append(interval)
avg_interval = float(np.mean(intervals)) if intervals else None
return {
"queue_size": len(self._data_queue),
"avg_interval": avg_interval,
"expected_freq": 1.0 / avg_interval if avg_interval and avg_interval > 0 else None,
}
class ZmqLatestObsSubscriber:
"""Background ZMQ SUB receiver for latest_obs packets."""
def __init__(
self,
uri: str,
topic: bytes,
buffer: LatestObsBuffer,
logger,
mode: str = "connect",
cpu_affinity=None,
conflate: bool = True,
):
self.uri = uri
self.topic = topic
self.buffer = buffer
self.logger = logger
self.mode = str(mode).strip().lower()
self.cpu_affinity = cpu_affinity or []
self.conflate = bool(conflate)
self._thread = None
self._stop_event = threading.Event()
self._context = None
self._socket = None
self._poller = None
self._recv_count = 0
def _process_packet(self, packet: bytes):
payload = unpack_numpy_message(packet, expected_topic=self.topic)
latest_obs = payload.get("latest_obs", None)
if latest_obs is None:
raise ValueError("ZMQ packet missing latest_obs field")
frame_index = payload.get("frame_index", None)
if frame_index is not None:
frame_index = int(np.asarray(frame_index).reshape(-1)[0])
sender_timestamp = payload.get("timestamp_realtime", None)
if sender_timestamp is not None:
sender_timestamp = float(np.asarray(sender_timestamp).reshape(-1)[0])
self.buffer.set(
np.asarray(latest_obs, dtype=np.float32),
sender_timestamp=sender_timestamp,
frame_index=frame_index,
)
self._recv_count += 1
if self._recv_count == 1:
self.logger.info(
f"[ZMQ] first latest_obs packet received from {self.uri}, "
f"topic={self.topic.decode('utf-8', errors='ignore')}"
)
def _run(self):
if self.cpu_affinity and set_thread_cpu_affinity(self.cpu_affinity):
self.logger.info(f"[ZMQ] subscriber thread pinned to CPUs {self.cpu_affinity}")
self._context = zmq.Context()
self._socket = self._context.socket(zmq.SUB)
self._socket.setsockopt(zmq.RCVHWM, 1)
self._socket.setsockopt(zmq.SUBSCRIBE, self.topic)
if self.conflate and hasattr(zmq, "CONFLATE"):
self._socket.setsockopt(zmq.CONFLATE, 1)
if self.mode == "bind":
self._socket.bind(self.uri)
elif self.mode == "connect":
self._socket.connect(self.uri)
else:
raise ValueError("latest_obs_zmq_mode must be 'bind' or 'connect'")
self._poller = zmq.Poller()
self._poller.register(self._socket, zmq.POLLIN)
self.logger.info(
f"[ZMQ] latest_obs subscriber ready: mode={self.mode}, uri={self.uri}, "
f"topic={self.topic.decode('utf-8', errors='ignore')}, conflate={self.conflate}"
)
try:
while not self._stop_event.is_set():
events = dict(self._poller.poll(50))
if self._socket not in events:
continue
try:
packet = self._socket.recv(flags=zmq.NOBLOCK)
except zmq.Again:
continue
self._process_packet(packet)
except Exception as exc:
if not self._stop_event.is_set():
self.logger.error(f"[ZMQ] subscriber loop failed: {exc}")
finally:
try:
if self._poller is not None and self._socket is not None:
self._poller.unregister(self._socket)
except Exception:
pass
try:
if self._socket is not None:
self._socket.close(0)
except Exception:
pass
try:
if self._context is not None:
self._context.term()
except Exception:
pass
self._socket = None
self._context = None
self._poller = None
def start(self):
if self._thread is not None:
return
self._stop_event.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self.logger.info("[ZMQ] subscriber thread started")
def stop(self):
self._stop_event.set()
if self._thread:
self._thread.join(timeout=2.0)
self._thread = None
self.logger.info("[ZMQ] subscriber thread stopped")
class HoloMotionPolicyNode(Node):
"""Main policy execution node for HoloMotion humanoid robot control with dual policy support.
This node implements the high-level control logic for a humanoid robot capable of
performing both velocity tracking and motion sequence execution. It supports two
neural network policies and allows runtime switching between them.
Key Features:
- Dual neural network policy inference (velocity + motion) using ONNX Runtime
- Runtime policy switching with A/B/Y button controls
- Velocity tracking mode with joystick control
- Motion tracking mode with motion clip sequence selection
- Safety-aware state machine with motion prerequisites
- Real-time observation processing and action generation
Policy Control:
- A button: Enable policy (defaults to velocity mode)
- B button: Switch from velocity to motion mode
- Y button: Switch from motion back to velocity mode
Input Controls:
- Motion mode: B button (for mode switch)
- Velocity mode: Y button (for mode switch) + Joystick +UP/DOWN/LEFT/RIGHT (for motion selection)
State Machine:
- ZERO_TORQUE: Initial safe state, waiting for activation
- MOVE_TO_DEFAULT: Ready state, allows policy operations
- Policy execution with mode switching
- Emergency stop handling
"""
def __init__(self):
"""Initialize the policy node with configuration, models, and ROS2 interfaces.
Sets up the complete policy execution pipeline including:
- Configuration loading from YAML file
- Neural network model initialization
- Motion data loading for all sequences
- ROS2 publishers, subscribers, and timers
- State machine initialization
The node starts in a safe state and waits for proper robot state
before allowing motion execution.
"""
super().__init__("policy_node")
# Get config path from ROS parameter
config_path = self.declare_parameter("config_path", "").value
with open(config_path, "r", encoding="utf-8") as config_file:
self.config_yaml = easydict.EasyDict(yaml.safe_load(config_file))
# Read policy frequency from config, default to 50 Hz if not specified
policy_freq = self.config_yaml.get("policy_freq", 50)
self.dt = 1.0 / policy_freq
self.get_logger().info(f"Policy frequency set to: {policy_freq} Hz (dt = {self.dt:.4f} s)")
# Initialize basic parameters - will be updated after config loading
self.actions_dim = 29 # Default value, will be updated from config
self.real_dof_names = [] # Will be loaded from config
self.current_motion_clip_index = 0 # Current motion clip index
# Button state tracking for preventing multiple triggers
self.last_button_states = {
KeyMap.up: 0,
KeyMap.down: 0,
KeyMap.left: 0,
KeyMap.right: 0,
KeyMap.A: 0,
KeyMap.B: 0,
KeyMap.Y: 0,
}
# Safety check related flags
self.policy_enabled = False # Controls whether policy is enabled
# Robot state related flags
self.robot_state_ready = False # Marks whether MOVE_TO_DEFAULT state is received, allowing key operations
self._setup_subscribers()
self._setup_publishers()
self._setup_timers()
# Initialize variables for dual policy
self.velocity_policy_session = None
self.motion_policy_session = None
self.use_kv_cache = False
self.motion_kv_cache = None
self.motion_kv_input_name = None
self.motion_kv_output_name = None
self.motion_step_idx_input_name = None
self.current_policy_mode = "velocity"
self.velocity_config = None
self.motion_config = None
self.motion_frame_idx = 0
self.ref_dof_pos = None
self.ref_dof_vel = None
self.ref_raw_bodylink_pos = None
self.ref_raw_bodylink_rot = None
self.n_motion_frames = 0
self.external_latest_obs = None
self.external_obs_received = False
self.last_external_obs_time = None
self._latest_sender_timestamp = None
self.latest_obs_flag = False
self.latest_obs_expected_dim = 65
self.external_fut_dof_pos_queue = None
self.external_fut_dof_vel_queue = None
self.external_fut_root_pos_queue = None
self.external_fut_root_rot_queue = None
self.external_fut_frame_idx_queue = None
self._prev_external_dof_pos = None
self._prev_external_dof_vel = None
self._prev_external_root_pos = None
self._prev_external_root_rot = None
self._prev_external_frame_idx = None
self.max_data_age = 0.6
self.stale_data_warning_count = 0
self.last_poll_time = None
self._last_vr_status_log_time = None
self.latest_obs_zmq_uri = self.declare_parameter(
"latest_obs_zmq_uri", "tcp://192.168.124.29:6001"
).value
self.latest_obs_zmq_topic = self.declare_parameter(
"latest_obs_zmq_topic", DEFAULT_ZMQ_TOPIC.decode("utf-8")
).value
self.latest_obs_zmq_mode = self.declare_parameter(
"latest_obs_zmq_mode", "connect"
).value
self.latest_obs_zmq_conflate = self.declare_parameter(
"latest_obs_zmq_conflate", True
).value
self.zmq_jitter_delay_frames = self.declare_parameter(
"zmq_jitter_delay_frames", 5
).value
self.require_vr_data_for_motion = self.declare_parameter(
"require_vr_data_for_motion", True
).value
self.enable_teleop_reference = self.declare_parameter(
"enable_teleop_reference", True
).value
self._cpu_affinity_main_str = self.declare_parameter(
"cpu_affinity_main", ""
).value
self._cpu_affinity_zmq_sub_str = self.declare_parameter(
"cpu_affinity_zmq_sub", ""
).value
self.timing_debug_enabled = self.declare_parameter(
"timing_debug_enabled", True
).value
self.timing_debug_log_interval_sec = self.declare_parameter(
"timing_debug_log_interval_sec", 5.0
).value
self.timing_debug_log_per_loop = self.declare_parameter(
"timing_debug_log_per_loop", False
).value
self._timing_debug_last_log_time = None
self._timing_debug_samples = deque(maxlen=500)
self._root_only_fk_keybody_warned = False
_vr = getattr(self.config_yaml, "vr", None) or {}
if _vr:
self.latest_obs_zmq_uri = str(_vr.get("latest_obs_zmq_uri", self.latest_obs_zmq_uri))
self.latest_obs_zmq_topic = str(
_vr.get("latest_obs_zmq_topic", self.latest_obs_zmq_topic)
)
self.latest_obs_zmq_mode = str(
_vr.get("latest_obs_zmq_mode", self.latest_obs_zmq_mode)
)
self.latest_obs_zmq_conflate = bool(
_vr.get("latest_obs_zmq_conflate", self.latest_obs_zmq_conflate)
)
self.zmq_jitter_delay_frames = int(
_vr.get("zmq_jitter_delay_frames", self.zmq_jitter_delay_frames)
)
self.max_data_age = float(_vr.get("max_data_age", self.max_data_age))
self.require_vr_data_for_motion = bool(
_vr.get("require_vr_data_for_motion", self.require_vr_data_for_motion)
)
self.enable_teleop_reference = bool(
_vr.get("enable_teleop_reference", self.enable_teleop_reference)
)
self.timing_debug_enabled = bool(
_vr.get("timing_debug_enabled", self.timing_debug_enabled)
)
self.timing_debug_log_interval_sec = float(
_vr.get(
"timing_debug_log_interval_sec",
self.timing_debug_log_interval_sec,
)
)
self.timing_debug_log_per_loop = bool(
_vr.get("timing_debug_log_per_loop", self.timing_debug_log_per_loop)
)
self._cpu_affinity_main_str = str(
getattr(self.config_yaml, "cpu_affinity_main", self._cpu_affinity_main_str)
)
self._cpu_affinity_zmq_sub_str = str(
getattr(
self.config_yaml,
"cpu_affinity_zmq_sub",
self._cpu_affinity_zmq_sub_str,
)
)
self._ros_latest_obs_buffer = None
self._npz_replay_frame_index = None
self._external_seen_frames = 0
self._vr_ready_logged = False
self._latest_obs_buffer = LatestObsBuffer()
self._latest_obs_zmq_topic_bytes = _decode_zmq_topic(self.latest_obs_zmq_topic)
if str(self.latest_obs_zmq_mode).strip().lower() == "connect":
uri_str = str(self.latest_obs_zmq_uri)
if "*" in uri_str or "0.0.0.0" in uri_str:
self.get_logger().warn(
"[ZMQ] connect mode requires a concrete peer address. "
"Do not use '*' or '0.0.0.0'; use the sender IP instead, "
"for example tcp://192.168.124.29:6001."
)
zmq_cpu_affinity = _parse_cpu_affinity_str(self._cpu_affinity_zmq_sub_str)
self._zmq_subscriber = ZmqLatestObsSubscriber(
uri=self.latest_obs_zmq_uri,
topic=self._latest_obs_zmq_topic_bytes,
mode=self.latest_obs_zmq_mode,
conflate=bool(self.latest_obs_zmq_conflate),
buffer=self._latest_obs_buffer,
logger=self.get_logger(),
cpu_affinity=zmq_cpu_affinity if zmq_cpu_affinity else None,
)
self._zmq_subscriber.start()
self.get_logger().info(
f"ZMQ latest_obs subscriber started: mode={self.latest_obs_zmq_mode}, "
f"uri={self.latest_obs_zmq_uri}, topic={self.latest_obs_zmq_topic}, "
f"jitter_delay={self.zmq_jitter_delay_frames}"
)
self.dof_names_ref_motion = []
self.num_actions = 29
self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32)
self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.target_dof_pos_onnx = self.default_angles_onnx.copy()
self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)
self._lowstate_msg = None
self.target_dof_pos_real = None
self.motion_in_progress = False
self._keybody_indices_by_term_name = {}
self.fk = None
self.fk_initialized = False
self.motion_action_ema_filter_enabled = False
self.motion_action_ema_filter_alpha = 1.0
self._motion_filtered_actions_onnx = None
def _is_vr_ready_for_motion(self) -> bool:
"""Return whether the ZMQ reference stream is ready for motion mode."""
if not getattr(self, "enable_teleop_reference", True):
return False
if not (
getattr(self, "external_obs_received", False)
and getattr(self, "external_latest_obs", None) is not None
):
return False
n_fut = int(getattr(self, "n_fut_frames", 0) or 0)
if n_fut <= 0:
return True
delay = int(getattr(self, "zmq_jitter_delay_frames", 0) or 0)
needed = n_fut + max(delay, 0) + 1
return int(getattr(self, "_external_seen_frames", 0)) >= needed
def _init_keybody_indices_cache(self):
if self.motion_config is None:
raise ValueError("motion_config is not loaded; cannot init keybody index cache")
atomic_list = self._get_policy_atomic_obs_list(self.motion_config)["atomic_obs_list"]
body_names = [str(name) for name in self.motion_config.robot.body_names]
body_name_to_idx = {body_name: idx for idx, body_name in enumerate(body_names)}
cache = {}
for term_dict in atomic_list:
term_name = str(list(term_dict.keys())[0])
term_cfg = term_dict[term_name]
params = {}
if isinstance(term_cfg, dict):
params = term_cfg.get("params", {}) or {}
if not isinstance(params, dict):
raise ValueError(
f"Observation term '{term_name}' params must be a dict, got {type(params)}"
)
needs_keybody = ("keybody" in term_name) or ("keybody_names" in params)
if not needs_keybody:
continue
keybody_names = params.get("keybody_names", None)
if keybody_names is None:
keybody_idxs = np.arange(len(body_names), dtype=np.int64)
else:
keybody_names = [str(name) for name in keybody_names]
missing_names = [
name for name in keybody_names if name not in body_name_to_idx
]
if len(missing_names) > 0:
raise ValueError(
f"Unknown keybody_names in '{term_name}': {missing_names}. "
f"Available body names: {body_names}"
)
keybody_idxs = np.asarray(
[body_name_to_idx[name] for name in keybody_names],
dtype=np.int64,
)
cache[term_name] = keybody_idxs
self._keybody_indices_by_term_name = cache
def _get_policy_atomic_obs_list(self, config):
"""Resolve the atomic obs list used to build the ONNX policy input.
Aligns with MuJoCo sim2sim eval ordering by honoring modules.actor.obs_schema
when available, to guarantee the policy input term order matches training/export.
"""
def _to_plain_obs_cfg(cfg):
if OmegaConf.is_config(cfg):
plain_cfg = OmegaConf.to_container(cfg, resolve=True)
elif cfg is None:
plain_cfg = {}
else:
plain_cfg = dict(cfg)
if plain_cfg is None:
plain_cfg = {}
if not isinstance(plain_cfg, dict):
raise ValueError(
f"Observation term config must be a mapping, got {type(plain_cfg)}"
)
return plain_cfg
def _get_actor_atomic_obs_entries():
obs_cfg = config.get("obs", None)
if obs_cfg is None:
raise ValueError("Missing config.obs for policy obs")
obs_groups = obs_cfg.get("obs_groups", None)
if obs_groups is None:
raise ValueError("Missing config.obs.obs_groups for policy obs")
if obs_groups.get("policy", None) is not None:
entries = []
for term_dict in obs_groups.policy.atomic_obs_list:
term_name = str(list(term_dict.keys())[0])
entries.append(
(
"policy",
term_name,
_to_plain_obs_cfg(term_dict[term_name]),
)
)
return entries
if obs_groups.get("unified", None) is not None:
entries = []
for term_dict in obs_groups.unified.atomic_obs_list:
term_name = str(list(term_dict.keys())[0])
if term_name.startswith("actor_"):
entries.append(
(
"unified",
term_name,
_to_plain_obs_cfg(term_dict[term_name]),
)
)
if not entries:
raise ValueError(
"obs_groups.unified found but contains no actor_* terms."
)
return entries
raise ValueError(
"Unsupported obs config : expected obs_groups.policy or obs_groups.unified."
)
def _get_actor_obs_schema_terms():
modules_cfg = config.get("modules", None)
if modules_cfg is None:
return []
actor_cfg = modules_cfg.get("actor", None)
if actor_cfg is None:
return []
obs_schema = actor_cfg.get("obs_schema", None)
if obs_schema is None:
return []
if OmegaConf.is_config(obs_schema):
obs_schema_plain = OmegaConf.to_container(obs_schema, resolve=True)
else:
obs_schema_plain = obs_schema
if not isinstance(obs_schema_plain, dict):
return []
ordered_terms = []
def _collect_terms(node):
if node is None:
return
if isinstance(node, dict):
if "terms" in node and isinstance(node["terms"], list):
ordered_terms.extend(str(term) for term in node["terms"])
return
for v in node.values():
_collect_terms(v)
return
if isinstance(node, list):
for v in node:
_collect_terms(v)
return
_collect_terms(obs_schema_plain)
return ordered_terms
actor_atomic_entries = _get_actor_atomic_obs_entries()
schema_terms = _get_actor_obs_schema_terms()
if len(schema_terms) == 0:
return {
"atomic_obs_list": [
{term_name: term_cfg}
for _, term_name, term_cfg in actor_atomic_entries
]
}
by_full_key = {}
by_leaf_key = {}
ambiguous_leaf_keys = set()
for group_name, term_name, term_cfg in actor_atomic_entries:
full_key = f"{group_name}/{term_name}"
by_full_key[full_key] = (term_name, term_cfg)
if term_name in by_leaf_key:
ambiguous_leaf_keys.add(term_name)
else:
by_leaf_key[term_name] = (term_name, term_cfg)
ordered_atomic_list = []
for schema_term in schema_terms:
schema_term_key = str(schema_term)
if schema_term_key in by_full_key:
term_name, term_cfg = by_full_key[schema_term_key]
ordered_atomic_list.append({term_name: term_cfg})
continue
leaf_key = schema_term_key.split("/")[-1]
if leaf_key in ambiguous_leaf_keys:
raise ValueError(
f"Ambiguous obs_schema term '{schema_term_key}': "
f"multiple atomic obs share leaf key '{leaf_key}'."
)
if leaf_key not in by_leaf_key:
available = sorted(list(by_leaf_key.keys()))
raise ValueError(
f"obs_schema term '{schema_term_key}' not found in atomic_obs_list. "
f"Available terms: {available}"
)
term_name, term_cfg = by_leaf_key[leaf_key]
ordered_atomic_list.append({term_name: term_cfg})
return {"atomic_obs_list": ordered_atomic_list}
def _find_actor_place_holder_ndim(self):
n_dim = 0
atomic_list = self._get_policy_atomic_obs_list(self.motion_config)[
"atomic_obs_list"
]
for obs_dict in atomic_list:
name = str(list(obs_dict.keys())[0])
if name == "place_holder" or name == "actor_place_holder":
cfg = obs_dict[name]
params = cfg.get("params", {}) if isinstance(cfg, dict) else {}
n_dim = int(params.get("n_dim", 0))
return n_dim
def _init_obs_buffers(self):
"""Initialize observation builders for both velocity and motion policies.
Each obs_builder uses its own model's dof_names_onnx and default_angles_onnx
to ensure correct observation computation for each policy.
"""
# Use velocity model's parameters for velocity obs_builder
self.velocity_obs_builder = PolicyObsBuilder(
dof_names_onnx=self.velocity_dof_names_onnx,
default_angles_onnx=self.velocity_default_angles_onnx,
evaluator=self,
obs_policy_cfg=self._get_policy_atomic_obs_list(self.velocity_config),
)
# Use motion model's parameters for motion obs_builder
self.motion_obs_builder = PolicyObsBuilder(
dof_names_onnx=self.motion_dof_names_onnx,
default_angles_onnx=self.motion_default_angles_onnx,
evaluator=self,
obs_policy_cfg=self._get_policy_atomic_obs_list(self.motion_config),
)
if hasattr(self, "n_fut_frames") and int(self.n_fut_frames) > 0:
n_fut = int(self.n_fut_frames)
self.external_fut_dof_pos_queue = np.zeros((n_fut, self.num_actions), dtype=np.float32)
self.external_fut_dof_vel_queue = np.zeros((n_fut, self.num_actions), dtype=np.float32)
self.external_fut_root_pos_queue = np.zeros((n_fut, 3), dtype=np.float32)
self.external_fut_root_rot_queue = np.zeros((n_fut, 4), dtype=np.float32)
self._fk_root_pos_seq_np = np.zeros((1, n_fut + 1, 3), dtype=np.float32)
self._fk_root_rot_seq_np = np.zeros((1, n_fut + 1, 4), dtype=np.float32)
self._fk_dof_pos_seq_np = np.zeros(
(1, n_fut + 1, self.num_actions), dtype=np.float32
)
self._fk_root_pos_seq_tensor = torch.from_numpy(self._fk_root_pos_seq_np)
self._fk_root_rot_seq_tensor = torch.from_numpy(self._fk_root_rot_seq_np)
self._fk_dof_pos_seq_tensor = torch.from_numpy(self._fk_dof_pos_seq_np)
self.external_fut_frame_idx_queue = np.full((n_fut,), -1, dtype=np.int32)
self.get_logger().info(
f"Initialized VR future frame queues: n_fut_frames={n_fut}, num_actions={self.num_actions}"
)
else:
self.external_fut_dof_pos_queue = None
self.external_fut_dof_vel_queue = None
self.external_fut_root_pos_queue = None
self.external_fut_root_rot_queue = None
self.external_fut_frame_idx_queue = None
self._fk_root_pos_seq_np = None
self._fk_root_rot_seq_np = None
self._fk_dof_pos_seq_np = None
self._fk_root_pos_seq_tensor = None
self._fk_root_rot_seq_tensor = None
self._fk_dof_pos_seq_tensor = None
# Set default obs_builder to velocity mode
self.obs_builder = self.velocity_obs_builder
def _reset_counter(self):
"""Reset motion timing counters to start of sequence."""
self.motion_frame_idx = 0
self.motion_step_idx = 0
if self.use_kv_cache and self.motion_kv_cache is not None:
self.motion_kv_cache.fill(0)
def _switch_to_velocity_mode(self, reason: str = ""):
"""Switch to velocity tracking mode and clear action cache.
Uses velocity model's default_angles_onnx to ensure correct initialization.
Also publishes velocity model's control parameters (kps/kds).
"""
self.current_policy_mode = "velocity"
self.latest_obs_flag = False
self.motion_in_progress = False
self._fk_vr_out = None
self._use_fk_vr = False
self._reset_motion_action_ema_filter()
self._reset_counter()
self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)
# Use velocity model's default angles
self.target_dof_pos_onnx = self.velocity_default_angles_onnx.copy()
# Publish velocity model's control parameters
self._publish_control_params()
if reason:
self.get_logger().info(f"Switched to velocity tracking mode ({reason})")
else:
self.get_logger().info("Switched to velocity tracking mode")
def _is_button_pressed(self, button_key):
"""Check if button was just pressed (rising edge detection)."""
current_state = self.remote_controller.button[button_key]
last_state = self.last_button_states[button_key]
# Update the last state
self.last_button_states[button_key] = current_state
# Return True only on rising edge (0 -> 1)
return current_state == 1 and last_state == 0
def load_policy(self):
"""Load both velocity and motion policy models using ONNX Runtime."""
self.get_logger().info("Loading dual policies...")
providers = [
(
"CUDAExecutionProvider",
{
"device_id": 0,
},
),
"CPUExecutionProvider",
]
onnx_threads = int(self.config_yaml.get("onnx_intra_op_threads", 2))
sess_options = onnxruntime.SessionOptions()
sess_options.intra_op_num_threads = onnx_threads
sess_options.inter_op_num_threads = 1
# Load velocity policy from model folder
velocity_model_folder = self.config_yaml.velocity_tracking_model_folder
velocity_model_path = os.path.join(
get_package_share_directory("humanoid_control"),
"models",
velocity_model_folder,
"exported",
)
# Find ONNX file in exported folder
velocity_onnx_files = [f for f in os.listdir(velocity_model_path) if f.endswith('.onnx')]
if not velocity_onnx_files:
raise FileNotFoundError(f"No ONNX files found in {velocity_model_path}")
velocity_onnx_path = os.path.join(velocity_model_path, velocity_onnx_files[0])
self.get_logger().info(f"Loading velocity policy from {velocity_onnx_path}")
self.velocity_policy_session = onnxruntime.InferenceSession(
str(velocity_onnx_path), sess_options=sess_options, providers=providers
)
self.get_logger().info(
f"Velocity policy loaded successfully using: "
f"{self.velocity_policy_session.get_providers()}"
)
# Load motion policy from model folder
motion_model_folder = self.config_yaml.motion_tracking_model_folder
motion_model_path = os.path.join(
get_package_share_directory("humanoid_control"),
"models",
motion_model_folder,
"exported",
)
# Find ONNX file in exported folder
motion_onnx_files = [f for f in os.listdir(motion_model_path) if f.endswith('.onnx')]
if not motion_onnx_files:
raise FileNotFoundError(f"No ONNX files found in {motion_model_path}")
motion_onnx_path = os.path.join(motion_model_path, motion_onnx_files[0])
self.get_logger().info(f"Loading motion policy from {motion_onnx_path}")
self.motion_policy_session = onnxruntime.InferenceSession(
str(motion_onnx_path), sess_options=sess_options, providers=providers
)
self.get_logger().info(
f"Motion policy loaded successfully using: "
f"{self.motion_policy_session.get_providers()}"
)
# Set input/output names for both policies
self.velocity_input_name = self.velocity_policy_session.get_inputs()[0].name
self.velocity_output_name = self.velocity_policy_session.get_outputs()[0].name
self.motion_input_name = self.motion_policy_session.get_inputs()[0].name
self.motion_output_name = self.motion_policy_session.get_outputs()[0].name
self.get_logger().info(
f"Velocity policy - Input: {self.velocity_input_name}, "
f"Output: {self.velocity_output_name}"
)
self.get_logger().info(
f"Motion policy - Input: {self.motion_input_name}, "
f"Output: {self.motion_output_name}"
)
# Store ONNX paths for metadata reading
self.velocity_onnx_path = velocity_onnx_path
self.motion_onnx_path = motion_onnx_path
self.get_logger().info("Initializing KV-Cache for Motion Policy...")
self.motion_kv_input_name = None
self.motion_kv_output_name = None
self.motion_kv_shape = None
self.motion_step_idx_input_name = None
self.motion_kv_dtype = np.float32
for node in self.motion_policy_session.get_inputs():
name = node.name
shape = node.shape
node_type = node.type
self.get_logger().info(f"Motion policy input: name={name}, shape={shape}, type={node_type}")
if "obs" in name:
self.motion_input_name = name
elif "past_key_values" in name:
self.motion_kv_input_name = name
self.motion_kv_shape = shape
if isinstance(node_type, str) and "float16" in node_type:
self.motion_kv_dtype = np.float16
elif "step_idx" in name or name == "step_idx":
self.motion_step_idx_input_name = name
elif "current_pos" in name or name == "current_pos":
self.motion_step_idx_input_name = name
elif (
self.motion_step_idx_input_name is None
and isinstance(node_type, str)
and "int64" in node_type
and name not in (self.motion_input_name, self.motion_kv_input_name)
):
self.motion_step_idx_input_name = name
motion_outputs = self.motion_policy_session.get_outputs()
action_output_name = None
kv_output_name = None
for node in motion_outputs:
self.get_logger().info(f"Motion policy output: name={node.name}, shape={node.shape}, type={node.type}")
if "present_key_values" in node.name:
kv_output_name = node.name
elif "actions" in node.name:
action_output_name = node.name
if action_output_name is None:
for node in motion_outputs:
if kv_output_name is not None and node.name == kv_output_name:
continue
action_output_name = node.name
break
if action_output_name is None:
action_output_name = motion_outputs[0].name
self.motion_output_name = action_output_name
self.motion_kv_output_name = kv_output_name
if self.motion_kv_input_name is not None and self.motion_kv_output_name is None:
self.get_logger().warn(
"Motion policy has past_key_values input but no present_key_values output was found. "
"KV cache will not update and transformer performance will degrade."
)
if self.motion_kv_input_name and self.motion_kv_shape:
shape = [d if isinstance(d, int) else 1 for d in self.motion_kv_shape]
self.motion_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype)
self.motion_model_context_len = int(shape[3]) if len(shape) > 3 else 0
self.motion_max_context_len = int(
self.motion_config.get("algo", {})
.get("config", {})
.get("num_steps_per_env", 0)
)
if self.motion_max_context_len > 0 and self.motion_model_context_len > 0:
self.motion_effective_context_len = min(
self.motion_max_context_len, self.motion_model_context_len
)
else:
self.motion_effective_context_len = self.motion_model_context_len
self.use_kv_cache = True
self.get_logger().info(
f"KV-Cache initialized with shape {shape} "
f"(model_ctx={self.motion_model_context_len}, "
f"effective_ctx={self.motion_effective_context_len})"
)
else:
self.use_kv_cache = False
self.motion_kv_cache = None
self.motion_model_context_len = 0
self.motion_effective_context_len = 0
self.get_logger().warn("No KV-Cache inputs found in Motion Policy model!")
self.get_logger().info("Dual policies loaded successfully")
def load_model_config(self):
"""Load config.yaml from both velocity and motion model folders."""
# Load velocity model config
velocity_model_folder = self.config_yaml.velocity_tracking_model_folder
velocity_config_dir = os.path.join(
get_package_share_directory("humanoid_control"),
"models",
velocity_model_folder,
)
# Try different config file names for velocity model
config_names = ["config.yaml"]
velocity_config_path = None
for config_name in config_names:
potential_path = os.path.join(velocity_config_dir, config_name)
if os.path.exists(potential_path):
velocity_config_path = potential_path
break
if velocity_config_path is None:
raise FileNotFoundError(
f"No config file found in {velocity_config_dir}. Tried: {config_names}"
)
self.get_logger().info(
f"Loading velocity model config from {velocity_config_path}"
)
self.velocity_config = OmegaConf.load(velocity_config_path)
# Load motion model config
motion_model_folder = self.config_yaml.motion_tracking_model_folder
motion_config_dir = os.path.join(
get_package_share_directory("humanoid_control"),
"models",
motion_model_folder,
)
# Try different config file names for motion model
motion_config_path = None
for config_name in config_names:
potential_path = os.path.join(motion_config_dir, config_name)
if os.path.exists(potential_path):
motion_config_path = potential_path
break
if motion_config_path is None:
raise FileNotFoundError(
f"No config file found in {motion_config_dir}. Tried: {config_names}"
)
self.get_logger().info(f"Loading motion model config from {motion_config_path}")
self.motion_config = OmegaConf.load(motion_config_path)
self._load_motion_action_ema_filter_cfg()
self.actor_place_holder_ndim = self._find_actor_place_holder_ndim()
self.n_fut_frames = int(self.motion_config.obs.n_fut_frames)
self.torso_body_idx = self.motion_config.robot.body_names.index("torso_link")
self.get_logger().info("Both model configs loaded successfully")
def _load_motion_action_ema_filter_cfg(self) -> None:
actuator_cfg = self.motion_config.get("robot", {}).get("actuators", {})
enabled_raw = actuator_cfg.get("ema_filter_enabled", None)
alpha_raw = actuator_cfg.get("ema_filter_alpha", None)
if enabled_raw is None or alpha_raw is None:
self.motion_action_ema_filter_enabled = False
self.motion_action_ema_filter_alpha = 1.0
self.get_logger().info(
"[Motion EMA] ema_filter_enabled/ema_filter_alpha not found in motion config; EMA disabled."
)
return
self.motion_action_ema_filter_enabled = _coerce_config_bool(
enabled_raw, default=False
)
self.motion_action_ema_filter_alpha = float(alpha_raw)
if not 0.0 <= self.motion_action_ema_filter_alpha <= 1.0:
raise ValueError(
"motion_config robot.actuators.ema_filter_alpha must be within [0, 1], "
f"got {self.motion_action_ema_filter_alpha}."
)
self.get_logger().info(
"[Motion EMA] Loaded from motion config: "
f"enabled={self.motion_action_ema_filter_enabled}, "
f"alpha={self.motion_action_ema_filter_alpha:.4f}"
)
def _reset_motion_action_ema_filter(self) -> None:
self._motion_filtered_actions_onnx = None
def _apply_motion_action_ema_filter(
self, raw_actions: np.ndarray
) -> np.ndarray:
raw_actions = np.asarray(raw_actions, dtype=np.float32).reshape(-1)
if not self.motion_action_ema_filter_enabled:
return raw_actions.copy()
if self._motion_filtered_actions_onnx is None:
self._motion_filtered_actions_onnx = raw_actions.copy()
return self._motion_filtered_actions_onnx.copy()
alpha = float(self.motion_action_ema_filter_alpha)
filtered_actions = (
alpha * raw_actions
+ (1.0 - alpha) * self._motion_filtered_actions_onnx
).astype(np.float32, copy=False)
self._motion_filtered_actions_onnx = filtered_actions.copy()
return self._motion_filtered_actions_onnx.copy()
def _build_dummy_input_from_onnx_node(self, node, fallback_last_dim: int | None = None):
shape = list(getattr(node, "shape", []) or [])
if not shape:
shape = [1]
inferred_shape = [_infer_onnx_dim(dim, default=1) for dim in shape]
if fallback_last_dim is not None and len(inferred_shape) >= 2:
last_dim = shape[-1]
if not isinstance(last_dim, int) or last_dim <= 0:
inferred_shape[-1] = int(fallback_last_dim)
dtype = _infer_numpy_dtype_from_onnx_type(getattr(node, "type", "tensor(float)"))
return np.zeros(inferred_shape, dtype=dtype)
def _warmup_motion_policy(self, num_iters: int = 2) -> None:
if self.motion_policy_session is None:
return
try:
input_nodes = {node.name: node for node in self.motion_policy_session.get_inputs()}
obs_node = input_nodes.get(self.motion_input_name, None)
if obs_node is None:
raise ValueError(
f"Motion warmup failed to find obs input '{self.motion_input_name}'."
)
motion_obs_dim = None
try:
motion_obs_dim = int(
self.motion_obs_builder.build_policy_obs().shape[0]
)
except Exception:
motion_obs_dim = None
obs_dummy = self._build_dummy_input_from_onnx_node(
obs_node, fallback_last_dim=motion_obs_dim
)
output_names = [self.motion_output_name]
if self.motion_kv_output_name:
output_names.append(self.motion_kv_output_name)
local_kv_cache = None
if self.use_kv_cache and self.motion_kv_input_name is not None:
if self.motion_kv_cache is not None:
local_kv_cache = np.zeros_like(self.motion_kv_cache)
else:
shape = [
_infer_onnx_dim(dim, default=1)
for dim in (self.motion_kv_shape or [])
]
local_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype)
for warmup_step in range(max(1, int(num_iters))):
input_feed = {self.motion_input_name: obs_dummy}
if self.use_kv_cache and self.motion_kv_input_name is not None:
input_feed[self.motion_kv_input_name] = local_kv_cache
if self.motion_step_idx_input_name is not None:
step_node = input_nodes.get(self.motion_step_idx_input_name, None)
step_dtype = np.int64
if step_node is not None:
step_dtype = _infer_numpy_dtype_from_onnx_type(
getattr(step_node, "type", "tensor(int64)")
)
input_feed[self.motion_step_idx_input_name] = np.array(
[warmup_step], dtype=step_dtype
)
warmup_output = self.motion_policy_session.run(output_names, input_feed)
if (
local_kv_cache is not None
and self.motion_kv_output_name
and len(warmup_output) > 1
):
local_kv_cache = warmup_output[1]
if self.motion_kv_cache is not None:
self.motion_kv_cache.fill(0)
self.motion_step_idx = 0
self.get_logger().info(
f"[Warmup] Motion policy warmup completed ({max(1, int(num_iters))} iterations, KV cache kept clean)."
)
except Exception as exc:
if self.motion_kv_cache is not None:
self.motion_kv_cache.fill(0)
self.motion_step_idx = 0
self.get_logger().warn(f"[Warmup] Motion policy warmup skipped: {exc}")
def update_config_parameters(self):
"""Update configuration parameters from loaded configs."""
# Check if both models have the same basic parameters
velocity_actions_dim = self.velocity_config.get("robot", {}).get("actions_dim", 29)
motion_actions_dim = self.motion_config.get("robot", {}).get("actions_dim", 29)
velocity_dof_names = self.velocity_config.get("robot", {}).get("dof_names", [])
motion_dof_names = self.motion_config.get("robot", {}).get("dof_names", [])
# Verify that both models have compatible configurations
if velocity_actions_dim != motion_actions_dim:
self.get_logger().warn(
f"Different actions_dim: velocity={velocity_actions_dim}, "
f"motion={motion_actions_dim}"
)
if velocity_dof_names != motion_dof_names:
self.get_logger().warn(f"Different dof_names between models")
self.get_logger().warn(f"Velocity dof_names: {len(velocity_dof_names)} items")
self.get_logger().warn(f"Motion dof_names: {len(motion_dof_names)} items")
# Use velocity config as the primary source for basic parameters
config = self.velocity_config
# Update basic parameters
self.actions_dim = config.get("robot", {}).get("actions_dim", 29)
self.real_dof_names = config.get("robot", {}).get("dof_names", [])
self.dof_names_ref_motion = list(config.robot.dof_names)
self.num_actions = len(self.dof_names_ref_motion)
# Update arrays with correct sizes
self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32)
self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.target_dof_pos_onnx = self.default_angles_onnx.copy()
self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.get_logger().info(
f"Updated config parameters: actions_dim={self.actions_dim}, "
f"dof_names={len(self.real_dof_names)}"
)
def load_motion_data(self):
"""Load motion clip data from .npz files."""
motion_clips_dir = os.path.join(
get_package_share_directory("humanoid_control"),
self.config_yaml.motion_clip_dir,
)
self.get_logger().info(f"Looking for motion clip data in: {motion_clips_dir}")
self.get_logger().info(f"Directory exists: {os.path.exists(motion_clips_dir)}")
if not os.path.exists(motion_clips_dir):
self.get_logger().warn(f"Motion clips directory not found: {motion_clips_dir}")
return
# Only collect .npz files
motion_clip_files = [f for f in os.listdir(motion_clips_dir) if f.endswith(".npz")]
motion_clip_files.sort()
self.get_logger().info(
f"Found {len(motion_clip_files)} motion clip files (.npz): {motion_clip_files}"
)
if not motion_clip_files:
self.get_logger().warn(
f"No motion clip files (.npz) found in directory: {motion_clips_dir}"
)
return
# Load each .npz file
self.all_motion_data = []
self.motion_file_names = []
for motion_clip_file in motion_clip_files:
motion_path = os.path.join(motion_clips_dir, motion_clip_file)
motion_data_dict = dict(np.load(motion_path, allow_pickle=True))
self.all_motion_data.append(
{
"dof_pos": motion_data_dict["ref_dof_pos"],
"dof_vel": motion_data_dict["ref_dof_vel"],
"global_translation": motion_data_dict[
"ref_global_translation"
],
"global_rotation_quat": motion_data_dict[
"ref_global_rotation_quat"
],
"global_velocity": motion_data_dict["ref_global_velocity"],
"global_angular_velocity": motion_data_dict["ref_global_angular_velocity"],
"n_frames": motion_data_dict["ref_dof_pos"].shape[0],
}
)
self.motion_file_names.append(motion_clip_file)
if not self.all_motion_data:
self.get_logger().error("Failed to load any motion clip files")
return
# Initialize with the first motion clip
self.current_motion_clip_index = 0
self._load_current_motion()
self.get_logger().info(f"Loaded {len(self.all_motion_data)} motion clips successfully")
self.get_logger().info(
f"Current motion clip: {self.motion_file_names[self.current_motion_clip_index]}"
)
def _load_current_motion(self):
"""Load the current selected motion clip data."""
if not self.all_motion_data:
return
self.motion_frame_idx = 0
current_motion = self.all_motion_data[self.current_motion_clip_index]
self.ref_dof_pos = current_motion["dof_pos"]
self.ref_dof_vel = current_motion["dof_vel"]
self.ref_raw_bodylink_pos = current_motion["global_translation"]
self.ref_raw_bodylink_rot = current_motion["global_rotation_quat"]
self.ref_global_velocity = current_motion["global_velocity"]
self.ref_global_angular_velocity = current_motion["global_angular_velocity"]
self.n_motion_frames = current_motion["n_frames"]
if self.ref_dof_pos is None or self.ref_dof_vel is None:
raise ValueError("Motion clip is missing ref_dof_pos/ref_dof_vel arrays")
if self.ref_raw_bodylink_pos is None or self.ref_raw_bodylink_rot is None:
raise ValueError(
"Motion clip is missing ref_global_translation/ref_global_rotation_quat arrays"
)
if int(self.ref_dof_pos.shape[1]) != int(len(self.dof_names_ref_motion)):
raise ValueError(
"ref_dof_pos DOF dimension mismatch: "
f"ref_dof_pos.shape[1]={int(self.ref_dof_pos.shape[1])} "
f"but len(dof_names_ref_motion)={int(len(self.dof_names_ref_motion))}"
)
if int(self.ref_raw_bodylink_pos.shape[1]) != int(
len(self.motion_config.robot.body_names)
):
raise ValueError(
"ref_global_translation body dimension mismatch: "
f"ref_raw_bodylink_pos.shape[1]={int(self.ref_raw_bodylink_pos.shape[1])} "
f"but len(motion_config.robot.body_names)={int(len(self.motion_config.robot.body_names))}"
)
self.motion_in_progress = True
self.get_logger().info(
f"Loaded motion clip {self.current_motion_clip_index}: "
f"{self.motion_file_names[self.current_motion_clip_index]} ({self.n_motion_frames} frames)"
)
def _setup_subscribers(self):
"""Set up ROS2 subscribers for robot state and remote controller input."""
self.remote_controller = RemoteController()
self.low_state_sub = self.create_subscription(
LowState,
self.config_yaml.lowstate_topic,
self._low_state_callback,
QoSProfile(depth=10),
)
# Add robot_state topic subscription
self.robot_state_sub = self.create_subscription(
String,
"/robot_state",
self._robot_state_callback,
QoSProfile(depth=10),
)
self.latest_obs_ros_sub = self.create_subscription(
Float32MultiArray,
"latest_obs_ros",
self._latest_obs_ros_callback,
QoSProfile(depth=10),
)
def _latest_obs_ros_callback(self, msg: Float32MultiArray):
"""Receive replayed latest_obs_ros messages for offline validation."""
data = np.asarray(msg.data, dtype=np.float32)
if data.size == 66:
frame_idx = int(data[0])
obs = data[1:66]
self._ros_latest_obs_buffer = (frame_idx, obs)
elif data.size >= 65:
self._ros_latest_obs_buffer = (None, data[:65])
def _setup_publishers(self):
"""Set up ROS2 publishers for action commands and status information."""
self.action_pub = self.create_publisher(
Float32MultiArray,
self.config_yaml.action_topic,
QoSProfile(depth=10),
)
# Add publishers for kps and kds parameters
self.kps_pub = self.create_publisher(
Float32MultiArray,
"/humanoid/kps",
QoSProfile(depth=10),
)
self.kds_pub = self.create_publisher(
Float32MultiArray,
"/humanoid/kds",
QoSProfile(depth=10),
)
# Add publisher for policy mode status
self.policy_mode_pub = self.create_publisher(
String,
"policy_mode",
QoSProfile(depth=10),
)
self.latest_obs_pub = self.create_publisher(
Float32MultiArray,
"latest_obs",
QoSProfile(depth=10),
)
def _setup_timers(self):
"""Set up ROS2 timer for main execution loop."""
# Create a one-time timer to call setup after ROS2 initialization
self.create_timer(0.1, self._delayed_setup)
self.create_timer(self.dt, self.run)
def _delayed_setup(self):
"""Call setup after ROS2 initialization is complete."""
if not hasattr(self, '_setup_completed'):
self.get_logger().info("Starting policy node setup...")
try:
self.setup()
self._setup_completed = True
self.get_logger().info("Policy node setup completed successfully")
except Exception as e:
self.get_logger().error(f"Policy node setup failed: {e}")
# Cancel the timer to avoid repeated attempts
return
def _robot_state_callback(self, msg: String):
"""Handle robot state messages for safety coordination.
Processes robot state updates from the main control node to ensure
safe operation. Button operations are only allowed when the robot
is in MOVE_TO_DEFAULT state.
Args:
msg: String message containing robot state information
Valid states: ZERO_TORQUE, MOVE_TO_DEFAULT, EMERGENCY_STOP, POLICY
"""
robot_state = msg.data
# Only allow button operations when robot state is MOVE_TO_DEFAULT
if robot_state == "MOVE_TO_DEFAULT":
self.robot_state_ready = True
elif robot_state == "ZERO_TORQUE":
self.robot_state_ready = False
elif robot_state == "EMERGENCY_STOP":
self.robot_state_ready = False
# =========== Properties ===========
@property
def robot_root_rot_quat_wxyz(self):
return np.array(self._lowstate_msg.imu_state.quaternion, dtype=np.float32)
@property
def robot_root_ang_vel(self):
return np.array(self._lowstate_msg.imu_state.gyroscope, dtype=np.float32)
@property
def robot_dof_pos_by_name(self):
"""Get DOF positions by name."""
if self._lowstate_msg is None:
return {}
return {
self.real_dof_names[i]: float(self._lowstate_msg.motor_state[i].q)
for i in range(self.actions_dim)
}
@property
def robot_dof_vel_by_name(self):
"""Get DOF velocities by name."""
if self._lowstate_msg is None:
return {}
return {
self.real_dof_names[i]: float(self._lowstate_msg.motor_state[i].dq)
for i in range(self.actions_dim)
}
@property
def ref_motion_frame_idx(self):
return min(self.motion_frame_idx, self.n_motion_frames - 1)
@property
def ref_dof_pos_raw(self):
if not self.latest_obs_flag:
return self.ref_dof_pos[self.ref_motion_frame_idx]
if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None:
if self._prev_external_dof_pos is not None:
return self._prev_external_dof_pos
return self.external_fut_dof_pos_queue[0]
if self.external_latest_obs is None:
return self.ref_dof_pos[self.ref_motion_frame_idx]
return self.external_latest_obs[0, :29]
@property
def ref_dof_vel_raw(self):
if not self.latest_obs_flag:
return self.ref_dof_vel[self.ref_motion_frame_idx]
if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None:
if self._prev_external_dof_vel is not None:
return self._prev_external_dof_vel
return self.external_fut_dof_vel_queue[0]
if self.external_latest_obs is None:
return self.ref_dof_vel[self.ref_motion_frame_idx]
return self.external_latest_obs[0, 29:58]
@property
def ref_dof_pos_onnx_order(self):
return self.ref_dof_pos_raw[self.ref_to_onnx]
@property
def ref_dof_vel_onnx_order(self):
return self.ref_dof_vel_raw[self.ref_to_onnx]
@property
def ref_root_pos_raw(self):
if not self.latest_obs_flag:
return np.asarray(
self.ref_raw_bodylink_pos[self.ref_motion_frame_idx, self.root_body_idx],
dtype=np.float32,
)
if self.n_fut_frames > 0 and self.external_fut_root_pos_queue is not None:
if self._prev_external_root_pos is not None:
return self._prev_external_root_pos.astype(np.float32)
return self.external_fut_root_pos_queue[0].astype(np.float32)
if self.external_latest_obs is None:
return np.zeros(3, dtype=np.float32)
return self.external_latest_obs[0, 58:61].astype(np.float32)
@property
def root_body_idx(self):
return 0
@property
def last_valid_ref_motion_frame_idx(self):
return self.n_motion_frames - 1
# =========== Policy Obeservation Methods ===========
def _xyzw_to_wxyz(self, q_xyzw: np.ndarray) -> np.ndarray:
"""Convert quaternions from xyzw to wxyz order."""
q_xyzw = np.asarray(q_xyzw, dtype=np.float32)
if q_xyzw.shape[-1] != 4:
raise ValueError(f"_xyzw_to_wxyz expects (...,4) but got shape {q_xyzw.shape}")
# q_xyzw[..., 0:3] -> xyz, q_xyzw[..., 3:4] -> w
w = q_xyzw[..., 3:4]
xyz = q_xyzw[..., 0:3]
return np.concatenate([w, xyz], axis=-1)
def _standardize_quaternion_wxyz(self, q_wxyz: np.ndarray) -> np.ndarray:
"""Standardize quaternion sign so that w >= 0."""
q_wxyz = np.asarray(q_wxyz, dtype=np.float32)
if q_wxyz.shape[-1] != 4:
raise ValueError(f"_standardize_quaternion_wxyz expects (...,4) but got shape {q_wxyz.shape}")
mask = q_wxyz[..., 0:1] < 0.0
q_wxyz = np.where(mask, -q_wxyz, q_wxyz)
return q_wxyz
def _quat_rotate_wxyz(self, q_wxyz: np.ndarray, v: np.ndarray) -> np.ndarray:
q_wxyz = np.asarray(q_wxyz, dtype=np.float32)
v = np.asarray(v, dtype=np.float32)
qvec = q_wxyz[..., 1:4]
w = q_wxyz[..., 0:1]
t = 2.0 * np.cross(qvec, v)
return v + w * t + np.cross(qvec, t)
def _quat_rotate_inv_wxyz(self, q_wxyz: np.ndarray, v: np.ndarray) -> np.ndarray:
q_wxyz = np.asarray(q_wxyz, dtype=np.float32)
n = int(np.prod(q_wxyz.shape[:-1])) if q_wxyz.ndim > 1 else 1
q_conj = self._q_conj_buffer[:n].reshape(q_wxyz.shape)
q_conj[..., 0] = q_wxyz[..., 0]
q_conj[..., 1:4] = -q_wxyz[..., 1:4]
return self._quat_rotate_wxyz(q_conj, v)
def _quat_rotate_inv_wxyz_single(
self, q_wxyz: np.ndarray, v: np.ndarray, out: np.ndarray
) -> np.ndarray:
"""Rotate one 3D vector by the inverse quaternion into a preallocated output."""
q_conj = self._q_conj_buffer[0]
q_conj[0] = q_wxyz[0]
q_conj[1] = -q_wxyz[1]
q_conj[2] = -q_wxyz[2]
q_conj[3] = -q_wxyz[3]
qvec = q_conj[1:4]
w = q_conj[0]
self._cross_t_buffer[:] = np.cross(qvec, v)
self._cross_t_buffer *= 2.0
out[:] = v + w * self._cross_t_buffer
self._cross_t_buffer[:] = np.cross(qvec, self._cross_t_buffer)
out += self._cross_t_buffer
return out
def _get_future_frame_indices(self) -> np.ndarray:
frame_idx = self.ref_motion_frame_idx
last_valid = self.last_valid_ref_motion_frame_idx
np.minimum(
frame_idx + self._future_frame_offsets,
last_valid,
out=self._future_frame_indices_buffer,
)
return self._future_frame_indices_buffer
def _cache_fk_vr_for_obs(self):
"""Cache FK outputs used repeatedly during observation construction."""
fk = getattr(self, "_fk_vr_out", None)
if not getattr(self, "latest_obs_flag", False) or fk is None:
self._use_fk_vr = False
return
self._use_fk_vr = True
T = self.n_fut_frames_int
rb = self.root_body_idx
np.copyto(self._fk_vel_0_root, fk["global_velocity"][0, 0, rb])
np.copyto(self._fk_angvel_0_root, fk["global_angular_velocity"][0, 0, rb])
np.copyto(self._fk_quat_0_root, fk["global_rotation_quat"][0, 0, rb])
self._fk_quat_0_root_wxyz[0] = self._fk_quat_0_root[3]
self._fk_quat_0_root_wxyz[1:4] = self._fk_quat_0_root[:3]
if self._fk_quat_0_root_wxyz[0] < 0.0:
self._fk_quat_0_root_wxyz *= -1.0
trans_0 = fk["global_translation"][0, 0]
if self._fk_trans_0 is None or self._fk_trans_0.shape != trans_0.shape:
self._fk_trans_0 = np.empty_like(trans_0)
np.copyto(self._fk_trans_0, trans_0)
if T > 0:
np.copyto(self._fk_vel_fut[:T], fk["global_velocity"][0, 1 : 1 + T, rb])
np.copyto(self._fk_angvel_fut[:T], fk["global_angular_velocity"][0, 1 : 1 + T, rb])
np.copyto(self._fk_quat_fut[:T], fk["global_rotation_quat"][0, 1 : 1 + T, rb])
self._fk_quat_fut_wxyz[:T, 0] = self._fk_quat_fut[:T, 3]
self._fk_quat_fut_wxyz[:T, 1:4] = self._fk_quat_fut[:T, :3]
neg = self._fk_quat_fut_wxyz[:T, 0] < 0.0
self._fk_quat_fut_wxyz[:T][neg] *= -1.0
trans_fut = fk["global_translation"][0, 1 : 1 + T]
if self._fk_trans_fut is None or self._fk_trans_fut.shape != trans_fut.shape:
self._fk_trans_fut = np.empty_like(trans_fut)
np.copyto(self._fk_trans_fut, trans_fut)
self._fill_vr_base_linvel_angvel_fut()
def _fill_vr_base_linvel_angvel_fut(self):
"""Rotate future linear and angular velocity buffers in one pass."""
T = self.n_fut_frames_int
if T <= 0:
return
vel_T6 = self._vel_fut_T6[:T]
vel_T6[:, :3] = self._fk_vel_fut[:T]
vel_T6[:, 3:6] = self._fk_angvel_fut[:T]
q = self._fk_quat_fut_wxyz[:T]
q_conj = self._q_conj_buffer[:T].reshape(T, 4)
q_conj[:, 0] = q[:, 0]
q_conj[:, 1:4] = -q[:, 1:4]
qvec = q_conj[:, 1:4]
w = q_conj[:, 0:1]
rt = self._rot_t_buffer[:T]
rc = self._rot_cross_buffer[:T]
rt[:] = np.cross(qvec, vel_T6[:, :3])
rt *= 2.0
rc[:] = np.cross(qvec, rt)
self._base_linvel_fut_buffer[:T] = vel_T6[:, :3] + w * rt + rc
rt[:] = np.cross(qvec, vel_T6[:, 3:6])
rt *= 2.0
rc[:] = np.cross(qvec, rt)
self._base_angvel_fut_buffer[:T] = vel_T6[:, 3:6] + w * rt + rc
def _prepare_vr_fk_tensors(
self,
cur_root_pos: np.ndarray,
cur_root_rot: np.ndarray,
cur_dof_pos: np.ndarray,
n_fut: int,
):
"""Fill preallocated FK input buffers and return torch views without reallocation."""
if (
n_fut <= 0
or self._fk_root_pos_seq_np is None
or self._fk_root_rot_seq_np is None
or self._fk_dof_pos_seq_np is None
):
raise ValueError("VR FK sequence buffers are not initialized")
np.copyto(self._fk_root_pos_seq_np[0, 0], cur_root_pos)
np.copyto(self._fk_root_rot_seq_np[0, 0], cur_root_rot)
np.copyto(self._fk_dof_pos_seq_np[0, 0], cur_dof_pos)
np.copyto(
self._fk_root_pos_seq_np[0, 1 : 1 + n_fut],
self.external_fut_root_pos_queue[:n_fut],
)
np.copyto(
self._fk_root_rot_seq_np[0, 1 : 1 + n_fut],
self.external_fut_root_rot_queue[:n_fut],
)
np.copyto(
self._fk_dof_pos_seq_np[0, 1 : 1 + n_fut],
self.external_fut_dof_pos_queue[:n_fut],
)
return (
self._fk_root_pos_seq_tensor,
self._fk_root_rot_seq_tensor,
self._fk_dof_pos_seq_tensor,
)
def _get_future_root_quat_wxyz(self) -> np.ndarray:
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable; future_root_quat_wxyz will return zeros."
)
return self._future_root_quat_wxyz_buffer
fut_idx = self._get_future_frame_indices()
q_root_xyzw = np.asarray(
self.ref_raw_bodylink_rot[fut_idx, self.root_body_idx],
dtype=np.float32,
)
q_root_wxyz = self._future_root_quat_wxyz_buffer
q_root_wxyz[:, 0] = q_root_xyzw[:, 3]
q_root_wxyz[:, 1] = q_root_xyzw[:, 0]
q_root_wxyz[:, 2] = q_root_xyzw[:, 1]
q_root_wxyz[:, 3] = q_root_xyzw[:, 2]
neg_mask = q_root_wxyz[:, 0] < 0.0
q_root_wxyz[neg_mask] *= -1.0
return self._future_root_quat_wxyz_buffer
def _get_ref_keybody_indices(self, term_name: str) -> np.ndarray:
keybody_idxs = self._keybody_indices_by_term_name.get(term_name, None)
if keybody_idxs is None:
raise ValueError(
f"Keybody indices for term '{term_name}' were not cached. "
"Ensure the term exists in motion policy obs and cache is initialized."
)
return keybody_idxs
def _get_obs_actor_velocity_command(self):
return self._get_obs_velocity_command()
def _get_obs_actor_projected_gravity(self):
return self._get_obs_projected_gravity()
def _get_obs_actor_rel_robot_root_ang_vel(self):
return self._get_obs_rel_robot_root_ang_vel()
def _get_obs_actor_dof_pos(self):
return self._get_obs_dof_pos()
def _get_obs_actor_dof_vel(self):
return self._get_obs_dof_vel()
def _get_obs_actor_last_action(self):
return self._get_obs_last_action()
def _get_obs_actor_ref_gravity_projection_cur(self):
return self._get_obs_ref_gravity_projection_cur()
def _get_obs_actor_ref_gravity_projection_fut(self):
return self._get_obs_ref_gravity_projection_fut()
def _get_obs_actor_ref_base_linvel_cur(self):
return self._get_obs_ref_base_linvel_cur()
def _get_obs_actor_ref_base_linvel_fut(self):
return self._get_obs_ref_base_linvel_fut()
def _get_obs_actor_ref_base_angvel_cur(self):
return self._get_obs_ref_base_angvel_cur()
def _get_obs_actor_ref_base_angvel_fut(self):
return self._get_obs_ref_base_angvel_fut()
def _get_obs_actor_ref_dof_pos_cur(self):
return self._get_obs_ref_dof_pos_cur()
def _get_obs_actor_ref_dof_pos_fut(self):
return self._get_obs_ref_dof_pos_fut()
def _get_obs_actor_ref_root_height_cur(self):
return self._get_obs_ref_root_height_cur()
def _get_obs_actor_ref_root_height_fut(self):
return self._get_obs_ref_root_height_fut()
def _get_obs_actor_ref_keybody_rel_pos_cur(self):
return self._get_obs_ref_keybody_rel_pos_cur()
def _get_obs_actor_ref_keybody_rel_pos_fut(self):
return self._get_obs_ref_keybody_rel_pos_fut()
def _get_obs_velocity_command(self):
"""Get velocity command observation (reuses pre-allocated array)."""
self._velocity_cmd_obs[1] = self.vx
self._velocity_cmd_obs[2] = self.vy
self._velocity_cmd_obs[3] = self.vyaw
self._velocity_cmd_obs[0] = float(
np.linalg.norm(self._velocity_cmd_obs[1:4]) > 0.1
)
return self._velocity_cmd_obs
def _get_obs_projected_gravity(self):
return get_gravity_orientation(self.robot_root_rot_quat_wxyz)
def _get_obs_rel_robot_root_ang_vel(self):
return self.robot_root_ang_vel
def _get_obs_dof_pos(self):
"""Get DOF position observation (pre-allocated buffer + index lookup, no dict/list)."""
if self._lowstate_msg is None:
return self._dof_pos_obs_buffer[: len(self.motion_dof_names_onnx)]
if self.current_policy_mode == "motion":
buf = self._dof_pos_obs_buffer
ms = self._lowstate_msg.motor_state
def_angles = self.motion_default_angles_onnx
for i, ri in enumerate(self.motion_dof_real_indices):
buf[i] = ms[ri].q - def_angles[i]
return buf[: len(self.motion_dof_names_onnx)]
def_angles = self.velocity_default_angles_onnx
for i, ri in enumerate(self.velocity_dof_real_indices):
self._dof_pos_obs_buffer[i] = (
self._lowstate_msg.motor_state[ri].q - def_angles[i]
)
return self._dof_pos_obs_buffer[: len(self.velocity_dof_names_onnx)]
def _get_obs_dof_vel(self):
"""Get DOF velocity observation (pre-allocated buffer + index lookup, no dict/list)."""
if self._lowstate_msg is None:
return self._dof_vel_obs_buffer[: len(self.motion_dof_names_onnx)]
if self.current_policy_mode == "motion":
buf = self._dof_vel_obs_buffer
ms = self._lowstate_msg.motor_state
for i, ri in enumerate(self.motion_dof_real_indices):
buf[i] = ms[ri].dq
return buf[: len(self.motion_dof_names_onnx)]
for i, ri in enumerate(self.velocity_dof_real_indices):
self._dof_vel_obs_buffer[i] = self._lowstate_msg.motor_state[ri].dq
return self._dof_vel_obs_buffer[: len(self.velocity_dof_names_onnx)]
def _get_obs_last_action(self):
return self.actions_onnx.copy()
def _get_obs_ref_motion_states(self):
return np.concatenate(
[self.ref_dof_pos_onnx_order, self.ref_dof_vel_onnx_order]
)
def _get_obs_ref_dof_pos_fut(self):
"""Get future DOF position observation (reuses pre-allocated buffer)."""
T = self.n_fut_frames_int
if T <= 0:
return np.zeros(0, dtype=np.float32)
if getattr(self, "latest_obs_flag", False):
if (
getattr(self, "external_fut_dof_pos_queue", None) is not None
and self.external_fut_dof_pos_queue.shape[0] >= T
):
pos_fut = self._pos_fut_buffer
pos_fut[:, :] = self.external_fut_dof_pos_queue[:T].T
pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0) # [N, T]
return pos_fut_onnx.reshape(-1).astype(np.float32)
return np.zeros(self.num_actions * T, dtype=np.float32)
if not hasattr(self, "ref_dof_pos") or self.ref_dof_pos is None:
self.get_logger().warn(
"[VR] ref_dof_pos is unavailable and latest_obs is not active; returning zeros for ref_dof_pos_fut."
)
return np.zeros(self.num_actions * T, dtype=np.float32)
fut_idx = self._get_future_frame_indices()
pos_fut = self._pos_fut_buffer
pos_fut[:, :] = self.ref_dof_pos[fut_idx].T
# Reorder to ONNX and flatten per training layout
pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0) # [N, T]
return pos_fut_onnx.reshape(-1).astype(np.float32)
def _get_obs_ref_root_height_fut(self):
"""Get future root height observation (reuses pre-allocated buffer)."""
T = self.n_fut_frames_int
if T <= 0:
return np.zeros(0, dtype=np.float32)
if self.latest_obs_flag and self.external_fut_root_pos_queue is not None:
root_pos_fut = self.external_fut_root_pos_queue[:, 2].astype(np.float32)
return root_pos_fut.reshape(-1)
if not hasattr(self, "ref_raw_bodylink_pos") or self.ref_raw_bodylink_pos is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_root_height_fut."
)
return np.zeros(T, dtype=np.float32)
fut_idx = self._get_future_frame_indices()
h_fut = self._h_fut_buffer
h_fut[0, :] = self.ref_raw_bodylink_pos[fut_idx, self.root_body_idx, 2]
return h_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_root_pos_fut(self):
"""Get future root position observation (reuses pre-allocated buffer)."""
T = self.n_fut_frames_int
if T <= 0:
return np.zeros(0, dtype=np.float32)
if self.latest_obs_flag and self.external_fut_root_pos_queue is not None:
pos_fut = self.external_fut_root_pos_queue.astype(np.float32)
return pos_fut.reshape(-1).astype(np.float32)
if not hasattr(self, "ref_raw_bodylink_pos") or self.ref_raw_bodylink_pos is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_root_pos_fut."
)
return np.zeros(3 * T, dtype=np.float32)
fut_idx = self._get_future_frame_indices()
pos_fut = self._root_pos_fut_buffer
pos_fut[:, :] = self.ref_raw_bodylink_pos[fut_idx, self.root_body_idx, :]
return pos_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_dof_pos_cur(self):
return self.ref_dof_pos_onnx_order
def _get_obs_ref_dof_vel_cur(self):
return self.ref_dof_vel_onnx_order
def _get_obs_ref_root_height_cur(self):
if not self.latest_obs_flag:
return self.ref_raw_bodylink_pos[
self.ref_motion_frame_idx, self.root_body_idx, 2
]
return float(self.ref_root_pos_raw[2])
def _get_obs_ref_root_pos_cur(self):
return self.ref_root_pos_raw.astype(np.float32)
def _get_obs_ref_gravity_projection_cur(self):
if getattr(self, "_use_fk_vr", False):
return get_gravity_orientation(self._fk_quat_0_root_wxyz)
if getattr(self, "latest_obs_flag", False) and getattr(
self, "external_latest_obs", None
) is not None:
q_root_wxyz = self.external_latest_obs[0, 61:65].astype(np.float32)
q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)
return get_gravity_orientation(q_root_wxyz)
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for gravity_projection_cur."
)
return np.zeros(3, dtype=np.float32)
q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx]
q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw)
q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)
return get_gravity_orientation(q_root_wxyz)
def _get_obs_ref_gravity_projection_fut(self):
T = self.n_fut_frames_int
if T <= 0:
return np.zeros(0, dtype=np.float32)
if getattr(self, "_use_fk_vr", False):
q_root_wxyz = self._fk_quat_fut_wxyz[:T]
gravity_fut = self._gravity_fut_buffer
qw = q_root_wxyz[:, 0]
qx = q_root_wxyz[:, 1]
qy = q_root_wxyz[:, 2]
qz = q_root_wxyz[:, 3]
gravity_fut[:, 0] = 2.0 * (-qz * qx + qw * qy)
gravity_fut[:, 1] = -2.0 * (qz * qy + qw * qx)
gravity_fut[:, 2] = 1.0 - 2.0 * (qw * qw + qz * qz)
return gravity_fut.reshape(-1).astype(np.float32)
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_gravity_projection_fut."
)
return np.zeros(3 * T, dtype=np.float32)
q_root_wxyz = self._get_future_root_quat_wxyz()
gravity_fut = self._gravity_fut_buffer
qw = q_root_wxyz[:, 0]
qx = q_root_wxyz[:, 1]
qy = q_root_wxyz[:, 2]
qz = q_root_wxyz[:, 3]
gravity_fut[:, 0] = 2.0 * (-qz * qx + qw * qy)
gravity_fut[:, 1] = -2.0 * (qz * qy + qw * qx)
gravity_fut[:, 2] = 1.0 - 2.0 * (qw * qw + qz * qz)
return gravity_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_base_linvel_cur(self):
if getattr(self, "_use_fk_vr", False):
self._quat_rotate_inv_wxyz_single(
self._fk_quat_0_root_wxyz, self._fk_vel_0_root, self._rotated_3vec_buffer
)
return self._rotated_3vec_buffer
if getattr(self, "latest_obs_flag", False) and getattr(
self, "external_latest_obs", None
) is not None:
return np.zeros(3, dtype=np.float32)
if not hasattr(self, "ref_global_velocity") or self.ref_global_velocity is None:
self.get_logger().warn(
"[VR] ref_global_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_cur."
)
return np.zeros(3, dtype=np.float32)
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_cur."
)
return np.zeros(3, dtype=np.float32)
q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx]
q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw)
q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)
v_root_w = np.asarray(
self.ref_global_velocity[self.ref_motion_frame_idx, self.root_body_idx],
dtype=np.float32,
)
v_root = self._quat_rotate_inv_wxyz(q_root_wxyz, v_root_w)
return np.asarray(v_root, dtype=np.float32).reshape(3)
def _get_obs_ref_base_linvel_fut(self):
T = self.n_fut_frames_int
if T <= 0:
return np.zeros(0, dtype=np.float32)
if getattr(self, "_use_fk_vr", False):
return self._base_linvel_fut_buffer[:T].reshape(-1).astype(np.float32)
if not hasattr(self, "ref_global_velocity") or self.ref_global_velocity is None:
self.get_logger().warn(
"[VR] ref_global_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_fut."
)
return np.zeros(3 * T, dtype=np.float32)
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_linvel_fut."
)
return np.zeros(3 * T, dtype=np.float32)
fut_idx = self._get_future_frame_indices()
q_root_wxyz = self._get_future_root_quat_wxyz()
v_root_w = np.asarray(
self.ref_global_velocity[fut_idx, self.root_body_idx],
dtype=np.float32,
)
base_linvel_fut = self._base_linvel_fut_buffer
base_linvel_fut[:, :] = self._quat_rotate_inv_wxyz(q_root_wxyz, v_root_w)
return base_linvel_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_base_angvel_cur(self):
if getattr(self, "_use_fk_vr", False):
self._quat_rotate_inv_wxyz_single(
self._fk_quat_0_root_wxyz,
self._fk_angvel_0_root,
self._rotated_angvel_cur_buffer,
)
return self._rotated_angvel_cur_buffer
if getattr(self, "latest_obs_flag", False) and getattr(
self, "external_latest_obs", None
) is not None:
return np.zeros(3, dtype=np.float32)
if not hasattr(self, "ref_global_angular_velocity") or self.ref_global_angular_velocity is None:
self.get_logger().warn(
"[VR] ref_global_angular_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_cur."
)
return np.zeros(3, dtype=np.float32)
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_cur."
)
return np.zeros(3, dtype=np.float32)
q_root_xyzw = self.ref_raw_bodylink_rot[self.ref_motion_frame_idx, self.root_body_idx]
q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw)
q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)
w_root_w = np.asarray(
self.ref_global_angular_velocity[self.ref_motion_frame_idx, self.root_body_idx],
dtype=np.float32,
)
w_root = self._quat_rotate_inv_wxyz(q_root_wxyz, w_root_w)
return np.asarray(w_root, dtype=np.float32).reshape(3)
def _get_obs_ref_base_angvel_fut(self):
T = self.n_fut_frames_int
if T <= 0:
return np.zeros(0, dtype=np.float32)
if getattr(self, "_use_fk_vr", False):
return self._base_angvel_fut_buffer[:T].reshape(-1).astype(np.float32)
if not hasattr(self, "ref_global_angular_velocity") or self.ref_global_angular_velocity is None:
self.get_logger().warn(
"[VR] ref_global_angular_velocity is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_fut."
)
return np.zeros(3 * T, dtype=np.float32)
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_base_angvel_fut."
)
return np.zeros(3 * T, dtype=np.float32)
fut_idx = self._get_future_frame_indices()
q_root_wxyz = self._get_future_root_quat_wxyz()
w_root_w = np.asarray(
self.ref_global_angular_velocity[fut_idx, self.root_body_idx],
dtype=np.float32,
)
base_angvel_fut = self._base_angvel_fut_buffer
base_angvel_fut[:, :] = self._quat_rotate_inv_wxyz(q_root_wxyz, w_root_w)
return base_angvel_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_keybody_rel_pos_cur(self):
if getattr(self, "_use_fk_vr", False) and self._fk_trans_0 is not None:
keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur")
n_keybodies = int(keybody_idxs.shape[0])
if n_keybodies == 0:
return np.zeros(0, dtype=np.float32)
if not self._root_only_fk_has_required_keybodies(keybody_idxs):
return np.zeros(3 * n_keybodies, dtype=np.float32)
root_pos = self._fk_trans_0[self.root_body_idx]
keybody_pos = self._fk_trans_0[keybody_idxs]
rel_pos_w = keybody_pos - root_pos[None, :]
rel_pos_root = self._quat_rotate_inv_wxyz(self._fk_quat_0_root_wxyz, rel_pos_w)
return np.asarray(rel_pos_root, dtype=np.float32).reshape(-1)
if getattr(self, "latest_obs_flag", False) and getattr(
self, "external_latest_obs", None
) is not None:
keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur")
n_keybodies = int(keybody_idxs.shape[0])
if n_keybodies == 0:
return np.zeros(0, dtype=np.float32)
return np.zeros(3 * n_keybodies, dtype=np.float32)
if not hasattr(self, "ref_raw_bodylink_pos") or self.ref_raw_bodylink_pos is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_cur."
)
keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur")
n_keybodies = int(keybody_idxs.shape[0])
if n_keybodies == 0:
return np.zeros(0, dtype=np.float32)
return np.zeros(3 * n_keybodies, dtype=np.float32)
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_cur."
)
keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur")
n_keybodies = int(keybody_idxs.shape[0])
if n_keybodies == 0:
return np.zeros(0, dtype=np.float32)
return np.zeros(3 * n_keybodies, dtype=np.float32)
keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_cur")
n_keybodies = int(keybody_idxs.shape[0])
if n_keybodies == 0:
return np.zeros(0, dtype=np.float32)
frame_idx = self.ref_motion_frame_idx
ref_body_global_pos = np.asarray(self.ref_raw_bodylink_pos[frame_idx], dtype=np.float32)
ref_root_global_pos = ref_body_global_pos[self.root_body_idx]
q_root_xyzw = self.ref_raw_bodylink_rot[frame_idx, self.root_body_idx]
q_root_wxyz = self._xyzw_to_wxyz(q_root_xyzw)
q_root_wxyz = self._standardize_quaternion_wxyz(q_root_wxyz)
rel_pos_w = ref_body_global_pos[keybody_idxs] - ref_root_global_pos[None, :]
rel_pos_root = self._quat_rotate_inv_wxyz(q_root_wxyz, rel_pos_w)
return np.asarray(rel_pos_root, dtype=np.float32).reshape(-1)
def _get_obs_ref_keybody_rel_pos_fut(self):
T = self.n_fut_frames_int
if T <= 0:
return np.zeros(0, dtype=np.float32)
if getattr(self, "_use_fk_vr", False) and self._fk_trans_fut is not None:
keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_fut")
n_keybodies = int(keybody_idxs.shape[0])
if n_keybodies == 0:
return np.zeros((T, 0), dtype=np.float32).reshape(-1)
if not self._root_only_fk_has_required_keybodies(keybody_idxs):
return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1)
ref_body = self._fk_trans_fut[:T] # (T, num_bodies, 3)
ref_root = ref_body[:, self.root_body_idx, :] # (T, 3)
if self._keybody_rel_pos_fut_buffer.shape[1] != n_keybodies:
self._keybody_rel_pos_fut_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32)
self._keybody_rel_pos_w_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32)
elif (
self._keybody_rel_pos_w_buffer is None
or self._keybody_rel_pos_w_buffer.shape[0] < T
or self._keybody_rel_pos_w_buffer.shape[1] != n_keybodies
):
self._keybody_rel_pos_w_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32)
rel_pos_fut = self._keybody_rel_pos_fut_buffer
np.subtract(
ref_body[:, keybody_idxs, :],
ref_root[:, None, :],
out=self._keybody_rel_pos_w_buffer[:T, :n_keybodies, :],
)
rel_pos_fut[:, :, :] = self._quat_rotate_inv_wxyz(
self._fk_quat_fut_wxyz[:T, None, :],
self._keybody_rel_pos_w_buffer[:T, :n_keybodies, :],
)
return rel_pos_fut.reshape(-1).astype(np.float32)
keybody_idxs = self._get_ref_keybody_indices("actor_ref_keybody_rel_pos_fut")
n_keybodies = int(keybody_idxs.shape[0])
if not hasattr(self, "ref_raw_bodylink_pos") or self.ref_raw_bodylink_pos is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_pos is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_fut."
)
if n_keybodies == 0:
return np.zeros((T, 0), dtype=np.float32).reshape(-1)
return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1)
if not hasattr(self, "ref_raw_bodylink_rot") or self.ref_raw_bodylink_rot is None:
self.get_logger().warn(
"[VR] ref_raw_bodylink_rot is unavailable and latest_obs is not active; returning zeros for ref_keybody_rel_pos_fut."
)
if n_keybodies == 0:
return np.zeros((T, 0), dtype=np.float32).reshape(-1)
return np.zeros((T, n_keybodies, 3), dtype=np.float32).reshape(-1)
if n_keybodies == 0:
return np.zeros((T, 0), dtype=np.float32).reshape(-1)
fut_idx = self._get_future_frame_indices()
q_root_wxyz = self._get_future_root_quat_wxyz()
ref_body_global_pos = np.asarray(self.ref_raw_bodylink_pos[fut_idx], dtype=np.float32)
ref_root_global_pos = ref_body_global_pos[:, self.root_body_idx, :]
rel_pos_w = ref_body_global_pos[:, keybody_idxs, :] - ref_root_global_pos[:, None, :]
if self._keybody_rel_pos_fut_buffer.shape[1] != n_keybodies:
self._keybody_rel_pos_fut_buffer = np.zeros((T, n_keybodies, 3), dtype=np.float32)
rel_pos_fut = self._keybody_rel_pos_fut_buffer
rel_pos_fut[:, :, :] = self._quat_rotate_inv_wxyz(q_root_wxyz[:, None, :], rel_pos_w)
return rel_pos_fut.reshape(-1).astype(np.float32)
def _get_obs_place_holder(self):
return np.zeros(self.actor_place_holder_ndim, dtype=np.float32)
# =========== Policy Obeservation Methods ===========
def _warmup_fk_for_vr(self):
"""Run one FK warmup step when entering VR motion mode."""
try:
if (
getattr(self, "fk", None) is None
or not getattr(self, "fk_initialized", False)
):
return
if getattr(self, "external_latest_obs", None) is None:
return
if getattr(self, "external_fut_dof_pos_queue", None) is None:
return
n_fut = int(getattr(self, "n_fut_frames", 0))
if (
n_fut <= 0
or self.external_fut_root_pos_queue is None
or self.external_fut_root_rot_queue is None
):
return
latest = self.external_latest_obs[0]
cur_root_pos = latest[58:61]
cur_root_rot = latest[61:65]
cur_dof_pos = latest[0:29]
root_pos_tensor, root_rot_tensor, dof_pos_tensor = (
self._prepare_vr_fk_tensors(
cur_root_pos=cur_root_pos,
cur_root_rot=cur_root_rot,
cur_dof_pos=cur_dof_pos,
n_fut=n_fut,
)
)
fk_out = self.fk(
root_pos=root_pos_tensor,
root_quat=root_rot_tensor,
dof_pos=dof_pos_tensor,
fps=float(1.0 / self.dt),
quat_format="wxyz",
vel_smoothing_sigma=0.0,
compute_velocity=False,
)
self._fk_vr_out = {
k: v.detach().cpu().numpy() for k, v in fk_out.items()
}
except Exception as e:
self.get_logger().warn(f"[VR] FK warmup failed, fallback to zeros: {e}")
def _low_state_callback(self, ls_msg: LowState):
"""Process low-level robot state and remote controller input.
Main callback that handles:
- Remote controller input processing
- Motion selection based on button presses
- Safety state checking
- Velocity command extraction
Motion Button Mapping:
- A button: Enable policy (defaults to velocity mode)
- B button: Switch from velocity to motion mode
- Y button: Switch from motion back to velocity mode
- UP/DOWN/LEFT/RIGHT: Motion clip selection (only in velocity tracking mode)
Args:
ls_msg: LowState message containing robot sensor data and remote controller input
"""
self._lowstate_msg = ls_msg
self.remote_controller.set(ls_msg.wireless_remote)
# A button: Toggle policy enable state (default to velocity mode)
if (
self._is_button_pressed(KeyMap.A) and self.robot_state_ready
):
self.policy_enabled = True
self.current_policy_mode = "velocity" # Default to velocity mode
self.latest_obs_flag = False
self._reset_motion_action_ema_filter()
self._reset_counter()
if hasattr(self, "use_kv_cache") and self.use_kv_cache:
self.motion_kv_cache.fill(0)
self.motion_step_idx = 0
# Initialize with velocity model's default angles
self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.target_dof_pos_onnx = self.velocity_default_angles_onnx.copy()
# Publish velocity model's control parameters (kps/kds)
self._publish_control_params()
self.get_logger().info(
f"Policy enabled in {self.current_policy_mode} tracking mode"
)
# B button: Switch to motion tracking mode (only when policy is enabled)
if (
self._is_button_pressed(KeyMap.B)
and self.robot_state_ready
and self.policy_enabled
and self.current_policy_mode == "velocity" # Only allow switch from velocity mode
):
vr_data_available = bool(
getattr(self, "enable_teleop_reference", True)
and getattr(self, "external_obs_received", False)
and getattr(self, "external_latest_obs", None) is not None
)
vr_ready = self._is_vr_ready_for_motion()
if (
self.enable_teleop_reference
and self.require_vr_data_for_motion
and not vr_ready
):
self.get_logger().warn(
"require_vr_data_for_motion=True but the VR queue is not ready yet; staying in velocity mode."
)
else:
# Don't automatically switch to next motion clip - keep current selection
if hasattr(self, "all_motion_data") and self.all_motion_data:
# Load the current motion clip data (don't change current_motion_clip_index)
self._load_current_motion()
self.current_policy_mode = "motion"
self._reset_motion_action_ema_filter()
self._reset_counter()
if hasattr(self, "use_kv_cache") and self.use_kv_cache:
self.motion_kv_cache.fill(0)
self.get_logger().info("Motion KV-Cache reset.")
self.motion_step_idx = 0
self.get_logger().info("Motion Step Index reset to 0.")
# Clear any pending actions to prevent conflicts between policies
# Use motion model's default angles
self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.target_dof_pos_onnx = self.motion_default_angles_onnx.copy()
# Publish motion model's control parameters (kps/kds)
self._publish_control_params()
self.latest_obs_flag = bool(vr_data_available)
source_mode = "ZMQ latest_obs" if self.latest_obs_flag else "offline motion"
self.get_logger().info(
f"Switched to motion tracking mode ({source_mode}) - motion clip index: {self.current_motion_clip_index}"
)
if self.latest_obs_flag:
self.get_logger().info("[VR] Reference trajectory source: ZMQ latest_obs")
self._warmup_fk_for_vr()
self.motion_in_progress = True
if (
self._is_button_pressed(KeyMap.Y)
and self.robot_state_ready
and self.policy_enabled
and self.current_policy_mode == "motion" # Only allow switch from motion mode
):
self._switch_to_velocity_mode()
# Get velocity commands only in velocity tracking mode
if self.current_policy_mode == "velocity":
self.vx, self.vy, self.vyaw = self.remote_controller.get_velocity_commands()
else:
# In motion tracking mode, ignore joystick input
self.vx, self.vy, self.vyaw = 0.0, 0.0, 0.0
# Handle motion clip selection in velocity tracking mode (UP/DOWN/LEFT/RIGHT)
if (
self.current_policy_mode == "velocity"
and self.policy_enabled
and self.robot_state_ready
):
# Handle motion clip selection with UP/DOWN/LEFT/RIGHT buttons
if self._is_button_pressed(KeyMap.up):
# Switch to previous motion clip
if hasattr(self, "all_motion_data") and self.all_motion_data:
self.current_motion_clip_index = (
self.current_motion_clip_index - 1
) % len(self.all_motion_data)
self.get_logger().info(
f"Selected previous motion clip: "
f"{self.motion_file_names[self.current_motion_clip_index]}"
)
elif self._is_button_pressed(KeyMap.down):
# Switch to next motion clip
if hasattr(self, "all_motion_data") and self.all_motion_data:
self.current_motion_clip_index = (
self.current_motion_clip_index + 1
) % len(self.all_motion_data)
self.get_logger().info(
f"Selected next motion clip: "
f"{self.motion_file_names[self.current_motion_clip_index]}"
)
elif self._is_button_pressed(KeyMap.left):
# Select first motion clip
if hasattr(self, "all_motion_data") and self.all_motion_data:
self.current_motion_clip_index = 0
self.get_logger().info(
f"Selected first motion clip: "
f"{self.motion_file_names[self.current_motion_clip_index]}"
)
elif self._is_button_pressed(KeyMap.right):
# Select last motion clip
if hasattr(self, "all_motion_data") and self.all_motion_data:
self.current_motion_clip_index = len(self.all_motion_data) - 1
self.get_logger().info(
f"Selected last motion clip: "
f"{self.motion_file_names[self.current_motion_clip_index]}"
)
def run(self):
"""Main execution loop for policy inference and action publication."""
# Only run if setup is completed
if not hasattr(self, '_setup_completed') or not self._setup_completed:
return
t_loop_start = time.perf_counter()
now = time.time()
t_io = time.perf_counter()
buf = getattr(self, "_ros_latest_obs_buffer", None)
if buf is not None:
self._ros_latest_obs_buffer = None
frame_idx, obs_arr = buf
if frame_idx is not None:
self._npz_replay_frame_index = frame_idx
self._store_external_latest_obs(obs_arr[None, :])
self._poll_zmq_latest_obs()
if getattr(self, "current_policy_mode", None) == "motion":
if self._last_vr_status_log_time is None:
self._last_vr_status_log_time = now
elif now - self._last_vr_status_log_time >= 5.0:
vr_available = bool(
getattr(self, "external_obs_received", False)
and getattr(self, "external_latest_obs", None) is not None
)
queue_stats = self._latest_obs_buffer.get_queue_stats()
freq = queue_stats.get("expected_freq")
if vr_available:
self.get_logger().info(
"[VR-STATUS] ZMQ latest_obs streaming | "
f"buffer_size={queue_stats['queue_size']} "
f"expected_freq={freq:.1f}Hz" if freq else
f"buffer_size={queue_stats['queue_size']} expected_freq=unknown"
)
else:
self.get_logger().warn(
"[VR-STATUS] No new ZMQ latest_obs received in the last 5 seconds; using offline reference or the last buffered VR state."
)
self._last_vr_status_log_time = now
if (
getattr(self, "require_vr_data_for_motion", False)
and getattr(self, "policy_enabled", False)
and not getattr(self, "_vr_ready_logged", False)
and self._is_vr_ready_for_motion()
):
self.get_logger().info(
f"[VR] VR queue is ready for motion mode (seen_frames={int(getattr(self, '_external_seen_frames', 0))}, "
f"n_fut={int(getattr(self, 'n_fut_frames', 0) or 0)}, "
f"delay={int(getattr(self, 'zmq_jitter_delay_frames', 0) or 0)})"
)
self._vr_ready_logged = True
self._publish_latest_obs()
io_ms = self._timing_ms(t_io)
policy_timing = self._run_without_profiling()
_run_elapsed = 0.0
if policy_timing is not None:
_run_elapsed = float(policy_timing.get("policy_total_ms", 0.0)) / 1000.0
if (
getattr(self, "current_policy_mode", None) == "motion"
and getattr(self, "latest_obs_flag", False)
and _run_elapsed > 0.5
and not getattr(self, "_vr_cold_start_logged", False)
):
self._vr_cold_start_logged = True
self.get_logger().info(
"[VR] The first motion step is a cold start (FK/ONNX initialization) and may take about 1 second."
)
if (
getattr(self, "current_policy_mode", None) == "motion"
and getattr(self, "latest_obs_flag", False)
and _run_elapsed > 1.15 * self.dt
and _run_elapsed <= 0.5
):
self._policy_slow_count = getattr(self, "_policy_slow_count", 0) + 1
if self._policy_slow_count == 1 or self._policy_slow_count % 50 == 0:
self.get_logger().warn(
f"[VR] Policy step latency {_run_elapsed*1000:.1f} ms exceeds the target {self.dt*1000:.1f} ms. "
f"Estimated /humanoid/action rate: {1.0/_run_elapsed:.1f} Hz (target {1.0/self.dt:.0f} Hz). "
"The main bottleneck is usually FK or ONNX inference; if the system settles near 30 Hz, consider setting policy_freq to 30."
)
if policy_timing is not None:
sample = dict(policy_timing)
sample["io_ms"] = io_ms
sample["loop_total_ms"] = self._timing_ms(t_loop_start)
self._record_timing_sample(sample)
def _read_onnx_metadata(self, onnx_model_path: str) -> dict:
"""Read model metadata from ONNX file and parse into Python types."""
model = onnx.load(str(onnx_model_path))
meta = {p.key: p.value for p in model.metadata_props}
def _parse_floats(csv_str: str):
return np.array(
[float(x) for x in csv_str.split(",") if x != ""],
dtype=np.float32,
)
result = {}
result["action_scale"] = _parse_floats(meta["action_scale"])
result["kps"] = _parse_floats(meta["joint_stiffness"])
result["kds"] = _parse_floats(meta["joint_damping"])
result["default_joint_pos"] = _parse_floats(meta["default_joint_pos"])
result["joint_names"] = [x for x in meta["joint_names"].split(",") if x != ""]
return result
def _store_external_latest_obs(self, arr: np.ndarray):
"""Store latest_obs and maintain the current/future frame queues."""
if arr.ndim == 1:
arr = arr[None, :]
if arr.shape[1] < self.latest_obs_expected_dim:
self.get_logger().warn(
f"Received latest_obs dim={arr.shape[1]}, expected >= {self.latest_obs_expected_dim}"
)
return
clipped = arr[:, : self.latest_obs_expected_dim].astype(np.float32, copy=False)
current_time = time.time()
self.external_latest_obs = clipped
self.external_obs_received = True
self.last_external_obs_time = current_time
self._external_seen_frames = int(getattr(self, "_external_seen_frames", 0)) + 1
latest_root_pos = clipped[0, 58:61]
latest_root_rot = clipped[0, 61:65]
latest_dof_pos = clipped[0, :29]
latest_dof_vel = clipped[0, 29:58]
if self.n_fut_frames > 0 and self.external_fut_dof_pos_queue is not None:
raw_idx = getattr(self, "_npz_replay_frame_index", None)
try:
latest_frame_idx = int(raw_idx) if raw_idx is not None else -1
except Exception:
latest_frame_idx = -1
if self._prev_external_dof_pos is None:
self._prev_external_dof_pos = np.empty_like(self.external_fut_dof_pos_queue[0])
self._prev_external_dof_vel = np.empty_like(self.external_fut_dof_vel_queue[0])
self._prev_external_root_pos = np.empty_like(self.external_fut_root_pos_queue[0])
if self.external_fut_root_rot_queue is not None:
self._prev_external_root_rot = np.empty_like(
self.external_fut_root_rot_queue[0]
)
np.copyto(self._prev_external_dof_pos, self.external_fut_dof_pos_queue[0])
np.copyto(self._prev_external_dof_vel, self.external_fut_dof_vel_queue[0])
np.copyto(self._prev_external_root_pos, self.external_fut_root_pos_queue[0])
if self.external_fut_root_rot_queue is not None:
np.copyto(self._prev_external_root_rot, self.external_fut_root_rot_queue[0])
if self.external_fut_frame_idx_queue is not None:
try:
self._prev_external_frame_idx = int(self.external_fut_frame_idx_queue[0])
except Exception:
self._prev_external_frame_idx = -1
self.external_fut_dof_pos_queue[:-1] = self.external_fut_dof_pos_queue[1:]
self.external_fut_dof_pos_queue[-1] = latest_dof_pos
self.external_fut_dof_vel_queue[:-1] = self.external_fut_dof_vel_queue[1:]
self.external_fut_dof_vel_queue[-1] = latest_dof_vel
self.external_fut_root_pos_queue[:-1] = self.external_fut_root_pos_queue[1:]
self.external_fut_root_pos_queue[-1] = latest_root_pos
if self.external_fut_root_rot_queue is not None:
self.external_fut_root_rot_queue[:-1] = self.external_fut_root_rot_queue[1:]
self.external_fut_root_rot_queue[-1] = latest_root_rot
if self.external_fut_frame_idx_queue is not None:
self.external_fut_frame_idx_queue[:-1] = self.external_fut_frame_idx_queue[1:]
self.external_fut_frame_idx_queue[-1] = latest_frame_idx
def _poll_zmq_latest_obs(self):
"""Poll the ZMQ latest_obs buffer with stale-data checks and delay."""
current_time = time.time()
data, timestamp, is_stale, frame_index, sender_timestamp = self._latest_obs_buffer.get_with_age_and_delay(
max_age=self.max_data_age,
delay_steps=int(getattr(self, "zmq_jitter_delay_frames", 0)),
)
if data is None:
return
if frame_index is not None:
self._npz_replay_frame_index = int(frame_index)
self._latest_sender_timestamp = sender_timestamp
if is_stale:
self.stale_data_warning_count += 1
if self.stale_data_warning_count % 50 == 0:
age_ms = (current_time - timestamp) * 1000.0
self.get_logger().warn(
f"ZMQ latest_obs is stale: age={age_ms:.1f}ms "
f"(threshold={self.max_data_age*1000:.1f}ms), "
f"stale_count={self.stale_data_warning_count}"
)
queue_stats = self._latest_obs_buffer.get_queue_stats()
if queue_stats.get("expected_freq"):
self.get_logger().warn(
f"latest_obs buffer: size={queue_stats['queue_size']}, "
f"avg_interval={queue_stats['avg_interval']*1000:.1f}ms, "
f"expected_freq={queue_stats['expected_freq']:.1f}Hz"
)
else:
if self.stale_data_warning_count > 0:
self.stale_data_warning_count = 0
if self.last_poll_time is not None:
poll_interval = current_time - self.last_poll_time
if poll_interval > 0.03:
self.get_logger().debug(
f"Policy poll interval {poll_interval*1000:.1f}ms (>30ms)"
)
self.last_poll_time = current_time
self._store_external_latest_obs(np.asarray(data, dtype=np.float32))
if (
getattr(self, "enable_teleop_reference", True)
and getattr(self, "require_vr_data_for_motion", False)
and not getattr(self, "latest_obs_flag", False)
and self._is_vr_ready_for_motion()
):
self.latest_obs_flag = True
if not getattr(self, "_vr_fk_started_logged", False):
self.get_logger().info(
"[VR] ZMQ data is ready; the main thread will build the reference trajectory from live ZMQ input."
)
self._vr_fk_started_logged = True
def _publish_latest_obs(self):
"""Publish the latest_obs topic for debugging or reuse."""
if self.external_latest_obs is None:
return
try:
msg = Float32MultiArray()
msg.data = self.external_latest_obs[0].tolist()
self.latest_obs_pub.publish(msg)
except Exception as e:
self.get_logger().error(f"Failed to publish latest_obs: {e}")
def _apply_onnx_metadata(self):
"""Apply PD/scale/defaults from ONNX metadata as authoritative values.
Load separate metadata for velocity and motion models."""
# Load velocity model metadata
velocity_meta = self._read_onnx_metadata(self.velocity_onnx_path)
self.velocity_dof_names_onnx = velocity_meta["joint_names"]
self.velocity_action_scale_onnx = velocity_meta["action_scale"].astype(np.float32)
self.velocity_kps_onnx = velocity_meta["kps"].astype(np.float32)
self.velocity_kds_onnx = velocity_meta["kds"].astype(np.float32)
self.velocity_default_angles_onnx = velocity_meta["default_joint_pos"].astype(np.float32)
# Load motion model metadata
motion_meta = self._read_onnx_metadata(self.motion_onnx_path)
self.motion_dof_names_onnx = motion_meta["joint_names"]
self.motion_action_scale_onnx = motion_meta["action_scale"].astype(np.float32)
self.motion_kps_onnx = motion_meta["kps"].astype(np.float32)
self.motion_kds_onnx = motion_meta["kds"].astype(np.float32)
self.motion_default_angles_onnx = motion_meta["default_joint_pos"].astype(np.float32)
# Use velocity model metadata as default (for backward compatibility)
self.dof_names_onnx = self.velocity_dof_names_onnx
self.action_scale_onnx = self.velocity_action_scale_onnx
self.kps_onnx = self.velocity_kps_onnx
self.kds_onnx = self.velocity_kds_onnx
self.default_angles_onnx = self.velocity_default_angles_onnx
self.default_angles_dict = {
name: float(self.default_angles_onnx[idx])
for idx, name in enumerate(self.dof_names_onnx)
}
def _build_dof_mappings(self):
# Map ONNX <-> MJCF for control
# Check if all ONNX names exist in real_dof_names (use velocity as reference)
missing_names = [name for name in self.velocity_dof_names_onnx if name not in self.real_dof_names]
if missing_names:
self.get_logger().warn(f"Missing names in real_dof_names: {missing_names}")
# Build mappings for velocity model
self.velocity_onnx_to_real = [
self.velocity_dof_names_onnx.index(name) for name in self.real_dof_names
]
self.velocity_kps_real = self.velocity_kps_onnx[self.velocity_onnx_to_real].astype(np.float32)
self.velocity_kds_real = self.velocity_kds_onnx[self.velocity_onnx_to_real].astype(np.float32)
# Build mappings for motion model
self.motion_onnx_to_real = [
self.motion_dof_names_onnx.index(name) for name in self.real_dof_names
]
self.motion_kps_real = self.motion_kps_onnx[self.motion_onnx_to_real].astype(np.float32)
self.motion_kds_real = self.motion_kds_onnx[self.motion_onnx_to_real].astype(np.float32)
# Use velocity model mappings as default (for backward compatibility)
self.onnx_to_real = self.velocity_onnx_to_real
self.kps_real = self.velocity_kps_real
self.kds_real = self.velocity_kds_real
self.default_angles_mu = self.velocity_default_angles_onnx[self.velocity_onnx_to_real].astype(np.float32)
self.action_scale_mu = self.velocity_action_scale_onnx[self.velocity_onnx_to_real].astype(np.float32)
# Build ref_to_onnx mapping (for motion model)
self.ref_to_onnx = [
self.dof_names_ref_motion.index(name) for name in self.motion_dof_names_onnx
]
# Pre-compute default angles dictionaries for efficient observation building
self.velocity_default_angles_dict = {
name: float(self.velocity_default_angles_onnx[idx])
for idx, name in enumerate(self.velocity_dof_names_onnx)
}
self.motion_default_angles_dict = {
name: float(self.motion_default_angles_onnx[idx])
for idx, name in enumerate(self.motion_dof_names_onnx)
}
# Pre-compute dof_names_onnx arrays for each mode (avoid repeated selection)
self.velocity_dof_names_onnx_array = np.array(self.velocity_dof_names_onnx)
self.motion_dof_names_onnx_array = np.array(self.motion_dof_names_onnx)
self.motion_dof_real_indices = [
self.real_dof_names.index(n) for n in self.motion_dof_names_onnx
]
self.velocity_dof_real_indices = [
self.real_dof_names.index(n) for n in self.velocity_dof_names_onnx
]
n_dof = max(len(self.motion_dof_names_onnx), len(self.velocity_dof_names_onnx))
self._dof_pos_obs_buffer = np.zeros(n_dof, dtype=np.float32)
self._dof_vel_obs_buffer = np.zeros(n_dof, dtype=np.float32)
# Pre-allocate arrays for future frame observations
if hasattr(self, "n_fut_frames") and self.n_fut_frames is not None:
self.n_fut_frames_int = int(self.n_fut_frames)
if self.n_fut_frames_int > 0:
self._pos_fut_buffer = np.zeros(
(len(self.dof_names_ref_motion), self.n_fut_frames_int), dtype=np.float32
)
self._h_fut_buffer = np.zeros((1, self.n_fut_frames_int), dtype=np.float32)
self._root_pos_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32)
else:
self.n_fut_frames_int = 0
else:
self.n_fut_frames_int = 0
self._future_frame_offsets = np.arange(1, self.n_fut_frames_int + 1, dtype=np.int64)
self._future_frame_indices_buffer = np.zeros(self.n_fut_frames_int, dtype=np.int64)
self._future_root_quat_wxyz_buffer = np.zeros((self.n_fut_frames_int, 4), dtype=np.float32)
self._gravity_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32)
self._base_linvel_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32)
self._base_angvel_fut_buffer = np.zeros((self.n_fut_frames_int, 3), dtype=np.float32)
self._keybody_rel_pos_fut_buffer = np.zeros((self.n_fut_frames_int, 0, 3), dtype=np.float32)
self._keybody_rel_pos_w_buffer = None
max_t = max(1, self.n_fut_frames_int)
self._vel_fut_T6 = np.zeros((max_t, 6), dtype=np.float32)
self._rot_t_buffer = np.zeros((max_t, 3), dtype=np.float32)
self._rot_cross_buffer = np.zeros((max_t, 3), dtype=np.float32)
self._use_fk_vr = False
self._fk_vel_0_root = np.zeros(3, dtype=np.float32)
self._fk_angvel_0_root = np.zeros(3, dtype=np.float32)
self._fk_quat_0_root = np.zeros(4, dtype=np.float32)
self._fk_trans_0 = None
max_t = max(1, self.n_fut_frames_int)
self._fk_vel_fut = np.zeros((max_t, 3), dtype=np.float32)
self._fk_angvel_fut = np.zeros((max_t, 3), dtype=np.float32)
self._fk_quat_fut = np.zeros((max_t, 4), dtype=np.float32)
self._fk_trans_fut = None
self._q_conj_buffer = np.zeros((max_t + 1, 4), dtype=np.float32)
self._rotated_3vec_buffer = np.zeros(3, dtype=np.float32)
self._rotated_angvel_cur_buffer = np.zeros(3, dtype=np.float32)
self._cross_t_buffer = np.zeros(3, dtype=np.float32)
self._fk_quat_0_root_wxyz = np.zeros(4, dtype=np.float32)
self._fk_quat_fut_wxyz = np.zeros((max_t, 4), dtype=np.float32)
# Pre-allocate velocity command observation array
self._velocity_cmd_obs = np.zeros(4, dtype=np.float32)
# Publish kps and kds parameters (use velocity as default)
self._publish_control_params()
def _publish_control_params(self):
"""Publish kps and kds control parameters based on current policy mode.
Called during initialization and mode switching to ensure control node
receives the correct parameters for the current policy mode.
"""
try:
# Use appropriate parameters based on current policy mode
if self.current_policy_mode == "motion":
current_kps = self.motion_kps_real
current_kds = self.motion_kds_real
else: # velocity mode
current_kps = self.velocity_kps_real
current_kds = self.velocity_kds_real
# Publish kps
kps_msg = Float32MultiArray()
kps_msg.data = current_kps.tolist()
self.kps_pub.publish(kps_msg)
# Publish kds
kds_msg = Float32MultiArray()
kds_msg.data = current_kds.tolist()
self.kds_pub.publish(kds_msg)
self.get_logger().info(
f"Published control parameters ({self.current_policy_mode} mode): "
f"kps={len(current_kps)}, kds={len(current_kds)}"
)
except Exception as e:
self.get_logger().error(f"Failed to publish control parameters: {e}")
def _publish_policy_mode(self):
"""Publish current policy mode status."""
try:
mode_msg = String()
mode_msg.data = f"{self.current_policy_mode}_{'enabled' if self.policy_enabled else 'disabled'}"
self.policy_mode_pub.publish(mode_msg)
except Exception as e:
self.get_logger().error(f"Failed to publish policy mode: {e}")
def _timing_ms(self, t0: float) -> float:
return (time.perf_counter() - t0) * 1000.0
def _record_timing_sample(self, sample: dict):
if not getattr(self, "timing_debug_enabled", False):
return
self._timing_debug_samples.append(sample)
if getattr(self, "timing_debug_log_per_loop", False):
self.get_logger().info(
"[Timing] "
f"loop_total={sample['loop_total_ms']:.2f}ms "
f"io={sample['io_ms']:.2f}ms "
f"policy_total={sample['policy_total_ms']:.2f}ms "
f"fk={sample['fk_ms']:.2f}ms "
f"obs={sample['obs_ms']:.2f}ms "
f"onnx={sample['onnx_ms']:.2f}ms "
f"post={sample['post_ms']:.2f}ms"
)
now = time.time()
last = getattr(self, "_timing_debug_last_log_time", None)
interval = float(getattr(self, "timing_debug_log_interval_sec", 5.0))
if last is None:
self._timing_debug_last_log_time = now
return
if now - last < interval:
return
if len(self._timing_debug_samples) == 0:
self._timing_debug_last_log_time = now
return
keys = [
"loop_total_ms",
"io_ms",
"policy_total_ms",
"fk_ms",
"obs_ms",
"onnx_ms",
"post_ms",
]
stats = {}
for key in keys:
vals = np.array(
[float(s.get(key, 0.0)) for s in self._timing_debug_samples],
dtype=np.float64,
)
stats[key] = (float(np.mean(vals)), float(np.max(vals)))
self.get_logger().info(
"[Timing-Agg] "
+ " ".join(
f"{key}=mean:{stats[key][0]:.2f}ms/max:{stats[key][1]:.2f}ms"
for key in keys
)
+ f" n={len(self._timing_debug_samples)}"
)
self._timing_debug_samples.clear()
self._timing_debug_last_log_time = now
def _root_only_fk_has_required_keybodies(self, keybody_idxs: np.ndarray) -> bool:
if keybody_idxs.size == 0:
return True
available_bodies = 0 if self._fk_trans_0 is None else int(self._fk_trans_0.shape[0])
if available_bodies <= int(np.max(keybody_idxs)):
if not self._root_only_fk_keybody_warned:
self.get_logger().warn(
"[RootOnlyFK] FK output only contains root body, but obs schema still "
"requests non-root keybody positions. Returning zeros for keybody obs."
)
self._root_only_fk_keybody_warned = True
return False
return True
def _run_without_profiling(self):
"""Run the main loop without performance profiling."""
if self._lowstate_msg is None or not self.policy_enabled:
return None
timing_info = {
"policy_total_ms": 0.0,
"fk_ms": 0.0,
"obs_ms": 0.0,
"onnx_ms": 0.0,
"post_ms": 0.0,
}
_t_policy_start = time.perf_counter()
if self.current_policy_mode == "motion":
if self.latest_obs_flag:
current_time = time.time()
if self.last_external_obs_time is None:
data_age = float("inf")
else:
data_age = current_time - self.last_external_obs_time
if data_age > self.max_data_age:
self.get_logger().warn(
f"ZMQ latest_obs is stale: age={data_age*1000:.1f}ms > {self.max_data_age*1000:.1f}ms; "
"switching to velocity tracking mode for safety."
)
self._switch_to_velocity_mode(reason="VR latest_obs stale")
return None
if not self.latest_obs_flag and (
not hasattr(self, "n_motion_frames") or not hasattr(self, "ref_dof_pos")
):
self.get_logger().warn("Motion data not loaded, skipping policy execution")
return None
if (
self.latest_obs_flag
and self.fk is not None
and self.external_fut_dof_pos_queue is not None
):
try:
n_fut = int(getattr(self, "n_fut_frames", 0))
if (
n_fut > 0
and self.external_fut_root_pos_queue is not None
and self.external_fut_root_rot_queue is not None
):
t_fk = time.perf_counter()
cur_root_pos = self.ref_root_pos_raw.astype(np.float32)
cur_root_rot = (
self._prev_external_root_rot
if self._prev_external_root_rot is not None
else self.external_fut_root_rot_queue[0].astype(np.float32)
)
cur_dof_pos = self.ref_dof_pos_raw.astype(np.float32)
root_pos_tensor, root_rot_tensor, dof_pos_tensor = (
self._prepare_vr_fk_tensors(
cur_root_pos=cur_root_pos,
cur_root_rot=cur_root_rot,
cur_dof_pos=cur_dof_pos,
n_fut=n_fut,
)
)
fk_out = self.fk(
root_pos=root_pos_tensor,
root_quat=root_rot_tensor,
dof_pos=dof_pos_tensor,
fps=float(1.0 / self.dt),
quat_format="wxyz",
vel_smoothing_sigma=0.0,
compute_velocity=False,
)
self._fk_vr_out = {
k: v.detach().cpu().numpy() for k, v in fk_out.items()
}
timing_info["fk_ms"] = self._timing_ms(t_fk)
else:
self._fk_vr_out = None
except Exception as e:
self.get_logger().error(
f"VR FK computation failed; falling back to offline reference: {e}"
)
self._fk_vr_out = None
self.obs_builder = self.motion_obs_builder
# Use motion model metadata
current_action_scale = self.motion_action_scale_onnx
current_default_angles = self.motion_default_angles_onnx
current_onnx_to_real = self.motion_onnx_to_real
else: # velocity mode
self.obs_builder = self.velocity_obs_builder
# Use velocity model metadata
current_action_scale = self.velocity_action_scale_onnx
current_default_angles = self.velocity_default_angles_onnx
current_onnx_to_real = self.velocity_onnx_to_real
t_obs = time.perf_counter()
if self.current_policy_mode == "motion":
self._cache_fk_vr_for_obs()
policy_obs_np = self.obs_builder.build_policy_obs()[None, :].astype(
np.float32, copy=False
)
timing_info["obs_ms"] = self._timing_ms(t_obs)
# Run ONNX inference with the appropriate policy session and correct input/output names
t_onnx = time.perf_counter()
if self.current_policy_mode == "velocity":
input_feed = {self.velocity_input_name: policy_obs_np}
onnx_output = self.velocity_policy_session.run([self.velocity_output_name], input_feed)
else: # motion mode
if self.use_kv_cache:
if self.motion_kv_cache is None:
shape = [
d if isinstance(d, int) else 1
for d in self.motion_kv_shape
]
self.motion_kv_cache = np.zeros(shape, dtype=self.motion_kv_dtype)
# if (
# self.motion_effective_context_len > 0
# and self.motion_step_idx > 0
# and self.motion_step_idx % self.motion_effective_context_len == 0
# ):
# self.motion_kv_cache.fill(0.0)
input_feed = {
self.motion_input_name: policy_obs_np,
self.motion_kv_input_name: self.motion_kv_cache,
}
if self.motion_step_idx_input_name is not None:
step_idx = self.motion_step_idx
# if self.motion_effective_context_len > 0:
# step_idx = (
# self.motion_step_idx
# % self.motion_effective_context_len
# )
input_feed[self.motion_step_idx_input_name] = np.array(
[step_idx], dtype=np.int64
)
output_names = [self.motion_output_name]
if self.motion_kv_output_name:
output_names.append(self.motion_kv_output_name)
onnx_output = self.motion_policy_session.run(
output_names, input_feed
)
if len(onnx_output) > 1:
self.motion_kv_cache = onnx_output[1]
self.motion_step_idx += 1
else:
input_feed = {self.motion_input_name: policy_obs_np}
onnx_output = self.motion_policy_session.run(
[self.motion_output_name], input_feed
)
timing_info["onnx_ms"] = self._timing_ms(t_onnx)
t_post = time.perf_counter()
raw_actions_onnx = np.asarray(onnx_output[0], dtype=np.float32).reshape(-1)
if self.current_policy_mode == "motion":
self.actions_onnx = self._apply_motion_action_ema_filter(raw_actions_onnx)
else:
self.actions_onnx = raw_actions_onnx.copy()
# Use the appropriate metadata based on current policy mode
self.target_dof_pos_onnx = (
self.actions_onnx * current_action_scale + current_default_angles
)
self.target_dof_pos_real = self.target_dof_pos_onnx[current_onnx_to_real]
# Action processing and publishing
self._process_and_publish_actions()
if self.current_policy_mode == "motion":
if (
not getattr(self, "latest_obs_flag", False)
and self.motion_frame_idx >= self.n_motion_frames
and self.motion_in_progress
):
self.get_logger().info("Motion action completed (offline reference)")
self.motion_in_progress = False
# Publish policy mode status
self._publish_policy_mode()
timing_info["post_ms"] = self._timing_ms(t_post)
timing_info["policy_total_ms"] = self._timing_ms(_t_policy_start)
return timing_info
def _process_and_publish_actions(self):
"""Process and publish action commands."""
if self.target_dof_pos_real is not None:
action_msg = Float32MultiArray()
action_msg.data = self.target_dof_pos_real.tolist()
# Check for NaN values
if np.isnan(self.target_dof_pos_real).any():
self.get_logger().error("Action contains NaN values")
self.action_pub.publish(action_msg)
self.motion_frame_idx += 1
def setup(self):
"""Set up the evaluator by loading all required components."""
main_affinity = _parse_cpu_affinity_str(
getattr(self, "_cpu_affinity_main_str", "") or ""
)
if main_affinity and set_thread_cpu_affinity(main_affinity):
self.get_logger().info(f"[Policy] main thread pinned to CPUs {main_affinity}")
self.load_model_config() # Load config first
self.update_config_parameters() # Update parameters from config
# Initialize FK for online VR reference reconstruction
self._init_fk()
self.load_policy() # Then load policies
self._apply_onnx_metadata()
self._init_obs_buffers()
self._build_dof_mappings()
self._warmup_motion_policy()
self._init_keybody_indices_cache()
# Always load motion data since we support both modes
self.load_motion_data()
self.get_logger().info("Synchronous root-only policy setup completed")
def _init_fk(self):
"""Initialize lightweight root-only FK for synchronous VR reference updates."""
try:
self.get_logger().info(
"Initializing root-only FK (no URDF, sync main-thread mode)"
)
self.fk = HoloMotionFKRootOnly(
dof_names=self.dof_names_ref_motion,
device="cpu",
timing_logger_enabled=True,
timing_log_interval_sec=5.0,
timing_log_per_call=False,
timing_name="FKRootOnlyVR",
timing_log_fn=self.get_logger().info,
)
try:
ndof = len(self.fk.dof_names)
root_pos_dummy = torch.zeros((1, 4, 3), dtype=torch.float32)
root_quat_dummy = torch.zeros((1, 4, 4), dtype=torch.float32)
root_quat_dummy[..., 0] = 1.0
dof_pos_dummy = torch.zeros((1, 4, ndof), dtype=torch.float32)
_ = self.fk(
root_pos=root_pos_dummy,
root_quat=root_quat_dummy,
dof_pos=dof_pos_dummy,
fps=float(1.0 / self.dt),
quat_format="wxyz",
vel_smoothing_sigma=0.0,
compute_velocity=False,
)
self.get_logger().info("[FK] Root-only warmup completed (B=1,T=4)")
except Exception as e_dummy:
self.get_logger().warn(f"[FK] Root-only warmup failed (ignored): {e_dummy}")
self.fk_initialized = True
self.get_logger().info(
f"Root-only FK initialized successfully with {len(self.fk.dof_names)} dofs"
)
except Exception as e:
self.get_logger().error(f"Failed to initialize root-only FK: {e}")
self.fk = None
self.fk_initialized = False
def destroy_node(self):
try:
if getattr(self, "_zmq_subscriber", None) is not None:
self._zmq_subscriber.stop()
except Exception:
pass
super().destroy_node()
def get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray:
"""Calculate gravity orientation from quaternion.
Args:
quaternion: Array-like [w, x, y, z]
Returns:
np.ndarray of shape (3,) representing gravity projection.
"""
qw = float(quaternion[0])
qx = float(quaternion[1])
qy = float(quaternion[2])
qz = float(quaternion[3])
gravity_orientation = np.zeros(3, dtype=np.float32)
gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2.0 * (qz * qy + qw * qx)
gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz)
return gravity_orientation
def main():
"""Main entry point for the policy node."""
rclpy.init()
policy_node = HoloMotionPolicyNode()
rclpy.spin(policy_node)
if __name__ == "__main__":
main()
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/__init__.py
================================================
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/command_helper.py
================================================
from unitree_sdk2py.idl.unitree_go.msg.dds_ import LowCmd_ as LowCmdGo
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as LowCmdHG
from typing import Union
class MotorMode:
PR = 0 # Series Control for Pitch/Roll Joints
AB = 1 # Parallel Control for A/B Joints
def create_damping_cmd(cmd: Union[LowCmdGo, LowCmdHG]):
size = len(cmd.motor_cmd)
for i in range(size):
cmd.motor_cmd[i].q = 0
cmd.motor_cmd[i].qd = 0
cmd.motor_cmd[i].kp = 0
cmd.motor_cmd[i].kd = 8
cmd.motor_cmd[i].tau = 0
def create_zero_cmd(cmd: Union[LowCmdGo, LowCmdHG]):
size = len(cmd.motor_cmd)
for i in range(size):
cmd.motor_cmd[i].q = 0
cmd.motor_cmd[i].qd = 0
cmd.motor_cmd[i].kp = 0
cmd.motor_cmd[i].kd = 0
cmd.motor_cmd[i].tau = 0
def init_cmd_hg(cmd: LowCmdHG, mode_machine: int, mode_pr: int):
cmd.mode_machine = mode_machine
cmd.mode_pr = mode_pr
size = len(cmd.motor_cmd)
for i in range(size):
cmd.motor_cmd[i].mode = 1
cmd.motor_cmd[i].q = 0
cmd.motor_cmd[i].qd = 0
cmd.motor_cmd[i].kp = 0
cmd.motor_cmd[i].kd = 0
cmd.motor_cmd[i].tau = 0
def init_cmd_go(cmd: LowCmdGo, weak_motor: list):
cmd.head[0] = 0xFE
cmd.head[1] = 0xEF
cmd.level_flag = 0xFF
cmd.gpio = 0
PosStopF = 2.146e9
VelStopF = 16000.0
size = len(cmd.motor_cmd)
for i in range(size):
if i in weak_motor:
cmd.motor_cmd[i].mode = 1
else:
cmd.motor_cmd[i].mode = 0x0A
cmd.motor_cmd[i].q = PosStopF
cmd.motor_cmd[i].qd = VelStopF
cmd.motor_cmd[i].kp = 0
cmd.motor_cmd[i].kd = 0
cmd.motor_cmd[i].tau = 0
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/maths.py
================================================
import torch
import numpy as np
import random
import os
@torch.jit.script
def normalize(x, eps: float = 1e-9):
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
@torch.jit.script
def torch_rand_float(lower, upper, shape, device):
# type: (float, float, Tuple[int, int], str) -> Tensor
return (upper - lower) * torch.rand(*shape, device=device) + lower
@torch.jit.script
def copysign(a, b):
# type: (float, Tensor) -> Tensor
a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
return torch.abs(a) * torch.sign(b)
def set_seed(seed, torch_deterministic=False):
"""set seed across modules"""
if seed == -1 and torch_deterministic:
seed = 42
elif seed == -1:
seed = np.random.randint(0, 10000)
print("Setting seed: {}".format(seed))
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch_deterministic:
# refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
return seed
def to_torch(x, dtype=torch.float, device="cuda:0", requires_grad=False):
return torch.tensor(
x, dtype=dtype, device=device, requires_grad=requires_grad
)
@torch.compile
def quat_mul_legacy(
a: torch.Tensor, b: torch.Tensor, w_last: bool = True
) -> torch.Tensor:
"""Multiply two quaternions.
Args:
a (torch.Tensor): (..., 4) quaternion.
b (torch.Tensor): (..., 4) quaternion.
w_last (bool): Whether the scalar part w is the last element.
If True, format is [x, y, z, w]; if False, format is [w, x, y, z].
Returns:
torch.Tensor: (..., 4) quaternion result of a * b.
"""
assert a.shape == b.shape
shape = a.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 4)
if w_last:
# Format: [x, y, z, w]
x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
else:
# Format: [w, x, y, z]
w1, x1, y1, z1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
w2, x2, y2, z2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
if w_last:
quat = torch.stack([x, y, z, w], dim=-1).view(shape)
else:
quat = torch.stack([w, x, y, z], dim=-1).view(shape)
return quat
@torch.jit.script
def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""Multiply two quaternions together.
Args:
q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
Returns:
The product of the two quaternions in (w, x, y, z). Shape is (..., 4).
Raises:
ValueError: Input shapes of ``q1`` and ``q2`` are not matching.
"""
# check input is correct
if q1.shape != q2.shape:
msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}."
raise ValueError(msg)
# reshape to (N, 4) for multiplication
shape = q1.shape
q1 = q1.reshape(-1, 4)
q2 = q2.reshape(-1, 4)
# extract components from quaternions
w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
# perform multiplication
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return torch.stack([w, x, y, z], dim=-1).view(shape)
@torch.jit.script
def normalize(x, eps: float = 1e-9):
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
@torch.jit.script
def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Apply a quaternion rotation to a vector.
Args:
quat: The quaternion in (w, x, y, z). Shape is (..., 4).
vec: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
# store shape
shape = vec.shape
# reshape to (N, 3) for multiplication
quat = quat.reshape(-1, 4)
vec = vec.reshape(-1, 3)
# extract components from quaternions
xyz = quat[:, 1:]
t = xyz.cross(vec, dim=-1) * 2
return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_apply_inverse(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Apply an inverse quaternion rotation to a vector.
Args:
quat: The quaternion in (w, x, y, z). Shape is (..., 4).
vec: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
# store shape
shape = vec.shape
# reshape to (N, 3) for multiplication
quat = quat.reshape(-1, 4)
vec = vec.reshape(-1, 3)
# extract components from quaternions
xyz = quat[:, 1:]
t = xyz.cross(vec, dim=-1) * 2
return (vec - quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a + b + c
# @torch.jit.script
def quat_rotate_inverse(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a - b + c
@torch.jit.script
def quat_conjugate(a):
shape = a.shape
a = a.reshape(-1, 4)
# return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)
return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape)
@torch.jit.script
def quat_unit(a):
return normalize(a)
@torch.jit.script
def quat_from_angle_axis(angle, axis):
theta = (angle / 2).unsqueeze(-1)
xyz = normalize(axis) * theta.sin()
w = theta.cos()
return quat_unit(torch.cat([xyz, w], dim=-1))
@torch.jit.script
def normalize_angle(x):
return torch.atan2(torch.sin(x), torch.cos(x))
@torch.jit.script
def get_basis_vector(q, v):
return quat_rotate(q, v)
def get_axis_params(value, axis_idx, x_value=0.0, dtype=np.float64, n_dims=3):
"""Construct arguments to `Vec` according to axis index."""
zs = np.zeros((n_dims,))
assert axis_idx < n_dims, (
"the axis dim should be within the vector dimensions"
)
zs[axis_idx] = 1.0
params = np.where(zs == 1.0, value, zs)
params[0] = x_value
return list(params.astype(dtype))
# @torch.jit.script
# def copysign(a, b):
# a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
# return torch.abs(a) * torch.sign(b)
@torch.jit.script
def get_euler_xyz(q: torch.Tensor) -> tuple:
qx, qy, qz, qw = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = (
q[:, qw] * q[:, qw]
- q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
+ q[:, qz] * q[:, qz]
)
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(
torch.abs(sinp) >= 1,
copysign(torch.tensor(np.pi / 2.0, device=sinp.device), sinp),
torch.asin(sinp),
)
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = (
q[:, qw] * q[:, qw]
+ q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
- q[:, qz] * q[:, qz]
)
yaw = torch.atan2(siny_cosp, cosy_cosp)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)
@torch.jit.script
def quat_from_euler_xyz(roll, pitch, yaw):
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp
return torch.stack([qx, qy, qz, qw], dim=-1)
def torch_rand_float(lower, upper, shape, device):
return (upper - lower) * torch.rand(*shape, device=device) + lower
# @torch.jit.script
@torch.compile
def torch_random_dir_2(shape, device):
angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1)
return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
@torch.jit.script
def tensor_clamp(t, min_t, max_t):
return torch.max(torch.min(t, max_t), min_t)
@torch.jit.script
def scale(x, lower, upper):
return 0.5 * (x + 1.0) * (upper - lower) + lower
@torch.jit.script
def unscale(x, lower, upper):
return (2.0 * x - upper - lower) / (upper - lower)
def unscale_np(x, lower, upper):
return (2.0 * x - upper - lower) / (upper - lower)
@torch.jit.script
def quat_to_angle_axis(q):
# computes axis-angle representation from quaternion q
# q must be normalized
min_theta = 1e-5
qx, _, _, qw = 0, 1, 2, 3
sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])
angle = 2 * torch.acos(q[..., qw])
angle = normalize_angle(angle)
sin_theta_expand = sin_theta.unsqueeze(-1)
axis = q[..., qx:qw] / sin_theta_expand
mask = torch.abs(sin_theta) > min_theta
default_axis = torch.zeros_like(axis)
default_axis[..., -1] = 1
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
@torch.jit.script
def angle_axis_to_exp_map(angle, axis):
# compute exponential map from axis-angle
angle_expand = angle.unsqueeze(-1)
exp_map = angle_expand * axis
return exp_map
@torch.jit.script
def quat_to_exp_map(q):
# compute exponential map from quaternion
# q must be normalized
angle, axis = quat_to_angle_axis(q)
exp_map = angle_axis_to_exp_map(angle, axis)
return exp_map
@torch.jit.script
def slerp(q0, q1, t):
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = q1.clone()
q1[neg_mask] = -q1[neg_mask]
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta)
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)
ratio_a = torch.sin((1 - t) * half_theta) / sin_half_theta
ratio_b = torch.sin(t * half_theta) / sin_half_theta
new_q = ratio_a * q0 + ratio_b * q1
new_q = torch.where(
torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q
)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
@torch.jit.script
def my_quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a + b + c
@torch.jit.script
def calc_heading(q):
# calculate heading direction from quaternion
# the heading is the direction on the xy plane
# q must be normalized
# this is the x axis heading
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., 0] = 1
rot_dir = my_quat_rotate(q, ref_dir)
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])
return heading
@torch.jit.script
def calc_heading_quat(q):
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(heading, axis)
return heading_q
@torch.jit.script
def calc_heading_quat_inv(q):
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(-heading, axis)
return heading_q
@torch.compile
def axis_angle_from_quat(
quat: torch.Tensor,
w_last: bool = True,
) -> torch.Tensor:
"""Compute axis-angle (log map) vector from a quaternion.
Args:
quat (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z].
w_last (bool): Whether the scalar part w is the last element.
Returns:
torch.Tensor: (..., 3) axis-angle vector (axis * angle), with angle in radians in [0, pi].
Notes:
- The quaternion is sign-adjusted to ensure w >= 0 and normalized to unit length for numerical stability.
- Uses a stable small-angle handling to avoid NaNs and gradient issues.
"""
# Handle different quaternion formats
if w_last:
# Quaternion is [q_x, q_y, q_z, q_w]
quat_w_orig = quat[..., -1:]
else:
# Quaternion is [q_w, q_x, q_y, q_z]
quat_w_orig = quat[..., 0:1]
# Normalize quaternion to have w > 0
quat = quat * (1.0 - 2.0 * (quat_w_orig < 0.0))
# Ensure unit quaternion for stability
quat = quat / torch.linalg.norm(quat, dim=-1, keepdim=True).clamp_min(
1.0e-9
)
# Recompute quat_xyz and quat_w after potential sign flip
if w_last:
quat_w = quat[..., -1:]
quat_xyz = quat[..., :3]
else:
quat_w = quat[..., 0:1]
quat_xyz = quat[..., 1:4]
mag = torch.linalg.norm(quat_xyz, dim=-1)
half_angle = torch.atan2(mag, quat_w.squeeze(-1))
angle = 2.0 * half_angle
# check whether to apply Taylor approximation
use_taylor = angle.abs() <= 1.0e-6
# To prevent NaN gradients with torch.where, we compute both branches and blend
# based on the condition.
# See: https://pytorch.org/docs/1.9.0/generated/torch.where.html#torch-where
# "However, if you need the gradients to flow through the branches, please use torch.lerp"
# Although we are not using lerp, the principle of avoiding sharp branches is the same.
sin_half_angles_over_angles_approx = 0.5 - angle * angle / 48
# Clamp angle to avoid division by zero in the non-taylor branch when angle is exactly 0.
angle_safe = torch.where(use_taylor, torch.ones_like(angle), angle)
sin_half_angles_over_angles_exact = torch.sin(half_angle) / angle_safe
sin_half_angles_over_angles = torch.where(
use_taylor,
sin_half_angles_over_angles_approx,
sin_half_angles_over_angles_exact,
)
return quat_xyz / sin_half_angles_over_angles[..., None]
@torch.jit.script
def quat_inv(q: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
"""Computes the inverse of a quaternion.
Args:
q: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
eps: A small value to avoid division by zero. Defaults to 1e-9.
Returns:
The inverse quaternion in (w, x, y, z). Shape is (N, 4).
"""
return quat_conjugate(q) / q.pow(2).sum(dim=-1, keepdim=True).clamp(
min=eps
)
# --------------------- WXYZ helpers (torch) ---------------------
def xyzw_to_wxyz(q: torch.Tensor) -> torch.Tensor:
"""
Convert quaternion from XYZW to WXYZ.
Args:
q (torch.Tensor): (..., 4) quaternion in XYZW.
Returns:
torch.Tensor: (..., 4) quaternion in WXYZ.
"""
return torch.cat([q[..., 3:4], q[..., 0:3]], dim=-1)
def wxyz_to_xyzw(q: torch.Tensor) -> torch.Tensor:
"""
Convert quaternion from WXYZ to XYZW.
Args:
q (torch.Tensor): (..., 4) quaternion in WXYZ.
Returns:
torch.Tensor: (..., 4) quaternion in XYZW.
"""
return torch.cat([q[..., 1:4], q[..., 0:1]], dim=-1)
@torch.compile
def quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""
Hamilton product in WXYZ layout using fused implementation.
Args:
q1 (torch.Tensor): (..., 4) WXYZ.
q2 (torch.Tensor): (..., 4) WXYZ.
Returns:
torch.Tensor: (..., 4) WXYZ.
"""
return quat_mul(q1, q2, w_last=False)
def subtract_frame_transforms(
t01: torch.Tensor,
q01: torch.Tensor,
t02: torch.Tensor = None,
q02: torch.Tensor = None,
):
r"""Subtract transformations between two reference frames into a stationary frame.
It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \times T_{02}`,
where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B.
Args:
t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3).
q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3).
Defaults to None, in which case the position is assumed to be zero.
q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
Defaults to None, in which case the orientation is assumed to be identity.
Returns:
A tuple containing the position and orientation of frame 2 w.r.t. frame 1.
Shape of the tensors are (N, 3) and (N, 4) respectively.
"""
# compute orientation
q10 = quat_inv(q01)
if q02 is not None:
q12 = quat_mul(q10, q02)
else:
q12 = q10
# compute translation
if t02 is not None:
t12 = quat_apply(q10, t02 - t01)
else:
t12 = quat_apply(q10, -t01)
return t12, q12
@torch.compile
def quat_normalize_wxyz(q_wxyz: torch.Tensor) -> torch.Tensor:
"""
Normalize quaternion in WXYZ layout.
Args:
q_wxyz (torch.Tensor): (..., 4) WXYZ.
Returns:
torch.Tensor: (..., 4) normalized WXYZ.
"""
return q_wxyz / torch.linalg.norm(q_wxyz, dim=-1, keepdim=True).clamp_min(
1.0e-9
)
@torch.jit.script
def matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
Returns:
Rotation matrices. The shape is (..., 3, 3).
Reference:
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70
"""
r, i, j, k = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/motor_crc.py
================================================
import struct
import numpy as np
from ctypes import Structure, c_uint8, c_float, c_uint32, Array
def crc32_core(data_array, length):
CRC32 = 0xFFFFFFFF
dwPolynomial = 0x04C11DB7
for i in range(length):
data = data_array[i]
for bit in range(32): # Process all 32 bits
if (CRC32 >> 31) & 1: # Check MSB before shift
CRC32 = ((CRC32 << 1) & 0xFFFFFFFF) ^ dwPolynomial
else:
CRC32 = (CRC32 << 1) & 0xFFFFFFFF
if (data >> (31 - bit)) & 1: # Match C++ bit processing order
CRC32 ^= dwPolynomial
return CRC32
def calc_crc(cmd) -> int:
"""Calculate CRC for LowCmd message"""
buffer = bytearray()
# Pack header (mode_pr, mode_machine + 2 padding)
buffer.extend(struct.pack("> i
# 读取原始值
lx_raw = struct.unpack("f", data[4:8])[0]
rx_raw = struct.unpack("f", data[8:12])[0]
ry_raw = struct.unpack("f", data[12:16])[0]
ly_raw = struct.unpack("f", data[20:24])[0]
# 应用滤波和死区
self.lx = self.apply_filter_and_deadzone(lx_raw, self.lx_prev)
self.ly = self.apply_filter_and_deadzone(ly_raw, self.ly_prev)
self.rx = self.apply_filter_and_deadzone(rx_raw, self.rx_prev)
self.ry = self.apply_filter_and_deadzone(ry_raw, self.ry_prev)
# 更新前一次的值
self.lx_prev = self.lx
self.ly_prev = self.ly
self.rx_prev = self.rx
self.ry_prev = self.ry
def get_velocity_commands(self):
"""
将摇杆值转换为速度命令
Returns:
tuple: (vx, vy, vyaw)
- vx: 前进/后退速度 (m/s),由左摇杆前后(ly)控制
- vy: 左右平移速度 (m/s),由左摇杆左右(lx)控制
- vyaw: 转向角速度 (rad/s),由右摇杆左右(rx)控制
"""
# 前进/后退速度,使用左摇杆的y轴
vx = (
self.ly * self.max_linear_speed_x
) # 注意:通常需要取反,因为向前推摇杆时ly为负
# 限制x速度最小值为-0.5
if vx < -0.5:
vx = -0.5
# 左右平移速度,使用左摇杆的x轴
vy = (
-self.lx * self.max_linear_speed_y
) # 注意:可能需要取反,取决于坐标系定义
# 转向角速度,使用右摇杆的x轴
vyaw = (
-self.rx * self.max_angular_speed
) # 注意:可能需要取反,取决于坐标系定义
# 应用速度阈值 - 当速度小于阈值时设为0
if abs(vx) < self.velocity_threshold_x:
vx = 0.0
if abs(vy) < self.velocity_threshold_y:
vy = 0.0
if abs(vyaw) < self.velocity_threshold_yaw:
vyaw = 0.0
return vx, vy, vyaw
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/rotation_helper.py
================================================
import numpy as np
from scipy.spatial.transform import Rotation as R
def get_gravity_orientation(quaternion):
qw = quaternion[0]
qx = quaternion[1]
qy = quaternion[2]
qz = quaternion[3]
gravity_orientation = np.zeros(3)
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
return gravity_orientation
def transform_imu_data(waist_yaw, waist_yaw_omega, imu_quat, imu_omega):
RzWaist = R.from_euler("z", waist_yaw).as_matrix()
R_torso = R.from_quat(
[imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]
).as_matrix()
R_pelvis = np.dot(R_torso, RzWaist.T)
w = np.dot(RzWaist, imu_omega[0]) - np.array([0, 0, waist_yaw_omega])
return R.from_matrix(R_pelvis).as_quat()[[3, 0, 1, 2]], w
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/humanoid_policy/utils/rotations.py
================================================
import torch
from torch import Tensor
import torch.nn.functional as F
from humanoid_policy.utils.maths import (
normalize,
copysign,
)
from typing import Tuple
import numpy as np
from typing import List, Optional
@torch.jit.script
def quat_unit(a):
return normalize(a)
@torch.jit.script
def quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor:
shape = b.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 3)
if w_last:
xyz = a[:, :3]
w = a[:, 3:]
else:
xyz = a[:, 1:]
w = a[:, :1]
t = xyz.cross(b, dim=-1) * 2
return (b + w * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_apply_yaw(quat: Tensor, vec: Tensor, w_last: bool) -> Tensor:
quat_yaw = quat.clone().view(-1, 4)
quat_yaw[:, :2] = 0.0
quat_yaw = normalize(quat_yaw)
return quat_apply(quat_yaw, vec, w_last)
@torch.jit.script
def wrap_to_pi(angles):
angles %= 2 * np.pi
angles -= 2 * np.pi * (angles > np.pi)
return angles
@torch.jit.script
def quat_conjugate(a: Tensor, w_last: bool) -> Tensor:
shape = a.shape
a = a.reshape(-1, 4)
if w_last:
return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)
else:
return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape)
@torch.jit.script
def quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor:
shape = b.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 3)
if w_last:
xyz = a[:, :3]
w = a[:, 3:]
else:
xyz = a[:, 1:]
w = a[:, :1]
t = xyz.cross(b, dim=-1) * 2
return (b + w * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_rotate(q: Tensor, v: Tensor, w_last: bool) -> Tensor:
shape = q.shape
if w_last:
q_w = q[:, -1]
q_vec = q[:, :3]
else:
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a + b + c
@torch.jit.script
def quat_rotate_inverse(q: Tensor, v: Tensor, w_last: bool) -> Tensor:
shape = q.shape
if w_last:
q_w = q[:, -1]
q_vec = q[:, :3]
else:
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a - b + c
@torch.jit.script
def quat_angle_axis(x: Tensor, w_last: bool) -> Tuple[Tensor, Tensor]:
"""
The (angle, axis) representation of the rotation. The axis is normalized to unit length.
The angle is guaranteed to be between [0, pi].
"""
if w_last:
w = x[..., -1]
axis = x[..., :3]
else:
w = x[..., 0]
axis = x[..., 1:]
s = 2 * (w**2) - 1
angle = s.clamp(-1, 1).arccos() # just to be safe
axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9)
return angle, axis
@torch.jit.script
def quat_from_angle_axis(angle: Tensor, axis: Tensor, w_last: bool) -> Tensor:
theta = (angle / 2).unsqueeze(-1)
xyz = normalize(axis) * theta.sin()
w = theta.cos()
if w_last:
return quat_unit(torch.cat([xyz, w], dim=-1))
else:
return quat_unit(torch.cat([w, xyz], dim=-1))
@torch.jit.script
def vec_to_heading(h_vec):
h_theta = torch.atan2(h_vec[..., 1], h_vec[..., 0])
return h_theta
@torch.jit.script
def heading_to_quat(h_theta, w_last: bool):
axis = torch.zeros(
h_theta.shape
+ [
3,
],
device=h_theta.device,
)
axis[..., 2] = 1
heading_q = quat_from_angle_axis(h_theta, axis, w_last=w_last)
return heading_q
@torch.jit.script
def quat_axis(q: Tensor, axis: int, w_last: bool) -> Tensor:
basis_vec = torch.zeros(q.shape[0], 3, device=q.device)
basis_vec[:, axis] = 1
return quat_rotate(q, basis_vec, w_last)
@torch.jit.script
def normalize_angle(x):
return torch.atan2(torch.sin(x), torch.cos(x))
@torch.jit.script
def get_basis_vector(q: Tensor, v: Tensor, w_last: bool) -> Tensor:
return quat_rotate(q, v, w_last)
@torch.jit.script
def quat_to_angle_axis(q):
# type: (Tensor) -> Tuple[Tensor, Tensor]
# computes axis-angle representation from quaternion q
# q must be normalized
# ZL: could have issues.
min_theta = 1e-5
qx, qy, qz, qw = 0, 1, 2, 3
sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])
angle = 2 * torch.acos(q[..., qw])
angle = normalize_angle(angle)
sin_theta_expand = sin_theta.unsqueeze(-1)
axis = q[..., qx:qw] / sin_theta_expand
mask = torch.abs(sin_theta) > min_theta
default_axis = torch.zeros_like(axis)
default_axis[..., -1] = 1
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
@torch.jit.script
def slerp(q0, q1, t):
# type: (Tensor, Tensor, Tensor) -> Tensor
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = q1.clone()
q1[neg_mask] = -q1[neg_mask]
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta)
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)
ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta
ratioB = torch.sin(t * half_theta) / sin_half_theta
new_q = ratioA * q0 + ratioB * q1
new_q = torch.where(
torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q
)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
@torch.jit.script
def angle_axis_to_exp_map(angle, axis):
# type: (Tensor, Tensor) -> Tensor
# compute exponential map from axis-angle
angle_expand = angle.unsqueeze(-1)
exp_map = angle_expand * axis
return exp_map
@torch.jit.script
def my_quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a + b + c
@torch.jit.script
def calc_heading(q):
# type: (Tensor) -> Tensor
# calculate heading direction from quaternion
# the heading is the direction on the xy plane
# q must be normalized
# this is the x axis heading
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., 0] = 1
rot_dir = my_quat_rotate(q, ref_dir)
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])
return heading
@torch.jit.script
def quat_to_exp_map(q):
# type: (Tensor) -> Tensor
# compute exponential map from quaternion
# q must be normalized
angle, axis = quat_to_angle_axis(q)
exp_map = angle_axis_to_exp_map(angle, axis)
return exp_map
@torch.jit.script
def calc_heading_quat(q, w_last):
# type: (Tensor, bool) -> Tensor
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(heading, axis, w_last=w_last)
return heading_q
@torch.jit.script
def calc_heading_quat_inv(q, w_last):
# type: (Tensor, bool) -> Tensor
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(-heading, axis, w_last=w_last)
return heading_q
@torch.jit.script
def quat_inverse(x, w_last):
# type: (Tensor, bool) -> Tensor
"""
The inverse of the rotation
"""
return quat_conjugate(x, w_last=w_last)
@torch.jit.script
def get_euler_xyz(q: Tensor, w_last: bool) -> Tuple[Tensor, Tensor, Tensor]:
if w_last:
qx, qy, qz, qw = 0, 1, 2, 3
else:
qw, qx, qy, qz = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = (
q[:, qw] * q[:, qw]
- q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
+ q[:, qz] * q[:, qz]
)
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(
torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)
)
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = (
q[:, qw] * q[:, qw]
+ q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
- q[:, qz] * q[:, qz]
)
yaw = torch.atan2(siny_cosp, cosy_cosp)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)
# @torch.jit.script
def get_euler_xyz_in_tensor(q):
qx, qy, qz, qw = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = (
q[:, qw] * q[:, qw]
- q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
+ q[:, qz] * q[:, qz]
)
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(
torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)
)
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = (
q[:, qw] * q[:, qw]
+ q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
- q[:, qz] * q[:, qz]
)
yaw = torch.atan2(siny_cosp, cosy_cosp)
return torch.stack((roll, pitch, yaw), dim=-1)
@torch.jit.script
def quat_pos(x):
"""
make all the real part of the quaternion positive
"""
q = x
z = (q[..., 3:] < 0).float()
q = (1 - 2 * z) * q
return q
@torch.jit.script
def is_valid_quat(q):
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
return (w * w + x * x + y * y + z * z).allclose(torch.ones_like(w))
@torch.jit.script
def quat_normalize(q):
"""
Construct 3D rotation from quaternion (the quaternion needs not to be normalized).
"""
q = quat_unit(quat_pos(q)) # normalized to positive and unit quaternion
return q
@torch.jit.script
def quat_mul(a, b, w_last: bool):
assert a.shape == b.shape
shape = a.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 4)
if w_last:
x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
else:
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
if w_last:
quat = torch.stack([x, y, z, w], dim=-1).view(shape)
else:
quat = torch.stack([w, x, y, z], dim=-1).view(shape)
return quat
@torch.jit.script
def quat_mul_norm(x, y, w_last):
# type: (Tensor, Tensor, bool) -> Tensor
"""
Combine two set of 3D rotations together using \**\* operator. The shape needs to be
broadcastable
"""
return quat_normalize(quat_mul(x, y, w_last))
@torch.jit.script
def quat_mul_norm(x, y, w_last):
# type: (Tensor, Tensor, bool) -> Tensor
"""
Combine two set of 3D rotations together using \**\* operator. The shape needs to be
broadcastable
"""
return quat_unit(quat_mul(x, y, w_last))
@torch.jit.script
def quat_identity(shape: List[int]):
"""
Construct 3D identity rotation given shape
"""
w = torch.ones(shape + [1])
xyz = torch.zeros(shape + [3])
q = torch.cat([xyz, w], dim=-1)
return quat_normalize(q)
@torch.jit.script
def quat_identity_like(x):
"""
Construct identity 3D rotation with the same shape
"""
return quat_identity(x.shape[:-1])
@torch.jit.script
def transform_from_rotation_translation(
r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None
):
"""
Construct a transform from a quaternion and 3D translation. Only one of them can be None.
"""
assert r is not None or t is not None, (
"rotation and translation can't be all None"
)
if r is None:
assert t is not None
r = quat_identity(list(t.shape))
if t is None:
t = torch.zeros(list(r.shape) + [3])
return torch.cat([r, t], dim=-1)
@torch.jit.script
def transform_rotation(x):
"""Get rotation from transform"""
return x[..., :4]
@torch.jit.script
def transform_translation(x):
"""Get translation from transform"""
return x[..., 4:]
@torch.jit.script
def transform_mul(x, y):
"""
Combine two transformation together
"""
z = transform_from_rotation_translation(
r=quat_mul_norm(
transform_rotation(x), transform_rotation(y), w_last=True
),
t=quat_rotate(
transform_rotation(x), transform_translation(y), w_last=True
)
+ transform_translation(x),
)
return z
##################################### FROM PHC rotation_conversions.py #####################################
@torch.jit.script
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
@torch.jit.script
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as axis/angle to quaternions.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles],
dim=-1,
)
return quaternions
# @torch.jit.script
def wxyz_to_xyzw(quat):
return quat[..., [1, 2, 3, 0]]
# @torch.jit.script
def xyzw_to_wxyz(quat):
return quat[..., [3, 0, 1, 2]]
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
w x y z
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
torch.stack(
[q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1
),
torch.stack(
[m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1
),
torch.stack(
[m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1
),
torch.stack(
[m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1
),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,
:, # pyre-ignore[16]
].reshape(batch_dim + (4,))
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def quat_w_first(rot):
rot = torch.cat([rot[..., [-1]], rot[..., :-1]], -1)
return rot
@torch.jit.script
def quat_from_euler_xyz(roll, pitch, yaw):
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp
return torch.stack([qx, qy, qz, qw], dim=-1)
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/include/common/motor_crc.h
================================================
/*****************************************************************
Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved.
******************************************************************/
#ifndef _MOTOR_CRC_H_
#define _MOTOR_CRC_H_
#include
#include
#include "rclcpp/rclcpp.hpp"
#include "unitree_go/msg/low_cmd.hpp"
#include "unitree_go/msg/motor_cmd.hpp"
#include "unitree_go/msg/bms_cmd.hpp"
constexpr int HIGHLEVEL = 0xee;
constexpr int LOWLEVEL = 0xff;
constexpr int TRIGERLEVEL = 0xf0;
constexpr double PosStopF = (2.146E+9f);
constexpr double VelStopF = (16000.0f);
// joint index
constexpr int FR_0 = 0;
constexpr int FR_1 = 1;
constexpr int FR_2 = 2;
constexpr int FL_0 = 3;
constexpr int FL_1 = 4;
constexpr int FL_2 = 5;
constexpr int RR_0 = 6;
constexpr int RR_1 = 7;
constexpr int RR_2 = 8;
constexpr int RL_0 = 9;
constexpr int RL_1 = 10;
constexpr int RL_2 = 11;
typedef struct
{
uint8_t off; // off 0xA5
std::array reserve;
} BmsCmd;
typedef struct
{
uint8_t mode; // desired working mode
float q; // desired angle (unit: radian)
float dq; // desired velocity (unit: radian/second)
float tau; // desired output torque (unit: N.m)
float Kp; // desired position stiffness (unit: N.m/rad )
float Kd; // desired velocity stiffness (unit: N.m/(rad/s) )
std::array reserve;
} MotorCmd; // motor control
typedef struct
{
std::array head;
uint8_t levelFlag;
uint8_t frameReserve;
std::array SN;
std::array version;
uint16_t bandWidth;
std::array motorCmd;
BmsCmd bms;
std::array wirelessRemote;
std::array led;
std::array fan;
uint8_t gpio;
uint32_t reserve;
uint32_t crc;
} LowCmd;
uint32_t crc32_core(uint32_t* ptr, uint32_t len);
void get_crc(unitree_go::msg::LowCmd& msg);
#endif
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/include/common/motor_crc_hg.h
================================================
/*****************************************************************
Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved.
******************************************************************/
#ifndef _MOTOR_CRC_H_
#define _MOTOR_CRC_H_
#include
#include
#include "rclcpp/rclcpp.hpp"
#include "unitree_hg/msg/low_cmd.hpp"
#include "unitree_hg/msg/motor_cmd.hpp"
typedef struct
{
uint8_t mode; // desired working mode
float q; // desired angle (unit: radian)
float dq; // desired velocity (unit: radian/second)
float tau; // desired output torque (unit: N.m)
float Kp; // desired position stiffness (unit: N.m/rad )
float Kd; // desired velocity stiffness (unit: N.m/(rad/s) )
uint32_t reserve = 0;
} MotorCmd; // motor control
typedef struct
{
uint8_t modePr;
uint8_t modeMachine;
std::array motorCmd;
std::array reserve;
uint32_t crc;
} LowCmd;
uint32_t crc32_core(uint32_t *ptr, uint32_t len);
void get_crc(unitree_hg::msg::LowCmd &msg);
#endif
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/include/common/ros2_sport_client.h
================================================
#ifndef _ROS2_SPORT_CLIENT_
#define _ROS2_SPORT_CLIENT_
#include
#include "nlohmann/json.hpp"
#include "unitree_api/msg/request.hpp"
#pragma pack(1)
const int32_t ROBOT_SPORT_API_ID_DAMP = 1001;
const int32_t ROBOT_SPORT_API_ID_BALANCESTAND = 1002;
const int32_t ROBOT_SPORT_API_ID_STOPMOVE = 1003;
const int32_t ROBOT_SPORT_API_ID_STANDUP = 1004;
const int32_t ROBOT_SPORT_API_ID_STANDDOWN = 1005;
const int32_t ROBOT_SPORT_API_ID_RECOVERYSTAND = 1006;
const int32_t ROBOT_SPORT_API_ID_EULER = 1007;
const int32_t ROBOT_SPORT_API_ID_MOVE = 1008;
const int32_t ROBOT_SPORT_API_ID_SIT = 1009;
const int32_t ROBOT_SPORT_API_ID_RISESIT = 1010;
const int32_t ROBOT_SPORT_API_ID_SWITCHGAIT = 1011;
const int32_t ROBOT_SPORT_API_ID_TRIGGER = 1012;
const int32_t ROBOT_SPORT_API_ID_BODYHEIGHT = 1013;
const int32_t ROBOT_SPORT_API_ID_FOOTRAISEHEIGHT = 1014;
const int32_t ROBOT_SPORT_API_ID_SPEEDLEVEL = 1015;
const int32_t ROBOT_SPORT_API_ID_HELLO = 1016;
const int32_t ROBOT_SPORT_API_ID_STRETCH = 1017;
const int32_t ROBOT_SPORT_API_ID_TRAJECTORYFOLLOW = 1018;
const int32_t ROBOT_SPORT_API_ID_CONTINUOUSGAIT = 1019;
const int32_t ROBOT_SPORT_API_ID_CONTENT = 1020;
const int32_t ROBOT_SPORT_API_ID_WALLOW = 1021;
const int32_t ROBOT_SPORT_API_ID_DANCE1 = 1022;
const int32_t ROBOT_SPORT_API_ID_DANCE2 = 1023;
const int32_t ROBOT_SPORT_API_ID_GETBODYHEIGHT = 1024;
const int32_t ROBOT_SPORT_API_ID_GETFOOTRAISEHEIGHT = 1025;
const int32_t ROBOT_SPORT_API_ID_GETSPEEDLEVEL = 1026;
const int32_t ROBOT_SPORT_API_ID_SWITCHJOYSTICK = 1027;
const int32_t ROBOT_SPORT_API_ID_POSE = 1028;
const int32_t ROBOT_SPORT_API_ID_SCRAPE = 1029;
const int32_t ROBOT_SPORT_API_ID_FRONTFLIP = 1030;
const int32_t ROBOT_SPORT_API_ID_FRONTJUMP = 1031;
const int32_t ROBOT_SPORT_API_ID_FRONTPOUNCE = 1032;
typedef struct
{
float timeFromStart;
float x;
float y;
float yaw;
float vx;
float vy;
float vyaw;
} PathPoint;
class SportClient
{
public:
/*
* @brief Damp
* @api: 1001
*/
void Damp(unitree_api::msg::Request &req);
/*
* @brief BalanceStand
* @api: 1002
*/
void BalanceStand(unitree_api::msg::Request &req);
/*
* @brief StopMove
* @api: 1003
*/
void StopMove(unitree_api::msg::Request &req);
/*
* @brief StandUp
* @api: 1004
*/
void StandUp(unitree_api::msg::Request &req);
/*
* @brief StandDown
* @api: 1005
*/
void StandDown(unitree_api::msg::Request &req);
/*
* @brief RecoveryStand
* @api: 1006
*/
void RecoveryStand(unitree_api::msg::Request &req);
/*
* @brief Euler
* @api: 1007
*/
void Euler(unitree_api::msg::Request &req, float roll, float pitch, float yaw);
/*
* @brief Move
* @api: 1008
*/
void Move(unitree_api::msg::Request &req, float vx, float vy, float vyaw);
/*
* @brief Sit
* @api: 1009
*/
void Sit(unitree_api::msg::Request &req);
/*
* @brief RiseSit
* @api: 1010
*/
void RiseSit(unitree_api::msg::Request &req);
/*
* @brief SwitchGait
* @api: 1011
*/
void SwitchGait(unitree_api::msg::Request &req, int d);
/*
* @brief Trigger
* @api: 1012
*/
void Trigger(unitree_api::msg::Request &req);
/*
* @brief BodyHeight
* @api: 1013
*/
void BodyHeight(unitree_api::msg::Request &req, float height);
/*
* @brief FootRaiseHeight
* @api: 1014
*/
void FootRaiseHeight(unitree_api::msg::Request &req, float height);
/*
* @brief SpeedLevel
* @api: 1015
*/
void SpeedLevel(unitree_api::msg::Request &req, int level);
/*
* @brief Hello
* @api: 1016
*/
void Hello(unitree_api::msg::Request &req);
/*
* @brief Stretch
* @api: 1017
*/
void Stretch(unitree_api::msg::Request &req);
/*
* @brief TrajectoryFollow
* @api: 1018
*/
void TrajectoryFollow(unitree_api::msg::Request &req, std::vector &path);
/*
* @brief SwitchJoystick
* @api: 1027
*/
void SwitchJoystick(unitree_api::msg::Request &req, bool flag);
/*
* @brief ContinuousGait
* @api: 1019
*/
void ContinuousGait(unitree_api::msg::Request &req, bool flag);
/*
* @brief Wallow
* @api: 1021
*/
void Wallow(unitree_api::msg::Request &req);
/*
* @brief Content
* @api: 1020
*/
void Content(unitree_api::msg::Request &req);
/*
* @brief Pose
* @api: 1028
*/
void Pose(unitree_api::msg::Request &req, bool flag);
/*
* @brief Scrape
* @api: 1029
*/
void Scrape(unitree_api::msg::Request &req);
/*
* @brief FrontFlip
* @api: 1030
*/
void FrontFlip(unitree_api::msg::Request &req);
/*
* @brief FrontJump
* @api: 1031
*/
void FrontJump(unitree_api::msg::Request &req);
/*
* @brief FrontPounce
* @api: 1032
*/
void FrontPounce(unitree_api::msg::Request &req);
/*
* @brief Dance1
* @api: 1022
*/
void Dance1(unitree_api::msg::Request &req);
/*
* @brief Dance2
* @api: 1023
*/
void Dance2(unitree_api::msg::Request &req);
};
#endif
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/include/common/wireless_controller.h
================================================
#pragma once
#include "unitree_go/msg/wireless_controller.hpp"
#include
#include
class KeyMap {
public:
static const int R1;
static const int L1;
static const int start;
static const int select;
static const int R2;
static const int L2;
static const int F1;
static const int F2;
static const int A;
static const int B;
static const int X;
static const int Y;
static const int up;
static const int right;
static const int down;
static const int left;
};
class RemoteController {
public:
// Constructor
RemoteController();
// Add overloaded set method for raw data
void set(const std::array& data);
// Keep original method for compatibility
void set(const unitree_go::msg::WirelessController::SharedPtr msg);
// Member variables
double lx;
double ly;
double rx;
double ry;
int button[16];
};
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py
================================================
"""
HoloMotion ROS2 Launch Configuration
This module defines the ROS2 launch configuration for the HoloMotion humanoid robot control system.
It sets up a complete robotics pipeline including robot control, motion policy execution, and data recording
for the Unitree G1 humanoid robot.
The launch file coordinates three main components:
1. Main control node (C++) - Handles low-level robot control and communication
2. Policy node (Python) - Executes motion policies and high-level decision making
3. Recording node - Captures sensor data and commands for analysis
Key Features:
- Configures network interface for robot communication
- Sets up CycloneDDS middleware with specific network interface
- Launches coordinated multi-node system with shared configuration
- Automatically records operational data with timestamped bags
Author: HoloMotion Team
License: See project LICENSE file
"""
from datetime import datetime
import os
from launch import LaunchDescription
from launch.actions import SetEnvironmentVariable, DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
from ament_index_python.packages import get_package_share_directory
from launch.actions import ExecuteProcess
from launch.conditions import IfCondition
def generate_launch_description():
"""
Generate the complete launch description for the HoloMotion humanoid control system.
This function creates a comprehensive ROS2 launch configuration that coordinates
multiple nodes required for humanoid robot operation. It sets up the necessary
environment, launches control nodes, and optionally initiates data recording.
Network Configuration:
- Uses specific network interface (eth0) for robot communication
- Configures CycloneDDS middleware to use designated network interface
- Ensures proper isolation and communication with the robot hardware
Node Architecture:
1. Main Control Node (C++):
- Handles real-time robot control and sensor data processing
- Manages low-level motor commands and feedback loops
- Interfaces directly with robot hardware via configured network
2. Policy Node (Python):
- Executes trained motion policies for humanoid locomotion
- Processes high-level commands and translates to robot actions
- Handles motion planning and behavior coordination
3. Recording Node (Optional):
- Automatically captures all relevant system data when enabled
- Records sensor states, commands, and system metrics
- Creates timestamped bag files for later analysis
Configuration:
- Robot: Unitree G1 with 29 DOF configuration
- Config file: g1_29dof_holomotion.yaml
- Recording format: MCAP for efficient data storage
- Recording: Disabled by default, can be enabled with --record parameter
Recorded Topics (when recording enabled):
- /lowcmd: Low-level motor commands sent to robot
- /lowstate: Robot sensor feedback and joint states
- /humanoid/action: High-level action commands from policy
Parameters:
- enable_recording: Boolean flag to enable/disable topic recording (default: false)
Returns:
LaunchDescription: Complete ROS2 launch configuration with all nodes,
environment variables, and optional recording setup
Raises:
FileNotFoundError: If the configuration file cannot be located
PermissionError: If unable to create recording directory
Example:
Launch without recording (default):
$ ros2 launch humanoid_control holomotion_29dof.launch.py
Launch with recording enabled:
$ ros2 launch humanoid_control holomotion_29dof.launch.py enable_recording:=true
Or using the shell script:
$ ./launch_holomotion_29dof.sh --record
"""
# Declare launch arguments
enable_recording_arg = DeclareLaunchArgument(
"enable_recording",
default_value="false",
description="Enable topic recording (true/false)",
)
network_interface = "eth0"
config_name = "g1_29dof_holomotion.yaml"
pkg_dir = get_package_share_directory("humanoid_control")
config_file = os.path.join(pkg_dir, "config", config_name)
# Allow overriding python interpreter via env var (set by the shell script)
python_executable = os.environ["Deploy_CONDA_PREFIX"] + "/bin/python"
print(f"Using Python executable: {python_executable}")
return LaunchDescription(
[
# Declare launch arguments
enable_recording_arg,
# Main control node (C++)
SetEnvironmentVariable(
name="CYCLONEDDS_URI",
value=f"{network_interface} ",
),
Node(
package="humanoid_control",
executable="humanoid_control",
name="main_node",
parameters=[{"config_path": config_file}],
output="screen",
),
# Policy node (Python)
Node(
package="humanoid_control",
executable="policy_node_29dof",
name="policy_node",
parameters=[{"config_path": config_file}],
output="screen",
prefix=python_executable,
),
# Recording node (conditional)
ExecuteProcess(
cmd=[
"ros2",
"bag",
"record",
"--storage",
"mcap",
"-o",
(
"./bag_record/"
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+ "_"
+ config_name.split(".")[0]
),
"/lowcmd",
"/lowstate",
"/humanoid/action",
],
output="screen",
condition=IfCondition(LaunchConfiguration("enable_recording")),
),
]
)
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/models/.gitkeep
================================================
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/motion_data/.gitkeep
================================================
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/package.xml
================================================
humanoid_control
0.0.0
Humanoid locomotion control package from Horizon Robotics
unitree
TODO: License declaration
ament_cmake
ament_cmake_python
rclcpp
sensor_msgs
unitree_hg
rclpy
python3-numpy
python3-torch
python3-yaml
ros2launch
ament_copyright
ament_flake8
ament_pep257
python3-pytest
ament_cmake
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/resource/humanoid_control
================================================
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/setup.cfg
================================================
[develop]
script_dir=$base/lib/humanoid_control
[install]
install_scripts=$base/lib/humanoid_control
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/setup.py
================================================
from setuptools import setup, find_packages
import os
package_name = "humanoid_control"
data_files = [
(
"share/ament_index/resource_index/packages",
["resource/" + package_name],
),
("share/" + package_name, ["package.xml"]),
]
# Add files from config, launch and model directories
for dir_name in ["config", "launch", "models"]:
if os.path.exists(dir_name): # Only process if directory exists
for root, dirs, files in os.walk(dir_name):
install_dir = os.path.join("share", package_name, root)
list_entry = (install_dir, [os.path.join(root, f) for f in files])
data_files.append(list_entry)
setup(
name=package_name,
version="0.0.1",
packages=find_packages(),
data_files=data_files,
install_requires=["setuptools"],
zip_safe=True,
maintainer="Horizon Robotics",
maintainer_email="maiyue01.chen@horizon.auto",
description="Humanoid locomotion control package from Horizon Robotics",
license="Apache License 2.0",
tests_require=["pytest"],
entry_points={
"console_scripts": [
"policy_node_performance = humanoid_policy.policy_node_performance:main",
],
},
)
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/src/common/motor_crc.cpp
================================================
#include "motor_crc.h"
void get_crc(unitree_go::msg::LowCmd& msg)
{
LowCmd raw{};
memcpy(&raw.head[0], &msg.head[0], 2);
raw.levelFlag=msg.level_flag;
raw.frameReserve=msg.frame_reserve;
memcpy(&raw.SN[0],&msg.sn[0], 8);
memcpy(&raw.version[0], &msg.version[0], 8);
raw.bandWidth=msg.bandwidth;
for(int i = 0; i<20; i++)
{
raw.motorCmd[i].mode=msg.motor_cmd[i].mode;
raw.motorCmd[i].q=msg.motor_cmd[i].q;
raw.motorCmd[i].dq=msg.motor_cmd[i].dq;
raw.motorCmd[i].tau=msg.motor_cmd[i].tau;
raw.motorCmd[i].Kp=msg.motor_cmd[i].kp;
raw.motorCmd[i].Kd=msg.motor_cmd[i].kd;
memcpy(&raw.motorCmd[i].reserve[0], &msg.motor_cmd[i].reserve[0], 12);
}
raw.bms.off=msg.bms_cmd.off;
memcpy(&raw.bms.reserve[0],&msg.bms_cmd.reserve[0], 3);
memcpy(&raw.wirelessRemote[0], &msg.wireless_remote[0], 40);
memcpy(&raw.led[0], &msg.led[0], 12); // go2
memcpy(&raw.fan[0], &msg.fan[0], 2);
raw.gpio=msg.gpio; // go2
raw.reserve=msg.reserve;
raw.crc=crc32_core((uint32_t *)&raw, (sizeof(LowCmd)>>2)-1);
msg.crc=raw.crc;
}
uint32_t crc32_core(uint32_t* ptr, uint32_t len)
{
uint32_t xbit = 0;
uint32_t data = 0;
uint32_t CRC32 = 0xFFFFFFFF;
const uint32_t dwPolynomial = 0x04c11db7;
for (uint32_t i = 0; i < len; i++)
{
xbit = 1 << 31;
data = ptr[i];
for (uint32_t bits = 0; bits < 32; bits++)
{
if (CRC32 & 0x80000000)
{
CRC32 <<= 1;
CRC32 ^= dwPolynomial;
}
else
CRC32 <<= 1;
if (data & xbit)
CRC32 ^= dwPolynomial;
xbit >>= 1;
}
}
return CRC32;
}
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/src/common/motor_crc_hg.cpp
================================================
#include "motor_crc_hg.h"
void get_crc(unitree_hg::msg::LowCmd &msg)
{
LowCmd raw{};
raw.modePr = msg.mode_pr;
raw.modeMachine = msg.mode_machine;
for (int i = 0; i < 35; i++)
{
raw.motorCmd[i].mode = msg.motor_cmd[i].mode;
raw.motorCmd[i].q = msg.motor_cmd[i].q;
raw.motorCmd[i].dq = msg.motor_cmd[i].dq;
raw.motorCmd[i].tau = msg.motor_cmd[i].tau;
raw.motorCmd[i].Kp = msg.motor_cmd[i].kp;
raw.motorCmd[i].Kd = msg.motor_cmd[i].kd;
raw.motorCmd[i].reserve = msg.motor_cmd[i].reserve;
}
memcpy(&raw.reserve[0], &msg.reserve[0], 4);
raw.crc = crc32_core((uint32_t *)&raw, (sizeof(LowCmd) >> 2) - 1);
msg.crc = raw.crc;
}
uint32_t crc32_core(uint32_t *ptr, uint32_t len)
{
uint32_t xbit = 0;
uint32_t data = 0;
uint32_t CRC32 = 0xFFFFFFFF;
const uint32_t dwPolynomial = 0x04c11db7;
for (uint32_t i = 0; i < len; i++)
{
xbit = 1 << 31;
data = ptr[i];
for (uint32_t bits = 0; bits < 32; bits++)
{
if (CRC32 & 0x80000000)
{
CRC32 <<= 1;
CRC32 ^= dwPolynomial;
}
else
CRC32 <<= 1;
if (data & xbit)
CRC32 ^= dwPolynomial;
xbit >>= 1;
}
}
return CRC32;
}
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/src/common/ros2_sport_client.cpp
================================================
#include "ros2_sport_client.h"
void SportClient::Damp(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_DAMP;
}
void SportClient::BalanceStand(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_BALANCESTAND;
}
void SportClient::StopMove(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_STOPMOVE;
}
void SportClient::StandUp(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_STANDUP;
}
void SportClient::StandDown(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_STANDDOWN;
}
void SportClient::RecoveryStand(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_RECOVERYSTAND;
}
void SportClient::Euler(unitree_api::msg::Request &req, float roll, float pitch, float yaw)
{
nlohmann::json js;
js["x"] = roll;
js["y"] = pitch;
js["z"] = yaw;
req.parameter = js.dump();
req.header.identity.api_id = ROBOT_SPORT_API_ID_EULER;
}
void SportClient::Move(unitree_api::msg::Request &req, float vx, float vy, float vyaw)
{
nlohmann::json js;
js["x"] = vx;
js["y"] = vy;
js["z"] = vyaw;
req.parameter = js.dump();
req.header.identity.api_id = ROBOT_SPORT_API_ID_MOVE;
}
void SportClient::Sit(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_SIT;
}
void SportClient::RiseSit(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_RISESIT;
}
void SportClient::SwitchGait(unitree_api::msg::Request &req, int d)
{
nlohmann::json js;
js["data"] = d;
req.header.identity.api_id = ROBOT_SPORT_API_ID_SWITCHGAIT;
req.parameter = js.dump();
}
void SportClient::Trigger(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_TRIGGER;
}
void SportClient::BodyHeight(unitree_api::msg::Request &req, float height)
{
nlohmann::json js;
js["data"] = height;
req.parameter = js.dump();
req.header.identity.api_id = ROBOT_SPORT_API_ID_BODYHEIGHT;
}
void SportClient::FootRaiseHeight(unitree_api::msg::Request &req, float height)
{
nlohmann::json js;
js["data"] = height;
req.parameter = js.dump();
req.header.identity.api_id = ROBOT_SPORT_API_ID_FOOTRAISEHEIGHT;
}
void SportClient::SpeedLevel(unitree_api::msg::Request &req, int level)
{
nlohmann::json js;
js["data"] = level;
req.parameter = js.dump();
req.header.identity.api_id = ROBOT_SPORT_API_ID_SPEEDLEVEL;
}
void SportClient::Hello(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_HELLO;
}
void SportClient::Stretch(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_STRETCH;
}
void SportClient::TrajectoryFollow(unitree_api::msg::Request &req, std::vector &path)
{
nlohmann::json js_path;
req.header.identity.api_id = ROBOT_SPORT_API_ID_TRAJECTORYFOLLOW;
for (int i = 0; i < 30; i++)
{
nlohmann::json js_point;
js_point["t_from_start"] = path[i].timeFromStart;
js_point["x"] = path[i].x;
js_point["y"] = path[i].y;
js_point["yaw"] = path[i].yaw;
js_point["vx"] = path[i].vx;
js_point["vy"] = path[i].vy;
js_point["vyaw"] = path[i].vyaw;
js_path.push_back(js_point);
}
req.parameter =js_path.dump();
}
void SportClient::SwitchJoystick(unitree_api::msg::Request &req, bool flag)
{
nlohmann::json js;
js["data"] = flag;
req.parameter = js.dump();
req.header.identity.api_id = ROBOT_SPORT_API_ID_SWITCHJOYSTICK;
}
void SportClient::ContinuousGait(unitree_api::msg::Request &req, bool flag)
{
nlohmann::json js;
js["data"] = flag;
req.parameter = js.dump();
req.header.identity.api_id = ROBOT_SPORT_API_ID_CONTINUOUSGAIT;
}
void SportClient::Wallow(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_WALLOW;
}
void SportClient::Content(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_CONTENT;
}
void SportClient::Pose(unitree_api::msg::Request &req, bool flag)
{
nlohmann::json js;
js["data"] = flag;
req.parameter = js.dump();
req.header.identity.api_id = ROBOT_SPORT_API_ID_POSE;
}
void SportClient::Scrape(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_SCRAPE;
}
void SportClient::FrontFlip(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTFLIP;
}
void SportClient::FrontJump(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTJUMP;
}
void SportClient::FrontPounce(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_FRONTPOUNCE;
}
void SportClient::Dance1(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_DANCE1;
}
void SportClient::Dance2(unitree_api::msg::Request &req)
{
req.header.identity.api_id = ROBOT_SPORT_API_ID_DANCE2;
}
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/src/common/wireless_controller.cpp
================================================
#include "common/wireless_controller.h"
#include
// Define static constants
const int KeyMap::R1 = 0;
const int KeyMap::L1 = 1;
const int KeyMap::start = 2;
const int KeyMap::select = 3;
const int KeyMap::R2 = 4;
const int KeyMap::L2 = 5;
const int KeyMap::F1 = 6;
const int KeyMap::F2 = 7;
const int KeyMap::A = 8;
const int KeyMap::B = 9;
const int KeyMap::X = 10;
const int KeyMap::Y = 11;
const int KeyMap::up = 12;
const int KeyMap::right = 13;
const int KeyMap::down = 14;
const int KeyMap::left = 15;
// Implement RemoteController methods
RemoteController::RemoteController() {
lx = 0;
ly = 0;
rx = 0;
ry = 0;
std::fill(button, button + 16, 0);
}
void RemoteController::set(const std::array &data) {
// Debug print raw bytes
// printf("Raw data bytes: ");
// for (int i = 0; i < 40; i++) {
// printf("%02x ", data[i]);
// }
// printf("\n");
// Extract keys from bytes 2-3
uint16_t keys = (data[3] << 8) | data[2];
// printf("Keys value: 0x%04x\n", keys);
for (int i = 0; i < 16; i++) {
button[i] = (keys & (1 << i)) >> i;
}
// Extract and print floats before memcpy
float lx_temp, rx_temp, ry_temp, ly_temp;
std::memcpy(&lx_temp, &data[4], 4); // bytes 4-7
std::memcpy(&rx_temp, &data[8], 4); // bytes 8-11
std::memcpy(&ry_temp, &data[12], 4); // bytes 12-15
std::memcpy(&ly_temp, &data[20], 4); // bytes 20-23
// printf("Values before assignment: lx=%f, ly=%f, rx=%f, ry=%f\n", lx_temp,
// ly_temp, rx_temp, ry_temp);
// Assign to class members
lx = lx_temp;
rx = rx_temp;
ry = ry_temp;
ly = ly_temp;
// printf("Values after assignment: lx=%f, ly=%f, rx=%f, ry=%f\n", lx, ly,
// rx,
// ry);
}
void RemoteController::set(
const unitree_go::msg::WirelessController::SharedPtr msg) {
uint16_t keys = msg->keys;
for (int i = 0; i < 16; i++) {
button[i] = (keys & (1 << i)) >> i;
}
lx = msg->lx;
rx = msg->rx;
ry = msg->ry;
ly = msg->ly;
}
================================================
FILE: deployment/unitree_g1_ros2_29dof/src/src/main_node.cpp
================================================
/**
* This example demonstrates how to use ROS2 to send low-level motor commands of
*unitree g1 robot 29 dof
**/
#include "common/motor_crc_hg.h"
#include "common/wireless_controller.h"
#include "rclcpp/rclcpp.hpp"
#include "unitree_go/msg/wireless_controller.hpp"
#include "unitree_hg/msg/low_cmd.hpp"
#include "unitree_hg/msg/low_state.hpp"
#include "unitree_hg/msg/motor_cmd.hpp"
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define INFO_IMU 0 // Set 1 to info IMU states
#define INFO_MOTOR 0 // Set 1 to info motor states
enum PRorAB { PR = 0, AB = 1 };
using std::placeholders::_1;
const int G1_NUM_MOTOR = 29;
enum class RobotState { ZERO_TORQUE, MOVE_TO_DEFAULT, EMERGENCY_STOP, POLICY };
enum class EmergencyStopPhase { DAMPING, DISABLE }; // New enum for emergency stop phases
// Create a humanoid_controller class for low state receive
class humanoid_controller : public rclcpp::Node {
public:
humanoid_controller() : Node("humanoid_controller") {
RCLCPP_INFO(this->get_logger(), "Using main_node !!!");
// Get config path from ROS parameter
std::string config_path =
this->declare_parameter("config_path", "");
RCLCPP_INFO(this->get_logger(), "Config file path: %s",
config_path.c_str());
// Load configuration
loadConfig(config_path);
RCLCPP_INFO(this->get_logger(),
"Entered ZERO_TORQUE state, press start to switch to "
"MOVE_TO_DEFAULT state, press A to switch to POLICY state, "
"press select to emergency stop. Waiting for start signal...");
lowstate_subscriber_ = this->create_subscription(
"/lowstate", 10,
std::bind(&humanoid_controller::LowStateHandler, this, _1));
policy_action_subscriber_ =
this->create_subscription(
"/humanoid/action", 10,
std::bind(&humanoid_controller::PolicyActionHandler, this, _1));
// Add subscribers for kps and kds parameters from policy node
kps_subscriber_ =
this->create_subscription(
"/humanoid/kps", 10,
std::bind(&humanoid_controller::KpsHandler, this, _1));
kds_subscriber_ =
this->create_subscription(
"/humanoid/kds", 10,
std::bind(&humanoid_controller::KdsHandler, this, _1));
lowcmd_publisher_ =
this->create_publisher("/lowcmd", 10);
robot_state_publisher_ =
this->create_publisher("/robot_state", 10);
timer_ =
this->create_wall_timer(std::chrono::milliseconds(timer_dt),
std::bind(&humanoid_controller::Control, this));
time_ = 0;
duration_ = 3; // 3 s
}
private:
std::map dof2motor_idx;
std::map default_dof_pos;
std::map kps;
std::map kds;
std::vector complete_dof_order;
std::vector policy_dof_order;
// Policy-subscribed control parameters
std::vector policy_kps_data;
std::vector policy_kds_data;
bool kps_received_ = false;
bool kds_received_ = false;
RemoteController remote_controller;
std::map target_dof_pos;
std::vector policy_action_data;
// Optional arrays from YAML for start (MOVE_TO_DEFAULT) behavior
bool has_joint_arrays_ = false;
std::vector joint_names_array_;
std::vector default_position_array_;
std::vector kp_array_;
std::vector kd_array_;
std::map move_to_default_kps;
std::map move_to_default_kds;
RobotState current_state_ = RobotState::ZERO_TORQUE;
bool should_shutdown_ = false;
// Add safety limit parameters using existing structure in YAML
std::map> joint_position_limits; // min, max
std::map joint_velocity_limits;
std::map joint_effort_limits;
// Scaling coefficients for limits
double position_limit_scale = 1.0;
double velocity_limit_scale = 1.0;
double effort_limit_scale = 1.0;
EmergencyStopPhase emergency_stop_phase_ = EmergencyStopPhase::DAMPING;
double emergency_stop_time_ = 0.0;
double emergency_damping_duration_ = 2.0; // 1 second of damping before disabling
// Add a helper function to calculate expected torque
double calculateExpectedTorque(const std::string& dof_name, double q_des, double q, double dq) {
double kp = kps[dof_name];
double kd = kds[dof_name];
// dq_des is assumed to be 0 in your control scheme
return kp * (q_des - q) + kd * (0.0 - dq);
}
// Add a helper function to scale kp and kd to limit torque
std::pair limitTorque(const std::string& dof_name, double q_des, double q, double dq) {
double kp = kps[dof_name];
double kd = kds[dof_name];
// Calculate expected torque
double expected_torque = calculateExpectedTorque(dof_name, q_des, q, dq);
double abs_expected_torque = std::abs(expected_torque);
// Check if torque would exceed limit
if (joint_effort_limits.find(dof_name) != joint_effort_limits.end()) {
double max_torque = joint_effort_limits[dof_name] * effort_limit_scale;
if (abs_expected_torque > max_torque && abs_expected_torque > 1e-6) {
// Scale both kp and kd by the same factor to preserve damping characteristics
double scale_factor = max_torque / abs_expected_torque;
return std::make_pair(kp * scale_factor, kd * scale_factor);
}
}
// If no scaling needed, return original values
return std::make_pair(kp, kd);
}
// Add a helper function to scale custom kp and kd to limit torque
std::pair limitTorqueWithCustomGains(
const std::string& dof_name,
double q_des,
double q,
double dq,
double custom_kp,
double custom_kd) {
// Calculate expected torque
double expected_torque = custom_kp * (q_des - q) + custom_kd * (0.0 - dq);
double abs_expected_torque = std::abs(expected_torque);
// Check if torque would exceed limit
if (joint_effort_limits.find(dof_name) != joint_effort_limits.end()) {
double max_torque = joint_effort_limits[dof_name] * effort_limit_scale;
if (abs_expected_torque > max_torque && abs_expected_torque > 1e-6) {
// Scale both kp and kd by the same factor to preserve damping characteristics
double scale_factor = max_torque / abs_expected_torque;
return std::make_pair(custom_kp * scale_factor, custom_kd * scale_factor);
}
}
// If no scaling needed, return original values
return std::make_pair(custom_kp, custom_kd);
}
void loadConfig(const std::string &config_path) {
try {
YAML::Node config = YAML::LoadFile(config_path);
// Load motor indices
auto indices = config["dof2motor_idx_mapping"];
for (const auto &it : indices) {
dof2motor_idx[it.first.as()] = it.second.as();
}
// Load default angles
auto angles = config["default_joint_angles"];
for (const auto &it : angles) {
default_dof_pos[it.first.as()] = it.second.as();
}
// Set target dof pos to default dof pos
for (const auto &it : default_dof_pos) {
target_dof_pos[it.first] = it.second;
}
// Note: kps and kds are now received from policy node via ROS topics
// No longer loading from config file to avoid conflicts
// Load dof order
for (const auto &it : config["complete_dof_order"]) {
complete_dof_order.push_back(it.as());
}
for (const auto &it : config["policy_dof_order"]) {
policy_dof_order.push_back(it.as());
}
// Load control frequency
control_freq_ = config["control_freq"].as();
control_dt_ = 1.0 / control_freq_;
timer_dt = static_cast(control_dt_ * 1000);
RCLCPP_INFO(this->get_logger(), "Control frequency set to: %f Hz",
control_freq_);
// Load joint limits
auto pos_limits = config["joint_limits"]["position"];
for (const auto &it : pos_limits) {
std::string dof_name = it.first.as();
auto limits = it.second.as>();
joint_position_limits[dof_name] = std::make_pair(limits[0], limits[1]);
}
auto vel_limits = config["joint_limits"]["velocity"];
for (const auto &it : vel_limits) {
joint_velocity_limits[it.first.as()] = it.second.as();
}
auto effort_limits = config["joint_limits"]["effort"];
for (const auto &it : effort_limits) {
joint_effort_limits[it.first.as()] = it.second.as();
}
// Load joint limits scaling coefficients (optional, default to 1.0)
position_limit_scale = config["limit_scales"]["position"].as(1.0);
velocity_limit_scale = config["limit_scales"]["velocity"].as(1.0);
effort_limit_scale = config["limit_scales"]["effort"].as(1.0);
RCLCPP_INFO(this->get_logger(), "Joint limit scales - Position: %f, Velocity: %f, Effort: %f",
position_limit_scale, velocity_limit_scale, effort_limit_scale);
// Optional: arrays for joint configuration on Start
// If kp and kd arrays are provided, use them with joint names and positions
// Auto-generate joint_names and default_position from complete_dof_order and default_joint_angles if not provided
if (config["kp"] && config["kd"]) {
joint_names_array_.clear();
default_position_array_.clear();
kp_array_.clear();
kd_array_.clear();
// Auto-generate joint_names and default_position from existing config if not explicitly provided
if (config["joint_names"] && config["default_position"]) {
// Use explicitly provided arrays
for (const auto &it : config["joint_names"]) {
joint_names_array_.push_back(it.as());
}
for (const auto &it : config["default_position"]) {
default_position_array_.push_back(it.as());
}
} else {
// Auto-generate from complete_dof_order and default_joint_angles
for (const auto &dof_name : complete_dof_order) {
joint_names_array_.push_back(dof_name);
if (default_dof_pos.find(dof_name) != default_dof_pos.end()) {
default_position_array_.push_back(default_dof_pos[dof_name]);
} else {
RCLCPP_WARN(this->get_logger(), "Default position not found for joint %s, using 0.0", dof_name.c_str());
default_position_array_.push_back(0.0);
}
}
RCLCPP_INFO(this->get_logger(), "Auto-generated joint_names and default_position from complete_dof_order and default_joint_angles");
}
// Load kp and kd arrays
for (const auto &it : config["kp"]) {
kp_array_.push_back(it.as());
}
for (const auto &it : config["kd"]) {
kd_array_.push_back(it.as());
}
// Basic validation
if (joint_names_array_.size() == default_position_array_.size() &&
joint_names_array_.size() == kp_array_.size() &&
joint_names_array_.size() == kd_array_.size()) {
has_joint_arrays_ = true;
// Store MoveToDefault-specific kps/kds and default positions
for (size_t i = 0; i < joint_names_array_.size(); ++i) {
const std::string &name = joint_names_array_[i];
double pos = default_position_array_[i];
double kp_v = kp_array_[i];
double kd_v = kd_array_[i];
default_dof_pos[name] = pos;
// Store MoveToDefault kp/kd
move_to_default_kps[name] = kp_v;
move_to_default_kds[name] = kd_v;
}
RCLCPP_INFO(this->get_logger(), "Using joint arrays for Start behavior (size: %zu)", joint_names_array_.size());
} else {
RCLCPP_WARN(this->get_logger(), "joint_names/default_position/kp/kd size mismatch; ignoring arrays");
}
}
} catch (const YAML::Exception &e) {
RCLCPP_ERROR(this->get_logger(), "Error parsing config file: %s",
e.what());
}
}
void Control() {
// First check if we're already in emergency stop
if (current_state_ == RobotState::EMERGENCY_STOP) {
emergency_stop_time_ += control_dt_;
if (emergency_stop_phase_ == EmergencyStopPhase::DAMPING) {
SendDampedEmergencyStop();
if (emergency_stop_time_ >= emergency_damping_duration_) {
emergency_stop_phase_ = EmergencyStopPhase::DISABLE;
RCLCPP_INFO(this->get_logger(), "Damping complete, disabling motors");
}
} else {
SendFinalEmergencyStop();
if (timer_) {
timer_->cancel();
}
rclcpp::shutdown();
return;
}
get_crc(low_command);
lowcmd_publisher_->publish(low_command);
return; // Exit early, ignore all other commands
}
// If not in emergency stop, check for emergency stop command first
if (remote_controller.button[KeyMap::select] == 1) {
current_state_ = RobotState::EMERGENCY_STOP;
should_shutdown_ = true;
publishRobotState();
return;
}
// Process other commands only if not in emergency stop
if (remote_controller.button[KeyMap::L1] == 1 &&
current_state_ != RobotState::ZERO_TORQUE) {
RCLCPP_INFO(this->get_logger(), "Switching to ZERO_TORQUE state");
current_state_ = RobotState::ZERO_TORQUE;
publishRobotState();
}
// Start button only works in ZERO_TORQUE state
if (remote_controller.button[KeyMap::start] == 1) {
if (current_state_ == RobotState::ZERO_TORQUE) {
RCLCPP_INFO(this->get_logger(), "Switching to MOVE_TO_DEFAULT state");
current_state_ = RobotState::MOVE_TO_DEFAULT;
time_ = 0.0;
publishRobotState();
} else {
RCLCPP_INFO(this->get_logger(),
"Start button only works in ZERO_TORQUE state. Current state: %d",
static_cast(current_state_));
}
}
// A button only works in MOVE_TO_DEFAULT state
if (remote_controller.button[KeyMap::A] == 1) {
if (current_state_ == RobotState::MOVE_TO_DEFAULT) {
// Check if kps and kds parameters have been received from policy node
if (!kps_received_ || !kds_received_) {
RCLCPP_ERROR(this->get_logger(),
"Cannot switch to POLICY state. Control parameters not received from policy node! kps_received: %s, kds_received: %s",
kps_received_ ? "true" : "false",
kds_received_ ? "true" : "false");
return;
}
// Check lower body joint positions before allowing transition
bool positions_ok = true;
std::stringstream deviation_msg;
const double position_threshold = 0.4;
// List of lower body joints to check
std::vector lower_body_joints = {
"left_hip_yaw", "left_hip_roll", "left_hip_pitch", "left_knee", "left_ankle_pitch", "left_ankle_roll",
"right_hip_yaw", "right_hip_roll", "right_hip_pitch", "right_knee", "right_ankle_pitch", "right_ankle_roll"
};
for (int i = 0; i < G1_NUM_MOTOR; ++i) {
std::string dof_name = complete_dof_order[i];
// Skip if not a lower body joint
if (std::find(lower_body_joints.begin(), lower_body_joints.end(), dof_name) == lower_body_joints.end()) {
continue;
}
double current_pos = motor[i].q;
double default_pos = default_dof_pos[dof_name];
double diff = std::abs(current_pos - default_pos);
if (diff > position_threshold) {
positions_ok = false;
deviation_msg << dof_name << "(" << diff << "), ";
}
}
if (positions_ok) {
RCLCPP_INFO(this->get_logger(), "Switching to POLICY state");
current_state_ = RobotState::POLICY;
time_ = 0.0;
publishRobotState();
} else {
RCLCPP_WARN(this->get_logger(),
"Cannot switch to POLICY state. Lower body joints with large deviations: %s",
deviation_msg.str().c_str());
}
} else {
RCLCPP_INFO(this->get_logger(),
"A button only works in MOVE_TO_DEFAULT state. Current state: %d",
static_cast(current_state_));
}
}
// Normal state machine logic
switch (current_state_) {
case RobotState::ZERO_TORQUE:
SendZeroTorqueCommand();
get_crc(low_command);
lowcmd_publisher_->publish(low_command);
break;
case RobotState::MOVE_TO_DEFAULT:
SendDefaultPositionCommand();
get_crc(low_command);
lowcmd_publisher_->publish(low_command);
break;
case RobotState::POLICY:
SendPolicyCommand();
get_crc(low_command);
lowcmd_publisher_->publish(low_command);
break;
case RobotState::EMERGENCY_STOP:
// Emergency stop is handled at the beginning of the function
// This case should not be reached due to early return
break;
}
// Publish current robot state
publishRobotState();
}
void SendZeroTorqueCommand() {
low_command.mode_pr = mode_;
low_command.mode_machine = mode_machine;
for (int i = 0; i < G1_NUM_MOTOR; ++i) {
low_command.motor_cmd[i].mode = 1; // Enable
low_command.motor_cmd[i].q = 0.0;
low_command.motor_cmd[i].dq = 0.0;
low_command.motor_cmd[i].kp = 0.0;
low_command.motor_cmd[i].kd = 0.0;
low_command.motor_cmd[i].tau = 0.0;
}
}
void SendDefaultPositionCommand() {
time_ += control_dt_;
low_command.mode_pr = mode_;
low_command.mode_machine = mode_machine;
// Print kp/kd values on first execution
static bool first_move_to_default = true;
if (first_move_to_default) {
RCLCPP_INFO(this->get_logger(), "=== First MOVE_TO_DEFAULT execution ===");
first_move_to_default = false;
}
if (has_joint_arrays_) {
// Use provided arrays and dof2motor mapping to command motors
double ratio = clamp(time_ / duration_, 0.0, 1.0);
for (size_t j = 0; j < joint_names_array_.size(); ++j) {
const std::string &dof_name = joint_names_array_[j];
if (dof2motor_idx.find(dof_name) == dof2motor_idx.end()) {
continue; // skip unknown names
}
int motor_idx = dof2motor_idx[dof_name];
double target_final = default_position_array_[j];
double target_pos = (1. - ratio) * motor[motor_idx].q + ratio * target_final;
// Current state
double current_pos = motor[motor_idx].q;
double current_vel = motor[motor_idx].dq;
// Use MoveToDefault specialized kp/kd
double kp_to_use = move_to_default_kps[dof_name];
double kd_to_use = move_to_default_kds[dof_name];
// Print kp/kd values on first execution
if (time_ <= control_dt_ * 2) { // Print for first few iterations
RCLCPP_INFO(this->get_logger(), "MoveToDefault - %s: kp=%.2f, kd=%.2f",
dof_name.c_str(), kp_to_use, kd_to_use);
}
// Apply torque limiting with MoveToDefault gains
auto [limited_kp, limited_kd] = limitTorqueWithCustomGains(
dof_name, target_pos, current_pos, current_vel, kp_to_use, kd_to_use);
low_command.motor_cmd[motor_idx].mode = 1;
low_command.motor_cmd[motor_idx].tau = 0.0;
low_command.motor_cmd[motor_idx].q = target_pos;
low_command.motor_cmd[motor_idx].dq = 0.0;
low_command.motor_cmd[motor_idx].kp = limited_kp;
low_command.motor_cmd[motor_idx].kd = limited_kd;
}
} else {
// Fall back to map-driven order with default MoveToDefault gains
// Use default kp/kd values for MoveToDefault since policy kps/kds are not available yet
const double default_move_kp = 50.0; // Default stiffness for MoveToDefault
const double default_move_kd = 5.0; // Default damping for MoveToDefault
// Print default kp/kd values on first execution
if (time_ <= control_dt_ * 2) { // Print for first few iterations
RCLCPP_INFO(this->get_logger(), "MoveToDefault (fallback) - Using default kp=%.2f, kd=%.2f",
default_move_kp, default_move_kd);
}
for (int i = 0; i < G1_NUM_MOTOR; ++i) {
std::string dof_name = complete_dof_order[i];
double ratio = clamp(time_ / duration_, 0.0, 1.0);
double target_pos = (1. - ratio) * motor[i].q + ratio * default_dof_pos[dof_name];
// Current state
double current_pos = motor[i].q;
double current_vel = motor[i].dq;
// Use default MoveToDefault gains with torque limiting
auto [limited_kp, limited_kd] = limitTorqueWithCustomGains(
dof_name, target_pos, current_pos, current_vel, default_move_kp, default_move_kd);
low_command.motor_cmd[i].mode = 1;
low_command.motor_cmd[i].tau = 0.0;
low_command.motor_cmd[i].q = target_pos;
low_command.motor_cmd[i].dq = 0.0;
low_command.motor_cmd[i].kp = limited_kp;
low_command.motor_cmd[i].kd = limited_kd;
}
}
}
void SendPolicyCommand() {
time_ += control_dt_;
low_command.mode_pr = mode_;
low_command.mode_machine = mode_machine;
// Print kp/kd values on first execution
static bool first_policy_command = true;
if (first_policy_command) {
RCLCPP_INFO(this->get_logger(), "=== First POLICY command execution ===");
first_policy_command = false;
}
// Check if kps and kds parameters have been received from policy node
if (!kps_received_ || !kds_received_) {
RCLCPP_ERROR(this->get_logger(),
"Policy control parameters not received! kps_received: %s, kds_received: %s",
kps_received_ ? "true" : "false",
kds_received_ ? "true" : "false");
RCLCPP_ERROR(this->get_logger(), "Cannot execute POLICY commands without control parameters. Triggering emergency stop.");
current_state_ = RobotState::EMERGENCY_STOP;
should_shutdown_ = true;
publishRobotState();
return;
}
for (const auto &pair : target_dof_pos) {
const std::string &dof_name = pair.first;
const double &target_pos = pair.second;
int motor_idx = dof2motor_idx[dof_name];
// Get policy kp/kd values
double policy_kp = kps[dof_name];
double policy_kd = kds[dof_name];
// Print kp/kd values on first execution
if (time_ <= control_dt_ * 2) { // Print for first few iterations
RCLCPP_INFO(this->get_logger(), "Policy - %s: kp=%.2f, kd=%.2f",
dof_name.c_str(), policy_kp, policy_kd);
}
// Use policy kp/kd values directly without torque limiting
low_command.motor_cmd[motor_idx].mode = 1;
low_command.motor_cmd[motor_idx].tau = 0.0;
low_command.motor_cmd[motor_idx].q = target_pos;
low_command.motor_cmd[motor_idx].dq = 0.0;
low_command.motor_cmd[motor_idx].kp = policy_kp;
low_command.motor_cmd[motor_idx].kd = policy_kd;
}
}
void SendDampedEmergencyStop() {
low_command.mode_pr = mode_;
low_command.mode_machine = mode_machine;
// Use default damping value for emergency stop since kds may not be available
const double default_emergency_kd = 10.0; // Higher damping for faster stopping
for (int i = 0; i < G1_NUM_MOTOR; ++i) {
std::string dof_name = complete_dof_order[i];
low_command.motor_cmd[i].mode = 1; // Keep enabled
low_command.motor_cmd[i].q = motor[i].q; // Current position
low_command.motor_cmd[i].dq = 0.0; // Target zero velocity
low_command.motor_cmd[i].kp = 0.0; // No position control
low_command.motor_cmd[i].kd = default_emergency_kd; // Use default damping
low_command.motor_cmd[i].tau = 0.0;
}
}
void SendFinalEmergencyStop() {
low_command.mode_pr = mode_;
low_command.mode_machine = mode_machine;
for (int i = 0; i < G1_NUM_MOTOR; ++i) {
low_command.motor_cmd[i].mode = 0; // Disable
low_command.motor_cmd[i].q = 0.0;
low_command.motor_cmd[i].dq = 0.0;
low_command.motor_cmd[i].kp = 0.0;
low_command.motor_cmd[i].kd = 0.0;
low_command.motor_cmd[i].tau = 0.0;
}
}
void LowStateHandler(unitree_hg::msg::LowState::SharedPtr message) {
mode_machine = (int)message->mode_machine;
imu = message->imu_state;
for (int i = 0; i < G1_NUM_MOTOR; i++) {
motor[i] = message->motor_state[i];
}
// Check joint limits for all joints
bool limits_exceeded = false;
std::string exceeded_msg;
// Trigger emergency stop if any limits are exceeded
if (limits_exceeded) {
RCLCPP_ERROR(this->get_logger(), "%s", exceeded_msg.c_str());
RCLCPP_ERROR(this->get_logger(), "Joint limits exceeded! Triggering emergency stop.");
// current_state_ = RobotState::EMERGENCY_STOP;
// should_shutdown_ = true;
// publishRobotState();
}
remote_controller.set(message->wireless_remote);
}
void PolicyActionHandler(
const std_msgs::msg::Float32MultiArray::SharedPtr message) {
// RCLCPP_INFO(this->get_logger(), "PolicyActionHandler called!");
policy_action_data = message->data;
// Check if message size matches expected size
if (policy_action_data.size() != policy_dof_order.size()) {
RCLCPP_ERROR(this->get_logger(),
"Policy action data size mismatch: got %zu, expected %zu",
policy_action_data.size(), policy_dof_order.size());
current_state_ = RobotState::EMERGENCY_STOP;
should_shutdown_ = true;
publishRobotState();
return;
}
// set target dof pos
for (size_t i = 0; i < policy_dof_order.size(); i++) {
const auto &dof_name = policy_dof_order[i];
double calculated_pos = policy_action_data[i];
// Check if the target position is within joint limits (with scaling)
if (joint_position_limits.find(dof_name) != joint_position_limits.end()) {
// Calculate the middle point of the range
double mid_pos = (joint_position_limits[dof_name].first + joint_position_limits[dof_name].second) / 2.0;
// Calculate the half-range and scale it
double half_range = (joint_position_limits[dof_name].second - joint_position_limits[dof_name].first) / 2.0;
double scaled_half_range = half_range * position_limit_scale;
// Calculate scaled min and max by expanding from midpoint
double min_pos = mid_pos - scaled_half_range;
double max_pos = mid_pos + scaled_half_range;
if (calculated_pos < min_pos || calculated_pos > max_pos) {
// RCLCPP_WARN(this->get_logger(),
// "Target position would exceed limit for joint %s: %f (scaled limits: %f, %f)",
// dof_name.c_str(), calculated_pos, min_pos, max_pos);
// Clamp the position to within limits
calculated_pos = std::clamp(calculated_pos, min_pos, max_pos);
}
}
// Set the target position (clamped to safe values if needed)
target_dof_pos[dof_name] = calculated_pos;
}
}
void KpsHandler(const std_msgs::msg::Float32MultiArray::SharedPtr message) {
policy_kps_data = message->data;
kps_received_ = true;
// Check if message size matches expected size
if (policy_kps_data.size() != policy_dof_order.size()) {
RCLCPP_ERROR(this->get_logger(),
"Policy kps data size mismatch: got %zu, expected %zu",
policy_kps_data.size(), policy_dof_order.size());
current_state_ = RobotState::EMERGENCY_STOP;
should_shutdown_ = true;
publishRobotState();
return;
}
// Update kps map with policy data
for (size_t i = 0; i < policy_dof_order.size(); i++) {
const auto &dof_name = policy_dof_order[i];
kps[dof_name] = policy_kps_data[i];
}
RCLCPP_INFO(this->get_logger(), "Received kps parameters from policy node (size: %zu)", policy_kps_data.size());
}
void KdsHandler(const std_msgs::msg::Float32MultiArray::SharedPtr message) {
policy_kds_data = message->data;
kds_received_ = true;
// Check if message size matches expected size
if (policy_kds_data.size() != policy_dof_order.size()) {
RCLCPP_ERROR(this->get_logger(),
"Policy kds data size mismatch: got %zu, expected %zu",
policy_kds_data.size(), policy_dof_order.size());
current_state_ = RobotState::EMERGENCY_STOP;
should_shutdown_ = true;
publishRobotState();
return;
}
// Update kds map with policy data
for (size_t i = 0; i < policy_dof_order.size(); i++) {
const auto &dof_name = policy_dof_order[i];
kds[dof_name] = policy_kds_data[i];
}
RCLCPP_INFO(this->get_logger(), "Received kds parameters from policy node (size: %zu)", policy_kds_data.size());
}
double clamp(double value, double low, double high) {
if (value < low)
return low;
if (value > high)
return high;
return value;
}
std::string robotStateToString(RobotState state) {
switch (state) {
case RobotState::ZERO_TORQUE:
return "ZERO_TORQUE";
case RobotState::MOVE_TO_DEFAULT:
return "MOVE_TO_DEFAULT";
case RobotState::EMERGENCY_STOP:
return "EMERGENCY_STOP";
case RobotState::POLICY:
return "POLICY";
default:
return "UNKNOWN";
}
}
void publishRobotState() {
std_msgs::msg::String state_msg;
state_msg.data = robotStateToString(current_state_);
robot_state_publisher_->publish(state_msg);
}
rclcpp::TimerBase::SharedPtr timer_; // ROS2 timer
rclcpp::Publisher::SharedPtr
lowcmd_publisher_; // ROS2 Publisher
rclcpp::Subscription::SharedPtr
lowstate_subscriber_; // ROS2 Subscriber
rclcpp::Subscription::SharedPtr
policy_action_subscriber_;
rclcpp::Subscription::SharedPtr
kps_subscriber_;
rclcpp::Subscription::SharedPtr
kds_subscriber_;
rclcpp::Publisher::SharedPtr robot_state_publisher_;
unitree_hg::msg::LowCmd low_command; // Unitree hg lowcmd message
unitree_hg::msg::IMUState imu; // Unitree hg IMU message
unitree_hg::msg::MotorState
motor[G1_NUM_MOTOR]; // Unitree hg motor state message
double control_freq_;
double control_dt_;
int timer_dt;
double time_; // Running time count
double duration_;
PRorAB mode_ = PRorAB::PR;
int mode_machine;
RemoteController wireless_remote_;
}; // End of humanoid_controller class
int main(int argc, char **argv) {
rclcpp::init(argc, argv); // Initialize rclcpp
auto node = std::make_shared(); // Create a ROS2 node
rclcpp::spin(node); // Run ROS2 node
rclcpp::shutdown(); // Exit
return 0;
}
================================================
FILE: deployment/unitree_g1_ros2_29dof/start_container.sh
================================================
#!/bin/bash
docker kill holomotion_orin_deploy
docker rm holomotion_orin_deploy
echo "Old holomotion_orin_deploy container removed !"
# Initialize variable as empty
holomotion_repo_path=""
# Loop until the user provides a non-empty string
while [[ -z "$holomotion_repo_path" ]]; do
read -p "Please enter the holomotion local repository path: " holomotion_repo_path
if [[ -z "$holomotion_repo_path" ]]; then
echo "Input cannot be empty."
fi
done
# Validate the directory exists before running Docker
if [ ! -d "$holomotion_repo_path" ]; then
echo "Error: Directory '$holomotion_repo_path' does not exist."
exit 1
fi
echo "Mounting path: $holomotion_repo_path"
sudo docker run -it \
--name holomotion_orin_deploy \
--runtime nvidia \
--gpus all \
--privileged \
--network host \
-e "ACCEPT_EULA=Y" \
-v "$holomotion_repo_path:/home/unitree/holomotion" \
-v "/usr/local/cuda-11.4/targets/aarch64-linux/lib:/cuda_base:ro" \
-v "/usr/lib/aarch64-linux-gnu/libcudnn.so.8.6.0:/host_gpu/libcudnn.so.8.6.0:ro" \
-v "/usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.6.0:/host_gpu/libcudnn_ops_infer.so.8.6.0:ro" \
-v "/usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.6.0:/host_gpu/libcudnn_cnn_infer.so.8.6.0:ro" \
horizonrobotics/holomotion:orin_foxy_jp5.1_humble_deploy_zmq_20260319 \
bash -c "ln -sf /host_gpu/libcudnn.so.8.6.0 /host_gpu/libcudnn.so.8 && \
ln -sf /host_gpu/libcudnn_ops_infer.so.8.6.0 /host_gpu/libcudnn_ops_infer.so.8 && \
ln -sf /host_gpu/libcudnn_cnn_infer.so.8.6.0 /host_gpu/libcudnn_cnn_infer.so.8 && \
source /root/miniconda3/bin/activate && conda activate holomotion_deploy && exec bash"
================================================
FILE: docs/environment_setup.md
================================================
# Environment Setup
## Step 1: Setup Conda
This project uses conda to manage Python environments. We recommend using [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install#linux-installer).
**For users in China:** Configure the conda mirror following [TUNA](https://mirrors.tuna.tsinghua.edu.cn/help/anaconda/) for faster downloads.
## Step 2: Setup Third-party Dependencies
### 2.1 Download SMPL/SMPLX Models
We use SMPL/SMPLX models to retarget mocap data into robot motion data. Register your account and download the models from:
- [SMPL](https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.1.0.zip)
- [SMPLX](https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip)
Place both zip files (`SMPL_python_v.1.1.0.zip` and `models_smplx_v1_1.zip`) in the `thirdparties/` folder, then extract:
```shell
mkdir thirdparties/smpl_models
unzip thirdparties/SMPL_python_v.1.1.0.zip -d thirdparties/smpl_models/
unzip thirdparties/models_smplx_v1_1.zip -d thirdparties/smpl_models/
```
The resulting file structure for smpl models would be:
```shell
thirdparties/
├── smpl_models
├── models
└── SMPL_python_v.1.1.0
```
### 2.2 Pull Submodules
After cloning this repository, run the following command to get all submodule dependencies:
```shell
git submodule update --init --recursive
```
### 2.3 Create Asset Symlinks
This project uses symbolic links to connect robot and SMPL assets from submodules to the main `assets` directory. Symlinks are created automatically when you clone the repository.
### 2.4 Verify Third-party File Structure
After completing the above steps, your file structure should look like this:
```shell
thirdparties/
├── HoloMotion_assets
├── GMR
├── smplx
├── joints2smpl
├── omomo_release
├── smpl_models
├── SMPLSim
├── unitree_ros
└── unitree_ros2
```
## Step 3: Create the Conda Environment
Create the conda environment named `holomotion_train` and `holomotion_deploy`:
```shell
conda env create -f environments/environment_train_isaaclab_cu118.yaml
# for newer GPUs like RTX 5090, use environment_train_isaaclab_cu128.yaml
conda env create -f environments/environment_deploy.yaml
```
Install smplx and GMR into the conda environment:
```shell
cd thirdparties
conda activate holomotion_train
pip install -e ./smplx
# use --no-deps to avoid pulling GMR's dependencies
pip install -e ./GMR --no-deps
```
## Step 4: Configure the Environment Variables
HoloMotion uses `train.env` and `deploy.env` files to export environment variables in the shell entry scripts. Please make sure the `Train_CONDA_PREFIX` and the `Deploy_CONDA_PREFIX` variables in `train.env` and `deploy.env` are correctly setup. You can manually source these files and check the output in the shell.
Take the `train.env` for example:
```shell
source train.env
```
These `.env` files will be sourced in the shell scripts (in `holomotion/scripts`) to correctly find and utilize your conda environments.
================================================
FILE: docs/evaluate_motion_tracking.md
================================================
## Evaluate the Motion Tracking Model
After training for a while and saving model checkpoints, it is necessary to run the evaluation pipeline to get to know your model performance both visually and quantitatively. HoloMotion also bakes the model exporting process for later deployment in the evaluation pipeline.
**Overall Workflow:**
```mermaid
flowchart LR
A[Trained Checkpoints]
B[HDF5 Database]
C[Evaluation Config]
D[Evaluation Entry]
E[Offline Evaluation]
F[Calculate Metrics]
G[MuJoCo Visualization]
A --> D
B --> D
C --> D
D --> E
E --> F
E --> G
classDef dashed stroke-dasharray: 5 5, rx:10, ry:10, fill:#c9d9f5
classDef normal fill:#c9d9f5, rx:10, ry:10
class A,B dashed
class C,D,E,F,G normal
```
### 1 Offline Evaluation
```bash
bash ./holomotion/scripts/evaluation/eval_motion_tracking.sh
```
Update the evaluation script by setting `checkpoint_path` (e.g., `logs/Holomotion/model_1000.pt`) and `eval_h5_dataset_path`.
### 2 Calculate Metrics
Process the `.npz` files generated in the previous step and convert them into a final quantitative JSON metrics report:
```bash
bash ./holomotion/scripts/evaluation/calc_offline_eval_metrics.sh
```
- `npz_dir`: Path to the folder containing `.npz` result files.
- `dataset_suffix`: Evaluation dataset name, set to differentiate different datasets.
### 3 MuJoCo Visualization
Generate video outputs to validate the motion tracking quality from the `.npz` result files by setting the `motion_npz_root` to the evaluation npz folder. Note that in order to properly visualize the recorded robot data, you should set the `+key_prefix="robot_"` .
```bash
bash ./holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh
```
- `motion_npz_root`: Path to the folder containing `.npz` result files.
- `video_rendering/{motion_name}.mp4` files in the corresponding `.npz` result files.
### 4 Export Trained Model to ONNX
To deploy our policy to real world robots, we need to convert the pytorch module into ONNX format, which is supported by most inference frameworks.
After running the evaluation script, the `.onnx` file will be generated and saved to the checkpoint directory:
```
logs/HoloMotion/your_checkpoint_dir/
├── config.yaml
├── exported
│ └── model_10000.onnx
└── model_10000.pt
```
================================================
FILE: docs/holomotion_motion_file_spec.md
================================================
## HoloMotion NPZ Format — Keys and Values
This document lists the exact keys saved in a HoloMotion NPZ and their value types/shapes.
- Prefix policy
- ref\_\*: reference motion (source-of-truth produced by preprocessing)
- ft*ref*_: filtered reference motion (post-filtering; never overwrites ref\__)
- robot\_\*: robot states (only present in offline evaluation exports)
- Legacy (no prefix): kept only for backward-compat; new files prefer ref\_\*
- metadata
- type: JSON string
- fields:
- motion_key: str
- raw_motion_key: str
- motion_fps: float
- num_frames: int
- wallclock_len: float (seconds, approx (num_frames - 1) / motion_fps)
- num_dofs: int
- num_bodies: int
- clip_length: int (original clip length in frames)
- valid_prefix_len: int (contiguous valid frames from the start)
- ref_dof_pos
- dtype: float32
- shape: [T, num_dofs] (URDF joint order; reference motion)
- ref_dof_vel
- dtype: float32
- shape: [T, num_dofs] (URDF joint order; reference motion)
- ref_global_translation
- dtype: float32
- shape: [T, num_bodies, 3] (meters, world frame; reference motion)
- ref_global_rotation_quat
- dtype: float32
- shape: [T, num_bodies, 4] (quaternion XYZW, world frame; reference motion)
- ref_global_velocity
- dtype: float32
- shape: [T, num_bodies, 3] (m/s, world frame; reference motion)
- ref_global_angular_velocity
- dtype: float32
- shape: [T, num_bodies, 3] (rad/s, world frame; reference motion)
- ft_ref_dof_pos
- dtype: float32
- shape: [T, num_dofs] (filtered reference motion)
- ft_ref_dof_vel
- dtype: float32
- shape: [T, num_dofs] (derived from filtered positions)
- ft_ref_global_translation
- dtype: float32
- shape: [T, num_bodies, 3] (filtered reference motion)
- ft_ref_global_rotation_quat
- dtype: float32
- shape: [T, num_bodies, 4] (filtered, normalized XYZW)
- ft_ref_global_velocity
- dtype: float32
- shape: [T, num_bodies, 3] (derived from filtered positions)
- ft_ref_global_angular_velocity
- dtype: float32
- shape: [T, num_bodies, 3] (derived from filtered quaternions)
- robot_dof_pos
- dtype: float32
- shape: [T, num_dofs] (URDF joint order; robot)
- robot_dof_vel
- dtype: float32
- shape: [T, num_dofs] (URDF joint order; robot)
- robot_global_translation
- dtype: float32
- shape: [T, num_bodies, 3] (meters, world frame; robot)
- robot_global_rotation_quat
- dtype: float32
- shape: [T, num_bodies, 4] (quaternion XYZW, world frame; robot)
- robot_global_velocity
- dtype: float32
- shape: [T, num_bodies, 3] (m/s, world frame; robot)
- robot_global_angular_velocity
- dtype: float32
- shape: [T, num_bodies, 3] (rad/s, world frame; robot)
- dof_pos (deprecated legacy key)
- dtype: float32
- shape: [T, num_dofs] (URDF joint order; ref or robot)
- deprecated
- dof_vels (deprecated legacy key)
- dtype: float32
- shape: [T, num_dofs] (URDF joint order; ref or robot)
- deprecated
- global_translation (deprecated legacy key)
- dtype: float32
- shape: [T, num_bodies, 3] (meters, world frame; ref or robot)
- deprecated
- global_rotation_quat (deprecated legacy key)
- dtype: float32
- shape: [T, num_bodies, 4] (quaternion XYZW, world frame; ref or robot)
- deprecated
- global_velocity (deprecated legacy key)
- dtype: float32
- shape: [T, num_bodies, 3] (m/s, world frame; ref or robot)
- deprecated
- global_angular_velocity (deprecated legacy key)
- dtype: float32
- shape: [T, num_bodies, 3] (rad/s, world frame; ref or robot)
- deprecated
Notes:
- T == num_frames from metadata.
- All arrays are float32.
================================================
FILE: docs/motion_retargeting.md
================================================
# Motion Retargeting
Transform human motion data into robot-compatible joint trajectories for following training. We support GMR for retargeting (https://github.com/YanjieZe/GMR)
## Prerequisites
Before running the motion retargeting pipeline, ensure you have:
### 1. Environment Setup
Please make sure the smplx and GMR are properly installed according to [[environment setup doc](./environment_setup.md)].
### 2. Data Preparation
Place your AMASS motion data in `/assets/test_data/motion_retargeting/{dataset_name}`
or modify 'amass_dir' in 'script/motion_retargeting/\*.sh' !Please check all related path in .sh and .yaml are right!
### 3. Model Preparation
Put SMPLX models under following path
thirdparties/
└── GMR/
├── assets/
│ └── body_models/
│ └── smplx/
│ ├── SMPLX_FEMALE.npz
│ ├── SMPLX_FEMALE.pkl
│ ├── SMPLX_MALE.npz
│ ├── SMPLX_MALE.pkl
│ ├── SMPLX_NETURAL.npz
│ └── SMPLX_NETURAL.pkl
### 4. Path Verification
Check data paths in the configuration scripts:
- `holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh`
Before using GMR, it is recommended to run `bash ./holomotion/scripts/motion_retargeting/apply_gmr_motion_retarget_patch.sh` first, which can help reduce singular solutions to some extent.
## Quick Start
### 1. Motion Retargeting
```bash
bash ./holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh
```
> Reminder: set device = "cuda:0" to "cpu" in "smplx_to_robot_dataset.py" if facing cuda error
After GMR retargeting, we further need to convert the dataset into a HoloMotion-compatible npz format, please run:
```bash
bash ./holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_to_holomotion.sh
```
### 2. Motion Visualization
Generate video outputs to validate retargeting quality of the HoloMotion npz files:
```bash
bash ./holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh
```
**Output**: `video_rendering/{motion_name}.mp4` files in the retargeted data directories
### 3. Pack to HDF5 for Training
After retargeting, we need to pack the npz files into a compact HDF5 database:
```bash
bash ./holomotion/scripts/motion_retargeting/pack_hdf5_dataset.sh
```
================================================
FILE: docs/mujoco_sim2sim.md
================================================
# Sim2Sim Verification
After generating the ONNX file from the evaluation stage, you can verify the performance of your Isaac-trained policy in another simulator, such as Mujoco to test its performance before deploying to the real robot.
The entry script is `holomotion/scripts/evaluation/eval_mujoco_sim2sim.sh`, you should set these variables before running:
- `robot_xml_path`: The scene mjcf .xml file for the robot
- `ONNX_PATH`: The exported ONNX model file
- `motion_npz_path`: The npz file containing the reference motion
================================================
FILE: docs/realworld_deployment.md
================================================
# HoloMotion Deployment Guide
This guide describes how to set up the deployment environment and run the trained policy on a physical Unitree G1 robot.
## Robot Configuration for 29 DOF
The 29 DOF configuration includes:
- 12 leg joints (6 per leg)
- 3 waist joints (yaw, roll, pitch)
- 14 arm joints (7 per arm)
---
## Deployment Options
This guide provides two deployment methods:
| Deployment Method | Target Platform |
| ----------------------------------------------- | ----------------- |
| [Laptop Deployment](#laptop-deployment) | Laptop/Desktop PC |
| [PC2 Docker Deployment](#pc2-docker-deployment) | G1 Robot's PC2 |
Choose the appropriate method based on your setup:
- **For laptop/desktop deployment**: Follow the [Laptop Deployment](#laptop-deployment) section
- **For PC2 on robot hardware**: Follow the [PC2 Docker Deployment](#pc2-docker-deployment) section
### ⚠️ Important Safety Notice
> **For safety reasons, it is strongly recommended to remove the dexterous hands before running the policy.**
---
## Laptop Deployment
### Quick Environment Setup
#### Prerequisites
Ensure the following are installed before proceeding:
- Anaconda or Miniconda
- ROS 2 Humble installed at `/opt/ros/humble`
- MCAP for efficient ROS 2 data recording
- Unitree ROS 2 SDK installed at `~/unitree_ros2/`
#### One-Click Deployment
```bash
cd /deployment
chmod +x deploy_environment.sh
./deploy_environment.sh
```
This script will:
- Create a new conda environment (with CUDA support if available)
- Install Python packages from `requirements/requirements_deploy.txt`
- Install Unitree SDK Python bindings
- Build the ROS 2 workspace under `unitree_g1_ros2_29dof/`
---
### Deploy on Physical G1 Robot (Laptop)
### Setup Overview
The deployment process consists of two types of steps:
| **One-Time Setup** (per computer) | **Every Run** (each time you use the robot) |
| --------------------------------- | ------------------------------------------- |
| Step 1: Network Configuration | Step 3: Power On & Initialize Robot |
| Step 2: Launch Script Setup | Step 4: Launch Policy Controller |
> **Note**: Once you complete Steps 1-2, you only need to do Steps 3-4 for each robot session!
### Step 1: Connect and Configure Network
#### Prerequisites for Network Setup:
1. **Power on the robot** and wait for it to fully boot
2. **Use an Ethernet cable** to connect your PC to the robot's LAN port
3. **Ensure both devices are powered on** during configuration
#### Network Configuration:
Configure your PC's network interface with the following static IP settings:
- **Static IP**: `192.168.123.222`
- **Netmask**: `255.255.255.0`
- **Gateway**: (leave empty)
#### Automatic Configuration Script:
You can use the following script to configure it automatically (use command `nmcli con show` to check your actual connection name):
Click to view set_static_ip.sh
```bash
#!/bin/bash
# Replace with your actual connection name (use `nmcli con show` to check)
CON_NAME="Wired connection 1"
IP_ADDRESS="192.168.123.222"
NETMASK="24"
GATEWAY=""
nmcli con modify "$CON_NAME" ipv4.addresses "$IP_ADDRESS/$NETMASK"
nmcli con modify "$CON_NAME" ipv4.method manual
if [ -n "$GATEWAY" ]; then
nmcli con modify "$CON_NAME" ipv4.gateway "$GATEWAY"
fi
nmcli con modify "$CON_NAME" ipv4.dns ""
nmcli con down "$CON_NAME" && nmcli con up "$CON_NAME"
```
---
### Step 2: Prepare Launch Script
#### Configure Network Interface:
1. **Check your network interface name** (while connected to the robot):
```bash
ifconfig
```
Look for the interface connected to the robot (e.g., `enxf8e43ba00afd`, `eth0`, `enp0s31f6`)
2. **Update the launch configuration**:
```bash
# Edit the launch file
nano /deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py
```
Find and update the `network_interface` parameter with your actual interface name.
---
### Step 3: Power On and Initialize the Robot
> **Do this every time** you want to run the robot.
#### Robot Initialization Sequence for 29 DOF:
1. **Power on the robot** - Start the robot in the **hanging position**
2. **Wait for zero torque mode** - The robot will automatically enter zero torque mode (joints feel loose)
3. **Connect your computer** - Use the same Ethernet cable to connect to the robot's LAN port
4. **Enter debugging mode** - On the remote controller, press `L2 + R2` simultaneously. Note: the new deployment automatically enters this mode on startup, so manual entry is usually not required.
---
### Step 4: Launch the Policy Controller
#### Preflight Checklist
Before running, ensure the following are ready.
- Model folders configured in `g1_29dof_holomotion.yaml` exist
- `motion_tracking_model_folder`: under `src/models/`
- `velocity_tracking_model_folder`: under `src/models/`
- Motion data directory exists and contains .npz files (retargeted results)
- `motion_clip_dir`: under `src/motion_data/`
- Config file path used by launch is correct
#### Motion Reference Source
Motion tracking supports two reference sources:
- **Offline motion mode**: the robot executes the selected `.npz` motion clip from `motion_clip_dir`.
- **Online teleoperation mode**: the robot uses live `latest_obs` data streamed from the teleoperation workstation / VR pipeline. See [Holomotion teleop setup](../deployment/holomotion_teleop/holomotion_teleop_setup.md) for Pico / XRoboToolkit, ZMQ publishing, and launch order on the workstation.
The mode is selected by YAML settings in `g1_29dof_holomotion.yaml`:
- **Offline motion mode**
- `vr.enable_teleop_reference: false`
- `vr.require_vr_data_for_motion: false`
- Result: even if ZMQ data is still arriving, the robot ignores it and motion tracking uses offline `.npz` clips only.
- **Online teleoperation mode**
- `vr.enable_teleop_reference: true`
- `vr.require_vr_data_for_motion: true`
- `vr.latest_obs_zmq_uri: "tcp://:6001"`
- Result: motion mode waits for live teleoperation data before entering and uses the incoming `latest_obs` stream as the motion reference.
If you want teleoperation to be available but not mandatory, you can also use:
- `vr.enable_teleop_reference: true`
- `vr.require_vr_data_for_motion: false`
In that configuration, motion mode can still start without waiting for VR readiness, but live teleoperation data may be used when available at mode entry.
#### One-click start
```bash
cd /deployment/unitree_g1_ros2_29dof
bash launch_holomotion_29dof.sh
```
> **Success indicator**: On startup, the robot joints should remain in zero torque state and feel free to move.
#### Motion Control Modes
The 29 DOF robot operates in two main modes:
| Mode | How to Enter | Controls | Switch |
| ------- | --------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------- | ------------------ |
| Velocity tracking | 1) Press Start to stand up, then press A 2) From motion tracking: press Y | Left stick: move (vx, vy) Right stick: rotate (yaw) D-Pad: select motion clip (Left=first, Right=last, Up=prev, Down=next) | B: enter motion tracking |
| Motion tracking | Press B | Executes selected motion clip or online teleoperation automatically | Y: back to velocity tracking |
#### Control Flow
Here is the robot control flowchart for 29 DOF:
```mermaid
flowchart TD
subgraph prepPhase ["Setup Phase"]
direction TB
A[Set Robot to Hanging Position] --> B[Power On and Zero Torque Mode]
B --> C[L2+R2: Enter Debug Mode]
C --> D[Launch Program]
end
%% Main flow
prepPhase --> E[Start: Stand Up]
E --> F[Lower Robot to Ground]
F --> G[A: Enter Velocity tracking Mode]
%% Velocity tracking mode controls
G --> H[Velocity tracking Mode]
H --> H1[Left Stick: Move]
H --> H2[Right Stick: Rotate]
H --> H3[D-Pad: Select Motion Clip]
H --> H4[B: Enter Motion tracking Mode]
%% Motion tracking mode
H4 --> I[Motion tracking Mode]
I --> I1[Execute Motion Clip or Teleoperation]
I --> I2[Y: Back to Velocity tracking]
%% Mode switching
I2 --> H
%% Emergency stop
D --> N[Select: Emergency Stop]
E --> N
F --> N
G --> N
H --> N
I --> N
N --> O[Close Program]
classDef startEnd fill:#e1f5fe,stroke:#01579b,stroke-width:2px
classDef control fill:#f3e5f5,stroke:#4a148c,stroke-width:2px
classDef velocityTracking fill:#e8f5e8,stroke:#1b5e20,stroke-width:2px
classDef motionTracking fill:#fff3e0,stroke:#e65100,stroke-width:2px
classDef emergency fill:#ffebee,stroke:#b71c1c,stroke-width:2px,stroke-dasharray: 5 5
classDef preparationFrame fill:#f9f9f9,stroke:#666,stroke-width:2px,stroke-dasharray: 5 5
class A,O startEnd
class B,C,D,E,F,G control
class H,H1,H2,H3,H4 velocityTracking
class I,I1,I2 motionTracking
class N emergency
class prepPhase preparationFrame
```
#### Configuration Files (used by Step 4)
**System Configuration**
- File: `HoloMotion/deployment/unitree_g1_ros2_29dof/src/config/g1_29dof_holomotion.yaml`
- Key parameters:
- `motion_tracking_model_folder`: motion tracking model folder under `models/`
- `velocity_tracking_model_folder`: velocity tracking model folder under `models/`
- `motion_clip_dir`: motion clip data folder under `src/`
- `vr.enable_teleop_reference`: enable or disable live teleoperation reference
- `vr.require_vr_data_for_motion`: whether motion mode must wait for live teleoperation data
- `vr.latest_obs_zmq_uri`: teleoperation ZMQ endpoint used in online mode
**Pre-trained Models**
We provide a pre-trained velocity tracking model that you can download and use:
- **Motion Tracking Model**: Download from [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_motion_tracking_model)
- **Velocity Tracking Model**: Download from [Hugging Face](https://huggingface.co/HorizonRobotics/HoloMotion_v1.2/tree/main/holomotion_v1.2_velocity_tracking_model)
To use this model:
1. Download the `holomotion_v1.2_velocity_tracking_model` folder from the Hugging Face repository
2. Place the downloaded folder under `models/` (e.g., `models/holomotion_v1.2_velocity_tracking_model/`)
4. Update `velocity_tracking_model_folder` in the `g1_29dof_holomotion.yaml` to point to this folder
**Adding New Motion Tracking Models**
1. Create a new folder under `models/` based on the following example model folder structure (e.g., `models/your_model_dir_name/`)
2. Update `motion_tracking_model_folder` in the `g1_29dof_holomotion.yaml`
3. Ensure the motion clip data files are in the `motion_clip_dir`
Example model folder structure (motion model):
```bash
HoloMotion/deployment/unitree_g1_ros2_29dof/src/models/your_model_dir_name
├── config.yaml
├── exported
└── your_model_name.onnx
```
---
### Safety Notice
This deployment is intended for demonstration only. It is not a production-grade control system. Do not interfere with the robot during operation. If unexpected behavior occurs, exit control immediately via the controller or keyboard to ensure safety.
To stop the control process, press `Select` or use `Ctrl+C` in the terminal.
---
## PC2 Docker Deployment
### Setup Overview
The deployment process consists of two types of steps:
| **One-Time Setup** (per PC2) | **Every Run** (each time you use the robot) |
| ----------------------------- | -------------------------------------------- |
| Step 1: Configure Docker | Step 4: Start Docker Container |
| Step 2: Load Docker Image | Step 5: Power On & Initialize Robot |
| Step 3: Configure Launch File Network Interface | Step 6: Launch Policy Controller |
> **Note**: Once you complete Steps 1-3, you only need to do Steps 4-6 for each robot session!
### System Requirements
- **Platform**: NVIDIA Jetson Orin
- **JetPack**: 5.1
- **Ubuntu**: 20.04
- **ROS 2**: Foxy
- **Docker**: Installed with NVIDIA Container Runtime support
### Step 1: Configure Docker for NVIDIA Runtime
Modify `/etc/docker/daemon.json`:
```json
{
"runtimes": {
"nvidia": {
"path": "nvidia-container-runtime",
"runtimeArgs": []
}
},
"default-runtime": "nvidia"
}
```
Restart Docker and verify:
```bash
sudo systemctl restart docker
sudo docker info | grep -i runtime
```
### Step 2: Load Docker Image
Pull the image from dockerhub with:
```bash
docker pull horizonrobotics/holomotion:orin_foxy_jp5.1_docker_humble_deploy_zmq_20260319
```
Or if you have the image locally, tag it appropriately:
```bash
docker tag holomotion:orin_foxy_jp5.1_docker_humble_deploy_zmq_20260319
```
### Step 3: Configure Launch File Network Interface:
1. **Check your network interface name on the robot**:
```bash
ifconfig
```
Look for the interface with IP `192.168.123.164`. The interface is typically `eth0`.
2. **Update the launch configuration** if your interface is not `eth0`:
```bash
nano /deployment/unitree_g1_ros2_29dof/src/launch/holomotion_29dof_launch.py
```
Find line 103 and update the `network_interface` parameter:
```python
network_interface = "eth0" # Change to your actual interface name
```
### Step 4: Start Docker Container
> **Important**: Before running Docker commands, ensure your user is added to the docker group. If you encounter permission errors, add your user to the docker group:
```bash
sudo usermod -aG docker $USER
```
After adding your user to the docker group, you need to log out and log back in (or restart your session) for the changes to take effect. Verify with:
```bash
groups
```
You should see `docker` in the list of groups.
> **Important**: You need to run this step **every time** you want to use the robot. The script will automatically remove any existing container and start a fresh one.
```bash
cd /deployment/unitree_g1_ros2_29dof
bash start_container.sh
```
**When prompted, enter the holomotion repository path:**
- The script will ask: `Please enter the holomotion local repository path:`
- Enter the full path to your holomotion repository, for example:
- `/home/unitree/HoloMotion` (if the repository is at this location)
- Or the actual path where your holomotion repository is located
### Step 5: Power On and Initialize Robot
> **Do this every time** before launching the policy controller.
1. Put the robot in hanging position
2. Wait for zero torque mode
3. Press `L2 + R2` on remote controller for debug mode
### Step 6: Launch Policy Controller
> **Do this every time** you want to run the robot (after Steps 4 and 5).
**Preflight Checklist**
Before running, ensure the following are ready.
- Model folders configured in `g1_29dof_holomotion.yaml` exist
- `motion_tracking_model_folder`: under `src/models/`
- `velocity_tracking_model_folder`: under `src/models/`
- Motion data directory exists and contains .npz files (retargeted results)
- `motion_clip_dir`: under `src/motion_data/`
- Config file path used by launch is correct
- Motion reference source is configured as intended (see [Motion Reference Source](#motion-reference-source))
**Pre-trained Models**
You can download and use the pre-trained velocity tracking model. Refer to the [Pre-trained Models](#configuration-files-used-by-step-4) section above for general instructions.
> **Note**: The model folder should be placed in your local repository before starting the Docker container, as the repository is mounted into the container.
**One-click start in the docker**
```bash
cd /home/unitree/holomotion/deployment/unitree_g1_ros2_29dof
bash launch_holomotion_29dof_docker.sh
```
> **Note**: The control flow is the same as described in the [Control Flow](#control-flow) section above.
---
### Safety Notice
This deployment is intended for demonstration only. It is not a production-grade control system. Do not interfere with the robot during operation. If unexpected behavior occurs, exit control immediately via the controller or keyboard to ensure safety.
To stop the control process:
- Press `Select` on the remote controller, or
- Use `Ctrl+C` in the terminal (inside Docker container)
================================================
FILE: docs/smpl_data_curation.md
================================================
# Dataset Preparation Guide
This guide describes the workflow and setup for preparing datasets to train the motion tracking model.
We use **AMASS-compatible SMPL-format motion capture data** as the training input.
---
## Overview
The dataset preparation pipeline has the following steps:
1. **Download datasets**
- To train with diverse and rich motion data, you first need to collect raw motion capture datasets from various sources.
- Then place all downloaded datasets under the data/raw_datasets directory in their original structure.
2. **Convert datasets to AMASS format**
- To ensure that all motion data is compatible with the AMASS-style .npz format used by the training pipeline, you need to convert the raw datasets.
- Then run the conversion script to generate .npz files under data/amass_compatible_datasets/.
3. **Filter datasets**
- To improve data quality by removing abnormal, noisy, or unwanted motion samples, you can optionally run the filtering step.
- Then run the filtering script to generate filtered .yaml files under holomotion/config/data_curation/.
4. **Visualize Prepared Data**
- Use the included visualization utility to preview and inspect the generated AMASS-compatible `.npz` motion files.
- Quickly check for anomalies or errors before training.
5. **Generate Motion from Monocular Video**
- You can also generate SMPL-format motion capture files **directly from monocular RGB videos** using GVHMR.
- This allows you to create training data or test the model with real-world video footage.
- Pipeline are given follow.
### Directory Structure After Full Setup
```
data/
├── raw_datasets/
│ ├── humanact12/
│ ├── OMOMO/
│ ├── MotionX/
│ └── ZJU_Mocap/
├── amass_compatible_datasets/
│ ├── amass/
│ │ ├── ACCAD/
│ │ ├── BioMotionLab_NTroje/
│ │ ├── ...
│ ├── humanact12/
│ ├── OMOMO/
│ ├── MotionX/
│ └── ZJU_Mocap/
├── dataset_labels/
│ ├── humanact12.jsonl
│ ├── OMOMO.jsonl
│ ├── MotionX.jsonl
│ ├── ZJU_Mocap.jsonl
│ ├── amass.jsonl
```
---
## Step-by-Step Instructions
### 1. Download Datasets
Download and extract the datasets into the `data/` folder as follows:
- `data/amass_compatible_datasets/amass/` (required)
- [AMASS dataset](https://amass.is.tue.mpg.de/download.php) — choose **SMPL-X G** format.
- `data/raw_datasets/humanact12/` (optional)
- [HumanAct12](https://github.com/EricGuo5513/action-to-motion?tab=readme-ov-file)
- `data/raw_datasets/OMOMO/` (optional)
- [OMOMO dataset](https://github.com/lijiaman/omomo_release?tab=readme-ov-file)
- `data/raw_datasets/MotionX/` (optional)
- [MotionX dataset](https://github.com/IDEA-Research/Motion-X)
- `data/raw_datasets/ZJU_Mocap/` (optional)
- [EasyMocap](https://github.com/zju3dv/EasyMocap)
---
### 2. Convert Optional Datasets to AMASS Format (optional)
Skip this step if you only use amass dataset.
Step:
1. Initialize Submodules:
Some datasets require external repositories or models for proper conversion. These are included as **submodules** under:
```
thirdparties/
```
Initialize them with:
```bash
git submodule update --init --recursive
```
If you need to modify or update these submodules, refer to their individual `README` files.
2. Download SMPL Model:
- Download `SMPL_NEUTRAL.npz` from the [SMPL official website](https://smpl.is.tue.mpg.de/download.php).
- Place the file into:
```
./assets/smpl/
```
3. Modify `thirdparties/joints2smpl/src/customloss.py`:
Before running the pipeline, make sure to modify the `body_fitting_loss_3d` function in `thirdparties/joints2smpl/src/customloss.py` to include the following change:
```python
joint3d_loss = (joint_loss_weight ** 2) * joint3d_loss_part.sum(dim=-1)
```
4. Modify `thirdparties/joints2smpl/src/smplify.py`:
Next, ensure the following modification in the `__call__` function of `SMPLify3D` inside `thirdparties/joints2smpl/src/smplify.py`:
```python
init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach()
```
5. Start Conversion:
Run the provided script to convert all available datasets to AMASS `.npz` files:
```bash
bash holomotion/scripts/data_curation/convert_to_amass.sh
```
This script reads from `data/{dataset}/` and writes to `data/amass_compatible_datasets/{dataset}/`.
---
### 3. Filter Datasets (optional)
Skip this step if you prefer to use all available data for training.
#### Why filter?
The raw datasets may contain motions that are irrelevant, undesirable, or of poor quality for training. This step helps improve the overall quality and consistency of your dataset.
#### Filtering criteria
The filtering process excludes samples based on the following rules:
- **Upstairs/Downstairs motion:**
Paths containing keywords like stairs, staircase, upstairs, downstairs, or motions with large upward/downward Z translation and velocity are excluded.
- **Sitting motion:**
Sitting motion: Paths containing sitting/Sitting keywords or frames that match a reference sitting pose are excluded.
- **Known abnormal datasets:**
Known abnormal datasets: Samples from datasets like aist is excluded.
- **Unrealistic velocity:**
Unrealistic velocity: Motions where the mean velocity exceeds a threshold (default: 100.0) are excluded.
#### How to run
You can use the `-l` option to specify which datasets to filter (space-separated list).
Run the filtering script to identify and exclude abnormal or unwanted samples:
```bash
bash holomotion/scripts/data_curation/filter_smpl_data.sh -l "amass humanact12 OMOMO MotionX ZJU_Mocap"
```
The output `.yaml` files will be placed in `holomotion/config/data_curation/`.
---
## Notes
- Paths are relative to the project root.
- The AMASS dataset must be manually requested from their website.
- Dataset conversion and filtering may take time depending on your hardware.
---
This guide assumes that you only need the basic configuration to run the complete pipeline. For further customization, refer to the relevant scripts in the repository and optional steps in the documentation.
## Video2SMPL Instructions
### 1. Environment setup
Create a new conda environment named 'gvhmr' following the official installation guide
[[GVHMR setup doc](../thirdparties/GVHMR/docs/INSTALL.md)]
> Reminder: If you encounter 'hmr4d' missing module errors during runtime, install the following dependency separately in gvhmr env
```bash
pip install hmr4d
```
Rename the SMPL model files as follows
basicmodel_{GENDER}_lbs_*.pkl → SMPL_{GENDER}.pkl
Place the SMPL and SMPL-X model files into the directory structure below
```
thirdparties/GVHMR/inputs/checkpoints/
├── body_models/smplx
│ └── SMPLX_{GENDER}.npz
└── body_models/smpl
└── SMPL_{GENDER}.pkl
```
### 2. Video to SMPL motion data
Confirm that all input videos have a frame rate of 30 FPS to avoid motion acceleration or deceleration.
```bash
bash ./holomotion/scripts/data_curation/video_to_smpl_gvhmr.sh
```
> Reminder: Set the directory in the .sh file to an absolute path.
### 3. Visualize generated SMPL motion data
Visualize the SMPL motion sequences generated by GVHMR for inspection and debugging.
```bash
bash ./holomotion/scripts/data_curation/visualize_smpl_gvhmr.sh
```
### 4. Convert SMPL data to SMPLX
Motion data from GVHMR are SMPL format.
Use ./thirdparties/GMR/scripts/smpl_to_smplx.py converting format to SMPLX for retargeting.
================================================
FILE: docs/train_motion_tracking.md
================================================
## Train the Motion Tracking Model
After completing motion retargeting, you can train a motion tracking model with HoloMotion using the following process.
**Overall Workflow:**
```mermaid
flowchart LR
A[Motion Retargeting] --> B[HDF5 Database]
B --> C[Training Config]
C --> D[Training Entry]
D --> E[Distributed PPO Training]
classDef dashed stroke-dasharray: 5 5, rx:10, ry:10, fill:#c9d9f5
classDef normal fill:#c9d9f5, rx:10, ry:10
class A dashed
class B,C,D,E normal
```
### 1. Train the Motion Tracking Model
The training entry point is `holomotion/src/training/train.py`, which uses the training config to start distributed training across multiple GPUs.
#### 2.1 Prepare the Training Config
Use the demo config at `holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking.yaml` as a template. Key configuration groups to modify (configs are located in the `holomotion/config/` directory):
- **`/algo`**: Algorithm settings (PPO) and network configurations
- **`/robot`**: Robot-specific config including DOF, body links, and control parameters
- **`/env`**: Environment settings including motion sampling and curriculum learning
- **`/env/observations`**: Observation dimensions, noise, and scaling for the policy
- **`/env/rewards`**: Reward function definitions
- **`/env/domain_randomization`**: Domain randomization settings (start with `NO_domain_rand`)
- **`/env/terrain`**: Terrain configuration
- **`/modules`**: The policy network modules definitions
```yaml
# @package _global_
defaults:
- /training: train_base
- /algo: ppo
- /robot: unitree/G1/29dof/29dof_training_isaaclab
- /env: motion_tracking
- /env/terminations: termination_motion_tracking
- /env/observations: motion_tracking/obs_motion_tracking_tf-moe
- /env/rewards: motion_tracking/rew_motion_tracking
- /env/domain_randomization: domain_rand_medium
- /env/terrain: isaaclab_plane
- /modules: motion_tracking/motion_tracking_tf-moe
project_name: HoloMotion
```
#### 2.2 Train your Policy
Review and modify the training script at `holomotion/scripts/training/train_motion_tracking.sh`. Ensure `config_name` match your training config and LMDB database directory.
Start training by running:
```shell
bash holomotion/scripts/training/train_motion_tracking.sh
# or
bash holomotion/scripts/training/train_velocity_tracking.sh
```
Note that IsaacLab relies on internet connections to pull assets from Nvidia's cloud storage. If you encountered stuck at scene creation, it is very likely that you can't access the cloud-hosted assets. Turn on your proxy and try again can solve the issue.
### Training Tips
#### How to use less GPU ?
Training requires significant GPU memory. Reduce `num_envs` if your GPU has limited GRAM. This will reduce both the rollout burden and the PPO training consumption, at the risk of significantly less stable policy optimization process.
#### How to start multiple training session ?
In cases where you would like to start multiple training sessions, you should explicitly add the `--main_process_port=port_number` option in the training entry bash script to avoid port conflict of the accelerate backend. And this `port_number` **can not** be `0` .
If you would like to run training on a specific GPU, just modify the GPU id in the `export CUDA_VISIBLE_DEVICES="X"` statement.
#### How to set the save/log intervals ?
You may want to have more or less frequent logging and model dumping intervals. You can alter these intervals by adding the following options:
- `algo.config.save_interval=X` : The checkpoint will be saved every `X` learning iterations.
- `algo.config.log_interval=Y`: The logging information will be displayed every `Y` learning iterations.
#### Where is the checkpoint dumped ?
By default, the model checkpoint will be dumped into a folder named `logs/HoloMotion`. You can change this path by explictly setting `project_name=X`, which results in dumping the checkpoints into the `logs/X` directory.
#### How to resume training from a checkpoint ?
To resume training from a pretrained checkpoint, you can find the checkpoint in the log directory, and then add the option like this: `checkpoint=logs/HoloMotion/20250728_214414-train_unitree_g1_21dof_teacher/model_X.pt`
================================================
FILE: environments/environment_deploy.yaml
================================================
name: holomotion_deploy
channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
# Python runtime
- python=3.10
# PyTorch with CUDA support
- pytorch-cuda=12.1
- cudnn>=9
- cudatoolkit>=11.7,<12
- pytorch==2.3.1
- torchvision==0.18.1
- torchaudio==2.3.1
# Scientific computing packages (via conda for better compatibility)
- numpy==1.24.3
- scipy
- matplotlib
- pandas
# System utilities and development tools
- sshpass=1.06
- git
- curl
- wget
- pyyaml
- easydict
- joblib
# Basic Python package management
- pip
- setuptools
- wheel
# Install additional packages via pip
- pip:
- -r environments/requirements_deploy.txt
================================================
FILE: environments/environment_train_isaaclab_cu118.yaml
================================================
name: holomotion_train
channels:
- conda-forge
dependencies:
- python=3.11
- pip
- mesalib
- pip:
- -r requirements_torch_cu118.txt
- -r requirements_base.txt
- -e ../thirdparties/SMPLSim
- -e ../
================================================
FILE: environments/environment_train_isaaclab_cu128.yaml
================================================
name: holomotion_train
channels:
- conda-forge
dependencies:
- python=3.11
- pip
- mesalib
- pip:
- -r requirements_torch_cu128.txt
- -r requirements_base.txt
- -e ../thirdparties/SMPLSim
- -e ../
================================================
FILE: environments/requirements_base.txt
================================================
isaacsim[all,extscache]==5.0.0
isaaclab[isaacsim,all]==2.2.0
setuptools
wheel
numpy==1.26.0
smplx==0.1.28
hydra-core==1.3.2
easydict
tqdm
open3d
lxml
ray
ipdb
joblib
scipy
jupyter
loguru
tensorboard
mujoco
mink
dm_control
loop_rate_limiters
qpsolvers[quadprog,proxqp]
accelerate
tabulate
matplotlib
pandas
termcolor
rich
pytorch-tcn
einops
onnxruntime-gpu
onnx
pre-commit
ruff
pytest
imageio>=2.9
imageio-ffmpeg
opencv-python
natsort
psutil
redis[hiredis]
chumpy
pyvirtualdisplay
pynput
xxhash
h5py>=3.8
pygame
tensordict==0.11.0
pytorch_kinematics
onnxscript
# human_body_prior 依赖
human-body-prior
================================================
FILE: environments/requirements_deploy.txt
================================================
# Machine Learning Runtime
onnxruntime-gpu
# SMPL/SMPLX support
smplx==0.1.28
# Configuration management
hydra-core==1.3.2
omegaconf
# Progress and logging
tqdm
loguru
termcolor
rich
# Data processing
lmdb
einops
# Protobuf (specific version for compatibility)
protobuf==3.20.3
onnx
# Development tools
ipdb
mujoco
pygame
# Note: The following packages are installed via conda in environment_deploy.yaml:
# - torch, torchvision, torchaudio (with CUDA support)
# - numpy, scipy, matplotlib, pandas
# - pyyaml, easydict, joblib
# - system utilities (git, curl, wget, sshpass)
================================================
FILE: environments/requirements_torch_cu118.txt
================================================
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://pypi.nvidia.com
torch==2.7.0
torchvision==0.22.0
torchaudio==2.7.0
================================================
FILE: environments/requirements_torch_cu128.txt
================================================
--extra-index-url https://download.pytorch.org/whl/cu128
--extra-index-url https://pypi.nvidia.com
torch==2.10.0+cu128
torchvision==0.25.0+cu128
torchaudio==2.10.0+cu128
================================================
FILE: environments/requirements_torch_cu130.txt
================================================
--extra-index-url https://download.pytorch.org/whl/cu130
--extra-index-url https://pypi.nvidia.com
torch==2.10.0+cu130
torchvision==0.25.0+cu130
torchaudio==2.10.0+cu130
================================================
FILE: holomotion/config/algo/ppo.yaml
================================================
# @package _global_
algo:
_target_: holomotion.src.algo.ppo.PPO
_recursive_: false
config:
# --- General Settings ---
enable_online_eval: false
num_learning_iterations: 10001
log_interval: 5
save_interval: 500
export_policy: true
onnx_name_suffix: null
use_kv_cache: true
eval_interval: null
load_optimizer: true
headless: ${headless}
# ---
# --- Accelerate Settings ---
mixed_precision: null # "fp16", "bf16", or null. Use "bf16" for A100/H100, "fp16" for older GPUs
dynamo_backend: "inductor" # "inductor", "aot_eager", "cudagraphs", or null. Enables automatic model compilation during prepare()
# ---
# --- PPO Related Settings ---
init_at_random_ep_len: true
num_steps_per_env: 32
num_learning_epochs: 3
num_mini_batches: 4
clip_param: 0.2
gamma: 0.99
lam: 0.95
value_loss_coef: 1.0
entropy_coef: 5.0e-3
anneal_entropy: false
zero_entropy_point: 1.0
max_grad_norm: 1.0
use_clipped_value_loss: true
desired_kl: 0.01
init_noise_std: 1.0
# --- Optimizer Settings ---
optimizer_type: AdamW # Options: "AdamW", "Adam"
schedule: adaptive
actor_learning_rate: 3.0e-4
critic_learning_rate: 5.0e-4
adaptive_lr:
adapt_critic: false
lr_scaler: 1.2
kl_high_factor: 2.0
kl_low_factor: 0.5
min_learning_rate: 1.0e-6
max_learning_rate: 1.0
distributed_update:
mode: scalable # Options: "legacy", "scalable"
lr_scale:
mode: sqrt_world_size # Options: "none", "sqrt_world_size", "linear_world_size"
reference_world_size: 1
max_scale: null
kl_early_stop:
enabled: true
signal: window_mean # Shared windowed KL control signal
window_size: 3
factor: 1.8
min_updates: 1
# Distributed training settings
normalize_advantage_per_mini_batch: false # Use global advantage norm for DDP
global_advantage_norm: true # Sync advantages across all ranks
# --- Sampling Strategy ---
sampling_strategy: uniform
curriculum:
p_a_ratio: 0.5
ema_alpha_signal: 0.2
ema_alpha_rel_improve: 0.2
relative_eps: 1.0e-6
dump_whole_window_scores_json: false
dump_whole_window_scores_every_swaps: 10
weighted_bin:
bin_regex_patterns: []
dump_sampled_keys: false
dump_sampled_keys_interval: 1000
# --- Module Settings ---
module_dict:
actor: ${modules.actor}
critic: ${modules.critic}
symmetry_loss:
enabled: false
coef: 0.1
dof_sign_by_name: ${robot.dof_sign_by_name}
================================================
FILE: holomotion/config/algo/ppo_tf.yaml
================================================
# @package _global_
defaults:
- ppo
algo:
_target_: holomotion.src.algo.ppo_tf.PPOTF
config:
num_steps_per_env: 32
kl_coef: 0.0
schedule: adaptive
actor_learning_rate: 3.0e-5
critic_learning_rate: 5.0e-5
num_learning_epochs: 3
num_mini_batches: 24
clip_param: 0.2
entropy_coef: 5.0e-3
desired_kl: 0.01
noise_std_type: log
fix_sigma: false
init_noise_std: 1.0
min_sigma: 0.01
max_sigma: 1.2
aux_state_pred:
enabled: true
w_keybody_contact: 1.0e-2
w_base_lin_vel: 1.0e-2
w_ref_keybody_rel_pos: 1.0e-1
w_robot_keybody_rel_pos: 1.0e-1
min_std: 0.01
max_std: 2.0
keybody_contact_names:
- left_hip_pitch_link
- right_hip_pitch_link
- left_knee_link
- right_knee_link
- left_ankle_roll_link
- right_ankle_roll_link
- left_elbow_link
- right_elbow_link
- left_wrist_yaw_link
- right_wrist_yaw_link
keybody_rel_pos_names:
- left_knee_link
- right_knee_link
- left_ankle_roll_link
- right_ankle_roll_link
- left_elbow_link
- right_elbow_link
- left_wrist_yaw_link
- right_wrist_yaw_link
dead_expert_margin_to_topk:
enabled: true
weight: 10.0
aux_router_command_recon:
enabled: false
weight: 0.0
hidden_dim: 0
term_prefix: actor_ref_
aux_router_switch_penalty:
enabled: false
weight: 0.0
router_expert_orthogonal:
enabled: false
weight: 0.0
min_active_usage: 1.0e-3
eps: 1.0e-8
selected_expert_margin_to_unselected:
enabled: false
weight: 0.0
target: 0.0
moe_router:
routing_score_fn: softmax
routing_scale: 1.0
use_dynamic_bias: false
bias_update_rate: 0.001
expert_bias_clip: 0.0
================================================
FILE: holomotion/config/data_curation/joints2smpl.yaml
================================================
================================================
FILE: holomotion/config/data_curation/smplify_base.yaml
================================================
================================================
FILE: holomotion/config/env/domain_randomization/NO_domain_rand.yaml
================================================
# @package _global_
domain_rand:
action_delay:
enabled: false
erfi:
enabled: false
motion_init_perturb:
root_pose_perturb_range:
x: [0.0, 0.0]
y: [0.0, 0.0]
z: [0.0, 0.0]
roll: [0.0, 0.0]
pitch: [0.0, 0.0]
yaw: [0.0, 0.0]
root_vel_perturb_range:
x: [0.0, 0.0]
y: [0.0, 0.0]
z: [0.0, 0.0]
roll: [0.0, 0.0]
pitch: [0.0, 0.0]
yaw: [0.0, 0.0]
dof_pos_perturb_range: [0.0, 0.0]
dof_vel_perturb_range: [0.0, 0.0]
obs_noise:
actor_ref_gravity_projection_cur:
n_min: 0
n_max: 0
actor_ref_gravity_projection_fut:
n_min: 0
n_max: 0
actor_ref_base_linvel_cur:
n_min: 0
n_max: 0
n_min_z: 0
n_max_z: 0
actor_ref_base_linvel_fut:
n_min: 0
n_max: 0
n_min_z: 0
n_max_z: 0
actor_ref_base_angvel_cur:
n_min: 0
n_max: 0
n_min_z: 0
n_max_z: 0
actor_ref_base_angvel_fut:
n_min: 0
n_max: 0
n_min_z: 0
n_max_z: 0
actor_ref_dof_pos_cur:
n_min: 0
n_max: 0
actor_ref_dof_pos_fut:
n_min: 0
n_max: 0
actor_ref_root_height_cur:
n_min: 0
n_max: 0
actor_ref_root_height_fut:
n_min: 0
n_max: 0
actor_ref_keybody_rel_pos_cur:
n_min: 0
n_max: 0
actor_ref_keybody_rel_pos_fut:
n_min: 0
n_max: 0
actor_projected_gravity:
n_min: 0.0
n_max: 0.0
actor_rel_robot_root_ang_vel:
n_min: 0.0
n_max: 0.0
actor_dof_pos:
n_min: 0.0
n_max: 0.0
actor_dof_vel:
n_min: 0.0
n_max: 0.0
================================================
FILE: holomotion/config/env/domain_randomization/domain_rand_medium.yaml
================================================
# @package _global_
domain_rand:
action_delay:
enabled: true
min_delay: 0
max_delay: 2
erfi:
enabled: false
rfi_probability: 0.5
rfi_lim: 0.1
randomize_rfi_lim: true
rfi_lim_range: [0.5, 1.5]
rao_lim: 0.1
obs_noise:
actor_ref_gravity_projection_cur:
n_min: -0.1
n_max: 0.1
actor_ref_gravity_projection_fut:
n_min: -0.1
n_max: 0.1
actor_ref_base_linvel_cur:
n_min: -0.1
n_max: 0.1
n_min_z: -0.05
n_max_z: 0.05
actor_ref_base_linvel_fut:
n_min: -0.1
n_max: 0.1
n_min_z: -0.05
n_max_z: 0.05
actor_ref_base_angvel_cur:
n_min: -0.1
n_max: 0.1
n_min_z: -0.1
n_max_z: 0.1
actor_ref_base_angvel_fut:
n_min: -0.1
n_max: 0.1
n_min_z: -0.1
n_max_z: 0.1
actor_ref_dof_pos_cur:
n_min: -0.05
n_max: 0.05
actor_ref_dof_pos_fut:
n_min: -0.05
n_max: 0.05
actor_ref_root_height_cur:
n_min: -0.1
n_max: 0.1
actor_ref_root_height_fut:
n_min: -0.1
n_max: 0.1
actor_ref_keybody_rel_pos_cur:
n_min: -0.1
n_max: 0.1
actor_ref_keybody_rel_pos_fut:
n_min: -0.1
n_max: 0.1
actor_projected_gravity:
n_min: -0.1
n_max: 0.1
actor_rel_robot_root_ang_vel:
n_min: -0.2
n_max: 0.2
actor_dof_pos:
n_min: -0.01
n_max: 0.01
actor_dof_vel:
n_min: -0.5
n_max: 0.5
motion_init_perturb:
root_pose_perturb_range:
x: [-0.05, 0.05]
y: [-0.05, 0.05]
z: [-0.01, 0.01]
roll: [-0.1, 0.1]
pitch: [-0.1, 0.1]
yaw: [-0.2, 0.2]
root_vel_perturb_range:
x: [-0.5, 0.5]
y: [-0.5, 0.5]
z: [-0.2, 0.2]
roll: [-0.5, 0.5]
pitch: [-0.5, 0.5]
yaw: [-0.2, 0.2]
dof_pos_perturb_range: [-0.1, 0.1]
dof_vel_perturb_range: [0.0, 0.0]
default_dof_pos_bias:
mode: startup
params:
joint_names: [".*"]
pos_distribution_params: [-0.01, 0.01]
operation: add
distribution: uniform
rigid_body_com:
mode: startup
params:
body_names: torso_link
com_range:
x: [-0.075, 0.075]
y: [-0.1, 0.1]
z: [-0.1, 0.1]
randomize_mass:
mode: startup
params:
body_names:
- "pelvis"
- "torso_link"
mass_range: [-1.0, 2.0]
rigid_body_material:
mode: startup
params:
body_names: ".*"
static_friction_range: [0.3, 1.6]
dynamic_friction_range: [0.3, 1.2]
restitution_range: [0.0, 0.5]
num_buckets: 64
push_by_setting_velocity:
mode: interval
interval_range_s: [1.0, 3.0]
params:
velocity_range:
x: [-0.5, 0.5]
y: [-0.5, 0.5]
z: [-0.2, 0.2]
roll: [-0.52, 0.52]
pitch: [-0.52, 0.52]
yaw: [-0.78, 0.78]
randomize_actuator_gains:
mode: startup
params:
asset_name: robot
body_names: ".*"
stiffness_distribution_params: [0.9, 1.1]
damping_distribution_params: [0.9, 1.1]
operation: scale
distribution: uniform
================================================
FILE: holomotion/config/env/domain_randomization/domain_rand_small.yaml
================================================
# @package _global_
domain_rand:
action_delay:
enabled: false
min_delay: 0
max_delay: 0
erfi:
enabled: false
rfi_probability: 0.5
rfi_lim: 0.1
randomize_rfi_lim: true
rfi_lim_range: [0.5, 1.5]
rao_lim: 0.1
obs_noise:
actor_ref_gravity_projection_cur:
n_min: -0.1
n_max: 0.1
actor_ref_gravity_projection_fut:
n_min: -0.1
n_max: 0.1
actor_ref_base_linvel_cur:
n_min: -0.1
n_max: 0.1
n_min_z: -0.05
n_max_z: 0.05
actor_ref_base_linvel_fut:
n_min: -0.1
n_max: 0.1
n_min_z: -0.05
n_max_z: 0.05
actor_ref_base_angvel_cur:
n_min: -0.1
n_max: 0.1
n_min_z: -0.1
n_max_z: 0.1
actor_ref_base_angvel_fut:
n_min: -0.1
n_max: 0.1
n_min_z: -0.1
n_max_z: 0.1
actor_ref_dof_pos_cur:
n_min: -0.05
n_max: 0.05
actor_ref_dof_pos_fut:
n_min: -0.05
n_max: 0.05
actor_ref_root_height_cur:
n_min: -0.1
n_max: 0.1
actor_ref_root_height_fut:
n_min: -0.1
n_max: 0.1
actor_ref_keybody_rel_pos_cur:
n_min: -0.1
n_max: 0.1
actor_ref_keybody_rel_pos_fut:
n_min: -0.1
n_max: 0.1
actor_projected_gravity:
n_min: -0.1
n_max: 0.1
actor_rel_robot_root_ang_vel:
n_min: -0.2
n_max: 0.2
actor_dof_pos:
n_min: -0.01
n_max: 0.01
actor_dof_vel:
n_min: -0.5
n_max: 0.5
motion_init_perturb:
root_pose_perturb_range:
x: [-0.05, 0.05]
y: [-0.05, 0.05]
z: [-0.01, 0.01]
roll: [-0.1, 0.1]
pitch: [-0.1, 0.1]
yaw: [-0.2, 0.2]
root_vel_perturb_range:
x: [-0.3, 0.3]
y: [-0.3, 0.3]
z: [-0.1, 0.1]
roll: [-0.3, 0.3]
pitch: [-0.3, 0.3]
yaw: [-0.4, 0.4]
dof_pos_perturb_range: [-0.1, 0.1]
dof_vel_perturb_range: [0.0, 0.0]
default_dof_pos_bias:
mode: startup
params:
joint_names: [".*"]
pos_distribution_params: [-0.01, 0.01]
operation: add
distribution: uniform
rigid_body_com:
mode: startup
params:
body_names: torso_link
com_range:
x: [-0.025, 0.025]
y: [-0.05, 0.05]
z: [-0.05, 0.05]
randomize_mass:
mode: startup
params:
body_names:
- "pelvis"
- "torso_link"
mass_range: [-1.0, 2.0]
rigid_body_material:
mode: startup
params:
body_names: ".*"
static_friction_range: [0.3, 1.6]
dynamic_friction_range: [0.3, 1.2]
restitution_range: [0.0, 0.5]
num_buckets: 64
push_by_setting_velocity:
mode: interval
interval_range_s: [1.0, 3.0]
params:
velocity_range:
x: [-0.5, 0.5]
y: [-0.5, 0.5]
z: [-0.2, 0.2]
roll: [-0.52, 0.52]
pitch: [-0.52, 0.52]
yaw: [-0.78, 0.78]
randomize_actuator_gains:
mode: startup
params:
asset_name: robot
body_names: ".*"
stiffness_distribution_params: [0.9, 1.1]
damping_distribution_params: [0.9, 1.1]
operation: scale
distribution: uniform
================================================
FILE: holomotion/config/env/domain_randomization/domain_rand_strong.yaml
================================================
# @package _global_
domain_rand:
action_delay:
enabled: true
min_delay: 0
max_delay: 4
erfi:
enabled: true
rfi_probability: 0.5
rfi_lim: 0.1
randomize_rfi_lim: true
rfi_lim_range: [0.5, 1.5]
rao_lim: 0.1
obs_noise:
actor_ref_gravity_projection_cur:
n_min: -0.1
n_max: 0.1
actor_ref_gravity_projection_fut:
n_min: -0.1
n_max: 0.1
actor_ref_base_linvel_cur:
n_min: -0.1
n_max: 0.1
n_min_z: -0.05
n_max_z: 0.05
actor_ref_base_linvel_fut:
n_min: -0.1
n_max: 0.1
n_min_z: -0.05
n_max_z: 0.05
actor_ref_base_angvel_cur:
n_min: -0.1
n_max: 0.1
n_min_z: -0.1
n_max_z: 0.1
actor_ref_base_angvel_fut:
n_min: -0.1
n_max: 0.1
n_min_z: -0.1
n_max_z: 0.1
actor_ref_dof_pos_cur:
n_min: -0.05
n_max: 0.05
actor_ref_dof_pos_fut:
n_min: -0.05
n_max: 0.05
actor_ref_root_height_cur:
n_min: -0.1
n_max: 0.1
actor_ref_root_height_fut:
n_min: -0.1
n_max: 0.1
actor_ref_keybody_rel_pos_cur:
n_min: -0.1
n_max: 0.1
actor_ref_keybody_rel_pos_fut:
n_min: -0.1
n_max: 0.1
actor_projected_gravity:
n_min: -0.1
n_max: 0.1
actor_rel_robot_root_ang_vel:
n_min: -0.2
n_max: 0.2
actor_dof_pos:
n_min: -0.01
n_max: 0.01
actor_dof_vel:
n_min: -0.5
n_max: 0.5
motion_init_perturb:
root_pose_perturb_range:
x: [-0.05, 0.05]
y: [-0.05, 0.05]
z: [-0.01, 0.01]
roll: [-0.1, 0.1]
pitch: [-0.1, 0.1]
yaw: [-0.2, 0.2]
root_vel_perturb_range:
x: [-0.5, 0.5]
y: [-0.5, 0.5]
z: [-0.2, 0.2]
roll: [-0.5, 0.5]
pitch: [-0.5, 0.5]
yaw: [-0.2, 0.2]
dof_pos_perturb_range: [-0.1, 0.1]
dof_vel_perturb_range: [0.0, 0.0]
default_dof_pos_bias:
mode: startup
params:
joint_names: [".*"]
pos_distribution_params: [-0.01, 0.01]
operation: add
distribution: uniform
rigid_body_com:
mode: startup
params:
body_names: torso_link
com_range:
x: [-0.075, 0.075]
y: [-0.1, 0.1]
z: [-0.1, 0.1]
randomize_mass:
mode: startup
params:
body_names:
- "pelvis"
- "torso_link"
mass_range: [-1.0, 2.0]
rigid_body_material:
mode: startup
params:
body_names: ".*"
static_friction_range: [0.3, 1.6]
dynamic_friction_range: [0.3, 1.2]
restitution_range: [0.0, 0.5]
num_buckets: 64
push_by_setting_velocity:
mode: interval
interval_range_s: [1.0, 3.0]
params:
velocity_range:
x: [-0.5, 0.5]
y: [-0.5, 0.5]
z: [-0.2, 0.2]
roll: [-0.52, 0.52]
pitch: [-0.52, 0.52]
yaw: [-0.78, 0.78]
randomize_actuator_gains:
mode: startup
params:
asset_name: robot
body_names: ".*"
stiffness_distribution_params: [0.9, 1.1]
damping_distribution_params: [0.9, 1.1]
operation: scale
distribution: uniform
================================================
FILE: holomotion/config/env/motion_tracking.yaml
================================================
# @package _global_
env:
_target_: holomotion.src.env.motion_tracking.MotionTrackingEnv
_recursive_: False
config:
experiment_name: ${experiment_name}
num_envs: ${num_envs}
env_spacing: 2.5 # meters
replicate_physics: true
headless: ${headless}
num_processes: ${num_processes}
main_process: ${main_process}
process_id: ${process_id}
ckpt_dir: null
disable_ref_viz: false
eval_log_dir: null
save_rendering_dir: null
robot: ${robot}
domain_rand: ${domain_rand}
rewards: ${rewards}
terrain: ${terrain}
obs: ${obs}
terminations: ${terminations}
simulation:
episode_length_s: 10 # Long episodes for fluid motion-based termination
sim_freq: 200
control_decimation: 4
physx:
bounce_threshold_velocity: 0.5
gpu_max_rigid_patch_count: 327680 # 10 * 2**15
scene:
terrain: ${terrain}
lighting:
distant_light_intensity: 3000.0
dome_light_intensity: 1000.0
contact_sensor:
history_length: 3
force_threshold: 10.0
track_air_time: true
debug_vis: false
actions:
dof_pos:
type: joint_position
params:
asset_name: robot
joint_names:
- ".*"
use_default_offset: true
scale: ${robot.actuators.action_scale}
commands:
ref_motion:
type: MotionCommandCfg
params:
command_obs_name: bydmmc_ref_motion
motion_lib_cfg: ${robot.motion}
urdf_dof_names: ${robot.dof_names}
urdf_body_names: ${robot.body_names}
arm_dof_names: ${robot.arm_dof_names}
waist_dof_names: ${robot.waist_dof_names}
leg_dof_names: ${robot.leg_dof_names}
arm_body_names: ${robot.arm_body_names}
torso_body_names: ${robot.torso_body_names}
leg_body_names: ${robot.leg_body_names}
anchor_bodylink_name: ${robot.anchor_body}
asset_name: robot
debug_vis: true
root_pose_perturb_range: ${domain_rand.motion_init_perturb.root_pose_perturb_range}
root_vel_perturb_range: ${domain_rand.motion_init_perturb.root_vel_perturb_range}
dof_pos_perturb_range: ${domain_rand.motion_init_perturb.dof_pos_perturb_range}
dof_vel_perturb_range: ${domain_rand.motion_init_perturb.dof_vel_perturb_range}
resample_time_interval_s: 100
n_fut_frames: ${obs.n_fut_frames}
target_fps: 50
normalization:
clip_actions: 100.0
clip_observations: 100.0
resample_motion_when_training: True
curriculum:
enabled: false
robot_friction_completion_rate:
enabled: True
func: robot_friction_range_by_completion_rate
params:
num_updates: 5
cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]
static_friction_target: [0.3, 1.6]
dynamic_friction_target: [0.3, 1.2]
body_names: ".*"
restitution_range: [0.0, 0.5]
num_buckets: 64
rigid_body_com_completion_rate:
enabled: True
func: rigid_body_com_by_completion_rate
params:
num_updates: 5
cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]
state_prefix: "_cr_curr"
asset_name: "robot"
body_names: "torso_link"
com_range_target:
x: [-0.025, 0.025]
y: [-0.05, 0.05]
z: [-0.05, 0.05]
default_dof_pos_bias_completion_rate:
enabled: True
func: default_dof_pos_bias_by_completion_rate
params:
num_updates: 5
cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]
state_prefix: "_cr_curr"
joint_names:
- ".*"
pos_distribution_params_target: [-0.01, 0.01]
operation: add
distribution: uniform
push_by_setting_velocity_completion_rate:
enabled: True
func: isaaclab_mdp.modify_term_cfg
params:
address: "events.push_by_setting_velocity.params"
modify_fn: push_by_setting_velocity_range_by_completion_rate
modify_params:
num_updates: 3
cr_thresholds: [0.20, 0.30, 0.40]
velocity_range_target:
x: [-0.5, 0.5]
y: [-0.5, 0.5]
z: [-0.2, 0.2]
roll: [-0.52, 0.52]
pitch: [-0.52, 0.52]
yaw: [-0.78, 0.78]
randomize_actuator_gains_completion_rate:
enabled: True
func: randomize_actuator_gains_by_completion_rate
params:
num_updates: 3
cr_thresholds: [0.20, 0.30, 0.40]
asset_name: "robot"
body_names: ".*"
stiffness_distribution_params_target: [0.9, 1.1]
damping_distribution_params_target: [0.9, 1.1]
operation: scale
distribution: uniform
action_rate_l2_completion_rate:
enabled: true
func: reward_term_weight_by_completion_rate
params:
reward_term_name: "action_rate_l2"
final_weight: -0.1
start_scale: 0.1
num_updates: 5
cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]
joint_pos_limits_completion_rate:
enabled: true
func: reward_term_weight_by_completion_rate
params:
reward_term_name: "joint_pos_limits"
final_weight: -10.0
start_scale: 0.1
num_updates: 5
cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]
undesired_contacts_completion_rate:
enabled: true
func: reward_term_weight_by_completion_rate
params:
reward_term_name: "undesired_contacts"
final_weight: -0.1
start_scale: 0.1
num_updates: 5
cr_thresholds: [0.10, 0.20, 0.28, 0.34, 0.40]
================================================
FILE: holomotion/config/env/observations/motion_tracking/obs_motion_tracking_mlp.yaml
================================================
# @package _global_
obs:
context_length: 32
n_fut_frames: 10
target_fps: 50
actor_obs_prefix: "ref_"
critic_obs_prefix: "ref_"
obs_groups:
unified:
atomic_obs_list:
- actor_ref_gravity_projection_cur:
func: ref_gravity_projection_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_max}
- actor_ref_gravity_projection_fut:
func: ref_gravity_projection_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_max}
# Reference base linear velocity
- actor_ref_base_linvel_cur:
func: ref_base_linvel_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max}
n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min_z}
n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max_z}
- actor_ref_base_linvel_fut:
func: ref_base_linvel_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max}
n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min_z}
n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max_z}
- actor_ref_base_angvel_cur:
func: ref_base_angvel_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max}
n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min_z}
n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max_z}
- actor_ref_base_angvel_fut:
func: ref_base_angvel_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max}
n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min_z}
n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max_z}
- actor_ref_dof_pos_cur:
func: ref_dof_pos_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_max}
- actor_ref_dof_pos_fut:
func: ref_dof_pos_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_max}
- actor_ref_motion_filter_cutoff_hz:
func: ref_motion_filter_cutoff_hz
- actor_ref_root_height_cur:
func: ref_root_height_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_max}
- actor_ref_root_height_fut:
func: ref_root_height_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_max}
- actor_ref_keybody_rel_pos_cur:
func: ref_keybody_rel_pos_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
keybody_names:
- "left_knee_link"
- "right_knee_link"
- "left_ankle_roll_link"
- "right_ankle_roll_link"
- "left_elbow_link"
- "right_elbow_link"
- "left_wrist_yaw_link"
- "right_wrist_yaw_link"
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_max}
- actor_ref_keybody_rel_pos_fut:
func: ref_keybody_rel_pos_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
keybody_names:
- "left_knee_link"
- "right_knee_link"
- "left_ankle_roll_link"
- "right_ankle_roll_link"
- "left_elbow_link"
- "right_elbow_link"
- "left_wrist_yaw_link"
- "right_wrist_yaw_link"
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_max}
- actor_projected_gravity:
func: projected_gravity
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_projected_gravity.n_min}
n_max: ${domain_rand.obs_noise.actor_projected_gravity.n_max}
- actor_rel_robot_root_ang_vel:
func: rel_robot_root_ang_vel
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_min}
n_max: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_max}
- actor_dof_pos:
func: dof_pos
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_dof_pos.n_min}
n_max: ${domain_rand.obs_noise.actor_dof_pos.n_max}
- actor_dof_vel:
func: dof_vel
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_dof_vel.n_min}
n_max: ${domain_rand.obs_noise.actor_dof_vel.n_max}
- actor_last_action:
func: last_action
history_length: ${obs.context_length}
flatten_history_dim: false
- critic_ref_dof_pos_cur:
func: ref_dof_pos_cur
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_dof_pos_fut:
func: ref_dof_pos_fut
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_root_height_fut:
func: ref_root_height_fut
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_root_height_cur:
func: ref_root_height_cur
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_global_anchor_diff:
func: global_anchor_diff
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_cur_heading_aligned_root_pos:
func: ref_motion_cur_heading_aligned_root_pos
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_fut_heading_aligned_root_pos:
func: ref_motion_fut_heading_aligned_root_pos
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_cur_heading_aligned_root_rot6d:
func: ref_motion_cur_heading_aligned_root_rot6d
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_fut_heading_aligned_root_rot6d:
func: ref_motion_fut_heading_aligned_root_rot6d
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_cur_heading_aligned_root_lin_vel:
func: ref_motion_cur_heading_aligned_root_lin_vel
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_fut_heading_aligned_root_lin_vel:
func: ref_motion_fut_heading_aligned_root_lin_vel
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_cur_heading_aligned_root_ang_vel:
func: ref_motion_cur_heading_aligned_root_ang_vel
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_fut_heading_aligned_root_ang_vel:
func: ref_motion_fut_heading_aligned_root_ang_vel
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_rel_robot_root_lin_vel:
func: rel_robot_root_lin_vel
- critic_rel_robot_root_ang_vel:
func: rel_robot_root_ang_vel
- critic_global_robot_bodylink_lin_vel_flat:
func: global_robot_bodylink_lin_vel_flat
- critic_global_robot_bodylink_ang_vel_flat:
func: global_robot_bodylink_ang_vel_flat
- critic_root_rel_robot_bodylink_pos_flat:
func: root_rel_robot_bodylink_pos_flat
- critic_root_rel_robot_bodylink_rot_mat_flat:
func: root_rel_robot_bodylink_rot_mat_flat
- critic_dof_pos:
func: dof_pos
- critic_dof_vel:
func: dof_vel
- critic_last_action:
func: last_action
enable_corruption: true
concatenate_terms: false
================================================
FILE: holomotion/config/env/observations/motion_tracking/obs_motion_tracking_tf-moe.yaml
================================================
# @package _global_
obs:
context_length: 1
n_fut_frames: 10
target_fps: 50
actor_obs_prefix: "ref_"
critic_obs_prefix: "ref_"
obs_groups:
unified:
atomic_obs_list:
- actor_ref_gravity_projection_cur:
func: ref_gravity_projection_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_cur.n_max}
- actor_ref_gravity_projection_fut:
func: ref_gravity_projection_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_gravity_projection_fut.n_max}
# Reference base linear velocity
- actor_ref_base_linvel_cur:
func: ref_base_linvel_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max}
n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_min_z}
n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_cur.n_max_z}
- actor_ref_base_linvel_fut:
func: ref_base_linvel_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max}
n_min_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_min_z}
n_max_z: ${domain_rand.obs_noise.actor_ref_base_linvel_fut.n_max_z}
- actor_ref_base_angvel_cur:
func: ref_base_angvel_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max}
n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_min_z}
n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_cur.n_max_z}
- actor_ref_base_angvel_fut:
func: ref_base_angvel_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max}
n_min_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_min_z}
n_max_z: ${domain_rand.obs_noise.actor_ref_base_angvel_fut.n_max_z}
- actor_ref_dof_pos_cur:
func: ref_dof_pos_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_cur.n_max}
- actor_ref_dof_pos_fut:
func: ref_dof_pos_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_dof_pos_fut.n_max}
- actor_ref_motion_filter_cutoff_hz:
func: ref_motion_filter_cutoff_hz
- actor_ref_root_height_cur:
func: ref_root_height_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_root_height_cur.n_max}
- actor_ref_root_height_fut:
func: ref_root_height_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_root_height_fut.n_max}
- actor_ref_keybody_rel_pos_cur:
func: ref_keybody_rel_pos_cur
history_length: ${obs.context_length}
flatten_history_dim: false
params:
ref_prefix: ${obs.actor_obs_prefix}
keybody_names:
- "left_knee_link"
- "right_knee_link"
- "left_ankle_roll_link"
- "right_ankle_roll_link"
- "left_elbow_link"
- "right_elbow_link"
- "left_wrist_yaw_link"
- "right_wrist_yaw_link"
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_cur.n_max}
- actor_ref_keybody_rel_pos_fut:
func: ref_keybody_rel_pos_fut
params:
ref_prefix: ${obs.actor_obs_prefix}
keybody_names:
- "left_knee_link"
- "right_knee_link"
- "left_ankle_roll_link"
- "right_ankle_roll_link"
- "left_elbow_link"
- "right_elbow_link"
- "left_wrist_yaw_link"
- "right_wrist_yaw_link"
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_min}
n_max: ${domain_rand.obs_noise.actor_ref_keybody_rel_pos_fut.n_max}
- actor_projected_gravity:
func: projected_gravity
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_projected_gravity.n_min}
n_max: ${domain_rand.obs_noise.actor_projected_gravity.n_max}
- actor_rel_robot_root_ang_vel:
func: rel_robot_root_ang_vel
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_min}
n_max: ${domain_rand.obs_noise.actor_rel_robot_root_ang_vel.n_max}
- actor_dof_pos:
func: dof_pos
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_dof_pos.n_min}
n_max: ${domain_rand.obs_noise.actor_dof_pos.n_max}
- actor_dof_vel:
func: dof_vel
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: ${domain_rand.obs_noise.actor_dof_vel.n_min}
n_max: ${domain_rand.obs_noise.actor_dof_vel.n_max}
- actor_last_action:
func: last_action
history_length: ${obs.context_length}
flatten_history_dim: false
- critic_ref_dof_pos_cur:
func: ref_dof_pos_cur
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_dof_pos_fut:
func: ref_dof_pos_fut
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_root_height_fut:
func: ref_root_height_fut
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_root_height_cur:
func: ref_root_height_cur
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_global_anchor_diff:
func: global_anchor_diff
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_cur_heading_aligned_root_pos:
func: ref_motion_cur_heading_aligned_root_pos
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_fut_heading_aligned_root_pos:
func: ref_motion_fut_heading_aligned_root_pos
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_cur_heading_aligned_root_rot6d:
func: ref_motion_cur_heading_aligned_root_rot6d
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_fut_heading_aligned_root_rot6d:
func: ref_motion_fut_heading_aligned_root_rot6d
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_cur_heading_aligned_root_lin_vel:
func: ref_motion_cur_heading_aligned_root_lin_vel
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_fut_heading_aligned_root_lin_vel:
func: ref_motion_fut_heading_aligned_root_lin_vel
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_cur_heading_aligned_root_ang_vel:
func: ref_motion_cur_heading_aligned_root_ang_vel
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_ref_motion_fut_heading_aligned_root_ang_vel:
func: ref_motion_fut_heading_aligned_root_ang_vel
params:
ref_prefix: ${obs.critic_obs_prefix}
- critic_rel_robot_root_lin_vel:
func: rel_robot_root_lin_vel
- critic_rel_robot_root_ang_vel:
func: rel_robot_root_ang_vel
- critic_global_robot_bodylink_lin_vel_flat:
func: global_robot_bodylink_lin_vel_flat
- critic_global_robot_bodylink_ang_vel_flat:
func: global_robot_bodylink_ang_vel_flat
- critic_root_rel_robot_bodylink_pos_flat:
func: root_rel_robot_bodylink_pos_flat
- critic_root_rel_robot_bodylink_rot_mat_flat:
func: root_rel_robot_bodylink_rot_mat_flat
- critic_dof_pos:
func: dof_pos
- critic_dof_vel:
func: dof_vel
- critic_last_action:
func: last_action
enable_corruption: true
concatenate_terms: false
================================================
FILE: holomotion/config/env/observations/velocity_tracking/obs_velocity_tracking.yaml
================================================
# @package _global_
obs:
context_length: 8
n_fut_frames: 0
target_fps: 50
obs_groups:
unified:
atomic_obs_list:
# Actor terms (sequence-style; flatten at serializer)
- actor_velocity_command:
func: velocity_command
history_length: ${obs.context_length}
flatten_history_dim: false
mirror_func: mirror_velocity_command
mirror_config: {}
- actor_projected_gravity:
func: projected_gravity
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: -0.1
n_max: 0.1
mirror_func: mirror_vec3
mirror_config: {}
- actor_rel_robot_root_ang_vel:
func: rel_robot_root_ang_vel
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: -0.2
n_max: 0.2
mirror_func: mirror_axial_vec3
mirror_config: {}
- actor_dof_pos:
func: dof_pos
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: -0.01
n_max: 0.01
mirror_func: mirror_dof
mirror_config: {}
- actor_dof_vel:
func: dof_vel
history_length: ${obs.context_length}
flatten_history_dim: false
noise:
type: AdditiveUniformNoiseCfg
params:
n_min: -0.5
n_max: 0.5
mirror_func: mirror_dof
mirror_config: {}
- actor_last_action:
func: last_action
history_length: ${obs.context_length}
flatten_history_dim: false
mirror_func: mirror_dof
mirror_config: {}
# Critic terms
- critic_velocity_command:
func: velocity_command
- critic_rel_robot_root_lin_vel:
func: rel_robot_root_lin_vel
- critic_rel_robot_root_ang_vel:
func: rel_robot_root_ang_vel
- critic_root_rel_robot_bodylink_pos_flat:
func: root_rel_robot_bodylink_pos_flat
params:
keybody_names: ${robot.key_bodies}
- critic_root_rel_robot_bodylink_rot_mat_flat:
func: root_rel_robot_bodylink_rot_mat_flat
params:
keybody_names: ${robot.key_bodies}
- critic_dof_pos:
func: dof_pos
- critic_dof_vel:
func: dof_vel
- critic_last_action:
func: last_action
enable_corruption: true
concatenate_terms: false
================================================
FILE: holomotion/config/env/rewards/motion_tracking/rew_motion_tracking.yaml
================================================
# @package _global_
rewards:
_config:
reward_prefix: "ref_"
is_alive:
weight: 0.5
params: {}
root_pos_xy_tracking_exp:
weight: 1.0
params:
std: 0.2
ref_prefix: ${rewards._config.reward_prefix}
root_rot_tracking_exp:
weight: 0.5
params:
std: 0.4
ref_prefix: ${rewards._config.reward_prefix}
root_rel_keybodylink_pos_tracking_l2_exp:
weight: 1.0
params:
keybody_names: ${robot.key_bodies}
std: 0.3
ref_prefix: ${rewards._config.reward_prefix}
root_rel_keybodylink_rot_tracking_l2_exp:
weight: 2.0
params:
keybody_names: ${robot.key_bodies}
std: 0.4
ref_prefix: ${rewards._config.reward_prefix}
global_keybodylink_lin_vel_tracking_l2_exp:
weight: 1.0
params:
keybody_names: ${robot.key_bodies}
std: 1.0
ref_prefix: ${rewards._config.reward_prefix}
global_keybodylink_ang_vel_tracking_l2_exp:
weight: 1.0
params:
keybody_names: ${robot.key_bodies}
std: 3.14
ref_prefix: ${rewards._config.reward_prefix}
action_rate_l2:
weight: -0.1
params: {}
# joint_acc_l2:
# weight: -1.0e-6
# params: {}
joint_pos_limits:
weight: -10.0
params:
asset_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: robot
joint_names:
- ".*"
undesired_contacts:
weight: -0.1
params:
threshold: 1.0
sensor_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: contact_forces
body_names:
- ${robot.undesired_contacts_regrex}
================================================
FILE: holomotion/config/env/rewards/velocity_tracking/rew_velocity_tracking.yaml
================================================
# @package _global_
rewards:
stand_still_action_rate:
weight: -1.0
params:
command_name: base_velocity
feet_contact_without_cmd:
weight: 1.0
params:
command_name: base_velocity
sensor_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: contact_forces
body_names:
- ".*ankle_roll.*"
track_stand_still_exp:
weight: 5.0
params:
command_name: base_velocity
std: 0.2
track_lin_vel_xy_heading_aligned_frame_exp:
weight: 3.0
params:
std: 0.5
command_name: base_velocity
track_ang_vel_z_heading_aligned_frame_exp:
weight: 3.0
params:
std: 0.5
command_name: base_velocity
is_alive:
weight: 0.15
params: {}
lin_vel_z_l2:
weight: -1.0
params: {}
ang_vel_xy_l2:
weight: -5.0e-2
params: {}
joint_acc_l2:
weight: -1.0e-6
params: {}
action_rate_l2:
weight: -1.0e-1
params: {}
joint_pos_limits:
weight: -5.0
params:
asset_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: robot
joint_names: [".*"]
feet_air_time_v4:
weight: 1.0
params:
threshold: 0.5
sensor_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: contact_forces
body_names: [".*ankle_roll.*"]
command_name: base_velocity
fly:
weight: -1.0
params:
threshold: 1.0
sensor_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: contact_forces
body_names: [".*ankle_roll.*"]
feet_too_near:
weight: -10.0
threshold: 0.2
params:
asset_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: robot
joint_names:
- .*ankle_roll.*
joint_deviation_l1_arms:
weight: -0.3
params:
asset_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: robot
joint_names:
- .*_hip_roll.*
- .*waist_roll.*
- .*waist_pitch.*
- .*_shoulder_roll.*
- .*_shoulder_yaw.*
- .*_wrist.*
joint_deviation_l1_legs_yaw:
weight: -0.15
params:
asset_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: robot
joint_names:
- .*waist_yaw.*
- .*_hip_yaw.*
- .*_elbow.*
- .*_ankle.*
joint_deviation_l1_legs:
weight: -0.02
params:
asset_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: robot
joint_names:
- .*_shoulder_pitch.*
- .*_hip_pitch.*
- .*_knee.*
flat_orientation_l2:
weight: -5.0
params: {}
base_height_l2:
weight: -10.0
params:
target_height: 0.78
feet_slide:
weight: -1.0
params:
asset_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: robot
body_names: [".*ankle_roll.*"]
sensor_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: contact_forces
body_names: [".*ankle_roll.*"]
undesired_contacts:
weight: -1.0
params:
threshold: 1.0
sensor_cfg:
_target_: isaaclab.managers.scene_entity_cfg.SceneEntityCfg
name: contact_forces
body_names:
- ${robot.undesired_contacts_regrex}
torso_xy_ang_vel_l2_penalty:
weight: -1.0
params: {}
torso_upright_l2_penalty:
weight: -1.0
params: {}
================================================
FILE: holomotion/config/env/terminations/NO_termination.yaml
================================================
# @package _global_
terminations: {}
================================================
FILE: holomotion/config/env/terminations/termination_motion_tracking.yaml
================================================
# @package _global_
terminations:
time_out:
time_out: true
ref_gravity_projection_far:
params:
threshold: 0.8
ref_prefix: ${rewards._config.reward_prefix}
keybody_ref_z_far:
params:
threshold: 0.25
ref_prefix: ${rewards._config.reward_prefix}
keybody_names:
- pelvis
- left_ankle_roll_link
- right_ankle_roll_link
- left_wrist_yaw_link
- right_wrist_yaw_link
# keybody_ref_pos_far:
# params:
# threshold: 0.5
# ref_prefix: ${rewards._config.reward_prefix}
# keybody_names:
# - pelvis
================================================
FILE: holomotion/config/env/terminations/termination_velocity_tracking.yaml
================================================
# @package _global_
terminations:
time_out:
time_out: true
root_height_below_minimum:
params:
minimum_height: 0.2
bad_orientation:
params:
limit_angle: 0.8
================================================
FILE: holomotion/config/env/terrain/isaaclab_plane.yaml
================================================
# @package _global_
# Simple flat terrain generated via IsaacLab height-field TerrainGenerator.
# Uses random-uniform height field with zero noise as a flat patch.
terrain:
terrain_type: generator
prim_path: /World/ground
static_friction: 1.0
dynamic_friction: 1.0
restitution: 0.0
friction_combine_mode: multiply
restitution_combine_mode: multiply
debug_vis: false
max_init_terrain_level: 0
# Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch.
# When false, env origins are placed on a regular grid as in the default importer.
random_spawn: true
# Keep random spawn points away from terrain edges to avoid spawning onto the outer border.
random_spawn_margin: 2.0
# TerrainGeneratorCfg parameters.
generator:
num_rows: 1
num_cols: 1
size: [10.0, 10.0]
border_width: 1000.0
horizontal_scale: 0.1
vertical_scale: 0.005
slope_threshold: null
difficulty_range: [0.0, 0.0]
color_scheme: height
sub_terrains:
plane:
type: plane
proportion: 1.0
# Offline visual material configuration (PreviewSurface, no MDL/Nucleus).
visual_material:
type: color
diffuse_color: [0.25, 0.25, 0.25]
metallic: 0.0
roughness: 0.5
================================================
FILE: holomotion/config/env/terrain/isaaclab_rough.yaml
================================================
# @package _global_
# Rough height-field terrain for locomotion training.
# Uses random-uniform height field to create continuous noise-like terrain.
terrain:
terrain_type: generator
prim_path: /World/ground
static_friction: 1.0
dynamic_friction: 1.0
restitution: 0.0
friction_combine_mode: multiply
restitution_combine_mode: multiply
debug_vis: false
max_init_terrain_level: 4
# Randomize spawn position within each sub-terrain (recommended for locomotion).
random_spawn: true
random_spawn_margin: 4.0
# TerrainGeneratorCfg parameters.
generator:
num_rows: 4 # Number of terrain rows (difficulty levels)
num_cols: 4 # Number of terrain columns (types)
size: [20.0, 20.0] # Size of each sub-terrain in meters [length, width]
border_width: 1000.0 # Border around terrain in meters
horizontal_scale: 0.1 # Resolution in x-y plane
vertical_scale: 0.005 # Height resolution
slope_threshold: null # Slopes above this become vertical
difficulty_range: [0.0, 1.0] # Min and max difficulty
color_scheme: height # Use material shading instead of vertex colors
sub_terrains:
rough:
type: random_uniform
proportion: 1.0
noise_range: [0.0, 0.04]
noise_step: 0.05
downsampled_scale: 1.0
# Offline visual material configuration (PreviewSurface, no MDL/Nucleus).
visual_material:
type: color
diffuse_color: [0.25, 0.25, 0.25]
metallic: 0.0
roughness: 0.5
================================================
FILE: holomotion/config/env/velocity_tracking.yaml
================================================
# @package _global_
env:
_target_: holomotion.src.env.velocity_tracking.VelocityTrackingEnv
_recursive_: False
config:
experiment_name: ${experiment_name}
num_envs: ${num_envs}
env_spacing: 2.5
replicate_physics: true
headless: ${headless}
num_processes: ${num_processes}
main_process: ${main_process}
process_id: ${process_id}
ckpt_dir: null
disable_ref_viz: false
eval_log_dir: null
save_rendering_dir: null
robot: ${robot}
domain_rand: ${domain_rand}
rewards: ${rewards}
terrain: ${terrain}
obs: ${obs}
terminations: ${terminations}
simulation:
episode_length_s: 20
sim_freq: 200
control_decimation: 4
physx:
bounce_threshold_velocity: 0.5
gpu_max_rigid_patch_count: 327680
scene:
terrain: ${terrain}
lighting:
distant_light_intensity: 3000.0
dome_light_intensity: 1000.0
contact_sensor:
history_length: 3
force_threshold: 10.0
track_air_time: true
debug_vis: false
actions:
dof_pos:
type: joint_position
params:
asset_name: robot
joint_names:
- ".*"
use_default_offset: true
scale: ${robot.actuators.action_scale}
commands:
base_velocity:
type: HoloMotionUniformVelocityCommandCfg
params:
asset_name: robot
resampling_time_range: [3, 10.0]
rel_standing_envs: 0.20
rel_yaw_envs: 0.30 # actual prob for sampled yaw-only is 0.3 * (1-0.2) = 0.24
rel_heading_envs: 1.0
heading_command: false
heading_control_stiffness: 0.5
debug_vis: true
ranges:
lin_vel_x: [-0.6, 1.0]
lin_vel_y: [-0.5, 0.5]
ang_vel_z: [-1.0, 1.0]
heading: [-3.14, 3.14]
# limit_ranges:
# lin_vel_x: [-0.5, 1.0]
# lin_vel_y: [-0.3, 0.3]
# ang_vel_z: [-0.2, 0.2]
# heading: [-3.14, 3.14]
================================================
FILE: holomotion/config/evaluation/eval_isaaclab.yaml
================================================
# @package _global_
defaults:
- /robot: unitree/G1/29dof/29dof_training_isaaclab
- /env: motion_tracking
- /env/terrain: isaaclab_plane
- /env/terminations: NO_termination
- /env/domain_randomization: NO_domain_rand
project_name: ???
experiment_name: ???
num_envs: ???
headless: ???
motion_h5_path: null
checkpoint: null
log_dir: null
ckpt_pt_names: null
num_processes: ???
main_process: ???
process_id: ???
timestamp: ${now:%Y%m%d_%H%M%S}
base_dir: logs
experiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name}
save_dir: ${experiment_dir}/.hydra
output_dir: ${experiment_dir}/output
experiment_save_dir: ???
export_policy: false
export_only: false
dump_npzs: false
calc_per_clip_metrics: false
generate_report: false
dof_mode: "23"
obs:
critic_obs_prefix: "ref_"
rewards:
_config:
reward_prefix: "ref_"
algo:
config:
dynamo_backend: null
sampling_strategy: uniform
seed: 114514
env:
config:
seed: 42
simulation:
episode_length_s: 36000
robot:
motion:
backend: "hdf5_v2"
cache_max_num_clips: ${num_envs}
train_hdf5_roots: ${robot.motion.val_hdf5_roots}
val_hdf5_roots: ${motion_h5_path}
max_frame_length: 10000 # 20s
min_frame_length: 1
# handpicked_motion_names: ${handpicked_motion_names}
world_frame_normalization: false
dataloader:
num_workers: 2
pin_memory: true
persistent_workers: false
prefetch_factor: 1
timeout: 600
batch_progress_bar: true
terrain:
terrain_type: generator
prim_path: /World/ground
static_friction: 1.0
dynamic_friction: 1.0
restitution: 0.0
friction_combine_mode: multiply
restitution_combine_mode: multiply
debug_vis: false
max_init_terrain_level: 0
# Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch.
# When false, env origins are placed on a regular grid as in the default importer.
random_spawn: true
# Keep random spawn points away from terrain edges to avoid spawning onto the outer border.
random_spawn_margin: 2.0
# TerrainGeneratorCfg parameters.
generator:
num_rows: 1
num_cols: 1
size: [10.0, 10.0]
border_width: 1000.0
horizontal_scale: 0.1
vertical_scale: 0.005
slope_threshold: null
difficulty_range: [0.0, 0.0]
color_scheme: height
sub_terrains:
plane:
type: plane
proportion: 1.0
# Offline visual material configuration (PreviewSurface, no MDL/Nucleus).
visual_material:
type: color
diffuse_color: [0.25, 0.25, 0.25]
metallic: 0.0
roughness: 0.5
================================================
FILE: holomotion/config/evaluation/eval_mujoco_sim2sim.yaml
================================================
# @package _global_
defaults:
- _self_
# Evaluation toggles
headless: false # true to run without GUI
record_video: false # true to save MP4 recordings
video_width: 1280
video_height: 720
video_fps: 30
camera_tracking: true # true to make camera follow robot root body
camera_height_offset:
0.3 # small offset (meters) above robot root for camera lookat point
# NOTE: This offsets where camera LOOKS AT, not camera position
# Use small values (0.2-0.5m) for proper framing, not large values
camera_distance:
4.0 # camera distance from lookat point (meters)
# Larger values = camera further away, smaller values = closer
camera_azimuth: 150.0 # default viewer/offscreen azimuth (deg), side-ish view
camera_elevation: -20.0 # default viewer/offscreen elevation (deg), slight downward angle
# Offline evaluation pipeline (dataset mode)
motion_npz_dir: null
dump_npzs: true
dump_onnx_io_npy: false
calc_per_clip_metrics: false
generate_report: false
metric_calculation: "per_clip" # "per_clip" (Macro) or "per_frame" (Micro)
dof_mode: "23" # "29" for full DoF, "23" for reduced DoF
failure_pos_err_thresh_m: 0.25
ray_actors_per_gpu: 16 # persistent Ray actors per GPU for batch eval
ray_multi_ckpt_mode: "split" # "split" or "per_checkpoint" for multi-ONNX eval
ckpt_onnx_root_dir: null # optional ONNX root directory for multi-checkpoint eval
ckpt_onnx_names: null # optional list of ONNX file names to evaluate
ray_parallel_metrics_postprocess: true # parallelize per-checkpoint metrics/report/export with Ray when evaluating multiple ckpts
ray_metrics_postprocess_num_cpus: 24 # Ray resource accounting per checkpoint-postprocess task (0 = don't reserve CPUs)
metrics_threadpool_max_workers: 24 # per-checkpoint ThreadPoolExecutor workers inside metrics.py (null => auto=min(num_files, 24))
# Termination / scheduling
max_policy_steps: 0 # 0 = unlimited; used in headless if no motion
policy_action_delay_step: 0 # max random action delay in 50 Hz policy steps; 0 disables delay
action_delay_type: "episode" # "episode" samples once per reset, "step" re-samples every policy step
unitree_viewer_dt: 0.0167 # ~60 Hz viewer sync
unitree_domain_id: 0
unitree_interface: "lo"
unitree_use_joystick: false
unitree_joystick_type: "xbox"
unitree_print_scene_information: false
# Debug options
debug_anchor_obs: false # when true, dump anchor pose/obs debug CSV for sim2sim
debug_anchor_obs_interval: 50 # log every N policy steps (>=1)
use_isaac_root_alignment: true # align free root state to IsaacLab reference root at frame 0
isaac_action_playback: false # when true, use per-frame actions recorded from IsaacLab instead of ONNX policy
robot_xml_path: ???
use_gpu: true
================================================
FILE: holomotion/config/evaluation/eval_velocity_tracking.yaml
================================================
# @package _global_
defaults:
- /robot: unitree/G1/29dof/29dof_training_isaaclab
- /env: velocity_tracking
- /env/terrain: isaaclab_plane
- /env/terminations: NO_termination
- /env/domain_randomization: NO_domain_rand
project_name: ???
experiment_name: ???
num_envs: ???
headless: ???
motion_h5_path: null
checkpoint: null
num_processes: ???
main_process: ???
process_id: ???
timestamp: ${now:%Y%m%d_%H%M%S}
base_dir: logs
experiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name}
save_dir: ${experiment_dir}/.hydra
output_dir: ${experiment_dir}/output
experiment_save_dir: ???
dump_npzs: false
export_policy: true
algo:
config:
dynamo_backend: null
# Video recording for offline evaluation (env.render() -> MP4 at target_fps)
record_video: false # enable MP4 recording during offline evaluation
env:
config:
simulation:
episode_length_s: 3600
robot:
motion:
cache_max_num_clips: ${num_envs}
max_frame_length: 10000
min_frame_length: 0
hdf5_root: ${motion_h5_path}
val_hdf5_root: ${motion_h5_path}
dataloader:
num_workers: 0
# terrain:
# terrain_type: usd
# usd_path: assets/isaac/4.1/Isaac/Environments/Grid/gridroom_black.usd
# usd_path: assets/isaac/4.1/Isaac/Environments/Terrains/rough_plane.usd
terrain:
terrain_type: generator
prim_path: /World/ground
static_friction: 1.0
dynamic_friction: 1.0
restitution: 0.0
friction_combine_mode: multiply
restitution_combine_mode: multiply
debug_vis: false
max_init_terrain_level: 0
# Use RandomSpawnTerrainImporter for optional random XY spawn inside the plane patch.
# When false, env origins are placed on a regular grid as in the default importer.
random_spawn: true
# TerrainGeneratorCfg parameters.
generator:
num_rows: 1
num_cols: 1
size: [8.0, 8.0]
border_width: 10.0
horizontal_scale: 0.1
vertical_scale: 0.005
slope_threshold: null
difficulty_range: [0.0, 0.0]
color_scheme: height
sub_terrains:
plane:
type: random_uniform
proportion: 1.0
noise_range: [0.0, 0.0]
noise_step: 0.25
downsampled_scale: 0.5
# Offline visual material configuration (PreviewSurface, no MDL/Nucleus).
visual_material:
type: color
diffuse_color: [0.25, 0.25, 0.25]
metallic: 0.0
roughness: 0.5
================================================
FILE: holomotion/config/modules/motion_tracking/motion_tracking_mlp.yaml
================================================
# @package _global_
modules:
actor:
type: MLP
hidden_norm: none
layer_config:
hidden_dims:
- 2048
- 1024
- 512
- 256
activation: SiLU
obs_norm:
enabled: true
epsilon: 1.0e-8 # Reduced for better stability in DDP
update_method: ema # ema or cumulative
ema_momentum: 1.0e-4
update_at_train: true
update_at_eval: false
enable_clipping: true # Enable clipping for DDP stability
clip_range: 10.0 # Reduced clip range for better stability
sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout
# Observation schema for motion tracking, from the actor's perspective.
obs_schema:
flattened_obs:
seq_len: ${obs.context_length}
terms:
- unified/actor_ref_gravity_projection_cur
- unified/actor_ref_base_linvel_cur
- unified/actor_ref_base_angvel_cur
- unified/actor_ref_dof_pos_cur
- unified/actor_ref_root_height_cur
- unified/actor_projected_gravity
- unified/actor_rel_robot_root_ang_vel
- unified/actor_dof_pos
- unified/actor_dof_vel
- unified/actor_last_action
flattened_obs_fut:
seq_len: ${obs.n_fut_frames}
terms:
- unified/actor_ref_dof_pos_fut
- unified/actor_ref_root_height_fut
- unified/actor_ref_gravity_projection_fut
- unified/actor_ref_base_linvel_fut
- unified/actor_ref_base_angvel_fut
output_dim: robot_action_dim
critic:
type: MLP
obs_norm:
enabled: true
epsilon: 1.0e-8 # Reduced for better stability in DDP
update_method: ema # ema or cumulative
ema_momentum: 1.0e-4
update_at_train: true
update_at_eval: false
enable_clipping: true # Enable clipping for DDP stability
clip_range: 10.0 # Reduced clip range for better stability
sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout
hidden_norm: rmsnorm
layer_config:
hidden_dims:
- 2048
- 2048
- 2048
- 2048
activation: SiLU
obs_schema:
flattened_obs:
seq_len: 1
terms:
- unified/critic_ref_dof_pos_cur
- unified/critic_global_anchor_diff
- unified/critic_ref_motion_cur_heading_aligned_root_pos
- unified/critic_ref_motion_cur_heading_aligned_root_rot6d
- unified/critic_ref_motion_cur_heading_aligned_root_lin_vel
- unified/critic_ref_motion_cur_heading_aligned_root_ang_vel
- unified/critic_rel_robot_root_lin_vel
- unified/critic_rel_robot_root_ang_vel
- unified/critic_global_robot_bodylink_lin_vel_flat
- unified/critic_global_robot_bodylink_ang_vel_flat
- unified/critic_root_rel_robot_bodylink_pos_flat
- unified/critic_root_rel_robot_bodylink_rot_mat_flat
- unified/critic_dof_pos
- unified/critic_dof_vel
- unified/critic_last_action
flattened_obs_fut:
seq_len: ${obs.n_fut_frames}
terms:
- unified/critic_ref_dof_pos_fut
- unified/critic_ref_root_height_fut
- unified/critic_ref_motion_fut_heading_aligned_root_pos
- unified/critic_ref_motion_fut_heading_aligned_root_rot6d
- unified/critic_ref_motion_fut_heading_aligned_root_lin_vel
- unified/critic_ref_motion_fut_heading_aligned_root_ang_vel
output_dim: 1
================================================
FILE: holomotion/config/modules/motion_tracking/motion_tracking_tf-moe.yaml
================================================
# @package _global_
modules:
actor:
type: ReferenceRoutedGroupedMoETransformerPolicy
use_checkpointing: false # use gradient checkpointing to save GRAM significantly
# MoE-specific hyperparameters
num_fine_experts: 16
num_shared_experts: 1
top_k: 2
moe_loss_coef: 0.0
routing_score_fn: ${algo.config.moe_router.routing_score_fn}
routing_scale: ${algo.config.moe_router.routing_scale}
use_dynamic_bias: ${algo.config.moe_router.use_dynamic_bias}
bias_update_rate: ${algo.config.moe_router.bias_update_rate}
expert_bias_clip: ${algo.config.moe_router.expert_bias_clip}
# Transformer hyperparameters - smaller model for stability
obs_embed_mlp_hidden: 2048
router_embed_mlp_hidden: 2048
d_model: 512
n_heads: 8
n_kv_heads: 4
use_gated_attn: true
n_layers: 3
ff_mult: 2.0
ff_mult_dense: 4
attn_dropout: 0.0
mlp_dropout: 0.0
max_ctx_len: 32
# Auxiliary dynamics prediction weights (0.0 = disabled)
aux_sys_id_weight: 0.0
aux_dynamics_weight: 0.0
obs_norm:
enabled: true
epsilon: 1.0e-8 # Reduced for better stability in DDP
update_method: ema # ema or cumulative
ema_momentum: 1.0e-4
update_at_train: true
update_at_eval: false
enable_clipping: true # Enable clipping for DDP stability
clip_range: 10.0 # Reduced clip range for better stability
sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout
# Observation schema for motion tracking, from the actor's perspective.
obs_schema:
flattened_obs:
seq_len: ${obs.context_length}
terms:
- unified/actor_ref_gravity_projection_cur
- unified/actor_ref_base_linvel_cur
- unified/actor_ref_base_angvel_cur
- unified/actor_ref_dof_pos_cur
- unified/actor_ref_root_height_cur
- unified/actor_projected_gravity
- unified/actor_rel_robot_root_ang_vel
- unified/actor_dof_pos
- unified/actor_dof_vel
- unified/actor_last_action
flattened_obs_fut:
seq_len: ${obs.n_fut_frames}
terms:
- unified/actor_ref_dof_pos_fut
- unified/actor_ref_root_height_fut
- unified/actor_ref_gravity_projection_fut
- unified/actor_ref_base_linvel_fut
- unified/actor_ref_base_angvel_fut
output_dim: robot_action_dim
critic:
type: MLP
obs_norm:
enabled: true
epsilon: 1.0e-8 # Reduced for better stability in DDP
update_method: ema # ema or cumulative
ema_momentum: 1.0e-4
update_at_train: true
update_at_eval: false
enable_clipping: true # Enable clipping for DDP stability
clip_range: 10.0 # Reduced clip range for better stability
sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout
hidden_norm: rmsnorm
layer_config:
hidden_dims:
- 2048
- 2048
- 2048
- 2048
activation: SiLU
obs_schema:
flattened_obs:
seq_len: 1
terms:
- unified/critic_ref_dof_pos_cur
- unified/critic_global_anchor_diff
- unified/critic_ref_motion_cur_heading_aligned_root_pos
- unified/critic_ref_motion_cur_heading_aligned_root_rot6d
- unified/critic_ref_motion_cur_heading_aligned_root_lin_vel
- unified/critic_ref_motion_cur_heading_aligned_root_ang_vel
- unified/critic_rel_robot_root_lin_vel
- unified/critic_rel_robot_root_ang_vel
- unified/critic_global_robot_bodylink_lin_vel_flat
- unified/critic_global_robot_bodylink_ang_vel_flat
- unified/critic_root_rel_robot_bodylink_pos_flat
- unified/critic_root_rel_robot_bodylink_rot_mat_flat
- unified/critic_dof_pos
- unified/critic_dof_vel
- unified/critic_last_action
flattened_obs_fut:
seq_len: ${obs.n_fut_frames}
terms:
- unified/critic_ref_dof_pos_fut
- unified/critic_ref_root_height_fut
- unified/critic_ref_motion_fut_heading_aligned_root_pos
- unified/critic_ref_motion_fut_heading_aligned_root_rot6d
- unified/critic_ref_motion_fut_heading_aligned_root_lin_vel
- unified/critic_ref_motion_fut_heading_aligned_root_ang_vel
output_dim: 1
================================================
FILE: holomotion/config/modules/velocity_tracking/velocity_tracking_mlp.yaml
================================================
# @package _global_
modules:
actor:
type: MLP
fix_sigma: false
noise_std_type: scalar
obs_norm:
enabled: true
epsilon: 1.0e-8 # Reduced for better stability in DDP
update_method: ema # ema or cumulative
ema_momentum: 1.0e-4
update_at_train: true
update_at_eval: false
enable_clipping: true # Enable clipping for DDP stability
clip_range: 10.0 # Reduced clip range for better stability
sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout
hidden_norm: rmsnorm
layer_config:
hidden_dims:
- 512
- 512
- 512
activation: SiLU
obs_schema:
flattened_obs:
seq_len: ${obs.context_length}
terms:
- unified/actor_velocity_command
- unified/actor_projected_gravity
- unified/actor_rel_robot_root_ang_vel
- unified/actor_dof_pos
- unified/actor_dof_vel
- unified/actor_last_action
output_dim: robot_action_dim
critic:
type: MLP
obs_norm:
enabled: true
epsilon: 1.0e-8 # Reduced for better stability in DDP
update_method: ema # ema or cumulative
ema_momentum: 1.0e-4
update_at_train: true
update_at_eval: false
enable_clipping: true # Enable clipping for DDP stability
clip_range: 10.0 # Reduced clip range for better stability
sync_interval_steps: 4 # Periodically sync obs normalizers across ranks during rollout
hidden_norm: rmsnorm
layer_config:
hidden_dims:
- 512
- 512
- 512
activation: SiLU
obs_schema:
flattened_obs:
seq_len: 1
terms:
- unified/critic_velocity_command
- unified/critic_rel_robot_root_lin_vel
- unified/critic_rel_robot_root_ang_vel
- unified/critic_root_rel_robot_bodylink_pos_flat
- unified/critic_root_rel_robot_bodylink_rot_mat_flat
- unified/critic_dof_pos
- unified/critic_dof_vel
- unified/critic_last_action
output_dim: 1
================================================
FILE: holomotion/config/motion_retargeting/gmr_to_holomotion.yaml
================================================
# @package _global_
defaults:
- _self_
hydra:
job:
chdir: false
io:
src_dir: ???
robot_config: ???
out_root: ???
ref_dir: holomotion/src/motion_retargeting/utils
processing:
target_fps: 50
fast_interpolate: true
skip_existing: true
debug_mode: false
ray:
num_workers: 0
ray_address: ""
naming:
emit_prefixed: true
emit_legacy: false
preprocess:
# Available stages:
# ['filename_as_motionkey','legacy_to_ref_keys','add_legacy_keys',
# 'slicing','apply_butterworth_filter','add_padding','tagging']
# Empty list [] means no preprocessing stages applied
pipeline: []
slicing:
window_size: 500
overlap: 50
filtering:
type: butterworth
butter_cutoff_hz: 3.0
butter_order: 4
padding:
# Robot config path for FK and default joint angles
# If empty, uses io.robot_config
robot_config_path: ${io.robot_config}
# Duration of stand-still padding before/after motion (seconds)
stand_still_time: 1.0
# Duration of transition between default pose and motion (seconds)
transition_time: 1.5
tagging:
# when empty, write tags to: /kinematic_tags.json
output_json_path: ""
================================================
FILE: holomotion/config/motion_retargeting/holomotion_preprocess.yaml
================================================
# @package _global_
defaults:
- _self_
hydra:
job:
chdir: false
io:
src_root: ???
out_root: ???
preprocess:
# Available stages:
# ['filename_as_motionkey','legacy_to_ref_keys','add_legacy_keys',
# 'slicing','apply_butterworth_filter','add_padding','tagging']
pipeline: []
slicing:
window_size: 500
overlap: 50
filtering:
type: butterworth
butter_cutoff_hz: 3.0
butter_order: 4
padding:
# Robot config path for FK and default joint angles
robot_config_path: ""
# Duration of stand-still padding before/after motion (seconds)
stand_still_time: 1.0
# Duration of transition between default pose and motion (seconds)
transition_time: 1.0
tagging:
# when empty, write tags to: /kinematic_tags.json
output_json_path: ""
ray:
enabled: true
num_workers: 2 # 0 = use all available CPUs
ray_address: "" # empty = local; otherwise connect to existing Ray cluster
================================================
FILE: holomotion/config/motion_retargeting/kinematic_filter.yaml
================================================
# @package _global_
defaults:
- _self_
hydra:
job:
chdir: false
io:
dataset_root: "" # absolute path to dataset root containing kinematic_tags.json
filtering:
output_yaml: "" # optional; defaults to /excluded_kinematic_motion_names.yaml
schema:
across: union
thresholds:
kinematic_features.root_linear_speed.max: { op: ">", value: 10.0 }
kinematic_features.root_angular_speed.max: { op: ">", value: 20.0 }
kinematic_features.root_delta_z.max: { op: ">", value: 2.0 }
kinematic_features.jerk.max: { op: ">", value: 2000.0 }
================================================
FILE: holomotion/config/motion_retargeting/pack_hdf5_database.yaml
================================================
# @package _global_
defaults:
- _self_
- /robot: unitree/G1/29dof/29dof_training_isaaclab
hydra:
job:
chdir: false
# IO
precomputed_npz_root: ???
hdf5_root: ???
# Runtime
# Optimal parameters for distributed JuiceFS training with millions of clips:
# - chunks_t: Larger chunks (2048-4096) reduce metadata overhead and improve sequential read performance on JuiceFS
# Balance: larger chunks = better sequential I/O, but too large wastes memory
chunks_t: 1024
# - compression: lzf provides fast decompression suitable for training workloads
# Alternatives: gzip (better compression, slower), none (fastest but largest files)
compression: lzf # lzf|gzip|none
# - shard_target_gb: Larger shards (5-10 GB) reduce shard count and metadata overhead for millions of clips
# Balance: fewer shards = less metadata overhead, better for distributed access; more shards = better parallelism
# For millions of clips, 5-10 GB reduces shard count while maintaining good parallelism
shard_target_gb: 1.0
# - num_jobs: Parallel workers for packing process (should match available CPU cores)
num_jobs: 16
debug_local_mode: false
================================================
FILE: holomotion/config/motion_retargeting/pack_hdf5_v2.yaml
================================================
# @package _global_
defaults:
- _self_
- /robot: unitree/G1/29dof/29dof_training_isaaclab
hydra:
job:
chdir: false
# IO
holomotion_npz_root: ???
hdf5_root: ???
# Runtime
chunks_t: 1024
compression: lzf # lzf|gzip|none
shard_target_gb: 1.0
shard_target_mode: h5_filesize # h5_filesize|npz_filesize|uncompressed_nbytes
num_jobs: 16
debug_local_mode: false
================================================
FILE: holomotion/config/motion_retargeting/unitree_G1_29dof_retargeting.yaml
================================================
robot:
humanoid_type: unitree/G1/29dof
asset:
smpl_dir: "assets/smpl"
assetRoot: "./"
assetFileName: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml"
training_mjcfName: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml"
video_dir: ${motion_npz_root}/video_rendering
skip_frames: 3 # when skip frames=1, it means 30hz, when skip frames=2, it means 15hz, etc.
show_markers: False
max_workers: 12
================================================
FILE: holomotion/config/mujoco_eval/sim2sim.yaml
================================================
# @package _group_
defaults:
- _self_
enabled: false
model_type: "holomotion"
# Evaluation toggles
headless: false
record_video: false
video_width: 1280
video_height: 720
video_fps: 30
camera_tracking: true
camera_height_offset: 0.3
camera_distance: 4.0
camera_azimuth: 150.0
camera_elevation: -20.0
# Input/output
robot_xml_path: null
motion_npz_dir: null
motion_npz_path: null
ckpt_onnx_path: null
ckpt_onnx_root_dir: null
ckpt_onnx_names: null
# Offline evaluation pipeline
dump_npzs: true
calc_per_clip_metrics: false
generate_report: false
metric_calculation: "per_clip"
dof_mode: "23"
failure_pos_err_thresh_m: 0.25
ray_actors_per_gpu: 16
ray_multi_ckpt_mode: "split"
# Runtime
use_gpu: true
================================================
FILE: holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml
================================================
# @package _global_
robot:
humanoid_type: unitree/G1/29dof
dof_obs_size: 29
actions_dim: 29
num_bodies: 30
num_extend_bodies: 0
undesired_contacts_regrex: "^(?!left_ankle_roll_link$)(?!right_ankle_roll_link$)(?!left_wrist_yaw_link$)(?!right_wrist_yaw_link$).+$"
torso_name: "torso_link"
anchor_body: "torso_link"
key_bodies:
- "pelvis"
- "left_hip_roll_link"
- "left_knee_link"
- "left_ankle_pitch_link"
- "right_hip_roll_link"
- "right_knee_link"
- "right_ankle_pitch_link"
- "torso_link"
- "left_shoulder_roll_link"
- "left_elbow_link"
- "left_wrist_yaw_link"
- "right_shoulder_roll_link"
- "right_elbow_link"
- "right_wrist_yaw_link"
key_dofs:
- "left_knee_joint"
- "right_knee_joint"
- "left_elbow_joint"
- "right_elbow_joint"
dof_names:
- "left_hip_pitch_joint"
- "left_hip_roll_joint"
- "left_hip_yaw_joint"
- "left_knee_joint"
- "left_ankle_pitch_joint"
- "left_ankle_roll_joint"
- "right_hip_pitch_joint"
- "right_hip_roll_joint"
- "right_hip_yaw_joint"
- "right_knee_joint"
- "right_ankle_pitch_joint"
- "right_ankle_roll_joint"
- "waist_yaw_joint"
- "waist_roll_joint"
- "waist_pitch_joint"
- "left_shoulder_pitch_joint"
- "left_shoulder_roll_joint"
- "left_shoulder_yaw_joint"
- "left_elbow_joint"
- "left_wrist_roll_joint"
- "left_wrist_pitch_joint"
- "left_wrist_yaw_joint"
- "right_shoulder_pitch_joint"
- "right_shoulder_roll_joint"
- "right_shoulder_yaw_joint"
- "right_elbow_joint"
- "right_wrist_roll_joint"
- "right_wrist_pitch_joint"
- "right_wrist_yaw_joint"
# ========== Unified DOF Groupings ==========
# Main anatomical groupings for DOF
arm_dof_names:
- "left_shoulder_pitch_joint"
- "left_shoulder_roll_joint"
- "left_shoulder_yaw_joint"
- "left_elbow_joint"
- "left_wrist_roll_joint"
- "left_wrist_pitch_joint"
- "left_wrist_yaw_joint"
- "right_shoulder_pitch_joint"
- "right_shoulder_roll_joint"
- "right_shoulder_yaw_joint"
- "right_elbow_joint"
- "right_wrist_roll_joint"
- "right_wrist_pitch_joint"
- "right_wrist_yaw_joint"
waist_dof_names:
- "waist_yaw_joint"
- "waist_roll_joint"
- "waist_pitch_joint"
leg_dof_names:
- "left_hip_pitch_joint"
- "left_hip_roll_joint"
- "left_hip_yaw_joint"
- "left_knee_joint"
- "left_ankle_pitch_joint"
- "left_ankle_roll_joint"
- "right_hip_pitch_joint"
- "right_hip_roll_joint"
- "right_hip_yaw_joint"
- "right_knee_joint"
- "right_ankle_pitch_joint"
- "right_ankle_roll_joint"
# Side-specific groupings for DOF
left_arm_dof_names:
- "left_shoulder_pitch_joint"
- "left_shoulder_roll_joint"
- "left_shoulder_yaw_joint"
- "left_elbow_joint"
- "left_wrist_roll_joint"
- "left_wrist_pitch_joint"
- "left_wrist_yaw_joint"
right_arm_dof_names:
- "right_shoulder_pitch_joint"
- "right_shoulder_roll_joint"
- "right_shoulder_yaw_joint"
- "right_elbow_joint"
- "right_wrist_roll_joint"
- "right_wrist_pitch_joint"
- "right_wrist_yaw_joint"
left_leg_dof_names:
- "left_hip_pitch_joint"
- "left_hip_roll_joint"
- "left_hip_yaw_joint"
- "left_knee_joint"
- "left_ankle_pitch_joint"
- "left_ankle_roll_joint"
right_leg_dof_names:
- "right_hip_pitch_joint"
- "right_hip_roll_joint"
- "right_hip_yaw_joint"
- "right_knee_joint"
- "right_ankle_pitch_joint"
- "right_ankle_roll_joint"
# Combined groupings for DOF (for backward compatibility and common usage)
upper_body_dof_names: ${robot.arm_dof_names} # Alias for arm_dof_names
lower_body_dof_names:
- "left_hip_pitch_joint"
- "left_hip_roll_joint"
- "left_hip_yaw_joint"
- "left_knee_joint"
- "left_ankle_pitch_joint"
- "left_ankle_roll_joint"
- "right_hip_pitch_joint"
- "right_hip_roll_joint"
- "right_hip_yaw_joint"
- "right_knee_joint"
- "right_ankle_pitch_joint"
- "right_ankle_roll_joint"
- "waist_yaw_joint"
- "waist_roll_joint"
- "waist_pitch_joint"
# ========== Unified Body Groupings ==========
# Main anatomical groupings for bodies
arm_body_names:
- "left_shoulder_pitch_link"
- "left_shoulder_roll_link"
- "left_shoulder_yaw_link"
- "left_elbow_link"
- "left_wrist_roll_link"
- "left_wrist_pitch_link"
- "left_wrist_yaw_link"
- "right_shoulder_pitch_link"
- "right_shoulder_roll_link"
- "right_shoulder_yaw_link"
- "right_elbow_link"
- "right_wrist_roll_link"
- "right_wrist_pitch_link"
- "right_wrist_yaw_link"
head_hand_bodies:
- "torso_link"
- "left_wrist_yaw_link"
- "right_wrist_yaw_link"
torso_body_names:
- "waist_yaw_link"
- "waist_roll_link"
- "torso_link"
leg_body_names:
- "left_hip_pitch_link"
- "left_hip_roll_link"
- "left_hip_yaw_link"
- "left_knee_link"
- "left_ankle_pitch_link"
- "left_ankle_roll_link"
- "right_hip_pitch_link"
- "right_hip_roll_link"
- "right_hip_yaw_link"
- "right_knee_link"
- "right_ankle_pitch_link"
- "right_ankle_roll_link"
# Side-specific groupings for bodies
left_arm_body_names:
- "left_shoulder_pitch_link"
- "left_shoulder_roll_link"
- "left_shoulder_yaw_link"
- "left_elbow_link"
- "left_wrist_roll_link"
- "left_wrist_pitch_link"
- "left_wrist_yaw_link"
right_arm_body_names:
- "right_shoulder_pitch_link"
- "right_shoulder_roll_link"
- "right_shoulder_yaw_link"
- "right_elbow_link"
- "right_wrist_roll_link"
- "right_wrist_pitch_link"
- "right_wrist_yaw_link"
left_leg_body_names:
- "left_hip_pitch_link"
- "left_hip_roll_link"
- "left_hip_yaw_link"
- "left_knee_link"
- "left_ankle_pitch_link"
- "left_ankle_roll_link"
right_leg_body_names:
- "right_hip_pitch_link"
- "right_hip_roll_link"
- "right_hip_yaw_link"
- "right_knee_link"
- "right_ankle_pitch_link"
- "right_ankle_roll_link"
body_names:
- "pelvis"
- "left_hip_pitch_link"
- "left_hip_roll_link"
- "left_hip_yaw_link"
- "left_knee_link"
- "left_ankle_pitch_link"
- "left_ankle_roll_link"
- "right_hip_pitch_link"
- "right_hip_roll_link"
- "right_hip_yaw_link"
- "right_knee_link"
- "right_ankle_pitch_link"
- "right_ankle_roll_link"
- "waist_yaw_link"
- "waist_roll_link"
- "torso_link"
- "left_shoulder_pitch_link"
- "left_shoulder_roll_link"
- "left_shoulder_yaw_link"
- "left_elbow_link"
- "left_wrist_roll_link"
- "left_wrist_pitch_link"
- "left_wrist_yaw_link"
- "right_shoulder_pitch_link"
- "right_shoulder_roll_link"
- "right_shoulder_yaw_link"
- "right_elbow_link"
- "right_wrist_roll_link"
- "right_wrist_pitch_link"
- "right_wrist_yaw_link"
init_state:
pos: [0.0, 0.0, 0.8] # x,y,z [m]
rot: [0.0, 0.929, 0.341, 0.298] # x,y,z,w [quat]
lin_vel: [0.0, 0.0, 0.0] # x,y,z [m/s]
ang_vel: [0.0, 0.0, 0.0] # x,y,z [rad/s]
default_joint_angles: # = target angles [rad] when action = 0.0
left_hip_pitch_joint: -0.312
left_hip_roll_joint: 0.0
left_hip_yaw_joint: 0.0
left_knee_joint: 0.669
left_ankle_pitch_joint: -0.363
left_ankle_roll_joint: 0.0
right_hip_pitch_joint: -0.312
right_hip_roll_joint: 0.0
right_hip_yaw_joint: 0.0
right_knee_joint: 0.669
right_ankle_pitch_joint: -0.363
right_ankle_roll_joint: 0.0
waist_yaw_joint: 0.
waist_roll_joint: 0.
waist_pitch_joint: 0.1
left_shoulder_pitch_joint: 0.2
left_shoulder_roll_joint: 0.2
left_shoulder_yaw_joint: 0.0
left_elbow_joint: 0.6
left_wrist_roll_joint: 0.0
left_wrist_pitch_joint: 0.0
left_wrist_yaw_joint: 0.0
right_shoulder_pitch_joint: 0.2
right_shoulder_roll_joint: -0.2
right_shoulder_yaw_joint: 0.0
right_elbow_joint: 0.6
right_wrist_roll_joint: 0.0
right_wrist_pitch_joint: 0.0
right_wrist_yaw_joint: 0.0
actuators:
actuator_type: unitree_erfi # implicit, unitree, or unitree_erfi
ema_filter_enabled: false
ema_filter_alpha: 1.0
all_joints:
joint_names_expr:
- ".*_hip_yaw_joint"
- ".*_hip_roll_joint"
- ".*_hip_pitch_joint"
- ".*_knee_joint"
- ".*_ankle_pitch_joint"
- ".*_ankle_roll_joint"
- "waist_yaw_joint"
- "waist_roll_joint"
- "waist_pitch_joint"
- ".*_shoulder_pitch_joint"
- ".*_shoulder_roll_joint"
- ".*_shoulder_yaw_joint"
- ".*_elbow_joint"
- ".*_wrist_roll_joint"
- ".*_wrist_pitch_joint"
- ".*_wrist_yaw_joint"
effort_limit_sim:
".*_hip_yaw_joint": 88.0
".*_hip_roll_joint": 139.0
".*_hip_pitch_joint": 88.0
".*_knee_joint": 139.0
".*_ankle_pitch_joint": 50.0
".*_ankle_roll_joint": 50.0
"waist_yaw_joint": 88.0
"waist_roll_joint": 50.0
"waist_pitch_joint": 50.0
".*_shoulder_pitch_joint": 25.0
".*_shoulder_roll_joint": 25.0
".*_shoulder_yaw_joint": 25.0
".*_elbow_joint": 25.0
".*_wrist_roll_joint": 25.0
".*_wrist_pitch_joint": 5.0
".*_wrist_yaw_joint": 5.0
velocity_limit_sim:
".*_hip_yaw_joint": 32.0
".*_hip_roll_joint": 20.0
".*_hip_pitch_joint": 32.0
".*_knee_joint": 20.0
".*_ankle_pitch_joint": 37.0
".*_ankle_roll_joint": 37.0
"waist_yaw_joint": 32.0
"waist_roll_joint": 37.0
"waist_pitch_joint": 37.0
".*_shoulder_pitch_joint": 37.0
".*_shoulder_roll_joint": 37.0
".*_shoulder_yaw_joint": 37.0
".*_elbow_joint": 37.0
".*_wrist_roll_joint": 37.0
".*_wrist_pitch_joint": 22.0
".*_wrist_yaw_joint": 22.0
stiffness:
".*_hip_pitch_joint": 40.17923847
".*_hip_roll_joint": 99.09842778
".*_hip_yaw_joint": 40.17923847
".*_knee_joint": 99.09842778
".*_ankle_pitch_joint": 28.50124620
".*_ankle_roll_joint": 28.50124620
"waist_yaw_joint": 40.17923847
"waist_roll_joint": 28.50124620
"waist_pitch_joint": 28.50124620
".*_shoulder_pitch_joint": 14.25062310
".*_shoulder_roll_joint": 14.25062310
".*_shoulder_yaw_joint": 14.25062310
".*_elbow_joint": 14.25062310
".*_wrist_roll_joint": 14.25062309787429
".*_wrist_pitch_joint": 16.77832748089279
".*_wrist_yaw_joint": 16.77832748089279
damping:
".*_hip_pitch_joint": 2.55788977
".*_hip_roll_joint": 6.30880185
".*_hip_yaw_joint": 2.55788977
".*_knee_joint": 6.30880185
".*_ankle_pitch_joint": 1.81444569
".*_ankle_roll_joint": 1.81444569
"waist_yaw_joint": 2.55788977
"waist_roll_joint": 1.81444569
"waist_pitch_joint": 1.81444569
".*_shoulder_pitch_joint": 0.90722284
".*_shoulder_roll_joint": 0.90722284
".*_shoulder_yaw_joint": 0.90722284
".*_elbow_joint": 0.90722284
".*_wrist_roll_joint": 0.907222843292423
".*_wrist_pitch_joint": 1.06814150219
".*_wrist_yaw_joint": 1.06814150219
armature:
".*_hip_pitch_joint": 0.010177520
".*_hip_roll_joint": 0.025101925
".*_hip_yaw_joint": 0.010177520
".*_knee_joint": 0.025101925
".*_ankle_pitch_joint": 0.007219450
".*_ankle_roll_joint": 0.007219450
"waist_yaw_joint": 0.010177520
"waist_roll_joint": 0.007219450
"waist_pitch_joint": 0.007219450
".*_shoulder_pitch_joint": 0.003609725
".*_shoulder_roll_joint": 0.003609725
".*_shoulder_yaw_joint": 0.003609725
".*_elbow_joint": 0.003609725
".*_wrist_roll_joint": 0.003609725
".*_wrist_pitch_joint": 0.00425
".*_wrist_yaw_joint": 0.00425
action_scale:
".*_hip_pitch_joint": 0.548
".*_hip_roll_joint": 0.351
".*_hip_yaw_joint": 0.548
".*_knee_joint": 0.351
".*_ankle_pitch_joint": 0.439
".*_ankle_roll_joint": 0.439
"waist_yaw_joint": 0.548
"waist_roll_joint": 0.439
"waist_pitch_joint": 0.439
".*_shoulder_pitch_joint": 0.439
".*_shoulder_roll_joint": 0.439
".*_shoulder_yaw_joint": 0.439
".*_elbow_joint": 0.439
".*_wrist_roll_joint": 0.439
".*_wrist_pitch_joint": 0.075
".*_wrist_yaw_joint": 0.075
dof_sign_by_name:
left_hip_pitch_joint: 1.0
left_hip_roll_joint: -1.0
left_hip_yaw_joint: -1.0
left_knee_joint: 1.0
left_ankle_pitch_joint: 1.0
left_ankle_roll_joint: -1.0
right_hip_pitch_joint: 1.0
right_hip_roll_joint: -1.0
right_hip_yaw_joint: -1.0
right_knee_joint: 1.0
right_ankle_pitch_joint: 1.0
right_ankle_roll_joint: -1.0
waist_yaw_joint: -1.0
waist_roll_joint: -1.0
waist_pitch_joint: 1.0
left_shoulder_pitch_joint: 1.0
left_shoulder_roll_joint: -1.0
left_shoulder_yaw_joint: -1.0
left_elbow_joint: 1.0
left_wrist_roll_joint: -1.0
left_wrist_pitch_joint: 1.0
left_wrist_yaw_joint: -1.0
right_shoulder_pitch_joint: 1.0
right_shoulder_roll_joint: -1.0
right_shoulder_yaw_joint: -1.0
right_elbow_joint: 1.0
right_wrist_roll_joint: -1.0
right_wrist_pitch_joint: 1.0
right_wrist_yaw_joint: -1.0
asset:
collapse_fixed_joints: True
replace_cylinder_with_capsule: True
flip_visual_attachments: False
max_angular_velocity: 1000.
max_linear_velocity: 1000.
density: 0.001
angular_damping: 0.
linear_damping: 0.
asset_root: "./"
urdf_file: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.urdf"
assetFileName: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml"
fix_base_link: false
force_usd_conversion: true
extend_config: []
motion:
asset:
assetRoot: "./"
assetFileName: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0.xml"
sampling_strategy: ${algo.config.sampling_strategy}
weighted_bin: ${algo.config.weighted_bin}
curriculum: ${algo.config.curriculum}
dump_sampled_motion_keys: false
dump_sampled_motion_keys_interval: 1
dump_sampled_motion_keys_dir: "sampled_motion_cache_keys"
max_frame_length: 300 # 6s
min_frame_length: 50 # 1s
handpicked_motion_names: null
excluded_motion_names: null
world_frame_normalization: true
backend: "hdf5_v2" # hdf5, hdf5_v2
train_hdf5_roots: ${train_hdf5_roots}
val_hdf5_roots: ${train_hdf5_roots}
dof_names: ${robot.dof_names}
body_names: ${robot.body_names}
key_bodies: ${robot.key_bodies}
extend_config: ${robot.extend_config}
dataloader:
num_workers: 2
prefetch_factor: 1
pin_memory: true
persistent_workers: false
timeout: 600
fk_robot_file_path: ${robot.asset.urdf_file}
fk_vel_smoothing_sigma: 2.0
online_filter:
enabled: false
butter_order: 4
butter_cutoff_hz_pool: []
cache:
batch_progress_bar: false
max_num_clips: ${num_envs} # Batch size for motion clips
device: "cuda" # "cuda" or "cpu"; cuda stages on GPU
swap_interval_steps: ${robot.motion.max_frame_length} # Swap cache every N steps
allowed_prefixes:
- "ref_"
- "ft_ref_"
================================================
FILE: holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab_s100.yaml
================================================
# @package _global_
defaults:
- unitree/G1/29dof/29dof_training_isaaclab
robot:
asset:
urdf_file: "assets/robots/${robot.humanoid_type}/g1_29dof_rev_1_0_s100.urdf"
================================================
FILE: holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking_mlp.yaml
================================================
# @package _global_
defaults:
- /training: train_base
- /algo: ppo
- /robot: unitree/G1/29dof/29dof_training_isaaclab
- /env: motion_tracking
- /env/terminations: termination_motion_tracking
- /env/observations: motion_tracking/obs_motion_tracking_mlp
- /env/rewards: motion_tracking/rew_motion_tracking
- /env/domain_randomization: domain_rand_medium
- /env/terrain: isaaclab_rough
- /modules: motion_tracking/motion_tracking_mlp
project_name: HoloMotionMotrackV1.2
# checkpoint: ???
train_hdf5_roots:
- data/h5v2_datasets/AMASS_test
================================================
FILE: holomotion/config/training/motion_tracking/train_g1_29dof_motion_tracking_tf-moe.yaml
================================================
# @package _global_
defaults:
- /training: train_base
- /algo: ppo_tf
- /robot: unitree/G1/29dof/29dof_training_isaaclab
- /env: motion_tracking
- /env/terminations: termination_motion_tracking
- /env/observations: motion_tracking/obs_motion_tracking_tf-moe
- /env/rewards: motion_tracking/rew_motion_tracking
- /env/domain_randomization: domain_rand_medium
- /env/terrain: isaaclab_rough
- /modules: motion_tracking/motion_tracking_tf-moe
project_name: HoloMotionMotrackV1.2
# checkpoint: ???
train_hdf5_roots:
- data/h5v2_datasets/AMASS_test
================================================
FILE: holomotion/config/training/train_base.yaml
================================================
# @package _global_
defaults:
- _self_
- /mujoco_eval: sim2sim
project_name: ???
experiment_name: ???
num_envs: ???
headless: ???
motion_h5_path: ???
checkpoint: null
num_processes: ???
main_process: ???
process_id: ???
timestamp: ${now:%Y%m%d_%H%M%S}
base_dir: logs
experiment_dir: ${base_dir}/${project_name}/${timestamp}-${experiment_name}
save_dir: ${experiment_dir}/.hydra
output_dir: ${experiment_dir}/output
experiment_save_dir: ???
================================================
FILE: holomotion/config/training/velocity_tracking/train_g1_29dof_velocity_tracking_mlp.yaml
================================================
# @package _global_
defaults:
- /training: train_base
- /algo: ppo
- /robot: unitree/G1/29dof/29dof_training_isaaclab
- /env: velocity_tracking
- /env/terminations: termination_velocity_tracking
- /env/observations: velocity_tracking/obs_velocity_tracking
- /env/rewards: velocity_tracking/rew_velocity_tracking
- /env/domain_randomization: domain_rand_medium
- /env/terrain: isaaclab_rough
- /modules: velocity_tracking/velocity_tracking_mlp
project_name: HoloMotionVelocityTrackingG1
# checkpoint: ???
train_hdf5_roots:
- /horizon-bucket/robot_lab/users/maiyue01.chen/h5_datasets/h5_unitree_walk_20160119
env:
config:
commands:
base_velocity:
params:
resampling_time_range: [3.0, 6.0]
rel_standing_envs: 0.2
rel_yaw_envs: 0.2
heading_command: false
heading_control_stiffness: 0.5
ranges:
lin_vel_x: [-1.0, 1.0]
lin_vel_y: [-0.5, 0.5]
ang_vel_z: [-1.0, 1.0]
heading: [-3.14, 3.14]
algo:
config:
symmetry_loss:
enabled: true
coef: 0.1
robot:
init_state:
default_joint_angles:
waist_pitch_joint: 0.0
================================================
FILE: holomotion/scripts/data_curation/convert_to_amass.sh
================================================
source train.env
# 默认原始数据路径
DATA_ROOT="./data/raw_datasets"
# 如果传入参数就覆盖默认
if [ ! -z "$1" ]; then
DATA_ROOT="$1"
fi
${Train_CONDA_PREFIX}/bin/python \
holomotion/src/data_curation/data_smplify.py \
--data_root "$DATA_ROOT"
================================================
FILE: holomotion/scripts/data_curation/filter_smpl_data.sh
================================================
source train.env
# default json lisy
default_jsonl_list=("humanact12" "MotionX" "OMOMO" "ZJU_Mocap" "amass")
jsonl_list=("${default_jsonl_list[@]}")
# extract command line params
while getopts "l:" opt; do
case $opt in
l)
# 用户输入的 jsonl_list
IFS=' ' read -r -a jsonl_list <<<"$OPTARG"
;;
*)
echo "Usage: $0 [-l \"file1 file2 ...\"]"
exit 1
;;
esac
done
echo "Running label_data.py first..."
${Train_CONDA_PREFIX}/bin/python \
./holomotion/src/data_curation/filter/label_data.py \
--jsonl_list "${jsonl_list[@]}"
echo "label_data.py finished."
echo "=============================="
for json in "${jsonl_list[@]}"; do
echo "Processing $json"
#
if [[ "$json" == "amass" ]]; then
parent_folder="./data/amass_compatible_datasets/amass"
else
parent_folder="./data/amass_compatible_datasets"
fi
# 生成路径
json_path="./data/dataset_labels/${json}.jsonl"
yaml_path="./holomotion/config/data_curation/${json}_excluded.yaml"
# 调用 python 脚本
${Train_CONDA_PREFIX}/bin/python \
./holomotion/src/data_curation/filter/filter.py \
--parent_folder "$parent_folder" \
--json_path "$json_path" \
--yaml_path "$yaml_path"
echo "Finished $json"
echo "-----------------------"
done
echo "All done"
================================================
FILE: holomotion/scripts/data_curation/video_to_smpl_gvhmr.sh
================================================
export CONDA_BASE=$(conda info --base)
export Train_CONDA_PREFIX="$CONDA_BASE/envs/gvhmr"
video_folder_root="holomotion_abs_path/data/video_data"
npz_data_root="holomotion_abs_path/data/gvhmr_converted/gvhmr_result"
out_dir="holomotion_abs_path/data/gvhmr_converted/collected_smpl"
cd thirdparties/GVHMR/
$Train_CONDA_PREFIX/bin/python ../../holomotion/src/data_curation/video_to_smpl_gvhmr.py \
--folder=${video_folder_root} \
--output_root=${npz_data_root} \
-s
mkdir -p "${out_dir}"
for subdir in "${npz_data_root}"/*; do
if [[ ! -d "${subdir}" ]]; then
continue
fi
sub_name=$(basename "${subdir}")
src_npz="${subdir}/smpl.npz"
if [[ ! -f "${src_npz}" ]]; then
echo "[SKIP] ${sub_name}: smpl.npz not found"
continue
fi
dst_npz="${out_dir}/${sub_name}_smpl.npz"
cp -f "${src_npz}" "${dst_npz}"
echo "[COPY] ${src_npz} -> ${dst_npz}"
done
================================================
FILE: holomotion/scripts/data_curation/visualize_smpl_npz.sh
================================================
export CONDA_BASE=$(conda info --base)
export Train_CONDA_PREFIX="$CONDA_BASE/envs/gvhmr"
$Train_CONDA_PREFIX/bin/python ../../holomotion/src/data_curation/visualize_smpl_npz.py
================================================
FILE: holomotion/scripts/evaluation/calc_offline_eval_metrics.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
npz_dir="your_npz_dir"
dataset_suffix="HoloMotion_eval"
metric_calculation="per_clip" # Options: "per_clip" or "per_frame"
dof_mode="23" # Options: "29" for full DoF, "23" for upper body only
${Train_CONDA_PREFIX}/bin/python \
holomotion/src/evaluation/metrics.py \
--npz_dir=${npz_dir} \
--dataset_suffix=${dataset_suffix} \
--failure_pos_err_thresh_m=0.25 \
--metric_calculation=${metric_calculation} \
--dof_mode=${dof_mode}
================================================
FILE: holomotion/scripts/evaluation/eval_motion_tracking.sh
================================================
#!/bin/bash
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
export CUDA_VISIBLE_DEVICES="0"
HEADLESS=true
CONFIG_NAME="eval_isaaclab"
CKPT_PATH="logs/HoloMotionMotrackV1.2/your_log_dir/model_xxx.pt"
eval_h5_dataset_path="['data/h5v2_datasets/lafan1']"
num_envs=4
${Train_CONDA_PREFIX}/bin/accelerate launch \
holomotion/src/evaluation/eval_motion_tracking_single.py \
--config-name=evaluation/${CONFIG_NAME} \
headless=${HEADLESS} \
num_envs=${num_envs} \
export_policy=true \
dump_npzs=true \
calc_per_clip_metrics=true \
generate_report=true \
motion_h5_path=${eval_h5_dataset_path} \
+use_kv_cache=true \
export_only=false \
checkpoint=$CKPT_PATH \
project_name="HoloMotionMoTrack"
================================================
FILE: holomotion/scripts/evaluation/eval_mujoco_sim2sim.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
export CUDA_VISIBLE_DEVICES="0"
export HEADLESS=false
if $HEADLESS; then
export MUJOCO_GL="osmesa"
export RECORD_VIDEO=true
else
export MUJOCO_GL="egl"
export RECORD_VIDEO=false
fi
model_type="${model_type:-holomotion}"
robot_xml_path="assets/robots/unitree/G1/29dof/scene_29dof.xml"
ONNX_PATH="your_onnx_model.onnx"
export motion_npz_path="your_npz.npz"
${Train_CONDA_PREFIX}/bin/python holomotion/src/evaluation/eval_mujoco_sim2sim.py \
record_video=$RECORD_VIDEO \
headless=$HEADLESS \
camera_tracking=true \
camera_distance=7.0 \
+model_type=${model_type} \
use_gpu=true \
dump_npzs=true \
dump_onnx_io_npy=false \
calc_per_clip_metrics=true \
generate_report=true \
ray_actors_per_gpu=12 \
policy_action_delay_step=0 \
action_delay_type=step \
+ckpt_onnx_path="$ONNX_PATH" \
+motion_npz_path='${oc.env:motion_npz_path}' \
robot_xml_path=${robot_xml_path}
================================================
FILE: holomotion/scripts/evaluation/eval_velocity_tracking.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
export CUDA_VISIBLE_DEVICES="0"
config_name="eval_velocity_tracking"
num_envs=1
ckpt_path="logs/HoloMotionVelocityTracking/xxxxx-train_g1_29dof_velocity_tracking/model_xxx.pt"
${Train_CONDA_PREFIX}/bin/python \
holomotion/src/evaluation/eval_velocity_tracking.py \
--config-name=evaluation/${config_name} \
project_name="HoloMotionVelocityTracking" \
num_envs=${num_envs} \
headless=false \
experiment_name=${config_name} \
checkpoint=${ckpt_path} \
+env.config.commands.base_velocity.params.resampling_time_range=[3,5]
================================================
FILE: holomotion/scripts/evaluation/mean_process_5metrics.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import csv
import glob
import json
import os
import re
import numpy as np
import pandas as pd
from tabulate import tabulate
# 需要统计的指标 Key
METRICS = [
"mpjpe_g",
"mpjpe_l",
"whole_body_joints_dist",
"root_vel_error",
"root_r_error",
"root_p_error",
"root_y_error",
"root_height_error",
"mean_dof_vel",
"mean_dof_acc",
"mean_dof_torque",
"mean_action_rate",
"success",
"mean_torque_jump_norm",
"p95_torque_jump_norm",
"mean_torque_jump_ratio",
"p95_torque_jump_ratio",
]
# 表头映射 (Json Key -> 表格显示名称)
COLUMN_MAPPING = {
"mpjpe_g": "Global Bodylink Pos Err",
"mpjpe_l": "Local Bodylink Pos Err",
"whole_body_joints_dist": "Dof Position Err",
"root_vel_error": "Root Vel Err",
"root_r_error": "Root Roll Err",
"root_p_error": "Root Pitch Err",
"root_y_error": "Root Yaw Err",
"root_height_error": "Root Height Err",
"mean_dof_vel": "Mean Dof Vel",
"mean_dof_acc": "Mean Dof Acc",
"mean_dof_torque": "Mean Dof Torque",
"mean_action_rate": "Mean Action Rate",
"success": "Success Rate",
"mean_torque_jump_norm": "Mean Torque Jump Norm",
"p95_torque_jump_norm": "P95 Torque Jump Norm",
"mean_torque_jump_ratio": "Mean Torque Jump Ratio",
"p95_torque_jump_ratio": "P95 Torque Jump Ratio",
}
def get_dataset_name(motion_key):
if not isinstance(motion_key, str):
return "Unknown"
match_old = re.search(r"clips_([a-zA-Z0-9]+)_", motion_key)
if match_old:
return match_old.group(1)
match_new = re.search(r"v1.1_eval_([a-zA-Z0-9]+)_", motion_key)
if match_new:
return match_new.group(1)
return motion_key.split("_")[0]
def process_data(folder_path):
folder_path = os.path.expanduser(folder_path)
search_pattern = os.path.join(folder_path, "*.json")
json_files = glob.glob(search_pattern)
json_files = [
file for file in json_files if "batch_" not in os.path.basename(file)
]
if not json_files:
raise FileNotFoundError(
f"No .json files found in directory: {folder_path}"
)
all_records = []
for file_path in json_files:
model_name = os.path.splitext(os.path.basename(file_path))[0]
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
# 结构兼容性处理
if isinstance(data, dict) and "per_clip" in data:
clips_data = data["per_clip"]
elif isinstance(data, list):
clips_data = data
elif isinstance(data, dict) and "motion_key" in data:
clips_data = [data]
else:
continue
for entry in clips_data:
if "motion_key" not in entry:
continue
dataset_name = get_dataset_name(entry["motion_key"])
record = {"Method": model_name, "Dataset": dataset_name}
for metric in METRICS:
val = entry.get(metric, None)
if val is not None:
record[metric] = val
all_records.append(record)
if not all_records:
raise ValueError(
f"No valid per-clip metric records extracted from: {folder_path}"
)
df = pd.DataFrame(all_records)
df = df.reindex(columns=["Method", "Dataset", *METRICS])
grouped_ds = df.groupby(["Method", "Dataset"])[METRICS]
df_mean_ds = grouped_ds.mean().reset_index()
df_median_ds = grouped_ds.median().reset_index()
# Macro-Mean calculation
df_mean_total = (
df_mean_ds.groupby(["Method"])[METRICS].mean().reset_index()
)
# Macro-Median calculation
df_median_total = (
df_median_ds.groupby(["Method"])[METRICS].mean().reset_index()
)
df_mean_total["Dataset"] = "Total (Macro)"
df_median_total["Dataset"] = "Total (Macro)"
final_mean = pd.concat([df_mean_ds, df_mean_total], ignore_index=True)
final_median = pd.concat(
[df_median_ds, df_median_total], ignore_index=True
)
return final_mean, final_median
def highlight_best(val, best_val):
"""Return a highlighted HTML string when value is best."""
if val is None or pd.isna(val):
return str(val)
val_float = float(val)
best_val_float = float(best_val)
formatted_val = f"{val_float:.4f}"
if np.isclose(val_float, best_val_float, atol=1e-6):
return f"{formatted_val} "
return formatted_val
def generate_report(
df,
folder_path,
file_name="result_table_mean.md",
title="Evaluation Results (Mean)",
):
out_md = os.path.join(folder_path, file_name)
all_datasets = df["Dataset"].unique().tolist()
# 排序:将 Total 放到最后
total_key = "Total (Macro)"
if total_key in all_datasets:
all_datasets.remove(total_key)
all_datasets.sort()
all_datasets.append(total_key)
else:
all_datasets.sort()
md_content_accumulator = f"# {title}\n\n"
md_content_accumulator += (
"> **Note:** 'Total (Macro)' represents the **Macro-Average**, "
"calculated as the arithmetic mean of the scores across all datasets, "
"treating each dataset equally regardless of sample size.\n\n"
)
for ds_name in all_datasets:
sub_df = df[df["Dataset"] == ds_name].copy()
for metric in METRICS:
if metric in sub_df.columns:
if metric == "success":
best_val = sub_df[metric].max()
else:
best_val = sub_df[metric].min()
sub_df[metric] = sub_df[metric].apply(
lambda x, best_val=best_val: highlight_best(x, best_val)
)
sub_df = sub_df.drop(columns=["Dataset"])
sub_df.rename(columns=COLUMN_MAPPING, inplace=True)
cols = list(sub_df.columns)
if "Method" in cols:
cols.insert(0, cols.pop(cols.index("Method")))
sub_df = sub_df[cols]
md_content_accumulator += f"### Dataset: {ds_name}\n"
# 使用 to_markdown 生成表格
table_str = sub_df.to_markdown(index=False)
md_content_accumulator += table_str + "\n\n"
with open(out_md, "w", encoding="utf-8") as f:
f.write(md_content_accumulator)
return os.path.abspath(out_md)
def _format_metric_values_for_cli(sub_df: pd.DataFrame) -> pd.DataFrame:
cli_df = sub_df.copy()
for metric in METRICS:
if metric in cli_df.columns:
cli_df[metric] = cli_df[metric].apply(
lambda x: f"{float(x):.4f}" if pd.notna(x) else "nan"
)
return cli_df
def _print_cli_tables(df: pd.DataFrame, title: str, folder_path: str) -> None:
total_key = "Total (Macro)"
all_datasets = df["Dataset"].unique().tolist()
dataset_order = sorted([d for d in all_datasets if d != total_key])
if total_key in all_datasets:
dataset_order.append(total_key)
merged_df = df.copy()
merged_df["Dataset"] = pd.Categorical(
merged_df["Dataset"], categories=dataset_order, ordered=True
)
merged_df = merged_df.sort_values(
by=["Dataset", "Method"], kind="stable"
).reset_index(drop=True)
merged_df["Dataset"] = merged_df["Dataset"].astype(str)
merged_df = _format_metric_values_for_cli(merged_df)
merged_df.rename(columns=COLUMN_MAPPING, inplace=True)
metric_display_cols = [
COLUMN_MAPPING[m] for m in METRICS if COLUMN_MAPPING[m] in merged_df
]
# table_cols = ["Dataset", "Method"] + metric_display_cols
table_cols = ["Dataset"] + metric_display_cols
merged_df = merged_df[table_cols]
output_tsv_path = os.path.join(
folder_path, "sub_dataset_macro_mean_metrics.tsv"
)
with open(output_tsv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f, delimiter="\t", lineterminator="\n")
writer.writerow(merged_df.columns.tolist())
writer.writerows(merged_df.values.tolist())
table_str = tabulate(
merged_df.values.tolist(),
headers=merged_df.columns.tolist(),
tablefmt="simple_outline",
colalign=("left",) * len(merged_df.columns),
)
block = (
"\n"
+ "=" * 80
+ f"\n{title}\n"
+ "=" * 80
+ f"\n\n{table_str}\n"
+ "=" * 80
+ "\n"
)
print(block)
metric_log_path = os.path.join(folder_path, "metric.log")
with open(metric_log_path, "a", encoding="utf-8") as file:
file.write(block)
def generate_macro_mean_report_from_json_dir(folder_path: str) -> str:
mean_df, _ = process_data(folder_path)
report_path = generate_report(
df=mean_df,
folder_path=folder_path,
file_name="result_table_macro_mean.md",
title="Evaluation Results (Macro-Averaging Mean)",
)
_print_cli_tables(
df=mean_df,
title="DATASET-WISE METRICS (MACRO-AVERAGING MEAN)",
folder_path=folder_path,
)
return report_path
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dir", type=str, help="json文件夹路径")
args = parser.parse_args()
out_md = generate_macro_mean_report_from_json_dir(args.dir)
print(f"报告已生成: {out_md}")
================================================
FILE: holomotion/scripts/evaluation/multi_model_metrics_analysis.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
metrics_json_dir="logs/Holomotion/metrics_output_dataset"
${Train_CONDA_PREFIX}/bin/python \
holomotion/src/evaluation/multi_model_metrics_report.py \
--json_dir="$metrics_json_dir"
================================================
FILE: holomotion/scripts/motion_retargeting/apply_gmr_motion_retarget_patch.sh
================================================
#!/usr/bin/env bash
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
#
# This file was originally adapted from the [GMR] repository:
# https://github.com/YanjieZe/GMR/blob/master/general_motion_retargeting/motion_retarget.py
#
set -euo pipefail
REPO_ROOT="$(pwd)"
TARGET_FILE="${1:-$REPO_ROOT/thirdparties/GMR/general_motion_retargeting/motion_retarget.py}"
if [[ ! -f "$TARGET_FILE" ]]; then
echo "Target file not found: $TARGET_FILE" >&2
exit 1
fi
python - "$TARGET_FILE" <<'PY'
from pathlib import Path
import ast
import sys
import textwrap
PATCH_MARKERS = (
"self.first_frame_damping = max(float(damping), 2.0)",
"self.prev_posture_task = mink.PostureTask(self.model, cost=1e-3)",
"def _solve_task_group(",
)
PATCHED_INIT = """
def __init__(
self,
src_human: str,
tgt_robot: str,
actual_human_height: float = None,
solver: str="daqp", # change from "quadprog" to "daqp".
damping: float=5e-1, # change from 1e-1 to 1e-2.
verbose: bool=True,
use_velocity_limit: bool=False,
) -> None:
# load the robot model
self.xml_file = str(ROBOT_XML_DICT[tgt_robot])
if verbose:
print("Use robot model: ", self.xml_file)
self.model = mj.MjModel.from_xml_path(self.xml_file)
# Print DoF names in order
print("[GMR] Robot Degrees of Freedom (DoF) names and their order:")
self.robot_dof_names = {}
for i in range(self.model.nv): # 'nv' is the number of DoFs
dof_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, self.model.dof_jntid[i])
self.robot_dof_names[dof_name] = i
if verbose:
print(f"DoF {i}: {dof_name}")
print("[GMR] Robot Body names and their IDs:")
self.robot_body_names = {}
for i in range(self.model.nbody): # 'nbody' is the number of bodies
body_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, i)
self.robot_body_names[body_name] = i
if verbose:
print(f"Body ID {i}: {body_name}")
print("[GMR] Robot Motor (Actuator) names and their IDs:")
self.robot_motor_names = {}
for i in range(self.model.nu): # 'nu' is the number of actuators (motors)
motor_name = mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_ACTUATOR, i)
self.robot_motor_names[motor_name] = i
if verbose:
print(f"Motor ID {i}: {motor_name}")
# Load the IK config
with open(IK_CONFIG_DICT[src_human][tgt_robot]) as f:
ik_config = json.load(f)
if verbose:
print("Use IK config: ", IK_CONFIG_DICT[src_human][tgt_robot])
# compute the scale ratio based on given human height and the assumption in the IK config
if actual_human_height is not None:
ratio = actual_human_height / ik_config["human_height_assumption"]
else:
ratio = 1.0
# adjust the human scale table
for key in ik_config["human_scale_table"].keys():
ik_config["human_scale_table"][key] = ik_config["human_scale_table"][key] * ratio
# used for retargeting
self.ik_match_table1 = ik_config["ik_match_table1"]
self.ik_match_table2 = ik_config["ik_match_table2"]
self.human_root_name = ik_config["human_root_name"]
self.robot_root_name = ik_config["robot_root_name"]
self.use_ik_match_table1 = ik_config["use_ik_match_table1"]
self.use_ik_match_table2 = ik_config["use_ik_match_table2"]
self.human_scale_table = ik_config["human_scale_table"]
self.ground = ik_config["ground_height"] * np.array([0, 0, 1])
self.max_iter = 10
self.solver = solver
self.damping = damping
self.first_frame_damping = max(float(damping), 2.0)
self.first_frame_max_iter = max(int(self.max_iter), 10)
self._is_first_frame = True
self.human_body_to_task1 = {}
self.human_body_to_task2 = {}
self.pos_offsets1 = {}
self.rot_offsets1 = {}
self.pos_offsets2 = {}
self.rot_offsets2 = {}
self._arm_task_original_orientation_costs = {}
self._first_frame_arm_orientation_cost = 1.0
self.task_errors1 = {}
self.task_errors2 = {}
self.ik_limits = [mink.ConfigurationLimit(self.model)]
if use_velocity_limit:
VELOCITY_LIMITS = {k: 3*np.pi for k in self.robot_motor_names.keys()}
self.ik_limits.append(mink.VelocityLimit(self.model, VELOCITY_LIMITS))
self.setup_retarget_configuration()
self.ground_offset = 0.0
"""
PATCHED_SETUP = """
def setup_retarget_configuration(self):
self.configuration = mink.Configuration(self.model)
self._default_qpos = self.configuration.data.qpos.copy()
self.posture_task = mink.PostureTask(self.model, cost=1e-2)
self.posture_task.set_target(self._default_qpos)
self.prev_posture_task = mink.PostureTask(self.model, cost=1e-3)
self.prev_posture_task.set_target(self._default_qpos)
self.tasks1 = []
self.tasks2 = []
for frame_name, entry in self.ik_match_table1.items():
body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry
if pos_weight != 0 or rot_weight != 0:
task = mink.FrameTask(
frame_name=frame_name,
frame_type="body",
position_cost=pos_weight,
orientation_cost=rot_weight,
lm_damping=1,
)
self.human_body_to_task1[body_name] = task
self.pos_offsets1[body_name] = np.array(pos_offset) - self.ground
self.rot_offsets1[body_name] = R.from_quat(
rot_offset, scalar_first=True
)
self.tasks1.append(task)
self.task_errors1[task] = []
if self._is_arm_body(body_name):
self._arm_task_original_orientation_costs[task] = float(
rot_weight
)
for frame_name, entry in self.ik_match_table2.items():
body_name, pos_weight, rot_weight, pos_offset, rot_offset = entry
if pos_weight != 0 or rot_weight != 0:
task = mink.FrameTask(
frame_name=frame_name,
frame_type="body",
position_cost=pos_weight,
orientation_cost=rot_weight,
lm_damping=1,
)
self.human_body_to_task2[body_name] = task
self.pos_offsets2[body_name] = np.array(pos_offset) - self.ground
self.rot_offsets2[body_name] = R.from_quat(
rot_offset, scalar_first=True
)
self.tasks2.append(task)
self.task_errors2[task] = []
if self._is_arm_body(body_name):
self._arm_task_original_orientation_costs[task] = float(
rot_weight
)
"""
PATCHED_RETARGET_BLOCK = """
@staticmethod
def _is_arm_body(body_name):
return any(
token in body_name
for token in (
"left_shoulder",
"right_shoulder",
"left_elbow",
"right_elbow",
"left_wrist",
"right_wrist",
)
)
def _set_first_frame_arm_task_costs(self, enabled):
for task, original_orientation_cost in (
self._arm_task_original_orientation_costs.items()
):
orientation_cost = (
self._first_frame_arm_orientation_cost
if enabled
else original_orientation_cost
)
task.set_orientation_cost(orientation_cost)
def _solve_task_group(
self,
tasks,
error_fn,
*,
damping,
max_iter,
include_posture,
include_prev_posture,
):
solve_tasks = list(tasks)
if include_posture:
solve_tasks.append(self.posture_task)
if include_prev_posture:
solve_tasks.append(self.prev_posture_task)
curr_error = error_fn()
dt = self.configuration.model.opt.timestep
vel = mink.solve_ik(
self.configuration,
solve_tasks,
dt,
self.solver,
damping,
limits=self.ik_limits,
)
self.configuration.integrate_inplace(vel, dt)
next_error = error_fn()
num_iter = 0
while curr_error - next_error > 0.001 and num_iter < max_iter:
curr_error = next_error
dt = self.configuration.model.opt.timestep
vel = mink.solve_ik(
self.configuration,
solve_tasks,
dt,
self.solver,
damping,
limits=self.ik_limits,
)
self.configuration.integrate_inplace(vel, dt)
next_error = error_fn()
num_iter += 1
def retarget(self, human_data, offset_to_ground=False):
prev_q = self.configuration.data.qpos.copy()
# Update the task targets
self.update_targets(human_data, offset_to_ground)
include_posture = self._is_first_frame
include_prev_posture = True
solve_damping = (
self.first_frame_damping if self._is_first_frame else self.damping
)
solve_max_iter = (
self.first_frame_max_iter if self._is_first_frame else self.max_iter
)
self.prev_posture_task.set_target(prev_q)
if self._is_first_frame:
self._set_first_frame_arm_task_costs(True)
if self.use_ik_match_table1:
self._solve_task_group(
self.tasks1,
self.error1,
damping=solve_damping,
max_iter=solve_max_iter,
include_posture=include_posture,
include_prev_posture=include_prev_posture,
)
if self.use_ik_match_table2:
self._solve_task_group(
self.tasks2,
self.error2,
damping=solve_damping,
max_iter=solve_max_iter,
include_posture=include_posture,
include_prev_posture=include_prev_posture,
)
if self._is_first_frame:
self._set_first_frame_arm_task_costs(False)
self._is_first_frame = False
return self.configuration.data.qpos.copy()
"""
def indent_block(src: str, indent: str = " ") -> str:
body = textwrap.dedent(src).strip("\n")
return "\n".join(indent + line if line else "" for line in body.splitlines()) + "\n"
def find_class(module: ast.Module, class_name: str) -> ast.ClassDef:
for node in module.body:
if isinstance(node, ast.ClassDef) and node.name == class_name:
return node
raise SystemExit(f"Class {class_name!r} not found in target file.")
def find_method(class_node: ast.ClassDef, method_name: str) -> ast.FunctionDef:
for node in class_node.body:
if isinstance(node, ast.FunctionDef) and node.name == method_name:
return node
raise SystemExit(f"Method {method_name!r} not found in class {class_node.name}.")
def apply_replacement(lines, node: ast.AST, replacement: str):
start = node.lineno - 1
end = node.end_lineno
return lines[:start] + [indent_block(replacement)] + lines[end:]
path = Path(sys.argv[1])
text = path.read_text(encoding="utf-8")
if all(marker in text for marker in PATCH_MARKERS):
print(f"Patch already present: {path}")
raise SystemExit(0)
module = ast.parse(text)
class_node = find_class(module, "GeneralMotionRetargeting")
init_node = find_method(class_node, "__init__")
setup_node = find_method(class_node, "setup_retarget_configuration")
retarget_node = find_method(class_node, "retarget")
lines = text.splitlines(keepends=True)
for node, replacement in sorted(
[
(retarget_node, PATCHED_RETARGET_BLOCK),
(setup_node, PATCHED_SETUP),
(init_node, PATCHED_INIT),
],
key=lambda item: item[0].lineno,
reverse=True,
):
lines = apply_replacement(lines, node, replacement)
new_text = "".join(lines)
ast.parse(new_text)
path.write_text(new_text, encoding="utf-8")
print(f"Patch applied to: {path}")
PY
================================================
FILE: holomotion/scripts/motion_retargeting/pack_hdf5_v2.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
export CUDA_VISIBLE_DEVICES=""
holomotion_npz_root='["data/holomotion_retargeted/AMASS_test"]'
hdf5_root="data/h5v2_datasets/AMASS_test"
robot_config="unitree/G1/29dof/29dof_training_isaaclab"
${Train_CONDA_PREFIX}/bin/python \
holomotion/src/motion_retargeting/pack_hdf5_v2.py \
robot=$robot_config \
holomotion_npz_root=${holomotion_npz_root} \
hdf5_root=$hdf5_root
================================================
FILE: holomotion/scripts/motion_retargeting/run_holomotion_preprocessing.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
holo_src_dir="src_holomotion_npz_dir"
holo_tgt_dir="output_holomotion_npz_dir"
pipeline="['filename_as_motionkey','legacy_to_ref_keys','tagging']"
robot_config="holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml"
${Train_CONDA_PREFIX}/bin/python \
holomotion/src/motion_retargeting/holomotion_preprocess.py \
padding.robot_config_path=${robot_config} \
io.src_root=${holo_src_dir} \
io.out_root=${holo_tgt_dir} \
preprocess.pipeline=${pipeline} \
ray.enabled=true \
padding.stand_still_time=20.0 \
ray.num_workers=2
================================================
FILE: holomotion/scripts/motion_retargeting/run_kinematic_filter.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
dataset_root="data/holomotion_retargeted/processed_datasets/AMASS_test"
${Train_CONDA_PREFIX}/bin/python \
holomotion/src/motion_retargeting/kinematic_filter.py \
io.dataset_root=${dataset_root}
================================================
FILE: holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_bvh.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
bvh_src_dir="data/lafan1_bvh"
gmr_tgt_dir="data/gmr_retargeted/lafan1/"
# Step 1: retargeting to robot dataset from smplx format
# create gmr_tgt_dir if not exists
if [ ! -d "$gmr_tgt_dir" ]; then
mkdir -p $gmr_tgt_dir
fi
$Train_CONDA_PREFIX/bin/python \
thirdparties/GMR/scripts/bvh_to_robot_dataset.py \
--src_folder ${bvh_src_dir}/ \
--tgt_folder ${gmr_tgt_dir}/ \
--robot unitree_g1
================================================
FILE: holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_smplx.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
smplx_src_dir="assets/test_data/motion_retargeting/"
gmr_tgt_dir="data/gmr_retargeted/AMASS_test/"
# create gmr_tgt_dir if not exists
if [ ! -d "$gmr_tgt_dir" ]; then
mkdir -p $gmr_tgt_dir
fi
$Train_CONDA_PREFIX/bin/python \
thirdparties/GMR/scripts/smplx_to_robot_dataset.py \
--src_folder=${smplx_src_dir}/ \
--tgt_folder=${gmr_tgt_dir}/ \
--num_cpus=16 \
--robot=unitree_g1
================================================
FILE: holomotion/scripts/motion_retargeting/run_motion_retargeting_gmr_to_holomotion.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
dir_name="AMASS_test"
gmr_tgt_dir="data/gmr_retargeted/${dir_name}"
holo_retargeted_dir="data/holomotion_retargeted/processed_datasets/${dir_name}"
robot_cfg="holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml"
preprocess_pipeline="['filename_as_motionkey','legacy_to_ref_keys','slicing','add_padding','tagging']"
${Train_CONDA_PREFIX}/bin/python \
holomotion/src/motion_retargeting/gmr_to_holomotion.py \
io.robot_config=${robot_cfg} \
io.src_dir=${gmr_tgt_dir} \
io.out_root=${holo_retargeted_dir} \
processing.target_fps=50 \
preprocess.pipeline=${preprocess_pipeline} \
ray.num_workers=16
================================================
FILE: holomotion/scripts/motion_retargeting/run_motion_viz_mujoco.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
export MUJOCO_GL="osmesa"
motion_npz_root="path_to_your_npz_dir"
export motion_name="all"
$Train_CONDA_PREFIX/bin/python holomotion/src/motion_retargeting/utils/visualize_with_mujoco.py \
+key_prefix="robot_" \
+draw_ref_body_spheres=true \
+ref_key_prefix="ref_" \
+motion_npz_root=${motion_npz_root} \
skip_frames=6 \
max_workers=11 \
+motion_name='${oc.env:motion_name}'
================================================
FILE: holomotion/scripts/training/train_motion_tracking.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
export CUDA_VISIBLE_DEVICES=0
if [[ $(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -eq 1 ]]; then
USE_MULTI_GPU=false
else
USE_MULTI_GPU=true
fi
config_name="train_g1_29dof_motion_tracking_mlp"
# config_name="train_g1_29dof_motion_tracking_tf-moe"
num_envs=4096
COMMON_ARGS=(
"holomotion/src/training/train.py"
"--config-name=training/motion_tracking/${config_name}"
"num_envs=${num_envs}"
"headless=true"
"experiment_name=${config_name}"
)
trap cleanup SIGINT SIGTERM
if [[ "${USE_MULTI_GPU}" == "true" ]]; then
${Train_CONDA_PREFIX}/bin/accelerate launch \
--multi_gpu \
"${COMMON_ARGS[@]}"
else
${Train_CONDA_PREFIX}/bin/accelerate launch \
"${COMMON_ARGS[@]}"
fi
wait ${TRAIN_PID}
trap - SIGINT SIGTERM
================================================
FILE: holomotion/scripts/training/train_velocity_tracking.sh
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
source train.env
export CUDA_VISIBLE_DEVICES=0
config_name="train_g1_29dof_velocity_tracking_mlp"
num_envs=4096
COMMON_ARGS=(
"holomotion/src/training/train.py"
"--config-name=training/velocity_tracking/${config_name}"
"experiment_name=${config_name}"
"num_envs=${num_envs}"
"headless=true"
)
trap cleanup SIGINT SIGTERM
if [[ $(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -eq 1 ]]; then
${Train_CONDA_PREFIX}/bin/accelerate launch \
--multi_gpu \
"${COMMON_ARGS[@]}"
else
${Train_CONDA_PREFIX}/bin/accelerate launch \
"${COMMON_ARGS[@]}"
fi
wait ${TRAIN_PID}
trap - SIGINT SIGTERM
================================================
FILE: holomotion/src/algo/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/algo/algo_base.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import random
import statistics
import sys
import time
from collections import deque
from typing import Any, Dict
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.utils import (
ProjectConfiguration,
TorchDynamoPlugin,
load_checkpoint_in_model,
load_state_dict,
)
from hydra.utils import get_class
from loguru import logger
from tensordict import TensorDict
from holomotion.src.algo.algo_utils import AlgoLogger
class BaseOnpolicyRL:
"""Base class for on-policy RL algorithms in HoloMotion."""
def __init__(
self,
env_config,
config,
log_dir=None,
headless: bool = True,
is_offline_eval: bool = False,
) -> None:
self.config = config
self.env_config = env_config
self.log_dir = log_dir
self.headless = headless
self.is_offline_eval = is_offline_eval
self._setup_accelerator()
self.algo_logger = AlgoLogger(
self.accelerator,
self.log_dir,
is_main_process=self.is_main_process,
)
self._setup_environment()
self._setup_configs()
self._setup_seeding()
self._setup_data_buffers()
self._setup_algo_components()
self._setup_models_and_optimizer()
def _setup_accelerator(self) -> None:
if not self.is_offline_eval:
os.makedirs(self.log_dir, exist_ok=True)
accelerator_kwargs = {}
mixed_precision = self.config.get("mixed_precision", None)
if mixed_precision in ("fp16", "bf16"):
accelerator_kwargs["mixed_precision"] = mixed_precision
dynamo_backend = self.config.get("dynamo_backend", None)
if os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1":
dynamo_backend = None
if dynamo_backend in ("inductor", "aot_eager", "cudagraphs"):
dynamo_dynamic = bool(self.config.get("dynamo_dynamic", True))
dynamo_fullgraph = bool(self.config.get("dynamo_fullgraph", False))
dynamo_mode = self.config.get("dynamo_mode", "default")
accelerator_kwargs["dynamo_plugin"] = TorchDynamoPlugin(
backend=str(dynamo_backend),
mode=str(dynamo_mode),
fullgraph=bool(dynamo_fullgraph),
dynamic=bool(dynamo_dynamic),
)
accelerator_kwargs["log_with"] = "tensorboard"
project_config = ProjectConfiguration(
project_dir=self.log_dir,
logging_dir=self.log_dir,
)
accelerator_kwargs["project_config"] = project_config
self.accelerator = Accelerator(**accelerator_kwargs)
self.local_rank = getattr(
self.accelerator, "local_process_index", None
)
if self.local_rank is None:
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
self.device = self.accelerator.device
if torch.cuda.is_available() and self.device.type == "cuda":
dev_index = self.device.index
if dev_index is None:
dev_index = int(self.local_rank)
self.device = torch.device("cuda", dev_index)
else:
dev_index = int(dev_index)
torch.cuda.set_device(dev_index)
self.is_main_process = self.accelerator.is_main_process
self.accelerator.init_trackers(
project_name="holomotion",
config={
"precision": mixed_precision if mixed_precision else "fp32",
"dynamo_backend": dynamo_backend if dynamo_backend else "none",
"dynamo_dynamic": bool(self.config.get("dynamo_dynamic", True))
if dynamo_backend
else False,
},
)
self._release_cuda_cache()
logger.remove()
log_level = os.environ.get("LOGURU_LEVEL", "INFO").upper()
if self.log_dir:
rank_log_file_name = (
"offline_eval_rank" if self.is_offline_eval else "run_rank"
)
logger.add(
os.path.join(
self.log_dir,
f"{rank_log_file_name}_{int(self.accelerator.process_index):04d}.log",
),
level=log_level,
colorize=False,
)
if self.is_main_process:
logger.add(
sys.stdout,
level=log_level,
colorize=True,
)
log_file_name = (
"offline_eval.log" if self.is_offline_eval else "run.log"
)
logger.add(
os.path.join(self.log_dir, log_file_name),
level=log_level,
colorize=False,
)
used_precision = mixed_precision if mixed_precision else "fp32"
logger.info(
f"Accelerator initialized with precision: {used_precision}"
)
if dynamo_backend:
logger.info(f"Accelerator dynamo_backend: {dynamo_backend}")
logger.info(f"TensorBoard logging enabled at: {self.log_dir}")
self.process_rank = self.accelerator.process_index
self.gpu_world_size = self.accelerator.num_processes
self.gpu_global_rank = self.accelerator.process_index
self.is_distributed = self.gpu_world_size > 1
env_rank = os.environ.get("RANK", "unset")
env_local_rank = os.environ.get("LOCAL_RANK", "unset")
env_world_size = os.environ.get("WORLD_SIZE", "unset")
env_local_world_size = os.environ.get("LOCAL_WORLD_SIZE", "unset")
env_node_rank = os.environ.get(
"NODE_RANK", os.environ.get("MACHINE_RANK", "unset")
)
env_master_addr = os.environ.get("MASTER_ADDR", "unset")
env_master_port = os.environ.get("MASTER_PORT", "unset")
env_cuda_visible_devices = os.environ.get(
"CUDA_VISIBLE_DEVICES", "unset"
)
cuda_device_count = (
int(torch.cuda.device_count()) if torch.cuda.is_available() else 0
)
logger.info(
"[Accelerate setup] "
f"distributed_type={self.accelerator.distributed_type}, "
f"num_processes={int(self.accelerator.num_processes)}, "
f"process_index={int(self.accelerator.process_index)}, "
f"local_process_index={int(self.local_rank)}, "
f"is_main_process={bool(self.accelerator.is_main_process)}"
)
logger.info(
"[Accelerate env] "
f"RANK={env_rank}, LOCAL_RANK={env_local_rank}, "
f"WORLD_SIZE={env_world_size}, "
f"LOCAL_WORLD_SIZE={env_local_world_size}, "
f"NODE_RANK={env_node_rank}, MASTER_ADDR={env_master_addr}, "
f"MASTER_PORT={env_master_port}"
)
logger.info(
"[Accelerate cuda] "
f"CUDA_VISIBLE_DEVICES={env_cuda_visible_devices}, "
f"torch_cuda_device_count={cuda_device_count}, "
f"selected_device={self.device}"
)
def _setup_environment(self) -> None:
"""Setup IsaacLab AppLauncher and environment instance."""
# Device string from accelerator (handles distributed training)
device_str = str(self.device)
# Delayed import to ensure Accelerate is fully initialized before IsaacLab
from isaaclab.app import AppLauncher
# Stagger IsaacSim AppLauncher initialization across distributed ranks
# Use local rank per node to stagger independently on each node
if self.is_distributed:
self.accelerator.wait_for_everyone()
base_delay_s = float(
os.environ.get("HOLOMOTION_ISAAC_STAGGER_SEC", "5.0")
)
local_rank = int(self.local_rank)
delay_s = base_delay_s * float(local_rank)
if delay_s > 0.0:
logger.info(
f"[Global Rank {self.gpu_global_rank}, Local Rank {local_rank}] "
f"Sleeping {delay_s:.1f}s before IsaacSim AppLauncher init"
)
time.sleep(delay_s)
# Create AppLauncher with accelerator device
# Enable cameras only when needed:
# - headless & recording: True (offscreen rendering)
# - headless & not recording: False (maximize performance)
# - with GUI: True
_record_video = bool(self.config.get("record_video", False))
enable_cameras = _record_video or (not self.headless)
# Explicitly disable Omniverse multi-GPU rendering to avoid per-process
# MGPU context creation across all visible GPUs.
kit_args_str = (
"--/renderer/multiGpu/enabled=false "
"--/renderer/multiGpu/autoEnable=false "
"--/renderer/multiGpu/maxGpuCount=1"
)
app_launcher_flags = {
"headless": self.headless,
"enable_cameras": enable_cameras,
"video": _record_video,
"device": device_str,
"kit_args": kit_args_str,
}
self._sim_app_launcher = AppLauncher(**app_launcher_flags)
self._sim_app = self._sim_app_launcher.app
logger.info(
f"AppLauncher initialized with flags: {app_launcher_flags}"
)
env_class = get_class(self.env_config._target_)
render_mode = (
"rgb_array"
if bool(self.config.get("record_video", False))
else None
)
self.env = env_class(
config=self.env_config.config,
device=device_str,
headless=self.headless,
log_dir=self.log_dir,
accelerator=self.accelerator,
render_mode=render_mode,
)
_ = self.env.reset_all()
logger.info(f"Environment initialized with render_mode: {render_mode}")
def _setup_configs(self) -> None:
self.num_envs: int = self.env.config.num_envs
self.num_privileged_obs = 0
self.num_actions = self.env.config.robot.actions_dim
self.command_name = list(self.env.config.commands.keys())[0]
self.command_term = self.env._env.command_manager.get_term(
self.command_name
)
if self.command_name == "ref_motion":
self.command_term.set_runtime_distributed_context(
process_id=int(self.accelerator.process_index),
num_processes=int(self.accelerator.num_processes),
)
self.command_term.setup_dumping_dir(self.log_dir)
self.save_interval = self.config.save_interval
self.log_interval = self.config.log_interval
self.num_steps_per_env = self.config.num_steps_per_env
self.num_learning_iterations = self.config.num_learning_iterations
self.total_learning_iterations = int(self.num_learning_iterations)
def _setup_seeding(self) -> None:
if self.command_name == "ref_motion":
self.seed = int(self.command_term.cfg.seed)
self.base_seed = int(self.seed - int(self.process_rank))
else:
self.base_seed = int(self.config.get("seed", int(time.time())))
self.seed = int(self.base_seed + int(self.process_rank))
random.seed(self.seed)
np.random.seed(self.seed)
torch.manual_seed(self.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(self.seed)
self.env.seed(self.seed)
if self.command_name == "ref_motion":
self.command_term.set_motion_cache_seed(
self.seed, reinitialize=False
)
def _setup_data_buffers(self) -> None:
self.tot_timesteps = 0
self.tot_time = 0
self.current_learning_iteration = 0
self.start_time = 0
self.stop_time = 0
self.collection_time = 0
self.learn_time = 0
self.ep_infos = []
self.rewbuffer = deque(maxlen=100)
self.lenbuffer = deque(maxlen=100)
self.cur_reward_sum = torch.zeros(
self.env.num_envs,
dtype=torch.float,
device=self.device,
)
self.cur_episode_length = torch.zeros(
self.env.num_envs,
dtype=torch.float,
device=self.device,
)
self.storage = None
self.transition_td = None
self._last_rollout_dones = None
self._last_rollout_actions = None
def _setup_algo_components(self) -> None:
"""Hook for algorithm-specific components (AMP, DAgger, PULSE)."""
return
def _setup_models_and_optimizer(self) -> None:
raise NotImplementedError(
"Subclasses must implement _setup_models_and_optimizer."
)
def _build_storage(self, obs_td: TensorDict) -> Any:
"""Hook for custom RolloutStorage. Override for specialized storage; default no-op."""
return None
def _post_env_step_hook(
self,
rewards: torch.Tensor,
dones: torch.Tensor,
time_outs: torch.Tensor,
infos: Dict[str, Any],
) -> None:
"""Hook after each env step for auxiliary data collection."""
if self.command_name != "ref_motion":
return
motion_term = self.env._env.command_manager.get_term("ref_motion")
if motion_term is None:
return
motion_term.update_curriculum_reward_accumulators(rewards)
def _post_update_hook(self, loss_dict: Dict[str, Any]) -> None:
"""Hook after each PPO update for auxiliary losses or logging."""
return
def _extra_checkpoint_state(self) -> Dict[str, Any]:
"""Additional state to save in checkpoints."""
return {}
def _load_extra_checkpoint_state(
self, loaded_dict: Dict[str, Any]
) -> None:
"""Load additional checkpoint state if present."""
return
def _build_transition(
self,
obs_td: TensorDict,
actor_out: TensorDict,
critic_out: TensorDict,
):
raise NotImplementedError(
"Subclasses must implement _build_transition."
)
def _post_iteration_hook(self, it: int) -> None:
return
def _post_training_hook(self) -> None:
return
def _release_cuda_cache(self) -> None:
if torch.cuda.is_available() and self.device.type == "cuda":
torch.cuda.empty_cache()
def _get_additional_log_metrics(self) -> Dict[str, Any]:
return {}
def train_mode(self) -> None:
self.actor.train()
self.critic.train()
def _ensure_storage(self, obs_td: TensorDict) -> None:
if self.storage is not None:
return
self.storage = self._build_storage(obs_td)
if self.storage is None:
raise RuntimeError(
"Storage is not initialized. Override _build_storage() or initialize self.storage in subclass."
)
def _reset_rollout_forward_state(self) -> None:
"""Hook for algorithm-specific rollout state reset."""
return
def _rollout_forward(
self,
obs_td: TensorDict,
*,
actor_mode: str = "sampling",
collect_transition: bool = True,
track_episode_stats: bool = True,
) -> TensorDict:
update_obs_norm = not self.is_offline_eval
with self.accelerator.autocast():
actor_out: TensorDict = self.actor(
obs_td,
actions=None,
mode=actor_mode,
update_obs_norm=update_obs_norm,
)
critic_out: TensorDict | None = None
if collect_transition:
critic_out = self.critic(
obs_td, update_obs_norm=update_obs_norm
)
if collect_transition:
self.transition_td = self._build_transition(
obs_td,
actor_out,
critic_out,
)
actions = actor_out.get("actions")
self._last_rollout_actions = actions
obs_dict, rewards, dones, time_outs, infos = self.env.step(actions)
next_obs_td = self._wrap_obs_dict(obs_dict)
dones = dones.to(self.device)
self._last_rollout_dones = dones
if collect_transition:
rewards = rewards.to(self.device)
time_outs = time_outs.to(self.device)
self.process_env_step(rewards, dones, time_outs, infos)
if track_episode_stats:
rewards_for_stats = rewards.to(self.device)
self._track_episode_stats(rewards_for_stats, dones, infos)
return next_obs_td
def _track_episode_stats(
self,
rewards: torch.Tensor,
dones: torch.Tensor,
infos: Dict[str, Any],
) -> None:
log_info = infos.get("log")
if self.is_main_process and isinstance(log_info, dict):
cpu_log_info: Dict[str, torch.Tensor] = {}
for key, value in log_info.items():
cpu_value = self._log_value_to_cpu_tensor(value)
if cpu_value is not None and cpu_value.numel() > 0:
cpu_log_info[key] = cpu_value
if len(cpu_log_info) > 0:
self.ep_infos.append(cpu_log_info)
self.cur_reward_sum += rewards
self.cur_episode_length += 1
done_ids = (dones > 0).nonzero(as_tuple=False)
self.rewbuffer.extend(
self.cur_reward_sum[done_ids][:, 0].cpu().numpy().tolist()
)
self.lenbuffer.extend(
self.cur_episode_length[done_ids][:, 0].cpu().numpy().tolist()
)
self.cur_reward_sum[done_ids] = 0
self.cur_episode_length[done_ids] = 0
def _compute_returns(self, obs_td: TensorDict) -> None:
update_obs_norm = not self.is_offline_eval
with self.accelerator.autocast():
last_values = (
self.critic(obs_td, update_obs_norm=update_obs_norm)
.get("values")
.detach()
)
self.storage.compute_returns(
last_values,
self.gamma,
self.lam,
normalize_advantage=False,
)
if getattr(self, "global_advantage_norm", False):
accelerator = self.accelerator if self.is_distributed else None
self.storage.normalize_advantages_global_by_command(
command_name=self.command_name,
accelerator=accelerator,
eps=1.0e-8,
)
def rollout_policy(self, obs_td: TensorDict) -> TensorDict:
"""Collect one rollout with current policy and compute returns."""
actor_was_training = self.actor.training
critic_was_training = self.critic.training
self.actor.eval()
self.critic.eval()
with torch.no_grad():
self._reset_rollout_forward_state()
for _ in range(self.num_steps_per_env):
obs_td = self._rollout_forward(obs_td)
self._compute_returns(obs_td)
if actor_was_training:
self.actor.train()
if critic_was_training:
self.critic.train()
return obs_td
def learn(self):
"""Main learning loop with runner logic shared across on-policy algorithms."""
obs_dict = self.env.reset_all()[0]
obs_td = self._wrap_obs_dict(obs_dict)
self._ensure_storage(obs_td)
self.train_mode()
start_it = self.current_learning_iteration
total_it = start_it + int(self.num_learning_iterations)
self.total_learning_iterations = total_it
self.accelerator.wait_for_everyone()
if self.is_main_process:
logger.info(
f"Starting training for {self.num_learning_iterations} iterations "
f"from iteration {self.current_learning_iteration}"
)
for it in range(start_it, total_it):
self.current_learning_iteration = it
start = time.time()
obs_td = self.rollout_policy(obs_td)
stop = time.time()
collection_time = stop - start
start = stop
loss_dict = self.update()
stop = time.time()
learn_time = stop - start
if self.is_main_process and it % self.log_interval == 0:
self._log_iteration(
it=it,
loss_dict=loss_dict,
collection_time=collection_time,
learn_time=learn_time,
)
if self.is_main_process and it % self.save_interval == 0:
self.save(
os.path.join(
self.log_dir,
f"model_{self.current_learning_iteration}.pt",
)
)
self._release_cuda_cache()
self._post_iteration_hook(it)
self.ep_infos.clear()
self.accelerator.wait_for_everyone()
final_checkpoint_path = os.path.join(
self.log_dir, f"model_{self.current_learning_iteration}.pt"
)
if self.is_main_process:
self.save(final_checkpoint_path)
self._release_cuda_cache()
self._post_training_hook()
if self.log_dir:
self.accelerator.wait_for_everyone()
self.accelerator.end_training()
if self.is_main_process:
logger.info(
f"Training completed. Model saved to {self.log_dir}"
)
def process_env_step(
self,
rewards: torch.Tensor,
dones: torch.Tensor,
time_outs: torch.Tensor,
infos: Dict[str, Any],
) -> None:
"""Process env step results and append to storage.
Args:
rewards: [N, 1] rewards (env step output).
dones: [N, 1] done flags (env step output).
time_outs: [N] time out flags (env step output).
infos: Environment info dictionary.
"""
raw_rewards = rewards.clone().view(-1, 1)
rewards = raw_rewards.clone()
dones = dones.view(-1, 1)
# Bootstrapping on time outs
rewards += self.gamma * (
self.transition_td.values * time_outs[:, None]
)
self.transition_td.rewards = rewards
self.transition_td.dones = dones.to(dtype=torch.bool)
self.storage.add(self.transition_td)
self._post_env_step_hook(raw_rewards, dones, time_outs, infos)
self.transition_td = None
def _wrap_obs_dict(self, obs_dict: dict) -> TensorDict:
"""Wrap env obs dict into a native nested TensorDict on device."""
return TensorDict.from_dict(
obs_dict,
batch_size=[self.env.num_envs],
device=self.device,
)
@staticmethod
def _clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Remove '_orig_mod.' prefix from torch.compile wrapped models.
Args:
state_dict: State dict that may contain '_orig_mod.' prefixed keys
Returns:
Cleaned state dict with prefixes removed
"""
cleaned_dict = {}
prefix = "_orig_mod."
prefix_len = len(prefix)
for k, v in state_dict.items():
new_k = k[prefix_len:] if k.startswith(prefix) else k
cleaned_dict[new_k] = v
return cleaned_dict
def _load_model_state(self, model, state_dict, *, strict: bool = True):
"""Load a state dict into a (possibly compiled) model safely.
- Always unwrap Accelerate wrappers first.
- If the model is a compiled OptimizedModule (has ``_orig_mod``),
load into the original module and strip any ``_orig_mod.`` prefixes
from the incoming state dict for robustness.
"""
target = self.accelerator.unwrap_model(model)
cleaned = self._clean_state_dict(state_dict)
if hasattr(target, "_orig_mod"):
target._orig_mod.load_state_dict(cleaned, strict=strict)
else:
target.load_state_dict(cleaned, strict=strict)
def _resolve_model_file_path(self, ckpt_path: str, model_name: str) -> str:
"""Resolve per-model Accelerate checkpoint directory from *.pt path."""
base_path = ckpt_path.replace(".pt", "")
model_path = os.path.join(base_path, model_name)
if not os.path.isdir(model_path):
raise FileNotFoundError(
f"Missing accelerate checkpoint directory for {model_name}: "
f"{model_path}"
)
return model_path
def _load_accelerate_model(
self, model, model_path: str, *, strict: bool = True
) -> None:
"""Load model params from Accelerate checkpoint directory/file."""
checkpoint_path = model_path
if os.path.isdir(model_path):
safetensors_path = os.path.join(model_path, "model.safetensors")
pytorch_bin_path = os.path.join(model_path, "pytorch_model.bin")
if os.path.isfile(safetensors_path):
checkpoint_path = safetensors_path
elif os.path.isfile(pytorch_bin_path):
checkpoint_path = pytorch_bin_path
else:
target = self.accelerator.unwrap_model(model)
load_checkpoint_in_model(target, model_path, strict=strict)
return
state_dict = load_state_dict(checkpoint_path)
self._load_model_state(model, state_dict, strict=strict)
def _aggregate_episode_log_metrics(
self,
) -> Dict[str, float]:
metrics: Dict[str, float] = {}
if len(self.ep_infos) == 0:
return metrics
metric_sums: Dict[str, float] = {}
metric_counts: Dict[str, int] = {}
for ep_info in self.ep_infos:
for key, value in ep_info.items():
cpu_value = self._log_value_to_cpu_tensor(value)
if cpu_value is None or cpu_value.numel() == 0:
continue
metric_sums[key] = metric_sums.get(key, 0.0) + float(
cpu_value.sum().item()
)
metric_counts[key] = metric_counts.get(key, 0) + int(
cpu_value.numel()
)
for key, total in metric_sums.items():
count = metric_counts.get(key, 0)
if count <= 0:
continue
mean_value = total / float(count)
metric_key = key if "/" in key else f"Episode/{key}"
metrics[metric_key] = mean_value
return metrics
@staticmethod
def _log_value_to_cpu_tensor(value: Any) -> torch.Tensor | None:
if isinstance(value, torch.Tensor):
tensor = value.detach()
if tensor.ndim == 0:
tensor = tensor.unsqueeze(0)
return tensor.to(device="cpu", dtype=torch.float32).reshape(-1)
if isinstance(value, np.ndarray):
return torch.as_tensor(value, dtype=torch.float32).reshape(-1)
if isinstance(value, (int, float)):
return torch.tensor([float(value)], dtype=torch.float32)
return None
def _log_iteration(
self,
*,
it: int,
loss_dict: Dict[str, Any],
collection_time: float,
learn_time: float,
synced_mean_reward: float | None = None,
synced_mean_episode_length: float | None = None,
) -> None:
if not self.log_dir:
return
world_size = max(1, int(self.gpu_world_size))
fps = int(
self.num_steps_per_env
* self.num_envs
* world_size
/ max(collection_time + learn_time, 1.0e-8)
)
total_learning_iterations = int(
getattr(
self,
"total_learning_iterations",
self.current_learning_iteration
+ int(self.num_learning_iterations),
)
)
iteration_metrics: Dict[str, Any] = {
"0-Train/iteration": int(it),
"0-Train/iterations_total": total_learning_iterations,
}
for key, value in loss_dict.items():
if value is None:
continue
scalar = float(value)
iteration_metrics[f"Loss/{key}"] = scalar
iteration_metrics.update(
{
"1-Perf/total_fps": float(fps),
"1-Perf/collection_time": float(collection_time),
"1-Perf/learning_time": float(learn_time),
}
)
if (
synced_mean_reward is not None
and synced_mean_episode_length is not None
):
iteration_metrics["0-Train/mean_reward"] = float(
synced_mean_reward
)
iteration_metrics["0-Train/mean_episode_length"] = float(
synced_mean_episode_length
)
elif len(self.rewbuffer) > 0:
mean_reward = float(statistics.mean(self.rewbuffer))
mean_episode_length = float(statistics.mean(self.lenbuffer))
iteration_metrics["0-Train/mean_reward"] = mean_reward
iteration_metrics["0-Train/mean_episode_length"] = (
mean_episode_length
)
iteration_metrics.update(self._aggregate_episode_log_metrics())
iteration_metrics.update(self._get_additional_log_metrics())
self.algo_logger.log_iteration(
step=it,
total_learning_iterations=total_learning_iterations,
metrics=iteration_metrics,
)
def load(self, ckpt_path):
raise NotImplementedError("Subclasses must implement load().")
def save(self, path, infos=None):
raise NotImplementedError("Subclasses must implement save().")
================================================
FILE: holomotion/src/algo/algo_utils.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
from collections.abc import Mapping
from typing import Any, Generator
import torch
import torch.nn as nn
from loguru import logger
from tabulate import tabulate
from tensordict import TensorDict, tensorclass
class AlgoLogger:
def __init__(
self,
accelerator,
log_dir: str | None,
*,
is_main_process: bool,
) -> None:
self.accelerator = accelerator
self.log_dir = log_dir
self.is_main_process = bool(is_main_process)
@staticmethod
def _is_scalar_metric(value: Any) -> bool:
if isinstance(value, (int, float)):
return True
if isinstance(value, torch.Tensor):
return value.numel() == 1
return False
@staticmethod
def _to_scalar(value: Any) -> float:
if isinstance(value, torch.Tensor):
return float(value.item())
return float(value)
@staticmethod
def _format_console_value(value: Any) -> str:
if isinstance(value, (int, float)):
value_f = float(value)
abs_value = abs(value_f)
if abs_value > 0.0 and (abs_value < 1.0e-4 or abs_value >= 1.0e4):
return f"{value_f:.4e}"
return f"{value_f:.4f}"
if isinstance(value, torch.Tensor) and value.numel() == 1:
value_f = float(value.item())
abs_value = abs(value_f)
if abs_value > 0.0 and (abs_value < 1.0e-4 or abs_value >= 1.0e4):
return f"{value_f:.4e}"
return f"{value_f:.4f}"
return str(value)
def _build_console_log(
self,
*,
step: int,
total_learning_iterations: int | None,
console_metrics: Mapping[str, Any],
) -> str:
if total_learning_iterations is None:
title = f"TRAINING LOG - Iteration {step}"
else:
title = (
f"TRAINING LOG - Iteration {step}/{total_learning_iterations}"
)
table_data = [
[key, str(console_metrics[key])]
for key in sorted(console_metrics.keys())
]
log_lines = [
"\n" + "=" * 80,
title,
"=" * 80,
tabulate(
table_data,
headers=["Metric", "Value"],
tablefmt="simple_outline",
colalign=("left", "left"),
disable_numparse=True,
),
"=" * 80,
f"Logging Directory: {os.path.abspath(self.log_dir)}",
"=" * 80 + "\n",
]
return "\n".join(log_lines)
def log_iteration(
self,
*,
step: int,
metrics: Mapping[str, Any],
total_learning_iterations: int | None = None,
) -> None:
if not self.log_dir or not self.is_main_process:
return
tensorboard_metrics: dict[str, float] = {}
for key in sorted(metrics.keys()):
value = metrics[key]
if value is None or not self._is_scalar_metric(value):
continue
tensorboard_metrics[key] = self._to_scalar(value)
if len(tensorboard_metrics) > 0:
self.accelerator.log(tensorboard_metrics, step=int(step))
console_metrics = {
key: self._format_console_value(value)
for key, value in metrics.items()
if value is not None
}
console_log = self._build_console_log(
step=step,
total_learning_iterations=total_learning_iterations,
console_metrics=console_metrics,
)
logger.info(console_log)
@tensorclass(shadow=True)
class PpoTransition:
"""PPO rollout transition tensorclass.
Batch axes:
- N: num_envs (per-step)
- B: minibatch_size (for minibatches)
Shapes (batch dims = [N] or [B]):
- obs: TensorDict with leaf tensors [*, ...]
- actions, teacher_actions, mu, sigma: [*, A]
- actions_log_prob, values, rewards, returns, advantages, dones: [*, 1]
All float tensors are float32. `dones` is bool.
"""
FIELD_SPECS = {
"obs": {"kind": "obs"},
"actions": {"shape": ("A",), "dtype": torch.float32},
"teacher_actions": {"shape": ("A",), "dtype": torch.float32},
"mu": {"shape": ("A",), "dtype": torch.float32},
"sigma": {"shape": ("A",), "dtype": torch.float32},
"actions_log_prob": {"shape": (1,), "dtype": torch.float32},
"values": {"shape": (1,), "dtype": torch.float32},
"rewards": {"shape": (1,), "dtype": torch.float32},
"dones": {"shape": (1,), "dtype": torch.bool},
"returns": {"shape": (1,), "dtype": torch.float32},
"advantages": {"shape": (1,), "dtype": torch.float32},
}
obs: TensorDict
actions: torch.Tensor
teacher_actions: torch.Tensor
mu: torch.Tensor
sigma: torch.Tensor
actions_log_prob: torch.Tensor
values: torch.Tensor
rewards: torch.Tensor
dones: torch.Tensor
returns: torch.Tensor
advantages: torch.Tensor
@tensorclass(shadow=True)
class PpoVelocityTransition:
"""PPO rollout transition tensorclass.
Batch axes:
- N: num_envs (per-step)
- B: minibatch_size (for minibatches)
Shapes (batch dims = [N] or [B]):
- obs: TensorDict with leaf tensors [*, ...]
- actions, teacher_actions, mu, sigma: [*, A]
- actions_log_prob, values, rewards, returns, advantages, dones: [*, 1]
- velocity_commands: [*, 4]
All float tensors are float32. `dones` is bool.
"""
FIELD_SPECS = {
"obs": {"kind": "obs"},
"actions": {"shape": ("A",), "dtype": torch.float32},
"teacher_actions": {"shape": ("A",), "dtype": torch.float32},
"mu": {"shape": ("A",), "dtype": torch.float32},
"sigma": {"shape": ("A",), "dtype": torch.float32},
"actions_log_prob": {"shape": (1,), "dtype": torch.float32},
"values": {"shape": (1,), "dtype": torch.float32},
"rewards": {"shape": (1,), "dtype": torch.float32},
"dones": {"shape": (1,), "dtype": torch.bool},
"returns": {"shape": (1,), "dtype": torch.float32},
"advantages": {"shape": (1,), "dtype": torch.float32},
"velocity_commands": {"shape": (4,), "dtype": torch.float32},
}
obs: TensorDict
actions: torch.Tensor
teacher_actions: torch.Tensor
mu: torch.Tensor
sigma: torch.Tensor
actions_log_prob: torch.Tensor
values: torch.Tensor
rewards: torch.Tensor
dones: torch.Tensor
returns: torch.Tensor
advantages: torch.Tensor
velocity_commands: torch.Tensor
@tensorclass(shadow=True)
class PpoAuxTransition:
"""PPO transition with auxiliary state-prediction supervision targets."""
SHAPE_TOKENS = {"C": 0, "K": 0}
FIELD_SPECS = {
"obs": {"kind": "obs"},
"actions": {"shape": ("A",), "dtype": torch.float32},
"teacher_actions": {"shape": ("A",), "dtype": torch.float32},
"mu": {"shape": ("A",), "dtype": torch.float32},
"sigma": {"shape": ("A",), "dtype": torch.float32},
"actions_log_prob": {"shape": (1,), "dtype": torch.float32},
"values": {"shape": (1,), "dtype": torch.float32},
"rewards": {"shape": (1,), "dtype": torch.float32},
"dones": {"shape": (1,), "dtype": torch.bool},
"returns": {"shape": (1,), "dtype": torch.float32},
"advantages": {"shape": (1,), "dtype": torch.float32},
"gt_base_lin_vel_b": {"shape": (3,), "dtype": torch.float32},
"gt_root_height_rel_terrain": {"shape": (1,), "dtype": torch.float32},
"gt_keybody_contacts": {"shape": ("C",), "dtype": torch.float32},
"gt_ref_keybody_rel_pos": {
"shape": ("K", 3),
"dtype": torch.float32,
},
"gt_robot_keybody_rel_pos": {
"shape": ("K", 3),
"dtype": torch.float32,
},
"gt_denoise_ref_root_lin_vel": {
"shape": (3,),
"dtype": torch.float32,
},
"gt_denoise_ref_root_ang_vel": {
"shape": (3,),
"dtype": torch.float32,
},
"gt_denoise_ref_dof_pos": {
"shape": ("A",),
"dtype": torch.float32,
},
}
obs: TensorDict
actions: torch.Tensor
teacher_actions: torch.Tensor
mu: torch.Tensor
sigma: torch.Tensor
actions_log_prob: torch.Tensor
values: torch.Tensor
rewards: torch.Tensor
dones: torch.Tensor
returns: torch.Tensor
advantages: torch.Tensor
gt_base_lin_vel_b: torch.Tensor
gt_root_height_rel_terrain: torch.Tensor
gt_keybody_contacts: torch.Tensor
gt_ref_keybody_rel_pos: torch.Tensor
gt_robot_keybody_rel_pos: torch.Tensor
gt_denoise_ref_root_lin_vel: torch.Tensor
gt_denoise_ref_root_ang_vel: torch.Tensor
gt_denoise_ref_dof_pos: torch.Tensor
class RolloutStorage(nn.Module):
"""Rollout storage as a single TensorDict buffer with batch size [T, N]."""
def __init__(
self,
num_envs,
num_transitions_per_env,
obs_template: TensorDict,
actions_shape,
device="cpu",
command_name: str | None = None,
transition_cls: type[PpoTransition] = PpoTransition,
):
super().__init__()
self.device = device
self.num_transitions_per_env = num_transitions_per_env
self.num_envs = num_envs
self.command_name = command_name
self._float_dtype = torch.float32
self._dones_dtype = torch.bool
self._transition_cls = transition_cls
obs_template = obs_template.to(self.device)
self.data = TensorDict(
{},
batch_size=[num_transitions_per_env, num_envs],
device=self.device,
)
self._allocate_from_transition(
obs_template=obs_template,
actions_shape=actions_shape,
)
self.step = 0
def _resolve_shape(self, spec_shape, actions_shape) -> tuple:
if spec_shape is None:
return tuple()
resolved = []
shape_tokens = getattr(self._transition_cls, "SHAPE_TOKENS", {})
for dim in spec_shape:
if dim == "A":
resolved.extend(tuple(actions_shape))
elif isinstance(dim, str) and dim in shape_tokens:
resolved.append(int(shape_tokens[dim]))
else:
resolved.append(int(dim))
return tuple(resolved)
def _allocate_from_transition(
self,
*,
obs_template: TensorDict,
actions_shape,
) -> None:
specs = getattr(self._transition_cls, "FIELD_SPECS", None)
if not isinstance(specs, dict):
raise ValueError(
"Transition class must define FIELD_SPECS for allocation."
)
for name, spec in specs.items():
if spec.get("kind") == "obs":
leaf_keys = obs_template.keys(
include_nested=True, leaves_only=True
)
for key in leaf_keys:
value = obs_template.get(key)
if not torch.is_tensor(value):
continue
dtype = (
self._float_dtype
if torch.is_floating_point(value)
else value.dtype
)
key_tuple = key if isinstance(key, tuple) else (key,)
self.data.set(
("obs",) + key_tuple,
torch.empty(
(self.num_transitions_per_env, self.num_envs)
+ tuple(value.shape[1:]),
device=self.device,
dtype=dtype,
),
)
continue
shape_spec = spec.get("shape")
dtype = spec.get("dtype", self._float_dtype)
resolved = self._resolve_shape(shape_spec, actions_shape)
self.data.set(
name,
torch.empty(
(self.num_transitions_per_env, self.num_envs) + resolved,
device=self.device,
dtype=dtype,
),
)
def _to_storage_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(tensor):
raise TypeError("Expected a tensor for RolloutStorage update.")
if tensor.device != self.device:
tensor = tensor.to(self.device)
if (
torch.is_floating_point(tensor)
and tensor.dtype != self._float_dtype
):
tensor = tensor.to(dtype=self._float_dtype)
return tensor
def add(self, transition: PpoTransition) -> None:
if self.step >= self.num_transitions_per_env:
raise OverflowError("Rollout buffer overflow!")
if not isinstance(transition, self._transition_cls):
raise TypeError(
"Transition must match the RolloutStorage transition class."
)
if transition.batch_size is None or len(transition.batch_size) < 1:
raise ValueError("Transition must have batch size [N].")
if int(transition.batch_size[0]) != int(self.num_envs):
raise ValueError(
f"Transition batch size {transition.batch_size} "
f"does not match num_envs={self.num_envs}."
)
td = transition.to_tensordict()
td = td.apply(self._to_storage_tensor, inplace=False)
if "dones" in td.keys():
dones = td.get("dones")
if torch.is_tensor(dones) and dones.dtype != self._dones_dtype:
td.set("dones", dones.to(dtype=self._dones_dtype))
self.data[self.step].update_(td)
self.step += 1
def clear(self) -> None:
self.step = 0
def compute_returns(
self,
last_values,
gamma,
lam,
normalize_advantage: bool = False,
) -> None:
advantage = 0
rewards = self.data["rewards"]
values = self.data["values"]
dones = self.data["dones"]
returns = self.data["returns"]
advantages = self.data["advantages"]
for step in reversed(range(self.num_transitions_per_env)):
if step == self.num_transitions_per_env - 1:
next_values = last_values
else:
next_values = values[step + 1]
next_is_not_terminal = 1.0 - dones[step].float()
delta = (
rewards[step]
+ next_is_not_terminal * gamma * next_values
- values[step]
)
advantage = delta + next_is_not_terminal * gamma * lam * advantage
returns[step] = advantage + values[step]
advantages.copy_(returns - values)
if normalize_advantage:
flat = advantages.view(-1)
mean = flat.mean()
std = flat.std().clamp_min(1.0e-8)
advantages.copy_((advantages - mean) / std)
@torch.no_grad()
def normalize_advantages_global(
self,
*,
accelerator=None,
eps: float = 1.0e-8,
) -> None:
"""Global advantage normalization over the full rollout buffer.
This normalizes `self.data["advantages"]` in-place using mean/std over
all `[T * N]` samples. If `accelerator` is provided, the moments are
aggregated across processes via `accelerator.reduce(..., reduction="sum")`.
"""
advantages = self.data["advantages"]
advantages_flat = advantages.view(-1).float()
count = torch.tensor(
[advantages_flat.numel()], device=self.device, dtype=torch.float32
)
sum_local = advantages_flat.sum()
sqsum_local = (advantages_flat * advantages_flat).sum()
if accelerator is not None and int(accelerator.num_processes) > 1:
count_g = accelerator.reduce(count, reduction="sum")
sum_g = accelerator.reduce(sum_local, reduction="sum")
sqsum_g = accelerator.reduce(sqsum_local, reduction="sum")
else:
count_g = count
sum_g = sum_local
sqsum_g = sqsum_local
mean = sum_g / count_g
var = (sqsum_g / count_g) - mean * mean
std = torch.sqrt(var.clamp_min(eps))
advantages.copy_((advantages - mean) / std)
@torch.no_grad()
def normalize_advantages_global_by_move_mask(
self,
*,
accelerator=None,
eps: float = 1.0e-8,
move_threshold: float = 0.5,
) -> None:
"""Global advantage normalization split by move vs static (velocity commands).
Assumes:
- `advantages`: [T, N, 1]
- `velocity_commands`: [T, N, 4], where channel 0 is move_mask in {0,1}.
"""
velocity_commands = self.data.get("velocity_commands", None)
if velocity_commands is None:
raise ValueError(
"velocity_commands is required for global advantage normalization by move mask."
)
advantages = self.data["advantages"]
advantages_flat = advantages.view(-1).float()
vel_flat = velocity_commands.view(-1, int(velocity_commands.shape[-1]))
move_mask = vel_flat[:, 0] > float(move_threshold)
static_mask = ~move_mask
count_all = torch.tensor(
[advantages_flat.numel()], device=self.device, dtype=torch.float32
)
sum_all_local = advantages_flat.sum()
sqsum_all_local = (advantages_flat * advantages_flat).sum()
if accelerator is not None and int(accelerator.num_processes) > 1:
count_all_g = accelerator.reduce(count_all, reduction="sum")
sum_all_g = accelerator.reduce(sum_all_local, reduction="sum")
sqsum_all_g = accelerator.reduce(sqsum_all_local, reduction="sum")
else:
count_all_g = count_all
sum_all_g = sum_all_local
sqsum_all_g = sqsum_all_local
mean_all = sum_all_g / count_all_g
var_all = (sqsum_all_g / count_all_g) - mean_all * mean_all
std_all = torch.sqrt(var_all.clamp_min(eps))
def _group_stats(
mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if not bool(mask.any().item()):
return mean_all, std_all
mask_f = mask.to(dtype=torch.float32)
count_local = mask_f.sum()
sum_local = (advantages_flat * mask_f).sum()
sqsum_local = (advantages_flat * advantages_flat * mask_f).sum()
if accelerator is not None and int(accelerator.num_processes) > 1:
count_g = accelerator.reduce(count_local, reduction="sum")
sum_g = accelerator.reduce(sum_local, reduction="sum")
sqsum_g = accelerator.reduce(sqsum_local, reduction="sum")
else:
count_g = count_local
sum_g = sum_local
sqsum_g = sqsum_local
if float(count_g.item()) <= 0.0:
return mean_all, std_all
mean = sum_g / count_g
var = (sqsum_g / count_g) - mean * mean
std = torch.sqrt(var.clamp_min(eps))
return mean, std
move_mean, move_std = _group_stats(move_mask)
static_mean, static_std = _group_stats(static_mask)
advantages_norm = advantages_flat.clone()
if bool(move_mask.any().item()):
advantages_norm[move_mask] = (
advantages_flat[move_mask] - move_mean
) / move_std
if bool(static_mask.any().item()):
advantages_norm[static_mask] = (
advantages_flat[static_mask] - static_mean
) / static_std
self.data["advantages"].copy_(advantages_norm.view_as(advantages))
@torch.no_grad()
def normalize_advantages_global_by_command(
self,
*,
command_name: str | None,
accelerator=None,
eps: float = 1.0e-8,
) -> None:
"""Dispatch global advantage normalization based on command type/name."""
if (
command_name == "base_velocity"
and self.data.get("velocity_commands", None) is not None
):
self.normalize_advantages_global_by_move_mask(
accelerator=accelerator, eps=eps
)
return
self.normalize_advantages_global(accelerator=accelerator, eps=eps)
def iter_minibatches(
self,
num_mini_batches: int,
num_epochs: int,
) -> Generator[PpoTransition, None, None]:
if self.step != self.num_transitions_per_env:
raise RuntimeError(
f"RolloutStorage buffer not full: step={self.step}, "
f"expected={self.num_transitions_per_env}. "
"This would mix stale entries from a previous rollout."
)
batch_size = self.num_envs * self.num_transitions_per_env
(
effective_num_mini_batches,
mini_batch_size,
) = self.resolve_mini_batch_partition(batch_size, num_mini_batches)
indices = torch.randperm(
batch_size,
requires_grad=False,
device=self.device,
)[: effective_num_mini_batches * mini_batch_size]
flat = self.data.flatten(0, 1) # [T * N, ...]
for _ in range(num_epochs):
for i in range(effective_num_mini_batches):
start = i * mini_batch_size
end = (i + 1) * mini_batch_size
batch_idx = indices[start:end]
batch = flat[batch_idx]
yield self._transition_cls.from_tensordict(batch)
@staticmethod
def resolve_mini_batch_partition(
batch_size: int,
num_mini_batches: int,
) -> tuple[int, int]:
if batch_size <= 0:
raise RuntimeError(
"RolloutStorage minibatch partition requires batch_size > 0."
)
effective_num_mini_batches = max(
1, min(int(num_mini_batches), int(batch_size))
)
mini_batch_size = max(1, batch_size // effective_num_mini_batches)
return effective_num_mini_batches, mini_batch_size
================================================
FILE: holomotion/src/algo/ppo.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import inspect
import json
import math
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from loguru import logger
from tabulate import tabulate
from tqdm import tqdm
import imageio
from omegaconf import OmegaConf
from holomotion.src.algo.algo_base import BaseOnpolicyRL
from holomotion.src.algo.algo_utils import (
PpoTransition,
PpoVelocityTransition,
RolloutStorage,
)
from holomotion.src.utils.onnx_export import (
export_policy_to_onnx as export_policy_to_onnx_common,
)
from tensordict import TensorDict
def _checkpoint_state_to_cpu(value):
if isinstance(value, torch.Tensor):
return value.detach().cpu()
if isinstance(value, dict):
return {k: _checkpoint_state_to_cpu(v) for k, v in value.items()}
if isinstance(value, list):
return [_checkpoint_state_to_cpu(v) for v in value]
if isinstance(value, tuple):
return tuple(_checkpoint_state_to_cpu(v) for v in value)
return value
class PPO(BaseOnpolicyRL):
def _setup_configs(self):
super()._setup_configs()
self.desired_kl = self.config.desired_kl
self.schedule = self.config.schedule
self.actor_learning_rate = self.config.get(
"actor_learning_rate", self.config.get("learning_rate", 3e-4)
)
self.critic_learning_rate = self.config.get(
"critic_learning_rate", self.config.get("learning_rate", 3e-4)
)
self.base_actor_learning_rate = float(self.actor_learning_rate)
self.base_critic_learning_rate = float(self.critic_learning_rate)
self.actor_beta1 = self.config.get("actor_beta1", 0.9)
self.actor_beta2 = self.config.get("actor_beta2", 0.999)
self.critic_beta1 = self.config.get("critic_beta1", 0.9)
self.critic_beta2 = self.config.get("critic_beta2", 0.999)
self.optimizer_type = self.config.optimizer_type
self.clip_param = self.config.clip_param
self.num_learning_epochs = int(self.config.num_learning_epochs)
self.configured_num_mini_batches = int(self.config.num_mini_batches)
if self.configured_num_mini_batches < 1:
raise ValueError("num_mini_batches must be >= 1.")
distributed_update_cfg = self.config.get("distributed_update", {})
self.distributed_update_mode = str(
distributed_update_cfg.get("mode", "legacy")
).lower()
if self.distributed_update_mode not in {"legacy", "scalable"}:
raise ValueError(
"distributed_update.mode must be one of "
"{'legacy', 'scalable'}."
)
self.requested_num_mini_batches = self._resolve_num_mini_batches(
self.configured_num_mini_batches
)
self.num_mini_batches = self.requested_num_mini_batches
self.gamma = self.config.gamma
self.lam = self.config.lam
self.value_loss_coef = self.config.value_loss_coef
self.initial_entropy_coef = float(self.config.entropy_coef)
self.anneal_entropy = bool(self.config.get("anneal_entropy", False))
self.zero_entropy_point = float(
self.config.get("zero_entropy_point", 1.0)
)
self._validate_entropy_schedule_config(
initial_entropy_coef=self.initial_entropy_coef,
anneal_entropy=self.anneal_entropy,
zero_entropy_point=self.zero_entropy_point,
)
self.entropy_coef = self.initial_entropy_coef
self.max_grad_norm = self.config.max_grad_norm
self.use_clipped_value_loss = self.config.use_clipped_value_loss
adaptive_lr_cfg = self.config.get("adaptive_lr", {})
self.adaptive_lr_adapt_critic = bool(
adaptive_lr_cfg.get("adapt_critic", False)
)
self.adaptive_lr_factor = float(adaptive_lr_cfg.get("lr_scaler", 1.2))
self.adaptive_lr_kl_high_factor = float(
adaptive_lr_cfg.get("kl_high_factor", 2.0)
)
self.adaptive_lr_kl_low_factor = float(
adaptive_lr_cfg.get("kl_low_factor", 0.5)
)
self.adaptive_lr_min = float(
adaptive_lr_cfg.get("min_learning_rate", 1.0e-6)
)
self.adaptive_lr_max = float(
adaptive_lr_cfg.get("max_learning_rate", 1.0)
)
if self.adaptive_lr_factor <= 1.0:
raise ValueError("adaptive_lr.lr_scaler must be > 1.")
if self.adaptive_lr_kl_high_factor <= 0.0:
raise ValueError("adaptive_lr.kl_high_factor must be > 0.")
if self.adaptive_lr_kl_low_factor <= 0.0:
raise ValueError("adaptive_lr.kl_low_factor must be > 0.")
if self.adaptive_lr_min <= 0.0:
raise ValueError("adaptive_lr.min_learning_rate must be > 0.")
if self.adaptive_lr_max < self.adaptive_lr_min:
raise ValueError(
"adaptive_lr.max_learning_rate must be >= "
"adaptive_lr.min_learning_rate."
)
kl_early_stop_cfg = distributed_update_cfg.get("kl_early_stop", {})
self.kl_early_stop_enabled = bool(
kl_early_stop_cfg.get("enabled", False)
)
kl_signal_mode = str(
kl_early_stop_cfg.get("signal", "window_mean")
).lower()
if kl_signal_mode != "window_mean":
raise ValueError(
"Only distributed_update.kl_early_stop.signal='window_mean' "
"is supported."
)
self.kl_early_stop_window_size = int(
kl_early_stop_cfg.get("window_size", 3)
)
self.kl_early_stop_factor = float(kl_early_stop_cfg.get("factor", 2.0))
self.kl_early_stop_min_updates = int(
kl_early_stop_cfg.get("min_updates", 1)
)
if self.kl_early_stop_window_size < 1:
raise ValueError(
"distributed_update.kl_early_stop.window_size must be >= 1."
)
if self.kl_early_stop_factor <= 0.0:
raise ValueError(
"distributed_update.kl_early_stop.factor must be > 0."
)
if self.kl_early_stop_min_updates < 1:
raise ValueError(
"distributed_update.kl_early_stop.min_updates must be >= 1."
)
if self.kl_early_stop_enabled and self.desired_kl is None:
raise ValueError(
"distributed_update.kl_early_stop requires desired_kl to be set."
)
self.global_advantage_norm = bool(
self.config.get("global_advantage_norm", True)
)
self.normalize_advantage_per_mini_batch = bool(
self.config.get("normalize_advantage_per_mini_batch", False)
)
self.distributed_lr_scale_factor = self._compute_lr_scale_factor(
distributed_update_cfg.get("lr_scale", {})
)
self.actor_learning_rate = (
self.base_actor_learning_rate * self.distributed_lr_scale_factor
)
self.critic_learning_rate = (
self.base_critic_learning_rate * self.distributed_lr_scale_factor
)
self._last_update_metrics = {
"0-Train/configured_num_mini_batches": float(
self.configured_num_mini_batches
),
"0-Train/requested_num_mini_batches": float(
self.requested_num_mini_batches
),
"0-Train/effective_num_mini_batches": float(
self.requested_num_mini_batches
),
"0-Train/mini_batch_size_per_rank": 0.0,
"0-Train/num_updates_executed": 0.0,
"0-Train/lr_scale_factor": float(self.distributed_lr_scale_factor),
"0-Train/scalable_distributed_update": float(
self.distributed_update_mode == "scalable"
),
"0-Train/kl_windowed": 0.0,
"0-Train/kl_stop_triggered": 0.0,
"0-Train/kl_stop_analytic": 0.0,
}
self._offline_evaluating: bool = False
motion_cfg = self.env_config.config.robot.motion
sampling_strategy_cfg = motion_cfg.get("sampling_strategy", None)
if sampling_strategy_cfg is None:
sampling_strategy = "uniform"
else:
sampling_strategy = str(sampling_strategy_cfg).lower()
valid_strategies = {"uniform", "weighted_bin", "curriculum"}
if sampling_strategy not in valid_strategies:
raise ValueError(
f"Invalid sampling_strategy '{sampling_strategy}'. "
f"Expected one of {sorted(valid_strategies)}."
)
self.sampling_strategy: str = sampling_strategy
self.weighted_bin_cfg = dict(motion_cfg.get("weighted_bin", {}))
sym_cfg = self.config.get("symmetry_loss", {})
self.symmetry_loss_enabled = bool(sym_cfg.get("enabled", False))
self.symmetry_loss_coef = float(sym_cfg.get("coef", 0.0))
self._sym_dof_perm: torch.Tensor | None = None
self._sym_dof_sign: torch.Tensor | None = None
self._obs_mirror_map: dict[str, callable] = {}
if self._symmetry_loss_active():
self._setup_symmetry()
def _resolve_num_mini_batches(self, base_num_mini_batches: int) -> int:
if self.distributed_update_mode == "legacy" and self.is_distributed:
return max(1, base_num_mini_batches * int(self.gpu_world_size))
return max(1, base_num_mini_batches)
def _compute_lr_scale_factor(self, lr_scale_cfg) -> float:
scale_mode = str(lr_scale_cfg.get("mode", "none")).lower()
if scale_mode not in {
"none",
"sqrt_world_size",
"linear_world_size",
}:
raise ValueError(
"distributed_update.lr_scale.mode must be one of "
"{'none', 'sqrt_world_size', 'linear_world_size'}."
)
reference_world_size = float(
lr_scale_cfg.get("reference_world_size", 1)
)
if reference_world_size <= 0.0:
raise ValueError(
"distributed_update.lr_scale.reference_world_size must be > 0."
)
runtime_world_size = float(
self.gpu_world_size if self.is_distributed else 1
)
world_ratio = runtime_world_size / reference_world_size
if scale_mode == "none":
scale = 1.0
elif scale_mode == "sqrt_world_size":
scale = math.sqrt(world_ratio)
else:
scale = world_ratio
max_scale = lr_scale_cfg.get("max_scale", None)
if max_scale is not None:
max_scale = float(max_scale)
if max_scale <= 0.0:
raise ValueError(
"distributed_update.lr_scale.max_scale must be > 0 when set."
)
scale = min(scale, max_scale)
return float(scale)
def _symmetry_loss_active(self) -> bool:
return bool(
getattr(self, "command_name", None) == "base_velocity"
and getattr(self, "symmetry_loss_enabled", False)
and float(getattr(self, "symmetry_loss_coef", 0.0)) > 0.0
)
@staticmethod
def _omega_or_obj_to_dict(value):
if value is None:
return {}
if OmegaConf.is_config(value):
return OmegaConf.to_container(value, resolve=True)
if isinstance(value, dict):
return value
if hasattr(value, "__dict__"):
return vars(value)
return {}
def _setup_symmetry(self) -> None:
robot_asset = self.env._env.scene["robot"]
joint_names = list(getattr(robot_asset, "joint_names", []))
if len(joint_names) != int(self.num_actions):
raise ValueError(
"symmetry_loss requires simulator joint_names to match "
f"num_actions, got {len(joint_names)} vs {self.num_actions}."
)
name_to_idx = {name: idx for idx, name in enumerate(joint_names)}
perm: list[int] = []
for name in joint_names:
if name.startswith("left_"):
mirror_name = "right_" + name[len("left_") :]
elif name.startswith("right_"):
mirror_name = "left_" + name[len("right_") :]
else:
mirror_name = name
perm.append(int(name_to_idx.get(mirror_name, name_to_idx[name])))
sym_cfg = self._omega_or_obj_to_dict(
self.config.get("symmetry_loss", {})
)
sign_by_name = sym_cfg.get("dof_sign_by_name", None)
if not sign_by_name:
robot_cfg = self._omega_or_obj_to_dict(
getattr(
getattr(self.env_config, "config", None), "robot", None
)
)
sign_by_name = robot_cfg.get("dof_sign_by_name", None)
sign_by_name = self._omega_or_obj_to_dict(sign_by_name)
if len(sign_by_name) == 0:
raise ValueError(
"symmetry_loss requires dof_sign_by_name in algo or robot config."
)
sign = [float(sign_by_name.get(name, 1.0)) for name in joint_names]
self._sym_dof_perm = torch.tensor(
perm, device=self.device, dtype=torch.long
)
self._sym_dof_sign = torch.tensor(
sign, device=self.device, dtype=torch.float32
)
self._build_obs_mirror_map()
def _extract_obs_mirror_metadata(self) -> dict[str, dict]:
obs_cfg = getattr(
getattr(self.env_config, "config", None), "obs", None
)
obs_root = self._omega_or_obj_to_dict(obs_cfg)
obs_groups = obs_root.get("obs_groups", {})
metadata: dict[str, dict] = {}
for group_name, group_cfg in obs_groups.items():
if not isinstance(group_cfg, dict):
continue
for term_entry in group_cfg.get("atomic_obs_list", []):
if not isinstance(term_entry, dict):
continue
for term_name, term_cfg in term_entry.items():
term_cfg = self._omega_or_obj_to_dict(term_cfg)
mirror_func = term_cfg.get("mirror_func", None)
if not mirror_func:
continue
metadata[f"{group_name}/{term_name}"] = {
"mirror_func": str(mirror_func),
"mirror_config": self._omega_or_obj_to_dict(
term_cfg.get("mirror_config", {})
),
}
return metadata
def _get_actor_schema_terms(self) -> set[str]:
module_dict = self._omega_or_obj_to_dict(
self.config.get("module_dict", {})
)
actor_cfg = self._omega_or_obj_to_dict(module_dict.get("actor", {}))
actor_schema = self._omega_or_obj_to_dict(
actor_cfg.get("obs_schema", {})
)
actor_terms: set[str] = set()
for seq_cfg in actor_schema.values():
if not isinstance(seq_cfg, dict):
continue
for term in seq_cfg.get("terms", []):
actor_terms.add(str(term))
return actor_terms
def _build_obs_mirror_map(self) -> None:
from holomotion.src.env.isaaclab_components.isaaclab_observation import (
MirrorFunctions,
)
self._obs_mirror_map = {}
if self._sym_dof_perm is None or self._sym_dof_sign is None:
return
term_meta = self._extract_obs_mirror_metadata()
actor_terms = self._get_actor_schema_terms()
for term in actor_terms:
meta = term_meta.get(term)
if meta is None:
continue
mirror_func = str(meta.get("mirror_func", ""))
if mirror_func == "mirror_dof":
perm = self._sym_dof_perm
sign = self._sym_dof_sign
def _fn(x, perm=perm, sign=sign):
perm_local = perm.to(device=x.device, dtype=torch.long)
sign_local = sign.to(device=x.device, dtype=x.dtype)
return MirrorFunctions.mirror_dof(
x, perm=perm_local, sign=sign_local
)
elif mirror_func == "mirror_vec3":
def _fn(x):
return MirrorFunctions.mirror_vec3(x)
elif mirror_func == "mirror_axial_vec3":
def _fn(x):
return MirrorFunctions.mirror_axial_vec3(x)
elif mirror_func == "mirror_velocity_command":
def _fn(x):
return MirrorFunctions.mirror_velocity_command(x)
else:
continue
self._obs_mirror_map[term] = _fn
@staticmethod
def _td_key_to_path(key) -> str:
if isinstance(key, tuple):
return "/".join(str(part) for part in key)
return str(key)
def _mirror_actor_obs(self, obs_td: TensorDict) -> TensorDict:
if (
not self._symmetry_loss_active()
or not isinstance(obs_td, TensorDict)
or len(getattr(self, "_obs_mirror_map", {})) == 0
):
return obs_td
mirrored = TensorDict(
{},
batch_size=list(obs_td.batch_size),
device=obs_td.device,
)
for key in obs_td.keys(include_nested=True, leaves_only=True):
key_tuple = key if isinstance(key, tuple) else (key,)
value = obs_td.get(key_tuple)
mirror_fn = self._obs_mirror_map.get(
self._td_key_to_path(key_tuple)
)
mirrored.set(
key_tuple,
mirror_fn(value) if mirror_fn is not None else value,
)
return mirrored
def _mirror_env_action(self, actions: torch.Tensor) -> torch.Tensor:
from holomotion.src.env.isaaclab_components.isaaclab_observation import (
MirrorFunctions,
)
if not self._symmetry_loss_active():
return actions
if self._sym_dof_perm is None or self._sym_dof_sign is None:
raise RuntimeError(
"Symmetry DOF permutation/signs are not initialized."
)
return MirrorFunctions.mirror_action(
actions,
perm=self._sym_dof_perm.to(
device=actions.device, dtype=torch.long
),
sign=self._sym_dof_sign.to(
device=actions.device, dtype=actions.dtype
),
)
def _compute_analytic_kl(
self,
old_mu: torch.Tensor,
old_sigma: torch.Tensor,
new_mu: torch.Tensor,
new_sigma: torch.Tensor,
weight: torch.Tensor | None = None,
) -> float:
with torch.no_grad():
kl_vec = torch.sum(
torch.log((new_sigma + 1.0e-8) / (old_sigma + 1.0e-8))
+ (torch.square(old_sigma) + torch.square(old_mu - new_mu))
/ (2.0 * torch.square(new_sigma) + 1.0e-8)
- 0.5,
dim=-1,
)
if weight is None:
kl_sum = kl_vec.sum()
kl_count = torch.tensor(
float(kl_vec.numel()),
device=self.device,
dtype=torch.float32,
)
else:
kl_weight = weight.to(dtype=torch.float32)
kl_sum = (kl_vec * kl_weight).sum()
kl_count = kl_weight.sum()
if self.is_distributed:
kl_sum = self.accelerator.reduce(kl_sum, reduction="sum")
kl_count = self.accelerator.reduce(kl_count, reduction="sum")
kl_mean = kl_sum / kl_count.clamp_min(1.0)
return float(kl_mean.item())
def _compute_clip_fraction(
self,
ratio: torch.Tensor,
weight: torch.Tensor | None = None,
) -> float:
with torch.no_grad():
clipped = (
(ratio < (1.0 - self.clip_param))
| (ratio > (1.0 + self.clip_param))
).to(torch.float32)
if weight is None:
clip_sum = clipped.sum()
clip_count = torch.tensor(
float(clipped.numel()),
device=self.device,
dtype=torch.float32,
)
else:
clip_weight = weight.to(dtype=torch.float32)
clip_sum = (clipped * clip_weight).sum()
clip_count = clip_weight.sum()
if self.is_distributed:
clip_sum = self.accelerator.reduce(clip_sum, reduction="sum")
clip_count = self.accelerator.reduce(
clip_count, reduction="sum"
)
clip_fraction = clip_sum / clip_count.clamp_min(1.0)
return float(clip_fraction.item())
def _compute_explained_variance(
self,
target: torch.Tensor,
prediction: torch.Tensor,
weight: torch.Tensor | None = None,
) -> float:
with torch.no_grad():
target_f = target.float()
prediction_f = prediction.float()
residual = target_f - prediction_f
if weight is None:
weight_f = torch.ones_like(target_f, dtype=torch.float32)
else:
weight_f = weight.to(dtype=torch.float32)
count = weight_f.sum()
target_sum = (target_f * weight_f).sum()
target_sq_sum = (target_f.square() * weight_f).sum()
residual_sum = (residual * weight_f).sum()
residual_sq_sum = (residual.square() * weight_f).sum()
if self.is_distributed:
count = self.accelerator.reduce(count, reduction="sum")
target_sum = self.accelerator.reduce(
target_sum, reduction="sum"
)
target_sq_sum = self.accelerator.reduce(
target_sq_sum, reduction="sum"
)
residual_sum = self.accelerator.reduce(
residual_sum, reduction="sum"
)
residual_sq_sum = self.accelerator.reduce(
residual_sq_sum, reduction="sum"
)
denom = count.clamp_min(1.0)
target_mean = target_sum / denom
residual_mean = residual_sum / denom
target_var = target_sq_sum / denom - target_mean.square()
residual_var = residual_sq_sum / denom - residual_mean.square()
if float(target_var.item()) <= 1.0e-8:
return 0.0
explained_variance = 1.0 - residual_var / target_var
return float(explained_variance.item())
def _set_optimizer_learning_rates(self) -> None:
for param_group in self.actor_optimizer.param_groups:
param_group["lr"] = self.actor_learning_rate
for param_group in self.critic_optimizer.param_groups:
param_group["lr"] = self.critic_learning_rate
@staticmethod
def _validate_entropy_schedule_config(
*,
initial_entropy_coef: float,
anneal_entropy: bool,
zero_entropy_point: float,
) -> None:
if float(initial_entropy_coef) < 0.0:
raise ValueError("entropy_coef must be >= 0.")
if anneal_entropy and not (0.0 < float(zero_entropy_point) <= 1.0):
raise ValueError(
"zero_entropy_point must be in (0.0, 1.0] when "
"anneal_entropy is enabled."
)
def _get_effective_entropy_coef(self) -> float:
if self.initial_entropy_coef <= 0.0 or not self.anneal_entropy:
return float(self.initial_entropy_coef)
total_learning_iterations = int(
getattr(
self,
"total_learning_iterations",
self.current_learning_iteration
+ int(self.num_learning_iterations),
)
)
total_learning_iterations = max(1, total_learning_iterations)
zero_entropy_iteration = float(self.zero_entropy_point) * float(
total_learning_iterations
)
anneal_scale = max(
0.0,
1.0
- float(self.current_learning_iteration) / zero_entropy_iteration,
)
return float(self.initial_entropy_coef * anneal_scale)
def _apply_adaptive_lr(self, kl_signal: float | None) -> None:
if (
self.desired_kl is None
or self.schedule != "adaptive"
or kl_signal is None
):
return
if kl_signal > self.desired_kl * self.adaptive_lr_kl_high_factor:
self.actor_learning_rate = max(
self.adaptive_lr_min,
self.actor_learning_rate / self.adaptive_lr_factor,
)
if self.adaptive_lr_adapt_critic:
self.critic_learning_rate = max(
self.adaptive_lr_min,
self.critic_learning_rate / self.adaptive_lr_factor,
)
elif (
kl_signal > 0.0
and kl_signal < self.desired_kl * self.adaptive_lr_kl_low_factor
):
self.actor_learning_rate = min(
self.adaptive_lr_max,
self.actor_learning_rate * self.adaptive_lr_factor,
)
if self.adaptive_lr_adapt_critic:
self.critic_learning_rate = min(
self.adaptive_lr_max,
self.critic_learning_rate * self.adaptive_lr_factor,
)
self._set_optimizer_learning_rates()
def _compute_windowed_kl_signal(
self, recent_analytic_kls: list[float]
) -> float | None:
if len(recent_analytic_kls) < self.kl_early_stop_window_size:
return None
window = recent_analytic_kls[-self.kl_early_stop_window_size :]
return float(sum(window) / len(window))
def _should_early_stop_for_kl(
self,
kl_signal: float | None,
num_kl_measurements: int,
) -> bool:
if not self.kl_early_stop_enabled or self.desired_kl is None:
return False
if kl_signal is None:
return False
required_measurements = max(
self.kl_early_stop_min_updates, self.kl_early_stop_window_size
)
if num_kl_measurements < required_measurements:
return False
return kl_signal > self.desired_kl * self.kl_early_stop_factor
def _setup_data_buffers(self):
super()._setup_data_buffers()
self.use_velocity_transition: bool = (
self.command_name == "base_velocity"
)
self.transition_cls = (
PpoVelocityTransition
if self.use_velocity_transition
else PpoTransition
)
self.transition_td: PpoTransition | PpoVelocityTransition | None = None
def _build_optimizer_kwargs(self, optimizer_class: type) -> dict:
if self.optimizer_type != "AdamW":
return {}
signature = inspect.signature(optimizer_class.__init__)
parameters = signature.parameters
use_fused = bool(
self.config.get(
"adamw_use_fused", bool(self.device.type == "cuda")
)
)
use_foreach = bool(self.config.get("adamw_use_foreach", True))
if (
use_fused
and ("fused" in parameters)
and (self.device.type == "cuda")
):
return {"fused": True}
if use_foreach and ("foreach" in parameters):
return {"foreach": True}
return {}
def _setup_models_and_optimizer(self):
from holomotion.src.modules.agent_modules import PPOActor, PPOCritic
# Build sample TensorDict for schema-based assembly
sample_obs_dict = self.env.reset_all()[0]
sample_td = self._wrap_obs_dict(sample_obs_dict)
actor_cfg = OmegaConf.to_container(
self.config.module_dict.actor, resolve=True
)
critic_cfg = OmegaConf.to_container(
self.config.module_dict.critic, resolve=True
)
self.actor_type = actor_cfg.get("type", "MLP")
self.critic_type = critic_cfg.get("type", "MLP")
actor_schema = actor_cfg.get("obs_schema", None)
critic_schema = critic_cfg.get("obs_schema", None)
self.actor = PPOActor(
obs_schema=actor_schema,
module_config_dict=actor_cfg,
num_actions=self.num_actions,
init_noise_std=self.config.init_noise_std,
obs_example=sample_td,
).to(self.device)
self.critic = PPOCritic(
obs_schema=critic_schema,
module_config_dict=critic_cfg,
obs_example=sample_td,
).to(self.device)
if self.is_main_process:
actor = self.accelerator.unwrap_model(self.actor)
critic = self.accelerator.unwrap_model(self.critic)
logger.info("Actor (TensorDict module):\n{!r}", actor)
logger.info(
"Actor keys: in_keys={} out_keys={}",
list(actor.in_keys),
list(actor.out_keys),
)
logger.info("Actor core nn module:\n{!r}", actor.actor_module)
logger.info("Critic (TensorDict module):\n{!r}", critic)
logger.info(
"Critic keys: in_keys={} out_keys={}",
list(critic.in_keys),
list(critic.out_keys),
)
logger.info("Critic core nn module:\n{!r}", critic.critic_module)
# Log actor and critic parameter counts (in millions)
actor_params = sum(p.numel() for p in self.actor.parameters())
critic_params = sum(p.numel() for p in self.critic.parameters())
params_table = [
["Actor", f"{actor_params / 1.0e6:.3f}"],
["Critic", f"{critic_params / 1.0e6:.3f}"],
["Total", f"{(actor_params + critic_params) / 1.0e6:.3f}"],
]
logger.info(
"Model Summary:\n"
+ tabulate(
params_table,
headers=["Model", "Params (M)"],
tablefmt="simple_outline",
)
)
optimizer_class = getattr(optim, self.optimizer_type)
optimizer_kwargs = self._build_optimizer_kwargs(optimizer_class)
self.actor_optimizer = optimizer_class(
self.actor.parameters(),
lr=self.actor_learning_rate,
betas=(self.actor_beta1, self.actor_beta2),
**optimizer_kwargs,
)
self.critic_optimizer = optimizer_class(
self.critic.parameters(),
lr=self.critic_learning_rate,
betas=(self.critic_beta1, self.critic_beta2),
**optimizer_kwargs,
)
dynamo_backend = self.config.get("dynamo_backend", None)
if dynamo_backend and self.is_main_process:
logger.info(
f"Models will be compiled with dynamo_backend='{dynamo_backend}' "
"during accelerator.prepare()"
)
(
self.actor,
self.critic,
self.actor_optimizer,
self.critic_optimizer,
) = self.accelerator.prepare(
self.actor,
self.critic,
self.actor_optimizer,
self.critic_optimizer,
)
def _build_storage(self, obs_td: TensorDict):
return RolloutStorage(
self.num_envs,
self.num_steps_per_env,
obs_template=obs_td,
actions_shape=[self.num_actions],
device=self.device,
command_name=self.command_name,
transition_cls=self.transition_cls,
)
def _build_transition(
self,
obs_td: TensorDict,
actor_out: TensorDict,
critic_out: TensorDict,
):
actions = actor_out.get("actions")
actions_log_prob = actor_out.get("actions_log_prob")
mu = actor_out.get("mu")
sigma = actor_out.get("sigma")
values = critic_out.get("values")
zero_scalar = torch.zeros(
self.num_envs,
1,
device=self.device,
dtype=torch.float32,
)
zero_scalar_bool = torch.zeros(
self.num_envs,
1,
device=self.device,
dtype=torch.bool,
)
transition_kwargs = {
"obs": obs_td,
"actions": actions.detach(),
"teacher_actions": torch.zeros_like(actions),
"mu": mu.detach(),
"sigma": sigma.detach(),
"actions_log_prob": actions_log_prob[..., None].detach(),
"values": values.detach(),
"rewards": zero_scalar.clone(),
"dones": zero_scalar_bool,
"returns": zero_scalar.clone(),
"advantages": zero_scalar.clone(),
"batch_size": [self.num_envs],
"device": self.device,
}
if self.use_velocity_transition:
import isaaclab.envs.mdp as isaaclab_mdp
velocity_cmd = isaaclab_mdp.generated_commands(
self.env._env, command_name="base_velocity"
)
if velocity_cmd.shape[-1] > 3:
velocity_cmd = velocity_cmd[..., :3]
move_mask = (velocity_cmd.norm(dim=-1) > 0.1).to(
dtype=velocity_cmd.dtype
)
transition_kwargs["velocity_commands"] = torch.cat(
[move_mask[..., None], velocity_cmd],
dim=-1,
).detach()
return self.transition_cls(**transition_kwargs)
def _post_iteration_hook(self, it: int) -> None:
if self.command_name == "ref_motion":
motion_cmd = self.env._env.command_manager.get_term("ref_motion")
motion_cmd.apply_cache_swap_if_pending_barrier(
accelerator=self.accelerator
)
def _post_training_hook(self) -> None:
if self.command_name == "ref_motion":
motion_cmd = self.env._env.command_manager.get_term("ref_motion")
if motion_cmd is not None:
motion_cmd.close()
def _get_mean_policy_std(self) -> torch.Tensor:
base_actor = self.accelerator.unwrap_model(self.actor)
if hasattr(base_actor, "std"):
return base_actor.std.mean()
if hasattr(base_actor, "log_std"):
return torch.exp(base_actor.log_std).mean()
return torch.tensor(0.0, device=self.device)
def _maybe_override_loaded_actor_sigma(self) -> None:
if not bool(self.config.get("override_sigma", False)):
return
sigma_override = self.config.get("sigma_override", None)
if sigma_override is None:
raise ValueError(
"config.override_sigma is enabled but config.sigma_override is not set."
)
actor_unwrapped = self.accelerator.unwrap_model(self.actor)
orig_mod = getattr(actor_unwrapped, "_orig_mod", None)
if orig_mod is not None:
actor_unwrapped = orig_mod
override_sigma = getattr(actor_unwrapped, "override_sigma", None)
if override_sigma is None:
raise AttributeError(
f"{type(actor_unwrapped).__name__} does not implement override_sigma()."
)
override_sigma(sigma_override)
if self.is_main_process:
logger.info(
"Reapplied sigma override after checkpoint load: {}",
sigma_override,
)
def _get_additional_log_metrics(self) -> dict[str, float]:
"""Build auxiliary training/cache metrics."""
iteration_metrics = {}
if "actor_learning_rate" in self.__dict__:
iteration_metrics["0-Train/actor_learning_rate"] = float(
self.actor_learning_rate
)
if "critic_learning_rate" in self.__dict__:
iteration_metrics["0-Train/critic_learning_rate"] = float(
self.critic_learning_rate
)
if "initial_entropy_coef" in self.__dict__:
iteration_metrics["0-Train/entropy_coef_effective"] = float(
self._get_effective_entropy_coef()
)
if "_last_update_metrics" in self.__dict__:
iteration_metrics.update(self._last_update_metrics)
mean_std = self._get_mean_policy_std()
iteration_metrics["0-Train/mean_noise_std"] = float(mean_std.item())
if self.command_name != "ref_motion":
return iteration_metrics
motion_cmd = self.env._env.command_manager.get_term("ref_motion")
cache = motion_cmd._motion_cache
iteration_metrics["1-Perf/Cache/swap_index"] = float(cache.swap_index)
pool_stats = cache.cache_curriculum_pool_statistics()
if pool_stats is not None:
core_cache_metric_keys = {
"prioritized_pool_size": "1-Perf/Cache/prioritized_pool_size",
"prioritized_pool_mean_score": "1-Perf/Cache/prioritized_pool_mean_score",
"uniform_pool_mean_score": "1-Perf/Cache/uniform_pool_mean_score",
"entered_prioritized_pool_count": "1-Perf/Cache/entered_prioritized_pool_count",
"exited_prioritized_pool_count": "1-Perf/Cache/exited_prioritized_pool_count",
}
for src_key, dst_key in core_cache_metric_keys.items():
if src_key in pool_stats:
iteration_metrics[dst_key] = float(pool_stats[src_key])
return iteration_metrics
def update(self):
mean_value_loss = 0.0
mean_surrogate_loss = 0.0
mean_entropy = 0.0
mean_kl_analytic = 0.0
mean_symmetry_loss = 0.0
critic_explained_variance = self._compute_explained_variance(
target=self.storage.data["returns"],
prediction=self.storage.data["values"],
)
batch_size = int(
self.storage.num_envs * self.storage.num_transitions_per_env
)
(
effective_num_mini_batches,
mini_batch_size,
) = RolloutStorage.resolve_mini_batch_partition(
batch_size, self.num_mini_batches
)
self._last_update_metrics = {
"0-Train/configured_num_mini_batches": float(
self.configured_num_mini_batches
),
"0-Train/requested_num_mini_batches": float(
self.requested_num_mini_batches
),
"0-Train/effective_num_mini_batches": float(
effective_num_mini_batches
),
"0-Train/mini_batch_size_per_rank": float(mini_batch_size),
"0-Train/num_updates_executed": 0.0,
"0-Train/lr_scale_factor": float(self.distributed_lr_scale_factor),
"0-Train/scalable_distributed_update": float(
self.distributed_update_mode == "scalable"
),
"0-Train/kl_windowed": 0.0,
"0-Train/kl_stop_triggered": 0.0,
"0-Train/kl_stop_analytic": 0.0,
"0-Train/kl_analytic_batch_last": 0.0,
"0-Train/kl_analytic_batch_max": 0.0,
"0-Train/clip_fraction_batch_mean": 0.0,
"0-Train/clip_fraction_batch_last": 0.0,
}
entropy_coef = self._get_effective_entropy_coef()
generator = self.storage.iter_minibatches(
effective_num_mini_batches,
self.num_learning_epochs,
)
measure_analytic_kl = self.desired_kl is not None
num_updates = 0
num_kl_measurements = 0
kl_stop_triggered = False
kl_stop_analytic = 0.0
kl_windowed = None
recent_analytic_kls: list[float] = []
kl_analytic_batch_last = 0.0
kl_analytic_batch_max = 0.0
clip_fraction_batch_mean = 0.0
clip_fraction_batch_last = 0.0
for batch in generator:
obs_batch = batch.obs
actions_batch = batch.actions
target_values_batch = batch.values
advantages_batch = batch.advantages
returns_batch = batch.returns
old_actions_log_prob_batch = batch.actions_log_prob
old_mu_batch = batch.mu
old_sigma_batch = batch.sigma
with self.accelerator.autocast():
actor_out = self.actor(
obs_batch,
actions=actions_batch,
mode="logp",
update_obs_norm=False,
)
critic_out = self.critic(obs_batch, update_obs_norm=False)
actions_log_prob_batch = actor_out.get("actions_log_prob")
mu_batch = actor_out.get("mu")
sigma_batch = actor_out.get("sigma")
entropy_batch = actor_out.get("entropy")
value_pred = critic_out.get("values")
symmetry_loss = None
if self._symmetry_loss_active():
mirrored_obs_batch = self._mirror_actor_obs(obs_batch)
mirrored_actor_out = self.actor(
mirrored_obs_batch,
actions=None,
mode="inference",
update_obs_norm=False,
)
mirrored_actions = mirrored_actor_out.get("actions")
mirrored_actions_back = self._mirror_env_action(
mirrored_actions
)
symmetry_loss = F.mse_loss(
mu_batch.float(),
mirrored_actions_back.float(),
)
value_batch = value_pred
returns_batch_norm = returns_batch
target_values_batch_norm = target_values_batch
analytic_kl = None
if measure_analytic_kl:
analytic_kl = self._compute_analytic_kl(
old_mu=old_mu_batch.float(),
old_sigma=old_sigma_batch.float(),
new_mu=mu_batch.float(),
new_sigma=sigma_batch.float(),
)
mean_kl_analytic += analytic_kl
num_kl_measurements += 1
kl_analytic_batch_last = analytic_kl
kl_analytic_batch_max = max(kl_analytic_batch_max, analytic_kl)
recent_analytic_kls.append(analytic_kl)
if len(recent_analytic_kls) > self.kl_early_stop_window_size:
recent_analytic_kls.pop(0)
kl_windowed = self._compute_windowed_kl_signal(
recent_analytic_kls
)
if self._should_early_stop_for_kl(
kl_windowed, num_kl_measurements
):
kl_stop_triggered = True
kl_stop_analytic = analytic_kl
break
ratio = torch.exp(
actions_log_prob_batch
- torch.squeeze(old_actions_log_prob_batch).float()
)
clip_fraction = self._compute_clip_fraction(ratio)
clip_fraction_batch_mean += clip_fraction
clip_fraction_batch_last = clip_fraction
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
if self.use_clipped_value_loss:
value_clipped = target_values_batch_norm + (
value_batch - target_values_batch_norm
).clamp(-self.clip_param, self.clip_param)
value_losses = (value_batch - returns_batch_norm).pow(2)
value_losses_clipped = (
value_clipped - returns_batch_norm
).pow(2)
value_loss = torch.max(
value_losses, value_losses_clipped
).mean()
else:
value_loss = (returns_batch_norm - value_batch).pow(2).mean()
actor_loss = surrogate_loss
critic_loss = self.value_loss_coef * value_loss
if entropy_coef > 0.0:
entropy_loss = entropy_batch.mean()
actor_loss = actor_loss - entropy_coef * entropy_loss
if symmetry_loss is not None:
actor_loss = (
actor_loss + self.symmetry_loss_coef * symmetry_loss
)
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
self.accelerator.backward(actor_loss)
self.accelerator.backward(critic_loss)
if self.max_grad_norm is not None:
self.accelerator.clip_grad_norm_(
self.actor.parameters(),
self.max_grad_norm,
)
self.accelerator.clip_grad_norm_(
self.critic.parameters(),
self.max_grad_norm,
)
self.actor_optimizer.step()
self.critic_optimizer.step()
num_updates += 1
mean_value_loss += float(value_loss.item())
mean_surrogate_loss += float(surrogate_loss.item())
mean_entropy += float(entropy_batch.mean().item())
if symmetry_loss is not None:
mean_symmetry_loss += float(symmetry_loss.item())
denom = max(1, num_updates)
mean_value_loss /= denom
mean_surrogate_loss /= denom
mean_entropy /= denom
mean_kl_analytic /= max(1, num_kl_measurements)
mean_symmetry_loss /= denom
clip_fraction_batch_mean /= denom
if self.schedule == "adaptive":
self._apply_adaptive_lr(kl_windowed)
self._last_update_metrics["0-Train/num_updates_executed"] = float(
num_updates
)
self._last_update_metrics["0-Train/kl_windowed"] = float(
kl_windowed or 0.0
)
self._last_update_metrics["0-Train/kl_stop_triggered"] = float(
kl_stop_triggered
)
self._last_update_metrics["0-Train/kl_stop_analytic"] = float(
kl_stop_analytic
)
self._last_update_metrics["0-Train/kl_analytic_batch_last"] = float(
kl_analytic_batch_last
)
self._last_update_metrics["0-Train/kl_analytic_batch_max"] = float(
kl_analytic_batch_max
)
self._last_update_metrics["0-Train/clip_fraction_batch_mean"] = float(
clip_fraction_batch_mean
)
self._last_update_metrics["0-Train/clip_fraction_batch_last"] = float(
clip_fraction_batch_last
)
self.storage.clear()
loss_out = {
"value_function": mean_value_loss,
"critic_explained_variance": critic_explained_variance,
"surrogate": mean_surrogate_loss,
"entropy": mean_entropy,
"kl_analytic": mean_kl_analytic,
}
if self._symmetry_loss_active():
loss_out["symmetry_loss"] = mean_symmetry_loss
# Reduce losses across processes for consistent logging on rank 0
if self.is_distributed:
reduced_out = {}
for k, v in loss_out.items():
if v is None:
reduced_out[k] = None
continue
t = torch.tensor(v, device=self.device, dtype=torch.float32)
reduced_t = self.accelerator.reduce(t, reduction="mean")
reduced_out[k] = float(reduced_t.item())
loss_out = reduced_out
self._post_update_hook(loss_out)
return loss_out
def load(self, ckpt_path):
if ckpt_path is None:
return None
if self.is_main_process:
logger.info(f"Loading checkpoint from {ckpt_path}")
actor_model_path = self._resolve_model_file_path(ckpt_path, "actor")
critic_model_path = self._resolve_model_file_path(ckpt_path, "critic")
self._load_accelerate_model(self.actor, actor_model_path, strict=True)
self._load_accelerate_model(
self.critic, critic_model_path, strict=True
)
loaded_dict = torch.load(ckpt_path, map_location=self.device)
if not getattr(self, "is_offline_eval", False):
self._restore_optimizer_state(
self.actor_optimizer,
loaded_dict["actor_optimizer_state_dict"],
optimizer_name="actor",
)
self._restore_optimizer_state(
self.critic_optimizer,
loaded_dict["critic_optimizer_state_dict"],
optimizer_name="critic",
)
elif self.is_main_process:
logger.info(
"Skipping optimizer state restore during offline evaluation."
)
self.current_learning_iteration = loaded_dict.get("iter", 0)
self._maybe_override_loaded_actor_sigma()
self._load_extra_checkpoint_state(loaded_dict)
return loaded_dict.get("infos", None)
def _restore_optimizer_state(
self,
optimizer,
loaded_state_dict,
*,
optimizer_name: str,
) -> bool:
compatible, reason = self._optimizer_state_is_compatible(
optimizer, loaded_state_dict
)
if not compatible:
if self.is_main_process:
logger.warning(
"Skipping {} optimizer state restore from checkpoint: {}",
optimizer_name,
reason,
)
return False
try:
optimizer.load_state_dict(loaded_state_dict)
except ValueError as exc:
if self.is_main_process:
logger.warning(
"Skipping {} optimizer state restore from checkpoint: {}",
optimizer_name,
exc,
)
return False
return True
def _optimizer_state_is_compatible(
self, optimizer, loaded_state_dict
) -> tuple[bool, str | None]:
current_state_dict = optimizer.state_dict()
current_groups = current_state_dict.get("param_groups")
loaded_groups = loaded_state_dict.get("param_groups")
if not isinstance(current_groups, list) or not isinstance(
loaded_groups, list
):
return True, None
if len(current_groups) != len(loaded_groups):
return (
False,
"param group count mismatch "
f"(current={len(current_groups)}, loaded={len(loaded_groups)})",
)
for group_idx, (current_group, loaded_group) in enumerate(
zip(current_groups, loaded_groups)
):
current_param_count = len(current_group.get("params", []))
loaded_param_count = len(loaded_group.get("params", []))
if current_param_count != loaded_param_count:
return (
False,
"param group size mismatch for group "
f"{group_idx} (current={current_param_count}, "
f"loaded={loaded_param_count})",
)
return True, None
def save(self, path, infos=None):
if not self.is_main_process:
return
logger.info(f"Saving checkpoint to {path}")
base_path = path.replace(".pt", "")
os.makedirs(
os.path.dirname(base_path) if os.path.dirname(base_path) else ".",
exist_ok=True,
)
self.accelerator.save_model(
self.actor, os.path.join(base_path, "actor")
)
self.accelerator.save_model(
self.critic, os.path.join(base_path, "critic")
)
custom_state = {
"actor_optimizer_state_dict": self.actor_optimizer.state_dict(),
"critic_optimizer_state_dict": self.critic_optimizer.state_dict(),
"iter": self.current_learning_iteration,
"infos": infos,
}
custom_state.update(self._extra_checkpoint_state())
torch.save(_checkpoint_state_to_cpu(custom_state), path)
if bool(self.config.get("export_policy", False)):
export_policy_to_onnx_common(
self,
path,
onnx_name_suffix=self.config.get("onnx_name_suffix", None),
use_kv_cache=bool(self.config.get("use_kv_cache", True)),
)
def offline_evaluate_policy(self, dump_npzs: bool = False):
"""Dump NPZs (no metrics) from validation cache using ref_motion command.
- Iterates validation batches; env i -> clip i (deterministic) starting at frame 0.
- Collect robot and reference sequences each step and save one NPZ per clip.
- NPZ conforms to holomotion_retargeted format keys.
- Optionally records viewport MP4(s) aligned with target_fps and rollout length.
"""
ckpt_path = self.config.checkpoint
n_fut_frames = self.env.config.commands.ref_motion.params.get(
"n_fut_frames", 8
)
# log_dir is already set to checkpoint directory in eval script
model_name = os.path.basename(ckpt_path).replace(".pt", "")
# Eval modes (freeze normalizers if enabled)
self.actor.eval()
self.critic.eval()
# Require ref_motion command and simple cache backend
command_name = list(self.env.config.commands.keys())[0]
if command_name != "ref_motion":
logger.warning(
"Offline evaluation only supported for ref_motion command"
)
return {}
motion_cmd = self.env._env.command_manager.get_term("ref_motion")
cache = getattr(motion_cmd, "_motion_cache", None)
if cache is None:
logger.error(
"Offline evaluation requires hdf5_simple cache backend (no LMDB support)"
)
return {}
self._offline_evaluating = True
# Evaluation flag and cache batch-size adjustment (ensure batch_size == num_envs)
motion_cmd._is_evaluating = True
num_envs = self.env.num_envs
try:
if getattr(cache, "_batch_size", None) != num_envs:
from holomotion.src.training.h5_dataloader import (
MotionClipBatchCache,
)
cache = MotionClipBatchCache(
train_dataset=cache._datasets["train"],
val_dataset=cache._datasets["val"],
batch_size=num_envs,
stage_device=getattr(cache, "_stage_device", None),
num_workers=getattr(cache, "_num_workers", 0),
prefetch_factor=getattr(cache, "_prefetch_factor", None),
pin_memory=getattr(cache, "_pin_memory", True),
persistent_workers=getattr(
cache, "_persistent_workers", False
),
sampler_rank=getattr(cache, "_sampler_rank", 0),
sampler_world_size=getattr(
cache, "_sampler_world_size", 1
),
allowed_prefixes=getattr(cache, "_allowed_prefixes", None),
swap_interval_steps=getattr(
cache, "swap_interval_steps", None
),
force_timeout_on_swap=getattr(
cache, "force_timeout_on_swap", True
),
seed=getattr(cache, "_seed", None),
loader_timeout=getattr(cache, "_loader_timeout", 0.0),
)
motion_cmd._motion_cache = cache
except Exception as e:
logger.warning(
f"Offline eval: failed to rebuild cache to batch_size={num_envs}: {e}"
)
# Derive HDF5 dataset base name (from validation dataset root) for output naming
dataset_suffix = None
val_dataset = cache._datasets["val"]
dataset_root = None
if hasattr(val_dataset, "hdf5_root"):
dataset_root = str(val_dataset.hdf5_root).rstrip(os.sep)
elif hasattr(val_dataset, "ts_roots"):
ts_roots = getattr(val_dataset, "ts_roots")
if ts_roots:
dataset_root = str(ts_roots[0]).rstrip(os.sep)
if dataset_root:
dataset_suffix = os.path.basename(dataset_root)
# Output directory (respect existing log_dir derived from checkpoint)
suffix = f"isaaclab_eval_output_{model_name}"
if dataset_suffix is not None:
suffix = f"{suffix}_{dataset_suffix}"
output_dir = os.path.join(self.log_dir, suffix)
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving evaluation outputs to: {output_dir}")
# Switch to validation cache and iterate all batches
if hasattr(cache, "set_mode"):
cache.set_mode("val")
# Determine policy/video FPS from command config (align wallclock time)
motion_fps = int(getattr(motion_cmd.cfg, "target_fps", 50))
total_batches = int(getattr(cache, "num_batches", 1))
with torch.no_grad():
for batch_idx in tqdm(
range(total_batches), desc="Evaluating batches"
):
if batch_idx > 0:
cache.advance()
# Reset envs first, then apply deterministic mapping on the active cache batch
_ = self.env.reset_all()
if hasattr(motion_cmd, "setup_offline_eval_deterministic"):
motion_cmd.setup_offline_eval_deterministic(
apply_pending_swap=False
)
self._reset_rollout_forward_state()
# Read current batch metadata AFTER reset + setup
current = getattr(cache, "current_batch", None)
if current is None or not hasattr(current, "motion_keys"):
logger.warning(
"Current cache batch missing motion_keys; skipping batch"
)
continue
motion_keys = list(current.motion_keys)
raw_motion_keys = list(
getattr(current, "raw_motion_keys", current.motion_keys)
)
# Determine active env count for this batch
clip_count = int(cache.clip_count)
active_count = min(num_envs, clip_count)
if active_count > 0:
active_ids = torch.arange(
active_count,
dtype=torch.long,
device=self.device,
)
motion_cmd.force_realign_offline_eval_no_perturb(
active_ids
)
# Recompute observations after deterministic setup
obs_mgr = self.env._env.observation_manager
if active_count > 0:
obs_mgr.reset(active_ids)
obs_dict = obs_mgr.compute(update_history=True)
else:
obs_dict = obs_mgr.compute(update_history=True)
obs = self._wrap_obs_dict(obs_dict)
# Map env -> motion_key for active envs
env_motion_keys = {
int(i): motion_keys[int(i)] for i in range(active_count)
}
env_raw_motion_keys = {
int(i): raw_motion_keys[int(i)]
for i in range(active_count)
}
# Prepare per-env collectors
env_has_done = torch.zeros(
num_envs, dtype=torch.bool, device=self.device
)
episode_lengths = torch.zeros(
num_envs, dtype=torch.long, device=self.device
)
active_mask = torch.zeros(
num_envs, dtype=torch.bool, device=self.device
)
if active_count > 0:
active_mask[:active_count] = True
# Reference collectors (URDF order)
ref_dof_pos = [[] for _ in range(active_count)]
ref_dof_vel = [[] for _ in range(active_count)]
ref_body_pos = [[] for _ in range(active_count)]
ref_body_rot_xyzw = [[] for _ in range(active_count)]
ref_body_vel = [[] for _ in range(active_count)]
ref_body_ang_vel = [[] for _ in range(active_count)]
# Robot collectors (URDF order)
robot_dof_pos = [[] for _ in range(active_count)]
robot_dof_vel = [[] for _ in range(active_count)]
robot_body_pos = [[] for _ in range(active_count)]
robot_body_rot_xyzw = [[] for _ in range(active_count)]
robot_body_vel = [[] for _ in range(active_count)]
robot_body_ang_vel = [[] for _ in range(active_count)]
robot_dof_acc = [[] for _ in range(active_count)]
robot_dof_torque = [[] for _ in range(active_count)]
robot_action_rate = [[] for _ in range(active_count)]
prev_robot_dof_vel = [None for _ in range(active_count)]
prev_robot_actions = [None for _ in range(active_count)]
step_dt = float(self.env._env.step_dt)
# Per-env bookkeeping
clip_lengths_np = (
current.lengths.detach().cpu().numpy()
if hasattr(current, "lengths")
else np.array(
[getattr(cache, "max_frame_length", 1000)]
* active_count
)
)
# Persist an explicit mapping file for verification
try:
mapping_records = []
for i in range(active_count):
mapping_records.append(
{
"env_id": int(i),
"motion_key": env_motion_keys[int(i)],
"raw_motion_key": env_raw_motion_keys[int(i)],
"clip_length": int(clip_lengths_np[int(i)]),
}
)
mapping_path = os.path.join(
output_dir, f"batch_{batch_idx:04d}_mapping.json"
)
with open(mapping_path, "w") as f:
json.dump(mapping_records, f, indent=2)
except Exception:
pass
env_frame_counts = [0 for _ in range(active_count)]
encountered_done = [False for _ in range(active_count)]
valid_masks = [[] for _ in range(active_count)]
def _sanitize_key(key: str) -> str:
return (
key.replace("/", "+")
.replace(" ", "_")
.replace("\\", "+")
)
def _get_out_path(idx: int) -> str:
out_name = f"{_sanitize_key(env_motion_keys[idx])}.npz"
return os.path.join(output_dir, out_name)
def _save_env_npz(idx: int):
if idx >= active_count:
return
# Total collected frames
total_len = int(min(env_frame_counts[idx], max_steps))
if total_len <= 0:
return
# Compute contiguous valid prefix length and slice_len
vm = valid_masks[idx][:total_len]
valid_prefix_len = 0
for b in vm:
if b:
valid_prefix_len += 1
else:
break
clip_len = int(clip_lengths_np[idx])
slice_len = int(min(valid_prefix_len, clip_len, total_len))
if slice_len <= 0:
return
# Reference arrays (sliced)
ref_dof_pos_arr = np.stack(
ref_dof_pos[idx][:slice_len], axis=0
).astype(np.float32)
ref_dof_vel_arr = np.stack(
ref_dof_vel[idx][:slice_len], axis=0
).astype(np.float32)
ref_body_pos_arr = np.stack(
ref_body_pos[idx][:slice_len], axis=0
).astype(np.float32)
ref_body_rot_xyzw_arr = np.stack(
ref_body_rot_xyzw[idx][:slice_len], axis=0
).astype(np.float32)
ref_body_vel_arr = np.stack(
ref_body_vel[idx][:slice_len], axis=0
).astype(np.float32)
ref_body_ang_vel_arr = np.stack(
ref_body_ang_vel[idx][:slice_len], axis=0
).astype(np.float32)
# Robot arrays (sliced)
robot_dof_pos_arr = np.stack(
robot_dof_pos[idx][:slice_len], axis=0
).astype(np.float32)
robot_dof_vel_arr = np.stack(
robot_dof_vel[idx][:slice_len], axis=0
).astype(np.float32)
robot_dof_acc_arr = np.stack(
robot_dof_acc[idx][:slice_len], axis=0
).astype(np.float32)
robot_dof_torque_arr = np.stack(
robot_dof_torque[idx][:slice_len], axis=0
).astype(np.float32)
robot_action_rate_arr = np.asarray(
robot_action_rate[idx][:slice_len], dtype=np.float32
)
robot_body_pos_arr = np.stack(
robot_body_pos[idx][:slice_len], axis=0
).astype(np.float32)
robot_body_rot_xyzw_arr = np.stack(
robot_body_rot_xyzw[idx][:slice_len], axis=0
).astype(np.float32)
robot_body_vel_arr = np.stack(
robot_body_vel[idx][:slice_len], axis=0
).astype(np.float32)
robot_body_ang_vel_arr = np.stack(
robot_body_ang_vel[idx][:slice_len], axis=0
).astype(np.float32)
# Metadata
motion_fps = int(getattr(motion_cmd.cfg, "target_fps", 50))
num_dofs = int(ref_dof_pos_arr.shape[1])
num_bodies = int(ref_body_pos_arr.shape[1])
wallclock_len = (
float(slice_len - 1) / float(motion_fps)
if motion_fps > 0 and slice_len > 0
else 0.0
)
meta = {
"motion_key": env_motion_keys[idx],
"raw_motion_key": env_raw_motion_keys[idx],
"motion_fps": float(motion_fps),
"num_frames": int(slice_len),
"wallclock_len": float(wallclock_len),
"num_dofs": int(num_dofs),
"num_bodies": int(num_bodies),
"clip_length": int(clip_lengths_np[idx]),
"valid_prefix_len": int(valid_prefix_len),
}
# Output filename: flattened motion_key
out_path = _get_out_path(idx)
np.savez_compressed(
out_path,
metadata=json.dumps(meta),
robot_dof_pos=robot_dof_pos_arr,
robot_dof_vel=robot_dof_vel_arr,
robot_dof_acc=robot_dof_acc_arr,
robot_dof_torque=robot_dof_torque_arr,
robot_action_rate=robot_action_rate_arr,
robot_global_translation=robot_body_pos_arr,
robot_global_rotation_quat=robot_body_rot_xyzw_arr,
robot_global_velocity=robot_body_vel_arr,
robot_global_angular_velocity=robot_body_ang_vel_arr,
ref_dof_pos=ref_dof_pos_arr,
ref_dof_vel=ref_dof_vel_arr,
ref_global_translation=ref_body_pos_arr,
ref_global_rotation_quat=ref_body_rot_xyzw_arr,
ref_global_velocity=ref_body_vel_arr,
ref_global_angular_velocity=ref_body_ang_vel_arr,
)
max_steps = int(
getattr(cache, "max_frame_length", 1000)
) # decide the max_length to evaluate
for rollout_step in tqdm(
range(max_steps), desc="Rollout steps"
):
# PRE-STEP: collect states for all active envs
active = [i for i in range(active_count)]
if len(active) > 0:
# Reference step tensors (URDF order)
ref_dp = (
motion_cmd.get_ref_motion_dof_pos_cur_urdf_order()
.detach()
.cpu()
.numpy()
)
ref_dv = (
motion_cmd.get_ref_motion_dof_vel_cur_urdf_order()
.detach()
.cpu()
.numpy()
)
ref_bp = (
motion_cmd.get_ref_motion_bodylink_global_pos_cur_urdf_order()
.detach()
.cpu()
.numpy()
)
ref_br = (
motion_cmd.get_ref_motion_bodylink_global_rot_xyzw_cur_urdf_order()
.detach()
.cpu()
.numpy()
)
ref_bv = (
motion_cmd.get_ref_motion_bodylink_global_lin_vel_cur_urdf_order()
.detach()
.cpu()
.numpy()
)
ref_bav = (
motion_cmd.get_ref_motion_bodylink_global_ang_vel_cur_urdf_order()
.detach()
.cpu()
.numpy()
)
# Robot step tensors (URDF order)
rob_dp = (
motion_cmd.robot_dof_pos_cur_urdf_order.detach()
.cpu()
.numpy()
)
rob_dv = (
motion_cmd.robot_dof_vel_cur_urdf_order.detach()
.cpu()
.numpy()
)
rob_bp = (
motion_cmd.robot_bodylink_global_pos_cur_urdf_order.detach()
.cpu()
.numpy()
)
rob_br = (
motion_cmd.robot_bodylink_global_rot_xyzw_cur_urdf_order.detach()
.cpu()
.numpy()
)
rob_bv = (
motion_cmd.robot_bodylink_global_lin_vel_cur_urdf_order.detach()
.cpu()
.numpy()
)
rob_bav = (
motion_cmd.robot_bodylink_global_ang_vel_cur_urdf_order.detach()
.cpu()
.numpy()
)
for idx in active:
if prev_robot_dof_vel[idx] is None:
dof_acc_cur = np.zeros_like(
rob_dv[idx], dtype=np.float32
)
else:
dof_acc_cur = (
rob_dv[idx] - prev_robot_dof_vel[idx]
) / step_dt
prev_robot_dof_vel[idx] = rob_dv[idx].copy()
ref_dof_pos[idx].append(ref_dp[idx])
ref_dof_vel[idx].append(ref_dv[idx])
ref_body_pos[idx].append(ref_bp[idx])
ref_body_rot_xyzw[idx].append(ref_br[idx])
ref_body_vel[idx].append(ref_bv[idx])
ref_body_ang_vel[idx].append(ref_bav[idx])
robot_dof_pos[idx].append(rob_dp[idx])
robot_dof_vel[idx].append(rob_dv[idx])
robot_dof_acc[idx].append(
dof_acc_cur.astype(np.float32)
)
robot_body_pos[idx].append(rob_bp[idx])
robot_body_rot_xyzw[idx].append(rob_br[idx])
robot_body_vel[idx].append(rob_bv[idx])
robot_body_ang_vel[idx].append(rob_bav[idx])
# Record valid mask for current frame (before step)
clip_limit = int(clip_lengths_np[idx])
valid_now = (
(idx < active_count)
and (not encountered_done[idx])
and (
env_frame_counts[idx]
< clip_limit - n_fut_frames
)
)
valid_masks[idx].append(bool(valid_now))
# Increment local frame counter
env_frame_counts[idx] += 1
# No mid-rollout finalize; we defer to end using valid masks
# Inference and step (advance sim)
obs = self._rollout_forward(
obs,
actor_mode="inference",
collect_transition=False,
track_episode_stats=False,
)
dones = self._last_rollout_dones
if dones is None:
raise RuntimeError(
"Rollout forward did not return dones during offline evaluation."
)
actions_step = self._last_rollout_actions
if actions_step is None:
raise RuntimeError(
"Rollout forward did not return actions during offline evaluation."
)
actions_np = actions_step.detach().cpu().numpy()
torque_urdf = (
motion_cmd.robot.data.applied_torque[
..., motion_cmd.sim2urdf_dof_idx
]
.detach()
.cpu()
.numpy()
)
for idx in range(active_count):
if prev_robot_actions[idx] is None:
action_rate_cur = 0.0
else:
action_rate_cur = float(
np.linalg.norm(
actions_np[idx] - prev_robot_actions[idx]
)
/ step_dt
)
prev_robot_actions[idx] = actions_np[idx].copy()
robot_action_rate[idx].append(
np.float32(action_rate_cur)
)
robot_dof_torque[idx].append(
torque_urdf[idx].astype(np.float32)
)
# Handle RL dones (first-done policy): mark done for future frames
step_dones = (
dones.bool().reshape(-1).detach().cpu().numpy()
)
for idx in range(min(active_count, len(step_dones))):
if step_dones[idx] and not encountered_done[idx]:
encountered_done[idx] = True
if rollout_step == max_steps - 1:
# End of rollout: save once per env with full rollout arrays + valid_mask
if dump_npzs and active_count > 0:
out_path_to_last_idx = {}
for idx in range(active_count):
out_path_to_last_idx[_get_out_path(idx)] = idx
save_indices = list(out_path_to_last_idx.values())
max_npz_save_workers = max(
1, min(16, len(save_indices))
)
with ThreadPoolExecutor(
max_workers=max_npz_save_workers
) as executor:
futures = [
executor.submit(_save_env_npz, idx)
for idx in save_indices
]
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Saving NPZs",
):
future.result()
break
logger.info(
f"Offline evaluation complete: saved clips to {output_dir}"
)
return {"output_dir": output_dir}
================================================
FILE: holomotion/src/algo/ppo_tf.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Generator
import torch
import torch.nn.functional as F
import torch.optim as optim
from holomotion.src.algo.algo_utils import PpoAuxTransition
from holomotion.src.algo.ppo import PPO
from holomotion.src.modules.agent_modules import (
PPOCondTFActor,
PPOCritic,
PPOTFActor,
PPOTFRefRouterActor,
PPOTFRefRouterSeqActor,
PPOTFRefRouterV3Actor,
TensorDictAssembler,
)
from holomotion.src.modules.network_modules import GroupedMoEBlock
from loguru import logger
from omegaconf import OmegaConf
from tabulate import tabulate
from tensordict import TensorDict
class PPOTF(PPO):
"""Transformer-policy PPO with TensorDict rollout and sequence update."""
@staticmethod
def _select_actor_wrapper_cls(actor_cfg: dict):
actor_type = str(actor_cfg.get("type", ""))
use_future_cross_attn = bool(
actor_cfg.get("use_future_cross_attn", False)
)
if actor_type == "ReferenceRoutedGroupedMoETransformerPolicy":
if use_future_cross_attn:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicy does not "
"support use_future_cross_attn=True."
)
return PPOTFRefRouterActor
if actor_type == "ReferenceRoutedGroupedMoETransformerPolicyV2":
if use_future_cross_attn:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 does not "
"support use_future_cross_attn=True."
)
return PPOTFRefRouterSeqActor
if actor_type == "ReferenceRoutedGroupedMoETransformerPolicyV3":
if use_future_cross_attn:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 does not "
"support use_future_cross_attn=True."
)
return PPOTFRefRouterV3Actor
if use_future_cross_attn:
return PPOCondTFActor
return PPOTFActor
@staticmethod
def _summarize_moe_layer_stats(moe_layers) -> dict[str, float | None]:
if len(moe_layers) == 0:
return {
"moe_active_expert_ratio": None,
"moe_max_expert_frac": None,
"moe_least_expert_frac": None,
"moe_dead_expert_ratio": None,
"moe_expert_count_cv": None,
"moe_selected_expert_margin_to_unselected": None,
}
def _mean_attr(attr_name: str) -> float:
values = torch.stack(
[
getattr(layer, attr_name).to(torch.float32)
for layer in moe_layers
]
)
return float(values.mean().item())
return {
"moe_active_expert_ratio": _mean_attr("last_active_expert_ratio"),
"moe_max_expert_frac": _mean_attr("last_max_expert_frac"),
"moe_least_expert_frac": _mean_attr("last_min_expert_frac"),
"moe_dead_expert_ratio": _mean_attr("last_dead_expert_ratio"),
"moe_expert_count_cv": _mean_attr("last_expert_count_cv"),
"moe_selected_expert_margin_to_unselected": _mean_attr(
"last_selected_expert_margin_to_unselected"
),
}
def _setup_configs(self):
super()._setup_configs()
aux_cfg = self.config.get("aux_state_pred", {})
self.use_aux_state_pred: bool = bool(aux_cfg.get("enabled", False))
self.aux_state_pred_w_base_lin_vel = float(
aux_cfg.get("w_base_lin_vel", 0.0)
)
self.aux_state_pred_w_root_height = float(
aux_cfg.get("w_root_height", 0.0)
)
self.aux_state_pred_w_keybody_contact = float(
aux_cfg.get("w_keybody_contact", 0.0)
)
self.aux_state_pred_w_ref_keybody_rel_pos = float(
aux_cfg.get("w_ref_keybody_rel_pos", 0.0)
)
self.aux_state_pred_w_robot_keybody_rel_pos = float(
aux_cfg.get("w_robot_keybody_rel_pos", 0.0)
)
self.aux_state_pred_w_denoise_ref_root_lin_vel = float(
aux_cfg.get("w_denoise_ref_root_lin_vel", 0.0)
)
self.aux_state_pred_w_denoise_ref_root_ang_vel = float(
aux_cfg.get("w_denoise_ref_root_ang_vel", 0.0)
)
self.aux_state_pred_w_denoise_ref_dof_pos = float(
aux_cfg.get("w_denoise_ref_dof_pos", 0.0)
)
self.aux_state_pred_keybody_contact_names = [
str(name) for name in aux_cfg.get("keybody_contact_names", [])
]
self.aux_state_pred_keybody_rel_pos_names = [
str(name) for name in aux_cfg.get("keybody_rel_pos_names", [])
]
self.aux_state_pred_num_contact_bodies = int(
len(self.aux_state_pred_keybody_contact_names)
)
self.aux_state_pred_num_keybody_bodies = int(
len(self.aux_state_pred_keybody_rel_pos_names)
)
self.use_aux_root_height = bool(
self.use_aux_state_pred and self.aux_state_pred_w_root_height > 0.0
)
self.use_aux_denoise_ref_root_lin_vel = bool(
self.use_aux_state_pred
and self.aux_state_pred_w_denoise_ref_root_lin_vel > 0.0
)
self.use_aux_denoise_ref_root_ang_vel = bool(
self.use_aux_state_pred
and self.aux_state_pred_w_denoise_ref_root_ang_vel > 0.0
)
self.use_aux_denoise_ref_dof_pos = bool(
self.use_aux_state_pred
and self.aux_state_pred_w_denoise_ref_dof_pos > 0.0
)
self.aux_state_pred_min_std = float(aux_cfg.get("min_std", 1.0e-3))
self.aux_state_pred_max_std = float(aux_cfg.get("max_std", 5.0))
self.aux_denoise_residual_huber_beta = float(
aux_cfg.get("denoise_residual_huber_beta", 0.1)
)
self.aux_state_pred_raycast_z_offset = float(
aux_cfg.get("raycast_z_offset", 1.0)
)
self.aux_state_pred_raycast_max_dist = float(
aux_cfg.get("raycast_max_dist", 20.0)
)
if self.aux_state_pred_min_std <= 0.0:
raise ValueError("aux_state_pred.min_std must be > 0.")
if self.aux_state_pred_max_std <= self.aux_state_pred_min_std:
raise ValueError(
"aux_state_pred.max_std must be > aux_state_pred.min_std."
)
if self.aux_denoise_residual_huber_beta <= 0.0:
raise ValueError(
"aux_state_pred.denoise_residual_huber_beta must be > 0."
)
if self.aux_state_pred_w_base_lin_vel < 0.0:
raise ValueError("aux_state_pred.w_base_lin_vel must be >= 0.")
if self.aux_state_pred_w_root_height < 0.0:
raise ValueError("aux_state_pred.w_root_height must be >= 0.")
if self.aux_state_pred_w_keybody_contact < 0.0:
raise ValueError("aux_state_pred.w_keybody_contact must be >= 0.")
if self.aux_state_pred_w_ref_keybody_rel_pos < 0.0:
raise ValueError(
"aux_state_pred.w_ref_keybody_rel_pos must be >= 0."
)
if self.aux_state_pred_w_robot_keybody_rel_pos < 0.0:
raise ValueError(
"aux_state_pred.w_robot_keybody_rel_pos must be >= 0."
)
if self.aux_state_pred_w_denoise_ref_root_lin_vel < 0.0:
raise ValueError(
"aux_state_pred.w_denoise_ref_root_lin_vel must be >= 0."
)
if self.aux_state_pred_w_denoise_ref_root_ang_vel < 0.0:
raise ValueError(
"aux_state_pred.w_denoise_ref_root_ang_vel must be >= 0."
)
if self.aux_state_pred_w_denoise_ref_dof_pos < 0.0:
raise ValueError(
"aux_state_pred.w_denoise_ref_dof_pos must be >= 0."
)
if self.use_aux_root_height:
if self.aux_state_pred_raycast_max_dist <= 0.0:
raise ValueError(
"aux_state_pred.raycast_max_dist must be > 0."
)
if self.aux_state_pred_raycast_z_offset < 0.0:
raise ValueError(
"aux_state_pred.raycast_z_offset must be >= 0."
)
if (
self.aux_state_pred_w_keybody_contact > 0.0
and self.aux_state_pred_num_contact_bodies == 0
):
raise ValueError(
"aux_state_pred.w_keybody_contact > 0 requires "
"aux_state_pred.keybody_contact_names to be non-empty."
)
if (
self.aux_state_pred_w_ref_keybody_rel_pos > 0.0
or self.aux_state_pred_w_robot_keybody_rel_pos > 0.0
) and self.aux_state_pred_num_keybody_bodies == 0:
raise ValueError(
"aux_state_pred keybody position weights > 0 require "
"aux_state_pred.keybody_rel_pos_names to be non-empty."
)
if self.use_aux_state_pred and self.command_name != "ref_motion":
raise ValueError(
"aux_state_pred is only supported for PPOTF motion tracking "
"(command_name='ref_motion')."
)
PpoAuxTransition.SHAPE_TOKENS["C"] = (
self.aux_state_pred_num_contact_bodies
)
PpoAuxTransition.SHAPE_TOKENS["K"] = (
self.aux_state_pred_num_keybody_bodies
)
aux_cmd_cfg = self.config.get("aux_router_command_recon", {})
self.use_aux_router_command_recon: bool = bool(
aux_cmd_cfg.get("enabled", False)
)
self.aux_router_command_recon_weight = float(
aux_cmd_cfg.get("weight", 0.0)
)
self.aux_router_command_recon_hidden_dim = int(
aux_cmd_cfg.get("hidden_dim", 0)
)
self.aux_router_command_recon_term_prefix = str(
aux_cmd_cfg.get("term_prefix", "actor_ref_")
)
aux_switch_cfg = self.config.get("aux_router_switch_penalty", {})
self.use_aux_router_switch_penalty = bool(
aux_switch_cfg.get("enabled", False)
)
self.aux_router_switch_penalty_weight = float(
aux_switch_cfg.get("weight", 0.0)
)
self.aux_router_switch_penalty_metric = str(
aux_switch_cfg.get("metric", "js")
).lower()
self.aux_router_switch_penalty_beta = float(
aux_switch_cfg.get("beta", 1.0)
)
aux_router_future_cfg = self.config.get("aux_router_future_recon", {})
self.use_aux_router_future_recon = bool(
aux_router_future_cfg.get("enabled", False)
)
self.aux_router_future_recon_weight = float(
aux_router_future_cfg.get("weight", 0.0)
)
self.aux_router_future_recon_hidden_dim = int(
aux_router_future_cfg.get("hidden_dim", 0)
)
self.aux_router_future_recon_huber_beta = float(
aux_router_future_cfg.get("huber_beta", 1.0)
)
dead_margin_cfg = self.config.get("dead_expert_margin_to_topk", {})
self.use_dead_expert_margin_to_topk = bool(
dead_margin_cfg.get("enabled", False)
)
self.dead_expert_margin_to_topk_weight = float(
dead_margin_cfg.get("weight", 0.0)
)
orth_cfg = self.config.get("router_expert_orthogonal", {})
self.use_router_expert_orthogonal = bool(
orth_cfg.get("enabled", False)
)
self.router_expert_orthogonal_weight = float(
orth_cfg.get("weight", 0.0)
)
self.router_expert_orthogonal_min_active_usage = float(
orth_cfg.get("min_active_usage", 1.0e-3)
)
self.router_expert_orthogonal_eps = float(orth_cfg.get("eps", 1.0e-8))
selected_margin_cfg = self.config.get(
"selected_expert_margin_to_unselected", {}
)
self.use_selected_expert_margin_to_unselected = bool(
selected_margin_cfg.get("enabled", False)
)
self.selected_expert_margin_to_unselected_weight = float(
selected_margin_cfg.get("weight", 0.0)
)
self.selected_expert_margin_to_unselected_target = float(
selected_margin_cfg.get("target", 0.0)
)
if self.aux_router_switch_penalty_metric not in {
"js",
"normed_smooth_l1",
}:
raise ValueError(
"aux_router_switch_penalty.metric must be one of "
"{'js', 'normed_smooth_l1'}, got "
f"{self.aux_router_switch_penalty_metric!r}."
)
if self.aux_router_command_recon_weight < 0.0:
raise ValueError("aux_router_command_recon.weight must be >= 0.")
if self.aux_router_future_recon_weight < 0.0:
raise ValueError("aux_router_future_recon.weight must be >= 0.")
if self.aux_router_switch_penalty_weight < 0.0:
raise ValueError("aux_router_switch_penalty.weight must be >= 0.")
if self.dead_expert_margin_to_topk_weight < 0.0:
raise ValueError("dead_expert_margin_to_topk.weight must be >= 0.")
if self.router_expert_orthogonal_weight < 0.0:
raise ValueError("router_expert_orthogonal.weight must be >= 0.")
if self.router_expert_orthogonal_min_active_usage < 0.0:
raise ValueError(
"router_expert_orthogonal.min_active_usage must be >= 0."
)
if self.selected_expert_margin_to_unselected_weight < 0.0:
raise ValueError(
"selected_expert_margin_to_unselected.weight must be >= 0."
)
if self.router_expert_orthogonal_eps <= 0.0:
raise ValueError("router_expert_orthogonal.eps must be > 0.")
if self.aux_router_switch_penalty_beta <= 0.0:
raise ValueError("aux_router_switch_penalty.beta must be > 0.")
if self.aux_router_future_recon_huber_beta <= 0.0:
raise ValueError("aux_router_future_recon.huber_beta must be > 0.")
if self.selected_expert_margin_to_unselected_target < 0.0:
raise ValueError(
"selected_expert_margin_to_unselected.target must be >= 0."
)
if (
self.use_dead_expert_margin_to_topk
and self.dead_expert_margin_to_topk_weight == 0.0
):
logger.warning(
"dead_expert_margin_to_topk.enabled=True but weight=0.0; "
"dead-expert margin loss will have no effect."
)
if (
self.use_router_expert_orthogonal
and not self.use_dead_expert_margin_to_topk
):
raise ValueError(
"router_expert_orthogonal.enabled=True requires "
"dead_expert_margin_to_topk.enabled=True in sparse top-k MoE."
)
if (
self.use_router_expert_orthogonal
and self.router_expert_orthogonal_weight == 0.0
):
logger.warning(
"router_expert_orthogonal.enabled=True but weight=0.0; "
"orthogonal regularization will have no effect."
)
if (
self.use_selected_expert_margin_to_unselected
and self.selected_expert_margin_to_unselected_weight == 0.0
):
logger.warning(
"selected_expert_margin_to_unselected.enabled=True but "
"weight=0.0; selected-expert margin loss will have no effect."
)
if (
self.use_aux_router_switch_penalty
and self.aux_router_switch_penalty_weight == 0.0
):
logger.warning(
"aux_router_switch_penalty.enabled=True but weight=0.0; "
"router switch penalty will have no effect."
)
if (
self.use_aux_router_future_recon
and self.aux_router_future_recon_weight == 0.0
):
logger.warning(
"aux_router_future_recon.enabled=True but weight=0.0; "
"future reconstruction loss will have no effect."
)
if (
self.use_aux_router_command_recon
or self.use_aux_router_switch_penalty
or self.use_aux_router_future_recon
) and self.command_name != "ref_motion":
raise ValueError(
"aux_router_command_recon, aux_router_future_recon, and "
"aux_router_switch_penalty are "
"only supported for PPOTF motion tracking "
"(command_name='ref_motion')."
)
self.aux_command_router_num_moe_layers = 0
self.aux_command_router_num_fine_experts = 0
self.aux_router_command_recon_assembler: TensorDictAssembler | None = (
None
)
actor_cfg = self.config.get("module_dict", {}).get("actor", {})
actor_type = str(actor_cfg.get("type", ""))
if actor_type in {
"ReferenceRoutedGroupedMoETransformerPolicyV2",
"ReferenceRoutedGroupedMoETransformerPolicyV3",
}:
if self.use_aux_router_command_recon:
raise ValueError(
f"{actor_type} does not support aux_router_command_recon."
)
unsupported_aux_weights = {
"w_root_height": self.aux_state_pred_w_root_height,
"w_denoise_ref_root_lin_vel": self.aux_state_pred_w_denoise_ref_root_lin_vel,
"w_denoise_ref_root_ang_vel": self.aux_state_pred_w_denoise_ref_root_ang_vel,
"w_denoise_ref_dof_pos": self.aux_state_pred_w_denoise_ref_dof_pos,
}
enabled_unsupported = [
name
for name, value in unsupported_aux_weights.items()
if float(value) > 0.0
]
if enabled_unsupported:
raise ValueError(
f"{actor_type} only supports "
"aux_state_pred weights for base_lin_vel, keybody_contact, "
"ref_keybody_rel_pos, and robot_keybody_rel_pos. Unsupported "
"weights: " + ", ".join(enabled_unsupported)
)
elif self.use_aux_router_future_recon:
raise ValueError(
"aux_router_future_recon requires "
"ReferenceRoutedGroupedMoETransformerPolicyV2 or V3."
)
@staticmethod
def _unwrap_obs_schema(schema: dict | None) -> dict | None:
if schema is None:
return None
has_terms = any(
isinstance(v, dict) and ("terms" in v) for v in schema.values()
)
if has_terms:
return schema
if len(schema) == 1:
only_value = next(iter(schema.values()))
if isinstance(only_value, dict):
return only_value
return schema
@staticmethod
def _schema_term_leaf_name(term: str) -> str:
return str(term).split("/")[-1]
@classmethod
def _is_aux_command_term(cls, term: str, term_prefix: str) -> bool:
return cls._schema_term_leaf_name(term).startswith(term_prefix)
@classmethod
def _build_aux_router_command_recon_schema(
cls, actor_schema: dict, term_prefix: str
) -> dict:
command_schema = {}
for group_name, seq_cfg in actor_schema.items():
terms = [
str(term)
for term in seq_cfg.get("terms", [])
if cls._is_aux_command_term(str(term), term_prefix)
]
if len(terms) == 0:
continue
next_seq_cfg = dict(seq_cfg)
next_seq_cfg["terms"] = terms
command_schema[group_name] = next_seq_cfg
if len(command_schema) == 0:
raise ValueError(
"aux_router_command_recon could not find any actor command terms in "
f"obs_schema with prefix '{term_prefix}'."
)
return command_schema
@staticmethod
def _masked_aux_keybody_mse(
pred: torch.Tensor,
target: torch.Tensor,
valid_tok: torch.Tensor,
) -> torch.Tensor:
if pred.shape != target.shape:
raise ValueError(
"pred and target must have the same shape for keybody MSE, "
f"got {tuple(pred.shape)} and {tuple(target.shape)}."
)
if pred.ndim != 4:
raise ValueError(
"Keybody MSE expects [B, T, K, 3] tensors, "
f"got pred with shape {tuple(pred.shape)}."
)
per_token_mse = torch.square(pred - target).mean(dim=(-1, -2))
valid_tok = valid_tok.to(per_token_mse.dtype)
if valid_tok.shape != per_token_mse.shape:
raise ValueError(
"valid_tok must match per-token keybody MSE shape, "
f"got {tuple(valid_tok.shape)} and "
f"{tuple(per_token_mse.shape)}."
)
valid_count = valid_tok.sum().clamp_min(1.0)
return (per_token_mse * valid_tok).sum() / valid_count
@staticmethod
def _masked_aux_mse(
pred: torch.Tensor,
target: torch.Tensor,
valid_tok: torch.Tensor,
) -> torch.Tensor:
if pred.shape != target.shape:
raise ValueError(
"pred and target must share the same shape for auxiliary MSE, "
f"got {tuple(pred.shape)} and {tuple(target.shape)}."
)
if pred.ndim < 3:
raise ValueError(
"Auxiliary MSE expects tensors with shape [B, T, ...], "
f"got {tuple(pred.shape)}."
)
reduce_dims = tuple(range(2, pred.ndim))
per_token_mse = torch.square(pred - target).mean(dim=reduce_dims)
valid_tok = valid_tok.to(per_token_mse.dtype)
if valid_tok.shape != per_token_mse.shape:
raise ValueError(
"valid_tok must match per-token auxiliary MSE shape, got "
f"{tuple(valid_tok.shape)} and {tuple(per_token_mse.shape)}."
)
valid_count = valid_tok.sum().clamp_min(1.0)
return (per_token_mse * valid_tok).sum() / valid_count
@staticmethod
def _masked_adjacent_router_js(
*,
router_features: torch.Tensor,
valid_tok: torch.Tensor,
num_moe_layers: int,
num_fine_experts: int,
) -> torch.Tensor:
if router_features.ndim != 3:
raise ValueError(
"router_features must have shape [B, T, L*E], got "
f"{tuple(router_features.shape)}."
)
if valid_tok.ndim != 2:
raise ValueError(
"valid_tok must have shape [B, T], got "
f"{tuple(valid_tok.shape)}."
)
if num_moe_layers <= 0 or num_fine_experts <= 0:
raise ValueError(
"num_moe_layers and num_fine_experts must be positive, got "
f"{num_moe_layers} and {num_fine_experts}."
)
bsz, seq_len, feat_dim = router_features.shape
expected_dim = num_moe_layers * num_fine_experts
if feat_dim != expected_dim:
raise ValueError(
"router_features last dim must equal num_moe_layers * "
"num_fine_experts, got "
f"{feat_dim} vs {expected_dim}."
)
if valid_tok.shape != (bsz, seq_len):
raise ValueError(
"valid_tok shape mismatch for router temporal loss: expected "
f"{(bsz, seq_len)}, got {tuple(valid_tok.shape)}."
)
if seq_len <= 1:
return router_features.new_zeros(())
router_p = router_features.reshape(
bsz, seq_len, num_moe_layers, num_fine_experts
).to(torch.float32)
prev_p = router_p[:, :-1]
curr_p = router_p[:, 1:]
mix_p = 0.5 * (prev_p + curr_p)
eps = 1.0e-20
prev_safe = prev_p.clamp_min(eps)
curr_safe = curr_p.clamp_min(eps)
mix_safe = mix_p.clamp_min(eps)
kl_prev = (prev_p * (torch.log(prev_safe) - torch.log(mix_safe))).sum(
dim=-1
)
kl_curr = (curr_p * (torch.log(curr_safe) - torch.log(mix_safe))).sum(
dim=-1
)
js = 0.5 * (kl_prev + kl_curr)
adjacent_valid = (valid_tok[:, :-1] * valid_tok[:, 1:]).to(js.dtype)
valid_count = adjacent_valid.sum().clamp_min(1.0) * float(
num_moe_layers
)
return (js * adjacent_valid.unsqueeze(-1)).sum() / valid_count
@staticmethod
def _masked_adjacent_router_normed_smooth_l1(
*,
router_temporal_features: torch.Tensor,
valid_tok: torch.Tensor,
num_moe_layers: int,
num_fine_experts: int,
beta: float = 1.0,
) -> torch.Tensor:
if router_temporal_features.ndim != 3:
raise ValueError(
"router_temporal_features must have shape [B, T, L*E], got "
f"{tuple(router_temporal_features.shape)}."
)
if valid_tok.ndim != 2:
raise ValueError(
"valid_tok must have shape [B, T], got "
f"{tuple(valid_tok.shape)}."
)
if num_moe_layers <= 0 or num_fine_experts <= 0:
raise ValueError(
"num_moe_layers and num_fine_experts must be positive, got "
f"{num_moe_layers} and {num_fine_experts}."
)
if beta <= 0.0:
raise ValueError(
f"beta must be positive for SmoothL1, got {beta}."
)
bsz, seq_len, feat_dim = router_temporal_features.shape
expected_dim = num_moe_layers * num_fine_experts
if feat_dim != expected_dim:
raise ValueError(
"router_temporal_features last dim must equal "
"num_moe_layers * num_fine_experts, got "
f"{feat_dim} vs {expected_dim}."
)
if valid_tok.shape != (bsz, seq_len):
raise ValueError(
"valid_tok shape mismatch for router temporal loss: expected "
f"{(bsz, seq_len)}, got {tuple(valid_tok.shape)}."
)
if seq_len <= 1:
return router_temporal_features.new_zeros(())
router_logits = router_temporal_features.reshape(
bsz, seq_len, num_moe_layers, num_fine_experts
).to(torch.float32)
router_logits = router_logits - router_logits.mean(
dim=-1, keepdim=True
)
router_logits = F.normalize(router_logits, p=2.0, dim=-1, eps=1.0e-5)
prev_logits = router_logits[:, :-1]
curr_logits = router_logits[:, 1:]
smooth_l1 = F.smooth_l1_loss(
curr_logits,
prev_logits,
reduction="none",
beta=beta,
).mean(dim=(-1, -2))
adjacent_valid = (valid_tok[:, :-1] * valid_tok[:, 1:]).to(
smooth_l1.dtype
)
valid_count = adjacent_valid.sum().clamp_min(1.0)
return (smooth_l1 * adjacent_valid).sum() / valid_count
@staticmethod
def _masked_aux_gaussian_nll(
*,
loc: torch.Tensor,
log_std: torch.Tensor,
target: torch.Tensor,
valid_tok: torch.Tensor,
min_std: float,
max_std: float,
) -> tuple[torch.Tensor, torch.Tensor]:
if loc.shape != log_std.shape or loc.shape != target.shape:
raise ValueError(
"loc, log_std, and target must share the same shape for "
"Gaussian aux loss, got "
f"{tuple(loc.shape)}, {tuple(log_std.shape)}, "
f"{tuple(target.shape)}."
)
if loc.ndim < 3:
raise ValueError(
"Gaussian aux loss expects tensors with shape [B, T, ...], "
f"got {tuple(loc.shape)}."
)
per_elem_std = torch.clamp(
torch.exp(log_std),
min=float(min_std),
max=float(max_std),
)
reduce_dims = tuple(range(2, loc.ndim))
per_token_nll = 0.5 * (
torch.square((target - loc) / per_elem_std)
+ 2.0 * torch.log(per_elem_std + 1.0e-8)
).sum(dim=reduce_dims)
valid_tok = valid_tok.to(per_token_nll.dtype)
if valid_tok.shape != per_token_nll.shape:
raise ValueError(
"valid_tok must match per-token Gaussian loss shape, got "
f"{tuple(valid_tok.shape)} and {tuple(per_token_nll.shape)}."
)
valid_count = valid_tok.sum().clamp_min(1.0)
loss = (per_token_nll * valid_tok).sum() / valid_count
per_token_std = per_elem_std.reshape(
per_elem_std.shape[0], per_elem_std.shape[1], -1
).mean(dim=-1)
mean_std = (per_token_std * valid_tok).sum() / valid_count
return loss, mean_std
@staticmethod
def _masked_aux_huber(
*,
pred: torch.Tensor,
target: torch.Tensor,
valid_tok: torch.Tensor,
beta: float,
) -> torch.Tensor:
if pred.shape != target.shape:
raise ValueError(
"pred and target must share the same shape for Huber aux loss, "
f"got {tuple(pred.shape)} and {tuple(target.shape)}."
)
if pred.ndim < 3:
raise ValueError(
"Huber aux loss expects tensors with shape [B, T, ...], "
f"got {tuple(pred.shape)}."
)
per_elem = F.smooth_l1_loss(pred, target, reduction="none", beta=beta)
reduce_dims = tuple(range(2, pred.ndim))
per_token = per_elem.mean(dim=reduce_dims)
valid_tok = valid_tok.to(per_token.dtype)
if valid_tok.shape != per_token.shape:
raise ValueError(
"valid_tok must match per-token Huber loss shape, got "
f"{tuple(valid_tok.shape)} and {tuple(per_token.shape)}."
)
valid_count = valid_tok.sum().clamp_min(1.0)
return (per_token * valid_tok).sum() / valid_count
def _compute_aux_router_future_recon_loss(
self,
*,
actor_wrapper: PPOTFActor,
actor_out: TensorDict,
obs_b: TensorDict,
valid_tok: torch.Tensor,
) -> torch.Tensor:
future_assembler = actor_wrapper.aux_router_future_recon_assembler
if future_assembler is None:
raise ValueError(
"aux_router_future_recon is enabled but future assembler was "
"not initialized on the actor wrapper."
)
aux_router_future_recon_pred = actor_out.get("aux_router_future_recon")
bsz, seq_len = int(obs_b.batch_size[0]), int(obs_b.batch_size[1])
future_target = future_assembler(obs_b.flatten(0, 1)).reshape(
bsz, seq_len, -1
)
normalized_future_target = actor_wrapper.actor_module.normalize_aux_router_future_recon_target(
future_target
).to(aux_router_future_recon_pred.dtype)
return self._masked_aux_huber(
pred=aux_router_future_recon_pred,
target=normalized_future_target,
valid_tok=valid_tok,
beta=self.aux_router_future_recon_huber_beta,
)
def _compute_routed_expert_orthogonal_loss(
self,
moe_layer: GroupedMoEBlock,
*,
dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
usage = moe_layer.last_routed_expert_usage.to(
device=device, dtype=torch.float32
)
active_mask = usage > float(
self.router_expert_orthogonal_min_active_usage
)
active_idx = torch.nonzero(active_mask, as_tuple=False).squeeze(-1)
active_count = torch.tensor(
float(active_idx.numel()), device=device, dtype=torch.float32
)
if active_idx.numel() < 2:
zero = torch.zeros((), device=device, dtype=dtype)
zero_f = torch.zeros((), device=device, dtype=torch.float32)
return zero, active_count, zero_f
expert_vecs = moe_layer.down_proj.index_select(0, active_idx)
expert_vecs = expert_vecs.reshape(active_idx.numel(), -1).to(
device=device, dtype=torch.float32
)
expert_vecs = F.normalize(
expert_vecs,
p=2.0,
dim=-1,
eps=float(self.router_expert_orthogonal_eps),
)
gram = expert_vecs @ expert_vecs.transpose(0, 1)
offdiag_mask = ~torch.eye(
gram.shape[0], dtype=torch.bool, device=gram.device
)
offdiag = gram.masked_select(offdiag_mask)
if offdiag.numel() == 0:
zero = torch.zeros((), device=device, dtype=dtype)
zero_f = torch.zeros((), device=device, dtype=torch.float32)
return zero, active_count, zero_f
orth_loss = offdiag.square().sum().to(dtype)
mean_offdiag_similarity = offdiag.abs().mean()
return orth_loss, active_count, mean_offdiag_similarity
@staticmethod
def _root_relative_body_pos_from_mixed_position_frames(
*,
body_pos_w: torch.Tensor,
root_pos_env: torch.Tensor,
root_quat_w: torch.Tensor,
env_origins: torch.Tensor,
) -> torch.Tensor:
"""Convert world-frame body positions using an env-frame root pose.
In IsaacLab, `isaaclab_mdp.root_pos_w(env)` is already in the
environment frame (simulator world minus `env.scene.env_origins`),
while `robot.data.body_pos_w` stays in simulator-world coordinates.
"""
if body_pos_w.ndim != 3 or body_pos_w.shape[-1] != 3:
raise ValueError(
"body_pos_w must have shape [B, N, 3], "
f"got {tuple(body_pos_w.shape)}."
)
if root_pos_env.ndim != 2 or root_pos_env.shape[-1] != 3:
raise ValueError(
"root_pos_env must have shape [B, 3], "
f"got {tuple(root_pos_env.shape)}."
)
if root_quat_w.ndim != 2 or root_quat_w.shape[-1] != 4:
raise ValueError(
"root_quat_w must have shape [B, 4], "
f"got {tuple(root_quat_w.shape)}."
)
if env_origins.ndim != 2 or env_origins.shape[-1] != 3:
raise ValueError(
"env_origins must have shape [B, 3], "
f"got {tuple(env_origins.shape)}."
)
if body_pos_w.shape[0] != root_pos_env.shape[0]:
raise ValueError(
"Batch size mismatch between body_pos_w and root_pos_env: "
f"{body_pos_w.shape[0]} vs {root_pos_env.shape[0]}."
)
if body_pos_w.shape[0] != root_quat_w.shape[0]:
raise ValueError(
"Batch size mismatch between body_pos_w and root_quat_w: "
f"{body_pos_w.shape[0]} vs {root_quat_w.shape[0]}."
)
if body_pos_w.shape[0] != env_origins.shape[0]:
raise ValueError(
"Batch size mismatch between body_pos_w and env_origins: "
f"{body_pos_w.shape[0]} vs {env_origins.shape[0]}."
)
body_pos_env = body_pos_w - env_origins[:, None, :]
rel_pos_env = body_pos_env - root_pos_env[:, None, :]
quat_vec = root_quat_w[:, None, 1:].expand_as(rel_pos_env)
quat_real = root_quat_w[:, None, :1].expand(
-1, rel_pos_env.shape[1], -1
)
t = 2.0 * torch.cross(quat_vec, rel_pos_env, dim=-1)
return rel_pos_env - quat_real * t + torch.cross(quat_vec, t, dim=-1)
def _setup_models_and_optimizer(self):
sample_obs_dict = self.env.reset_all()[0]
sample_td = self._wrap_obs_dict(sample_obs_dict)
actor_cfg = OmegaConf.to_container(
self.config.module_dict.actor, resolve=True
)
critic_cfg = OmegaConf.to_container(
self.config.module_dict.critic, resolve=True
)
actor_cfg["noise_std_type"] = getattr(
self.config, "noise_std_type", "log"
)
actor_cfg["min_sigma"] = getattr(self.config, "min_sigma", 0.1)
actor_cfg["max_sigma"] = getattr(self.config, "max_sigma", 1.5)
actor_cfg["fix_sigma"] = getattr(self.config, "fix_sigma", False)
self._future_mask_prob = float(actor_cfg.get("future_mask_prob", 0.0))
self._future_mask_mode = str(
actor_cfg.get("future_mask_mode", "random_suffix")
).lower()
aux_cfg = self.config.get("aux_state_pred", {})
if isinstance(aux_cfg, dict):
actor_cfg["aux_state_pred"] = dict(aux_cfg)
else:
actor_cfg["aux_state_pred"] = OmegaConf.to_container(
aux_cfg, resolve=True
)
aux_cmd_cfg = self.config.get("aux_router_command_recon", {})
if isinstance(aux_cmd_cfg, dict):
actor_aux_cmd_cfg = dict(aux_cmd_cfg)
else:
actor_aux_cmd_cfg = OmegaConf.to_container(
aux_cmd_cfg, resolve=True
)
aux_switch_cfg = self.config.get("aux_router_switch_penalty", {})
if isinstance(aux_switch_cfg, dict):
actor_aux_switch_cfg = dict(aux_switch_cfg)
else:
actor_aux_switch_cfg = OmegaConf.to_container(
aux_switch_cfg, resolve=True
)
aux_router_future_cfg = self.config.get("aux_router_future_recon", {})
if isinstance(aux_router_future_cfg, dict):
actor_aux_router_future_cfg = dict(aux_router_future_cfg)
else:
actor_aux_router_future_cfg = OmegaConf.to_container(
aux_router_future_cfg, resolve=True
)
dead_margin_cfg = self.config.get("dead_expert_margin_to_topk", {})
if isinstance(dead_margin_cfg, dict):
actor_dead_margin_cfg = dict(dead_margin_cfg)
else:
actor_dead_margin_cfg = OmegaConf.to_container(
dead_margin_cfg, resolve=True
)
selected_margin_cfg = self.config.get(
"selected_expert_margin_to_unselected", {}
)
if isinstance(selected_margin_cfg, dict):
actor_selected_margin_cfg = dict(selected_margin_cfg)
else:
actor_selected_margin_cfg = OmegaConf.to_container(
selected_margin_cfg, resolve=True
)
actor_schema = self._unwrap_obs_schema(
actor_cfg.get("obs_schema", None)
)
critic_schema = self._unwrap_obs_schema(
critic_cfg.get("obs_schema", None)
)
if actor_schema is None:
raise ValueError(
"PPOTF requires actor obs_schema to infer flattened obs dim."
)
if self.use_aux_router_command_recon:
aux_command_schema = self._build_aux_router_command_recon_schema(
actor_schema, self.aux_router_command_recon_term_prefix
)
self.aux_router_command_recon_assembler = TensorDictAssembler(
aux_command_schema, output_mode="flat"
)
actor_aux_cmd_cfg["output_dim"] = int(
self.aux_router_command_recon_assembler.infer_output_dim(
sample_td
)
)
if self.aux_router_command_recon_hidden_dim > 0:
actor_aux_cmd_cfg["hidden_dim"] = (
self.aux_router_command_recon_hidden_dim
)
actor_cfg["aux_router_command_recon"] = actor_aux_cmd_cfg
actor_cfg["aux_router_future_recon"] = actor_aux_router_future_cfg
actor_cfg["aux_router_switch_penalty"] = actor_aux_switch_cfg
actor_cfg["dead_expert_margin_to_topk"] = actor_dead_margin_cfg
actor_cfg["selected_expert_margin_to_unselected"] = (
actor_selected_margin_cfg
)
actor_obs_dim = int(
TensorDictAssembler(
actor_schema, output_mode="flat"
).infer_output_dim(sample_td)
)
use_future_cross_attn = bool(
actor_cfg.get("use_future_cross_attn", False)
)
actor_cls = self._select_actor_wrapper_cls(actor_cfg)
if use_future_cross_attn:
if "flattened_obs" not in actor_schema:
raise ValueError(
"use_future_cross_attn=True requires "
"actor obs_schema.flattened_obs."
)
if "flattened_obs_fut" not in actor_schema:
raise ValueError(
"use_future_cross_attn=True requires "
"actor obs_schema.flattened_obs_fut."
)
state_schema = {"flattened_obs": actor_schema["flattened_obs"]}
future_schema = {
"flattened_obs_fut": actor_schema["flattened_obs_fut"]
}
state_obs_dim = int(
TensorDictAssembler(
state_schema, output_mode="flat"
).infer_output_dim(sample_td)
)
future_asm = TensorDictAssembler(future_schema, output_mode="seq")
future_token_dim = int(future_asm.infer_output_dim(sample_td))
future_seq_len = int(future_asm.seq_len)
actor_cfg["state_obs_dim"] = state_obs_dim
actor_cfg["future_token_dim"] = future_token_dim
actor_cfg["future_seq_len"] = future_seq_len
actor_cfg["input_dim_override"] = state_obs_dim
else:
actor_cfg["input_dim_override"] = actor_obs_dim
self.actor = actor_cls(
obs_schema=actor_schema,
module_config_dict=actor_cfg,
num_actions=self.num_actions,
init_noise_std=self.config.init_noise_std,
obs_example=sample_td,
).to(self.device)
actor_module_unwrapped = self.actor.actor_module
self.aux_command_router_num_moe_layers = int(
getattr(actor_module_unwrapped, "_num_moe_layers", 0)
)
self.aux_command_router_num_fine_experts = int(
getattr(actor_module_unwrapped, "num_fine_experts", 0)
)
if (
self.use_aux_router_switch_penalty
and self.aux_command_router_num_moe_layers <= 0
):
raise ValueError(
"aux_router_switch_penalty requires at least one "
"GroupedMoEBlock."
)
self.critic = PPOCritic(
obs_schema=critic_schema,
module_config_dict=critic_cfg,
obs_example=sample_td,
).to(self.device)
if self.is_main_process:
actor = self.accelerator.unwrap_model(self.actor)
critic = self.accelerator.unwrap_model(self.critic)
logger.info("Actor (TensorDict module):\n{!r}", actor)
logger.info(
"Actor keys: in_keys={} out_keys={}",
list(actor.in_keys),
list(actor.out_keys),
)
logger.info("Actor core nn module:\n{!r}", actor.actor_module)
logger.info("Critic (TensorDict module):\n{!r}", critic)
logger.info(
"Critic keys: in_keys={} out_keys={}",
list(critic.in_keys),
list(critic.out_keys),
)
logger.info("Critic core nn module:\n{!r}", critic.critic_module)
actor_params = sum(p.numel() for p in self.actor.parameters())
critic_params = sum(p.numel() for p in self.critic.parameters())
params_table = [
["Actor(Transformer)", f"{actor_params / 1.0e6:.3f}"],
["Critic", f"{critic_params / 1.0e6:.3f}"],
["Total", f"{(actor_params + critic_params) / 1.0e6:.3f}"],
]
logger.info(
"Model Summary:\n"
+ tabulate(
params_table,
headers=["Model", "Params (M)"],
tablefmt="simple_outline",
)
)
optimizer_class = getattr(optim, self.optimizer_type)
optimizer_kwargs = self._build_optimizer_kwargs(optimizer_class)
if self.optimizer_type == "AdamW":
decay_params = []
non_decay_params = []
for name, p in self.actor.named_parameters():
if not p.requires_grad:
continue
if (
p.ndim < 2
or ("log_std" in name)
or ("bias" in name)
or ("norm" in name)
):
non_decay_params.append(p)
else:
decay_params.append(p)
self.actor_optimizer = optimizer_class(
[
{"params": decay_params, "weight_decay": 0.01},
{"params": non_decay_params, "weight_decay": 0.0},
],
lr=self.actor_learning_rate,
betas=(self.actor_beta1, self.actor_beta2),
**optimizer_kwargs,
)
else:
self.actor_optimizer = optimizer_class(
self.actor.parameters(),
lr=self.actor_learning_rate,
betas=(self.actor_beta1, self.actor_beta2),
**optimizer_kwargs,
)
self.critic_optimizer = optimizer_class(
self.critic.parameters(),
lr=self.critic_learning_rate,
betas=(self.critic_beta1, self.critic_beta2),
**optimizer_kwargs,
)
(
self.actor,
self.critic,
self.actor_optimizer,
self.critic_optimizer,
) = self.accelerator.prepare(
self.actor,
self.critic,
self.actor_optimizer,
self.critic_optimizer,
)
actor_for_kv = self.accelerator.unwrap_model(self.actor)
if hasattr(actor_for_kv, "reset_kv_cache"):
actor_for_kv.reset_kv_cache(self.env.num_envs, self.device)
self._kv_reset_pending = torch.zeros(
self.env.num_envs, dtype=torch.bool, device=self.device
)
self._rollout_future_masks = None
self._rollout_step_idx = 0
def _setup_data_buffers(self):
super()._setup_data_buffers()
self._aux_height_scanner = None
self._aux_contact_sensor = None
self._aux_contact_body_ids = None
self._aux_keybody_body_ids = None
if not self.use_aux_state_pred:
return
if self.use_velocity_transition:
raise ValueError(
"aux_state_pred is not supported with velocity "
"tracking in PPOTF."
)
self.transition_cls = PpoAuxTransition
if self.use_aux_root_height:
if "height_scanner" not in self.env._env.scene.sensors:
raise ValueError(
"aux_state_pred requires a RayCaster sensor "
"named 'height_scanner' "
"in env.scene.sensors."
)
height_scanner = self.env._env.scene.sensors["height_scanner"]
height_scanner.cfg.max_distance = (
self.aux_state_pred_raycast_max_dist
)
height_scanner.cfg.ray_alignment = "world"
height_scanner.cfg.offset.pos = (
0.0,
0.0,
self.aux_state_pred_raycast_z_offset,
)
if height_scanner.is_initialized:
height_scanner.ray_starts[..., 2] = (
self.aux_state_pred_raycast_z_offset
)
self._aux_height_scanner = height_scanner
if self.aux_state_pred_num_contact_bodies > 0:
if "contact_forces" not in self.env._env.scene.sensors:
raise ValueError(
"aux_state_pred.keybody_contact_names requires "
"a ContactSensor "
"named 'contact_forces' in env.scene.sensors."
)
contact_sensor = self.env._env.scene.sensors["contact_forces"]
sensor_body_names = list(contact_sensor.body_names)
body_ids = []
for body_name in self.aux_state_pred_keybody_contact_names:
if body_name not in sensor_body_names:
raise ValueError(
f"Body '{body_name}' not found in contact "
"sensor body_names."
)
body_ids.append(sensor_body_names.index(body_name))
self._aux_contact_sensor = contact_sensor
self._aux_contact_body_ids = torch.tensor(
body_ids, dtype=torch.long, device=self.device
)
if self.aux_state_pred_num_keybody_bodies > 0:
robot_body_names = list(self.env._env.scene["robot"].body_names)
body_ids = []
for body_name in self.aux_state_pred_keybody_rel_pos_names:
if body_name not in robot_body_names:
raise ValueError(
f"Body '{body_name}' not found in robot body_names."
)
body_ids.append(robot_body_names.index(body_name))
self._aux_keybody_body_ids = torch.tensor(
body_ids, dtype=torch.long, device=self.device
)
def _build_transition(
self,
obs_td: TensorDict,
actor_out: TensorDict,
critic_out: TensorDict,
):
if not self.use_aux_state_pred:
return super()._build_transition(obs_td, actor_out, critic_out)
import isaaclab.envs.mdp as isaaclab_mdp
actions = actor_out.get("actions")
actions_log_prob = actor_out.get("actions_log_prob")
mu = actor_out.get("mu")
sigma = actor_out.get("sigma")
values = critic_out.get("values")
zero_scalar = torch.zeros(
self.num_envs,
1,
device=self.device,
dtype=torch.float32,
)
zero_scalar_bool = torch.zeros(
self.num_envs,
1,
device=self.device,
dtype=torch.bool,
)
gt_base_lin_vel_b = isaaclab_mdp.base_lin_vel(self.env._env)
if self.use_aux_root_height:
root_pos_w = isaaclab_mdp.root_pos_w(self.env._env)
if self._aux_height_scanner is None:
raise RuntimeError(
"Aux state prediction expected "
"_aux_height_scanner to be initialized."
)
terrain_z = self._aux_height_scanner.data.ray_hits_w[:, 0, 2:3]
env_origin_z = self.env._env.scene.env_origins[:, 2:3]
terrain_z = torch.where(
torch.isfinite(terrain_z), terrain_z, env_origin_z
)
gt_root_height_rel_terrain = root_pos_w[:, 2:3] - terrain_z
else:
gt_root_height_rel_terrain = torch.zeros(
self.num_envs, 1, device=self.device, dtype=torch.float32
)
if self.aux_state_pred_num_contact_bodies > 0:
if (
self._aux_contact_sensor is None
or self._aux_contact_body_ids is None
):
raise RuntimeError(
"Aux keybody contact prediction expects contact sensor "
"and body ids to be initialized."
)
contact_time = self._aux_contact_sensor.data.current_contact_time[
:, self._aux_contact_body_ids
]
gt_keybody_contacts = (contact_time > 0.0).to(torch.float32)
else:
gt_keybody_contacts = torch.zeros(
self.num_envs, 0, device=self.device, dtype=torch.float32
)
command = self.env._env.command_manager.get_term(self.command_name)
if self.aux_state_pred_num_keybody_bodies > 0:
if self._aux_keybody_body_ids is None:
raise RuntimeError(
"Aux keybody position prediction expects body "
"ids to be initialized."
)
# Both the ref-motion command and robot asset expose bodies in
# simulator order, so the cached robot body indices align here.
gt_ref_keybody_rel_pos = (
command.get_ref_motion_bodylink_rel_pos_cur()[
:, self._aux_keybody_body_ids, :
]
)
robot = self.env._env.scene["robot"]
robot_keybody_global_pos = robot.data.body_pos_w[
:, self._aux_keybody_body_ids, :
]
env_origins = self.env._env.scene.env_origins
root_pos_w = isaaclab_mdp.root_pos_w(self.env._env)
root_quat_w = isaaclab_mdp.root_quat_w(self.env._env)
gt_robot_keybody_rel_pos = (
self._root_relative_body_pos_from_mixed_position_frames(
body_pos_w=robot_keybody_global_pos,
root_pos_env=root_pos_w,
root_quat_w=root_quat_w,
env_origins=env_origins,
)
)
else:
gt_ref_keybody_rel_pos = torch.zeros(
self.num_envs, 0, 3, device=self.device, dtype=torch.float32
)
gt_robot_keybody_rel_pos = torch.zeros(
self.num_envs, 0, 3, device=self.device, dtype=torch.float32
)
gt_denoise_ref_root_lin_vel = torch.zeros(
self.num_envs, 3, device=self.device, dtype=torch.float32
)
gt_denoise_ref_root_ang_vel = torch.zeros(
self.num_envs, 3, device=self.device, dtype=torch.float32
)
gt_denoise_ref_dof_pos = torch.zeros(
self.num_envs,
actions.shape[-1],
device=self.device,
dtype=torch.float32,
)
if (
self.use_aux_denoise_ref_root_lin_vel
or self.use_aux_denoise_ref_root_ang_vel
or self.use_aux_denoise_ref_dof_pos
):
try:
if self.use_aux_denoise_ref_root_lin_vel:
gt_denoise_ref_root_lin_vel = (
command.get_ref_motion_base_linvel_cur(
prefix="ft_ref_"
)
- command.get_ref_motion_base_linvel_cur(prefix="ref_")
)
if self.use_aux_denoise_ref_root_ang_vel:
gt_denoise_ref_root_ang_vel = (
command.get_ref_motion_base_angvel_cur(
prefix="ft_ref_"
)
- command.get_ref_motion_base_angvel_cur(prefix="ref_")
)
if self.use_aux_denoise_ref_dof_pos:
gt_denoise_ref_dof_pos = (
command.get_ref_motion_dof_pos_cur(prefix="ft_ref_")
- command.get_ref_motion_dof_pos_cur(prefix="ref_")
)
expected_shape = (self.num_envs, actions.shape[-1])
if tuple(gt_denoise_ref_dof_pos.shape) != expected_shape:
raise ValueError(
"gt_denoise_ref_dof_pos must match the action-aligned "
"DoF shape "
f"{expected_shape}, got "
f"{tuple(gt_denoise_ref_dof_pos.shape)}."
)
except KeyError as exc:
raise RuntimeError(
"Filtered reference tensors are unavailable for "
"aux_denoise_* targets. Enable online filtering or "
"materialize ft_ref_* tensors in the motion cache."
) from exc
return self.transition_cls(
obs=obs_td,
actions=actions.detach(),
teacher_actions=torch.zeros_like(actions),
mu=mu.detach(),
sigma=sigma.detach(),
actions_log_prob=actions_log_prob[..., None].detach(),
values=values.detach(),
rewards=zero_scalar.clone(),
dones=zero_scalar_bool,
returns=zero_scalar.clone(),
advantages=zero_scalar.clone(),
gt_base_lin_vel_b=gt_base_lin_vel_b.detach(),
gt_root_height_rel_terrain=gt_root_height_rel_terrain.detach(),
gt_keybody_contacts=gt_keybody_contacts.detach(),
gt_ref_keybody_rel_pos=gt_ref_keybody_rel_pos.detach(),
gt_robot_keybody_rel_pos=gt_robot_keybody_rel_pos.detach(),
gt_denoise_ref_root_lin_vel=gt_denoise_ref_root_lin_vel.detach(),
gt_denoise_ref_root_ang_vel=gt_denoise_ref_root_ang_vel.detach(),
gt_denoise_ref_dof_pos=gt_denoise_ref_dof_pos.detach(),
batch_size=[self.num_envs],
device=self.device,
)
def _build_storage(self, obs_td: TensorDict):
actor_for_kv = self.accelerator.unwrap_model(self.actor)
actor_policy = actor_for_kv.actor_module
if bool(getattr(actor_policy, "use_future_cross_attn", False)):
n_fut = int(getattr(actor_policy, "future_seq_len", 0))
if n_fut <= 0:
raise ValueError(
"future_seq_len must be positive when "
"use_future_cross_attn=True"
)
obs_td = obs_td.clone(recurse=False)
obs_td.set(
"future_mask",
torch.ones(
self.env.num_envs,
n_fut,
dtype=torch.bool,
device=self.device,
),
)
return super()._build_storage(obs_td)
def _sample_iteration_future_masks(self) -> torch.Tensor | None:
actor_for_kv = self.accelerator.unwrap_model(self.actor)
actor_policy = actor_for_kv.actor_module
if not bool(getattr(actor_policy, "use_future_cross_attn", False)):
return None
n_fut = int(getattr(actor_policy, "future_seq_len", 0))
if n_fut <= 0:
raise ValueError(
"future_seq_len must be positive when "
"use_future_cross_attn=True"
)
if self._future_mask_mode != "random_suffix":
raise ValueError(
"Unsupported future_mask_mode: "
f"{self._future_mask_mode}. "
"Expected 'random_suffix'."
)
num_steps = int(self.num_steps_per_env)
num_envs = int(self.env.num_envs)
keep = torch.ones(
num_steps,
num_envs,
n_fut,
dtype=torch.bool,
device=self.device,
)
if bool(getattr(self, "_offline_evaluating", False)):
return keep
if self._future_mask_prob <= 0.0:
return keep
apply_mask = (
torch.rand(num_steps, num_envs, device=self.device)
< self._future_mask_prob
)
keep_len = torch.randint(
1,
n_fut + 1,
(num_steps, num_envs),
device=self.device,
)
full_len = torch.full(
(num_steps, num_envs),
n_fut,
dtype=torch.long,
device=self.device,
)
keep_len = torch.where(apply_mask, keep_len, full_len)
token_idx = torch.arange(n_fut, device=self.device, dtype=torch.long)[
None, None, :
]
return token_idx < keep_len[:, :, None]
def _reset_rollout_forward_state(self) -> None:
actor_for_kv = self.accelerator.unwrap_model(self.actor)
actor_for_kv.clear_env_cache(None)
actor_policy = actor_for_kv.actor_module
actor_policy.reset_routing_stats()
actor_policy.set_collect_routing_stats(True)
self._kv_reset_pending.zero_()
self._rollout_future_masks = self._sample_iteration_future_masks()
self._rollout_step_idx = 0
def _rollout_forward(
self,
obs_td: TensorDict,
*,
actor_mode: str = "sampling",
collect_transition: bool = True,
track_episode_stats: bool = True,
) -> TensorDict:
if collect_transition and self._rollout_future_masks is not None:
if self._rollout_step_idx >= int(
self._rollout_future_masks.shape[0]
):
raise RuntimeError(
"Rollout future-mask step index exceeded "
"pre-sampled mask length."
)
obs_td = obs_td.clone(recurse=False)
obs_td.set(
"future_mask",
self._rollout_future_masks[self._rollout_step_idx],
)
actor_for_kv = self.accelerator.unwrap_model(self.actor)
if torch.any(self._kv_reset_pending):
env_ids = torch.nonzero(self._kv_reset_pending).squeeze(-1)
if env_ids.numel() > 0:
actor_for_kv.clear_env_cache(env_ids)
self._kv_reset_pending[env_ids] = False
next_obs_td = super()._rollout_forward(
obs_td,
actor_mode=actor_mode,
collect_transition=collect_transition,
track_episode_stats=track_episode_stats,
)
if collect_transition and self._rollout_future_masks is not None:
self._rollout_step_idx += 1
if not collect_transition:
dones = self._last_rollout_dones
if dones is not None:
self._kv_reset_pending |= (
dones.view(-1).to(torch.bool).to(self.device)
)
return next_obs_td
def process_env_step(
self,
rewards: torch.Tensor,
dones: torch.Tensor,
time_outs: torch.Tensor,
infos: dict,
) -> None:
super().process_env_step(rewards, dones, time_outs, infos)
if getattr(self, "_kv_reset_pending", None) is not None:
self._kv_reset_pending |= (
dones.view(-1).to(torch.bool).to(self.device)
)
@staticmethod
def _build_episode_causal_mask(dones_seq: torch.Tensor) -> torch.Tensor:
"""Build [N, T, T] mask: causal and within the same episode segment."""
n, t, _ = dones_seq.shape
device = dones_seq.device
dones = dones_seq.squeeze(-1).to(torch.long)
seg = torch.cumsum(dones, dim=1) - dones
same = seg[:, :, None] == seg[:, None, :]
causal = torch.tril(torch.ones(t, t, dtype=torch.bool, device=device))
return same & causal
@staticmethod
def _resolve_sequence_batch_partition(
num_envs: int,
num_mini_batches: int,
) -> tuple[int, int]:
if num_envs <= 0:
raise RuntimeError(
"PPOTF sequence batching requires at least one "
"environment on each rank."
)
effective_num_mini_batches = max(
1, min(int(num_mini_batches), int(num_envs))
)
mini_batch_envs = max(
1,
(num_envs + effective_num_mini_batches - 1)
// effective_num_mini_batches,
)
return effective_num_mini_batches, mini_batch_envs
def _sequence_batches(
self, num_mini_batches: int, num_epochs: int
) -> Generator[tuple, None, None]:
data = self.storage.data
obs_seq = data["obs"].transpose(0, 1)
actions_seq = data["actions"].transpose(0, 1)
values_seq = data["values"].transpose(0, 1)
rewards_seq = data["rewards"].transpose(0, 1)
returns_seq = data["returns"].transpose(0, 1)
adv_seq = data["advantages"].transpose(0, 1)
old_logp_seq = data["actions_log_prob"].transpose(0, 1)
old_mu_seq = data["mu"].transpose(0, 1)
old_sigma_seq = data["sigma"].transpose(0, 1)
dones_seq = data["dones"].transpose(0, 1)
gt_base_lin_vel_seq = None
gt_root_height_seq = None
gt_keybody_contact_seq = None
gt_ref_keybody_rel_pos_seq = None
gt_robot_keybody_rel_pos_seq = None
gt_denoise_ref_root_lin_vel_seq = None
gt_denoise_ref_root_ang_vel_seq = None
gt_denoise_ref_dof_pos_seq = None
if self.use_aux_state_pred:
gt_base_lin_vel_seq = data["gt_base_lin_vel_b"].transpose(0, 1)
gt_root_height_seq = data["gt_root_height_rel_terrain"].transpose(
0, 1
)
gt_keybody_contact_seq = data["gt_keybody_contacts"].transpose(
0, 1
)
gt_ref_keybody_rel_pos_seq = data[
"gt_ref_keybody_rel_pos"
].transpose(0, 1)
gt_robot_keybody_rel_pos_seq = data[
"gt_robot_keybody_rel_pos"
].transpose(0, 1)
gt_denoise_ref_root_lin_vel_seq = data[
"gt_denoise_ref_root_lin_vel"
].transpose(0, 1)
gt_denoise_ref_root_ang_vel_seq = data[
"gt_denoise_ref_root_ang_vel"
].transpose(0, 1)
gt_denoise_ref_dof_pos_seq = data[
"gt_denoise_ref_dof_pos"
].transpose(0, 1)
num_envs = int(actions_seq.shape[0])
if num_envs <= 0:
raise RuntimeError(
"PPOTF sequence batching requires at least one "
"environment on each rank, "
f"got num_envs={num_envs}."
)
num_mini_batches, mb_env = self._resolve_sequence_batch_partition(
num_envs, num_mini_batches
)
env_indices = torch.randperm(num_envs, device=self.device)
for _ in range(num_epochs):
for i in range(num_mini_batches):
start = i * mb_env
if start >= num_envs:
break
end = min(num_envs, (i + 1) * mb_env)
idx = env_indices[start:end]
obs_b = obs_seq[idx]
actions_b = actions_seq[idx]
values_b = values_seq[idx]
rewards_b = rewards_seq[idx]
returns_b = returns_seq[idx]
adv_b = adv_seq[idx]
old_logp_b = old_logp_seq[idx]
old_mu_b = old_mu_seq[idx]
old_sigma_b = old_sigma_seq[idx]
dones_b = dones_seq[idx]
gt_base_lin_vel_b = (
gt_base_lin_vel_seq[idx]
if gt_base_lin_vel_seq is not None
else None
)
gt_root_height_b = (
gt_root_height_seq[idx]
if gt_root_height_seq is not None
else None
)
gt_keybody_contact_b = (
gt_keybody_contact_seq[idx]
if gt_keybody_contact_seq is not None
else None
)
gt_ref_keybody_rel_pos_b = (
gt_ref_keybody_rel_pos_seq[idx]
if gt_ref_keybody_rel_pos_seq is not None
else None
)
gt_robot_keybody_rel_pos_b = (
gt_robot_keybody_rel_pos_seq[idx]
if gt_robot_keybody_rel_pos_seq is not None
else None
)
gt_denoise_ref_root_lin_vel_b = (
gt_denoise_ref_root_lin_vel_seq[idx]
if gt_denoise_ref_root_lin_vel_seq is not None
else None
)
gt_denoise_ref_root_ang_vel_b = (
gt_denoise_ref_root_ang_vel_seq[idx]
if gt_denoise_ref_root_ang_vel_seq is not None
else None
)
gt_denoise_ref_dof_pos_b = (
gt_denoise_ref_dof_pos_seq[idx]
if gt_denoise_ref_dof_pos_seq is not None
else None
)
attn_mask = self._build_episode_causal_mask(dones_b)
yield (
obs_b,
actions_b,
values_b,
adv_b,
returns_b,
rewards_b,
old_logp_b,
old_mu_b,
old_sigma_b,
attn_mask,
gt_base_lin_vel_b,
gt_root_height_b,
gt_keybody_contact_b,
gt_ref_keybody_rel_pos_b,
gt_robot_keybody_rel_pos_b,
gt_denoise_ref_root_lin_vel_b,
gt_denoise_ref_root_ang_vel_b,
gt_denoise_ref_dof_pos_b,
)
def update(self):
actor_unwrapped = self.accelerator.unwrap_model(self.actor)
actor_policy = actor_unwrapped.actor_module
actor_policy.set_collect_routing_stats(False)
mean_value_loss = 0.0
mean_surrogate_loss = 0.0
mean_entropy = 0.0
mean_kl_token = 0.0
mean_kl_loss = 0.0
mean_kl_analytic = 0.0
critic_explained_variance = self._compute_explained_variance(
target=self.storage.data["returns"],
prediction=self.storage.data["values"],
)
mean_aux_base_lin_vel_nll = 0.0
mean_aux_root_height_nll = 0.0
mean_aux_base_lin_vel_std = 0.0
mean_aux_root_height_std = 0.0
mean_aux_keybody_contact_bce = 0.0
mean_aux_keybody_contact_acc = 0.0
mean_aux_ref_keybody_rel_pos_mse = 0.0
mean_aux_robot_keybody_rel_pos_mse = 0.0
mean_aux_denoise_ref_root_lin_vel_huber = 0.0
mean_aux_denoise_ref_root_ang_vel_huber = 0.0
mean_aux_denoise_ref_dof_pos_huber = 0.0
mean_aux_router_command_recon_mse = 0.0
mean_aux_router_future_recon_huber = 0.0
mean_aux_router_switch_penalty_js = 0.0
mean_dead_expert_margin_to_topk_loss = 0.0
mean_router_expert_orthogonal_loss = 0.0
mean_selected_expert_margin_to_unselected_loss = 0.0
moe_layers = [
layer
for layer in actor_policy.layers
if isinstance(layer, GroupedMoEBlock)
]
(
effective_num_mini_batches,
mini_batch_envs,
) = self._resolve_sequence_batch_partition(
self.storage.num_envs, self.num_mini_batches
)
self._last_update_metrics = {
"0-Train/configured_num_mini_batches": float(
self.configured_num_mini_batches
),
"0-Train/requested_num_mini_batches": float(
self.requested_num_mini_batches
),
"0-Train/effective_num_mini_batches": float(
effective_num_mini_batches
),
"0-Train/mini_batch_size_per_rank": float(
mini_batch_envs * self.num_steps_per_env
),
"0-Train/mini_batch_num_envs_per_rank": float(mini_batch_envs),
"0-Train/num_updates_executed": 0.0,
"0-Train/lr_scale_factor": float(self.distributed_lr_scale_factor),
"0-Train/scalable_distributed_update": float(
self.distributed_update_mode == "scalable"
),
"0-Train/kl_windowed": 0.0,
"0-Train/kl_stop_triggered": 0.0,
"0-Train/kl_stop_analytic": 0.0,
"0-Train/kl_analytic_batch_last": 0.0,
"0-Train/kl_analytic_batch_max": 0.0,
"0-Train/clip_fraction_batch_mean": 0.0,
"0-Train/clip_fraction_batch_last": 0.0,
}
entropy_coef = self._get_effective_entropy_coef()
generator = self._sequence_batches(
effective_num_mini_batches,
self.num_learning_epochs,
)
measure_analytic_kl = self.desired_kl is not None
normalize_per_mb = bool(self.normalize_advantage_per_mini_batch)
num_updates = 0
num_kl_measurements = 0
kl_stop_triggered = False
kl_stop_analytic = 0.0
kl_windowed = None
recent_analytic_kls: list[float] = []
kl_analytic_batch_last = 0.0
kl_analytic_batch_max = 0.0
clip_fraction_batch_mean = 0.0
clip_fraction_batch_last = 0.0
for (
obs_b,
actions_b,
target_values_b,
advantages_b,
returns_b,
_rewards_b,
old_logp_b,
old_mu_b,
old_sigma_b,
attn_mask_b,
gt_base_lin_vel_b,
gt_root_height_b,
gt_keybody_contact_b,
gt_ref_keybody_rel_pos_b,
gt_robot_keybody_rel_pos_b,
gt_denoise_ref_root_lin_vel_b,
gt_denoise_ref_root_ang_vel_b,
gt_denoise_ref_dof_pos_b,
) in generator:
valid_tok = attn_mask_b.diagonal(dim1=1, dim2=2).to(torch.float32)
valid_count = valid_tok.sum().clamp_min(1.0)
if normalize_per_mb:
with torch.no_grad():
flat = advantages_b.view(-1).float()
if self.global_advantage_norm and self.is_distributed:
count = torch.tensor(
[flat.numel()],
device=self.device,
dtype=torch.float32,
)
sum_g = self.accelerator.reduce(
flat.sum(), reduction="sum"
)
sqsum_g = self.accelerator.reduce(
(flat * flat).sum(), reduction="sum"
)
count_g = self.accelerator.reduce(
count, reduction="sum"
)
mean = sum_g / count_g
var = (sqsum_g / count_g) - mean * mean
std = torch.sqrt(var.clamp_min(1.0e-8))
else:
mean = flat.mean()
std = flat.std().clamp_min(1.0e-8)
advantages_b = (advantages_b - mean) / std
b, t = int(obs_b.batch_size[0]), int(obs_b.batch_size[1])
critic_obs_flat = obs_b.flatten(0, 1)
with self.accelerator.autocast():
actor_out = self.actor(
obs_b,
actions=actions_b,
mode="sequence_logp",
attn_mask=attn_mask_b,
update_obs_norm=False,
)
critic_out = self.critic(
critic_obs_flat, update_obs_norm=False
)
logp_new_b = actor_out.get("actions_log_prob")
mu_b = actor_out.get("mu")
sigma_b = actor_out.get("sigma")
entropy_b = actor_out.get("entropy")
v_pred_flat = critic_out.get("values")
value_batch = v_pred_flat.reshape(b, t, -1)
returns_batch_norm = returns_b
target_values_batch_norm = target_values_b
analytic_kl = None
if measure_analytic_kl:
analytic_kl = self._compute_analytic_kl(
old_mu=old_mu_b.float(),
old_sigma=old_sigma_b.float(),
new_mu=mu_b.float(),
new_sigma=sigma_b.float(),
weight=valid_tok,
)
mean_kl_analytic += analytic_kl
num_kl_measurements += 1
kl_analytic_batch_last = analytic_kl
kl_analytic_batch_max = max(kl_analytic_batch_max, analytic_kl)
recent_analytic_kls.append(analytic_kl)
if len(recent_analytic_kls) > self.kl_early_stop_window_size:
recent_analytic_kls.pop(0)
kl_windowed = self._compute_windowed_kl_signal(
recent_analytic_kls
)
if self._should_early_stop_for_kl(
kl_windowed, num_kl_measurements
):
kl_stop_triggered = True
kl_stop_analytic = analytic_kl
break
logp_new = logp_new_b.squeeze(-1).float()
logp_old = old_logp_b.squeeze(-1).float()
ratio = torch.exp(logp_new - logp_old)
clip_fraction = self._compute_clip_fraction(
ratio, weight=valid_tok
)
clip_fraction_batch_mean += clip_fraction
clip_fraction_batch_last = clip_fraction
adv = advantages_b.squeeze(-1)
s1 = ratio * adv
s2 = (
torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
)
* adv
)
surrogate_loss = (
-torch.min(s1, s2) * valid_tok
).sum() / valid_count
if self.use_clipped_value_loss:
value_clipped = target_values_batch_norm + (
value_batch - target_values_batch_norm
).clamp(-self.clip_param, self.clip_param)
value_losses = (value_batch - returns_batch_norm).pow(2)
value_losses_clipped = (
value_clipped - returns_batch_norm
).pow(2)
v_max = torch.max(value_losses, value_losses_clipped).squeeze(
-1
)
value_loss = (v_max * valid_tok).sum() / valid_count
else:
v_err = (returns_batch_norm - value_batch).pow(2).squeeze(-1)
value_loss = (v_err * valid_tok).sum() / valid_count
actor_loss = surrogate_loss
critic_loss = self.value_loss_coef * value_loss
aux_base_lin_vel_loss = None
aux_root_height_loss = None
aux_base_lin_vel_std = None
aux_root_height_std = None
aux_keybody_contact_loss = None
aux_keybody_contact_acc = None
aux_ref_keybody_rel_pos_loss = None
aux_robot_keybody_rel_pos_loss = None
aux_denoise_ref_root_lin_vel_loss = None
aux_denoise_ref_root_ang_vel_loss = None
aux_denoise_ref_dof_pos_loss = None
aux_router_command_recon_loss = None
aux_router_future_recon_loss = None
aux_router_switch_penalty_loss = None
dead_expert_margin_to_topk_loss = None
router_expert_orthogonal_loss = None
selected_expert_margin_to_unselected_loss = None
if self.use_aux_state_pred:
aux_base_lin_vel_loc = actor_out.get("aux_base_lin_vel_loc")
aux_base_lin_vel_log_std = actor_out.get(
"aux_base_lin_vel_log_std"
)
aux_base_lin_vel_std = torch.clamp(
torch.exp(aux_base_lin_vel_log_std),
min=self.aux_state_pred_min_std,
max=self.aux_state_pred_max_std,
)
aux_base_lin_vel_nll = 0.5 * (
torch.square(
(gt_base_lin_vel_b - aux_base_lin_vel_loc)
/ aux_base_lin_vel_std
)
+ 2.0 * torch.log(aux_base_lin_vel_std + 1.0e-8)
).sum(dim=-1)
aux_base_lin_vel_loss = (
aux_base_lin_vel_nll * valid_tok
).sum() / valid_count
actor_loss = (
actor_loss
+ self.aux_state_pred_w_base_lin_vel
* aux_base_lin_vel_loss
)
aux_root_height_loc = actor_out.get("aux_root_height_loc")
aux_root_height_log_std = actor_out.get(
"aux_root_height_log_std"
)
if self.use_aux_root_height and gt_root_height_b is not None:
aux_root_height_std = torch.clamp(
torch.exp(aux_root_height_log_std),
min=self.aux_state_pred_min_std,
max=self.aux_state_pred_max_std,
)
aux_root_height_nll = 0.5 * (
torch.square(
(gt_root_height_b - aux_root_height_loc)
/ aux_root_height_std
)
+ 2.0 * torch.log(aux_root_height_std + 1.0e-8)
).sum(dim=-1)
aux_root_height_loss = (
aux_root_height_nll * valid_tok
).sum() / valid_count
actor_loss = (
actor_loss
+ self.aux_state_pred_w_root_height
* aux_root_height_loss
)
else:
actor_loss = actor_loss + 0.0 * (
aux_root_height_loc.sum()
+ aux_root_height_log_std.sum()
)
if (
self.aux_state_pred_num_contact_bodies > 0
and gt_keybody_contact_b is not None
):
aux_keybody_contact_logits = actor_out.get(
"aux_keybody_contact_logits"
)
contact_bce = F.binary_cross_entropy_with_logits(
aux_keybody_contact_logits,
gt_keybody_contact_b,
reduction="none",
).mean(dim=-1)
aux_keybody_contact_loss = (
contact_bce * valid_tok
).sum() / valid_count
actor_loss = (
actor_loss
+ self.aux_state_pred_w_keybody_contact
* aux_keybody_contact_loss
)
contact_pred = (aux_keybody_contact_logits > 0.0).to(
gt_keybody_contact_b.dtype
)
contact_acc_tok = (
(contact_pred == gt_keybody_contact_b)
.to(torch.float32)
.mean(dim=-1)
)
aux_keybody_contact_acc = (
contact_acc_tok * valid_tok
).sum() / valid_count
aux_ref_keybody_rel_pos = actor_out.get(
"aux_ref_keybody_rel_pos"
)
aux_robot_keybody_rel_pos = actor_out.get(
"aux_robot_keybody_rel_pos"
)
if (
self.aux_state_pred_num_keybody_bodies > 0
and gt_ref_keybody_rel_pos_b is not None
):
aux_ref_keybody_rel_pos_loss = (
self._masked_aux_keybody_mse(
aux_ref_keybody_rel_pos,
gt_ref_keybody_rel_pos_b,
valid_tok,
)
)
actor_loss = (
actor_loss
+ self.aux_state_pred_w_ref_keybody_rel_pos
* aux_ref_keybody_rel_pos_loss
)
elif aux_ref_keybody_rel_pos.numel() > 0:
actor_loss = (
actor_loss + 0.0 * aux_ref_keybody_rel_pos.sum()
)
if (
self.aux_state_pred_num_keybody_bodies > 0
and gt_robot_keybody_rel_pos_b is not None
):
aux_robot_keybody_rel_pos_loss = (
self._masked_aux_keybody_mse(
aux_robot_keybody_rel_pos,
gt_robot_keybody_rel_pos_b,
valid_tok,
)
)
actor_loss = (
actor_loss
+ self.aux_state_pred_w_robot_keybody_rel_pos
* aux_robot_keybody_rel_pos_loss
)
elif aux_robot_keybody_rel_pos.numel() > 0:
actor_loss = (
actor_loss + 0.0 * aux_robot_keybody_rel_pos.sum()
)
if self.use_aux_denoise_ref_root_lin_vel:
aux_denoise_ref_root_lin_vel_residual = actor_out.get(
"aux_denoise_ref_root_lin_vel_residual"
)
aux_denoise_ref_root_lin_vel_loss = self._masked_aux_huber(
pred=aux_denoise_ref_root_lin_vel_residual,
target=gt_denoise_ref_root_lin_vel_b,
valid_tok=valid_tok,
beta=self.aux_denoise_residual_huber_beta,
)
actor_loss = (
actor_loss
+ self.aux_state_pred_w_denoise_ref_root_lin_vel
* aux_denoise_ref_root_lin_vel_loss
)
if self.use_aux_denoise_ref_root_ang_vel:
aux_denoise_ref_root_ang_vel_residual = actor_out.get(
"aux_denoise_ref_root_ang_vel_residual"
)
aux_denoise_ref_root_ang_vel_loss = self._masked_aux_huber(
pred=aux_denoise_ref_root_ang_vel_residual,
target=gt_denoise_ref_root_ang_vel_b,
valid_tok=valid_tok,
beta=self.aux_denoise_residual_huber_beta,
)
actor_loss = (
actor_loss
+ self.aux_state_pred_w_denoise_ref_root_ang_vel
* aux_denoise_ref_root_ang_vel_loss
)
if self.use_aux_denoise_ref_dof_pos:
aux_denoise_ref_dof_pos_residual = actor_out.get(
"aux_denoise_ref_dof_pos_residual"
)
aux_denoise_ref_dof_pos_loss = self._masked_aux_huber(
pred=aux_denoise_ref_dof_pos_residual,
target=gt_denoise_ref_dof_pos_b,
valid_tok=valid_tok,
beta=self.aux_denoise_residual_huber_beta,
)
actor_loss = (
actor_loss
+ self.aux_state_pred_w_denoise_ref_dof_pos
* aux_denoise_ref_dof_pos_loss
)
if self.use_aux_router_command_recon:
if self.aux_router_command_recon_assembler is None:
raise ValueError(
"aux_router_command_recon is enabled but command "
"assembler was not initialized."
)
aux_router_command_recon_pred = actor_out.get(
"aux_router_command_recon"
)
gt_aux_router_command_recon_b = (
self.aux_router_command_recon_assembler(
obs_b.flatten(0, 1)
).reshape(b, t, -1)
)
aux_router_command_recon_loss = self._masked_aux_mse(
aux_router_command_recon_pred,
gt_aux_router_command_recon_b,
valid_tok,
)
actor_loss = (
actor_loss
+ self.aux_router_command_recon_weight
* aux_router_command_recon_loss
)
if self.use_aux_router_future_recon:
aux_router_future_recon_loss = (
self._compute_aux_router_future_recon_loss(
actor_wrapper=actor_unwrapped,
actor_out=actor_out,
obs_b=obs_b,
valid_tok=valid_tok,
)
)
actor_loss = (
actor_loss
+ self.aux_router_future_recon_weight
* aux_router_future_recon_loss
)
if self.use_aux_router_switch_penalty:
if self.aux_router_switch_penalty_metric == "js":
aux_router_features = actor_out.get("router_features")
aux_router_switch_penalty_loss = self._masked_adjacent_router_js(
router_features=aux_router_features,
valid_tok=valid_tok,
num_moe_layers=self.aux_command_router_num_moe_layers,
num_fine_experts=self.aux_command_router_num_fine_experts,
)
else:
aux_router_temporal_features = actor_out.get(
"router_temporal_features"
)
aux_router_switch_penalty_loss = self._masked_adjacent_router_normed_smooth_l1(
router_temporal_features=aux_router_temporal_features,
valid_tok=valid_tok,
num_moe_layers=self.aux_command_router_num_moe_layers,
num_fine_experts=self.aux_command_router_num_fine_experts,
beta=self.aux_router_switch_penalty_beta,
)
aux_router_switch_penalty_loss = (
aux_router_switch_penalty_loss.to(actor_loss.dtype)
)
actor_loss = (
actor_loss
+ self.aux_router_switch_penalty_weight
* aux_router_switch_penalty_loss
)
if self.use_dead_expert_margin_to_topk and len(moe_layers) > 0:
margin_losses = [
layer.last_dead_expert_margin_to_topk_loss
for layer in moe_layers
if layer.last_dead_expert_margin_to_topk_loss is not None
]
if len(margin_losses) > 0:
dead_expert_margin_to_topk_loss = torch.stack(
[
loss.to(actor_loss.device, dtype=actor_loss.dtype)
for loss in margin_losses
]
).mean()
actor_loss = (
actor_loss
+ self.dead_expert_margin_to_topk_weight
* dead_expert_margin_to_topk_loss
)
if self.use_router_expert_orthogonal and len(moe_layers) > 0:
orth_losses = []
for layer in moe_layers:
layer_orth_loss, _, _ = (
self._compute_routed_expert_orthogonal_loss(
layer,
dtype=actor_loss.dtype,
device=actor_loss.device,
)
)
orth_losses.append(layer_orth_loss)
if len(orth_losses) > 0:
router_expert_orthogonal_loss = torch.stack(
orth_losses
).mean()
actor_loss = (
actor_loss
+ self.router_expert_orthogonal_weight
* router_expert_orthogonal_loss
)
if (
self.use_selected_expert_margin_to_unselected
and len(moe_layers) > 0
):
selected_margin_losses = [
layer.last_selected_expert_margin_to_unselected_loss
for layer in moe_layers
if layer.last_selected_expert_margin_to_unselected_loss
is not None
]
if len(selected_margin_losses) > 0:
selected_expert_margin_to_unselected_loss = torch.stack(
[
loss.to(actor_loss.device, dtype=actor_loss.dtype)
for loss in selected_margin_losses
]
).mean()
actor_loss = (
actor_loss
+ self.selected_expert_margin_to_unselected_weight
* selected_expert_margin_to_unselected_loss
)
kl_coef = float(
getattr(self.config, "kl_coef", self.desired_kl or 0.0) or 0.0
)
if kl_coef > 0.0:
delta_logp = logp_new - logp_old
kl_token = (
ratio.detach() * delta_logp * valid_tok
).sum() / valid_count
kl_loss = kl_coef * kl_token
actor_loss = actor_loss + kl_loss
mean_kl_token += float(kl_token.item())
mean_kl_loss += float(kl_loss.item())
if entropy_coef > 0.0:
ent_tok = entropy_b.squeeze(-1)
entropy_loss = (ent_tok * valid_tok).sum() / valid_count
actor_loss = actor_loss - entropy_coef * entropy_loss
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
self.accelerator.backward(actor_loss)
self.accelerator.backward(critic_loss)
if self.max_grad_norm is not None:
self.accelerator.clip_grad_norm_(
self.actor.parameters(), self.max_grad_norm
)
self.accelerator.clip_grad_norm_(
self.critic.parameters(), self.max_grad_norm
)
self.actor_optimizer.step()
self.critic_optimizer.step()
num_updates += 1
mean_value_loss += float(value_loss.item())
mean_surrogate_loss += float(surrogate_loss.item())
mean_entropy += float(entropy_b.mean().item())
if aux_base_lin_vel_loss is not None:
mean_aux_base_lin_vel_nll += float(
aux_base_lin_vel_loss.item()
)
if aux_root_height_loss is not None:
mean_aux_root_height_nll += float(aux_root_height_loss.item())
if aux_base_lin_vel_std is not None:
mean_aux_base_lin_vel_std += float(
aux_base_lin_vel_std.mean().item()
)
if aux_root_height_std is not None:
mean_aux_root_height_std += float(
aux_root_height_std.mean().item()
)
if aux_keybody_contact_loss is not None:
mean_aux_keybody_contact_bce += float(
aux_keybody_contact_loss.item()
)
if aux_keybody_contact_acc is not None:
mean_aux_keybody_contact_acc += float(
aux_keybody_contact_acc.item()
)
if aux_ref_keybody_rel_pos_loss is not None:
mean_aux_ref_keybody_rel_pos_mse += float(
aux_ref_keybody_rel_pos_loss.item()
)
if aux_robot_keybody_rel_pos_loss is not None:
mean_aux_robot_keybody_rel_pos_mse += float(
aux_robot_keybody_rel_pos_loss.item()
)
if aux_denoise_ref_root_lin_vel_loss is not None:
mean_aux_denoise_ref_root_lin_vel_huber += float(
aux_denoise_ref_root_lin_vel_loss.item()
)
if aux_denoise_ref_root_ang_vel_loss is not None:
mean_aux_denoise_ref_root_ang_vel_huber += float(
aux_denoise_ref_root_ang_vel_loss.item()
)
if aux_denoise_ref_dof_pos_loss is not None:
mean_aux_denoise_ref_dof_pos_huber += float(
aux_denoise_ref_dof_pos_loss.item()
)
if aux_router_command_recon_loss is not None:
mean_aux_router_command_recon_mse += float(
aux_router_command_recon_loss.item()
)
if aux_router_future_recon_loss is not None:
mean_aux_router_future_recon_huber += float(
aux_router_future_recon_loss.item()
)
if aux_router_switch_penalty_loss is not None:
mean_aux_router_switch_penalty_js += float(
aux_router_switch_penalty_loss.item()
)
if dead_expert_margin_to_topk_loss is not None:
mean_dead_expert_margin_to_topk_loss += float(
dead_expert_margin_to_topk_loss.item()
)
if router_expert_orthogonal_loss is not None:
mean_router_expert_orthogonal_loss += float(
router_expert_orthogonal_loss.item()
)
if selected_expert_margin_to_unselected_loss is not None:
mean_selected_expert_margin_to_unselected_loss += float(
selected_expert_margin_to_unselected_loss.item()
)
actor_policy.apply_dynamic_bias_update_from_stats()
denom = max(1, num_updates)
mean_value_loss /= denom
mean_surrogate_loss /= denom
mean_entropy /= denom
mean_kl_token /= denom
mean_kl_loss /= denom
mean_kl_analytic /= max(1, num_kl_measurements)
clip_fraction_batch_mean /= denom
if self.schedule == "adaptive":
self._apply_adaptive_lr(kl_windowed)
mean_aux_base_lin_vel_nll /= denom
mean_aux_root_height_nll /= denom
mean_aux_base_lin_vel_std /= denom
mean_aux_root_height_std /= denom
mean_aux_keybody_contact_bce /= denom
mean_aux_keybody_contact_acc /= denom
mean_aux_ref_keybody_rel_pos_mse /= denom
mean_aux_robot_keybody_rel_pos_mse /= denom
mean_aux_denoise_ref_root_lin_vel_huber /= denom
mean_aux_denoise_ref_root_ang_vel_huber /= denom
mean_aux_denoise_ref_dof_pos_huber /= denom
mean_aux_router_command_recon_mse /= denom
mean_aux_router_future_recon_huber /= denom
mean_aux_router_switch_penalty_js /= denom
mean_dead_expert_margin_to_topk_loss /= denom
mean_router_expert_orthogonal_loss /= denom
mean_selected_expert_margin_to_unselected_loss /= denom
self._last_update_metrics["0-Train/num_updates_executed"] = float(
num_updates
)
self._last_update_metrics["0-Train/kl_windowed"] = float(
kl_windowed or 0.0
)
self._last_update_metrics["0-Train/kl_stop_triggered"] = float(
kl_stop_triggered
)
self._last_update_metrics["0-Train/kl_stop_analytic"] = float(
kl_stop_analytic
)
self._last_update_metrics["0-Train/kl_analytic_batch_last"] = float(
kl_analytic_batch_last
)
self._last_update_metrics["0-Train/kl_analytic_batch_max"] = float(
kl_analytic_batch_max
)
self._last_update_metrics["0-Train/clip_fraction_batch_mean"] = float(
clip_fraction_batch_mean
)
self._last_update_metrics["0-Train/clip_fraction_batch_last"] = float(
clip_fraction_batch_last
)
moe_layers = [
layer
for layer in actor_unwrapped.actor_module.layers
if isinstance(layer, GroupedMoEBlock)
]
moe_active_expert_ratio = None
moe_max_expert_frac = None
moe_least_expert_frac = None
moe_dead_expert_ratio = None
moe_expert_count_cv = None
moe_selected_expert_margin_to_unselected = None
moe_last_router_js_step = None
moe_last_router_top1_switch_rate = None
if len(moe_layers) > 0:
moe_metrics = self._summarize_moe_layer_stats(moe_layers)
moe_active_expert_ratio = moe_metrics["moe_active_expert_ratio"]
moe_max_expert_frac = moe_metrics["moe_max_expert_frac"]
moe_least_expert_frac = moe_metrics["moe_least_expert_frac"]
moe_dead_expert_ratio = moe_metrics["moe_dead_expert_ratio"]
moe_expert_count_cv = moe_metrics["moe_expert_count_cv"]
moe_selected_expert_margin_to_unselected = moe_metrics[
"moe_selected_expert_margin_to_unselected"
]
router_shift_stats = actor_policy.get_last_moe_router_shift_stats()
js_sum = router_shift_stats["js_sum"]
js_count = router_shift_stats["js_count"]
top1_switch_sum = router_shift_stats["top1_switch_sum"]
top1_switch_count = router_shift_stats["top1_switch_count"]
if (
js_sum is not None
and js_count is not None
and top1_switch_sum is not None
and top1_switch_count is not None
):
js_sum = js_sum.detach().to(self.device, dtype=torch.float32)
js_count = js_count.detach().to(
self.device, dtype=torch.float32
)
top1_switch_sum = top1_switch_sum.detach().to(
self.device, dtype=torch.float32
)
top1_switch_count = top1_switch_count.detach().to(
self.device, dtype=torch.float32
)
if self.is_distributed:
js_sum = self.accelerator.reduce(js_sum, reduction="sum")
js_count = self.accelerator.reduce(
js_count, reduction="sum"
)
top1_switch_sum = self.accelerator.reduce(
top1_switch_sum, reduction="sum"
)
top1_switch_count = self.accelerator.reduce(
top1_switch_count, reduction="sum"
)
if float(js_count.item()) > 0.0:
moe_last_router_js_step = float((js_sum / js_count).item())
if float(top1_switch_count.item()) > 0.0:
moe_last_router_top1_switch_rate = float(
(top1_switch_sum / top1_switch_count).item()
)
self.storage.clear()
loss_out = {
"value_function": mean_value_loss,
"critic_explained_variance": critic_explained_variance,
"surrogate": mean_surrogate_loss,
"entropy": mean_entropy,
"kl_token": mean_kl_token,
"kl_loss": mean_kl_loss,
"kl_analytic": mean_kl_analytic,
"aux_base_lin_vel_nll": mean_aux_base_lin_vel_nll,
"aux_root_height_nll": mean_aux_root_height_nll,
"aux_base_lin_vel_std": mean_aux_base_lin_vel_std,
"aux_root_height_std": mean_aux_root_height_std,
"aux_keybody_contact_bce": mean_aux_keybody_contact_bce,
"aux_keybody_contact_acc": mean_aux_keybody_contact_acc,
"aux_ref_keybody_rel_pos_mse": mean_aux_ref_keybody_rel_pos_mse,
"aux_robot_keybody_rel_pos_mse": (
mean_aux_robot_keybody_rel_pos_mse
),
"aux_denoise_ref_root_lin_vel_huber": (
mean_aux_denoise_ref_root_lin_vel_huber
),
"aux_denoise_ref_root_ang_vel_huber": (
mean_aux_denoise_ref_root_ang_vel_huber
),
"aux_denoise_ref_dof_pos_huber": (
mean_aux_denoise_ref_dof_pos_huber
),
"aux_router_command_recon_mse": mean_aux_router_command_recon_mse,
"aux_router_future_recon_huber": (
mean_aux_router_future_recon_huber
),
"aux_router_switch_penalty_js": (
mean_aux_router_switch_penalty_js
),
"dead_expert_margin_to_topk": (
mean_dead_expert_margin_to_topk_loss
),
"router_expert_orthogonal": mean_router_expert_orthogonal_loss,
"selected_expert_margin_to_unselected": (
mean_selected_expert_margin_to_unselected_loss
),
"moe_active_expert_ratio": moe_active_expert_ratio,
"moe_max_expert_frac": moe_max_expert_frac,
"moe_least_expert_frac": moe_least_expert_frac,
"moe_dead_expert_ratio": moe_dead_expert_ratio,
"moe_expert_count_cv": moe_expert_count_cv,
"moe_selected_expert_margin_to_unselected": (
moe_selected_expert_margin_to_unselected
),
"moe_last_router_js_step": moe_last_router_js_step,
"moe_last_router_top1_switch_rate": (
moe_last_router_top1_switch_rate
),
}
if self.is_distributed:
reduced_out = {}
for k, v in loss_out.items():
if v is None:
reduced_out[k] = None
continue
t = torch.tensor(v, device=self.device, dtype=torch.float32)
reduced_t = self.accelerator.reduce(t, reduction="mean")
reduced_out[k] = float(reduced_t.item())
loss_out = reduced_out
self._post_update_hook(loss_out)
return loss_out
================================================
FILE: holomotion/src/data_curation/.gitignore
================================================
_generated/
================================================
FILE: holomotion/src/data_curation/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/data_curation/data_smplify.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
#
import argparse
import os
from smplify.smplify_humanact12 import humanact12_to_amass
from smplify.smplify_motionx import motionx_to_amass
from smplify.smplify_omomo import omomo_to_amass
from holomotion.holomotion.src.data_curation.smplify.smplify_zjumocap import (
zju_to_amass,
)
def ensure_dir(path):
"""Make sure the dir exist.
Args:
path: The path of the dir.
"""
if not os.path.exists(path):
os.makedirs(path)
def main():
"""Convert multiple motion capture datasets to AMASS format.
This function parses command-line arguments to specify the root directory
of raw datasets and an optional save directory. It iterates over the
supported datasets (MotionX, ZJU_Mocap, HumanAct12, OMOMO), and if the
corresponding data directory exists, converts it to AMASS format and saves
it in a unified directory structure.
Raises:
SystemExit: If required command-line arguments are missing.
Side Effects:
Creates output directories and writes converted data files.
Prints progress and warning messages to stdout.
"""
parser = argparse.ArgumentParser(
description="Convert all datasets to AMASS format"
)
parser.add_argument(
"--data_root",
type=str,
required=True,
help="Path to the root directory of raw datasets",
)
parser.add_argument(
"--save_root",
type=str,
default=None,
help="Path to save the unified data (default: data_root/smplx_data)",
)
args = parser.parse_args()
data_root = os.path.abspath(args.data_root)
save_root = args.save_root or "./data/amass_compatible_datasets"
print(f"Raw data root: {data_root}")
print(f"Unified data will be saved to: {save_root}")
ensure_dir(save_root)
datasets = [
("MotionX", motionx_to_amass),
("ZJU_Mocap", zju_to_amass),
("humanact12", humanact12_to_amass),
("OMOMO", omomo_to_amass),
]
for name, func in datasets:
data_dir = os.path.join(data_root, name)
save_dir = os.path.join(save_root, name)
ensure_dir(save_dir)
if not os.path.exists(data_dir):
print(f"Warning: {data_dir} does not exist. Skipping {name}.")
continue
print(f"Processing {name}...")
func(data_dir, save_dir)
print(f"{name} done. Saved to {save_dir}.\n")
print("All datasets processed.")
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/data_curation/filter/filter.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import json
import os
import numpy as np
def checksitpose(
npz_path, ref_pose_path, threshold=0.75, frame_thresh=1
) -> bool:
"""Check if the given motion sequence is close to the pose.
Args:
npz_path (str): Path to the .npz file of the motion sequence.
ref_pose_path (str): the reference sitting pose.
threshold (float, optional): Euclidean distance threshold.
frame_thresh (int, optional): Minimum number of frames.
Returns:
bool: True if the sequence contains sitting-like frames.
"""
count = 0
try:
sitdata = np.load(ref_pose_path)
sitpose = sitdata["poses"][535][:66] # reference sitting pose
except Exception:
return False
sitpose_down = sitpose[3:36] # lower-body joints only
bdata = np.load(npz_path)
curposes = bdata["poses"] # shape: (N, 165)
for pose in curposes:
pose_down = pose[3:36]
dist = np.linalg.norm(pose_down - sitpose_down)
if dist < threshold:
count += 1
if count >= frame_thresh:
return True
return False
def process_dataset(
parent_folder,
json_path,
output_path,
abnormal_path,
sit_pose_reference,
stair_keywords=None,
sit_keywords=None,
sit_threshold=0.75,
frame_threshold=20,
velocity_threshold=100.0,
):
"""Label the dataset under parent folder."""
stair_keywords = stair_keywords or [
"stairs",
"staircase",
"upstairs",
"downstairs",
]
sit_keywords = sit_keywords or ["sitting", "Sitting"]
abnormal_dataset = ["aist"]
stairs = sit = untrack = 0
filtered_paths = set()
with (
open(json_path) as f_in,
open(output_path, "w") as f_out_normal,
open(abnormal_path, "w") as f_out_abnormal,
):
for line in f_in:
line = line.strip()
if not line:
continue
try:
content = json.loads(line)
path = content.get("path", "")
npz_path = os.path.join(parent_folder, path)
# skip the path if it is abnormal
if path in filtered_paths:
f_out_abnormal.write(line + "\n")
continue
up_z = content.get("max_up_z_velocity", 0)
down_z = content.get("max_down_z_velocity", 0)
max_z = content.get("max_z_translation", 0)
min_z = content.get("min_z_translation", 0)
mean_v = content.get("mean_velocity", 0)
# filter by keywords
if any(kw in path for kw in stair_keywords):
f_out_abnormal.write(line + "\n")
filtered_paths.clear()
filtered_paths.add(path)
stairs += 1
continue
elif any(kw in path for kw in sit_keywords):
f_out_abnormal.write(line + "\n")
filtered_paths.clear()
filtered_paths.add(path)
sit += 1
continue
elif any(kw in path for kw in abnormal_dataset):
f_out_abnormal.write(line + "\n")
filtered_paths.clear()
filtered_paths.add(path)
continue
if mean_v > velocity_threshold:
f_out_abnormal.write(line + "\n")
filtered_paths.clear()
filtered_paths.add(path)
untrack += 1
continue
if up_z >= 0.6 and max_z > 0.7:
f_out_abnormal.write(line + "\n")
filtered_paths.clear()
filtered_paths.add(path)
stairs += 1
continue
elif down_z <= -0.7 and min_z < -0.7:
f_out_abnormal.write(line + "\n")
filtered_paths.clear()
filtered_paths.add(path)
stairs += 1
continue
if checksitpose(
npz_path,
sit_pose_reference,
sit_threshold,
frame_threshold,
):
f_out_abnormal.write(line + "\n")
filtered_paths.add(path)
sit += 1
continue
# normal motion
f_out_normal.write(line + "\n")
except Exception as e:
print(f"Error processing line: {line}\nException: {e}")
print(
f"total abnormal data:upstairs {stairs}, sitting {sit}, \
velocity {untrack}"
)
def jsonl_to_yaml(jsonl_path, yaml_output_path):
"""Convert jsonl file into yaml file."""
output_set = set()
with open(jsonl_path) as f:
for line in f:
if not line.strip():
continue
try:
data = json.loads(line)
path = data.get("path", "")
if path:
clean_path = os.path.splitext(path.strip().lstrip("/"))[0]
new_name = "0-" + clean_path.replace("/", "_").replace(
"\\", "_"
)
output_set.add(f"{new_name}")
except json.JSONDecodeError:
print(f"skip json line: {line.strip()}")
continue
with open(yaml_output_path, "w") as out:
out.write("[\n")
for item in sorted(output_set):
out.write(f" {item},\n")
out.write("]\n")
print(f"done, total {len(output_set)} items -> {yaml_output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Filter AMASS dataset and save results."
)
parser.add_argument(
"--parent_folder",
type=str,
default="./data/amass_compatible_datasets",
help="Path to the parent folder of AMASS data",
)
parser.add_argument(
"--json_path",
type=str,
default="./data/dataset_labels/OMOMO.jsonl",
help="Path to the input JSONL file",
)
parser.add_argument(
"--output_path",
type=str,
default="./data/dataset_labels/temp.jsonl",
help="Path to save the filtered output JSONL",
)
parser.add_argument(
"--abnormal_path",
type=str,
default="./data/dataset_labels/temp2.jsonl",
help="Path to save abnormal data JSONL",
)
parser.add_argument(
"--sit_pose_reference",
type=str,
default="./data/amass_compatible_datasets/amass/BioMotionLab_NTroje/rub062/0016_sitting2_poses.npz",
help="Path to the reference sitting pose npz",
)
parser.add_argument(
"--yaml_path",
type=str,
default="./holomotion/config/data_curation/base.yaml",
help="Path to the excluded yaml file",
)
args = parser.parse_args()
process_dataset(
parent_folder=args.parent_folder,
json_path=args.json_path,
output_path=args.output_path,
abnormal_path=args.abnormal_path,
sit_pose_reference=args.sit_pose_reference,
)
os.makedirs(os.path.dirname(args.yaml_path), exist_ok=True)
jsonl_to_yaml(args.abnormal_path, args.yaml_path)
================================================
FILE: holomotion/src/data_curation/filter/label_data.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
#
import argparse
import json
import os
import sys
import numpy as np
sys.path.append(
"./holomotion/src/data_curation/omomo_release/human_body_prior/src"
)
def calc_max_xy_translation(motion_data: dict):
"""Calculate max xy translation."""
trans = motion_data["trans"]
root_trans_offset = trans
max_xy_translation = np.max(
np.linalg.norm(
root_trans_offset[:, :2] - root_trans_offset[0:1, :2],
axis=1,
)
)
return max_xy_translation
def calc_max_z_translation(motion_data: dict):
"""Calculate max and min z translation."""
trans = motion_data["trans"]
root_trans_offset = trans
max_z_translation = np.max(
root_trans_offset[:, 2] - root_trans_offset[0:1, 2]
)
min_z_translation = np.min(
root_trans_offset[:, 2] - root_trans_offset[0:1, 2]
)
return max_z_translation, min_z_translation
def calc_max_velocity_scale(motion_data: dict, fps: float = 30):
"""Calculate max velocity scale."""
root_trans_offset = motion_data["trans"]
est_root_vel = np.diff(root_trans_offset * fps, axis=0)
root_vel_norm = np.linalg.norm(est_root_vel, axis=-1)
max_velocity_scale = np.max(root_vel_norm)
return max_velocity_scale
def calc_mean_velocity_scale(motion_data: dict, fps: float = 30):
"""Calculate mean velocity scale."""
root_trans_offset = motion_data["trans"]
est_root_vel = np.diff(root_trans_offset * fps, axis=0)
root_vel_norm = np.linalg.norm(est_root_vel, axis=-1)
mean_velocity_scale = np.mean(root_vel_norm)
return mean_velocity_scale
def calc_std_velocity_scale(motion_data: dict, fps: float = 30):
"""Calculate std velocity scale."""
root_trans_offset = motion_data["trans"]
est_root_vel = np.diff(root_trans_offset * fps, axis=0)
root_vel_norm = np.linalg.norm(est_root_vel, axis=-1)
std_velocity_scale = np.std(root_vel_norm)
return std_velocity_scale
def calc_max_vxy_scale(motion_data: dict, fps: float = 30):
"""Calculate smax vx, vy scale."""
root_trans_offset = motion_data["trans"]
est_root_vel = np.diff(root_trans_offset * fps, axis=0)
root_vel_norm = np.linalg.norm(est_root_vel[:, :2], axis=-1)
max_vxy_scale = np.max(root_vel_norm)
mean_vxy_scale = np.mean(root_vel_norm)
std_vxy_scale = np.std(root_vel_norm)
return max_vxy_scale, mean_vxy_scale, std_vxy_scale
def calc_std_accel(motion_data: dict, fps: float = 30.0) -> float:
"""Calculate the standard deviation of root joint acceleration.
This function computes the per-frame acceleration of the root joint in the
XY plane from its translation data and returns the standard deviation
of those values.
Args:
motion_data (dict): A dictionary that must contain a 'trans' key
representing global translation of the root joint.
Shape should be (T, 3), where T is the number of frames.
fps (float): Frames per second of the motion sequence.
Returns:
float: Standard deviation of the acceleration magnitudes
on the XY plane. Returns 0.0 if there are fewer than 3 frames.
"""
trans = motion_data["trans"] # shape: (T, 3)
if trans.shape[0] < 3:
return 0.0 # At least 3 frames are needed to compute two differences
# Compute velocity (frame-to-frame displacement * fps)
velocities = np.diff(trans, axis=0) * fps # shape: (T-1, 3)
# Compute acceleration (frame-to-frame velocity difference * fps)
accelerations = np.diff(velocities, axis=0) * fps # shape: (T-2, 3)
# Compute acceleration magnitude in XY plane
accel_xy_norm = np.linalg.norm(
accelerations[:, :2], axis=1
) # shape: (T-2,)
# Return standard deviation
return np.std(accel_xy_norm)
def calc_max_vz_scale(motion_data: dict, fps: float = 30):
"""Calculate max vz scale."""
root_trans_offset = motion_data["trans"]
est_root_vel = np.diff(root_trans_offset * fps, axis=0)
root_vel_norm = np.abs(est_root_vel[:, 2])
max_vz_scale = np.max(root_vel_norm)
mean_vz_scale = np.mean(root_vel_norm)
std_vz_scale = np.std(root_vel_norm)
return max_vz_scale, mean_vz_scale, std_vz_scale
def calc_vz_scale_with_direction(motion_data: dict, fps: float = 30):
"""Calculate vz scale with direction."""
root_trans_offset = motion_data["trans"]
est_root_vel = np.diff(root_trans_offset * fps, axis=0)
vz = est_root_vel[:, 2]
max_up_vz = np.max(vz[vz > 0]) if np.any(vz > 0) else 0.0
max_down_vz = np.min(vz[vz < 0]) if np.any(vz < 0) else 0.0
mean_vz = np.mean(vz)
std_vz = np.std(vz)
return max_up_vz, max_down_vz, mean_vz, std_vz
def beyond_upper_dof_limits(
motion_data: dict,
upper_dof_mapping: dict,
upper_dof_max_limits: dict,
):
"""Check whether or not the motion data is beyond upper dof limits."""
for dof_name, dof_idx in upper_dof_mapping.items():
dof_data = motion_data["dof"][:, dof_idx]
max_dof_scale = np.max(dof_data)
min_dof_scale = np.min(dof_data)
if (
max_dof_scale < upper_dof_max_limits[dof_name][0]
or max_dof_scale > upper_dof_max_limits[dof_name][1]
or min_dof_scale < upper_dof_max_limits[dof_name][0]
or min_dof_scale > upper_dof_max_limits[dof_name][1]
):
return True
return False
class HyperParams:
max_xy_translation: float = 2.0
max_z_translation: float = 0.3
max_velocity_scale: float = 1.0
max_vxy_scale: float = 1.2
max_vz_scale: float = 0.3
upper_dof_mapping: dict = {
"left_shoulder_pitch_joint": 13,
"left_shoulder_roll_joint": 14,
"left_shoulder_yaw_joint": 15,
"left_elbow_joint": 16,
"right_shoulder_pitch_joint": 17,
"right_shoulder_roll_joint": 18,
"right_shoulder_yaw_joint": 19,
"right_elbow_joint": 20,
}
upper_dof_max_limits: dict = {
"left_shoulder_pitch_joint": [-1.0, 1.0],
"left_shoulder_roll_joint": [0.0, 0.5],
"left_shoulder_yaw_joint": [-0.5, 0.5],
"left_elbow_joint": [0.5, 1.3],
"right_shoulder_pitch_joint": [-1.0, 1.0],
"right_shoulder_roll_joint": [-0.5, 0.0],
"right_shoulder_yaw_joint": [-0.5, 0.3],
"right_elbow_joint": [0.5, 1.5],
}
def label_data_with_metrics(data_folder, jsonl_path: str, parent_folder: str):
"""Calculate the metics and load them into a jsonl file."""
assert jsonl_path.endswith(".jsonl")
with open(jsonl_path, "w") as f_out:
for root, _, files in os.walk(data_folder):
for file in files:
if file.endswith(".npz"):
npz_path = os.path.join(root, file)
content = {}
content["path"] = os.path.relpath(npz_path, parent_folder)
data = np.load(npz_path)
fps = data.get("mocap_frame_rate")
if fps is None:
fps = data.get("mocap_framerate")
if fps is None:
continue
try:
content["max_xy_translation"] = round(
calc_max_xy_translation(data), 2
)
max_z_translation, min_z_translation = (
calc_max_z_translation(data)
)
content["max_z_translation"] = round(
max_z_translation, 2
)
content["min_z_translation"] = round(
min_z_translation, 2
)
content["max_velocity"] = round(
calc_max_velocity_scale(data, fps), 2
)
content["mean_velocity"] = round(
calc_mean_velocity_scale(data, fps), 2
)
content["std_velocity"] = round(
calc_std_velocity_scale(data, fps), 2
)
content["std_accel"] = round(
calc_std_accel(data, fps), 2
)
max_xy_v, mean_xy_v, std_xy_v = calc_max_vxy_scale(
data, fps
)
content["max_xy_velocity"] = round(max_xy_v, 2)
content["mean_xy_velocity"] = round(mean_xy_v, 2)
content["std_xy_velocity"] = round(std_xy_v, 2)
max_up_z_v, max_down_z_v, mean_z_v, std_z_v = (
calc_vz_scale_with_direction(data, fps)
)
content["max_up_z_velocity"] = round(max_up_z_v, 2)
content["max_down_z_velocity"] = round(max_down_z_v, 2)
content["mean_z_velocity"] = round(mean_z_v, 2)
content["std_z_velocity"] = round(std_z_v, 2)
except Exception as e:
print(f"Error: {e}")
def convert_to_builtin_type(obj):
if isinstance(obj, dict):
return {
k: convert_to_builtin_type(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [convert_to_builtin_type(i) for i in obj]
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.int32, np.int64)):
return int(obj)
else:
return obj
f_out.write(
json.dumps(convert_to_builtin_type(content)) + "\n"
)
print(f"Annotated file saved to: {jsonl_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--jsonl_list",
nargs="+",
default=["humanact12", "MotionX", "OMOMO", "ZJU_Mocap", "amass"],
help="List of jsonl files to process.",
)
args = parser.parse_args()
amass_folder = "./data/amass_compatible_datasets/amass"
other_folder = "./data/amass_compatible_datasets"
caption_folder = "./data/dataset_labels"
os.makedirs(caption_folder, exist_ok=True)
for name in args.jsonl_list:
file = name + ".jsonl"
if name != "amass":
label_data_with_metrics(
os.path.join(other_folder, name),
os.path.join(caption_folder, file),
other_folder,
)
else:
label_data_with_metrics(
amass_folder, os.path.join(caption_folder, file), amass_folder
)
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/data_curation/smpl_npz_to_html.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple
import numpy as np
from scipy.spatial.transform import Rotation as R
# -----------------------------
# Defaults
# -----------------------------
DEFAULT_TEMPLATE_PATH = Path("index_wooden_static.html")
DEFAULT_OUT_HTML = Path("vis.html")
POSE_JOINTS = 22
EULER_FIX_DEG = (-90.0, 180.0, 0.0)
EULER_ORDER = "xyz"
# Empirical vertical offset (in meters) to align wooden_static visualization mesh
# with canonical SMPL coordinates (e.g., GVHMR pipelines).
WOODEN_SMPL_HEIGHT_OFFSET = 0.2
@dataclass(frozen=True)
class SmplSequence:
"""A minimal SMPL motion sequence loaded from npz."""
poses: np.ndarray # (T, 66) = root(3) + body(63), axis-angle
trans: np.ndarray # (T, 3)
betas: np.ndarray # (B,)
fps: float
gender: str
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(
description="Generate vis.html from a SMPL npz using a HTML template."
)
ap.add_argument("--npz", type=Path, help="Path to input .npz")
ap.add_argument(
"--template",
type=Path,
default=DEFAULT_TEMPLATE_PATH,
help="Path to HTML template",
)
ap.add_argument(
"--out",
type=Path,
default=DEFAULT_OUT_HTML,
help="Path to output HTML",
)
ap.add_argument(
"--pose_joints",
type=int,
default=POSE_JOINTS,
help=f"Number of pose joints in poses (default: {POSE_JOINTS}).",
)
ap.add_argument(
"--height_axis",
type=int,
default=1,
choices=[0, 1, 2],
help="Axis index for height in Th (default: 1 for Y-up).",
)
ap.add_argument(
"--height_offset",
type=float,
default=WOODEN_SMPL_HEIGHT_OFFSET,
help=(
"Subtract from Th height axis (Y-up), in meters. "
"Default is an empirical offset to align wooden_static mesh "
"with canonical SMPL coordinates (e.g., GVHMR)."
),
)
return ap.parse_args()
def euler_fix_rot(euler_deg=EULER_FIX_DEG, order=EULER_ORDER) -> R:
"""Rotation for world-frame correction: R_new = R_fix * R_old."""
return R.from_euler(order.lower(), euler_deg, degrees=True)
def _require_key(data: np.lib.npyio.NpzFile, key: str) -> np.ndarray:
if key not in data:
raise KeyError(
f"Missing key '{key}' in npz. Available: {list(data.keys())}"
)
return data[key]
def load_npz(path: Path) -> SmplSequence:
if not path.exists():
raise FileNotFoundError(f"Missing {path}")
data = np.load(path, allow_pickle=False)
poses = _require_key(data, "poses").astype(np.float32)
trans = _require_key(data, "trans").astype(np.float32)
betas = _require_key(data, "betas").astype(np.float32)
fps = float(np.asarray(_require_key(data, "mocap_framerate")))
gender = str(np.asarray(_require_key(data, "gender")))
return SmplSequence(
poses=poses, trans=trans, betas=betas, fps=fps, gender=gender
)
def validate_sequence(seq: SmplSequence, pose_joints: int) -> int:
"""Validate shapes and return T."""
if seq.poses.ndim != 2:
raise ValueError(f"poses must be 2D, got shape={seq.poses.shape}")
if seq.trans.ndim != 2 or seq.trans.shape[1] != 3:
raise ValueError(f"trans must be (T,3), got shape={seq.trans.shape}")
T = int(seq.poses.shape[0])
exp_dim = int(pose_joints) * 3
if seq.poses.shape[1] != exp_dim:
raise ValueError(
f"unexpected poses shape: {seq.poses.shape}, expected (T,{exp_dim})"
)
if seq.trans.shape[0] != T:
raise ValueError(
f"poses frames ({T}) != trans frames ({seq.trans.shape[0]})"
)
return T
def build_smpl_frames(
seq: SmplSequence,
*,
pose_joints: int,
height_axis: int,
height_offset: float,
) -> Tuple[list, int]:
"""
Build frames in the format expected by index_wooden_static.html template.
Notes:
height_offset is a visualization-only correction to compensate for the
vertical origin mismatch between wooden_static mesh and canonical SMPL
coordinates (e.g., GVHMR). Override via --height_offset if needed.
"""
T = validate_sequence(seq, pose_joints)
rot_fix = euler_fix_rot()
root_aa = seq.poses[:, :3]
body_aa = seq.poses[:, 3:]
# root: left-multiply world rotation
Rh = (rot_fix * R.from_rotvec(root_aa)).as_rotvec().astype(np.float32)
# trans: rotate in world frame, then apply visualization height offset
Th = rot_fix.apply(seq.trans).astype(np.float32)
if height_offset != 0.0:
Th[:, int(height_axis)] -= float(height_offset)
# pad hands (6) -> body(63) + hand(6) = 69
poses_js = np.concatenate([body_aa, np.zeros((T, 6), np.float32)], axis=1)
shapes = seq.betas.reshape(-1).tolist()
frames = [
[
{
"id": 0,
"gender": seq.gender,
"Rh": [Rh[f].tolist()],
"Th": [Th[f].tolist()],
"poses": [poses_js[f].tolist()],
"shapes": shapes,
}
]
for f in range(T)
]
return frames, T
def render_html(template_path: Path, frames: list, T: int, fps: float) -> str:
template = template_path.read_text(encoding="utf-8")
smpl_data_json = json.dumps(frames, ensure_ascii=False)
caption_html = (
""
f"Frames: {T} Framerate: {fps:.1f} fps"
"
"
)
return template.replace("{{ smpl_data_json }}", smpl_data_json).replace(
"{{ caption_html }}", caption_html
)
def main(
npz_path: Path,
template_path: Path,
out_html: Path,
*,
pose_joints: int,
height_axis: int,
height_offset: float,
) -> None:
if not template_path.exists():
raise FileNotFoundError(f"Missing {template_path}")
seq = load_npz(npz_path)
frames, T = build_smpl_frames(
seq,
pose_joints=pose_joints,
height_axis=height_axis,
height_offset=height_offset,
)
html = render_html(template_path, frames, T, seq.fps)
out_html.write_text(html, encoding="utf-8")
print(f"[OK] wrote {out_html.resolve()}")
if __name__ == "__main__":
args = parse_args()
main(
args.npz,
args.template,
args.out,
pose_joints=args.pose_joints,
height_axis=args.height_axis,
height_offset=args.height_offset,
)
================================================
FILE: holomotion/src/data_curation/smplify/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/data_curation/smplify/smplify_humanact12.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import random
import h5py
import numpy as np
import smplx
import torch
from scipy.spatial.transform import Rotation
from tqdm import tqdm
from thirdparties.joints2smpl.src import config
from thirdparties.joints2smpl.src.smplify import SMPLify3D
SMPL_MODEL_DIR = "./assets/smpl/"
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
num_joints = 22
joint_category = "AMASS"
num_smplify_iters = 150
fix_foot = False
def joints2smpl(file_name, data_dir, save_dir):
"""Convert 3D joint positions to SMPL-X parameters.
Args:
file_name (str): Name of the input .npy joint file
data_dir (str): Directory containing input joint files
save_dir (str): Directory to save processed output files
"""
# print(file_name)
input_joints = np.load(os.path.join(data_dir, file_name))
input_joints = input_joints[:, :, [0, 1, 2]] # amass stands on x, y
"""XY at origin"""
input_joints[..., [0, 1]] -= input_joints[0, 0, [0, 1]]
"""Put on Floor"""
floor_height = input_joints[:, :, 2].min()
input_joints[:, :, 2] -= floor_height
batch_size = input_joints.shape[0]
smplmodel = smplx.create(
SMPL_MODEL_DIR,
model_type="smpl",
gender="neutral",
ext="npz",
batch_size=batch_size,
).to(device)
# ## --- load the mean pose as original ----
smpl_mean_file = config.SMPL_MEAN_FILE
file = h5py.File(smpl_mean_file, "r")
init_mean_pose = (
torch.from_numpy(file["pose"][:])
.unsqueeze(0)
.repeat(batch_size, 1)
.float()
.to(device)
)
init_mean_shape = (
torch.from_numpy(file["shape"][:])
.unsqueeze(0)
.repeat(batch_size, 1)
.float()
.to(device)
)
cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)
# # #-------------initialize SMPLify
smplify = SMPLify3D(
smplxmodel=smplmodel,
batch_size=batch_size,
joints_category=joint_category,
num_iters=num_smplify_iters,
device=device,
)
keypoints_3d = torch.Tensor(input_joints).to(device).float()
pred_betas = init_mean_shape
pred_pose = init_mean_pose
pred_cam_t = cam_trans_zero
if joint_category == "AMASS":
confidence_input = torch.ones(num_joints)
# make sure the foot and ankle
if fix_foot:
confidence_input[7] = 1.5
confidence_input[8] = 1.5
confidence_input[10] = 1.5
confidence_input[11] = 1.5
else:
print("Such category not settle down!")
(
new_opt_vertices,
new_opt_joints,
new_opt_pose,
new_opt_betas,
new_opt_cam_t,
new_opt_joint_loss,
) = smplify(
pred_pose.detach(),
pred_betas.detach(),
pred_cam_t.detach(),
keypoints_3d,
conf_3d=confidence_input.to(device),
# seq_ind=idx
)
poses = new_opt_pose.detach().cpu().numpy()
betas = new_opt_betas.mean(axis=0).detach().cpu().numpy()
trans = keypoints_3d[:, 0].detach().cpu().numpy()
target_dim = 165
current_dim = poses.shape[-1]
pad_dim = target_dim - current_dim
if pad_dim > 0:
pad_array = np.zeros((*poses.shape[:-1], pad_dim), dtype=poses.dtype)
poses = np.concatenate([poses, pad_array], axis=-1)
root_orient = poses[:, :3]
root_mat = Rotation.from_rotvec(root_orient).as_matrix()
rx_minus_100 = Rotation.from_euler("x", -100, degrees=True).as_matrix()
align_r = rx_minus_100 @ root_mat
align_axis_angle = Rotation.from_matrix(align_r).as_rotvec()
poses[:, :3] = align_axis_angle
trans_rotated = rx_minus_100 @ (trans.T)
trans_rotated = trans_rotated.T
input_joints = input_joints[:, :, [0, 2, 1]] # jts stands on x, z
input_joints[..., 0] *= -1
param = {
"poses": poses,
"trans": trans_rotated,
"betas": betas,
"gender": "neutral",
"jtr": input_joints,
"mocap_frame_rate": 30,
}
file_name = file_name.split(".")[0] + ".npz"
print(file_name)
np.savez_compressed(os.path.join(save_dir, file_name), **param)
def humanact12_to_amass(data_dir, save_dir):
"""Convert HumanAct12 dataset to AMASS-compatible format.
Args:
data_dir (str): Directory containing HumanAct12 .npy joint files
save_dir (str): Directory to save processed AMASS .npz files
"""
os.makedirs(save_dir, exist_ok=True)
file_list = os.listdir(data_dir)
random.shuffle(file_list)
for file_name in tqdm(file_list):
if os.path.exists(os.path.join(save_dir, file_name)):
print(f"{os.path.join(save_dir, file_name)} already exists")
continue
joints2smpl(file_name, data_dir, save_dir)
if __name__ == "__main__":
data_dir = ""
save_dir = ""
================================================
FILE: holomotion/src/data_curation/smplify/smplify_motionx.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import numpy as np
from scipy.spatial.transform import Rotation
def motionx_to_amass(src_root, dst_root):
"""Convert MotionX format motion data to AMASS format.
Args:
src_root (str): Source directory containing MotionX .npy files
dst_root (str): Destination directory for processed AMASS .npz files
Side effects:
Creates directory structure mirroring src_root under dst_root
Generates compressed .npz files in destination directory
Prints file paths of processed files
Processed data contains:
poses: [T, 165] float array of joint rotations (root first)
trans: [T, 3] float array of root translations
betas: [10] float array of shape parameters
gender: str (always "neutral")
mocap_frame_rate: int (always 30)
"""
os.makedirs(dst_root, exist_ok=True)
for root, _, files in os.walk(src_root):
# print(files)
for file in files:
src_file_path = os.path.join(root, file)
motion = np.load(src_file_path)
poses = motion[:, :156] # 最终 shape: (T, 156)
num_frames = poses.shape[0]
sl = poses.shape[1]
pad = np.zeros((num_frames, 165 - sl), dtype=poses.dtype) # (T, 9)
poses = np.concatenate([poses, pad], axis=1) # (T, 165)
align_axis_angle = poses[:, :3]
root_orient = poses[:, :3]
root_mat = Rotation.from_rotvec(root_orient).as_matrix()
rotate_matrix = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
align_r = rotate_matrix @ root_mat
align_axis_angle = Rotation.from_matrix(align_r).as_rotvec()
poses[:, :3] = align_axis_angle
trans = motion[:, 309:312] # (T, 3)
trans[:, 2] = trans[:, 2] * (-1)
trans_matrix = np.array(
[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]
)
trans = np.dot(trans, trans_matrix)
trans = rotate_matrix @ (trans.T)
trans = trans.T
betas = motion[0, 312:]
amass_data = {
"poses": poses,
"trans": trans,
"betas": betas,
"gender": "neutral",
"mocap_frame_rate": 30,
}
relative_path = src_file_path.replace(src_root, "")
file_name = dst_root + relative_path
save_dir = file_name.split("/")[-1]
save_dir = file_name.replace(save_dir, "")
os.makedirs(save_dir, exist_ok=True)
file_name = file_name.replace(".npy", ".npz")
print(file_name)
np.savez_compressed(file_name, **amass_data)
if __name__ == "__main__":
src_root = "./data/smplx_322"
dst_root = "./data/smplx_data/MotionX"
motionx_to_amass(src_root, dst_root)
================================================
FILE: holomotion/src/data_curation/smplify/smplify_omomo.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
#
# -----------------------------------------------------------------------------
# Portions of this file are derived from omomo_release (https://github.com/lijiaman/omomo_release).
# The original omomo_release code is licensed under the MIT license.
# -----------------------------------------------------------------------------
import os
import numpy as np
import pytorch3d.transforms as transforms
import torch
from torch.utils import data
from thirdparties.omomo_release.manip.data.hand_foot_dataset import (
HandFootManipDataset,
quat_ik_torch,
)
class MyHandFootManipDataset(HandFootManipDataset):
"""Modified dataset class for hand-foot manipulation tasks.
This class overrides the __getitem__ method.
"""
def __init__(self, *args, **kwargs):
"""Initialize the dataset instance by forwarding all arguments.
This constructor ensures proper initialization
of the HandFootManipDataset parent class.
All parameters and keyword arguments are passed through unchanged.
Args:
*args: Variable length argument list for parent class
**kwargs: Arbitrary keyword arguments for parent class
"""
super().__init__(*args, **kwargs)
def __getitem__(self, index):
"""Retrieve and process a data sample by index.
Try not to padding when retrieve motion data.
Args:
index (int): Index of the sample to retrieve
Reference:
https://github.com/lijiaman/omomo_release/blob/main/manip/data/hand_foot_dataset.py
"""
# index = 0 # For debug
data_input = self.window_data_dict[index]["motion"]
data_input = torch.from_numpy(data_input).float()
seq_name = self.window_data_dict[index]["seq_name"]
object_name = seq_name.split("_")[1]
trans2joint = self.window_data_dict[index]["trans2joint"]
if self.use_object_splits:
ori_w_idx = self.window_data_dict[index]["ori_w_idx"]
obj_bps_npy_path = os.path.join(
self.dest_obj_bps_npy_folder,
seq_name + "_" + str(ori_w_idx) + ".npy",
)
else:
obj_bps_npy_path = os.path.join(
self.dest_obj_bps_npy_folder,
seq_name + "_" + str(index) + ".npy",
)
obj_bps_data = np.load(obj_bps_npy_path) # T X N X 3
obj_bps_data = torch.from_numpy(obj_bps_data)
num_joints = 24
normalized_jpos = self.normalize_jpos_min_max(
data_input[:, : num_joints * 3].reshape(-1, num_joints, 3)
) # T X 22 X 3
global_joint_rot = data_input[:, 2 * num_joints * 3 :] # T X (22*6)
new_data_input = torch.cat(
(normalized_jpos.reshape(-1, num_joints * 3), global_joint_rot),
dim=1,
)
ori_data_input = torch.cat(
(data_input[:, : num_joints * 3], global_joint_rot), dim=1
)
# Add padding.
actual_steps = new_data_input.shape[0]
# pass
paded_new_data_input = new_data_input
paded_ori_data_input = ori_data_input
paded_obj_bps = obj_bps_data.reshape(actual_steps, -1)
paded_obj_com_pos = torch.from_numpy(
self.window_data_dict[index]["window_obj_com_pos"]
).float()
paded_obj_rot_mat = torch.from_numpy(
self.window_data_dict[index]["obj_rot_mat"]
).float()
paded_obj_scale = torch.from_numpy(
self.window_data_dict[index]["obj_scale"]
).float()
paded_obj_trans = torch.from_numpy(
self.window_data_dict[index]["obj_trans"]
).float()
if object_name in ["mop", "vacuum"]:
paded_obj_bottom_rot_mat = torch.from_numpy(
self.window_data_dict[index]["obj_bottom_rot_mat"]
).float()
paded_obj_bottom_scale = torch.from_numpy(
self.window_data_dict[index]["obj_bottom_scale"]
).float()
paded_obj_bottom_trans = (
torch.from_numpy(
self.window_data_dict[index]["obj_bottom_trans"]
)
.float()
.squeeze(-1)
)
data_input_dict = {}
data_input_dict["motion"] = paded_new_data_input
data_input_dict["ori_motion"] = paded_ori_data_input
data_input_dict["obj_bps"] = paded_obj_bps
data_input_dict["obj_com_pos"] = paded_obj_com_pos
data_input_dict["obj_rot_mat"] = paded_obj_rot_mat
data_input_dict["obj_scale"] = paded_obj_scale
data_input_dict["obj_trans"] = paded_obj_trans
if object_name in ["mop", "vacuum"]:
data_input_dict["obj_bottom_rot_mat"] = paded_obj_bottom_rot_mat
data_input_dict["obj_bottom_scale"] = paded_obj_bottom_scale
data_input_dict["obj_bottom_trans"] = paded_obj_bottom_trans
else:
data_input_dict["obj_bottom_rot_mat"] = paded_obj_rot_mat
data_input_dict["obj_bottom_scale"] = paded_obj_scale
data_input_dict["obj_bottom_trans"] = paded_obj_trans
data_input_dict["betas"] = self.window_data_dict[index]["betas"]
data_input_dict["gender"] = str(self.window_data_dict[index]["gender"])
data_input_dict["seq_name"] = seq_name
data_input_dict["obj_name"] = seq_name.split("_")[1]
data_input_dict["seq_len"] = actual_steps
data_input_dict["trans2joint"] = trans2joint
return data_input_dict
def run_smplx_model(root_trans, aa_rot_rep, betas, gender, fname):
"""Prepare and save SMPL-X motion data in AMASS npz format.
Processes input motion parameters into SMPL-X compatible format and saves
as a compressed npz file.
Args:
root_trans (torch.Tensor): Root translations [BS, T, 3]
aa_rot_rep (torch.Tensor): Axis-angle joint rotations
[BS, T, num_joints, 3]
where num_joints can be either 22 (body only) or 52 (body+hands)
betas (torch.Tensor): Shape parameters [BS, 16]
gender (list): Gender strings for each sample in batch [BS]
fname (str): Output filename/path for saving .npz file
Output npz file contains:
poses: [BS*T, 165] float array of pose parameters
trans: [BS*T, 3] float array of translations
betas: [16] float array of shape parameters (from first sample)
gender: str (always "neutral")
mocap_frame_rate: int (always 30)
"""
# root_trans: BS X T X 3
# aa_rot_rep: BS X T X 22 X 3
# betas: BS X 16
# gender: BS
bs, num_steps, num_joints, _ = aa_rot_rep.shape
if num_joints != 52:
padding_zeros_hand = torch.zeros(bs, num_steps, 30, 3).to(
aa_rot_rep.device
) # BS X T X 30 X 3
aa_rot_rep = torch.cat(
(aa_rot_rep, padding_zeros_hand), dim=2
) # BS X T X 52 X 3
aa_rot_rep = aa_rot_rep.reshape(
bs * num_steps, -1, 3
) # (BS*T) X n_joints X 3
betas = (
betas[:, None, :].repeat(1, num_steps, 1).reshape(bs * num_steps, -1)
) # (BS*T) X 16
gender = np.asarray(gender)[:, np.newaxis].repeat(num_steps, axis=1)
gender = gender.reshape(-1).tolist() # (BS*T)
smpl_trans = root_trans.reshape(-1, 3) # (BS*T) X 3
smpl_root_orient = aa_rot_rep[:, 0, :] # (BS*T) X 3
# print(smpl_root_orient.shape)
smpl_pose_body = aa_rot_rep[:, 1:22, :].reshape(-1, 63) # (BS*T) X 63
smpl_pose_hand = aa_rot_rep[:, 22:, :].reshape(-1, 90) # (BS*T) X 90
poses = torch.cat(
[smpl_root_orient, smpl_pose_body, smpl_pose_hand], dim=-1
)
target_dim = 165
current_dim = poses.shape[-1]
pad_dim = target_dim - current_dim
if pad_dim > 0:
pad_tensor = torch.zeros(
*poses.shape[:-1], pad_dim, device=poses.device, dtype=poses.dtype
)
poses_padded = torch.cat([poses, pad_tensor], dim=-1)
else:
poses_padded = poses # already 165 or more
amass_data = {
"poses": poses_padded.detach().cpu().numpy(),
"trans": smpl_trans.detach().cpu().numpy(),
"betas": betas[0].detach().cpu().numpy(),
"gender": "neutral",
"mocap_frame_rate": 30,
}
np.savez_compressed(fname, **amass_data)
def process_dataset(dl, dataset, target_folder, split_name: str):
"""Process a motion dataset batch and convert sequences to SMPL-X format.
Args:
dl (DataLoader): PyTorch DataLoader providing batched data
dataset (Dataset): Source dataset object (for denormalization)
target_folder (str): target folder for data saving
split_name (str): Name of data split being processed
Output files:
Saved as: {target_folder}/{split_name}_{object_name}_{index}.npz
Where:
target_folder: (implied from external context)
object_name: Extracted from sequence name
index: Incremental sequence counter
"""
index = 0
for data_dict in dl:
val_data = data_dict["motion"].cuda()
for_vis_gt_data = val_data[:]
all_res_list = for_vis_gt_data
num_seq = all_res_list.shape[0]
print(f"Processing {split_name}, num_seq: {num_seq}")
num_joints = 24
normalized_global_jpos = all_res_list[:, :, : num_joints * 3].reshape(
num_seq, -1, num_joints, 3
)
global_jpos = dataset.de_normalize_jpos_min_max(
normalized_global_jpos.reshape(-1, num_joints, 3)
)
global_jpos = global_jpos.reshape(num_seq, -1, num_joints, 3)
global_root_jpos = global_jpos[:, :, 0, :].clone()
global_rot_6d = all_res_list[:, :, -22 * 6 :].reshape(
num_seq, -1, 22, 6
)
global_rot_mat = transforms.rotation_6d_to_matrix(global_rot_6d)
trans2joint = data_dict["trans2joint"].to(all_res_list.device)
for idx in range(num_seq):
curr_global_rot_mat = global_rot_mat[idx]
curr_local_rot_mat = quat_ik_torch(curr_global_rot_mat)
curr_local_rot_aa_rep = transforms.matrix_to_axis_angle(
curr_local_rot_mat
)
curr_global_root_jpos = global_root_jpos[idx]
curr_trans2joint = trans2joint[idx : idx + 1].clone()
root_trans = curr_global_root_jpos + curr_trans2joint
betas = data_dict["betas"][idx]
gender = data_dict["gender"][idx]
curr_seq_name = data_dict["seq_name"][idx]
object_name = curr_seq_name.split("_")[1]
fname = os.path.join(
target_folder, f"{split_name}_{object_name}_{index}.npz"
)
print(fname)
run_smplx_model(
root_trans[None].cuda(),
curr_local_rot_aa_rep[None].cuda(),
betas.cuda(),
[gender],
fname,
)
index += 1
def omomo_to_amass(data_root_folder, target_folder):
"""Convert Omomo dataset to AMASS-compatible SMPL-X format.
Args:
data_root_folder (str): Path to the root directory of Omomo dataset
target_folder (str): Output directory for processed AMASS files
"""
use_object_split = True
window_size = 120
train_dataset = MyHandFootManipDataset(
train=True,
data_root_folder=data_root_folder,
window=window_size,
use_object_splits=use_object_split,
)
val_dataset = MyHandFootManipDataset(
train=False,
data_root_folder=data_root_folder,
window=window_size,
use_object_splits=use_object_split,
)
val_ds = val_dataset
train_ds = train_dataset
val_dl = data.DataLoader(
val_ds, batch_size=1, shuffle=False, pin_memory=True, num_workers=0
)
train_dl = data.DataLoader(
train_ds, batch_size=1, shuffle=False, pin_memory=True, num_workers=0
)
process_dataset(train_dl, train_dataset, target_folder, "train")
process_dataset(val_dl, val_dataset, target_folder, "val")
if __name__ == "__main__":
data_root_folder = ""
target_folder = ""
omomo_to_amass(data_root_folder, target_folder)
================================================
FILE: holomotion/src/data_curation/smplify/smplify_zjumocap.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import numpy as np
from scipy.spatial.transform import Rotation
from tqdm import tqdm
def zju_single_to_amass(
param_dir, out_path, gender="neutral", fps=30, rotate=False
):
"""Convert .npy files into a single AMASS-style .npz file.
Args:
param_dir: Folder containing 0.npy, 1.npy, ....
out_path: Output .npz path.
gender: Gender to assign ('neutral', 'male', 'female').
fps: Mocap frame rate.
rotate: whether or not rotate the body
"""
pose_list = []
trans_list = []
shape_list = []
# Get sorted list of npy files
files = sorted(
[f for f in os.listdir(param_dir) if f.endswith(".npy")],
key=lambda x: int(os.path.splitext(x)[0]),
)
if rotate:
ry_minus_180 = Rotation.from_euler("y", -180, degrees=True).as_matrix()
else:
ry_minus_180 = Rotation.from_euler("y", 0, degrees=True).as_matrix()
for fname in tqdm(files, desc="Processing frames"):
fpath = os.path.join(param_dir, fname)
data = np.load(fpath, allow_pickle=True).item()
poses = data["poses"] # (1, 72)
global_orient = data["Rh"]
root_orient = global_orient
root_mat = Rotation.from_rotvec(root_orient).as_matrix()
align_r = ry_minus_180 @ root_mat
align_axis_angle = Rotation.from_matrix(align_r).as_rotvec()
global_orient = align_axis_angle
body_pose = poses[:, 3:66]
hand_pose = poses[:, 66:72]
full_pose = np.concatenate(
[global_orient, body_pose, hand_pose], axis=1
) # (1, 165)
pose_list.append(full_pose[0]) # shape: (165,)
trans_list.append(data["Th"][0]) # shape: (3,)
shape_list.append(data["shapes"][0]) # shape: (10,)
poses = np.stack(pose_list, axis=0).astype(np.float32) # (N, 165)
trans = np.stack(trans_list, axis=0).astype(np.float32) # (N, 3)
trans_rotated = ry_minus_180 @ (trans.T)
trans_rotated = trans_rotated.T
betas = shape_list[0].astype(np.float32) # (10,) same for all frames
# Save as AMASS-style npz
np.savez_compressed(
out_path,
poses=poses,
trans=trans_rotated,
betas=betas,
gender=gender,
mocap_frame_rate=fps,
)
print(f"Saved AMASS-style file to: {out_path}")
print(f"Total frames: {poses.shape[0]}")
def zju_to_amass(input_dir, output_dir):
"""Convert multiple ZJU-formatted folders to AMASS-style .npz files.
Args:
input_dir: Path to ZJU dataset root folder.
output_dir: Path to save AMASS-format .npz files.
"""
os.makedirs(output_dir, exist_ok=True)
subjects = sorted(
[
d
for d in os.listdir(input_dir)
if os.path.isdir(os.path.join(input_dir, d))
]
)
for subject in subjects:
subject_dir = os.path.join(input_dir, subject)
new_params_dir = os.path.join(subject_dir, "new_params")
params_dir = os.path.join(subject_dir, "params")
if os.path.isdir(new_params_dir):
param_dir = new_params_dir
print(f"[{subject}] Using new_params")
elif os.path.isdir(params_dir):
param_dir = params_dir
print(f"[{subject}] Using params")
else:
print(f"[{subject}] No params found, skipping")
continue
out_path = os.path.join(output_dir, f"{subject}.npz")
zju_single_to_amass(param_dir, out_path)
print(f"All subjects processed. Output saved to {output_dir}")
# Example usage
if __name__ == "__main__":
zju_to_amass(
param_dir="",
out_path="",
)
================================================
FILE: holomotion/src/data_curation/templates/index_wooden_static.html
================================================
Motion Visualization
{{ caption_html }}
Loading...
1.0x
================================================
FILE: holomotion/src/data_curation/video_to_smpl_gvhmr.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
#
# This file was originally copied from the [PBHC] repository:
# https://github.com/TeleHuman/PBHC
# Modifications have been made to fit the needs of this project.
import cv2
import torch
import pytorch_lightning as pl
import numpy as np
import argparse
from hmr4d.utils.pylogger import Log
import hydra
from hydra import initialize_config_module, compose
from pathlib import Path
from pytorch3d.transforms import quaternion_to_matrix
from hmr4d.configs import register_store_gvhmr
from hmr4d.utils.video_io_utils import (
get_video_lwh,
read_video_np,
save_video,
merge_videos_horizontal,
get_writer,
get_video_reader,
)
from hmr4d.utils.vis.cv2_utils import (
draw_bbx_xyxy_on_image_batch,
draw_coco17_skeleton_batch,
)
from hmr4d.utils.preproc import Tracker, Extractor, VitPoseExtractor, SimpleVO
from hmr4d.utils.geo.hmr_cam import (
get_bbx_xys_from_xyxy,
estimate_K,
convert_K_to_K4,
create_camera_sensor,
)
from hmr4d.utils.geo_transform import compute_cam_angvel
from hmr4d.model.gvhmr.gvhmr_pl_demo import DemoPL
from hmr4d.utils.net_utils import detach_to_cpu, to_cuda
from hmr4d.utils.smplx_utils import make_smplx
from hmr4d.utils.vis.renderer import (
Renderer,
get_global_cameras_static,
get_ground_params_from_points,
)
from tqdm import tqdm
from hmr4d.utils.geo_transform import apply_T_on_points, compute_T_ayfz2ay
from einops import einsum, rearrange
import shutil
import subprocess
from scipy.spatial.transform import Rotation as sRot
CRF = 23 # 17 is lossless, every +6 halves the mp4 size
def get_video_fps(video_path: Path) -> float:
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
if fps is None or fps <= 1e-6:
raise RuntimeError(f"Failed to read FPS from video: {video_path}")
return float(fps)
def is_close_fps(a: float, b: float, tol: float = 0.02) -> bool:
return abs(a - b) <= tol
def transcode_to_30fps_cfr(src: Path, dst: Path, crf: int) -> None:
dst.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"ffmpeg",
"-y",
"-i",
str(src),
"-vf",
"fps=30",
"-vsync",
"cfr",
"-c:v",
"libx264",
"-crf",
str(crf),
"-preset",
"medium",
"-c:a",
"copy",
str(dst),
]
subprocess.run(cmd, check=True)
def parse_args_to_cfg(args=None):
# Put all args to cfg
if args is None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--video", type=str, default="inputs/demo/dance_3.mp4"
)
parser.add_argument(
"--output_root",
type=str,
default=None,
help="by default to outputs/demo",
)
parser.add_argument(
"-s",
"--static_cam",
action="store_true",
help="If true, skip DPVO",
)
parser.add_argument(
"--use_dpvo",
action="store_true",
help="If true, use DPVO. By default not using DPVO.",
)
parser.add_argument(
"--f_mm",
type=int,
default=None,
help="Focal length of fullframe camera in mm. Leave it as None to use default values."
"For iPhone 15p, the [0.5x, 1x, 2x, 3x] lens have typical values [13, 24, 48, 77]."
"If the camera zoom in a lot, you can try 135, 200 or even larger values.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="If true, draw intermediate results",
)
args = parser.parse_args()
# Input
video_path = Path(args.video)
assert video_path.exists(), f"Video not found at {video_path}"
length, width, height = get_video_lwh(video_path)
Log.info(f"[Input]: {video_path}")
Log.info(f"(L, W, H) = ({length}, {width}, {height})")
# Cfg
with initialize_config_module(
version_base="1.3", config_module=f"hmr4d.configs"
):
overrides = [
f"video_name='{video_path.stem}'",
f"static_cam={args.static_cam}",
f"verbose={args.verbose}",
f"use_dpvo={args.use_dpvo}",
]
if args.f_mm is not None:
overrides.append(f"f_mm={args.f_mm}")
# Allow to change output root
if args.output_root is not None:
overrides.append(f"output_root='{args.output_root}'")
register_store_gvhmr()
cfg = compose(config_name="demo", overrides=overrides)
# Output
Log.info(f"[Output Dir]: {cfg.output_dir}")
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.preprocess_dir).mkdir(parents=True, exist_ok=True)
# Copy raw-input-video to video_path
Log.info(f"[Prepare Video] {video_path} -> {cfg.video_path}")
src_len = get_video_lwh(video_path)[0]
dst_path = Path(cfg.video_path)
need_regen = (not dst_path.exists()) or (
get_video_lwh(dst_path)[0] != src_len
)
src_fps = get_video_fps(video_path)
Log.info(f"[Input FPS]: {src_fps:.4f}")
if need_regen:
if is_close_fps(src_fps, 30.0):
Log.info("[FPS OK] ~30fps, copy without re-encoding.")
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(video_path, dst_path)
else:
Log.info("[FPS CONVERT] transcoding to 30fps with constant speed.")
transcode_to_30fps_cfr(video_path, Path(cfg.video_path), CRF)
return cfg
@torch.no_grad()
def run_preprocess(cfg):
Log.info(f"[Preprocess] Start!")
tic = Log.time()
video_path = cfg.video_path
paths = cfg.paths
static_cam = cfg.static_cam
verbose = cfg.verbose
# Get bbx tracking result
if not Path(paths.bbx).exists():
tracker = Tracker()
bbx_xyxy = tracker.get_one_track(video_path).float() # (L, 4)
bbx_xys = get_bbx_xys_from_xyxy(
bbx_xyxy, base_enlarge=1.2
).float() # (L, 3) apply aspect ratio and enlarge
torch.save({"bbx_xyxy": bbx_xyxy, "bbx_xys": bbx_xys}, paths.bbx)
del tracker
else:
bbx_xys = torch.load(paths.bbx)["bbx_xys"]
Log.info(f"[Preprocess] bbx (xyxy, xys) from {paths.bbx}")
if verbose:
video = read_video_np(video_path)
bbx_xyxy = torch.load(paths.bbx)["bbx_xyxy"]
video_overlay = draw_bbx_xyxy_on_image_batch(bbx_xyxy, video)
save_video(video_overlay, cfg.paths.bbx_xyxy_video_overlay)
# Get VitPose
if not Path(paths.vitpose).exists():
vitpose_extractor = VitPoseExtractor()
vitpose = vitpose_extractor.extract(video_path, bbx_xys)
torch.save(vitpose, paths.vitpose)
del vitpose_extractor
else:
vitpose = torch.load(paths.vitpose)
Log.info(f"[Preprocess] vitpose from {paths.vitpose}")
if verbose:
video = read_video_np(video_path)
video_overlay = draw_coco17_skeleton_batch(video, vitpose, 0.5)
save_video(video_overlay, paths.vitpose_video_overlay)
# Get vit features
if not Path(paths.vit_features).exists():
extractor = Extractor()
vit_features = extractor.extract_video_features(video_path, bbx_xys)
torch.save(vit_features, paths.vit_features)
del extractor
else:
Log.info(f"[Preprocess] vit_features from {paths.vit_features}")
# Get visual odometry results
if not static_cam: # use slam to get cam rotation
if not Path(paths.slam).exists():
if not cfg.use_dpvo:
simple_vo = SimpleVO(
cfg.video_path,
scale=0.5,
step=8,
method="sift",
f_mm=cfg.f_mm,
)
vo_results = simple_vo.compute() # (L, 4, 4), numpy
torch.save(vo_results, paths.slam)
else: # DPVO
from hmr4d.utils.preproc.slam import SLAMModel
length, width, height = get_video_lwh(cfg.video_path)
K_fullimg = estimate_K(width, height)
intrinsics = convert_K_to_K4(K_fullimg)
slam = SLAMModel(
video_path,
width,
height,
intrinsics,
buffer=4000,
resize=0.5,
)
bar = tqdm(total=length, desc="DPVO")
while True:
ret = slam.track()
if ret:
bar.update()
else:
break
slam_results = slam.process() # (L, 7), numpy
torch.save(slam_results, paths.slam)
else:
Log.info(f"[Preprocess] slam results from {paths.slam}")
Log.info(f"[Preprocess] End. Time elapsed: {Log.time() - tic:.2f}s")
def load_data_dict(cfg):
paths = cfg.paths
length, width, height = get_video_lwh(cfg.video_path)
if cfg.static_cam:
R_w2c = torch.eye(3).repeat(length, 1, 1)
else:
traj = torch.load(cfg.paths.slam)
if cfg.use_dpvo: # DPVO
traj_quat = torch.from_numpy(traj[:, [6, 3, 4, 5]])
R_w2c = quaternion_to_matrix(traj_quat).mT
else: # SimpleVO
R_w2c = torch.from_numpy(traj[:, :3, :3])
if cfg.f_mm is not None:
K_fullimg = create_camera_sensor(width, height, cfg.f_mm)[2].repeat(
length, 1, 1
)
else:
K_fullimg = estimate_K(width, height).repeat(length, 1, 1)
data = {
"length": torch.tensor(length),
"bbx_xys": torch.load(paths.bbx)["bbx_xys"],
"kp2d": torch.load(paths.vitpose),
"K_fullimg": K_fullimg,
"cam_angvel": compute_cam_angvel(R_w2c),
"f_imgseq": torch.load(paths.vit_features),
}
return data
def save_npz(pred, save_path):
out_dir = Path(save_path).parent
out_dir.mkdir(parents=True, exist_ok=True)
trans = pred["transl"].detach().cpu()
body_pose = torch.cat(
(
pred["global_orient"].detach().cpu(),
pred["body_pose"].detach().cpu(),
),
dim=1,
)
transform1 = sRot.from_euler(
"xyz", np.array([np.pi / 2, 0, np.pi]), degrees=False
)
new_root = (
transform1 * sRot.from_rotvec(body_pose[:, :3].numpy())
).as_rotvec()
body_pose[:, :3] = torch.from_numpy(new_root)
trans = trans @ torch.tensor(transform1.as_matrix().T, dtype=torch.float32)
out_path = out_dir / "smpl.npz"
Log.info(f"npz_path {out_path}")
np.savez(
str(out_path),
betas=pred["betas"][0].detach().cpu().numpy(),
gender="neutral",
poses=body_pose.numpy(),
trans=trans.numpy(),
mocap_framerate=30.0,
)
def render_incam(cfg):
incam_video_path = Path(cfg.paths.incam_video)
if incam_video_path.exists():
Log.info(f"[Render Incam] Video already exists at {incam_video_path}")
return
pred = torch.load(cfg.paths.hmr4d_results)
smplx = make_smplx("supermotion").cuda()
smplx2smpl = torch.load(
"hmr4d/utils/body_model/smplx2smpl_sparse.pt"
).cuda()
faces_smpl = make_smplx("smpl").faces
# smpl
smplx_out = smplx(**to_cuda(pred["smpl_params_incam"]))
pred_c_verts = torch.stack(
[torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices]
)
# -- rendering code -- #
video_path = cfg.video_path
length, width, height = get_video_lwh(video_path)
K = pred["K_fullimg"][0]
# renderer
renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K)
reader = get_video_reader(video_path) # (F, H, W, 3), uint8, numpy
bbx_xys_render = torch.load(cfg.paths.bbx)["bbx_xys"]
# -- render mesh -- #
verts_incam = pred_c_verts
writer = get_writer(incam_video_path, fps=30, crf=CRF)
for i, img_raw in tqdm(
enumerate(reader),
total=get_video_lwh(video_path)[0],
desc=f"Rendering Incam",
):
img = renderer.render_mesh(
verts_incam[i].cuda(), img_raw, [0.8, 0.8, 0.8]
)
# # bbx
# bbx_xys_ = bbx_xys_render[i].cpu().numpy()
# lu_point = (bbx_xys_[:2] - bbx_xys_[2:] / 2).astype(int)
# rd_point = (bbx_xys_[:2] + bbx_xys_[2:] / 2).astype(int)
# img = cv2.rectangle(img, lu_point, rd_point, (255, 178, 102), 2)
writer.write_frame(img)
writer.close()
reader.close()
def render_global(cfg):
global_video_path = Path(cfg.paths.global_video)
# Always save NPZ regardless of whether the video already exists
pred = torch.load(cfg.paths.hmr4d_results)
save_npz(pred["smpl_params_global"], save_path=global_video_path)
if global_video_path.exists():
Log.info(
f"[Render Global] Video already exists at {global_video_path}"
)
return
debug_cam = False
smplx = make_smplx("supermotion").cuda()
smplx2smpl = torch.load(
"hmr4d/utils/body_model/smplx2smpl_sparse.pt"
).cuda()
faces_smpl = make_smplx("smpl").faces
J_regressor = torch.load(
"hmr4d/utils/body_model/smpl_neutral_J_regressor.pt"
).cuda()
# smpl
smplx_out = smplx(**to_cuda(pred["smpl_params_global"]))
# npz already saved above
pred_ay_verts = torch.stack(
[torch.matmul(smplx2smpl, v_) for v_ in smplx_out.vertices]
)
def move_to_start_point_face_z(verts):
"XZ to origin, Start from the ground, Face-Z"
# position
verts = verts.clone() # (L, V, 3)
offset = einsum(J_regressor, verts[0], "j v, v i -> j i")[0] # (3)
offset[1] = verts[:, :, [1]].min()
verts = verts - offset
# face direction
T_ay2ayfz = compute_T_ayfz2ay(
einsum(J_regressor, verts[[0]], "j v, l v i -> l j i"),
inverse=True,
)
verts = apply_T_on_points(verts, T_ay2ayfz)
return verts
verts_glob = move_to_start_point_face_z(pred_ay_verts)
joints_glob = einsum(
J_regressor, verts_glob, "j v, l v i -> l j i"
) # (L, J, 3)
global_R, global_T, global_lights = get_global_cameras_static(
verts_glob.cpu(),
beta=2.0,
cam_height_degree=20,
target_center_height=1.0,
)
# -- rendering code -- #
video_path = cfg.video_path
length, width, height = get_video_lwh(video_path)
_, _, K = create_camera_sensor(width, height, 24) # render as 24mm lens
# renderer
renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K)
# renderer = Renderer(width, height, device="cuda", faces=faces_smpl, K=K, bin_size=0)
# -- render mesh -- #
scale, cx, cz = get_ground_params_from_points(
joints_glob[:, 0], verts_glob
)
renderer.set_ground(scale * 1.5, cx, cz)
color = torch.ones(3).float().cuda() * 0.8
render_length = length if not debug_cam else 8
writer = get_writer(global_video_path, fps=30, crf=CRF)
for i in tqdm(range(render_length), desc=f"Rendering Global"):
cameras = renderer.create_camera(global_R[i], global_T[i])
img = renderer.render_with_ground(
verts_glob[[i]], color[None], cameras, global_lights
)
writer.write_frame(img)
writer.close()
if __name__ == "__main__":
# Top-level parser to support folder batch mode
top_parser = argparse.ArgumentParser()
top_parser.add_argument("--video", type=str, default=None)
top_parser.add_argument("--folder", "-f", type=str, default=None)
top_parser.add_argument("--output_root", "-d", type=str, default=None)
top_parser.add_argument("-s", "--static_cam", action="store_true")
top_parser.add_argument("--use_dpvo", action="store_true")
top_parser.add_argument("--f_mm", type=int, default=None)
top_parser.add_argument("--verbose", action="store_true")
top_args = top_parser.parse_args()
# Batch mode
if top_args.folder is not None:
folder = Path(top_args.folder)
mp4_paths = sorted(
list(folder.glob("*.mp4")) + list(folder.glob("*.MP4"))
)
Log.info(f"Found {len(mp4_paths)} .mp4 files in {folder}")
for mp4_path in tqdm(mp4_paths):
per_args = argparse.Namespace(
video=str(mp4_path),
output_root=top_args.output_root,
static_cam=top_args.static_cam,
use_dpvo=top_args.use_dpvo,
f_mm=top_args.f_mm,
verbose=top_args.verbose,
)
try:
cfg = parse_args_to_cfg(per_args)
paths = cfg.paths
Log.info(f"[GPU]: {torch.cuda.get_device_name()}")
Log.info(f"[GPU]: {torch.cuda.get_device_properties('cuda')}")
run_preprocess(cfg)
data = load_data_dict(cfg)
if not Path(paths.hmr4d_results).exists():
Log.info("[HMR4D] Predicting")
model: DemoPL = hydra.utils.instantiate(
cfg.model, _recursive_=False
)
model.load_pretrained_model(cfg.ckpt_path)
model = model.eval().cuda()
tic = Log.sync_time()
pred = model.predict(data, static_cam=cfg.static_cam)
pred = detach_to_cpu(pred)
data_time = data["length"] / 30
Log.info(
f"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s"
)
torch.save(pred, paths.hmr4d_results)
render_incam(cfg)
render_global(cfg)
if not Path(paths.incam_global_horiz_video).exists():
Log.info("[Merge Videos]")
merge_videos_horizontal(
[paths.incam_video, paths.global_video],
paths.incam_global_horiz_video,
)
except Exception as e:
Log.error(f"Failed on {mp4_path}: {e}")
raise SystemExit(0)
# Single video mode
if top_args.video is None:
top_parser.error("Must provide --video or --folder")
single_args = argparse.Namespace(
video=top_args.video,
output_root=top_args.output_root,
static_cam=top_args.static_cam,
use_dpvo=top_args.use_dpvo,
f_mm=top_args.f_mm,
verbose=top_args.verbose,
)
cfg = parse_args_to_cfg(single_args)
paths = cfg.paths
Log.info(f"[GPU]: {torch.cuda.get_device_name()}")
Log.info(f"[GPU]: {torch.cuda.get_device_properties('cuda')}")
# ===== Preprocess and save to disk ===== #
run_preprocess(cfg)
data = load_data_dict(cfg)
# ===== HMR4D ===== #
if not Path(paths.hmr4d_results).exists():
Log.info("[HMR4D] Predicting")
model: DemoPL = hydra.utils.instantiate(cfg.model, _recursive_=False)
model.load_pretrained_model(cfg.ckpt_path)
model = model.eval().cuda()
tic = Log.sync_time()
pred = model.predict(data, static_cam=cfg.static_cam)
pred = detach_to_cpu(pred)
data_time = data["length"] / 30
Log.info(
f"[HMR4D] Elapsed: {Log.sync_time() - tic:.2f}s for data-length={data_time:.1f}s"
)
torch.save(pred, paths.hmr4d_results)
# ===== Render ===== #
render_incam(cfg)
render_global(cfg)
if not Path(paths.incam_global_horiz_video).exists():
Log.info("[Merge Videos]")
merge_videos_horizontal(
[paths.incam_video, paths.global_video],
paths.incam_global_horiz_video,
)
================================================
FILE: holomotion/src/data_curation/vison_mocap/joints2smpl.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import sys
sys.path.append("../../")
sys.path.append("../../thirdparties/joints2smpl/src")
import h5py
import numpy as np
import smplx
import torch
from scipy.spatial.transform import Rotation
from thirdparties.joints2smpl.src import config
from thirdparties.joints2smpl.src.smplify import SMPLify3D
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
num_joints = 22
joint_category = "AMASS"
num_smplify_iters = 300
fix_foot = False
def joints2smpl(input_joints, save_name):
"""Save the joints as amass-compatible npz file.
Note:
This function depends on the `joints2smpl` repository.
To use this function properly,
you need to manually modify parts of the internal
`joints2smpl` repository code to ensure compatibility.
"""
# print(file_name)
input_joints = input_joints[:, :, [0, 1, 2]] # amass stands on x, y
"""XY at origin"""
input_joints[..., [0, 1]] -= input_joints[0, 0, [0, 1]]
"""Put on Floor"""
floor_height = input_joints[:, :, 2].min()
input_joints[:, :, 2] -= floor_height
batch_size = input_joints.shape[0]
smplmodel = smplx.create(
config.SMPL_MODEL_DIR,
model_type="smpl",
gender="neutral",
ext="npz",
batch_size=batch_size,
).to(device)
# ## --- load the mean pose as original ----
smpl_mean_file = config.SMPL_MEAN_FILE
file = h5py.File(smpl_mean_file, "r")
init_mean_pose = (
torch.from_numpy(file["pose"][:])
.unsqueeze(0)
.repeat(batch_size, 1)
.float()
.to(device)
)
init_mean_shape = (
torch.from_numpy(file["shape"][:])
.unsqueeze(0)
.repeat(batch_size, 1)
.float()
.to(device)
)
cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)
# # #-------------initialize SMPLify
smplify = SMPLify3D(
smplxmodel=smplmodel,
batch_size=batch_size,
joints_category=joint_category,
num_iters=num_smplify_iters,
device=device,
)
keypoints_3d = torch.Tensor(input_joints).to(device).float()
pred_betas = init_mean_shape
pred_pose = init_mean_pose
pred_cam_t = cam_trans_zero
if joint_category == "AMASS":
confidence_input = torch.ones(num_joints)
# make sure the foot and ankle
if fix_foot:
confidence_input[7] = 1.5
confidence_input[8] = 1.5
confidence_input[10] = 1.5
confidence_input[11] = 1.5
else:
print("Such category not settle down!")
(
new_opt_vertices,
new_opt_joints,
new_opt_pose,
new_opt_betas,
new_opt_cam_t,
new_opt_joint_loss,
) = smplify(
pred_pose.detach(),
pred_betas.detach(),
pred_cam_t.detach(),
keypoints_3d,
conf_3d=confidence_input.to(device),
# seq_ind=idx
)
poses = new_opt_pose.detach().cpu().numpy()
betas = new_opt_betas.mean(axis=0).detach().cpu().numpy()
trans = keypoints_3d[:, 0].detach().cpu().numpy()
root_orient = poses[:, :3]
root_mat = Rotation.from_rotvec(root_orient).as_matrix()
rx_minus_90 = Rotation.from_euler("xz", [90, 0], degrees=True).as_matrix()
# rotate_matrix = np.array([[1,0,0],[0,0,-1],[0,1,0]])
# Ry_10 = Rotation.from_euler('z',20,degrees=True).as_matrix()
align_r = rx_minus_90 @ root_mat
# align_r = rotate_matrix@root_mat
align_axis_angle = Rotation.from_matrix(align_r).as_rotvec()
poses[:, :3] = align_axis_angle
input_joints = input_joints[:, :, [0, 2, 1]] # jts stands on x, z
input_joints[..., 0] *= -1
trans_rotated = rx_minus_90 @ (trans.T)
trans_rotated = trans_rotated.T
target_dim = 165
poses_padding = np.zeros((poses.shape[0], target_dim))
if poses.shape[1] < target_dim:
poses_padding[:, : poses.shape[1]] = poses
else:
poses_padding = poses
param = {
"poses": poses_padding,
"trans": trans_rotated,
"betas": betas,
"gender": "neutral",
"jtr": input_joints,
"mocap_frame_rate": 30,
}
np.savez_compressed(save_name, **param)
print(f"successfully save file:{save_name}")
================================================
FILE: holomotion/src/data_curation/visualize_smpl_npz.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import argparse
import http.server
import os
import socket
import socketserver
import subprocess
import sys
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import webview
# -----------------------------
# UI Shell
# -----------------------------
SHELL_HTML = r"""
SMPL NPZ Viewer
Load NPZ
Select an NPZ file…
"""
# -----------------------------
# Config
# -----------------------------
@dataclass(frozen=True)
class AppConfig:
root: Path
port: int
smpl_npz_to_html: Path
template: Path
out_html: Path
window_title: str
width: int
height: int
auto_pick: bool
debug: bool
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description="SMPL NPZ viewer UI.")
ap.add_argument(
"--port",
type=int,
default=8000,
help="Local HTTP port for serving assets.",
)
ap.add_argument(
"--smpl_npz_to_html",
type=Path,
default=Path("smpl_npz_to_html.py"),
help="Path to smpl_npz_to_html.py",
)
ap.add_argument(
"--template",
type=Path,
default=Path("templates/index_wooden_static.html"),
help="HTML template path",
)
ap.add_argument(
"--out",
type=Path,
default=Path("_generated/vis.html"),
help="Output vis.html path",
)
ap.add_argument(
"--title", type=str, default="NPZ Viewer", help="Window title"
)
ap.add_argument("--width", type=int, default=800, help="Window width")
ap.add_argument("--height", type=int, default=600, help="Window height")
ap.add_argument(
"--no-auto-pick",
action="store_false",
help="Do not auto-open file picker at startup",
)
ap.add_argument(
"--debug", action="store_true", help="Enable pywebview debug/devtools"
)
return ap.parse_args()
# -----------------------------
# Utilities
# -----------------------------
def js_escape(s: str) -> str:
return s.replace("\\", "\\\\").replace("'", "\\'")
def ensure_exists(path: Path, what: str) -> None:
if not path.exists():
raise FileNotFoundError(f"Missing {what}: {path}")
def is_port_available(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind(("127.0.0.1", port))
return True
except OSError:
return False
# -----------------------------
# Core: server + generator + UI API
# -----------------------------
class StaticServer:
def __init__(self, root: Path, port: int):
self.root = root
self.port = port
self._thread: Optional[threading.Thread] = None
def start(self) -> None:
def _serve():
os.chdir(self.root) # serve assets from project root
handler = http.server.SimpleHTTPRequestHandler
with socketserver.TCPServer(
("127.0.0.1", self.port), handler
) as httpd:
httpd.serve_forever()
self._thread = threading.Thread(target=_serve, daemon=True)
self._thread.start()
class MakeVisRunner:
def __init__(
self,
root: Path,
smpl_npz_to_html: Path,
template: Path,
out_html: Path,
):
self.root = root
self.smpl_npz_to_html = smpl_npz_to_html
self.template = template
self.out_html = out_html
def run(self, npz_path: Path) -> None:
ensure_exists(self.smpl_npz_to_html, "smpl_npz_to_html.py")
ensure_exists(self.template, "template html")
ensure_exists(npz_path, "npz file")
self.out_html.parent.mkdir(parents=True, exist_ok=True)
cmd = [
sys.executable,
str(self.smpl_npz_to_html),
"--npz",
str(npz_path),
"--template",
str(self.template),
"--out",
str(self.out_html),
]
subprocess.check_call(cmd, cwd=str(self.root))
def pick_npz_dialog(window) -> Optional[Path]:
file_types = ("NPZ files (*.npz)", "All files (*.*)")
# Prefer new enum if available; fallback to deprecated constant.
try:
dialog_open = webview.FileDialog.OPEN # type: ignore[attr-defined]
paths = window.create_file_dialog(
dialog_open, allow_multiple=False, file_types=file_types
)
except Exception:
paths = window.create_file_dialog(
webview.OPEN_DIALOG, allow_multiple=False, file_types=file_types
)
return Path(paths[0]) if paths else None
class UIAPI:
def __init__(self, window, cfg: AppConfig, runner: MakeVisRunner):
self.window = window
self.cfg = cfg
self.runner = runner
self._busy = False
def pick_and_generate(self) -> None:
if self._busy:
return
npz = pick_npz_dialog(self.window)
if npz is None:
return
safe_name = js_escape(npz.name)
self.window.evaluate_js(
f"setBusy(true); setStatus('Generating: {safe_name}');"
)
def worker():
self._busy = True
try:
self.runner.run(npz)
rel = self.cfg.out_html.relative_to(self.cfg.root).as_posix()
self.window.evaluate_js(
f"setBusy(false); setStatus('Loaded: {safe_name}'); "
f"showViewer('http://127.0.0.1:{self.cfg.port}/{rel}');"
)
except Exception as e:
msg = js_escape(str(e))
self.window.evaluate_js(
f"setBusy(false); setStatus('Failed: {msg}');"
)
finally:
self._busy = False
threading.Thread(target=worker, daemon=True).start()
def auto_pick_once(self) -> None:
# Called from window.events.loaded; ensure it runs once.
if getattr(self, "_auto_done", False):
return
setattr(self, "_auto_done", True)
if self.cfg.auto_pick:
self.pick_and_generate()
# -----------------------------
# Entrypoint
# -----------------------------
def build_config(args: argparse.Namespace) -> AppConfig:
root = Path(__file__).resolve().parent
smpl_npz_to_html = (
(root / args.smpl_npz_to_html).resolve()
if not args.smpl_npz_to_html.is_absolute()
else args.smpl_npz_to_html
)
template = (
(root / args.template).resolve()
if not args.template.is_absolute()
else args.template
)
out_html = (
(root / args.out).resolve() if not args.out.is_absolute() else args.out
)
return AppConfig(
root=root,
port=int(args.port),
smpl_npz_to_html=smpl_npz_to_html,
template=template,
out_html=out_html,
window_title=str(args.title),
width=int(args.width),
height=int(args.height),
auto_pick=not bool(args.no_auto_pick),
debug=bool(args.debug),
)
def main() -> None:
args = parse_args()
cfg = build_config(args)
if not is_port_available(cfg.port):
raise RuntimeError(
f"Port {cfg.port} is already in use. Try --port 8001"
)
server = StaticServer(cfg.root, cfg.port)
server.start()
runner = MakeVisRunner(
cfg.root, cfg.smpl_npz_to_html, cfg.template, cfg.out_html
)
window = webview.create_window(
cfg.window_title, html=SHELL_HTML, width=cfg.width, height=cfg.height
)
api = UIAPI(window, cfg, runner)
window.expose(api.pick_and_generate)
# Auto pick once on initial load (optional)
window.events.loaded += lambda: threading.Thread(
target=api.auto_pick_once, daemon=True
).start()
webview.start(debug=cfg.debug)
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/env/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/env/isaaclab_components/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from holomotion.src.env.isaaclab_components.isaaclab_actions import (
build_actions_config,
ActionsCfg,
)
from holomotion.src.env.isaaclab_components.isaaclab_scene import (
build_scene_config,
MotionTrackingSceneCfg,
)
from holomotion.src.env.isaaclab_components.isaaclab_simulator import (
build_simulator_config,
)
from holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import (
build_motion_tracking_commands_config,
MoTrack_CommandsCfg,
)
from holomotion.src.env.isaaclab_components.isaaclab_rewards import (
build_rewards_config,
RewardsCfg,
)
from holomotion.src.env.isaaclab_components.isaaclab_observation import (
build_observations_config,
ObservationsCfg,
)
from holomotion.src.env.isaaclab_components.isaaclab_termination import (
build_terminations_config,
TerminationsCfg,
)
from holomotion.src.env.isaaclab_components.isaaclab_domain_rand import (
build_domain_rand_config,
EventsCfg,
)
from holomotion.src.env.isaaclab_components.isaaclab_curriculum import (
build_curriculum_config,
CurriculumCfg,
)
from holomotion.src.env.isaaclab_components.isaaclab_velocity_tracking_command import (
build_velocity_commands_config,
VelTrack_CommandsCfg,
)
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_actions.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from isaaclab.utils import configclass
import isaaclab.envs.mdp as mdp
class ActionFunctions:
"""Collection of action function implementations."""
@staticmethod
def joint_position_action(
asset_name: str = "robot",
joint_names: list[str] | None = None,
use_default_offset: bool = True,
scale: float = 1.0,
) -> mdp.JointPositionActionCfg:
"""Joint position control action."""
if joint_names is None:
joint_names = [".*"]
return mdp.JointPositionActionCfg(
asset_name=asset_name,
joint_names=joint_names,
use_default_offset=use_default_offset,
scale=scale,
)
@staticmethod
def joint_velocity_action(
asset_name: str = "robot",
joint_names: list[str] | None = None,
scale: float = 1.0,
) -> mdp.JointVelocityActionCfg:
"""Joint velocity control action."""
if joint_names is None:
joint_names = [".*"]
return mdp.JointVelocityActionCfg(
asset_name=asset_name,
joint_names=joint_names,
scale=scale,
)
@staticmethod
def joint_effort_action(
asset_name: str = "robot",
joint_names: list[str] | None = None,
scale: float = 1.0,
) -> mdp.JointEffortActionCfg:
"""Joint effort control action."""
if joint_names is None:
joint_names = [".*"]
return mdp.JointEffortActionCfg(
asset_name=asset_name,
joint_names=joint_names,
scale=scale,
)
@configclass
class ActionsCfg:
"""Container for action terms."""
pass
def build_actions_config(actions_config_dict: dict) -> ActionsCfg:
"""Build IsaacLab-compatible ActionsCfg from a config dictionary."""
actions_cfg = ActionsCfg()
for action_name, action_config in actions_config_dict.items():
action_type = action_config["type"]
params = action_config.get("params", {})
if action_type == "joint_position":
action_term = ActionFunctions.joint_position_action(**params)
elif action_type == "joint_velocity":
action_term = ActionFunctions.joint_velocity_action(**params)
elif action_type == "joint_effort":
action_term = ActionFunctions.joint_effort_action(**params)
else:
raise ValueError(f"Unknown action type: {action_type}")
setattr(actions_cfg, action_name, action_term)
return actions_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_curriculum.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from isaaclab.envs import ManagerBasedRLEnv
import torch
from typing import Sequence
from isaaclab.managers import CurriculumTermCfg
from isaaclab.utils import configclass
import isaaclab.envs.mdp as isaaclab_mdp
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing import Any, Callable, Dict
from loguru import logger
from .isaaclab_domain_rand import DomainRandFunctions
def _completion_rate_curriculum_get_level(
env,
*,
term_tag: str = "default",
metric_key: str = "Metrics/ref_motion/Task/Completion_Rate",
num_updates: int = 5,
cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),
min_steps_per_level: int = 300,
cooldown_steps: int = 0,
apply_on_startup: bool = True,
startup_level: int = 0,
state_prefix: str = "_cr_curr",
):
base_env = getattr(env, "unwrapped", env)
level_key = f"{state_prefix}_level"
startup_key = f"{state_prefix}_startup_applied"
last_up_key = f"{state_prefix}_last_upgrade_step"
level_start_step_key = f"{state_prefix}_level_start_step"
if not hasattr(base_env, level_key):
setattr(base_env, level_key, -1)
if not hasattr(base_env, startup_key):
setattr(base_env, startup_key, False)
if not hasattr(base_env, last_up_key):
setattr(base_env, last_up_key, -(10**18))
if not hasattr(base_env, level_start_step_key):
setattr(base_env, level_start_step_key, 0)
step = int(
getattr(
base_env,
"common_step_counter",
getattr(env, "common_step_counter", 0),
)
)
def _get_completion_stats():
metrics = getattr(base_env, "metrics", None)
if isinstance(metrics, dict) and metric_key in metrics:
val = metrics[metric_key]
val = float(val.item()) if hasattr(val, "item") else float(val)
return val, step
return None
def _thr_for_next(next_level: int) -> float:
if not cr_thresholds:
return 1.0
idx = max(0, min(next_level - 1, len(cr_thresholds) - 1))
return float(cr_thresholds[idx])
stats = _get_completion_stats()
cur_level = int(getattr(base_env, level_key))
changed = False
# -------- startup init --------
if apply_on_startup and not bool(getattr(base_env, startup_key)):
init_level = int(max(0, min(int(startup_level), int(num_updates))))
setattr(base_env, level_key, max(cur_level, init_level))
setattr(base_env, startup_key, True)
setattr(base_env, last_up_key, step)
setattr(base_env, level_start_step_key, step)
cur_level = int(getattr(base_env, level_key))
changed = True
# -------- level upgrade --------
if cur_level < int(num_updates):
if stats is not None:
cr_val, _ = stats
level_start_step = int(getattr(base_env, level_start_step_key))
stayed_steps = int(step - level_start_step)
cooldown_ok = True
if int(cooldown_steps) > 0:
last_up = int(getattr(base_env, last_up_key))
cooldown_ok = (step - last_up) >= int(cooldown_steps)
if cooldown_ok and stayed_steps >= int(min_steps_per_level):
next_level = min(cur_level + 1, int(num_updates))
thr = _thr_for_next(next_level)
if float(cr_val) >= float(thr):
setattr(base_env, level_key, next_level)
setattr(base_env, last_up_key, step)
setattr(base_env, level_start_step_key, step)
cur_level = next_level
changed = True
applied_key = (
f"{state_prefix}_applied_{str(term_tag)}_level_{int(cur_level)}"
)
if not hasattr(base_env, applied_key):
setattr(base_env, applied_key, False)
already_applied = bool(getattr(base_env, applied_key))
need_apply = bool(changed) or (not already_applied)
return int(cur_level), stats, bool(changed), bool(need_apply)
def lin_vel_cmd_levels(
env: ManagerBasedRLEnv,
env_ids: Sequence[int],
reward_term_name: str = "track_lin_vel_xy",
) -> torch.Tensor:
command_term = env.command_manager.get_term("base_velocity")
ranges = command_term.cfg.ranges
limit_ranges = command_term.cfg.limit_ranges
reward_term = env.reward_manager.get_term_cfg(reward_term_name)
reward = (
torch.mean(env.reward_manager._episode_sums[reward_term_name][env_ids])
/ env.max_episode_length_s
)
if env.common_step_counter % env.max_episode_length == 0:
if reward > reward_term.weight * 0.8:
delta_command = torch.tensor([-0.1, 0.1], device=env.device)
ranges.lin_vel_x = torch.clamp(
torch.tensor(ranges.lin_vel_x, device=env.device)
+ delta_command,
limit_ranges.lin_vel_x[0],
limit_ranges.lin_vel_x[1],
).tolist()
ranges.lin_vel_y = torch.clamp(
torch.tensor(ranges.lin_vel_y, device=env.device)
+ delta_command,
limit_ranges.lin_vel_y[0],
limit_ranges.lin_vel_y[1],
).tolist()
return torch.tensor(ranges.lin_vel_x[1], device=env.device)
def ang_vel_cmd_levels(
env: ManagerBasedRLEnv,
env_ids: Sequence[int],
reward_term_name: str = "track_ang_vel_z",
) -> torch.Tensor:
command_term = env.command_manager.get_term("base_velocity")
ranges = command_term.cfg.ranges
limit_ranges = command_term.cfg.limit_ranges
reward_term = env.reward_manager.get_term_cfg(reward_term_name)
reward = (
torch.mean(env.reward_manager._episode_sums[reward_term_name][env_ids])
/ env.max_episode_length_s
)
if env.common_step_counter % env.max_episode_length == 0:
if reward > reward_term.weight * 0.8:
delta_command = torch.tensor([-0.1, 0.1], device=env.device)
ranges.ang_vel_z = torch.clamp(
torch.tensor(ranges.ang_vel_z, device=env.device)
+ delta_command,
limit_ranges.ang_vel_z[0],
limit_ranges.ang_vel_z[1],
).tolist()
return torch.tensor(ranges.ang_vel_z[1], device=env.device)
def robot_friction_range_by_completion_rate(
env: ManagerBasedRLEnv,
env_ids: Sequence[int],
*,
num_updates: int = 5,
cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),
min_steps_per_level: int = 300,
cooldown_steps: int = 0,
state_prefix: str = "_cr_curr",
static_friction_target=(0.3, 1.6),
dynamic_friction_target=(0.3, 1.2),
enforce_dynamic_le_static: bool = True,
asset_name: str = "robot",
body_names: str = ".*",
restitution_range=(0.0, 0.5),
num_buckets: int = 64,
anchor_quantile: float = 0.5,
min_expand_frac: float = 0.0,
):
base_env = getattr(env, "unwrapped", env)
def _quantile(lo: float, hi: float, q: float) -> float:
lo, hi = float(min(lo, hi)), float(max(lo, hi))
q = float(max(0.0, min(1.0, q)))
return lo + (hi - lo) * q
def _compute_ranges(level: int):
level_i = int(max(0, min(level, int(num_updates))))
frac = (
1.0
if int(num_updates) <= 0
else (level_i / float(int(num_updates)))
)
s_lo_t, s_hi_t = map(float, static_friction_target)
d_lo_t, d_hi_t = map(float, dynamic_friction_target)
s_lo_t, s_hi_t = min(s_lo_t, s_hi_t), max(s_lo_t, s_hi_t)
d_lo_t, d_hi_t = min(d_lo_t, d_hi_t), max(d_lo_t, d_hi_t)
s_anchor = _quantile(s_lo_t, s_hi_t, anchor_quantile)
d_anchor = _quantile(d_lo_t, d_hi_t, anchor_quantile)
eps = float(min_expand_frac)
band = eps + (1.0 - eps) * float(max(frac, 0.0))
s_lo = s_anchor - (s_anchor - s_lo_t) * band
s_hi = s_anchor + (s_hi_t - s_anchor) * band
d_lo = d_anchor - (d_anchor - d_lo_t) * band
d_hi = d_anchor + (d_hi_t - d_anchor) * band
s_lo, s_hi = min(s_lo, s_hi), max(s_lo, s_hi)
d_lo, d_hi = min(d_lo, d_hi), max(d_lo, d_hi)
if enforce_dynamic_le_static:
d_hi = min(d_hi, s_hi)
d_lo = min(d_lo, d_hi)
return (
float(s_lo),
float(s_hi),
float(d_lo),
float(d_hi),
float(frac),
int(level_i),
)
level, stats, changed, need_apply = _completion_rate_curriculum_get_level(
env,
term_tag="fric",
num_updates=num_updates,
cr_thresholds=cr_thresholds,
min_steps_per_level=min_steps_per_level,
cooldown_steps=cooldown_steps,
state_prefix=state_prefix,
)
if not need_apply:
return float(level)
s_lo, s_hi, d_lo, d_hi, frac, level_i = _compute_ranges(int(level))
DomainRandFunctions._get_dr_rigid_body_material(
env=env,
env_ids=None,
asset_name=asset_name,
body_names=body_names,
static_friction_range=(s_lo, s_hi),
dynamic_friction_range=(d_lo, d_hi),
restitution_range=tuple(restitution_range),
num_buckets=int(num_buckets),
)
setattr(base_env, f"{state_prefix}_applied_fric_level_{int(level)}", True)
return float(level)
def rigid_body_com_by_completion_rate(
env: ManagerBasedRLEnv,
env_ids: Sequence[int],
*,
num_updates: int = 5,
cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),
min_steps_per_level: int = 300,
cooldown_steps: int = 0,
state_prefix: str = "_cr_curr",
asset_name: str = "robot",
body_names: str = "torso_link",
com_range_target: dict = {
"x": (-0.025, 0.025),
"y": (-0.05, 0.05),
"z": (-0.05, 0.05),
},
anchor_quantile: float = 0.5,
min_expand_frac: float = 0.0,
):
base_env = getattr(env, "unwrapped", env)
def _quantile(lo: float, hi: float, q: float) -> float:
lo, hi = float(min(lo, hi)), float(max(lo, hi))
q = float(max(0.0, min(1.0, q)))
return lo + (hi - lo) * q
level, stats, changed, need_apply = _completion_rate_curriculum_get_level(
env,
term_tag="com",
num_updates=num_updates,
cr_thresholds=cr_thresholds,
min_steps_per_level=min_steps_per_level,
cooldown_steps=cooldown_steps,
state_prefix=state_prefix,
)
if not need_apply:
return float(level)
level_i = int(max(0, min(int(level), int(num_updates))))
frac = (
1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates)))
)
band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float(
max(frac, 0.0)
)
com_range = {}
for axis, (lo_t, hi_t) in com_range_target.items():
lo_t, hi_t = float(lo_t), float(hi_t)
lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t)
anchor = _quantile(lo_t, hi_t, anchor_quantile)
lo = anchor - (anchor - lo_t) * band
hi = anchor + (hi_t - anchor) * band
com_range[axis] = (float(min(lo, hi)), float(max(lo, hi)))
DomainRandFunctions._get_dr_rigid_body_com(
env=env,
env_ids=None,
com_range=com_range,
asset_name=asset_name,
body_names=body_names,
)
setattr(base_env, f"{state_prefix}_applied_com_level_{int(level)}", True)
return float(level)
def default_dof_pos_bias_by_completion_rate(
env: ManagerBasedRLEnv,
env_ids: Sequence[int],
*,
num_updates: int = 5,
cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),
min_steps_per_level: int = 300,
cooldown_steps: int = 0,
state_prefix: str = "_cr_curr",
asset_name: str = "robot",
joint_names: list[str] = (".*"),
pos_distribution_params_target: tuple[float, float] = (-0.01, 0.01),
operation: str = "add",
distribution: str = "uniform",
anchor_quantile: float = 0.5,
min_expand_frac: float = 0.0,
):
base_env = getattr(env, "unwrapped", env)
level, stats, changed, need_apply = _completion_rate_curriculum_get_level(
env,
term_tag="dof",
num_updates=num_updates,
cr_thresholds=cr_thresholds,
min_steps_per_level=min_steps_per_level,
cooldown_steps=cooldown_steps,
state_prefix=state_prefix,
)
if not need_apply:
return float(level)
def _quantile(lo: float, hi: float, q: float) -> float:
lo, hi = float(min(lo, hi)), float(max(lo, hi))
q = float(max(0.0, min(1.0, q)))
return lo + (hi - lo) * q
lo_t, hi_t = map(float, pos_distribution_params_target)
lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t)
level_i = int(max(0, min(int(level), int(num_updates))))
frac = (
1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates)))
)
band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float(
max(frac, 0.0)
)
anchor = _quantile(lo_t, hi_t, anchor_quantile)
lo = anchor - (anchor - lo_t) * band
hi = anchor + (hi_t - anchor) * band
lo, hi = float(min(lo, hi)), float(max(lo, hi))
DomainRandFunctions._get_dr_default_dof_pos_bias(
env=env,
env_ids=None,
asset_name=asset_name,
joint_names=joint_names,
pos_distribution_params=(lo, hi),
operation=operation,
distribution=distribution,
)
setattr(base_env, f"{state_prefix}_applied_dof_level_{int(level)}", True)
return float(level)
def push_by_setting_velocity_range_by_completion_rate(
env: ManagerBasedRLEnv,
env_ids: Sequence[int],
old_value,
*,
num_updates: int = 5,
cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),
min_steps_per_level: int = 300,
cooldown_steps: int = 0,
state_prefix: str = "_cr_curr",
velocity_range_target: dict = {
"x": (-0.5, 0.5),
"y": (-0.5, 0.5),
"z": (-0.2, 0.2),
"roll": (-0.52, 0.52),
"pitch": (-0.52, 0.52),
"yaw": (-0.78, 0.78),
},
anchor_quantile: float = 0.5,
min_expand_frac: float = 0.0,
):
base_env = getattr(env, "unwrapped", env)
def _quantile(lo: float, hi: float, q: float) -> float:
lo, hi = float(min(lo, hi)), float(max(lo, hi))
q = float(max(0.0, min(1.0, q)))
return lo + (hi - lo) * q
level, stats, changed, need_apply = _completion_rate_curriculum_get_level(
env,
term_tag="push",
num_updates=num_updates,
cr_thresholds=cr_thresholds,
min_steps_per_level=min_steps_per_level,
cooldown_steps=cooldown_steps,
state_prefix=state_prefix,
)
if not need_apply:
return isaaclab_mdp.modify_term_cfg.NO_CHANGE
level_i = int(max(0, min(int(level), int(num_updates))))
frac = (
1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates)))
)
band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float(
max(frac, 0.0)
)
new_params = dict(old_value) if isinstance(old_value, dict) else old_value
current_velocity_range = {}
for axis, (lo_t, hi_t) in velocity_range_target.items():
lo_t, hi_t = float(lo_t), float(hi_t)
lo_t, hi_t = min(lo_t, hi_t), max(lo_t, hi_t)
anchor = _quantile(lo_t, hi_t, anchor_quantile)
lo = anchor - (anchor - lo_t) * band
hi = anchor + (hi_t - anchor) * band
current_velocity_range[axis] = [float(min(lo, hi)), float(max(lo, hi))]
if isinstance(new_params, dict) or hasattr(new_params, "__setitem__"):
if isinstance(new_params, dict):
new_params = dict(new_params)
new_params["velocity_range"] = current_velocity_range
else:
new_params["velocity_range"] = current_velocity_range
else:
setattr(new_params, "velocity_range", current_velocity_range)
setattr(base_env, f"{state_prefix}_applied_push_level_{int(level)}", True)
return new_params
def randomize_actuator_gains_by_completion_rate(
env: ManagerBasedRLEnv,
env_ids: Sequence[int],
*,
num_updates: int = 5,
cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),
min_steps_per_level: int = 300,
cooldown_steps: int = 0,
state_prefix: str = "_cr_curr",
asset_name: str = "robot",
body_names: str = ".*",
stiffness_distribution_params_target: tuple[float, float] = (0.9, 1.1),
damping_distribution_params_target: tuple[float, float] = (0.9, 1.1),
operation: str = "scale",
distribution: str = "uniform",
anchor_quantile: float = 0.5,
min_expand_frac: float = 0.0,
):
base_env = getattr(env, "unwrapped", env)
level, stats, changed, need_apply = _completion_rate_curriculum_get_level(
env,
term_tag="gains",
num_updates=num_updates,
cr_thresholds=cr_thresholds,
min_steps_per_level=min_steps_per_level,
cooldown_steps=cooldown_steps,
state_prefix=state_prefix,
)
if not need_apply:
return float(level)
def _quantile(lo: float, hi: float, q: float) -> float:
lo, hi = float(min(lo, hi)), float(max(lo, hi))
q = float(max(0.0, min(1.0, q)))
return lo + (hi - lo) * q
level_i = int(max(0, min(int(level), int(num_updates))))
frac = (
1.0 if int(num_updates) <= 0 else (level_i / float(int(num_updates)))
)
band = float(min_expand_frac) + (1.0 - float(min_expand_frac)) * float(
max(frac, 0.0)
)
# stiffness
ks_lo_t, ks_hi_t = map(float, stiffness_distribution_params_target)
ks_lo_t, ks_hi_t = min(ks_lo_t, ks_hi_t), max(ks_lo_t, ks_hi_t)
ks_anchor = _quantile(ks_lo_t, ks_hi_t, anchor_quantile)
ks_lo = ks_anchor - (ks_anchor - ks_lo_t) * band
ks_hi = ks_anchor + (ks_hi_t - ks_anchor) * band
ks_lo, ks_hi = float(min(ks_lo, ks_hi)), float(max(ks_lo, ks_hi))
# damping
kd_lo_t, kd_hi_t = map(float, damping_distribution_params_target)
kd_lo_t, kd_hi_t = min(kd_lo_t, kd_hi_t), max(kd_lo_t, kd_hi_t)
kd_anchor = _quantile(kd_lo_t, kd_hi_t, anchor_quantile)
kd_lo = kd_anchor - (kd_anchor - kd_lo_t) * band
kd_hi = kd_anchor + (kd_hi_t - kd_anchor) * band
kd_lo, kd_hi = float(min(kd_lo, kd_hi)), float(max(kd_lo, kd_hi))
DomainRandFunctions._get_dr_randomize_actuator_gains(
env=env,
env_ids=None,
asset_name=asset_name,
body_names=body_names,
stiffness_distribution_params=(ks_lo, ks_hi),
damping_distribution_params=(kd_lo, kd_hi),
operation=operation,
distribution=distribution,
)
setattr(base_env, f"{state_prefix}_applied_gains_level_{int(level)}", True)
return float(level)
def reward_term_weight_by_completion_rate(
env,
env_ids,
*,
reward_term_name: str,
final_weight: float,
start_scale: float = 0.1,
num_updates: int = 5,
cr_thresholds=(0.10, 0.20, 0.28, 0.34, 0.40),
min_steps_per_level: int = 300,
cooldown_steps: int = 0,
state_prefix: str = "_cr_curr",
):
base_env = getattr(env, "unwrapped", env)
level, stats, changed, need_apply = _completion_rate_curriculum_get_level(
env,
term_tag=f"reward_{reward_term_name}",
num_updates=num_updates,
cr_thresholds=cr_thresholds,
min_steps_per_level=min_steps_per_level,
cooldown_steps=cooldown_steps,
state_prefix=state_prefix,
)
progress = 1.0 if num_updates <= 0 else float(level) / float(num_updates)
start_weight = float(final_weight) * float(start_scale)
new_weight = start_weight + progress * (float(final_weight) - start_weight)
reward_cfg = env.reward_manager.get_term_cfg(reward_term_name)
old_weight = float(reward_cfg.weight)
if not need_apply:
return float(level)
reward_cfg.weight = float(new_weight)
env.reward_manager.set_term_cfg(reward_term_name, reward_cfg)
setattr(
base_env,
f"{state_prefix}_reward_weight_{reward_term_name}",
float(new_weight),
)
setattr(
base_env,
f"{state_prefix}_applied_reward_{reward_term_name}_level_{int(level)}",
True,
)
return float(level)
@configclass
class CurriculumCfg:
pass
def build_curriculum_config(curriculum_config_dict: dict) -> CurriculumCfg:
"""
Build IsaacLab-compatible CurriculumCfg from a config dictionary.
"""
if isinstance(curriculum_config_dict, (DictConfig, ListConfig)):
curriculum_config_dict = OmegaConf.to_container(
curriculum_config_dict, resolve=True
)
curriculum_cfg = CurriculumCfg()
cfg_dict: Dict[str, Any] = dict(curriculum_config_dict or {})
def _resolve_callable(name: Any) -> Callable:
if callable(name):
return name
if isinstance(name, str) and name.startswith("isaaclab_mdp."):
name = name.split(".", 1)[1]
fn = globals().get(name)
if callable(fn):
return fn
fn = getattr(isaaclab_mdp, name, None)
if callable(fn):
return fn
if hasattr(isaaclab_mdp, "curriculums"):
fn = getattr(isaaclab_mdp.curriculums, name, None)
if callable(fn):
return fn
raise ValueError(f"Unknown curriculum function: {name}")
def _normalize_modify_params(x: Any) -> Any:
if isinstance(x, list):
# many configs express tuples as YAML lists
return tuple(_normalize_modify_params(v) for v in x)
if isinstance(x, dict):
return {k: _normalize_modify_params(v) for k, v in x.items()}
return x
def _fix_params(params: Dict[str, Any]) -> Dict[str, Any]:
params = dict(params or {})
if "modify_fn" in params and isinstance(
params["modify_fn"], (str, Callable)
):
params["modify_fn"] = _resolve_callable(params["modify_fn"])
if "modify_params" in params and isinstance(
params["modify_params"], dict
):
params["modify_params"] = _normalize_modify_params(
params["modify_params"]
)
return params
global_enabled = cfg_dict.pop("enabled", True)
if not global_enabled:
return curriculum_cfg
for term_name, term_cfg in cfg_dict.items():
if term_cfg is None:
term_cfg = {}
if isinstance(term_cfg, bool):
if not term_cfg:
continue
term_cfg = {}
if not isinstance(term_cfg, dict):
raise TypeError(
f"[build_curriculum_config] term '{term_name}' must be a dict/bool/None, got {type(term_cfg)}"
)
if not term_cfg.get("enabled", True):
continue
func_field = term_cfg.get("func", None)
if func_field is None:
func = _resolve_callable(term_name)
else:
func = _resolve_callable(func_field)
params = _fix_params(term_cfg.get("params", {}) or {})
setattr(
curriculum_cfg,
term_name,
CurriculumTermCfg(func=func, params=params),
)
return curriculum_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_domain_rand.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import torch
from typing import Literal
import isaaclab.utils.math as math_utils
from isaaclab.assets import Articulation
import isaaclab.envs.mdp as isaaclab_mdp
from isaaclab.envs.mdp.events import _randomize_prop_by_op
from isaaclab.managers import SceneEntityCfg, EventTermCfg
from isaaclab.utils import configclass
from isaaclab.envs import ManagerBasedEnv
from isaaclab.managers import EventTermCfg
class DomainRandFunctions:
@staticmethod
def _get_dr_default_dof_pos_bias(
env: ManagerBasedEnv,
env_ids: torch.Tensor | None,
asset_name: str = "robot",
joint_names: list[str] = (".*"),
pos_distribution_params: tuple[float, float] | None = None,
operation: Literal["add", "scale", "abs"] = "abs",
distribution: Literal[
"uniform", "log_uniform", "gaussian"
] = "uniform",
):
asset_cfg = SceneEntityCfg(asset_name, joint_names=joint_names)
asset_cfg.resolve(env.scene)
asset: Articulation = env.scene[asset_name]
asset.data.default_joint_pos_nominal = torch.clone(
asset.data.default_joint_pos[0]
)
if env_ids is None:
env_ids = torch.arange(env.scene.num_envs, device=asset.device)
if asset_cfg.joint_ids == slice(None):
joint_ids = slice(None)
else:
joint_ids = torch.tensor(
asset_cfg.joint_ids,
dtype=torch.int,
device=asset.device,
)
if pos_distribution_params is not None:
pos = asset.data.default_joint_pos.to(asset.device).clone()
pos = _randomize_prop_by_op(
pos,
pos_distribution_params,
env_ids,
joint_ids,
operation=operation,
distribution=distribution,
)[env_ids][:, joint_ids]
if env_ids != slice(None) and joint_ids != slice(None):
env_ids = env_ids[:, None]
asset.data.default_joint_pos[env_ids, joint_ids] = pos
env.action_manager.get_term("dof_pos")._offset[
env_ids, joint_ids
] = pos
@staticmethod
def _get_dr_rigid_body_com(
env: ManagerBasedEnv,
env_ids: torch.Tensor | None,
com_range: dict[str, tuple[float, float]],
asset_name: str = "robot",
body_names: str = "torso_link",
):
asset_cfg = SceneEntityCfg(asset_name, body_names=body_names)
asset_cfg.resolve(env.scene)
return isaaclab_mdp.events.randomize_rigid_body_com(
env,
env_ids,
com_range,
asset_cfg,
)
@staticmethod
def _get_dr_rigid_body_material(
env: ManagerBasedEnv,
env_ids: torch.Tensor | None,
asset_name: str = "robot",
body_names: str = ".*",
static_friction_range: tuple[float, float] | None = None,
dynamic_friction_range: tuple[float, float] | None = None,
restitution_range: tuple[float, float] | None = None,
num_buckets: int = 64,
):
asset_cfg = SceneEntityCfg(asset_name, body_names=body_names)
asset_cfg.resolve(env.scene)
eveent_cfg = EventTermCfg(
func=isaaclab_mdp.events.randomize_rigid_body_material,
params={
"asset_cfg": asset_cfg,
"static_friction_range": static_friction_range,
"dynamic_friction_range": dynamic_friction_range,
"restitution_range": restitution_range,
"num_buckets": num_buckets,
},
)
material_randomizer = (
isaaclab_mdp.events.randomize_rigid_body_material(eveent_cfg, env)
)
return material_randomizer(env, env_ids, **eveent_cfg.params)
@staticmethod
def _get_dr_push_by_setting_velocity(
env: ManagerBasedEnv,
env_ids: torch.Tensor,
velocity_range: dict[str, tuple[float, float]],
):
return isaaclab_mdp.events.push_by_setting_velocity(
env,
env_ids,
velocity_range,
)
@staticmethod
def _get_dr_randomize_actuator_gains(
env: ManagerBasedEnv,
env_ids: torch.Tensor,
asset_name: str = "robot",
body_names: str = ".*",
stiffness_distribution_params: tuple[float, float] | None = None,
damping_distribution_params: tuple[float, float] | None = None,
operation: Literal["add", "scale", "abs"] = "abs",
distribution: Literal[
"uniform", "log_uniform", "gaussian"
] = "uniform",
):
asset_cfg = SceneEntityCfg(asset_name, body_names=body_names)
asset_cfg.resolve(env.scene)
return isaaclab_mdp.events.randomize_actuator_gains(
env,
env_ids,
asset_cfg,
stiffness_distribution_params,
damping_distribution_params,
operation=operation,
distribution=distribution,
)
@staticmethod
def _get_dr_randomize_mass(
env: ManagerBasedEnv,
env_ids: torch.Tensor,
asset_name: str = "robot",
body_names: str = ".*",
mass_range: tuple[float, float] | None = None,
):
asset_cfg = SceneEntityCfg(asset_name, body_names=body_names)
asset_cfg.resolve(env.scene)
return isaaclab_mdp.events.randomize_rigid_body_mass(
env,
env_ids,
mass_distribution_params=mass_range,
asset_cfg=asset_cfg,
operation="add",
)
@configclass
class EventsCfg:
pass
def build_domain_rand_config(domain_rand_config_dict: dict) -> EventsCfg:
"""Build IsaacLab-compatible EventsCfg from a config dictionary."""
events_cfg = EventsCfg()
for event_name, cfg in domain_rand_config_dict.items():
# Keep non-event config under `domain_rand` available for Hydra
# references without forcing it through the Isaac Lab event builder.
if not (isinstance(cfg, dict) and "mode" in cfg):
continue
try:
func = getattr(DomainRandFunctions, f"_get_dr_{event_name}")
except AttributeError as exc:
raise AttributeError(
f"Unknown domain randomization event '{event_name}'"
) from exc
term = EventTermCfg(
func=func,
**cfg,
)
setattr(events_cfg, event_name, term)
return events_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_motion_tracking_command.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from dataclasses import MISSING
from typing import Sequence
import time
import json
from collections import defaultdict
from typing import Dict, List, Optional
import numpy as np
from tqdm import tqdm
from scipy.spatial.transform import Rotation as sRot
import isaaclab.envs.mdp as mdp
import isaaclab.sim as sim_utils
import isaaclab.utils.math as isaaclab_math
import torch
from isaaclab.actuators import ImplicitActuatorCfg
from isaaclab.assets import Articulation, ArticulationCfg, AssetBaseCfg
from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg
from isaaclab.envs.mdp.actions import JointEffortActionCfg
from isaaclab.managers import (
ActionTermCfg,
CommandTerm,
CommandTermCfg,
EventTermCfg as EventTerm,
ObservationGroupCfg,
ObservationGroupCfg as ObsGroup,
ObservationTermCfg,
ObservationTermCfg as ObsTerm,
RewardTermCfg,
TerminationTermCfg,
)
from isaaclab.markers import (
VisualizationMarkers,
VisualizationMarkersCfg,
)
from holomotion.src.training.h5_dataloader import (
Hdf5MotionDataset,
Hdf5RootDofDataset,
MotionClipBatchCache,
build_motion_datasets_from_cfg,
)
import os
from isaaclab.markers.config import SPHERE_MARKER_CFG
from isaaclab.sim import PreviewSurfaceCfg
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns
from isaaclab.sim import PhysxCfg, SimulationCfg
from isaaclab.terrains import TerrainImporterCfg
from isaaclab.utils import configclass
from isaaclab.utils.noise import AdditiveUniformNoiseCfg as Unoise
from omegaconf import OmegaConf
from holomotion.src.utils.isaac_utils.rotations import (
calc_heading_quat_inv,
get_euler_xyz,
my_quat_rotate,
quat_inverse,
quat_mul,
quat_rotate,
quat_rotate_inverse,
quaternion_to_matrix,
wrap_to_pi,
wxyz_to_xyzw,
xyzw_to_wxyz,
)
from holomotion.src.utils.reference_prefix import (
resolve_reference_tensor_key,
)
from loguru import logger
class RefMotionCommand(CommandTerm):
cfg: CommandTermCfg
def __init__(
self,
cfg,
env: ManagerBasedRLEnv,
):
# print(cfg)
super().__init__(cfg, env)
self._env = env
self._is_evaluating = self.cfg.is_evaluating
self._runtime_process_id = int(self.cfg.process_id)
self._runtime_num_processes = max(1, int(self.cfg.num_processes))
self._init_robot_handle()
self._init_buffers()
self._init_motion_lib()
# # self._init_tracking_config()
def _init_tracking_config(self, config):
self.log_dict_holomotion = {}
self.log_dict_nonreduced_holomotion = {}
self.log_dict_nonreduced = {}
self.log_dict = {}
if "head_hand_bodies" in config:
self.motion_tracking_id = [
self.robot.body_names.index(link)
for link in config.head_hand_bodies
]
if "leg_body_names" in config:
self.lower_body_id = [
self.robot.body_names.index(link)
for link in config.leg_body_names
]
if "arm_body_names" in config:
self.upper_body_id = [
self.robot.body_names.index(link)
for link in config.arm_body_names
]
if "leg_dof_names" in config:
self.lower_body_joint_ids = [
config.dof_names.index(link) for link in config.leg_dof_names
]
if "arm_dof_names" in config:
self.upper_body_joint_ids = [
config.dof_names.index(link) for link in config.arm_dof_names
]
if "waist_dof_names" in config:
self.waist_dof_indices = [
config.dof_names.index(link) for link in config.waist_dof_names
]
@staticmethod
def _amp_filter_names_by_prefix(
names: Sequence[str], prefix: str, keywords: Sequence[str]
) -> list[str]:
return [
name
for name in names
if name.startswith(prefix) and any(key in name for key in keywords)
]
@staticmethod
def _amp_pick_first_name(
names: Sequence[str], patterns: Sequence[str]
) -> str | None:
for pattern in patterns:
for name in names:
if pattern in name:
return name
return None
def _resolve_motion_cache_stage_device(
self, cache_cfg: Dict[str, object]
) -> Optional[torch.device]:
raw_stage_device = cache_cfg.get("device", "cuda")
if isinstance(raw_stage_device, torch.device):
if raw_stage_device.type == "cpu":
return None
if raw_stage_device.type != "cuda":
raise ValueError(
f"Unsupported motion cache device: {raw_stage_device}"
)
if raw_stage_device.index is not None:
return raw_stage_device
if not torch.cuda.is_available():
return None
local_rank_env = os.environ.get("LOCAL_RANK")
if local_rank_env is not None:
local_rank = int(local_rank_env)
device_count = int(torch.cuda.device_count())
if 0 <= local_rank < device_count:
return torch.device("cuda", local_rank)
return torch.device("cuda", int(torch.cuda.current_device()))
stage_device = str(raw_stage_device).strip().lower()
if stage_device in ("none", "cpu"):
return None
if stage_device == "cuda":
if isinstance(self.device, torch.device):
if self.device.type == "cuda":
return self.device
return None
device_str = str(self.device).strip().lower()
if device_str.startswith("cuda"):
return torch.device(device_str)
if not torch.cuda.is_available():
return None
local_rank_env = os.environ.get("LOCAL_RANK")
if local_rank_env is not None:
local_rank = int(local_rank_env)
device_count = int(torch.cuda.device_count())
if 0 <= local_rank < device_count:
return torch.device("cuda", local_rank)
return torch.device("cuda", int(torch.cuda.current_device()))
if stage_device.startswith("cuda:"):
return torch.device(stage_device)
raise ValueError(
f"Unsupported motion cache device config: {raw_stage_device}"
)
def _init_motion_lib(self):
mcfg = OmegaConf.create(self.cfg.motion_lib_cfg)
self.mcfg = mcfg
backend = str(mcfg.get("backend", "hdf5")).lower()
self._motion_cache = None
if backend in ("hdf5", "hdf5_simple"):
# Support multi-root configuration while keeping single-root
# behavior fully backward compatible.
train_hdf5_roots = mcfg.get("train_hdf5_roots", None)
val_hdf5_roots = mcfg.get("val_hdf5_roots", None)
if train_hdf5_roots:
train_roots = [str(r) for r in train_hdf5_roots]
else:
hdf5_root = mcfg.get("hdf5_root")
if hdf5_root is None:
raise ValueError("hdf5_root is required")
train_roots = [str(hdf5_root)]
val_hdf5_root = mcfg.get("val_hdf5_root", None)
if val_hdf5_roots:
val_roots = [str(r) for r in val_hdf5_roots]
elif val_hdf5_root is not None and str(val_hdf5_root) != str(
train_roots[0]
):
val_roots = [str(val_hdf5_root)]
else:
val_roots = None
train_manifest_paths = [
os.path.join(root, "manifest.json") for root in train_roots
]
for mp in train_manifest_paths:
if not os.path.exists(mp):
raise FileNotFoundError(
f"HDF5 manifest not found at {mp}. "
"Please set robot.motion.hdf5_root/train_hdf5_roots to "
"the correct path!"
)
max_frame_length = int(mcfg.get("max_frame_length", 500))
min_frame_length = int(mcfg.get("min_frame_length", 1))
world_frame_norm = bool(
mcfg.get("world_frame_normalization", True)
)
cache_cfg = mcfg.get("cache", {})
allowed_prefixes = cache_cfg.get(
"allowed_prefixes",
["ref_", "ft_ref_"],
)
if len(train_manifest_paths) == 1:
logger.info(
f"Loading HDF5 training dataset from {train_manifest_paths[0]}"
)
else:
logger.info(
f"Loading HDF5 training dataset from manifests: "
f"{train_manifest_paths}"
)
train_dataset = Hdf5MotionDataset(
manifest_path=train_manifest_paths
if len(train_manifest_paths) > 1
else train_manifest_paths[0],
max_frame_length=max_frame_length,
min_window_length=min_frame_length,
handpicked_motion_names=mcfg.get(
"handpicked_motion_names", None
),
excluded_motion_names=mcfg.get("excluded_motion_names", None),
world_frame_normalization=world_frame_norm,
allowed_prefixes=allowed_prefixes,
)
if len(train_dataset) == 0:
raise ValueError(
"Training dataset is empty. Check that all manifests "
"contain valid clips with length "
f">= {min_frame_length}"
)
logger.info(f"Loaded {len(train_dataset)} training motion windows")
train_num_clips = len(train_dataset.clips)
train_total_frames = sum(
int(meta.get("length", 0))
for meta in train_dataset.clips.values()
)
fps_used = int(self.cfg.target_fps)
train_duration_s = (
float(train_total_frames) / float(fps_used)
if fps_used > 0
else 0.0
)
if len(train_roots) == 1:
logger.info(
f"Train dataset: root={train_roots[0]}, "
f"manifest={train_manifest_paths[0]}"
)
else:
logger.info(
f"Train dataset: roots={train_roots}, "
f"manifests={train_manifest_paths}"
)
logger.info(
f"Train clips={train_num_clips}, frames={train_total_frames}, "
f"duration={train_duration_s / 3600:.2f}h @ {fps_used} fps"
)
excluded_names = mcfg.get("excluded_motion_names", None)
if excluded_names:
excluded_set = set(excluded_names)
excluded_clip_keys = [
k for k in train_dataset.clips.keys() if k in excluded_set
]
excluded_num_clips = len(excluded_clip_keys)
excluded_total_frames = sum(
int(train_dataset.clips[k].get("length", 0))
for k in excluded_clip_keys
)
excluded_duration_s = (
float(excluded_total_frames) / float(fps_used)
if fps_used > 0
else 0.0
)
left_num_clips = max(0, train_num_clips - excluded_num_clips)
left_total_frames = max(
0, train_total_frames - excluded_total_frames
)
left_duration_s = (
float(left_total_frames) / float(fps_used)
if fps_used > 0
else 0.0
)
logger.info(
f"Excluded (by name): clips={excluded_num_clips}, "
f"frames={excluded_total_frames}, "
f"duration={excluded_duration_s / 3600:.2f}h"
)
logger.info(
f"Remaining after exclusion: clips={left_num_clips}, "
f"frames={left_total_frames}, "
f"duration={left_duration_s / 3600:.2f}h"
)
val_dataset = None
if val_roots is not None:
val_manifest_paths = [
os.path.join(root, "manifest.json") for root in val_roots
]
for mp in val_manifest_paths:
if not os.path.exists(mp):
raise FileNotFoundError(
f"HDF5 validation manifest not found at {mp}. "
"Please set robot.motion.val_hdf5_root/"
"val_hdf5_roots to the correct path!"
)
if len(val_manifest_paths) == 1:
logger.info(
f"Loading HDF5 validation dataset from {val_manifest_paths[0]}"
)
else:
logger.info(
"Loading HDF5 validation dataset from manifests: "
f"{val_manifest_paths}"
)
val_dataset = Hdf5MotionDataset(
manifest_path=val_manifest_paths
if len(val_manifest_paths) > 1
else val_manifest_paths[0],
max_frame_length=max_frame_length,
min_window_length=min_frame_length,
handpicked_motion_names=mcfg.get(
"handpicked_motion_names", None
),
excluded_motion_names=mcfg.get(
"excluded_motion_names", None
),
world_frame_normalization=world_frame_norm,
allowed_prefixes=allowed_prefixes,
)
logger.info(
f"Loaded {len(val_dataset)} validation motion windows"
)
val_num_clips = len(val_dataset.clips)
val_total_frames = sum(
int(meta.get("length", 0))
for meta in val_dataset.clips.values()
)
val_duration_s = (
float(val_total_frames) / float(fps_used)
if fps_used > 0
else 0.0
)
if len(val_roots) == 1:
logger.info(
f"Val dataset: root={val_roots[0]}, "
f"manifest={val_manifest_paths[0]}"
)
else:
logger.info(
f"Val dataset: roots={val_roots}, "
f"manifests={val_manifest_paths}"
)
logger.info(
f"Val clips={val_num_clips}, frames={val_total_frames}, "
f"duration={val_duration_s / 3600:.1f}h @ {fps_used} fps"
)
else:
logger.info(
"Validation dataset: using training dataset "
"(no separate val manifest found)"
)
dataloader_cfg = mcfg.get("dataloader", {})
stage_device = self._resolve_motion_cache_stage_device(cache_cfg)
self._motion_cache = MotionClipBatchCache(
train_dataset=train_dataset,
val_dataset=val_dataset,
batch_size=int(cache_cfg.get("max_num_clips", 1024)),
stage_device=stage_device,
num_workers=int(dataloader_cfg.get("num_workers", 4)),
prefetch_factor=dataloader_cfg.get("prefetch_factor", None),
pin_memory=bool(dataloader_cfg.get("pin_memory", True)),
persistent_workers=bool(
dataloader_cfg.get("persistent_workers", True)
),
batch_progress_bar=bool(
cache_cfg.get("batch_progress_bar", False)
),
sampler_rank=int(self.cfg.process_id),
sampler_world_size=int(self.cfg.num_processes),
allowed_prefixes=allowed_prefixes,
swap_interval_steps=int(
cache_cfg.get("swap_interval_steps", max_frame_length)
),
seed=int(self.cfg.seed),
loader_timeout=float(dataloader_cfg.get("timeout", 0.0)),
)
cache = self._motion_cache
logger.info(
"DataLoader params: "
f"batch_size={cache._batch_size}, "
f"num_workers={cache._num_workers}, "
f"prefetch_factor={cache._prefetch_factor}, "
f"pin_memory={cache._pin_memory}, "
f"persistent_workers={cache._persistent_workers}"
)
logger.info(
"Sampler/Cache params: "
f"rank={cache._sampler_rank}/{cache._sampler_world_size}, "
f"device={cache._stage_device}, "
f"swap_interval_steps={cache.swap_interval_steps}"
)
self._motion_lib = None
elif backend == "hdf5_v2":
max_frame_length = int(mcfg.get("max_frame_length", 500))
min_frame_length = int(mcfg.get("min_frame_length", 1))
world_frame_norm = bool(
mcfg.get("world_frame_normalization", True)
)
cache_cfg = mcfg.get("cache", {})
allowed_prefixes = cache_cfg.get(
"allowed_prefixes",
["ref_", "ft_ref_"],
)
train_hdf5_roots = mcfg.get("train_hdf5_roots", None)
if train_hdf5_roots:
train_roots = [str(r) for r in train_hdf5_roots]
else:
hdf5_root = mcfg.get("hdf5_root", None)
train_roots = [str(hdf5_root)] if hdf5_root is not None else []
train_manifest_paths = [
os.path.join(root, "manifest.json") for root in train_roots
]
(
train_dataset,
val_dataset,
cache_kwargs,
) = build_motion_datasets_from_cfg(
motion_cfg=mcfg,
max_frame_length=max_frame_length,
min_window_length=min_frame_length,
world_frame_normalization=world_frame_norm,
handpicked_motion_names=mcfg.get(
"handpicked_motion_names", None
),
excluded_motion_names=mcfg.get("excluded_motion_names", None),
allowed_prefixes=allowed_prefixes,
)
if len(train_dataset) == 0:
raise ValueError(
"Training dataset is empty. Check that all HDF5 v2 "
"roots contain valid clips with length "
f">= {min_frame_length}"
)
if len(train_manifest_paths) == 1:
logger.info(
f"Loading HDF5 v2 training dataset from {train_manifest_paths[0]}"
)
else:
logger.info(
"Loading HDF5 v2 training dataset from manifests: "
f"{train_manifest_paths}"
)
fps_used = int(self.cfg.target_fps)
logger.info(f"Loaded {len(train_dataset)} training motion windows")
train_num_clips = len(train_dataset.clips)
train_total_frames = sum(
int(meta.get("length", 0))
for meta in train_dataset.clips.values()
)
train_duration_s = (
float(train_total_frames) / float(fps_used)
if fps_used > 0
else 0.0
)
logger.info(
f"Train clips={train_num_clips}, frames={train_total_frames}, "
f"duration={train_duration_s / 3600:.2f}h @ {fps_used} fps"
)
if len(train_roots) == 1:
logger.info(
f"Train dataset: root={train_roots[0]}, "
f"manifest={train_manifest_paths[0]}"
)
elif len(train_roots) > 1:
logger.info(
f"Train dataset: roots={train_roots}, "
f"manifests={train_manifest_paths}"
)
excluded_names = mcfg.get("excluded_motion_names", None)
if excluded_names:
excluded_set = set(excluded_names)
excluded_clip_keys: List[str] = []
if isinstance(train_dataset, Hdf5RootDofDataset):
for key, meta in train_dataset.clips.items():
aliases = train_dataset._build_motion_key_aliases(
key, meta
)
if any(alias in excluded_set for alias in aliases):
excluded_clip_keys.append(key)
else:
excluded_clip_keys = [
k
for k in train_dataset.clips.keys()
if k in excluded_set
]
excluded_num_clips = len(excluded_clip_keys)
excluded_total_frames = sum(
int(train_dataset.clips[k].get("length", 0))
for k in excluded_clip_keys
)
excluded_duration_s = (
float(excluded_total_frames) / float(fps_used)
if fps_used > 0
else 0.0
)
remaining_num_clips = train_num_clips - excluded_num_clips
remaining_total_frames = (
train_total_frames - excluded_total_frames
)
remaining_duration_s = train_duration_s - excluded_duration_s
logger.info(
"Excluded (by name): "
f"clips={excluded_num_clips}, frames={excluded_total_frames}, "
f"duration={excluded_duration_s / 3600:.2f}h"
)
logger.info(
"Remaining after exclusion: "
f"clips={remaining_num_clips}, frames={remaining_total_frames}, "
f"duration={remaining_duration_s / 3600:.2f}h"
)
if val_dataset is None:
logger.info(
"Validation dataset: using training dataset "
"(no separate val HDF5 v2 roots found)"
)
dataloader_cfg = mcfg.get("dataloader", {})
stage_device = self._resolve_motion_cache_stage_device(cache_cfg)
self._motion_cache = MotionClipBatchCache(
train_dataset=train_dataset,
val_dataset=val_dataset,
batch_size=int(cache_cfg.get("max_num_clips", 1024)),
stage_device=stage_device,
num_workers=int(dataloader_cfg.get("num_workers", 4)),
prefetch_factor=dataloader_cfg.get("prefetch_factor", None),
pin_memory=bool(dataloader_cfg.get("pin_memory", True)),
persistent_workers=bool(
dataloader_cfg.get("persistent_workers", True)
),
batch_progress_bar=bool(
cache_cfg.get("batch_progress_bar", False)
),
sampler_rank=int(self.cfg.process_id),
sampler_world_size=int(self.cfg.num_processes),
allowed_prefixes=allowed_prefixes,
swap_interval_steps=int(
cache_cfg.get("swap_interval_steps", max_frame_length)
),
seed=int(self.cfg.seed),
loader_timeout=float(dataloader_cfg.get("timeout", 0.0)),
**cache_kwargs,
)
cache = self._motion_cache
logger.info(
"DataLoader params: "
f"batch_size={cache._batch_size}, "
f"num_workers={cache._num_workers}, "
f"prefetch_factor={cache._prefetch_factor}, "
f"pin_memory={cache._pin_memory}, "
f"persistent_workers={cache._persistent_workers}"
)
logger.info(
"Sampler/Cache params: "
f"rank={cache._sampler_rank}/{cache._sampler_world_size}, "
f"device={cache._stage_device}, "
f"swap_interval_steps={cache.swap_interval_steps}"
)
self._motion_lib = None
else:
raise ValueError(f"Unsupported motion backend: {backend}")
sampling_strategy_cfg = mcfg.get("sampling_strategy", None)
if sampling_strategy_cfg is None:
sampling_strategy = "uniform"
else:
sampling_strategy = str(sampling_strategy_cfg).lower()
if sampling_strategy == "weighted_bin":
weighted_bin_cfg = mcfg.get("weighted_bin", {})
self._motion_cache.enable_weighted_bin_sampling(
cfg=dict(weighted_bin_cfg or {})
)
elif sampling_strategy == "curriculum":
curriculum_cfg = dict(mcfg.get("curriculum", {}) or {})
self._motion_cache.enable_cache_curriculum_sampling(
cfg=curriculum_cfg
)
elif sampling_strategy not in ("uniform", "curriculum"):
raise ValueError(
f"Invalid sampling_strategy '{sampling_strategy}'. "
"Expected one of ['curriculum', 'uniform', 'weighted_bin']."
)
self._sampling_strategy = sampling_strategy
self._init_per_env_cache()
def setup_dumping_dir(self, log_dir: str):
mcfg = self.mcfg
base_log_dir = str(log_dir)
if self._sampling_strategy == "curriculum":
curriculum_dump_dir = os.path.join(
base_log_dir, "cache_curriculum_window_scores"
)
self._motion_cache.set_cache_curriculum_dump_dir(
curriculum_dump_dir
)
self._dump_sampled_motion_keys_enabled = bool(
mcfg.get("dump_sampled_motion_keys", False)
)
if not self._dump_sampled_motion_keys_enabled:
return
self._dump_sampled_motion_keys_interval = max(
1, int(mcfg.get("dump_sampled_motion_keys_interval", 1))
)
dump_dir_cfg = "sampled_motion_cache_keys"
self._dump_sampled_motion_keys_dir = os.path.join(
base_log_dir, dump_dir_cfg
)
if self._dump_sampled_motion_keys_enabled:
os.makedirs(self._dump_sampled_motion_keys_dir, exist_ok=True)
logger.info(
f"Dumping sampled motion keys to {self._dump_sampled_motion_keys_dir}"
)
def set_runtime_distributed_context(
self, *, process_id: int, num_processes: int
) -> None:
self._runtime_process_id = int(process_id)
self._runtime_num_processes = max(1, int(num_processes))
def set_motion_cache_seed(
self, seed: int, *, reinitialize: bool = True
) -> None:
self._motion_cache.set_seed(int(seed), reinitialize=reinitialize)
if reinitialize:
self._init_per_env_cache()
def close(self) -> None:
"""Release motion cache resources for this command term."""
if self._motion_cache is not None:
self._motion_cache.close()
self._motion_cache = None
def _init_per_env_cache(self):
"""Initialize per-env cache for motion tracking."""
self._clip_indices = torch.zeros(
self.num_envs, dtype=torch.long, device=self.device
)
self._frame_indices = torch.zeros(
self.num_envs, dtype=torch.long, device=self.device
)
self._swap_pending = False
self._swap_step_counter = 0
# Initial assignment
clip_idx, frame_idx = self._motion_cache.sample_env_assignments(
self.num_envs,
self.cfg.n_fut_frames,
self.device,
deterministic_start=(self._is_evaluating),
)
self._clip_indices[:] = clip_idx
self._frame_indices[:] = frame_idx
self._start_frame_indices[:] = frame_idx
self._reward_sum_since_assign[:] = 0.0
self._step_count_since_assign[:] = 0.0
self._update_ref_motion_state_from_cache()
def _maybe_dump_sampled_motion_keys(self) -> None:
if not self._dump_sampled_motion_keys_enabled:
return
swap_index = int(self._motion_cache.swap_index)
if swap_index <= 0:
return
if swap_index % self._dump_sampled_motion_keys_interval != 0:
return
current_batch = self._motion_cache.current_batch
window_indices = current_batch.window_indices.detach().cpu().tolist()
cache_scores = None
cache_selection_counts = None
cache_in_prioritized_pool = None
curriculum_state_step = None
score_bundle = (
self._motion_cache.cache_curriculum_scores_for_window_indices(
current_batch.window_indices
)
)
if score_bundle is not None:
score_tensor, state, version = score_bundle
cache_scores = score_tensor.detach().cpu().tolist()
cache_selection_counts = (
state["selection_count"].detach().cpu().tolist()
)
cache_in_prioritized_pool = (
state["in_prioritized_pool"].detach().cpu().tolist()
)
curriculum_state_step = int(version)
payload = {
"swap_index": swap_index,
"sampling_strategy": str(self._sampling_strategy),
"num_keys": int(len(current_batch.motion_keys)),
"motion_keys": list(current_batch.motion_keys),
"raw_motion_keys": list(current_batch.raw_motion_keys),
"window_indices": window_indices,
"cache_sampling_score": cache_scores,
"cache_sampling_count": cache_selection_counts,
"cache_in_prioritized_pool": cache_in_prioritized_pool,
"curriculum_state_step": curriculum_state_step,
}
file_name = (
f"sampled_motion_keys_rank_{self._runtime_process_id:04d}_swap_"
f"{swap_index:06d}.json"
)
output_path = os.path.join(
self._dump_sampled_motion_keys_dir, file_name
)
with open(output_path, "w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
handle.write("\n")
def _init_robot_handle(self):
self.robot: Articulation = self._env.scene[self.cfg.asset_name]
self.anchor_bodylink_name = self.cfg.anchor_bodylink_name
self.anchor_bodylink_idx = self.robot.body_names.index(
self.anchor_bodylink_name
)
self.urdf_dof_names = self.cfg.urdf_dof_names
self.urdf_body_names = self.cfg.urdf_body_names
self.simulator_dof_names = self.robot.joint_names
self.simulator_body_names = self.robot.body_names
self.urdf2sim_dof_idx = [
self.urdf_dof_names.index(dof) for dof in self.simulator_dof_names
]
self.urdf2sim_body_idx = [
self.urdf_body_names.index(body)
for body in self.simulator_body_names
]
self.sim2urdf_dof_idx = [
self.simulator_dof_names.index(dof) for dof in self.urdf_dof_names
]
self.sim2urdf_body_idx = [
self.simulator_body_names.index(body)
for body in self.urdf_body_names
]
self.arm_dof_indices = [
self.simulator_dof_names.index(dof)
for dof in self.cfg.arm_dof_names
]
self.torso_dof_indices = [
self.simulator_dof_names.index(dof)
for dof in self.cfg.waist_dof_names
]
self.leg_dof_indices = [
self.simulator_dof_names.index(dof)
for dof in self.cfg.leg_dof_names
]
# Body indices for mpkpe metrics using unified naming
self.arm_body_indices = [
self.simulator_body_names.index(body)
for body in self.cfg.arm_body_names
]
self.torso_body_indices = [
self.simulator_body_names.index(body)
for body in self.cfg.torso_body_names
]
self.leg_body_indices = [
self.simulator_body_names.index(body)
for body in self.cfg.leg_body_names
]
# Per-env world origins (translation only)
# Shape: [num_envs, 3] on the same device as the sim
self._env_origins = self._env.scene.env_origins.to(self.device)
# AMP-style observation indices (RSL reference alignment)
urdf_dof_name_to_idx = {
name: idx for idx, name in enumerate(self.urdf_dof_names)
}
sim_dof_name_to_idx = {
name: idx for idx, name in enumerate(self.simulator_dof_names)
}
urdf_body_name_to_idx = {
name: idx for idx, name in enumerate(self.urdf_body_names)
}
sim_body_name_to_idx = {
name: idx for idx, name in enumerate(self.simulator_body_names)
}
left_arm_dof_names = list(
getattr(self.cfg, "left_arm_dof_names", []) or []
)
right_arm_dof_names = list(
getattr(self.cfg, "right_arm_dof_names", []) or []
)
left_leg_dof_names = list(
getattr(self.cfg, "left_leg_dof_names", []) or []
)
right_leg_dof_names = list(
getattr(self.cfg, "right_leg_dof_names", []) or []
)
if not left_arm_dof_names:
left_arm_dof_names = self._amp_filter_names_by_prefix(
self.urdf_dof_names,
"left_",
("shoulder", "elbow", "wrist"),
)
if not right_arm_dof_names:
right_arm_dof_names = self._amp_filter_names_by_prefix(
self.urdf_dof_names,
"right_",
("shoulder", "elbow", "wrist"),
)
if not left_leg_dof_names:
left_leg_dof_names = self._amp_filter_names_by_prefix(
self.urdf_dof_names, "left_", ("hip", "knee", "ankle")
)
if not right_leg_dof_names:
right_leg_dof_names = self._amp_filter_names_by_prefix(
self.urdf_dof_names, "right_", ("hip", "knee", "ankle")
)
self._amp_left_arm_urdf_dof_idx = [
urdf_dof_name_to_idx[name] for name in left_arm_dof_names
]
self._amp_right_arm_urdf_dof_idx = [
urdf_dof_name_to_idx[name] for name in right_arm_dof_names
]
self._amp_left_leg_urdf_dof_idx = [
urdf_dof_name_to_idx[name] for name in left_leg_dof_names
]
self._amp_right_leg_urdf_dof_idx = [
urdf_dof_name_to_idx[name] for name in right_leg_dof_names
]
self._amp_left_arm_sim_dof_idx = [
sim_dof_name_to_idx[name] for name in left_arm_dof_names
]
self._amp_right_arm_sim_dof_idx = [
sim_dof_name_to_idx[name] for name in right_arm_dof_names
]
self._amp_left_leg_sim_dof_idx = [
sim_dof_name_to_idx[name] for name in left_leg_dof_names
]
self._amp_right_leg_sim_dof_idx = [
sim_dof_name_to_idx[name] for name in right_leg_dof_names
]
left_arm_body_names = list(
getattr(self.cfg, "left_arm_body_names", []) or []
)
right_arm_body_names = list(
getattr(self.cfg, "right_arm_body_names", []) or []
)
left_leg_body_names = list(
getattr(self.cfg, "left_leg_body_names", []) or []
)
right_leg_body_names = list(
getattr(self.cfg, "right_leg_body_names", []) or []
)
if not left_arm_body_names:
left_arm_body_names = self._amp_filter_names_by_prefix(
self.urdf_body_names, "left_", ("shoulder", "elbow", "wrist")
)
if not right_arm_body_names:
right_arm_body_names = self._amp_filter_names_by_prefix(
self.urdf_body_names, "right_", ("shoulder", "elbow", "wrist")
)
if not left_leg_body_names:
left_leg_body_names = self._amp_filter_names_by_prefix(
self.urdf_body_names, "left_", ("hip", "knee", "ankle")
)
if not right_leg_body_names:
right_leg_body_names = self._amp_filter_names_by_prefix(
self.urdf_body_names, "right_", ("hip", "knee", "ankle")
)
left_elbow_name = self._amp_pick_first_name(
left_arm_body_names, ("left_elbow", "elbow")
)
right_elbow_name = self._amp_pick_first_name(
right_arm_body_names, ("right_elbow", "elbow")
)
left_foot_name = self._amp_pick_first_name(
left_leg_body_names,
("left_ankle_roll", "left_ankle_pitch", "left_ankle"),
)
right_foot_name = self._amp_pick_first_name(
right_leg_body_names,
("right_ankle_roll", "right_ankle_pitch", "right_ankle"),
)
self._amp_left_elbow_urdf_body_idx = (
urdf_body_name_to_idx[left_elbow_name]
if left_elbow_name is not None
else None
)
self._amp_right_elbow_urdf_body_idx = (
urdf_body_name_to_idx[right_elbow_name]
if right_elbow_name is not None
else None
)
self._amp_left_foot_urdf_body_idx = (
urdf_body_name_to_idx[left_foot_name]
if left_foot_name is not None
else None
)
self._amp_right_foot_urdf_body_idx = (
urdf_body_name_to_idx[right_foot_name]
if right_foot_name is not None
else None
)
self._amp_left_elbow_sim_body_idx = (
sim_body_name_to_idx[left_elbow_name]
if left_elbow_name is not None
else None
)
self._amp_right_elbow_sim_body_idx = (
sim_body_name_to_idx[right_elbow_name]
if right_elbow_name is not None
else None
)
self._amp_left_foot_sim_body_idx = (
sim_body_name_to_idx[left_foot_name]
if left_foot_name is not None
else None
)
self._amp_right_foot_sim_body_idx = (
sim_body_name_to_idx[right_foot_name]
if right_foot_name is not None
else None
)
self._amp_left_hand_local_vec = torch.tensor(
[0.0, 0.0, -0.3], device=self.device, dtype=torch.float32
)
self._amp_right_hand_local_vec = torch.tensor(
[0.0, 0.0, -0.3], device=self.device, dtype=torch.float32
)
def _init_buffers(self):
self.metrics = {}
self.ref_motion_global_frame_ids = torch.zeros(
self.num_envs,
dtype=torch.long,
device=self.device,
)
# mark envs that timed out (frame id exceeded end frame) in current step
self._motion_end_mask = torch.zeros(
self.num_envs,
dtype=torch.bool,
device=self.device,
)
# counter for number of motion ends per environment
self.motion_end_counter = torch.zeros(
self.num_envs,
dtype=torch.long,
device=self.device,
)
# per-environment cached motion indices
self._cached_motion_ids = torch.zeros(
self.num_envs,
dtype=torch.long,
device=self.device,
)
# env -> cache row indirection (starts as identity mapping)
self._env_to_cache_row = torch.arange(
self.num_envs, dtype=torch.long, device=self.device
)
self._start_frame_indices = torch.zeros(
self.num_envs,
dtype=torch.long,
device=self.device,
)
self._reward_sum_since_assign = torch.zeros(
self.num_envs,
dtype=torch.float32,
device=self.device,
)
self._mpjpe_sum_since_assign = torch.zeros(
self.num_envs,
dtype=torch.float32,
device=self.device,
)
self._mpkpe_sum_since_assign = torch.zeros(
self.num_envs,
dtype=torch.float32,
device=self.device,
)
self._step_count_since_assign = torch.zeros(
self.num_envs,
dtype=torch.float32,
device=self.device,
)
self._completion_rate_sum_by_window: Dict[int, float] = {}
self._completion_rate_count_by_window: Dict[int, int] = {}
self._mpkpe_signal_sum_by_window: Dict[int, float] = {}
self._mpkpe_signal_count_by_window: Dict[int, int] = {}
self.pos_history_buffer = None
self.rot_history_buffer = None
self.ref_pos_history_buffer = None
self.current_accel = None
self.ref_body_accel = None
self.current_ang_accel = None # Placeholder for angular acceleration
self.metrics["Task/MPJPE_WholeBody"] = torch.zeros(
self.num_envs, device=self.device
)
self.metrics["Task/MPKPE_WholeBody"] = torch.zeros(
self.num_envs, device=self.device
)
def _record_completion_rate_for_envs(self, env_ids: torch.Tensor) -> None:
if env_ids.numel() == 0:
return
selected_clip_indices = self._clip_indices[env_ids]
lengths = self._motion_cache.lengths_for_indices(selected_clip_indices)
window_indices = self._motion_cache.window_indices_for_indices(
selected_clip_indices
)
available_steps = torch.clamp(
lengths
- int(self.cfg.n_fut_frames)
- self._start_frame_indices[env_ids],
min=1,
)
completion_rate = torch.clamp(
self._step_count_since_assign[env_ids] / available_steps.float(),
min=0.0,
max=1.0,
)
step_den = torch.clamp(self._step_count_since_assign[env_ids], min=1.0)
mpkpe_mean = self._mpkpe_sum_since_assign[env_ids] / step_den
completion_values = completion_rate.detach().cpu().tolist()
mpkpe_values = mpkpe_mean.detach().cpu().tolist()
window_values = window_indices.detach().cpu().tolist()
for idx, window_index_obj in enumerate(window_values):
completion_value = float(completion_values[idx])
mpkpe_value = float(mpkpe_values[idx])
mpkpe_signal = -mpkpe_value
window_index = int(window_index_obj)
if window_index in self._completion_rate_sum_by_window:
self._completion_rate_sum_by_window[window_index] += (
completion_value
)
self._completion_rate_count_by_window[window_index] += 1
else:
self._completion_rate_sum_by_window[window_index] = (
completion_value
)
self._completion_rate_count_by_window[window_index] = 1
if window_index in self._mpkpe_signal_sum_by_window:
self._mpkpe_signal_sum_by_window[window_index] += mpkpe_signal
self._mpkpe_signal_count_by_window[window_index] += 1
else:
self._mpkpe_signal_sum_by_window[window_index] = mpkpe_signal
self._mpkpe_signal_count_by_window[window_index] = 1
self._reward_sum_since_assign[env_ids] = 0.0
self._mpjpe_sum_since_assign[env_ids] = 0.0
self._mpkpe_sum_since_assign[env_ids] = 0.0
self._step_count_since_assign[env_ids] = 0.0
def _reset_window_curriculum_stats(self) -> None:
self._completion_rate_sum_by_window = {}
self._completion_rate_count_by_window = {}
self._mpkpe_signal_sum_by_window = {}
self._mpkpe_signal_count_by_window = {}
def _build_window_curriculum_stats_from_current_batch(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_window_indices = self._motion_cache.current_batch.window_indices
row_window_indices = batch_window_indices.detach().to(
self.device, dtype=torch.long
)
count = int(row_window_indices.numel())
row_mpkpe_signal = torch.zeros(
count, dtype=torch.float32, device=self.device
)
row_completion_rate = torch.zeros(
count, dtype=torch.float32, device=self.device
)
row_count = torch.zeros(count, dtype=torch.float32, device=self.device)
window_values = row_window_indices.detach().cpu().tolist()
for row_idx, window_index_obj in enumerate(window_values):
window_index = int(window_index_obj)
completion_count = int(
self._completion_rate_count_by_window.get(window_index, 0)
)
mpkpe_count = int(
self._mpkpe_signal_count_by_window.get(window_index, 0)
)
if completion_count > 0:
row_completion_rate[row_idx] = float(
self._completion_rate_sum_by_window[window_index]
) / float(completion_count)
if mpkpe_count > 0:
row_mpkpe_signal[row_idx] = float(
self._mpkpe_signal_sum_by_window[window_index]
) / float(mpkpe_count)
row_count[row_idx] = float(max(completion_count, mpkpe_count))
return (
row_window_indices,
row_mpkpe_signal,
row_completion_rate,
row_count,
)
def _update_cache_curriculum_state(
self,
*,
accelerator,
swap_index: int,
) -> None:
if self._sampling_strategy != "curriculum":
self._reset_window_curriculum_stats()
return
(
row_window_indices,
row_mpkpe_signal,
row_completion_rate,
row_count,
) = self._build_window_curriculum_stats_from_current_batch()
if accelerator is not None and int(accelerator.num_processes) > 1:
gather_window_indices = accelerator.gather(row_window_indices)
gather_mpkpe_signal = accelerator.gather(row_mpkpe_signal)
gather_completion_rate = accelerator.gather(row_completion_rate)
gather_count = accelerator.gather(row_count)
else:
gather_window_indices = row_window_indices
gather_mpkpe_signal = row_mpkpe_signal
gather_completion_rate = row_completion_rate
gather_count = row_count
self._motion_cache.update_cache_curriculum(
window_indices=gather_window_indices,
mpkpe_signal_means=gather_mpkpe_signal,
completion_rate_means=gather_completion_rate,
counts=gather_count,
swap_index=int(swap_index),
)
self._reset_window_curriculum_stats()
def update_curriculum_reward_accumulators(
self, rewards: torch.Tensor
) -> None:
reward_flat = rewards.view(-1).to(self.device, dtype=torch.float32)
all_ids = torch.arange(
self.num_envs, dtype=torch.long, device=self.device
)
motion_ids = self._filter_env_ids_for_motion_task(all_ids)
if motion_ids.numel() == 0:
return
self._reward_sum_since_assign[motion_ids] += reward_flat[motion_ids]
mpjpe = self.metrics["Task/MPJPE_WholeBody"]
mpkpe = self.metrics["Task/MPKPE_WholeBody"]
self._mpjpe_sum_since_assign[motion_ids] += mpjpe[motion_ids].to(
dtype=torch.float32
)
self._mpkpe_sum_since_assign[motion_ids] += mpkpe[motion_ids].to(
dtype=torch.float32
)
self._step_count_since_assign[motion_ids] += 1.0
@property
def command(
self,
) -> torch.Tensor:
# call the corresponding method based on configured command_obs_name
return getattr(self, f"_get_obs_{self.cfg.command_obs_name}")()
@property
def command_fut(
self,
) -> torch.Tensor:
# call the corresponding method based on configured command_obs_name
return getattr(self, f"_get_obs_{self.cfg.command_obs_name}_fut")()
def reset(
self,
env_ids: Sequence[int] | None = None,
) -> dict[str, float]:
extras = super().reset(env_ids)
if env_ids is None:
env_ids = slice(None)
if not isinstance(env_ids, torch.Tensor):
env_ids = torch.tensor(
env_ids, device=self.device, dtype=torch.long
)
else:
env_ids = env_ids.to(self.device)
self._motion_end_mask[env_ids] = False
self.motion_end_counter[env_ids] = 0
# Do not apply cache swap inside per-env reset; defer to PPO barrier.
# Always resample only the requested envs here.
motion_ids = self._filter_env_ids_for_motion_task(env_ids.view(-1))
self._resample_command(motion_ids, eval=self._is_evaluating)
return extras
def apply_cache_swap_if_pending_barrier(self, accelerator=None) -> bool:
"""Apply a pending cache swap at a rollout barrier.
Returns:
bool: True if a swap was applied, otherwise False.
"""
if not getattr(self, "_swap_pending", False):
return False
all_ids = torch.arange(
self.num_envs, dtype=torch.long, device=self.device
)
motion_ids = self._filter_env_ids_for_motion_task(all_ids)
if motion_ids.numel() == 0:
# No motion envs active under multi-task: keep ref motion inert.
self._swap_pending = False
self._swap_step_counter = 0
return False
self._record_completion_rate_for_envs(motion_ids)
next_swap_index = int(self._motion_cache.swap_index) + 1
self._update_cache_curriculum_state(
accelerator=accelerator,
swap_index=next_swap_index,
)
# Advance cache and reset counters
self._motion_cache.advance()
self._maybe_dump_sampled_motion_keys()
self._swap_pending = False
self._swap_step_counter = 0
# Reassign motion envs to the new cache batch
clip_idx, frame_idx = self._motion_cache.sample_env_assignments(
int(motion_ids.numel()),
self.cfg.n_fut_frames,
self.device,
deterministic_start=(self._is_evaluating),
)
self._clip_indices[motion_ids] = clip_idx
self._frame_indices[motion_ids] = frame_idx
self._start_frame_indices[motion_ids] = frame_idx
self._reward_sum_since_assign[motion_ids] = 0.0
self._step_count_since_assign[motion_ids] = 0.0
self._update_ref_motion_state_from_cache(env_ids=motion_ids)
# Realign robot states to the new reference
self._align_root_to_ref(motion_ids)
self._align_dof_to_ref(motion_ids)
# Reset per-episode timeout bookkeeping for consistency
self._motion_end_mask[motion_ids] = False
self.motion_end_counter[motion_ids] = 0
return True
def compute(self, dt: float):
all_ids = torch.arange(
self.num_envs, dtype=torch.long, device=self.device
)
motion_ids = self._filter_env_ids_for_motion_task(all_ids)
if motion_ids.numel() == 0:
return
self._update_metrics()
self._update_command()
def _update_ref_motion_state(self):
"""Update reference motion state (unified API)."""
return self._update_ref_motion_state_from_cache()
def _update_ref_motion_state_from_cache(
self, env_ids: torch.Tensor | None = None
):
"""Compatibility no-op for cache-backed reference access."""
del env_ids
return None
def _get_ref_state_array(
self,
base_key: str,
prefix: str = "ref_",
) -> torch.Tensor:
"""Gather a reference tensor from the current cache batch.
Args:
base_key: Base key in the motion cache (e.g. \"dof_pos\", \"root_pos\").
prefix: Optional logical prefix (e.g. \"\", \"ref_\", \"ft_ref_\", \"robot_\").
Returns:
Tensor of shape ``[num_envs, 1 + n_fut_frames, ...]`` gathered for
the envs' current clip/frame assignments.
"""
batch_tensors = self._motion_cache.current_batch.tensors
tensor_key = resolve_reference_tensor_key(
batch_tensors=batch_tensors,
base_key=base_key,
prefix=prefix,
)
return self._motion_cache.gather_tensor(
tensor_key,
clip_indices=self._clip_indices,
frame_indices=self._frame_indices,
n_future_frames=self.cfg.n_fut_frames,
)
def get_ref_motion_filter_cutoff_hz_cur(self) -> torch.Tensor:
try:
base = self._get_ref_state_array("filter_cutoff_hz", prefix="")
except KeyError:
# Older/local datasets may not carry per-clip filter metadata.
# Keep the observation available with a neutral default instead of
# failing during env construction.
return torch.zeros(
self.num_envs, 1, device=self.device, dtype=torch.float32
)
return base[:, 0, ...]
def _uniform_sample_ref_start_frames(self, env_ids: torch.Tensor):
"""Uniformly sample start frames within cached windows for env_ids.
Sampling range is [start, end - 1 - n_fut_frames] to ensure required
future frames exist. If that upper bound is < start, it falls back to start.
"""
if not isinstance(env_ids, torch.Tensor):
env_ids = torch.tensor(
env_ids, device=self.device, dtype=torch.long
)
else:
env_ids = env_ids.to(self.device).long()
starts = self.ref_motion_global_start_frame_ids[env_ids]
ends = self.ref_motion_global_end_frame_ids[env_ids]
# Ensure room for future frames if requested
n_fut = (
int(self.cfg.n_fut_frames)
if hasattr(self.cfg, "n_fut_frames")
else 0
)
max_start = ends - 1 - n_fut
max_start = torch.maximum(max_start, starts)
num_choices = (max_start - starts + 1).clamp(min=1)
# Sample offsets uniformly
rand = torch.rand_like(starts, dtype=torch.float32)
offsets = torch.floor(rand * num_choices.float()).long()
sampled = starts + offsets
self.ref_motion_global_frame_ids[env_ids] = sampled
def get_ref_motion_dof_pos_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("dof_pos", prefix)
return base[:, 1:, ...][..., self.urdf2sim_dof_idx]
def _get_immediate_next_ref_state_array(
self,
base_key: str,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array(base_key, prefix)
if base.shape[1] < 2:
raise ValueError(
f"Immediate-next reference for '{base_key}' requires at least one future frame."
)
return base[:, 1, ...]
def get_ref_motion_dof_vel_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("dof_vel", prefix)
return base[:, 1:, ...][..., self.urdf2sim_dof_idx]
def get_ref_motion_root_global_pos_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("root_pos", prefix)
return base[:, 1:, ...] + self._env_origins[:, None, :]
def get_ref_motion_root_global_rot_quat_xyzw_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
return self._get_ref_state_array("root_rot", prefix)[:, 1:, ...]
def get_ref_motion_root_global_rot_quat_wxyz_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
return self.get_ref_motion_root_global_rot_quat_xyzw_fut(
prefix=prefix
)[..., [3, 0, 1, 2]]
def get_ref_motion_root_global_lin_vel_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("root_vel", prefix)
return base[:, 1:, ...]
def get_ref_motion_root_global_ang_vel_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("root_ang_vel", prefix)
return base[:, 1:, ...]
def get_ref_motion_bodylink_global_pos_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("rg_pos", prefix)
return (
base[:, 1:, ...][..., self.urdf2sim_body_idx, :]
+ self._env_origins[:, None, None, :]
)
def get_ref_motion_bodylink_rel_pos_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
ref_body_global_pos = self.get_ref_motion_bodylink_global_pos_cur(
prefix=prefix
) # [B, N, 3]
ref_root_global_pos = self.get_ref_motion_root_global_pos_cur(
prefix=prefix
) # [B, 3]
ref_root_global_rot_wxyz = (
self.get_ref_motion_root_global_rot_quat_wxyz_cur(prefix=prefix)
) # [B, 4]
rel_pos_w = (
ref_body_global_pos - ref_root_global_pos[:, None, :]
) # [B, N, 3]
num_bodies = rel_pos_w.shape[1]
expanded_ref_root_global_rot_wxyz = ref_root_global_rot_wxyz[
:, None, :
].expand(-1, num_bodies, -1)
return isaaclab_math.quat_apply_inverse(
expanded_ref_root_global_rot_wxyz, rel_pos_w
) # [B, N, 3]
def get_ref_motion_bodylink_rel_pos_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
ref_body_global_pos_fut = self.get_ref_motion_bodylink_global_pos_fut(
prefix=prefix
) # [B, T, N, 3]
ref_root_global_pos_fut = self.get_ref_motion_root_global_pos_fut(
prefix=prefix
) # [B, T, 3]
ref_root_global_rot_wxyz_fut = (
self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)
) # [B, T, 4]
rel_pos_w_fut = (
ref_body_global_pos_fut - ref_root_global_pos_fut[:, :, None, :]
) # [B, T, N, 3]
num_bodies = rel_pos_w_fut.shape[2]
expanded_ref_root_global_rot_wxyz_fut = ref_root_global_rot_wxyz_fut[
:, :, None, :
].expand(-1, -1, num_bodies, -1)
return isaaclab_math.quat_apply_inverse(
expanded_ref_root_global_rot_wxyz_fut, rel_pos_w_fut
) # [B, T, N, 3]
def get_ref_motion_bodylink_global_rot_xyzw_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("rb_rot", prefix)
return base[:, 1:, ...][..., self.urdf2sim_body_idx, :]
def get_ref_motion_bodylink_global_lin_vel_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("body_vel", prefix)
return base[:, 1:, ...][..., self.urdf2sim_body_idx, :]
def get_ref_motion_bodylink_global_ang_vel_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("body_ang_vel", prefix)
return base[:, 1:, ...][..., self.urdf2sim_body_idx, :]
def get_ref_motion_dof_pos_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("dof_pos", prefix)
return base[:, 0, ...][..., self.urdf2sim_dof_idx]
def get_ref_motion_dof_pos_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_immediate_next_ref_state_array("dof_pos", prefix)
return base[..., self.urdf2sim_dof_idx]
def get_immediate_next_two_dof_pos(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Immediate next two DoF positions in simulator DoF order."""
n_fut = int(self.cfg.n_fut_frames)
if n_fut < 1:
raise ValueError(
"n_fut_frames must be at least 1 for immediate next two DoF positions."
)
base = self._get_ref_state_array("dof_pos", prefix)
return base[:, :2, ...][..., self.urdf2sim_dof_idx]
def get_ref_motion_dof_pos_cur_urdf_order(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("dof_pos", prefix)
return base[:, 0, ...]
def get_ref_motion_cur_heading_aligned_root_pos(
self,
prefix: str = "ref_",
) -> torch.Tensor:
# prepare current frame robot root global poses
robot_cur_global_root_pos = self.robot.data.root_pos_w
robot_cur_global_root_rot = self.robot.data.root_quat_w # wxyz
yaw_quat = isaaclab_math.yaw_quat(robot_cur_global_root_rot)
# transform the current goal frame root poses into the relative heading aligned frame
global_pos_diff = (
self.get_ref_motion_root_global_pos_cur(prefix=prefix)
- robot_cur_global_root_pos
)
global_pos_diff_heading_aligned = isaaclab_math.quat_apply_inverse(
yaw_quat, global_pos_diff
)
return global_pos_diff_heading_aligned
def get_ref_motion_fut_heading_aligned_root_pos(
self,
prefix: str = "ref_",
) -> torch.Tensor:
# prepare current frame robot root global poses
robot_cur_global_root_pos = self.robot.data.root_pos_w # [B, 3]
robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4]
yaw_quat = isaaclab_math.yaw_quat(robot_cur_global_root_rot) # [B, 4]
# transform the current goal frame root poses into the relative heading aligned frame
fut_root_global_pos = self.get_ref_motion_root_global_pos_fut(
prefix=prefix
) # [B, T, 3]
num_fut_frames = fut_root_global_pos.shape[1]
global_pos_diff = (
fut_root_global_pos - robot_cur_global_root_pos[:, None, :]
) # [B, T, 3]
expanded_yaw_quat = yaw_quat[:, None, :].expand(
-1, num_fut_frames, -1
) # [B, T, 4]
fut_root_global_pos_heading_aligned = isaaclab_math.quat_apply_inverse(
expanded_yaw_quat, global_pos_diff
) # [B, T, 3]
return fut_root_global_pos_heading_aligned
def get_ref_motion_cur_heading_aligned_root_rot6d(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Current reference root rotation (rot6d) in heading-aligned frame.
Returns:
torch.Tensor: [B, 6]
"""
robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz
heading_quat_wxyz = isaaclab_math.yaw_quat(
robot_cur_global_root_rot
) # [B, 4] wxyz
heading_quat_inv_wxyz = isaaclab_math.quat_inv(
heading_quat_wxyz
) # [B, 4] wxyz
ref_root_quat_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur(
prefix=prefix
) # [B, 4] wxyz
ref_root_quat_in_heading_wxyz = isaaclab_math.quat_mul(
heading_quat_inv_wxyz, ref_root_quat_wxyz
) # [B, 4] wxyz
# rot6d: first two columns of rotation matrix (flattened)
ref_root_rot6d = isaaclab_math.matrix_from_quat(
ref_root_quat_in_heading_wxyz
)[..., :2].reshape(ref_root_quat_wxyz.shape[0], 6) # [B, 6]
return ref_root_rot6d
def get_ref_motion_fut_heading_aligned_root_rot6d(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Future reference root rotations (rot6d) in heading-aligned frame.
Returns:
torch.Tensor: [B, T, 6]
"""
robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz
heading_quat_wxyz = isaaclab_math.yaw_quat(
robot_cur_global_root_rot
) # [B, 4] wxyz
heading_quat_inv_wxyz = isaaclab_math.quat_inv(
heading_quat_wxyz
) # [B, 4] wxyz
ref_root_quat_wxyz_fut = (
self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)
) # [B, T, 4] wxyz
num_envs, num_fut_frames, _ = ref_root_quat_wxyz_fut.shape
heading_quat_inv_wxyz_fut = heading_quat_inv_wxyz[:, None, :].expand(
-1, num_fut_frames, -1
) # [B, T, 4]
ref_root_quat_in_heading_wxyz_fut = isaaclab_math.quat_mul(
heading_quat_inv_wxyz_fut, ref_root_quat_wxyz_fut
) # [B, T, 4] wxyz
ref_root_rot6d_fut = isaaclab_math.matrix_from_quat(
ref_root_quat_in_heading_wxyz_fut
)[..., :2].reshape(num_envs, num_fut_frames, 6) # [B, T, 6]
return ref_root_rot6d_fut
def get_ref_motion_cur_heading_aligned_root_lin_vel(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Current reference root linear velocity in heading-aligned frame.
Returns: [B, 3]
"""
robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz
heading_quat_wxyz = isaaclab_math.yaw_quat(
robot_cur_global_root_rot
) # [B, 4] wxyz
ref_root_lin_vel_w = self.get_ref_motion_root_global_lin_vel_cur(
prefix=prefix
) # [B, 3]
ref_root_lin_vel_heading = isaaclab_math.quat_apply_inverse(
heading_quat_wxyz, ref_root_lin_vel_w
) # [B, 3]
return ref_root_lin_vel_heading
def get_ref_motion_fut_heading_aligned_root_lin_vel(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Future reference root linear velocity in heading-aligned frame.
Returns: [B, T, 3]
"""
robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz
heading_quat_wxyz = isaaclab_math.yaw_quat(
robot_cur_global_root_rot
) # [B, 4] wxyz
ref_root_lin_vel_w_fut = self.get_ref_motion_root_global_lin_vel_fut(
prefix=prefix
) # [B, T, 3]
num_envs, num_fut_frames, _ = ref_root_lin_vel_w_fut.shape
heading_quat_wxyz_fut = heading_quat_wxyz[:, None, :].expand(
-1, num_fut_frames, -1
) # [B, T, 4]
ref_root_lin_vel_heading_fut = isaaclab_math.quat_apply_inverse(
heading_quat_wxyz_fut, ref_root_lin_vel_w_fut
) # [B, T, 3]
return ref_root_lin_vel_heading_fut
def get_ref_motion_cur_heading_aligned_root_ang_vel(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Current reference root angular velocity in heading-aligned frame.
Returns: [B, 3]
"""
robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz
heading_quat_wxyz = isaaclab_math.yaw_quat(
robot_cur_global_root_rot
) # [B, 4] wxyz
ref_root_ang_vel_w = self.get_ref_motion_root_global_ang_vel_cur(
prefix=prefix
) # [B, 3]
ref_root_ang_vel_heading = isaaclab_math.quat_apply_inverse(
heading_quat_wxyz, ref_root_ang_vel_w
) # [B, 3]
return ref_root_ang_vel_heading
def get_ref_motion_fut_heading_aligned_root_ang_vel(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Future reference root angular velocity in heading-aligned frame.
Returns: [B, T, 3]
"""
robot_cur_global_root_rot = self.robot.data.root_quat_w # [B, 4] wxyz
heading_quat_wxyz = isaaclab_math.yaw_quat(
robot_cur_global_root_rot
) # [B, 4] wxyz
ref_root_ang_vel_w_fut = self.get_ref_motion_root_global_ang_vel_fut(
prefix=prefix
) # [B, T, 3]
num_envs, num_fut_frames, _ = ref_root_ang_vel_w_fut.shape
heading_quat_wxyz_fut = heading_quat_wxyz[:, None, :].expand(
-1, num_fut_frames, -1
) # [B, T, 4]
ref_root_ang_vel_heading_fut = isaaclab_math.quat_apply_inverse(
heading_quat_wxyz_fut, ref_root_ang_vel_w_fut
) # [B, T, 3]
return ref_root_ang_vel_heading_fut
@property
def robot_dof_pos_cur_urdf_order(self):
return self.robot.data.joint_pos[..., self.sim2urdf_dof_idx]
def get_ref_motion_dof_vel_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("dof_vel", prefix)
return base[:, 0, ...][..., self.urdf2sim_dof_idx]
def get_ref_motion_dof_vel_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_immediate_next_ref_state_array("dof_vel", prefix)
return base[..., self.urdf2sim_dof_idx]
@property
def robot_dof_vel_cur_urdf_order(self):
return self.robot.data.joint_vel[..., self.sim2urdf_dof_idx]
def get_ref_motion_dof_vel_cur_urdf_order(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("dof_vel", prefix)
return base[:, 0, ...]
def get_ref_motion_root_global_pos_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("root_pos", prefix)
return base[:, 0, ...] + self._env_origins
def get_ref_motion_root_global_pos_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_immediate_next_ref_state_array("root_pos", prefix)
return base + self._env_origins
def get_ref_motion_root_global_rot_quat_xyzw_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
return self._get_ref_state_array("root_rot", prefix)[:, 0, ...]
def get_ref_motion_root_global_rot_quat_xyzw_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
return self._get_immediate_next_ref_state_array("root_rot", prefix)
def get_ref_motion_root_global_rot_quat_wxyz_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
return self.get_ref_motion_root_global_rot_quat_xyzw_cur(
prefix=prefix
)[..., [3, 0, 1, 2]]
def get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
return self.get_ref_motion_root_global_rot_quat_xyzw_immediate_next(
prefix=prefix
)[..., [3, 0, 1, 2]]
def get_ref_motion_root_global_lin_vel_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("root_vel", prefix)
return base[:, 0, ...]
def get_ref_motion_root_global_lin_vel_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
return self._get_immediate_next_ref_state_array("root_vel", prefix)
@property
def ref_motion_root_global_lin_vel_cur(self) -> torch.Tensor:
return self.get_ref_motion_root_global_lin_vel_cur()
def get_ref_motion_root_global_ang_vel_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("root_ang_vel", prefix)
return base[:, 0, ...]
def get_ref_motion_root_global_ang_vel_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
return self._get_immediate_next_ref_state_array("root_ang_vel", prefix)
def get_ref_motion_gravity_projection_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Current reference gravity projected into reference root frame."""
g_w = self.robot.data.GRAVITY_VEC_W # [B, 3]
ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur(
prefix=prefix
) # [B, 4]
return isaaclab_math.quat_apply_inverse(ref_root_rot_wxyz, g_w)
def get_ref_motion_gravity_projection_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
g_w = self.robot.data.GRAVITY_VEC_W # [B, 3]
ref_root_rot_wxyz = (
self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=prefix
)
)
return isaaclab_math.quat_apply_inverse(ref_root_rot_wxyz, g_w)
def get_ref_motion_gravity_projection_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Future reference gravity projected into reference root frame."""
g_w = self.robot.data.GRAVITY_VEC_W # [B, 3]
ref_root_rot_wxyz_fut = (
self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)
) # [B, T, 4]
gravity_fut = g_w[:, None, :].expand(
-1, ref_root_rot_wxyz_fut.shape[1], -1
) # [B, T, 3]
return isaaclab_math.quat_apply_inverse(
ref_root_rot_wxyz_fut, gravity_fut
) # [B, T, 3]
def get_ref_motion_base_linvel_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Current reference base linear velocity in reference root frame."""
ref_root_lin_vel_w = self.get_ref_motion_root_global_lin_vel_cur(
prefix=prefix
) # [B, 3]
ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur(
prefix=prefix
) # [B, 4]
return isaaclab_math.quat_apply_inverse(
ref_root_rot_wxyz, ref_root_lin_vel_w
) # [B, 3]
def get_ref_motion_base_linvel_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
ref_root_lin_vel_w = (
self.get_ref_motion_root_global_lin_vel_immediate_next(
prefix=prefix
)
)
ref_root_rot_wxyz = (
self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=prefix
)
)
return isaaclab_math.quat_apply_inverse(
ref_root_rot_wxyz, ref_root_lin_vel_w
)
def get_ref_motion_base_linvel_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Future reference base linear velocity in reference root frame."""
ref_root_lin_vel_w_fut = self.get_ref_motion_root_global_lin_vel_fut(
prefix=prefix
) # [B, T, 3]
ref_root_rot_wxyz_fut = (
self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)
) # [B, T, 4]
return isaaclab_math.quat_apply_inverse(
ref_root_rot_wxyz_fut, ref_root_lin_vel_w_fut
) # [B, T, 3]
def get_ref_motion_base_angvel_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Current reference base angular velocity in reference root frame."""
ref_root_ang_vel_w = self.get_ref_motion_root_global_ang_vel_cur(
prefix=prefix
) # [B, 3]
ref_root_rot_wxyz = self.get_ref_motion_root_global_rot_quat_wxyz_cur(
prefix=prefix
) # [B, 4]
return isaaclab_math.quat_apply_inverse(
ref_root_rot_wxyz, ref_root_ang_vel_w
) # [B, 3]
def get_ref_motion_base_angvel_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
ref_root_ang_vel_w = (
self.get_ref_motion_root_global_ang_vel_immediate_next(
prefix=prefix
)
)
ref_root_rot_wxyz = (
self.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=prefix
)
)
return isaaclab_math.quat_apply_inverse(
ref_root_rot_wxyz, ref_root_ang_vel_w
)
def get_ref_motion_base_angvel_fut(
self,
prefix: str = "ref_",
) -> torch.Tensor:
"""Future reference base angular velocity in reference root frame."""
ref_root_ang_vel_w_fut = self.get_ref_motion_root_global_ang_vel_fut(
prefix=prefix
) # [B, T, 3]
ref_root_rot_wxyz_fut = (
self.get_ref_motion_root_global_rot_quat_wxyz_fut(prefix=prefix)
) # [B, T, 4]
return isaaclab_math.quat_apply_inverse(
ref_root_rot_wxyz_fut, ref_root_ang_vel_w_fut
) # [B, T, 3]
def get_ref_motion_bodylink_global_pos_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("rg_pos", prefix)
return (
base[:, 0, ...][..., self.urdf2sim_body_idx, :]
+ self._env_origins[:, None, :]
)
def get_ref_motion_bodylink_global_pos_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_immediate_next_ref_state_array("rg_pos", prefix)
return (
base[..., self.urdf2sim_body_idx, :]
+ self._env_origins[:, None, :]
)
def get_ref_motion_bodylink_global_pos_cur_urdf_order(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("rg_pos", prefix)
return base[:, 0, ...] + self._env_origins[:, None, :]
def get_ref_motion_bodylink_global_rot_wxyz_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
rot_xyzw = self.get_ref_motion_bodylink_global_rot_xyzw_cur(
prefix=prefix
)
return rot_xyzw[..., [3, 0, 1, 2]]
def get_ref_motion_bodylink_global_rot_xyzw_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("rb_rot", prefix)
return base[:, 0, ...][..., self.urdf2sim_body_idx, :]
def get_ref_motion_bodylink_global_rot_xyzw_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_immediate_next_ref_state_array("rb_rot", prefix)
return base[..., self.urdf2sim_body_idx, :]
def get_ref_motion_bodylink_global_rot_wxyz_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
rot_xyzw = self.get_ref_motion_bodylink_global_rot_xyzw_immediate_next(
prefix=prefix
)
return rot_xyzw[..., [3, 0, 1, 2]]
def get_ref_motion_bodylink_global_rot_xyzw_cur_urdf_order(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("rb_rot", prefix)
return base[:, 0, ...]
@property
def robot_bodylink_global_pos_cur_urdf_order(self):
return self.robot.data.body_pos_w[:, self.sim2urdf_body_idx]
@property
def robot_bodylink_global_rot_wxyz_cur_urdf_order(self):
return self.robot.data.body_quat_w[:, self.sim2urdf_body_idx]
@property
def robot_bodylink_global_rot_xyzw_cur_urdf_order(self):
return self.robot_bodylink_global_rot_wxyz_cur_urdf_order[
..., [1, 2, 3, 0]
]
@property
def robot_bodylink_global_lin_vel_cur_urdf_order(self):
return self.robot.data.body_lin_vel_w[:, self.sim2urdf_body_idx]
@property
def robot_bodylink_global_ang_vel_cur_urdf_order(self):
return self.robot.data.body_ang_vel_w[:, self.sim2urdf_body_idx]
def get_ref_motion_bodylink_global_lin_vel_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("body_vel", prefix)
return base[:, 0, ...][..., self.urdf2sim_body_idx, :]
def get_ref_motion_bodylink_global_lin_vel_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_immediate_next_ref_state_array("body_vel", prefix)
return base[..., self.urdf2sim_body_idx, :]
def get_ref_motion_bodylink_global_lin_vel_cur_urdf_order(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("body_vel", prefix)
return base[:, 0, ...]
def get_ref_motion_bodylink_global_ang_vel_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("body_ang_vel", prefix)
return base[:, 0, ...][..., self.urdf2sim_body_idx, :]
def get_ref_motion_bodylink_global_ang_vel_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_immediate_next_ref_state_array("body_ang_vel", prefix)
return base[..., self.urdf2sim_body_idx, :]
def get_ref_motion_bodylink_global_ang_vel_cur_urdf_order(
self,
prefix: str = "ref_",
) -> torch.Tensor:
base = self._get_ref_state_array("body_ang_vel", prefix)
return base[:, 0, ...]
def _build_amp_obs_from_ref_state(
self, frame_idx: int, prefix: str = "ft_ref_"
) -> torch.Tensor:
if (
not self._amp_left_arm_urdf_dof_idx
or not self._amp_right_arm_urdf_dof_idx
or not self._amp_left_leg_urdf_dof_idx
or not self._amp_right_leg_urdf_dof_idx
or self._amp_left_elbow_urdf_body_idx is None
or self._amp_right_elbow_urdf_body_idx is None
or self._amp_left_foot_urdf_body_idx is None
or self._amp_right_foot_urdf_body_idx is None
):
raise ValueError(
"AMP obs indices are not initialized for ref motion."
)
dof_pos = self._get_ref_state_array("dof_pos", prefix)[
:, frame_idx, ...
]
dof_vel = self._get_ref_state_array("dof_vel", prefix)[
:, frame_idx, ...
]
right_arm_pos = dof_pos[:, self._amp_right_arm_urdf_dof_idx]
left_arm_pos = dof_pos[:, self._amp_left_arm_urdf_dof_idx]
right_leg_pos = dof_pos[:, self._amp_right_leg_urdf_dof_idx]
left_leg_pos = dof_pos[:, self._amp_left_leg_urdf_dof_idx]
right_arm_vel = dof_vel[:, self._amp_right_arm_urdf_dof_idx]
left_arm_vel = dof_vel[:, self._amp_left_arm_urdf_dof_idx]
right_leg_vel = dof_vel[:, self._amp_right_leg_urdf_dof_idx]
left_leg_vel = dof_vel[:, self._amp_left_leg_urdf_dof_idx]
root_pos = self._get_ref_state_array("root_pos", prefix)[
:, frame_idx, ...
]
root_rot = self._get_ref_state_array("root_rot", prefix)[
:, frame_idx, ...
]
root_inv = quat_inverse(root_rot, w_last=True)
rg_pos = self._get_ref_state_array("rg_pos", prefix)[:, frame_idx, ...]
rb_rot = self._get_ref_state_array("rb_rot", prefix)[:, frame_idx, ...]
left_elbow_pos = rg_pos[:, self._amp_left_elbow_urdf_body_idx, :]
right_elbow_pos = rg_pos[:, self._amp_right_elbow_urdf_body_idx, :]
left_elbow_rot = rb_rot[:, self._amp_left_elbow_urdf_body_idx, :]
right_elbow_rot = rb_rot[:, self._amp_right_elbow_urdf_body_idx, :]
left_hand_offset = self._amp_left_hand_local_vec.expand(
left_elbow_pos.shape[0], -1
)
right_hand_offset = self._amp_right_hand_local_vec.expand(
right_elbow_pos.shape[0], -1
)
left_hand_world = left_elbow_pos + quat_rotate(
left_elbow_rot, left_hand_offset, w_last=True
)
right_hand_world = right_elbow_pos + quat_rotate(
right_elbow_rot, right_hand_offset, w_last=True
)
left_hand_rel = quat_rotate(
root_inv, left_hand_world - root_pos, w_last=True
)
right_hand_rel = quat_rotate(
root_inv, right_hand_world - root_pos, w_last=True
)
left_foot_world = rg_pos[:, self._amp_left_foot_urdf_body_idx, :]
right_foot_world = rg_pos[:, self._amp_right_foot_urdf_body_idx, :]
left_foot_rel = quat_rotate(
root_inv, left_foot_world - root_pos, w_last=True
)
right_foot_rel = quat_rotate(
root_inv, right_foot_world - root_pos, w_last=True
)
return torch.cat(
[
right_arm_pos,
left_arm_pos,
right_leg_pos,
left_leg_pos,
right_arm_vel,
left_arm_vel,
right_leg_vel,
left_leg_vel,
left_hand_rel,
right_hand_rel,
left_foot_rel,
right_foot_rel,
],
dim=-1,
)
def get_ref_motion_amp_obs_cur(
self, prefix: str = "ft_ref_"
) -> torch.Tensor:
"""AMP observation aligned with RSL reference (current frame)."""
return self._build_amp_obs_from_ref_state(0, prefix=prefix)
@property
def motion_end_mask(self) -> torch.Tensor:
"""[B] bool: per-step timeout mask.
Uses the per-step `motion_end_mask` set before resampling so the
event is observable within the same step, and falls back to a
direct comparison if not available.
"""
return self._motion_end_mask
@property
def global_robot_anchor_pos_cur(self):
return self.robot.data.body_pos_w[:, self.anchor_bodylink_idx]
def get_ref_motion_anchor_bodylink_global_pos_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
pos = self.get_ref_motion_bodylink_global_pos_cur(prefix=prefix)
return pos[:, self.anchor_bodylink_idx]
def get_ref_motion_anchor_bodylink_global_rot_wxyz_cur(
self,
prefix: str = "ref_",
) -> torch.Tensor:
rot = self.get_ref_motion_bodylink_global_rot_wxyz_cur(prefix=prefix)
return rot[:, self.anchor_bodylink_idx]
def get_ref_motion_anchor_bodylink_global_pos_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
pos = self.get_ref_motion_bodylink_global_pos_immediate_next(
prefix=prefix
)
return pos[:, self.anchor_bodylink_idx]
def get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next(
self,
prefix: str = "ref_",
) -> torch.Tensor:
rot = self.get_ref_motion_bodylink_global_rot_wxyz_immediate_next(
prefix=prefix
)
return rot[:, self.anchor_bodylink_idx]
def _get_obs_bydmmc_ref_motion(
self,
obs_prefix: str = "ref_",
) -> torch.Tensor:
base_pos = self._get_ref_state_array("dof_pos", obs_prefix)[:, 0, ...][
..., self.urdf2sim_dof_idx
]
base_vel = self._get_ref_state_array("dof_vel", obs_prefix)[:, 0, ...][
..., self.urdf2sim_dof_idx
]
num_envs = base_pos.shape[0]
cur_ref_dof_pos_flat = base_pos.reshape(num_envs, -1)
cur_ref_dof_vel_flat = base_vel.reshape(num_envs, -1)
return torch.cat([cur_ref_dof_pos_flat, cur_ref_dof_vel_flat], dim=-1)
def _get_obs_bydmmc_ref_motion_fut(
self,
obs_prefix: str = "ref_",
) -> torch.Tensor:
base_pos = self._get_ref_state_array("dof_pos", obs_prefix)[
:, 1:, ...
][..., self.urdf2sim_dof_idx]
base_vel = self._get_ref_state_array("dof_vel", obs_prefix)[
:, 1:, ...
][..., self.urdf2sim_dof_idx]
num_envs = base_pos.shape[0]
n_fut_frames = int(self.cfg.n_fut_frames)
fut_ref_dof_pos_flat = base_pos.reshape(num_envs, n_fut_frames, -1)
fut_ref_dof_vel_flat = base_vel.reshape(num_envs, n_fut_frames, -1)
rel_fut_ref_motion_state_seq = torch.cat(
[fut_ref_dof_pos_flat, fut_ref_dof_vel_flat], dim=-1
)
return rel_fut_ref_motion_state_seq.reshape(num_envs, -1)
def _get_obs_vr_ref_motion_states(
self,
obs_prefix: str = "ref_",
) -> torch.Tensor:
base_pos = self._get_ref_state_array("dof_pos", obs_prefix)[:, 0, ...][
..., self.urdf2sim_dof_idx
]
num_envs = base_pos.shape[0]
cur_ref_dof_pos_flat = base_pos.reshape(num_envs, -1)
return torch.cat(
[
cur_ref_dof_pos_flat,
torch.zeros_like(
cur_ref_dof_pos_flat,
device=cur_ref_dof_pos_flat.device,
),
],
dim=-1,
)
def _get_obs_vr_ref_motion_fut(
self,
obs_prefix: str = "ref_",
) -> torch.Tensor:
base_pos = self._get_ref_state_array("dof_pos", obs_prefix)[
:, 1:, ...
][..., self.urdf2sim_dof_idx]
num_envs = base_pos.shape[0]
n_fut_frames = int(self.cfg.n_fut_frames)
fut_ref_dof_pos_flat = base_pos.reshape(num_envs, n_fut_frames, -1)
rel_fut_ref_motion_state_seq = torch.cat(
[
fut_ref_dof_pos_flat,
torch.zeros_like(
fut_ref_dof_pos_flat, device=fut_ref_dof_pos_flat.device
),
],
dim=-1,
)
return rel_fut_ref_motion_state_seq.reshape(num_envs, -1)
def _get_obs_holomotion_rel_ref_motion_flat(
self,
obs_prefix: str = "ref_",
) -> torch.Tensor:
# Gather all needed arrays with obs prefix
fut_rg_pos = self._get_ref_state_array("rg_pos", obs_prefix)[
:, 1:, ...
][..., self.urdf2sim_body_idx, :]
fut_rb_rot_xyzw = self._get_ref_state_array("rb_rot", obs_prefix)[
:, 1:, ...
][..., self.urdf2sim_body_idx, :]
fut_root_rot_xyzw = self._get_ref_state_array("root_rot", obs_prefix)[
:, 1:, ...
]
fut_root_lin_vel = self._get_ref_state_array("root_vel", obs_prefix)[
:, 1:, ...
]
fut_root_ang_vel = self._get_ref_state_array(
"root_ang_vel", obs_prefix
)[:, 1:, ...]
fut_dof_pos = self._get_ref_state_array("dof_pos", obs_prefix)[
:, 1:, ...
][..., self.urdf2sim_dof_idx]
fut_dof_vel = self._get_ref_state_array("dof_vel", obs_prefix)[
:, 1:, ...
][..., self.urdf2sim_dof_idx]
num_envs, num_fut_timesteps, num_bodies, _ = fut_rg_pos.shape
assert num_envs == self.num_envs
assert num_fut_timesteps == self.cfg.n_fut_frames
fut_ref_root_rot_quat = fut_root_rot_xyzw # [B, T, 4]
fut_ref_root_rot_quat_inv = quat_inverse(
fut_ref_root_rot_quat, w_last=True
) # [B, T, 4]
fut_ref_root_rot_quat_body_flat = (
fut_ref_root_rot_quat[:, :, None, :]
.repeat(1, 1, num_bodies, 1)
.reshape(-1, 4)
)
fut_ref_root_rot_quat_body_flat_inv = quat_inverse(
fut_ref_root_rot_quat_body_flat, w_last=True
)
ref_fut_heading_quat_inv = calc_heading_quat_inv(
fut_root_rot_xyzw.reshape(-1, 4),
w_last=True,
) # [B*T, 4]
ref_fut_quat_rp = quat_mul(
ref_fut_heading_quat_inv,
fut_root_rot_xyzw.reshape(-1, 4),
w_last=True,
) # [B*T, 4]
ref_fut_roll, ref_fut_pitch, _ = get_euler_xyz(
ref_fut_quat_rp,
w_last=True,
)
ref_fut_roll = wrap_to_pi(ref_fut_roll).reshape(
num_envs, num_fut_timesteps, -1
) # [B, T, 1]
ref_fut_pitch = wrap_to_pi(ref_fut_pitch).reshape(
num_envs, num_fut_timesteps, -1
) # [B, T, 1]
ref_fut_rp = torch.cat(
[ref_fut_roll, ref_fut_pitch], dim=-1
) # [B, T, 2]
ref_fut_rp_flat = ref_fut_rp.reshape(num_envs, -1) # [B, T * 2]
# ---
fut_ref_root_quat_inv_fut_flat = fut_ref_root_rot_quat_inv.reshape(
-1, 4
)
fut_ref_cur_root_rel_base_lin_vel = quat_rotate(
fut_ref_root_quat_inv_fut_flat, # [B*T, 4]
fut_root_lin_vel.reshape(-1, 3), # [B*T, 3]
w_last=True,
).reshape(num_envs, -1) # [B, num_fut_timesteps * 3]
fut_ref_cur_root_rel_base_ang_vel = quat_rotate(
fut_ref_root_quat_inv_fut_flat, # [B*T, 4]
fut_root_ang_vel.reshape(-1, 3), # [B*T, 3]
w_last=True,
).reshape(num_envs, -1) # [B, num_fut_timesteps * 3]
# ---
# --- calculate the absolute DoF position and velocity ---
fut_ref_dof_pos_flat = fut_dof_pos.reshape(num_envs, -1)
fut_ref_dof_vel_flat = fut_dof_vel.reshape(num_envs, -1)
# ---
# --- calculate the future per frame bodylink position and rotation ---
fut_ref_global_bodylink_pos = fut_rg_pos # [B, T, num_bodies, 3]
fut_ref_global_bodylink_rot = fut_rb_rot_xyzw # [B, T, num_bodies, 4]
# get root-relative bodylink position
fut_ref_root_rel_bodylink_pos = quat_rotate(
fut_ref_root_rot_quat_body_flat_inv,
(
fut_ref_global_bodylink_pos
- fut_ref_global_bodylink_pos[:, :, 0:1, :]
).reshape(-1, 3),
w_last=True,
).reshape(
num_envs, num_fut_timesteps, num_bodies, -1
) # [B, num_fut_timesteps, num_bodies, 3]
# get root-relative bodylink rotation
fut_ref_root_rel_bodylink_rot = quat_mul(
fut_ref_root_rot_quat_body_flat_inv,
fut_ref_global_bodylink_rot.reshape(-1, 4),
w_last=True,
)
fut_ref_root_rel_bodylink_rot_mat = quaternion_to_matrix(
fut_ref_root_rel_bodylink_rot,
w_last=True,
)[:, :, :2].reshape(
num_envs, num_fut_timesteps, num_bodies, -1
) # [B, num_fut_timesteps, num_bodies, 6]
rel_fut_ref_motion_state_seq = torch.cat(
[
ref_fut_rp_flat.reshape(
num_envs, num_fut_timesteps, -1
), # [B, T, 2]
fut_ref_cur_root_rel_base_lin_vel.reshape(
num_envs, num_fut_timesteps, -1
), # [B, T, 3]
fut_ref_cur_root_rel_base_ang_vel.reshape(
num_envs, num_fut_timesteps, -1
), # [B, T, 3]
fut_ref_dof_pos_flat.reshape(
num_envs, num_fut_timesteps, -1
), # [B, T, num_dofs]
fut_ref_dof_vel_flat.reshape(
num_envs, num_fut_timesteps, -1
), # [B, T, num_dofs]
fut_ref_root_rel_bodylink_pos.reshape(
num_envs, num_fut_timesteps, -1
), # [B, T, num_bodies*3]
fut_ref_root_rel_bodylink_rot_mat.reshape(
num_envs, num_fut_timesteps, -1
), # [B, T, num_bodies*6]
],
dim=-1,
) # [B, T, 2 + 3 + 3 + num_dofs * 2 + num_bodies * (3 + 6)]
return rel_fut_ref_motion_state_seq.reshape(self.num_envs, -1)
def _resample_command(self, env_ids: Sequence[int], eval=False):
"""Resample command for specified environments."""
if len(env_ids) == 0:
return
if not isinstance(env_ids, torch.Tensor):
env_ids = torch.tensor(env_ids, device=self.device)
else:
env_ids = env_ids.to(self.device)
if isinstance(env_ids, torch.Tensor):
idxs = env_ids
elif isinstance(env_ids, slice):
idxs = torch.arange(self.num_envs, device=self.device)
else:
idxs = torch.tensor(env_ids, device=self.device, dtype=torch.long)
idxs = self._filter_env_ids_for_motion_task(idxs.view(-1))
if idxs.numel() == 0:
return
self._record_completion_rate_for_envs(idxs)
clip_idx, frame_idx = self._motion_cache.sample_env_assignments(
len(idxs),
self.cfg.n_fut_frames,
self.device,
deterministic_start=(eval or self._is_evaluating),
)
self._clip_indices[idxs] = clip_idx
self._frame_indices[idxs] = frame_idx
self._start_frame_indices[idxs] = frame_idx
self._reward_sum_since_assign[idxs] = 0.0
self._step_count_since_assign[idxs] = 0.0
self._update_ref_motion_state_from_cache(env_ids=idxs)
self._align_root_to_ref(idxs)
self._align_dof_to_ref(idxs)
def _filter_env_ids_for_motion_task(
self, env_ids: torch.Tensor
) -> torch.Tensor:
"""Filter env_ids to those currently assigned to motion_tracking task.
In multi-task training, we may keep `ref_motion` registered for observation
schemas, but we must avoid applying motion-based state alignment to envs
that are not running motion tracking (e.g., velocity tracking only).
Behavior:
- If env does not expose multi-task task buffers, return env_ids (legacy).
- If env exposes task buffers but has no "motion_tracking" task, return empty.
- Otherwise, return env_ids where holo_task_ids == holo_task_name_to_id["motion_tracking"].
"""
if env_ids.numel() == 0:
return env_ids
task_ids = getattr(self._env, "holo_task_ids", None)
task_name_to_id = getattr(self._env, "holo_task_name_to_id", None)
if task_ids is None or task_name_to_id is None:
return env_ids
motion_tid = task_name_to_id.get("motion_tracking", None)
if motion_tid is None:
return env_ids[:0]
task_ids_t = task_ids.to(device=self.device, dtype=torch.long).view(-1)
env_ids_t = env_ids.to(device=self.device, dtype=torch.long).view(-1)
mask = task_ids_t[env_ids_t] == int(motion_tid)
return env_ids_t[mask]
def _align_root_to_ref(self, env_ids):
if not isinstance(env_ids, torch.Tensor):
env_ids = torch.tensor(
env_ids, device=self.device, dtype=torch.long
)
else:
env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1)
env_ids = self._filter_env_ids_for_motion_task(env_ids)
if env_ids.numel() == 0:
return
root_pos = self.get_ref_motion_root_global_pos_cur().clone()
root_rot_xyzw = self.get_ref_motion_root_global_rot_quat_xyzw_cur()
root_rot = root_rot_xyzw[..., [3, 0, 1, 2]].clone()
root_lin_vel = self.get_ref_motion_root_global_lin_vel_cur().clone()
root_ang_vel = self.get_ref_motion_root_global_ang_vel_cur().clone()
pos_rot_range_list = [
self.cfg.root_pose_perturb_range.get(key, (0.0, 0.0))
for key in ["x", "y", "z", "roll", "pitch", "yaw"]
]
pos_rot_ranges = torch.tensor(pos_rot_range_list, device=self.device)
pos_rot_rand_deltas = isaaclab_math.sample_uniform(
pos_rot_ranges[:, 0],
pos_rot_ranges[:, 1],
(len(env_ids), 6),
device=self.device,
)
translation_delta = pos_rot_rand_deltas[:, 0:3]
rotation_delta = isaaclab_math.quat_from_euler_xyz(
pos_rot_rand_deltas[:, 3],
pos_rot_rand_deltas[:, 4],
pos_rot_rand_deltas[:, 5],
)
root_pos[env_ids] += translation_delta
root_rot[env_ids] = isaaclab_math.quat_mul(
rotation_delta,
root_rot[env_ids],
)
lin_ang_vel_range_list = [
self.cfg.root_vel_perturb_range.get(key, (0.0, 0.0))
for key in ["x", "y", "z", "roll", "pitch", "yaw"]
]
lin_ang_vel_ranges = torch.tensor(
lin_ang_vel_range_list, device=self.device
)
lin_ang_vel_rand_deltas = isaaclab_math.sample_uniform(
lin_ang_vel_ranges[:, 0],
lin_ang_vel_ranges[:, 1],
(len(env_ids), 6),
device=self.device,
)
root_lin_vel[env_ids] += lin_ang_vel_rand_deltas[:, :3]
root_ang_vel[env_ids] += lin_ang_vel_rand_deltas[:, 3:]
self.robot.write_root_state_to_sim(
torch.cat(
[
root_pos[env_ids],
root_rot[env_ids],
root_lin_vel[env_ids],
root_ang_vel[env_ids],
],
dim=-1,
),
env_ids=env_ids,
)
def _align_dof_to_ref(self, env_ids):
if not isinstance(env_ids, torch.Tensor):
env_ids = torch.tensor(
env_ids, device=self.device, dtype=torch.long
)
else:
env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1)
env_ids = self._filter_env_ids_for_motion_task(env_ids)
if env_ids.numel() == 0:
return
dof_pos = self.get_ref_motion_dof_pos_cur().clone()
dof_vel = self.get_ref_motion_dof_vel_cur().clone()
dof_pos += isaaclab_math.sample_uniform(
*self.cfg.dof_pos_perturb_range,
dof_pos.shape,
dof_pos.device,
)
soft_dof_pos_limits = self.robot.data.soft_joint_pos_limits[env_ids]
dof_pos[env_ids] = torch.clip(
dof_pos[env_ids],
soft_dof_pos_limits[:, :, 0],
soft_dof_pos_limits[:, :, 1],
)
self.robot.write_joint_state_to_sim(
dof_pos[env_ids],
dof_vel[env_ids],
env_ids=env_ids,
)
def force_realign_root_state_to_ref_no_perturb(self, env_ids) -> None:
if not isinstance(env_ids, torch.Tensor):
env_ids = torch.tensor(
env_ids, device=self.device, dtype=torch.long
)
else:
env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1)
env_ids = self._filter_env_ids_for_motion_task(env_ids)
if env_ids.numel() == 0:
return
root_pos = self.get_ref_motion_root_global_pos_cur().clone()
root_rot_xyzw = self.get_ref_motion_root_global_rot_quat_xyzw_cur()
root_rot = root_rot_xyzw[..., [3, 0, 1, 2]].clone()
root_lin_vel = self.get_ref_motion_root_global_lin_vel_cur().clone()
root_ang_vel = self.get_ref_motion_root_global_ang_vel_cur().clone()
self.robot.write_root_state_to_sim(
torch.cat(
[
root_pos[env_ids],
root_rot[env_ids],
root_lin_vel[env_ids],
root_ang_vel[env_ids],
],
dim=-1,
),
env_ids=env_ids,
)
def force_realign_dof_state_to_ref_no_perturb(self, env_ids) -> None:
if not isinstance(env_ids, torch.Tensor):
env_ids = torch.tensor(
env_ids, device=self.device, dtype=torch.long
)
else:
env_ids = env_ids.to(device=self.device, dtype=torch.long).view(-1)
env_ids = self._filter_env_ids_for_motion_task(env_ids)
if env_ids.numel() == 0:
return
dof_pos = self.get_ref_motion_dof_pos_cur().clone()
dof_vel = self.get_ref_motion_dof_vel_cur().clone()
soft_dof_pos_limits = self.robot.data.soft_joint_pos_limits[env_ids]
dof_pos[env_ids] = torch.clip(
dof_pos[env_ids],
soft_dof_pos_limits[:, :, 0],
soft_dof_pos_limits[:, :, 1],
)
self.robot.write_joint_state_to_sim(
dof_pos[env_ids],
dof_vel[env_ids],
env_ids=env_ids,
)
def force_realign_offline_eval_no_perturb(self, env_ids) -> None:
self.force_realign_root_state_to_ref_no_perturb(env_ids)
self.force_realign_dof_state_to_ref_no_perturb(env_ids)
def _update_command(self):
all_ids = torch.arange(
self.num_envs, dtype=torch.long, device=self.device
)
motion_ids = self._filter_env_ids_for_motion_task(all_ids)
if motion_ids.numel() == 0:
return
continue_ids = motion_ids
episode_length_buf = getattr(self._env, "episode_length_buf", None)
if episode_length_buf is not None:
continue_mask = episode_length_buf[motion_ids] != 0
continue_ids = motion_ids[continue_mask]
if continue_ids.numel() > 0:
self._frame_indices[continue_ids] += 1
self._swap_step_counter += 1
if self._swap_step_counter >= self._motion_cache.swap_interval_steps:
self._swap_pending = True
# Resample when motion ends
self._resample_when_motion_end_cache()
self._update_ref_motion_state_from_cache()
def _resample_when_motion_end_cache(self):
"""Resample environments when motion ends (simple cache mode)."""
all_ids = torch.arange(
self.num_envs, dtype=torch.long, device=self.device
)
motion_ids = self._filter_env_ids_for_motion_task(all_ids)
if motion_ids.numel() == 0:
return
lengths = self._motion_cache.lengths_for_indices(self._clip_indices)
max_valid_frame = torch.clamp(
lengths - 1 - self.cfg.n_fut_frames, min=0
)
need_resample = (
self._frame_indices[motion_ids] > max_valid_frame[motion_ids]
)
if torch.any(need_resample):
resample_ids = motion_ids[torch.nonzero(need_resample).squeeze(-1)]
# Resample these envs
self._record_completion_rate_for_envs(resample_ids)
clip_idx, frame_idx = self._motion_cache.sample_env_assignments(
len(resample_ids),
self.cfg.n_fut_frames,
self.device,
deterministic_start=self._is_evaluating,
)
self._clip_indices[resample_ids] = clip_idx
self._frame_indices[resample_ids] = frame_idx
self._start_frame_indices[resample_ids] = frame_idx
self._reward_sum_since_assign[resample_ids] = 0.0
self._step_count_since_assign[resample_ids] = 0.0
# Realign robot state
self._update_ref_motion_state_from_cache(env_ids=resample_ids)
self._align_root_to_ref(resample_ids)
self._align_dof_to_ref(resample_ids)
# Mark motion end
self._motion_end_mask[motion_ids] = False
self._motion_end_mask[resample_ids] = True
self.motion_end_counter[resample_ids] += 1
def _update_metrics(self):
"""Update metrics for command progress tracking."""
if not hasattr(self, "metrics"):
self.metrics = {}
self._update_mpjpe_metrics()
self._update_mpkpe_metrics()
def _update_mpjpe_metrics(self):
"""Update MPJPE (Mean Per Joint Position Error) metrics."""
# Get current and reference joint positions
current_dof_pos = self.robot.data.joint_pos # [B, num_dofs]
ref_dof_pos = self.get_ref_motion_dof_pos_immediate_next()
# Compute joint position errors
dof_pos_error = torch.abs(
current_dof_pos - ref_dof_pos
) # [B, num_dofs]
# MPJPE whole body
mpjpe_wholebody = torch.mean(dof_pos_error, dim=-1) # [B]
# MPJPE arms (using unified naming)
mpjpe_arms = torch.mean(
dof_pos_error[:, self.arm_dof_indices], dim=-1
) # [B]
# MPJPE torso (using unified naming)
mpjpe_waist = torch.mean(
dof_pos_error[:, self.torso_dof_indices], dim=-1
) # [B]
# MPJPE legs
mpjpe_legs = torch.mean(
dof_pos_error[:, self.leg_dof_indices], dim=-1
) # [B]
# Initialize metric tensors if needed
for metric_name in [
"Task/MPJPE_WholeBody",
"Task/MPJPE_Arms",
"Task/MPJPE_Waist",
"Task/MPJPE_Legs",
]:
if metric_name not in self.metrics:
self.metrics[metric_name] = torch.zeros(
self.num_envs, device=self.device
)
# Update metric values
self.metrics["Task/MPJPE_WholeBody"][:] = mpjpe_wholebody
self.metrics["Task/MPJPE_Arms"][:] = mpjpe_arms
self.metrics["Task/MPJPE_Waist"][:] = mpjpe_waist
self.metrics["Task/MPJPE_Legs"][:] = mpjpe_legs
def _update_mpkpe_metrics(self):
"""Update MPKPE (Mean Per Keybody Position Error) metrics."""
# Get current and reference body positions
current_body_pos = self.robot.data.body_pos_w # [B, num_bodies, 3]
ref_body_pos = self.get_ref_motion_bodylink_global_pos_immediate_next()
# [B, num_bodies, 3]
# Compute body position errors (L2 norm)
body_pos_error = torch.norm(
current_body_pos - ref_body_pos, dim=-1
) # [B, num_bodies]
# MPKPE whole body
mpkpe_wholebody = torch.mean(body_pos_error, dim=-1) # [B]
# MPKPE arms (using unified naming)
mpkpe_arms = torch.mean(
body_pos_error[:, self.arm_body_indices], dim=-1
) # [B]
# MPKPE torso (using unified naming)
mpkpe_waist = torch.mean(
body_pos_error[:, self.torso_body_indices], dim=-1
) # [B]
# MPKPE legs
mpkpe_legs = torch.mean(
body_pos_error[:, self.leg_body_indices], dim=-1
) # [B]
# Initialize metric tensors if needed
for metric_name in [
"Task/MPKPE_WholeBody",
"Task/MPKPE_Arms",
"Task/MPKPE_Waist",
"Task/MPKPE_Legs",
]:
if metric_name not in self.metrics:
self.metrics[metric_name] = torch.zeros(
self.num_envs, device=self.device
)
# Update metric values
self.metrics["Task/MPKPE_WholeBody"][:] = mpkpe_wholebody
self.metrics["Task/MPKPE_Arms"][:] = mpkpe_arms
self.metrics["Task/MPKPE_Waist"][:] = mpkpe_waist
self.metrics["Task/MPKPE_Legs"][:] = mpkpe_legs
# --- Pose-error getters for curriculum (WholeBody only) ---
def get_wholebody_mpjpe(
self,
) -> torch.Tensor:
"""[B] current whole-body MPJPE (URDF joint-space abs error)."""
if not hasattr(self, "metrics") or (
"Task/MPJPE_WholeBody" not in self.metrics
):
return torch.zeros(self.num_envs, device=self.device)
return self.metrics["Task/MPJPE_WholeBody"]
def get_wholebody_mpkpe(
self,
) -> torch.Tensor:
"""[B] current whole-body MPKPE (body position error)."""
if not hasattr(self, "metrics") or (
"Task/MPKPE_WholeBody" not in self.metrics
):
return torch.zeros(self.num_envs, device=self.device)
return self.metrics["Task/MPKPE_WholeBody"]
def get_current_motion_keys(
self,
) -> list[str]:
"""Return motion window keys for the envs' current cached clips."""
try:
if hasattr(self, "_motion_cache") and hasattr(
self._motion_cache, "motion_keys_for_indices"
):
return self._motion_cache.motion_keys_for_indices(
self._clip_indices
)
except Exception:
pass
return []
def _set_debug_vis_impl(self, debug_vis: bool):
if debug_vis:
# Just enable debug mode - visualizers will be created lazily in callback
self._debug_vis_enabled = True
# Set visibility if visualizers already exist
if hasattr(self, "ref_body_visualizers"):
for visualizer in self.ref_body_visualizers:
visualizer.set_visibility(True)
else:
self._debug_vis_enabled = False
# Set visibility to false
if hasattr(self, "ref_body_visualizers"):
for visualizer in self.ref_body_visualizers:
visualizer.set_visibility(False)
def setup_offline_eval_from_frame_zero(self):
"""Setup reference frame indices for offline evaluation from frame 0."""
self._frame_indices[:] = 0
self._update_ref_motion_state()
logger.info(
f"Offline evaluation setup complete: all {self.num_envs} "
f"environments set to frame 0 references"
)
def setup_offline_eval_deterministic(
self, apply_pending_swap: bool = True
) -> None:
"""Deterministic multi-env setup for offline evaluation.
- Optionally apply a pending cache swap.
- Set env i -> cache row i mapping for active clips, frame 0.
- Update reference state only. Robot realignment is handled by caller.
"""
if apply_pending_swap and getattr(self, "_swap_pending", False):
self._motion_cache.advance()
self._swap_pending = False
self._swap_step_counter = 0
clip_count = int(self._motion_cache.clip_count)
active_count = min(int(self.num_envs), clip_count)
# Reset indices
self._clip_indices[:] = 0
self._frame_indices[:] = 0
if active_count > 0:
active_ids = torch.arange(
active_count, dtype=torch.long, device=self.device
)
self._clip_indices[active_ids] = torch.arange(
active_count, dtype=torch.long, device=self.device
)
self._update_ref_motion_state_from_cache()
def _debug_vis_callback(self, event):
if not self.robot.is_initialized:
return
# Check if debug visualization is enabled
if not getattr(self, "_debug_vis_enabled", False):
return
# Check if motion cache/assignments are available
if (
not hasattr(self, "_motion_cache")
or self._motion_cache is None
or not hasattr(self, "_clip_indices")
or not hasattr(self, "_frame_indices")
):
return
# Create visualizers lazily if they don't exist
if not hasattr(self, "ref_body_visualizers"):
self.ref_body_visualizers = []
# Get number of bodies from the reference motion data
num_bodies = self.get_ref_motion_bodylink_global_pos_cur().shape[
-2
]
for i in range(num_bodies):
# Reference bodylinks as red spheres
self.ref_body_visualizers.append(
VisualizationMarkers(
self.cfg.body_keypoint_visualizer_cfg.replace(
prim_path=f"/Visuals/Command/ref_body_{i}"
)
)
)
# Visualize reference body keypoints
if len(self.ref_body_visualizers) > 0:
ref_body_pos = self.get_ref_motion_bodylink_global_pos_cur()
# [B, num_bodies, 3]
num_bodies = min(
len(self.ref_body_visualizers), ref_body_pos.shape[1]
)
for i in range(num_bodies):
# Visualize reference bodylinks as spheres (position only)
self.ref_body_visualizers[i].visualize(
ref_body_pos[:, i], # [B, 3]
)
@configclass
class MotionCommandCfg(CommandTermCfg):
"""Configuration for the motion command."""
class_type: type = RefMotionCommand
command_obs_name: str = MISSING
urdf_dof_names: list[str] = MISSING
urdf_body_names: list[str] = MISSING
# DOF name groupings for mpjpe metrics (using unified naming)
arm_dof_names: list[str] = MISSING
waist_dof_names: list[str] = MISSING
leg_dof_names: list[str] = MISSING
# Body name groupings for mpkpe metrics (using unified naming)
arm_body_names: list[str] = MISSING
torso_body_names: list[str] = MISSING
leg_body_names: list[str] = MISSING
motion_lib_cfg: dict = MISSING
seed: int = MISSING
process_id: int = MISSING
num_processes: int = MISSING
is_evaluating: bool = MISSING
resample_time_interval_s: float = MISSING
n_fut_frames: int = MISSING
target_fps: int = MISSING
anchor_bodylink_name: str = "pelvis"
asset_name: str = MISSING
debug_vis: bool = False
root_pose_perturb_range: dict[str, tuple[float, float]] = {}
root_vel_perturb_range: dict[str, tuple[float, float]] = {}
dof_pos_perturb_range: tuple[float, float] = (-0.1, 0.1)
dof_vel_perturb_range: tuple[float, float] = (-1.0, 1.0)
body_keypoint_visualizer_cfg: VisualizationMarkersCfg = (
SPHERE_MARKER_CFG.replace(prim_path="/Visuals/Command/ref_keypoint")
)
body_keypoint_visualizer_cfg.markers["sphere"].radius = 0.03
body_keypoint_visualizer_cfg.markers[
"sphere"
].visual_material = PreviewSurfaceCfg(
diffuse_color=(0.0, 0.0, 1.0) # blue
)
resampling_time_range: tuple[float, float] = (1.0, 1.0)
@configclass
class MoTrack_CommandsCfg:
pass
def build_motion_tracking_commands_config(command_config_dict: dict):
"""Build isaaclab-compatible CommandsCfg from a config dictionary.
Args:
command_config_dict: Dictionary mapping command names to command configurations.
Each command config should contain the type and parameters.
Example:
command_config_dict = {
"ref_motion": {
"type": "MotionCommandCfg",
"params": {
"command_obs_name": "bydmmc_ref_motion",
"motion_lib_cfg": {...},
"process_id": 0,
"num_processes": 1,
# ... other parameters
}
}
}
"""
commands_cfg = MoTrack_CommandsCfg()
# Add command terms dynamically
for command_name, command_config in command_config_dict.items():
command_type = command_config.get("type", "MotionCommandCfg")
command_params = command_config.get("params", {})
# Get the command class type
if command_type == "MotionCommandCfg":
command_cfg = MotionCommandCfg(**command_params)
else:
raise ValueError(f"Unknown command type: {command_type}")
# Add command to config
setattr(commands_cfg, command_name, command_cfg)
return commands_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_observation.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import isaaclab.envs.mdp as isaaclab_mdp
import isaaclab.sim as sim_utils
from dataclasses import fields as dataclass_fields
from isaaclab.actuators import ImplicitActuatorCfg
from isaaclab.assets import Articulation, ArticulationCfg, AssetBaseCfg
from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg
from isaaclab.managers import (
ActionTermCfg,
CommandTerm,
CommandTermCfg,
EventTermCfg as EventTerm,
ObservationGroupCfg,
ObservationGroupCfg as ObsGroup,
ObservationTermCfg,
ObservationTermCfg as ObsTerm,
RewardTermCfg,
SceneEntityCfg,
TerminationTermCfg,
)
import torch
from isaaclab.markers import (
VisualizationMarkers,
VisualizationMarkersCfg,
)
from isaaclab.markers.config import FRAME_MARKER_CFG
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns
from isaaclab.sim import PhysxCfg, SimulationCfg
from isaaclab.terrains import TerrainImporterCfg
from isaaclab.utils import configclass
import isaaclab.utils.math as isaaclab_math
import isaaclab.utils.noise as isaaclab_noise
from omegaconf import DictConfig, ListConfig, OmegaConf
from holomotion.src.env.isaaclab_components.isaaclab_utils import (
resolve_holo_config,
)
from holomotion.src.utils.frame_utils import (
positions_world_to_env_frame,
root_relative_positions_from_env_frame,
)
def _build_noise_cfg(noise_cfg):
noise_cfg = resolve_holo_config(noise_cfg)
if not (isinstance(noise_cfg, dict) and "type" in noise_cfg):
return noise_cfg
noise_cls = getattr(isaaclab_noise, noise_cfg["type"])
noise_params = resolve_holo_config(noise_cfg.get("params", {}))
if not isinstance(noise_params, dict):
return noise_cls(**noise_params)
noise_params = dict(noise_params)
if "n_min_z" in noise_params or "n_max_z" in noise_params:
base_n_min = noise_params["n_min"]
base_n_max = noise_params["n_max"]
noise_params["n_min"] = torch.tensor(
[base_n_min, base_n_min, noise_params.pop("n_min_z", base_n_min)],
dtype=torch.float32,
)
noise_params["n_max"] = torch.tensor(
[base_n_max, base_n_max, noise_params.pop("n_max_z", base_n_max)],
dtype=torch.float32,
)
return noise_cls(**noise_params)
class MirrorFunctions:
"""Generic observation mirroring utilities."""
@staticmethod
def mirror_dof(
x: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor
) -> torch.Tensor:
"""Mirror DOF-aligned tensor [..., A] with permutation and sign."""
if x.shape[-1] != int(perm.numel()):
raise ValueError(
f"mirror_dof expected last dim {perm.numel()}, got {x.shape[-1]}"
)
if perm.device != x.device or perm.dtype != torch.long:
perm = perm.to(device=x.device, dtype=torch.long)
if sign.device != x.device or sign.dtype != x.dtype:
sign = sign.to(device=x.device, dtype=x.dtype)
mirrored = torch.index_select(x, dim=x.ndim - 1, index=perm)
sign_view = sign.view(*([1] * (mirrored.ndim - 1)), sign.numel())
return mirrored * sign_view
@staticmethod
def mirror_action(
actions: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor
) -> torch.Tensor:
"""Mirror action tensor [..., A] in DOF space with permutation and sign."""
return MirrorFunctions.mirror_dof(actions, perm=perm, sign=sign)
@staticmethod
def mirror_vec3(x: torch.Tensor) -> torch.Tensor:
"""Mirror a true vector [..., 3] with sign [1, -1, 1]."""
if x.shape[-1] != 3:
raise ValueError(
f"mirror_vec3 expected last dim 3, got {x.shape[-1]}"
)
sign = torch.tensor(
[1.0, -1.0, 1.0], device=x.device, dtype=x.dtype
).view(*([1] * (x.ndim - 1)), 3)
return x * sign
@staticmethod
def mirror_axial_vec3(x: torch.Tensor) -> torch.Tensor:
"""Mirror an axial vector [..., 3] with sign [-1, 1, -1]."""
if x.shape[-1] != 3:
raise ValueError(
f"mirror_axial_vec3 expected last dim 3, got {x.shape[-1]}"
)
sign = torch.tensor(
[-1.0, 1.0, -1.0], device=x.device, dtype=x.dtype
).view(*([1] * (x.ndim - 1)), 3)
return x * sign
@staticmethod
def mirror_velocity_command(x: torch.Tensor) -> torch.Tensor:
"""Mirror velocity command [..., 3] or [..., 4] preserving move_mask."""
last_dim = x.shape[-1]
if last_dim == 3:
sign = torch.tensor(
[1.0, -1.0, -1.0], device=x.device, dtype=x.dtype
).view(*([1] * (x.ndim - 1)), 3)
return x * sign
if last_dim == 4:
sign = torch.tensor(
[1.0, 1.0, -1.0, -1.0], device=x.device, dtype=x.dtype
).view(*([1] * (x.ndim - 1)), 4)
return x * sign
raise ValueError(
f"mirror_velocity_command expected last dim 3 or 4, got {last_dim}"
)
class ObservationFunctions:
"""Atomic observation functions.
The most foundamental observation functions are defined here, aiming to
utize the convenient functions from isaaclab apis. For complex observation
composition patterns, we'll use the custom observation serizliazer.
"""
@staticmethod
def _get_body_indices(
robot: Articulation, keybody_names: list[str] | None
) -> list[int]:
"""Convert body names to indices.
Args:
robot: Robot articulation asset
keybody_names: List of body names. If None, returns all body indices.
Returns:
List of body indices corresponding to the given names
"""
if keybody_names is None:
return list(range(robot.num_bodies))
body_indices = []
for name in keybody_names:
if name not in robot.body_names:
raise ValueError(
f"Body '{name}' not found in robot.body_names: {robot.body_names}"
)
body_indices.append(robot.body_names.index(name))
return body_indices
@staticmethod
def _slice_future_frames(
tensor: torch.Tensor,
*,
num_frames: int | None,
obs_name: str,
) -> torch.Tensor:
if num_frames is None:
return tensor
num_frames = int(num_frames)
if num_frames <= 0:
raise ValueError(
f"{obs_name} num_frames must be positive, got {num_frames}."
)
if tensor.ndim < 2:
raise ValueError(
f"{obs_name} expected future tensor with ndim >= 2, got {tensor.ndim}."
)
if int(tensor.shape[1]) < num_frames:
raise ValueError(
f"{obs_name} requested {num_frames} future frames, but only "
f"{int(tensor.shape[1])} are available."
)
return tensor[:, :num_frames, ...]
# ------- Robot Head / mid360 States -------
@staticmethod
def _get_obs_head_pos_quat_vel(
env: ManagerBasedRLEnv, robot_asset_name: str = "robot"
):
"""Head (mid360) features in torso frame with first-frame anchor.
Returns [B,13]: pos(3), quat_wxyz->xyzw(4), lin_vel(3), ang_vel(3), all in torso frame and anchored.
"""
robot_ptr = env.scene[robot_asset_name]
body_names = robot_ptr.body_names
if body_names is None:
raise RuntimeError("robot.body_names is empty")
try:
torso_idx = body_names.index("torso_link")
except ValueError:
raise ValueError(
f"'torso_link' not found in body_names: {body_names}"
)
B = env.num_envs
device = env.device
# Mid360 extrinsics relative to torso (rotation about Y by pitch)
rel_pos_t = torch.tensor(
[0.0002835, 0.00003, 0.41618], dtype=torch.float, device=device
)
pitch = torch.tensor(
0.04014257279586953, dtype=torch.float, device=device
)
half = pitch * 0.5
# WXYZ
rel_quat_wxyz = torch.stack(
[
torch.cos(half),
torch.zeros_like(half),
torch.sin(half),
torch.zeros_like(half),
],
dim=-1,
)
rel_quat_wxyz = rel_quat_wxyz.expand(B, -1)
# World pose/vel from torso + extrinsics (WXYZ math)
torso_pos_w = robot_ptr.data.body_pos_w[:, torso_idx, :]
torso_quat_wxyz = robot_ptr.data.body_quat_w[:, torso_idx, :]
torso_lin_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :]
torso_ang_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :]
rel_pos = rel_pos_t.expand(B, -1)
r_world = isaaclab_math.quat_apply(torso_quat_wxyz, rel_pos)
pos_w = torso_pos_w + r_world
quat_wxyz = isaaclab_math.quat_mul(torso_quat_wxyz, rel_quat_wxyz)
lin_w = torso_lin_w + torch.cross(torso_ang_w, r_world, dim=-1)
ang_w = torso_ang_w
# Convert to torso frame (WXYZ math)
rel_p = pos_w - torso_pos_w
torso_inv_wxyz = isaaclab_math.quat_inv(torso_quat_wxyz)
pos_torso = isaaclab_math.quat_apply(torso_inv_wxyz, rel_p)
lin_torso = isaaclab_math.quat_apply(
torso_inv_wxyz, lin_w - torch.cross(ang_w, rel_p, dim=-1)
)
ang_torso = isaaclab_math.quat_apply(torso_inv_wxyz, ang_w)
quat_torso_wxyz = isaaclab_math.quat_mul(torso_inv_wxyz, quat_wxyz)
# export quaternion as XYZW to match common obs format
quat_torso_xyzw = quat_torso_wxyz[..., [1, 2, 3, 0]]
# First-frame anchor normalization (in torso frame)
if not hasattr(env, "head_anchor_set"):
env.head_anchor_set = torch.zeros(
B, dtype=torch.bool, device=device
)
env.head_anchor_pos = torch.zeros(B, 3, device=device)
env.head_anchor_quat_wxyz = torch.zeros(B, 4, device=device)
env.head_anchor_quat_wxyz[:, 0] = 1.0 # identity W
unset = ~env.head_anchor_set
if unset.any():
env.head_anchor_pos[unset] = pos_torso[unset]
env.head_anchor_quat_wxyz[unset] = quat_torso_wxyz[unset]
env.head_anchor_set[unset] = True
q0_inv = isaaclab_math.quat_inv(env.head_anchor_quat_wxyz)
pos_rel = isaaclab_math.quat_apply(
q0_inv, pos_torso - env.head_anchor_pos
)
lin_rel = isaaclab_math.quat_apply(q0_inv, lin_torso)
ang_rel = isaaclab_math.quat_apply(q0_inv, ang_torso)
quat_rel_wxyz = isaaclab_math.quat_mul(q0_inv, quat_torso_wxyz)
quat_rel_xyzw = quat_rel_wxyz[..., [1, 2, 3, 0]]
return torch.cat([pos_rel, quat_rel_xyzw, lin_rel, ang_rel], dim=-1)
@staticmethod
def _get_obs_rel_headlink_lin_vel(
env: ManagerBasedRLEnv, robot_asset_name: str = "robot"
) -> torch.Tensor: # [num_envs, 3]
"""Headlink relative linear velocity, expressed in the headlink's frame.
Definitions:
- Headlink: a virtual rigid sensor frame fixed to `torso_link` using the
extrinsics defined below (translation `rel_pos_t` and rotation `rel_quat_wxyz`).
- Relative linear velocity: v_head - v_torso_origin, both measured in the world
frame before re-expression. For a rigid mount, this equals ω_torso × r_world.
- Expression frame: the instantaneous headlink frame (i.e., result is in headlink axes).
Returns:
Tensor of shape [num_envs, 3]: headlink relative linear velocity in headlink frame.
"""
robot_ptr = env.scene[robot_asset_name]
body_names = robot_ptr.body_names
if body_names is None:
raise RuntimeError("robot.body_names is empty")
torso_idx = body_names.index("torso_link")
num_envs = env.num_envs
device = env.device
# Headlink extrinsics relative to torso: translation + rotation about Y (pitch)
rel_pos_t = torch.tensor(
[0.0002835, 0.00003, 0.41618], dtype=torch.float, device=device
) # [3]
pitch = torch.tensor(
0.04014257279586953, dtype=torch.float, device=device
)
half = pitch * 0.5
# Quaternion (WXYZ) for rotation about Y by 'pitch'
rel_quat_wxyz = torch.stack(
[
torch.cos(half),
torch.zeros_like(half),
torch.sin(half),
torch.zeros_like(half),
],
dim=-1,
).expand(num_envs, -1) # [num_envs, 4]
# Torso world state
torso_quat_wxyz = robot_ptr.data.body_quat_w[
:, torso_idx, :
] # [num_envs, 4]
torso_lin_w = robot_ptr.data.body_lin_vel_w[
:, torso_idx, :
] # [num_envs, 3]
torso_ang_w = robot_ptr.data.body_ang_vel_w[
:, torso_idx, :
] # [num_envs, 3]
# Headlink world pose from torso + extrinsics
rel_pos = rel_pos_t.expand(num_envs, -1) # [num_envs, 3]
r_world = isaaclab_math.quat_apply(
torso_quat_wxyz, rel_pos
) # [num_envs, 3]
head_quat_wxyz = isaaclab_math.quat_mul(
torso_quat_wxyz, rel_quat_wxyz
) # [num_envs, 4]
# World-frame velocities
head_lin_w = torso_lin_w + torch.cross(
torso_ang_w, r_world, dim=-1
) # [num_envs, 3]
# Relative linear velocity in world frame
rel_lin_w = (
head_lin_w - torso_lin_w
) # [num_envs, 3] == ω_torso × r_world
# Re-express in headlink frame
head_inv_wxyz = isaaclab_math.quat_inv(head_quat_wxyz) # [num_envs, 4]
rel_lin_head = isaaclab_math.quat_apply(
head_inv_wxyz, rel_lin_w
) # [num_envs, 3]
return rel_lin_head
@staticmethod
def _get_obs_rel_headlink_ang_vel(
env: ManagerBasedRLEnv, robot_asset_name: str = "robot"
) -> torch.Tensor: # [num_envs, 3]
"""Headlink relative angular velocity, expressed in the headlink's frame.
Definitions:
- Headlink: a virtual rigid sensor frame fixed to `torso_link` using the
extrinsics defined below.
- Relative angular velocity: ω_head - ω_torso, measured in the world frame,
then re-expressed in the headlink frame.
- For a rigid mount (no neck articulation), ω_head == ω_torso, so the result
is identically zero. If an articulated head exists, replace ω_head with the
head link's world angular velocity before the subtraction.
Returns:
Tensor of shape [num_envs, 3]: headlink relative angular velocity in headlink frame.
"""
robot_ptr = env.scene[robot_asset_name]
body_names = robot_ptr.body_names
if body_names is None:
raise RuntimeError("robot.body_names is empty")
torso_idx = body_names.index("torso_link")
num_envs = env.num_envs
device = env.device
# Headlink extrinsics (rotation about Y by pitch)
pitch = torch.tensor(
0.04014257279586953, dtype=torch.float, device=device
)
half = pitch * 0.5
rel_quat_wxyz = torch.stack(
[
torch.cos(half),
torch.zeros_like(half),
torch.sin(half),
torch.zeros_like(half),
],
dim=-1,
).expand(num_envs, -1) # [num_envs, 4]
torso_quat_wxyz = robot_ptr.data.body_quat_w[
:, torso_idx, :
] # [num_envs, 4]
torso_ang_w = robot_ptr.data.body_ang_vel_w[
:, torso_idx, :
] # [num_envs, 3]
# For the rigid mount, ω_head_w == ω_torso_w
head_ang_w = torso_ang_w # [num_envs, 3]
rel_ang_w = (
head_ang_w - torso_ang_w
) # [num_envs, 3] -> zeros for rigid mount
# Re-express in headlink frame
head_quat_wxyz = isaaclab_math.quat_mul(
torso_quat_wxyz, rel_quat_wxyz
) # [num_envs, 4]
head_inv_wxyz = isaaclab_math.quat_inv(head_quat_wxyz) # [num_envs, 4]
rel_ang_head = isaaclab_math.quat_apply(
head_inv_wxyz, rel_ang_w
) # [num_envs, 3]
return rel_ang_head
# ------- Robot Root States -------
@staticmethod
def _get_obs_global_robot_root_pos(env: ManagerBasedRLEnv):
"""Asset root position in the environment frame.
IsaacLab's root position helpers subtract `env.scene.env_origins`, so
this is not the raw simulator-world position.
"""
return isaaclab_mdp.root_pos_w(env)
@staticmethod
def _get_obs_global_robot_root_rot_wxyz(env: ManagerBasedRLEnv):
"""Asset root orientation (w, x, y, z) in the environment frame."""
return isaaclab_mdp.root_quat_w(env)
@staticmethod
def _get_obs_global_robot_root_rot_xyzw(env: ManagerBasedRLEnv):
"""Asset root orientation (x, y, z, w) in the environment frame."""
return ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)[
..., [1, 2, 3, 0]
]
@staticmethod
def _get_obs_global_robot_root_rot_mat(env: ManagerBasedRLEnv):
"""Asset root orientation as a 3x3 matrix, flattened to the first two rows (6D)."""
return isaaclab_math.matrix_from_quat(
ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)
)[..., :2] # [num_envs, 6]
@staticmethod
def _get_obs_global_robot_root_lin_vel(env: ManagerBasedRLEnv):
"""Asset root linear velocity in the environment frame."""
return isaaclab_mdp.root_lin_vel_w(env) # [num_envs, 3]
@staticmethod
def _get_obs_global_robot_root_ang_vel(env: ManagerBasedRLEnv):
"""Asset root angular velocity in the environment frame."""
return isaaclab_mdp.root_ang_vel_w(env) # [num_envs, 3]
@staticmethod
def _get_obs_rel_robot_root_lin_vel(env: ManagerBasedRLEnv):
"""Relative root linear velocity in the root frame."""
return isaaclab_mdp.base_lin_vel(env) # [num_envs, 3]
@staticmethod
def _get_obs_rel_robot_root_ang_vel(env: ManagerBasedRLEnv):
"""Relative root angular velocity in the root frame."""
return isaaclab_mdp.base_ang_vel(env) # [num_envs, 3]
@staticmethod
def _get_obs_rel_anchor_lin_vel(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
anchor_bodylink_name: str = "torso_link",
):
"""Relative anchor linear velocity in the anchor frame."""
torso_global_rot_quat_wxyz = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(
env, robot_asset_name, [anchor_bodylink_name]
)
) # [num_envs, 1, 4]
torso_global_lin_vel = (
ObservationFunctions._get_obs_global_robot_bodylink_lin_vel(
env, robot_asset_name, [anchor_bodylink_name]
)
) # [num_envs, 1, 3]
torso_rel_lin_vel = isaaclab_math.quat_apply(
isaaclab_math.quat_inv(torso_global_rot_quat_wxyz),
torso_global_lin_vel,
) # [num_envs, 1, 3]
return torso_rel_lin_vel.squeeze(1) # [num_envs, 3]
@staticmethod
def _get_obs_projected_gravity(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
) -> torch.Tensor: # [num_envs, 3]
"""Gravity vector projected into the robot's root frame.
Projects the world-frame gravity vector into the robot's base frame
using the inverse root orientation quaternion.
"""
robot_ptr = env.scene[robot_asset_name]
g_w: torch.Tensor = robot_ptr.data.GRAVITY_VEC_W # [num_envs, 3]
root_quat_wxyz: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)
) # [num_envs, 4]
# Project gravity into root frame using inverse quaternion
projected_gravity: torch.Tensor = isaaclab_math.quat_apply_inverse(
root_quat_wxyz, g_w
) # [num_envs, 3]
return projected_gravity
@staticmethod
def _get_obs_global_robot_root_yaw(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
):
"""Robot's yaw heading in the environment frame (in radians)."""
robot_ptr = env.scene[robot_asset_name]
return robot_ptr.data.heading_w # [num_envs, ]
# @torch.compile
@staticmethod
def _get_obs_robot_root_heading_aligned_quat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
):
"""A quaternion representing only the robot's yaw heading."""
global_yaw = ObservationFunctions._get_obs_global_robot_root_yaw(
env,
robot_asset_name,
) # [num_envs, ]
zero_roll = torch.zeros_like(global_yaw, device=env.device)
zero_pitch = torch.zeros_like(global_yaw, device=env.device)
heading_aligned_quat = isaaclab_math.quat_from_angle_axis(
roll=zero_roll,
pitch=zero_pitch,
yaw=global_yaw,
) # [num_envs, 4]
return heading_aligned_quat # [num_envs, 4]
# @torch.compile
@staticmethod
def _get_obs_rel_robot_root_roll_pitch(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
):
"""Robot's roll and pitch relative to its heading-aligned frame."""
heading_aligned_quat = (
ObservationFunctions._get_obs_robot_root_heading_aligned_quat(
env,
robot_asset_name,
)
) # [num_envs, 4]
robot_quat_in_heading_aligned_frame = isaaclab_math.quat_mul(
isaaclab_math.quat_inv(heading_aligned_quat),
ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env),
) # [num_envs, 4]
rel_roll, rel_pitch, _ = isaaclab_math.get_euler_xyz(
robot_quat_in_heading_aligned_frame
) # [num_envs, 3]
return torch.stack([rel_roll, rel_pitch], dim=-1) # [num_envs, 2]
# ------- Robot Bodylink States -------
@staticmethod
def _get_obs_global_robot_bodylink_pos(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
):
"""Positions of specified bodylinks in the environment frame.
Body link poses are stored in simulator-world coordinates, so this
helper subtracts `env.scene.env_origins` to match IsaacLab's
environment-frame root helpers.
"""
robot_ptr = env.scene[robot_asset_name]
keybody_idxs = ObservationFunctions._get_body_indices(
robot_ptr, keybody_names
)
keybody_global_pos = positions_world_to_env_frame(
robot_ptr.data.body_pos_w[:, keybody_idxs],
env.scene.env_origins,
)
return keybody_global_pos # [num_envs, num_keybodies, 3]
@staticmethod
def _get_obs_global_robot_bodylink_rot_wxyz(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
):
"""Orientations (w, x, y, z) of specified bodylinks in the environment frame."""
robot_ptr = env.scene[robot_asset_name]
keybody_idxs = ObservationFunctions._get_body_indices(
robot_ptr, keybody_names
)
keybody_global_rot = robot_ptr.data.body_quat_w[:, keybody_idxs]
return keybody_global_rot # [num_envs, num_keybodies, 4]
@staticmethod
def _get_obs_global_robot_bodylink_rot_xyzw(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
):
"""Orientations (x, y, z, w) of specified bodylinks in the environment frame."""
return ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(
env,
robot_asset_name,
keybody_names,
)[..., [1, 2, 3, 0]] # [num_envs, num_keybodies, 4]
@staticmethod
def _get_obs_global_robot_bodylink_rot_mat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
):
"""Orientations of specified bodylinks as a 3x3 matrix, flattened to the first two rows (6D)."""
keybody_global_rot_wxyz = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(
env,
robot_asset_name,
keybody_names,
)
)
return isaaclab_math.matrix_from_quat(keybody_global_rot_wxyz)[
..., :2
] # [num_envs, num_keybodies, 6]
@staticmethod
def _get_obs_global_robot_bodylink_lin_vel(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
):
"""Linear velocities of specified bodylinks in the environment frame."""
robot_ptr = env.scene[robot_asset_name]
keybody_idxs = ObservationFunctions._get_body_indices(
robot_ptr, keybody_names
)
keybody_global_lin_vel = robot_ptr.data.body_lin_vel_w[:, keybody_idxs]
return keybody_global_lin_vel # [num_envs, num_keybodies, 3]
@staticmethod
def _get_obs_global_robot_bodylink_ang_vel(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
):
"""Angular velocities of specified bodylinks in the environment frame."""
robot_ptr = env.scene[robot_asset_name]
keybody_idxs = ObservationFunctions._get_body_indices(
robot_ptr, keybody_names
)
keybody_global_ang_vel = robot_ptr.data.body_ang_vel_w[:, keybody_idxs]
return keybody_global_ang_vel # [num_envs, num_keybodies, 3]
@staticmethod
def _get_obs_root_rel_robot_bodylink_pos(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies, 3]
"""Root-relative bodylink positions from environment-frame positions."""
# Get global states
keybody_global_pos: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_bodylink_pos(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 3]
global_root_pos: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_root_pos(env)
) # [num_envs, 3]
root_global_rot_wxyz: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)
) # [num_envs, 4]
return root_relative_positions_from_env_frame(
body_pos_env=keybody_global_pos,
root_pos_env=global_root_pos,
root_quat_w=root_global_rot_wxyz,
)
@staticmethod
def _get_obs_root_rel_robot_bodylink_rot_wxyz(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies, 4]
"""Orientations (w, x, y, z) of specified bodylinks relative to the robot's root frame."""
# Get global states
keybody_global_rot: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 4]
root_global_rot_wxyz: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)
) # [num_envs, 4]
# Transform to root frame by multiplying with inverse root rotation
root_inv_rot: torch.Tensor = isaaclab_math.quat_inv(
root_global_rot_wxyz
) # [num_envs, 4]
num_bodies = keybody_global_rot.shape[1]
rel_rot_root: torch.Tensor = isaaclab_math.quat_mul(
root_inv_rot[..., None, :].expand(-1, num_bodies, -1),
keybody_global_rot,
) # [num_envs, num_keybodies, 4]
return rel_rot_root
@staticmethod
def _get_obs_root_rel_robot_bodylink_rot_xyzw(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies, 4]
"""Orientations (x, y, z, w) of specified bodylinks relative to the robot's root frame."""
return ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz(
env, robot_asset_name, keybody_names
)[
..., [1, 2, 3, 0]
] # [num_envs, num_keybodies, 4] - convert WXYZ to XYZW
@staticmethod
def _get_obs_root_rel_robot_bodylink_rot_mat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies, 6]
"""Orientations of specified bodylinks relative to the robot's root frame, as a 3x3 matrix, flattened to the first two rows (6D)."""
keybody_rel_rot_wxyz: torch.Tensor = (
ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 4]
return isaaclab_math.matrix_from_quat(keybody_rel_rot_wxyz)[
..., :2
] # [num_envs, num_keybodies, 6]
@staticmethod
def _get_obs_root_rel_robot_bodylink_lin_vel(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies, 3]
"""Linear velocities of specified bodylinks relative to the robot's root frame."""
# Get global states
keybody_global_lin_vel: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_bodylink_lin_vel(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 3]
root_global_lin_vel: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_root_lin_vel(env)
) # [num_envs, 3]
root_global_rot_wxyz: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)
) # [num_envs, 4]
# Compute relative velocity in world frame
rel_lin_vel_w = keybody_global_lin_vel - root_global_lin_vel.unsqueeze(
1
)
# Transform to root frame by rotating with inverse root rotation
root_inv_rot: torch.Tensor = isaaclab_math.quat_inv(
root_global_rot_wxyz
) # [num_envs, 4]
rel_lin_vel_root: torch.Tensor = isaaclab_math.quat_apply(
root_inv_rot.unsqueeze(1), rel_lin_vel_w
) # [num_envs, num_keybodies, 3]
return rel_lin_vel_root
@staticmethod
def _get_obs_root_rel_robot_bodylink_ang_vel(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies, 3]
"""Angular velocities of specified bodylinks relative to the robot's root frame."""
# Get global states
keybody_global_ang_vel: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_bodylink_ang_vel(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 3]
root_global_ang_vel: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_root_ang_vel(env)
) # [num_envs, 3]
root_global_rot_wxyz: torch.Tensor = (
ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)
) # [num_envs, 4]
# Compute relative angular velocity in world frame
rel_ang_vel_w = keybody_global_ang_vel - root_global_ang_vel.unsqueeze(
1
)
# Transform to root frame by rotating with inverse root rotation
root_inv_rot: torch.Tensor = isaaclab_math.quat_inv(
root_global_rot_wxyz
) # [num_envs, 4]
rel_ang_vel_root: torch.Tensor = isaaclab_math.quat_apply(
root_inv_rot.unsqueeze(1), rel_ang_vel_w
) # [num_envs, num_keybodies, 3]
return rel_ang_vel_root
# ------- Flat Bodylink Observations -------
@staticmethod
def _get_obs_global_robot_bodylink_pos_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 3]
"""Flattened positions of specified bodylinks in the environment frame."""
bodylink_pos = ObservationFunctions._get_obs_global_robot_bodylink_pos(
env, robot_asset_name, keybody_names
) # [num_envs, num_keybodies, 3]
return bodylink_pos.reshape(
bodylink_pos.shape[0], -1
) # [num_envs, num_keybodies * 3]
@staticmethod
def _get_obs_global_robot_bodylink_rot_wxyz_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 4]
"""Flattened orientations (w, x, y, z) of specified bodylinks in the environment frame."""
bodylink_rot = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 4]
return bodylink_rot.reshape(
bodylink_rot.shape[0], -1
) # [num_envs, num_keybodies * 4]
@staticmethod
def _get_obs_global_robot_bodylink_rot_xyzw_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 4]
"""Flattened orientations (x, y, z, w) of specified bodylinks in the environment frame."""
bodylink_rot = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_xyzw(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 4]
return bodylink_rot.reshape(
bodylink_rot.shape[0], -1
) # [num_envs, num_keybodies * 4]
@staticmethod
def _get_obs_global_robot_bodylink_rot_mat_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 6]
"""Flattened orientation matrices (6D) of specified bodylinks in the environment frame."""
bodylink_rot_mat = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_mat(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 6]
return bodylink_rot_mat.reshape(
bodylink_rot_mat.shape[0], -1
) # [num_envs, num_keybodies * 6]
@staticmethod
def _get_obs_global_robot_bodylink_lin_vel_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 3]
"""Flattened linear velocities of specified bodylinks in the environment frame."""
bodylink_lin_vel = (
ObservationFunctions._get_obs_global_robot_bodylink_lin_vel(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 3]
return bodylink_lin_vel.reshape(
bodylink_lin_vel.shape[0], -1
) # [num_envs, num_keybodies * 3]
@staticmethod
def _get_obs_global_robot_bodylink_ang_vel_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 3]
"""Flattened angular velocities of specified bodylinks in the environment frame."""
bodylink_ang_vel = (
ObservationFunctions._get_obs_global_robot_bodylink_ang_vel(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 3]
return bodylink_ang_vel.reshape(
bodylink_ang_vel.shape[0], -1
) # [num_envs, num_keybodies * 3]
@staticmethod
def _get_obs_root_rel_robot_bodylink_pos_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 3]
"""Flattened positions of specified bodylinks relative to the robot's root frame."""
bodylink_pos = (
ObservationFunctions._get_obs_root_rel_robot_bodylink_pos(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 3]
return bodylink_pos.reshape(
bodylink_pos.shape[0], -1
) # [num_envs, num_keybodies * 3]
@staticmethod
def _get_obs_root_rel_robot_bodylink_rot_wxyz_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 4]
"""Flattened orientations (w, x, y, z) of specified bodylinks relative to the robot's root frame."""
bodylink_rot = (
ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_wxyz(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 4]
return bodylink_rot.reshape(
bodylink_rot.shape[0], -1
) # [num_envs, num_keybodies * 4]
@staticmethod
def _get_obs_root_rel_robot_bodylink_rot_xyzw_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 4]
"""Flattened orientations (x, y, z, w) of specified bodylinks relative to the robot's root frame."""
bodylink_rot = (
ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_xyzw(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 4]
return bodylink_rot.reshape(
bodylink_rot.shape[0], -1
) # [num_envs, num_keybodies * 4]
@staticmethod
def _get_obs_root_rel_robot_bodylink_rot_mat_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 6]
"""Flattened orientation matrices (6D) of specified bodylinks relative to the robot's root frame."""
bodylink_rot_mat = (
ObservationFunctions._get_obs_root_rel_robot_bodylink_rot_mat(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 6]
return bodylink_rot_mat.reshape(
bodylink_rot_mat.shape[0], -1
) # [num_envs, num_keybodies * 6]
@staticmethod
def _get_obs_root_rel_robot_bodylink_lin_vel_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 3]
"""Flattened linear velocities of specified bodylinks relative to the robot's root frame."""
bodylink_lin_vel = (
ObservationFunctions._get_obs_root_rel_robot_bodylink_lin_vel(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 3]
return bodylink_lin_vel.reshape(
bodylink_lin_vel.shape[0], -1
) # [num_envs, num_keybodies * 3]
@staticmethod
def _get_obs_root_rel_robot_bodylink_ang_vel_flat(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies * 3]
"""Flattened angular velocities of specified bodylinks relative to the robot's root frame."""
bodylink_ang_vel = (
ObservationFunctions._get_obs_root_rel_robot_bodylink_ang_vel(
env, robot_asset_name, keybody_names
)
) # [num_envs, num_keybodies, 3]
return bodylink_ang_vel.reshape(
bodylink_ang_vel.shape[0], -1
) # [num_envs, num_keybodies * 3]
# ------- Robot DoF States -------
@staticmethod
def _get_obs_dof_pos(env: ManagerBasedRLEnv):
"""Joint positions relative to the default joint angles."""
return isaaclab_mdp.joint_pos_rel(env) # [num_envs, num_dofs]
@staticmethod
def _get_obs_dof_vel(env: ManagerBasedRLEnv):
"""Joint velocities."""
return isaaclab_mdp.joint_vel_rel(env) # [num_envs, num_dofs]
@staticmethod
def _get_obs_last_actions(env: ManagerBasedRLEnv):
"""Last action output by the policy."""
return isaaclab_mdp.last_action(env) # [num_envs, num_actions]
# ------- Reference Motion States -------
@staticmethod
def _get_obs_ref_motion_states(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
):
"""Reference motion states (flattened) via RefMotionCommand schema."""
command = env.command_manager.get_term(ref_motion_command_name)
obs_fn_name = f"_get_obs_{command.cfg.command_obs_name}"
obs_fn = getattr(command, obs_fn_name)
return obs_fn(obs_prefix=ref_prefix)
@staticmethod
def _get_obs_ref_motion_states_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
):
"""Future reference motion states (flattened)."""
command = env.command_manager.get_term(ref_motion_command_name)
obs_fn_name = f"_get_obs_{command.cfg.command_obs_name}_fut"
obs_fn = getattr(command, obs_fn_name)
return obs_fn(obs_prefix=ref_prefix)
@staticmethod
def _get_obs_vr_ref_motion_states(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
):
command = env.command_manager.get_term(ref_motion_command_name)
return command._get_obs_vr_ref_motion_states(obs_prefix=ref_prefix)
@staticmethod
def _get_obs_vr_ref_motion_states_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
):
"""Future reference motion states (flattened)."""
command = env.command_manager.get_term(ref_motion_command_name)
return command._get_obs_vr_ref_motion_fut(obs_prefix=ref_prefix)
@staticmethod
def _get_obs_ref_dof_pos_cur(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, num_dofs]
"""Reference current DoF positions in simulator DoF order."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_dof_pos_cur(prefix=ref_prefix)
@staticmethod
def _get_obs_immediate_next_two_dof_pos(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 2 * num_dofs]
"""Immediate next two DoF positions in simulator DoF order."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_immediate_next_two_dof_pos(prefix=ref_prefix)
@staticmethod
def _get_obs_ref_motion_cur_heading_aligned_root_pos(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 3]
"""Reference current heading-aligned root position."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_cur_heading_aligned_root_pos(
prefix=ref_prefix
)
@staticmethod
def _get_obs_ref_motion_fut_heading_aligned_root_pos(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, T, 3]
"""Future reference heading-aligned root position."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_fut_heading_aligned_root_pos(
prefix=ref_prefix
)
@staticmethod
def _get_obs_ref_motion_cur_heading_aligned_root_rot6d(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 6]
"""Reference current heading-aligned root rotation (rot6d)."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_cur_heading_aligned_root_rot6d(
prefix=ref_prefix
)
@staticmethod
def _get_obs_ref_motion_fut_heading_aligned_root_rot6d(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, T, 6]
"""Future reference heading-aligned root rotation (rot6d)."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_fut_heading_aligned_root_rot6d(
prefix=ref_prefix
)
@staticmethod
def _get_obs_ref_motion_cur_heading_aligned_root_lin_vel(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 3]
"""Reference current heading-aligned root linear velocity."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_cur_heading_aligned_root_lin_vel(
prefix=ref_prefix
)
@staticmethod
def _get_obs_ref_motion_fut_heading_aligned_root_lin_vel(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, T, 3]
"""Future reference heading-aligned root linear velocity."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_fut_heading_aligned_root_lin_vel(
prefix=ref_prefix
)
@staticmethod
def _get_obs_ref_motion_cur_heading_aligned_root_ang_vel(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 3]
"""Reference current heading-aligned root angular velocity."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_cur_heading_aligned_root_ang_vel(
prefix=ref_prefix
)
@staticmethod
def _get_obs_ref_motion_fut_heading_aligned_root_ang_vel(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, T, 3]
"""Future reference heading-aligned root angular velocity."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_fut_heading_aligned_root_ang_vel(
prefix=ref_prefix
)
@staticmethod
def _get_obs_ref_dof_vel_cur(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, num_dofs]
"""Reference current DoF velocities in simulator DoF order."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_dof_vel_cur(prefix=ref_prefix)
@staticmethod
def _get_obs_ref_motion_filter_cutoff_hz(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
) -> torch.Tensor:
"""Return clip-level filter metadata; this is prefix-independent."""
command = env.command_manager.get_term(ref_motion_command_name)
return command.get_ref_motion_filter_cutoff_hz_cur()
@staticmethod
def _get_obs_ref_root_height_cur(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 1]
"""Reference current root height: world z minus env-origin z."""
command = env.command_manager.get_term(ref_motion_command_name)
world_pos = command.get_ref_motion_root_global_pos_cur(
prefix=ref_prefix
) # [B, 3]
height = (world_pos[..., 2] - env.scene.env_origins[..., 2]).unsqueeze(
-1
) # [B,1]
return height
@staticmethod
def _get_obs_ref_dof_pos_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
num_frames: int | None = None,
) -> torch.Tensor: # [num_envs, n_fut_frames * num_dofs]
"""Future reference DoF positions (flattened over time) in simulator DoF order."""
command = env.command_manager.get_term(ref_motion_command_name)
dof_pos_fut = command.get_ref_motion_dof_pos_fut(
prefix=ref_prefix
) # [B, T, D(sim)]
dof_pos_fut = ObservationFunctions._slice_future_frames(
dof_pos_fut,
num_frames=num_frames,
obs_name="ref_dof_pos_fut",
)
return dof_pos_fut
@staticmethod
def _get_obs_ref_gravity_projection_cur(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 3]
"""Reference gravity projection."""
command = env.command_manager.get_term(ref_motion_command_name)
gravity_projection = command.get_ref_motion_gravity_projection_cur(
prefix=ref_prefix
)
return gravity_projection
@staticmethod
def _get_obs_ref_gravity_projection_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
num_frames: int | None = None,
) -> torch.Tensor: # [num_envs, T, 3]
"""Future reference gravity projection."""
command = env.command_manager.get_term(ref_motion_command_name)
gravity_projection = command.get_ref_motion_gravity_projection_fut(
prefix=ref_prefix
)
gravity_projection = ObservationFunctions._slice_future_frames(
gravity_projection,
num_frames=num_frames,
obs_name="ref_gravity_projection_fut",
)
return gravity_projection
@staticmethod
def _get_obs_ref_base_linvel_cur(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 3]
"""Reference base linear velocity."""
command = env.command_manager.get_term(ref_motion_command_name)
base_linvel = command.get_ref_motion_base_linvel_cur(prefix=ref_prefix)
return base_linvel
@staticmethod
def _get_obs_ref_base_linvel_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
num_frames: int | None = None,
) -> torch.Tensor: # [num_envs, T, 3]
"""Future reference base linear velocity."""
command = env.command_manager.get_term(ref_motion_command_name)
base_linvel = command.get_ref_motion_base_linvel_fut(prefix=ref_prefix)
base_linvel = ObservationFunctions._slice_future_frames(
base_linvel,
num_frames=num_frames,
obs_name="ref_base_linvel_fut",
)
return base_linvel
@staticmethod
def _get_obs_ref_base_angvel_cur(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor: # [num_envs, 3]
"""Reference base angular velocity."""
command = env.command_manager.get_term(ref_motion_command_name)
base_angvel = command.get_ref_motion_base_angvel_cur(prefix=ref_prefix)
return base_angvel
@staticmethod
def _get_obs_ref_keybody_rel_pos_cur(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
keybody_names: list[str] | None = None,
) -> torch.Tensor: # [num_envs, num_keybodies, 3]
"""Reference keybody root-relative positions."""
command = env.command_manager.get_term(ref_motion_command_name)
ref_keybody_rel_pos = command.get_ref_motion_bodylink_rel_pos_cur(
prefix=ref_prefix
) # [B, N, 3]
if keybody_names is None:
return ref_keybody_rel_pos
robot_ptr = env.scene["robot"]
keybody_idxs = ObservationFunctions._get_body_indices(
robot_ptr, keybody_names
)
kb_rel_pos = ref_keybody_rel_pos[:, keybody_idxs, :]
bs = kb_rel_pos.shape[0]
return kb_rel_pos.reshape(bs, -1)
@staticmethod
def _get_obs_ref_keybody_rel_pos_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
keybody_names: list[str] | None = None,
num_frames: int | None = None,
) -> torch.Tensor: # [num_envs, T, num_keybodies, 3]
"""Future reference keybody root-relative positions."""
command = env.command_manager.get_term(ref_motion_command_name)
ref_keybody_rel_pos_fut = command.get_ref_motion_bodylink_rel_pos_fut(
prefix=ref_prefix
) # [B, T, N, 3]
ref_keybody_rel_pos_fut = ObservationFunctions._slice_future_frames(
ref_keybody_rel_pos_fut,
num_frames=num_frames,
obs_name="ref_keybody_rel_pos_fut",
)
if keybody_names is None:
return ref_keybody_rel_pos_fut
robot_ptr = env.scene["robot"]
keybody_idxs = ObservationFunctions._get_body_indices(
robot_ptr, keybody_names
)
kb_rel_pos_fut = ref_keybody_rel_pos_fut[:, :, keybody_idxs, :]
bs, t, _, _ = kb_rel_pos_fut.shape
return kb_rel_pos_fut.reshape(bs, t, -1)
@staticmethod
def _get_obs_ref_base_angvel_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
num_frames: int | None = None,
) -> torch.Tensor: # [num_envs, T, 3]
"""Future reference base angular velocity."""
command = env.command_manager.get_term(ref_motion_command_name)
base_angvel = command.get_ref_motion_base_angvel_fut(prefix=ref_prefix)
base_angvel = ObservationFunctions._slice_future_frames(
base_angvel,
num_frames=num_frames,
obs_name="ref_base_angvel_fut",
)
return base_angvel
@staticmethod
def _get_obs_ref_dof_vel_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
num_frames: int | None = None,
) -> torch.Tensor: # [num_envs, n_fut_frames * num_dofs]
"""Future reference DoF velocities (flattened over time) in simulator DoF order."""
command = env.command_manager.get_term(ref_motion_command_name)
dof_vel_fut = command.get_ref_motion_dof_vel_fut(
prefix=ref_prefix
) # [B, T, D(sim)]
dof_vel_fut = ObservationFunctions._slice_future_frames(
dof_vel_fut,
num_frames=num_frames,
obs_name="ref_dof_vel_fut",
)
B, T, D = dof_vel_fut.shape
return dof_vel_fut.reshape(B, T * D)
@staticmethod
def _get_obs_ref_root_height_fut(
env: ManagerBasedRLEnv,
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
num_frames: int | None = None,
) -> torch.Tensor: # [num_envs, n_fut_frames]
"""Future reference root heights per frame: world z minus env-origin z."""
command = env.command_manager.get_term(ref_motion_command_name)
world_pos = command.get_ref_motion_root_global_pos_fut(
prefix=ref_prefix
) # [B, T, 3]
world_pos = ObservationFunctions._slice_future_frames(
world_pos,
num_frames=num_frames,
obs_name="ref_root_height_fut",
)
heights = (
world_pos[..., 2] - env.scene.env_origins[:, None, 2]
) # [B, T]
return heights[..., None]
# @torch.compile
@staticmethod
def _get_obs_global_anchor_diff(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
):
command = env.command_manager.get_term(ref_motion_command_name)
env_ref_motion_anchor_pos = positions_world_to_env_frame(
command.get_ref_motion_anchor_bodylink_global_pos_cur(
prefix=ref_prefix
),
env.scene.env_origins,
)
global_ref_motino_anchor_rot_wxyz = (
command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur(
prefix=ref_prefix
)
)
global_robot_anchor_pos = (
ObservationFunctions._get_obs_global_robot_bodylink_pos(
env, robot_asset_name, [command.anchor_bodylink_name]
).squeeze(1)
)
global_robot_anchor_rot_wxyz = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(
env, robot_asset_name, [command.anchor_bodylink_name]
).squeeze(1)
)
pos_diff, rot_diff = isaaclab_math.subtract_frame_transforms(
t01=global_robot_anchor_pos,
q01=global_robot_anchor_rot_wxyz,
t02=env_ref_motion_anchor_pos,
q02=global_ref_motino_anchor_rot_wxyz,
)
rot_diff_mat = isaaclab_math.matrix_from_quat(rot_diff)
return torch.cat(
[
pos_diff,
rot_diff_mat[..., :2].reshape(env.num_envs, -1),
],
dim=-1,
) # [num_envs, 9]
@staticmethod
def _get_obs_global_anchor_pos_diff(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
):
command = env.command_manager.get_term(ref_motion_command_name)
env_ref_motion_anchor_pos = positions_world_to_env_frame(
command.get_ref_motion_anchor_bodylink_global_pos_cur(
prefix=ref_prefix
),
env.scene.env_origins,
)
global_ref_motino_anchor_rot_wxyz = (
command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur(
prefix=ref_prefix
)
)
global_robot_anchor_pos = (
ObservationFunctions._get_obs_global_robot_bodylink_pos(
env, robot_asset_name, [command.anchor_bodylink_name]
).squeeze(1)
)
global_robot_anchor_rot_wxyz = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(
env, robot_asset_name, [command.anchor_bodylink_name]
).squeeze(1)
)
pos_diff, _ = isaaclab_math.subtract_frame_transforms(
t01=global_robot_anchor_pos,
q01=global_robot_anchor_rot_wxyz,
t02=env_ref_motion_anchor_pos,
q02=global_ref_motino_anchor_rot_wxyz,
)
return pos_diff
@staticmethod
def _get_obs_global_anchor_rot_diff(
env: ManagerBasedRLEnv,
robot_asset_name: str = "robot",
ref_motion_command_name: str = "ref_motion",
ref_prefix: str = "ref_",
):
command = env.command_manager.get_term(ref_motion_command_name)
env_ref_motion_anchor_pos = positions_world_to_env_frame(
command.get_ref_motion_anchor_bodylink_global_pos_cur(
prefix=ref_prefix
),
env.scene.env_origins,
)
global_ref_motino_anchor_rot_wxyz = (
command.get_ref_motion_anchor_bodylink_global_rot_wxyz_cur(
prefix=ref_prefix
)
)
global_robot_anchor_pos = (
ObservationFunctions._get_obs_global_robot_bodylink_pos(
env, robot_asset_name, [command.anchor_bodylink_name]
).squeeze(1)
)
global_robot_anchor_rot_wxyz = (
ObservationFunctions._get_obs_global_robot_bodylink_rot_wxyz(
env, robot_asset_name, [command.anchor_bodylink_name]
).squeeze(1)
)
_, rot_diff = isaaclab_math.subtract_frame_transforms(
t01=global_robot_anchor_pos,
q01=global_robot_anchor_rot_wxyz,
t02=env_ref_motion_anchor_pos,
q02=global_ref_motino_anchor_rot_wxyz,
)
rot_diff_mat = isaaclab_math.matrix_from_quat(rot_diff)
return rot_diff_mat[..., :2].reshape(env.num_envs, -1)
@staticmethod
def _get_obs_velocity_command(
env: ManagerBasedRLEnv,
):
"""Velocity command.
This function should return the velocity command which
has already been serialized into flattened vectors. Note that we also
add a model switch mask dimension, when commands are small, the mode
is set to 0, otherwise it is set to 1.
"""
velocity_command = isaaclab_mdp.generated_commands(
env,
command_name="base_velocity",
)
# Some IsaacLab velocity commands may append extra channels (e.g., heading).
# For velocity-tracking PPO we only use (vx, vy, yaw_rate) to keep the
# observation contract stable.
if velocity_command.shape[-1] > 3:
velocity_command = velocity_command[..., :3]
move_mask = (velocity_command.norm(dim=-1) > 0.1).to(
dtype=velocity_command.dtype
)
return torch.cat(
[
move_mask[..., None],
velocity_command,
],
dim=-1,
) # [num_envs, 4]
@staticmethod
def _get_obs_place_holder(env: ManagerBasedRLEnv, n_dim: int):
return torch.zeros(env.num_envs, n_dim, device=env.device)
@staticmethod
def _get_obs_ref_headling_aligned_vel_cmd(
env: ManagerBasedRLEnv, ref_prefix: str = "ref_"
):
heading_aligned_lin_vel_xyz = ObservationFunctions._get_obs_ref_motion_cur_heading_aligned_root_lin_vel(
env, ref_prefix=ref_prefix
)
heading_aligned_ang_vel_xyz = ObservationFunctions._get_obs_ref_motion_cur_heading_aligned_root_ang_vel(
env, ref_prefix=ref_prefix
)
heading_aligned_vel_cmd = torch.cat(
[
heading_aligned_lin_vel_xyz[:, :2],
heading_aligned_ang_vel_xyz[:, 2:3],
],
dim=-1,
)
move_mask = (heading_aligned_vel_cmd.norm(dim=-1) > 0.1).to(
dtype=heading_aligned_vel_cmd.dtype
)
heading_aligned_vel_cmd = torch.cat(
[
move_mask[..., None],
heading_aligned_vel_cmd,
],
dim=-1,
)
return heading_aligned_vel_cmd
@staticmethod
def _get_obs_heading_aligned_root_ang_vel(env: ManagerBasedRLEnv):
root_global_ang_vel = (
ObservationFunctions._get_obs_global_robot_root_ang_vel(env)
)
root_global_rot_wxyz = (
ObservationFunctions._get_obs_global_robot_root_rot_wxyz(env)
)
heading_quat_wxyz = isaaclab_math.yaw_quat(root_global_rot_wxyz)
heading_aligned_root_ang_vel = isaaclab_math.quat_apply_inverse(
heading_quat_wxyz, root_global_ang_vel
)
return heading_aligned_root_ang_vel
@configclass
class ObservationsCfg:
pass
def build_observations_config(obs_config_dict: dict):
"""Build isaaclab-compatible ObservationsCfg from a config dictionary."""
if isinstance(obs_config_dict, (DictConfig, ListConfig)):
obs_config_dict = OmegaConf.to_container(obs_config_dict, resolve=True)
obs_cfg = ObservationsCfg()
obs_term_field_names = {
field.name for field in dataclass_fields(ObservationTermCfg)
}
# Create observation groups dynamically
for group_name, group_cfg in obs_config_dict.items():
group_cfg = resolve_holo_config(group_cfg)
isaaclab_obs_group_cfg = ObsGroup()
for key, value in group_cfg.items():
if key == "atomic_obs_list":
continue
if hasattr(isaaclab_obs_group_cfg, key):
setattr(isaaclab_obs_group_cfg, key, value)
# Add observation terms to the group
for obs_term_dict in group_cfg["atomic_obs_list"]:
for obs_name, obs_params in obs_term_dict.items():
obs_params = resolve_holo_config(obs_params)
func_name = obs_params.get("func", obs_name)
method_name = f"_get_obs_{func_name}"
if hasattr(ObservationFunctions, method_name):
func = getattr(ObservationFunctions, method_name)
elif hasattr(isaaclab_mdp, func_name):
func = getattr(isaaclab_mdp, func_name)
else:
raise ValueError(
f"Unknown observation function: {func_name}"
)
obs_term_kwargs = {"func": func}
try:
params_cfg = obs_params.get("params", {})
except AttributeError:
print(f"No params found for {obs_name}")
obs_term_kwargs["params"] = resolve_holo_config(params_cfg)
noise_cfg = obs_params.get("noise")
if noise_cfg is not None:
obs_term_kwargs["noise"] = _build_noise_cfg(noise_cfg)
for field_name in obs_term_field_names:
if field_name in {"func", "params", "noise"}:
continue
if field_name in obs_params:
obs_term_kwargs[field_name] = obs_params[field_name]
obs_term = ObsTerm(**obs_term_kwargs)
# Add observation term to group
setattr(isaaclab_obs_group_cfg, obs_name, obs_term)
# Add group to main observations config
setattr(obs_cfg, group_name, isaaclab_obs_group_cfg)
return obs_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_rewards.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import torch
from isaaclab.assets import Articulation
from isaaclab.envs import ManagerBasedRLEnv
from isaaclab.managers import ManagerTermBase, RewardTermCfg, SceneEntityCfg
from isaaclab.sensors import ContactSensor
from isaaclab.utils import configclass
import isaaclab.utils.math as isaaclab_math
from holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import (
RefMotionCommand,
)
from holomotion.src.utils.frame_utils import (
positions_world_to_env_frame,
root_relative_positions_from_env_frame,
root_relative_positions_from_mixed_position_frames,
)
import isaaclab.envs.mdp as isaaclab_mdp
from hydra.utils import instantiate as hydra_instantiate
from omegaconf import DictConfig, ListConfig, OmegaConf
from loguru import logger
from holomotion.src.env.isaaclab_components.isaaclab_utils import (
_get_body_indices,
resolve_holo_config,
_get_dof_indices,
)
def _joint_ids_to_tensor(
joint_ids: slice | list[int] | tuple[int, ...] | torch.Tensor | None,
num_joints: int,
device: torch.device | str,
) -> torch.Tensor:
"""Convert joint indices to a dense tensor in articulation order."""
if joint_ids is None:
return torch.arange(num_joints, device=device, dtype=torch.long)
if isinstance(joint_ids, slice):
if joint_ids == slice(None):
return torch.arange(num_joints, device=device, dtype=torch.long)
return torch.arange(num_joints, device=device, dtype=torch.long)[
joint_ids
]
if isinstance(joint_ids, torch.Tensor):
return joint_ids.to(device=device, dtype=torch.long).flatten()
return torch.tensor(joint_ids, device=device, dtype=torch.long)
def _select_effort_limit_vector(
asset: Articulation,
selected_joint_ids: torch.Tensor,
) -> torch.Tensor:
"""Build a per-joint effort-limit vector from instantiated actuators."""
num_joints = asset.data.applied_torque.shape[1]
device = asset.data.applied_torque.device
dtype = asset.data.applied_torque.dtype
effort_limit_vec = torch.zeros(num_joints, device=device, dtype=dtype)
is_filled = torch.zeros(num_joints, device=device, dtype=torch.bool)
for actuator in asset.actuators.values():
actuator_joint_ids = _joint_ids_to_tensor(
actuator.joint_indices, num_joints=num_joints, device=device
)
actuator_effort_limit = torch.as_tensor(
actuator.effort_limit, device=device, dtype=dtype
)
if actuator_effort_limit.ndim == 0:
actuator_effort_limit = actuator_effort_limit.expand(
actuator_joint_ids.numel()
)
elif actuator_effort_limit.ndim == 2:
if actuator_effort_limit.shape[0] > 1:
reference = actuator_effort_limit[0].unsqueeze(0)
if not torch.allclose(
actuator_effort_limit,
reference.expand_as(actuator_effort_limit),
):
raise ValueError(
"normed_torque_rate requires actuator effort limits to be static across envs."
)
actuator_effort_limit = actuator_effort_limit[0]
elif actuator_effort_limit.ndim != 1:
raise ValueError(
"normed_torque_rate expects actuator effort limits to be scalar, 1-D, or 2-D tensors."
)
if actuator_effort_limit.numel() != actuator_joint_ids.numel():
raise ValueError(
"normed_torque_rate found mismatched actuator joint indices and effort limits."
)
effort_limit_vec[actuator_joint_ids] = actuator_effort_limit
is_filled[actuator_joint_ids] = True
if not torch.all(is_filled[selected_joint_ids]):
missing_joint_ids = selected_joint_ids[~is_filled[selected_joint_ids]]
raise ValueError(
"normed_torque_rate could not resolve actuator effort limits for "
f"joint ids {missing_joint_ids.tolist()}."
)
selected_effort_limits = effort_limit_vec[selected_joint_ids]
if not torch.all(torch.isfinite(selected_effort_limits)):
raise ValueError(
"normed_torque_rate requires finite actuator effort limits for all selected joints."
)
if not torch.all(selected_effort_limits > 0.0):
raise ValueError(
"normed_torque_rate requires strictly positive actuator effort limits for all selected joints."
)
return selected_effort_limits
def key_dof_position_tracking_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
key_dofs: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
keydof_idxs = _get_dof_indices(command.robot, key_dofs)
ref_dof_pos = command.get_ref_motion_dof_pos_immediate_next(
prefix=ref_prefix
)
error = torch.sum(
torch.square(
command.robot.data.joint_pos[:, keydof_idxs]
- ref_dof_pos[:, keydof_idxs]
),
dim=-1,
)
return torch.exp(-error / std**2)
def key_dof_velocity_tracking_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
key_dofs: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
keydof_idxs = _get_dof_indices(command.robot, key_dofs)
ref_dof_vel = command.get_ref_motion_dof_vel_immediate_next(
prefix=ref_prefix
)
error = torch.sum(
torch.square(
command.robot.data.joint_vel[:, keydof_idxs]
- ref_dof_vel[:, keydof_idxs]
),
dim=-1,
)
return torch.exp(-error / std**2)
def motion_global_anchor_position_error_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
ref_motion_command: RefMotionCommand = env.command_manager.get_term(
command_name
)
ref_anchor_pos = ref_motion_command.get_ref_motion_anchor_bodylink_global_pos_immediate_next(
prefix=ref_prefix
)
robot_anchor_pos = ref_motion_command.global_robot_anchor_pos_cur
error = torch.sum(
torch.square(ref_anchor_pos - robot_anchor_pos),
dim=-1,
)
return torch.exp(-error / std**2)
def motion_global_anchor_orientation_error_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
ref_anchor_quat = (
command.get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next(
prefix=ref_prefix
)
)
error = (
isaaclab_math.quat_error_magnitude(
ref_anchor_quat,
command.robot.data.body_quat_w[:, command.anchor_bodylink_idx],
)
** 2
)
return torch.exp(-error / std**2)
def motion_relative_body_position_error_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
# Get body indexes based on body names (similar to whole_body_tracking implementation)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Get reference and robot anchor positions/orientations
ref_anchor_pos = command.get_ref_motion_root_global_pos_immediate_next(
prefix=ref_prefix
) # [B, 3]
ref_anchor_quat = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, 4] (w,x,y,z)
robot_anchor_pos = command.robot.data.body_pos_w[
:, command.anchor_bodylink_idx
] # [B, 3]
robot_anchor_quat = command.robot.data.body_quat_w[
:, command.anchor_bodylink_idx
] # [B, 4] (w,x,y,z)
# Get reference body positions in global frame
ref_body_pos_global = (
command.get_ref_motion_bodylink_global_pos_immediate_next(
prefix=ref_prefix
)
) # [B, num_bodies, 3]
# Transform reference body positions to be relative to robot's current anchor
# This follows the same logic as the whole_body_tracking implementation
# Select relevant body indices first
ref_body_pos_selected = ref_body_pos_global[
:, keybody_idxs
] # [B, selected_bodies, 3]
# Expand anchor positions/orientations to match number of selected bodies
num_bodies = len(keybody_idxs)
ref_anchor_pos_exp = ref_anchor_pos[:, None, :].expand(
-1, num_bodies, -1
) # [B, num_bodies, 3]
ref_anchor_quat_exp = ref_anchor_quat[:, None, :].expand(
-1, num_bodies, -1
) # [B, num_bodies, 4]
robot_anchor_pos_exp = robot_anchor_pos[:, None, :].expand(
-1, num_bodies, -1
) # [B, num_bodies, 3]
robot_anchor_quat_exp = robot_anchor_quat[:, None, :].expand(
-1, num_bodies, -1
) # [B, num_bodies, 4]
# Create delta transformation (preserving z from reference, aligning xy to robot)
delta_pos = robot_anchor_pos_exp.clone()
delta_pos[..., 2] = ref_anchor_pos_exp[..., 2] # Keep reference Z height
delta_ori = isaaclab_math.yaw_quat(
isaaclab_math.quat_mul(
robot_anchor_quat_exp,
isaaclab_math.quat_inv(ref_anchor_quat_exp),
)
)
# Transform reference body positions to relative frame
ref_body_pos_relative = delta_pos + isaaclab_math.quat_apply(
delta_ori, ref_body_pos_selected - ref_anchor_pos_exp
)
# Get robot body positions
robot_body_pos = command.robot.data.body_pos_w[:, keybody_idxs]
# Compute error
error = torch.sum(
torch.square(ref_body_pos_relative - robot_body_pos),
dim=-1,
)
return torch.exp(-error.mean(-1) / std**2)
def root_rel_keybodylink_pos_tracking_l2_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track root-relative keybody positions using environment-frame positions.
IsaacLab MDP root position helpers are expressed in the environment frame
(simulation-world position minus `env.scene.env_origins`). This reward
converts body positions into the same environment frame before computing
root-relative vectors.
"""
command: RefMotionCommand = env.command_manager.get_term(command_name)
# Get body indexes based on body names (similar to whole_body_tracking implementation)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Get reference and robot root positions/orientations
ref_root_pos_env = positions_world_to_env_frame(
command.get_ref_motion_root_global_pos_immediate_next(
prefix=ref_prefix
),
env.scene.env_origins,
) # [B, 3]
ref_root_quat_w = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, 4] (w,x,y,z)
robot_root_pos_env = isaaclab_mdp.root_pos_w(env) # [B, 3]
robot_root_quat_w = isaaclab_mdp.root_quat_w(env) # [B, 4] (w,x,y,z)
# Select relevant body indices first
ref_body_pos_env = positions_world_to_env_frame(
command.get_ref_motion_bodylink_global_pos_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs],
env.scene.env_origins,
)
robot_body_pos_root_rel = (
root_relative_positions_from_mixed_position_frames(
body_pos_w=command.robot.data.body_pos_w[:, keybody_idxs],
root_pos_env=robot_root_pos_env,
root_quat_w=robot_root_quat_w,
env_origins=env.scene.env_origins,
)
)
ref_body_pos_root_rel = root_relative_positions_from_env_frame(
body_pos_env=ref_body_pos_env,
root_pos_env=ref_root_pos_env,
root_quat_w=ref_root_quat_w,
)
# Compute error
error = torch.sum(
torch.square(ref_body_pos_root_rel - robot_body_pos_root_rel),
dim=-1,
)
return torch.exp(-error.mean(-1) / std**2)
def motion_relative_body_orientation_error_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
# Get body indexes based on body names (similar to whole_body_tracking implementation)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Get reference and robot anchor orientations
ref_anchor_quat = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, 4] (w,x,y,z)
robot_anchor_quat = command.robot.data.body_quat_w[
:, command.anchor_bodylink_idx
] # [B, 4] (w,x,y,z)
# Get reference body orientations in global frame
ref_body_quat_global = (
command.get_ref_motion_bodylink_global_rot_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, num_bodies, 4]
# Select relevant body indices
ref_body_quat_selected = ref_body_quat_global[
:, keybody_idxs
] # [B, selected_bodies, 4]
# Expand anchor orientations to match number of selected bodies
num_bodies = len(keybody_idxs)
ref_anchor_quat_exp = ref_anchor_quat[:, None, :].expand(
-1, num_bodies, -1
) # [B, num_bodies, 4]
robot_anchor_quat_exp = robot_anchor_quat[:, None, :].expand(
-1, num_bodies, -1
) # [B, num_bodies, 4]
# Compute relative orientation transformation (only yaw component)
delta_ori = isaaclab_math.yaw_quat(
isaaclab_math.quat_mul(
robot_anchor_quat_exp,
isaaclab_math.quat_inv(ref_anchor_quat_exp),
)
)
# Transform reference body orientations to relative frame
ref_body_quat_relative = isaaclab_math.quat_mul(
delta_ori, ref_body_quat_selected
)
# Get robot body orientations
robot_body_quat = command.robot.data.body_quat_w[:, keybody_idxs]
# Compute error
error = (
isaaclab_math.quat_error_magnitude(
ref_body_quat_relative, robot_body_quat
)
** 2
)
return torch.exp(-error.mean(-1) / std**2)
def motion_global_body_linear_velocity_error_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
# Get body indexes based on body names (similar to whole_body_tracking implementation)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Direct comparison of global velocities (no coordinate transformation needed)
ref_lin_vel = (
command.get_ref_motion_bodylink_global_lin_vel_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs]
)
robot_lin_vel = command.robot.data.body_lin_vel_w[:, keybody_idxs]
error = torch.sum(torch.square(ref_lin_vel - robot_lin_vel), dim=-1)
return torch.exp(-error.mean(-1) / std**2)
def motion_global_body_angular_velocity_error_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
# Get body indexes based on body names (similar to whole_body_tracking implementation)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Direct comparison of global angular velocities (no coordinate transformation needed)
ref_ang_vel = (
command.get_ref_motion_bodylink_global_ang_vel_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs]
)
robot_ang_vel = command.robot.data.body_ang_vel_w[:, keybody_idxs]
error = torch.sum(torch.square(ref_ang_vel - robot_ang_vel), dim=-1)
return torch.exp(-error.mean(-1) / std**2)
def root_pos_xy_tracking_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
ref_root_pos = command.get_ref_motion_root_global_pos_immediate_next(
prefix=ref_prefix
)
error = torch.sum(
torch.square(
ref_root_pos[:, :2] - command.robot.data.root_pos_w[:, :2]
),
dim=-1,
)
return torch.exp(-error / std**2)
def root_rot_tracking_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
ref_root_quat = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
)
error = (
isaaclab_math.quat_error_magnitude(
ref_root_quat,
isaaclab_mdp.root_quat_w(env),
)
** 2
)
return torch.exp(-error / std**2)
def root_pos_rel_z_tracking_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
command: RefMotionCommand = env.command_manager.get_term(command_name)
robot_root_z = command.robot.data.root_pos_w[:, 2]
ref_root_z = command.get_ref_motion_root_global_pos_immediate_next(
prefix=ref_prefix
)[:, 2]
dz_rel = robot_root_z - ref_root_z
error = torch.square(dz_rel)
return torch.exp(-error / std**2)
def root_lin_vel_tracking_l2_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track root linear velocity in each entity's own root frame.
Returns: [B]
"""
command: RefMotionCommand = env.command_manager.get_term(command_name)
# [B, 3], [B, 4]
robot_root_lin_vel_w = isaaclab_mdp.root_lin_vel_w(env)
robot_root_quat_w = isaaclab_mdp.root_quat_w(env)
ref_root_lin_vel_w = (
command.get_ref_motion_root_global_lin_vel_immediate_next(
prefix=ref_prefix
)
)
ref_root_quat_w = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
)
# Project to respective root frames
robot_root_lin_vel = isaaclab_math.quat_apply(
isaaclab_math.quat_inv(robot_root_quat_w),
robot_root_lin_vel_w,
) # [B, 3]
ref_root_lin_vel = isaaclab_math.quat_apply(
isaaclab_math.quat_inv(ref_root_quat_w),
ref_root_lin_vel_w,
) # [B, 3]
error = torch.sum(
torch.square(ref_root_lin_vel - robot_root_lin_vel), dim=-1
)
return torch.exp(-error / std**2)
def root_ang_vel_tracking_l2_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track root angular velocity in each entity's own root frame.
Returns: [B]
"""
command: RefMotionCommand = env.command_manager.get_term(command_name)
# [B, 3], [B, 4]
robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env)
robot_root_quat_w = isaaclab_mdp.root_quat_w(env)
ref_root_ang_vel_w = (
command.get_ref_motion_root_global_ang_vel_immediate_next(
prefix=ref_prefix
)
)
ref_root_quat_w = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
)
# Project to respective root frames
robot_root_ang_vel = isaaclab_math.quat_apply(
isaaclab_math.quat_inv(robot_root_quat_w),
robot_root_ang_vel_w,
) # [B, 3]
ref_root_ang_vel = isaaclab_math.quat_apply(
isaaclab_math.quat_inv(ref_root_quat_w),
ref_root_ang_vel_w,
) # [B, 3]
error = torch.sum(
torch.square(ref_root_ang_vel - robot_root_ang_vel), dim=-1
)
return torch.exp(-error / std**2)
def root_rel_keybodylink_pos_tracking_l2_exp_bydmmc_style(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track keybody positions using per-entity heading-aligned frames.
For each of robot and reference:
- subtract own root position (root-relative in world)
- rotate by own yaw-only inverse (heading-aligned frame)
Then compare these root-relative, heading-aligned positions.
All positions are first converted into IsaacLab's environment frame
(simulation world minus `env.scene.env_origins`) so robot root and body
positions use the same translation convention.
Returns: [B]
"""
command: RefMotionCommand = env.command_manager.get_term(command_name)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Root states in environment frame
ref_root_pos = positions_world_to_env_frame(
command.get_ref_motion_root_global_pos_immediate_next(
prefix=ref_prefix
),
env.scene.env_origins,
) # [B, 3]
ref_root_quat = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, 4]
robot_root_pos = isaaclab_mdp.root_pos_w(env) # [B, 3]
robot_root_quat = isaaclab_mdp.root_quat_w(env) # [B, 4]
# Body positions in environment frame
robot_body_pos = positions_world_to_env_frame(
command.robot.data.body_pos_w[:, keybody_idxs],
env.scene.env_origins,
) # [B, N, 3]
ref_body_pos = positions_world_to_env_frame(
command.get_ref_motion_bodylink_global_pos_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs],
env.scene.env_origins,
) # [B, N, 3]
# Expand for broadcasting
num_bodies = len(keybody_idxs)
ref_root_pos_exp = ref_root_pos[:, None, :].expand(-1, num_bodies, -1)
ref_root_quat_exp = ref_root_quat[:, None, :].expand(-1, num_bodies, -1)
robot_root_pos_exp = robot_root_pos[:, None, :].expand(-1, num_bodies, -1)
robot_root_quat_exp = robot_root_quat[:, None, :].expand(
-1, num_bodies, -1
)
# Yaw-only delta orientation (root frames)
delta_ori = isaaclab_math.yaw_quat(
isaaclab_math.quat_mul(
robot_root_quat_exp, isaaclab_math.quat_inv(ref_root_quat_exp)
)
) # [B, N, 4]
# Keep origin at root: compare root-relative vectors after yaw alignment
robot_rel = robot_body_pos - robot_root_pos_exp # [B, N, 3]
ref_rel = ref_body_pos - ref_root_pos_exp # [B, N, 3]
ref_rel_aligned = isaaclab_math.quat_apply(delta_ori, ref_rel) # [B, N, 3]
# Compare in world (root-relative)
error = torch.sum(
torch.square(ref_rel_aligned - robot_rel), dim=-1
) # [B, N]
return torch.exp(-error.mean(-1) / std**2)
def root_rel_keybodylink_rot_tracking_l2_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track root-relative keybody rotations in each entity's root frame.
Returns: [B]
"""
command: RefMotionCommand = env.command_manager.get_term(command_name)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Root orientations
robot_root_quat_w = isaaclab_mdp.root_quat_w(env) # [B, 4]
ref_root_quat_w = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, 4]
# Body orientations (world)
robot_body_quat_w = command.robot.data.body_quat_w[
:, keybody_idxs
] # [B, N, 4]
ref_body_quat_w = (
command.get_ref_motion_bodylink_global_rot_wxyz_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs]
) # [B, N, 4]
# Relative (q_rel = q_root^{-1} * q_body)
num_bodies = len(keybody_idxs)
robot_root_quat_inv_exp = isaaclab_math.quat_inv(robot_root_quat_w)[
:, None, :
].expand(-1, num_bodies, -1)
ref_root_quat_inv_exp = isaaclab_math.quat_inv(ref_root_quat_w)[
:, None, :
].expand(-1, num_bodies, -1)
robot_rel_quat = isaaclab_math.quat_mul(
robot_root_quat_inv_exp,
robot_body_quat_w,
) # [B, N, 4]
ref_rel_quat = isaaclab_math.quat_mul(
ref_root_quat_inv_exp,
ref_body_quat_w,
) # [B, N, 4]
error = (
isaaclab_math.quat_error_magnitude(ref_rel_quat, robot_rel_quat) ** 2
) # [B, N]
return torch.exp(-error.mean(-1) / std**2)
def root_rel_keybodylink_lin_vel_tracking_l2_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track keybody linear velocities with motion_relative frame alignment.
Compute rigid-body-relative velocities for both entities w.r.t. their
roots, yaw-align reference to robot using root quats, then compare in
world space.
Root/body positions used for rigid-body radius vectors are first converted
into IsaacLab's environment frame (simulation world minus
`env.scene.env_origins`) so the translation convention matches
`isaaclab_mdp.root_pos_w(env)`.
Returns: [B]
"""
command: RefMotionCommand = env.command_manager.get_term(command_name)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Root states
robot_root_pos_w = isaaclab_mdp.root_pos_w(env) # [B, 3]
robot_root_quat_w = isaaclab_mdp.root_quat_w(env) # [B, 4]
robot_root_lin_vel_w = isaaclab_mdp.root_lin_vel_w(env) # [B, 3]
robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env) # [B, 3]
ref_root_pos_w = positions_world_to_env_frame(
command.get_ref_motion_root_global_pos_immediate_next(
prefix=ref_prefix
),
env.scene.env_origins,
) # [B, 3]
ref_root_quat_w = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, 4]
ref_root_lin_vel_w = (
command.get_ref_motion_root_global_lin_vel_immediate_next(
prefix=ref_prefix
)
) # [B, 3]
ref_root_ang_vel_w = (
command.get_ref_motion_root_global_ang_vel_immediate_next(
prefix=ref_prefix
)
) # [B, 3]
# Body states (world)
robot_body_pos_w = positions_world_to_env_frame(
command.robot.data.body_pos_w[:, keybody_idxs],
env.scene.env_origins,
) # [B, N, 3]
robot_body_lin_vel_w = command.robot.data.body_lin_vel_w[
:, keybody_idxs
] # [B, N, 3]
ref_body_pos_w = positions_world_to_env_frame(
command.get_ref_motion_bodylink_global_pos_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs],
env.scene.env_origins,
) # [B, N, 3]
ref_body_lin_vel_w = (
command.get_ref_motion_bodylink_global_lin_vel_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs]
) # [B, N, 3]
# Rigid-body relative (world)
robot_r_w = robot_body_pos_w - robot_root_pos_w[:, None, :]
ref_r_w = ref_body_pos_w - ref_root_pos_w[:, None, :]
robot_cross = torch.cross(
robot_root_ang_vel_w[:, None, :], robot_r_w, dim=-1
) # [B, N, 3]
ref_cross = torch.cross(
ref_root_ang_vel_w[:, None, :], ref_r_w, dim=-1
) # [B, N, 3]
robot_v_rel_w = (
robot_body_lin_vel_w - robot_root_lin_vel_w[:, None, :] - robot_cross
) # [B, N, 3]
ref_v_rel_w = (
ref_body_lin_vel_w - ref_root_lin_vel_w[:, None, :] - ref_cross
) # [B, N, 3]
# Yaw-only delta orientation from root quats; rotate reference velocities
num_bodies = len(keybody_idxs)
robot_root_quat_exp = robot_root_quat_w[:, None, :].expand(
-1, num_bodies, -1
) # [B, N, 4]
ref_root_quat_exp = ref_root_quat_w[:, None, :].expand(
-1, num_bodies, -1
) # [B, N, 4]
delta_ori = isaaclab_math.yaw_quat(
isaaclab_math.quat_mul(
robot_root_quat_exp, isaaclab_math.quat_inv(ref_root_quat_exp)
)
) # [B, N, 4]
ref_v_rel_aligned_w = isaaclab_math.quat_apply(delta_ori, ref_v_rel_w)
error = torch.sum(
torch.square(ref_v_rel_aligned_w - robot_v_rel_w), dim=-1
) # [B, N]
return torch.exp(-error.mean(-1) / std**2)
def root_rel_keybodylink_ang_vel_tracking_l2_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track root-relative keybody angular velocities in root frames.
Uses w_rel_w = w_body - w_root, then rotates into each entity's root frame.
Returns: [B]
"""
command: RefMotionCommand = env.command_manager.get_term(command_name)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
# Root orientations and angular velocities
robot_root_quat_w = isaaclab_mdp.root_quat_w(env) # [B, 4]
robot_root_ang_vel_w = isaaclab_mdp.root_ang_vel_w(env) # [B, 3]
ref_root_quat_w = (
command.get_ref_motion_root_global_rot_quat_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, 4]
ref_root_ang_vel_w = (
command.get_ref_motion_root_global_ang_vel_immediate_next(
prefix=ref_prefix
)
) # [B, 3]
# Body angular velocities (world)
robot_body_ang_vel_w = command.robot.data.body_ang_vel_w[
:, keybody_idxs
] # [B, N, 3]
ref_body_ang_vel_w = (
command.get_ref_motion_bodylink_global_ang_vel_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs]
) # [B, N, 3]
# Relative (world), then rotate
robot_w_rel_w = robot_body_ang_vel_w - robot_root_ang_vel_w[:, None, :]
ref_w_rel_w = ref_body_ang_vel_w - ref_root_ang_vel_w[:, None, :]
num_bodies = len(keybody_idxs)
robot_root_quat_inv_exp = isaaclab_math.quat_inv(robot_root_quat_w)[
:, None, :
].expand(-1, num_bodies, -1)
ref_root_quat_inv_exp = isaaclab_math.quat_inv(ref_root_quat_w)[
:, None, :
].expand(-1, num_bodies, -1)
robot_w_rel = isaaclab_math.quat_apply(
robot_root_quat_inv_exp,
robot_w_rel_w,
) # [B, N, 3]
ref_w_rel = isaaclab_math.quat_apply(
ref_root_quat_inv_exp,
ref_w_rel_w,
) # [B, N, 3]
error = torch.sum(torch.square(ref_w_rel - robot_w_rel), dim=-1) # [B, N]
return torch.exp(-error.mean(-1) / std**2)
def global_keybodylink_lin_vel_tracking_l2_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track global keybody linear velocities."""
command: RefMotionCommand = env.command_manager.get_term(command_name)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
ref_global_keybody_lin_vel = (
command.get_ref_motion_bodylink_global_lin_vel_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs]
) # [B, N, 3]
robot_keybody_lin_vel = command.robot.data.body_lin_vel_w[
:, keybody_idxs
] # [B, N, 3]
error = torch.sum(
torch.square(ref_global_keybody_lin_vel - robot_keybody_lin_vel),
dim=-1,
) # [B, N]
return torch.exp(-error.mean(-1) / std**2)
def global_keybodylink_ang_vel_tracking_l2_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Track global keybody angular velocities."""
command: RefMotionCommand = env.command_manager.get_term(command_name)
keybody_idxs = _get_body_indices(command.robot, keybody_names)
ref_global_keybody_ang_vel = (
command.get_ref_motion_bodylink_global_ang_vel_immediate_next(
prefix=ref_prefix
)[:, keybody_idxs]
) # [B, N, 3]
robot_keybody_ang_vel = command.robot.data.body_ang_vel_w[
:, keybody_idxs
] # [B, N, 3]
error = torch.sum(
torch.square(ref_global_keybody_ang_vel - robot_keybody_ang_vel),
dim=-1,
) # [B, N]
return torch.exp(-error.mean(-1) / std**2)
# @torch.compile
def feet_contact_time(
env: ManagerBasedRLEnv,
sensor_cfg: SceneEntityCfg,
threshold: float,
) -> torch.Tensor:
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
first_air = contact_sensor.compute_first_air(env.step_dt, env.physics_dt)[
:, sensor_cfg.body_ids
]
last_contact_time = contact_sensor.data.last_contact_time[
:, sensor_cfg.body_ids
]
reward = torch.sum((last_contact_time < threshold) * first_air, dim=-1)
return reward
def track_lin_vel_xy_yaw_frame_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""Track linear velocity (xy) in the gravity-aligned yaw frame using exponential kernel.
This mirrors the implementation in IsaacLab locomotion velocity MDP.
"""
asset: Articulation = env.scene[asset_cfg.name]
vel_yaw = isaaclab_math.quat_apply_inverse(
isaaclab_math.yaw_quat(asset.data.root_quat_w),
asset.data.root_lin_vel_w[:, :3],
)
# vel_yaw = isaaclab_math.quat_rotate_inverse(
# isaaclab_math.yaw_quat(asset.data.root_quat_w),
# asset.data.root_lin_vel_w[:, :3],
# )
lin_vel_error = torch.sum(
torch.square(
env.command_manager.get_command(command_name)[:, :2]
- vel_yaw[:, :2]
),
dim=1,
)
return torch.exp(-lin_vel_error / (std**2))
def feet_slide(
env: ManagerBasedRLEnv,
sensor_cfg: SceneEntityCfg,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""Penalize feet sliding when in contact using contact forces and foot linear velocity."""
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
contacts = (
contact_sensor.data.net_forces_w_history[:, :, sensor_cfg.body_ids, :]
.norm(dim=-1)
.max(dim=1)[0]
> 1.0
)
asset: Articulation = env.scene[asset_cfg.name]
body_vel = asset.data.body_lin_vel_w[:, asset_cfg.body_ids, :2]
reward = torch.sum(body_vel.norm(dim=-1) * contacts, dim=1)
return reward
def feet_slide_ang_vel(
env: ManagerBasedRLEnv,
sensor_cfg: SceneEntityCfg,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""Penalize feet sliding when in contact using contact forces and foot linear velocity."""
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
contacts = (
contact_sensor.data.net_forces_w_history[:, :, sensor_cfg.body_ids, :]
.norm(dim=-1)
.max(dim=1)[0]
> 1.0
)
asset: Articulation = env.scene[asset_cfg.name]
body_ang_vel = asset.data.body_ang_vel_w[:, asset_cfg.body_ids, 2:3]
reward = torch.sum(body_ang_vel.norm(dim=-1) * contacts, dim=1)
return reward
def foot_clearance_reward(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg,
target_height: float,
std: float,
tanh_mult: float,
sensor_cfg: SceneEntityCfg,
) -> torch.Tensor:
"""Reward swinging feet clearing a target height with velocity-shaped kernel.
Only rewards feet that are swinging (not in contact) and are close to the target height.
"""
asset: Articulation = env.scene[asset_cfg.name]
foot_z = asset.data.body_pos_w[:, asset_cfg.body_ids, 2] # [B, N]
delta_z = target_height - foot_z
delta_z = torch.clamp(delta_z, min=0.0) # only penalze if below target
foot_z_error = torch.square(delta_z) # [B, N]
# Only reward swinging feet (not in contact)
is_swinging = torch.ones_like(foot_z_error, dtype=torch.bool)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
is_contact = (
contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0
) # [B, N]
is_swinging = ~is_contact
# Gate reward by horizontal velocity to ensure feet are actually moving
foot_horizontal_vel = torch.norm(
asset.data.body_lin_vel_w[:, asset_cfg.body_ids, :2], dim=2
) # [B, N]
velocity_gate = torch.tanh(tanh_mult * foot_horizontal_vel) # [B, N]
# Reward: high when error is low (at target height) and foot is swinging
reward_per_foot = (
torch.exp(-foot_z_error / std**2) * velocity_gate * is_swinging.float()
)
return torch.sum(reward_per_foot, dim=1)
def feet_gait(
env: ManagerBasedRLEnv,
period: float,
offset: list[float],
sensor_cfg: SceneEntityCfg,
threshold: float = 0.5,
command_name=None,
) -> torch.Tensor:
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
is_contact = (
contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0
)
global_phase = (
(env.episode_length_buf * env.step_dt) % period / period
).unsqueeze(1)
phases = []
for offset_ in offset:
phase = (global_phase + offset_) % 1.0
phases.append(phase)
leg_phase = torch.cat(phases, dim=-1)
reward = torch.zeros(env.num_envs, dtype=torch.float, device=env.device)
for i in range(len(sensor_cfg.body_ids)):
is_stance = leg_phase[:, i] < threshold
reward += ~(is_stance ^ is_contact[:, i])
if command_name is not None:
cmd_norm = torch.norm(
env.command_manager.get_command(command_name), dim=1
)
reward *= cmd_norm > 0.1
return reward
joint_deviation_l1_arms = isaaclab_mdp.joint_deviation_l1
joint_deviation_l1_arms_roll = isaaclab_mdp.joint_deviation_l1
joint_deviation_l1_waists = isaaclab_mdp.joint_deviation_l1
joint_deviation_l1_legs = isaaclab_mdp.joint_deviation_l1
joint_deviation_l1_legs_yaw = isaaclab_mdp.joint_deviation_l1
joint_deviation_l1_stand_still = isaaclab_mdp.joint_deviation_l1
def joint_deviation_l2(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""Penalize joint positions that deviate from the default one."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
# compute out of limits constraints
angle = (
asset.data.joint_pos[:, asset_cfg.joint_ids]
- asset.data.default_joint_pos[:, asset_cfg.joint_ids]
)
return torch.sum(torch.square(angle), dim=1)
joint_deviation_l2_arms_roll = joint_deviation_l2
joint_deviation_l2_arms = joint_deviation_l2
joint_deviation_l2_waists = joint_deviation_l2
joint_deviation_l2_legs = joint_deviation_l2
joint_deviation_l2_shoulder_roll = joint_deviation_l2
joint_deviation_l2_hip_roll = joint_deviation_l2
def energy(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""Penalize the energy used by the robot's joints."""
asset: Articulation = env.scene[asset_cfg.name]
qvel = asset.data.joint_vel[:, asset_cfg.joint_ids]
qfrc = asset.data.applied_torque[:, asset_cfg.joint_ids]
return torch.sum(torch.abs(qvel) * torch.abs(qfrc), dim=-1)
def positive_work(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""Penalize only the positive mechanical work (energy injected) by the joints."""
asset: Articulation = env.scene[asset_cfg.name]
qvel = asset.data.joint_vel[:, asset_cfg.joint_ids]
qfrc = asset.data.applied_torque[:, asset_cfg.joint_ids]
# Calculate raw mechanical power (positive = motoring, negative = braking)
power = qfrc * qvel
# Only keep positive values, zero out negative (braking) work
positive_power = torch.relu(power)
return torch.sum(positive_power, dim=-1)
class normed_positive_work(ManagerTermBase):
"""Penalize positive joint work normalized by effort and velocity limits."""
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):
super().__init__(cfg, env)
self._asset_name: str | None = None
self._joint_ids: torch.Tensor | None = None
self._inv_effort_limit: torch.Tensor | None = None
def _maybe_build_cache(
self,
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg,
) -> Articulation:
asset: Articulation = env.scene[asset_cfg.name]
joint_ids = _joint_ids_to_tensor(
getattr(asset_cfg, "joint_ids", None),
num_joints=asset.data.applied_torque.shape[1],
device=asset.data.applied_torque.device,
)
cache_needs_refresh = (
self._asset_name != asset_cfg.name
or self._joint_ids is None
or not torch.equal(self._joint_ids, joint_ids)
or self._inv_effort_limit is None
or self._inv_effort_limit.shape != (joint_ids.numel(),)
or self._inv_effort_limit.device
!= asset.data.applied_torque.device
or self._inv_effort_limit.dtype != asset.data.applied_torque.dtype
)
if not cache_needs_refresh:
return asset
effort_limit = _select_effort_limit_vector(asset, joint_ids)
self._asset_name = asset_cfg.name
self._joint_ids = joint_ids
self._inv_effort_limit = effort_limit.reciprocal()
return asset
def __call__(
self,
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
asset = self._maybe_build_cache(env, asset_cfg)
joint_ids = self._joint_ids
inv_effort_limit = self._inv_effort_limit
assert joint_ids is not None
assert inv_effort_limit is not None
current_torque = asset.data.applied_torque[:, joint_ids]
current_joint_vel = asset.data.joint_vel[:, joint_ids]
joint_vel_limits = asset.data.joint_vel_limits[:, joint_ids]
if not torch.all(torch.isfinite(joint_vel_limits)) or not torch.all(
joint_vel_limits > 0.0
):
raise ValueError(
"normed_positive_work requires finite, strictly positive "
"joint velocity limits for all selected joints."
)
normalized_power = (current_torque * inv_effort_limit) * (
current_joint_vel / joint_vel_limits
)
return torch.sum(torch.relu(normalized_power), dim=-1)
class normed_torque_rate(ManagerTermBase):
"""Penalize joint torque-rate changes normalized by actuator effort limits."""
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):
super().__init__(cfg, env)
self._asset_name: str | None = None
self._joint_ids: torch.Tensor | None = None
self._inv_effort_limit: torch.Tensor | None = None
self._prev_applied_torque: torch.Tensor | None = None
self._needs_reseed = torch.ones(
self.num_envs, device=self.device, dtype=torch.bool
)
def reset(self, env_ids=None) -> None:
if env_ids is None:
self._needs_reseed[:] = True
return
if isinstance(env_ids, slice):
self._needs_reseed[env_ids] = True
return
env_ids_tensor = torch.as_tensor(
env_ids, device=self.device, dtype=torch.long
)
self._needs_reseed[env_ids_tensor] = True
def _maybe_build_cache(
self,
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg,
) -> Articulation:
asset: Articulation = env.scene[asset_cfg.name]
joint_ids = _joint_ids_to_tensor(
getattr(asset_cfg, "joint_ids", None),
num_joints=asset.data.applied_torque.shape[1],
device=asset.data.applied_torque.device,
)
cache_needs_refresh = (
self._asset_name != asset_cfg.name
or self._joint_ids is None
or not torch.equal(self._joint_ids, joint_ids)
or self._prev_applied_torque is None
or self._prev_applied_torque.shape
!= (env.num_envs, joint_ids.numel())
or self._prev_applied_torque.device
!= asset.data.applied_torque.device
or self._prev_applied_torque.dtype
!= asset.data.applied_torque.dtype
)
if not cache_needs_refresh:
return asset
effort_limit = _select_effort_limit_vector(asset, joint_ids)
self._asset_name = asset_cfg.name
self._joint_ids = joint_ids
self._inv_effort_limit = effort_limit.reciprocal()
self._prev_applied_torque = torch.zeros(
env.num_envs,
joint_ids.numel(),
device=asset.data.applied_torque.device,
dtype=asset.data.applied_torque.dtype,
)
self._needs_reseed = torch.ones(
env.num_envs,
device=asset.data.applied_torque.device,
dtype=torch.bool,
)
return asset
def __call__(
self,
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
asset = self._maybe_build_cache(env, asset_cfg)
joint_ids = self._joint_ids
inv_effort_limit = self._inv_effort_limit
prev_applied_torque = self._prev_applied_torque
assert joint_ids is not None
assert inv_effort_limit is not None
assert prev_applied_torque is not None
current_torque = asset.data.applied_torque[:, joint_ids]
reward = torch.zeros(
env.num_envs,
device=current_torque.device,
dtype=current_torque.dtype,
)
reseed_mask = self._needs_reseed.clone()
if hasattr(env, "episode_length_buf"):
reseed_mask |= env.episode_length_buf == 0
active_mask = ~reseed_mask
if torch.any(active_mask):
delta = (
current_torque[active_mask] - prev_applied_torque[active_mask]
) * inv_effort_limit
reward[active_mask] = torch.sum(delta.square(), dim=1)
prev_applied_torque.copy_(current_torque)
self._needs_reseed[reseed_mask] = False
return reward
def track_stand_still_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "base_velocity",
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""Track stand still joint position using exponential kernel when command velocity is low.
Returns: [B]
"""
asset: Articulation = env.scene[asset_cfg.name]
error = torch.sum(
torch.square(asset.data.joint_pos - asset.data.default_joint_pos),
dim=1,
)
# Use generated velocity commands (vx, vy, yaw_rate). Some command terms may
# expose additional channels (e.g., heading) via get_command().
cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)
if cmd.shape[-1] > 3:
cmd = cmd[..., :3]
cmd_norm = torch.norm(cmd, dim=1)
return torch.exp(-error / std**2) * (cmd_norm < 0.1)
def stand_still_joint_deviation_l1(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
command_name: str = "base_velocity",
) -> torch.Tensor:
"""Penalize L1 joint deviation from default pose when command velocity is low.
Returns: [B]
"""
asset: Articulation = env.scene[asset_cfg.name]
# L1 error: sum(|q - q_default|)
error = torch.sum(
torch.abs(asset.data.joint_pos - asset.data.default_joint_pos),
dim=1,
)
cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)
if cmd.shape[-1] > 3:
cmd = cmd[..., :3]
cmd_norm = torch.norm(cmd, dim=1)
# Return error (to be penalized with negative weight) only when standing still
return error * (cmd_norm < 0.1)
def stand_still_action_rate(
env: ManagerBasedRLEnv,
command_name: str = "base_velocity",
) -> torch.Tensor:
cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)
if cmd.shape[-1] > 3:
cmd = cmd[..., :3]
stand_still = torch.norm(cmd, dim=1) < 0.1
return (
torch.sum(
torch.square(
env.action_manager.action - env.action_manager.prev_action
),
dim=1,
)
* stand_still
)
def stand_still_dof_vel_l2(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
command_name: str = "base_velocity",
) -> torch.Tensor:
cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)
if cmd.shape[-1] > 3:
cmd = cmd[..., :3]
stand_still = torch.norm(cmd, dim=1) < 0.1
return (
torch.sum(
torch.square(env.scene[asset_cfg.name].data.joint_vel),
dim=1,
)
* stand_still
)
class action_acc(ManagerTermBase):
"""Penalize the change in action-rate using a stateful second difference."""
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):
super().__init__(cfg, env)
self._prev_action: torch.Tensor | None = None
self._prev_action_rate: torch.Tensor | None = None
self._needs_reseed = torch.ones(
self.num_envs, device=self.device, dtype=torch.bool
)
self._needs_prev_rate = torch.ones(
self.num_envs, device=self.device, dtype=torch.bool
)
def reset(self, env_ids=None) -> None:
if env_ids is None:
self._needs_reseed[:] = True
self._needs_prev_rate[:] = True
return
if isinstance(env_ids, slice):
self._needs_reseed[env_ids] = True
self._needs_prev_rate[env_ids] = True
return
env_ids_tensor = torch.as_tensor(
env_ids, device=self.device, dtype=torch.long
)
self._needs_reseed[env_ids_tensor] = True
self._needs_prev_rate[env_ids_tensor] = True
def _maybe_build_cache(
self, env: ManagerBasedRLEnv
) -> tuple[torch.Tensor, torch.Tensor]:
current_action = env.action_manager.action
cache_needs_refresh = (
self._prev_action is None
or self._prev_action_rate is None
or self._prev_action.shape != current_action.shape
or self._prev_action.device != current_action.device
or self._prev_action.dtype != current_action.dtype
or self._prev_action_rate.shape != current_action.shape
or self._prev_action_rate.device != current_action.device
or self._prev_action_rate.dtype != current_action.dtype
)
if cache_needs_refresh:
self._prev_action = torch.zeros_like(current_action)
self._prev_action_rate = torch.zeros_like(current_action)
self._needs_reseed = torch.ones(
env.num_envs,
device=current_action.device,
dtype=torch.bool,
)
self._needs_prev_rate = torch.ones(
env.num_envs,
device=current_action.device,
dtype=torch.bool,
)
assert self._prev_action is not None
assert self._prev_action_rate is not None
return self._prev_action, self._prev_action_rate
def __call__(self, env: ManagerBasedRLEnv) -> torch.Tensor:
current_action = env.action_manager.action
prev_action, prev_action_rate = self._maybe_build_cache(env)
reward = torch.zeros(
env.num_envs,
device=current_action.device,
dtype=current_action.dtype,
)
reseed_mask = self._needs_reseed.clone()
if hasattr(env, "episode_length_buf"):
reseed_mask |= env.episode_length_buf == 0
if torch.any(reseed_mask):
prev_action[reseed_mask] = current_action[reseed_mask]
prev_action_rate[reseed_mask].zero_()
self._needs_prev_rate[reseed_mask] = True
active_mask = ~reseed_mask
if torch.any(active_mask):
current_action_rate = (
current_action[active_mask] - prev_action[active_mask]
)
ready_mask = ~self._needs_prev_rate[active_mask]
if torch.any(ready_mask):
action_acc_value = (
current_action_rate[ready_mask]
- prev_action_rate[active_mask][ready_mask]
)
reward[
active_mask.nonzero(as_tuple=False).flatten()[ready_mask]
] = torch.sum(action_acc_value.square(), dim=1)
prev_action[active_mask] = current_action[active_mask]
prev_action_rate[active_mask] = current_action_rate
self._needs_prev_rate[active_mask] = False
self._needs_reseed[reseed_mask] = False
return reward
action_acc_l2 = action_acc
def feet_stumble(
env: ManagerBasedRLEnv, sensor_cfg: SceneEntityCfg
) -> torch.Tensor:
# extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
forces_z = torch.abs(
contact_sensor.data.net_forces_w[:, sensor_cfg.body_ids, 2]
)
forces_xy = torch.linalg.norm(
contact_sensor.data.net_forces_w[:, sensor_cfg.body_ids, :2], dim=2
)
# Penalize feet hitting vertical surfaces
reward = torch.any(forces_xy > 4 * forces_z, dim=1).float()
return reward
def feet_too_near(
env: ManagerBasedRLEnv,
threshold: float = 0.2,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
asset: Articulation = env.scene[asset_cfg.name]
feet_pos = asset.data.body_pos_w[:, asset_cfg.body_ids, :]
distance = torch.norm(feet_pos[:, 0] - feet_pos[:, 1], dim=-1)
return (threshold - distance).clamp(min=0)
def feet_contact_without_cmd(
env: ManagerBasedRLEnv,
sensor_cfg: SceneEntityCfg,
command_name: str = "base_velocity",
) -> torch.Tensor:
"""
Reward for feet contact when the command is zero.
"""
# asset: Articulation = env.scene[asset_cfg.name]
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
is_contact = (
contact_sensor.data.current_contact_time[:, sensor_cfg.body_ids] > 0
)
cmd = isaaclab_mdp.generated_commands(env, command_name=command_name)
if cmd.shape[-1] > 3:
cmd = cmd[..., :3]
command_norm = torch.norm(cmd, dim=1)
reward = torch.sum(is_contact, dim=-1).float()
return reward * (command_norm < 0.1)
def torso_xy_ang_vel_l2_penalty(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
robot_ptr = env.scene[asset_cfg.name]
torso_idx = robot_ptr.body_names.index("torso_link")
# World-frame torso angular velocity: [B, 3]
torso_ang_vel_w: torch.Tensor = robot_ptr.data.body_ang_vel_w[
:, torso_idx, :
]
# Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading.
# Build yaw-only quaternion from stored heading_w (shape [B]).
heading_yaw: torch.Tensor = robot_ptr.data.heading_w # [B]
zero = torch.zeros_like(heading_yaw, device=env.device)
heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz(
roll=zero,
pitch=zero,
yaw=heading_yaw,
) # [B, 4]
# Re-express torso angular velocity in heading-aligned frame.
heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz)
torso_ang_vel_h: torch.Tensor = isaaclab_math.quat_apply(
heading_inv_wxyz,
torso_ang_vel_w,
) # [B, 3]
# Penalize lateral components (x, y) with squared magnitude.
penalty: torch.Tensor = torch.sum(
torch.square(torso_ang_vel_h[:, :2]),
dim=-1,
) # [B]
return penalty
def torso_upright_l2_penalty(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
robot_ptr = env.scene[asset_cfg.name]
torso_idx = robot_ptr.body_names.index("torso_link")
torso_rot_quat_w = robot_ptr.data.body_quat_w[:, torso_idx, :]
# Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading.
# Build yaw-only quaternion from stored heading_w (shape [B]).
heading_yaw: torch.Tensor = robot_ptr.data.heading_w # [B]
zero = torch.zeros_like(heading_yaw, device=env.device)
heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz(
roll=zero,
pitch=zero,
yaw=heading_yaw,
) # [B, 4]
# Re-express torso angular velocity in heading-aligned frame.
heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz)
torso_rot_quat_h: torch.Tensor = isaaclab_math.quat_mul(
heading_inv_wxyz,
torso_rot_quat_w,
) # [B, 3]
# get the roll and pitch
roll, pitch, _ = isaaclab_math.euler_xyz_from_quat(torso_rot_quat_h)
pitch *= pitch > 0.0
rollpitch = torch.stack([roll * 2.0, pitch], dim=-1)
# Penalize lateral components (x, y) with squared magnitude.
penalty: torch.Tensor = torch.sum(
torch.square(rollpitch),
dim=-1,
) # [B]
return penalty
def torso_upright_l2_penalty_v2(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
target_pitch: float = 0.0,
roll_scale: float = 2.0,
pitch_scale: float = 1.0,
) -> torch.Tensor:
"""Penalize torso roll/pitch deviation in a heading-aligned frame (symmetric).
Compared to `torso_upright_l2_penalty`, this version penalizes *both* forward
and backward pitch w.r.t. `target_pitch`.
Returns: [B]
"""
robot_ptr = env.scene[asset_cfg.name]
torso_idx = robot_ptr.body_names.index("torso_link")
torso_rot_quat_w: torch.Tensor = robot_ptr.data.body_quat_w[
:, torso_idx, :
] # [B, 4]
heading_yaw: torch.Tensor = robot_ptr.data.heading_w # [B]
zero = torch.zeros_like(heading_yaw, device=env.device)
heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz(
roll=zero,
pitch=zero,
yaw=heading_yaw,
) # [B, 4]
heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz)
torso_rot_quat_h: torch.Tensor = isaaclab_math.quat_mul(
heading_inv_wxyz,
torso_rot_quat_w,
) # [B, 4]
roll, pitch, _ = isaaclab_math.euler_xyz_from_quat(torso_rot_quat_h) # [B]
roll_err: torch.Tensor = roll_scale * roll
pitch_err: torch.Tensor = pitch_scale * (pitch - target_pitch)
roll_pitch = torch.stack([roll_err, pitch_err], dim=-1) # [B, 2]
penalty: torch.Tensor = torch.sum(torch.square(roll_pitch), dim=-1) # [B]
return penalty
def stand_still_torso_upright_exp_v2(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "base_velocity",
cmd_threshold: float = 0.1,
target_pitch: float = 0.0,
roll_scale: float = 2.0,
pitch_scale: float = 1.0,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""Reward torso uprightness under stand-still commands using an exp kernel.
Reward:
exp(-penalty / std^2) if ||cmd|| <= cmd_threshold else 0
where penalty is computed by `torso_upright_l2_penalty_v2`.
Returns: [B]
"""
command = env.command_manager.get_command(command_name)
stand_still_flag: torch.Tensor = (
torch.norm(command, dim=1) <= cmd_threshold
)
penalty = torso_upright_l2_penalty_v2(
env,
asset_cfg=asset_cfg,
target_pitch=target_pitch,
roll_scale=roll_scale,
pitch_scale=pitch_scale,
) # [B]
reward = torch.exp(-penalty / std**2) # [B]
return reward * stand_still_flag.to(dtype=reward.dtype)
def torso_linacc_xy_l2_penalty(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
robot_ptr = env.scene[asset_cfg.name]
torso_idx = robot_ptr.body_names.index("torso_link")
# World-frame torso angular velocity: [B, 3]
torso_linacc_w = robot_ptr.data.body_lin_acc_w[:, torso_idx, :]
# Heading-aligned frame: z-up, x-forward, y-left, defined by robot yaw heading.
# Build yaw-only quaternion from stored heading_w (shape [B]).
heading_yaw: torch.Tensor = robot_ptr.data.heading_w # [B]
zero = torch.zeros_like(heading_yaw, device=env.device)
heading_quat_wxyz: torch.Tensor = isaaclab_math.quat_from_euler_xyz(
roll=zero,
pitch=zero,
yaw=heading_yaw,
) # [B, 4]
# Re-express torso angular velocity in heading-aligned frame.
heading_inv_wxyz: torch.Tensor = isaaclab_math.quat_inv(heading_quat_wxyz)
torso_linacc_h: torch.Tensor = isaaclab_math.quat_apply(
heading_inv_wxyz,
torso_linacc_w,
) # [B, 3]
# Penalize lateral components (x, y) with squared magnitude.
penalty: torch.Tensor = torch.sum(
torch.square(torso_linacc_h),
dim=-1,
) # [B]
return penalty
def track_lin_vel_xy_heading_aligned_frame_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "base_velocity",
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""
Track linear velocity (xy) in the heading-aligned frame using exponential kernel.
Returns: [B]
"""
asset: Articulation = env.scene[asset_cfg.name]
vel_yaw = isaaclab_math.quat_apply_inverse(
isaaclab_math.yaw_quat(asset.data.root_quat_w),
asset.data.root_lin_vel_w[:, :3],
)
command = env.command_manager.get_command(command_name)
stand_still_envs = torch.norm(command, dim=1) <= 0.1
# treat yaw-only envs as zero-translation targets too
# (vx, vy are approx 0 by definition)
zero_lin_vel_envs = stand_still_envs
tracking_targets = torch.where(
zero_lin_vel_envs[:, None], 0.0, command[:, :2]
)
lin_vel_error = torch.sum(
torch.square(tracking_targets - vel_yaw[:, :2]),
dim=1,
)
return torch.exp(-lin_vel_error / std**2)
def track_lin_vel_xy_heading_aligned_frame_exp_v2(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "base_velocity",
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
asset: Articulation = env.scene[asset_cfg.name]
vel_yaw = isaaclab_math.quat_apply_inverse(
isaaclab_math.yaw_quat(asset.data.root_quat_w),
asset.data.root_lin_vel_w[:, :3],
)
command = env.command_manager.get_command(command_name)
yaw_envs = (torch.norm(command[:, :2], dim=1) < 0.1) & (
torch.abs(command[:, 2]) > 0.1
)
stand_still_envs = torch.norm(command, dim=1) <= 0.1
# treat yaw-only envs as zero-translation targets too
# (vx, vy are approx 0 by definition)
zero_lin_vel_envs = stand_still_envs | yaw_envs
tracking_targets = torch.where(
zero_lin_vel_envs[:, None], 0.0, command[:, :2]
)
lin_vel_error = torch.sum(
torch.square(tracking_targets - vel_yaw[:, :2]),
dim=1,
)
# encourage zero linear velocity for stand still environments, and encourage yaw-only environments to have more
# precise zero linear velocity tracking too
reward_weights = torch.where(yaw_envs, 2.0, 1.0) + torch.where(
stand_still_envs, 10.0, 0.0
)
return reward_weights * torch.exp(-lin_vel_error / std**2)
def track_ang_vel_z_heading_aligned_frame_exp_v2(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "base_velocity",
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""
Track angular velocity (z) in the heading-aligned frame using exponential kernel.
Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame.
Returns: [B]
"""
asset: Articulation = env.scene[asset_cfg.name]
command = env.command_manager.get_command(command_name)
yaw_envs = (torch.norm(command[:, :2], dim=1) < 0.1) & (
torch.abs(command[:, 2]) > 0.1
)
stand_still_envs = torch.norm(command, dim=1) <= 0.1
# set the tracking targets to 0.0 for stand still environments
tracking_targets = torch.where(stand_still_envs, 0.0, command[:, 2])
ang_vel_error = torch.square(
tracking_targets - asset.data.root_ang_vel_w[:, 2]
)
# encourage zero angular velocity for stand still environments, and encourage yaw-only environments to have more
# precise angular velocity tracking
reward_weights = torch.where(yaw_envs, 2.0, 1.0) + torch.where(
stand_still_envs, 10.0, 0.0
)
return reward_weights * torch.exp(-ang_vel_error / std**2)
def track_ang_vel_z_heading_aligned_frame_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "base_velocity",
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""
Track angular velocity (z) in the heading-aligned frame using exponential kernel.
Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame.
Returns: [B]
"""
asset: Articulation = env.scene[asset_cfg.name]
command = env.command_manager.get_command(command_name)
stand_still_envs = torch.norm(command, dim=1) <= 0.1
tracking_targets = torch.where(stand_still_envs, 0.0, command[:, 2])
ang_vel_error = torch.square(
tracking_targets - asset.data.root_ang_vel_w[:, 2]
)
return torch.exp(-ang_vel_error / std**2)
def smoothed_track_ang_vel_z_heading_aligned_frame_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "base_velocity",
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
) -> torch.Tensor:
"""
Track angular velocity (z) in the heading-aligned frame using exponential kernel.
Note that the angular velocity in the world frame is the same as the angular velocity in the heading-aligned frame.
Returns: [B]
"""
asset: Articulation = env.scene[asset_cfg.name]
command = env.command_manager.get_command(command_name)
hist_robot_heading_aligned_ang_vel_z = env.observation_manager.compute()[
"unified"
]["rew_heading_aligned_root_ang_vel"][..., 2]
ep_len = env.episode_length_buf
obs_window_len = hist_robot_heading_aligned_ang_vel_z.shape[1]
smooth_window = torch.minimum(
torch.full_like(ep_len, obs_window_len), ep_len
)
smoothed_robot_heading_aligned_ang_vel_z = (
hist_robot_heading_aligned_ang_vel_z.sum(dim=1) / smooth_window
)
ang_vel_error = torch.square(
command[:, 2] - smoothed_robot_heading_aligned_ang_vel_z
)
return torch.exp(-ang_vel_error / std**2)
def feet_air_time(
env: ManagerBasedRLEnv,
threshold: float,
sensor_cfg: SceneEntityCfg,
command_name: str = "base_velocity",
) -> torch.Tensor:
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids]
contact_time = contact_sensor.data.current_contact_time[
:, sensor_cfg.body_ids
]
in_contact = contact_time > 0.0
in_mode_time = torch.where(in_contact, contact_time, air_time)
single_stance = torch.sum(in_contact.int(), dim=1) == 1
reward = torch.min(
torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1
)[0]
reward = torch.clamp(reward, max=threshold)
# no reward for zero command
command = env.command_manager.get_command(command_name)
reward *= (
torch.norm(command[:, :2], dim=1) + torch.abs(command[:, 2])
) > 0.1
return reward
def feet_air_time_v2(
env: ManagerBasedRLEnv,
threshold: float,
sensor_cfg: SceneEntityCfg,
command_name: str = "base_velocity",
) -> torch.Tensor:
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
air_time = contact_sensor.data.current_air_time[:, sensor_cfg.body_ids]
contact_time = contact_sensor.data.current_contact_time[
:, sensor_cfg.body_ids
]
in_contact = contact_time > 0.0
in_mode_time = torch.where(in_contact, contact_time, air_time)
single_stance = torch.sum(in_contact.int(), dim=1) == 1
reward = torch.min(
torch.where(single_stance.unsqueeze(-1), in_mode_time, 0.0), dim=1
)[0]
reward = torch.clamp(reward, max=threshold)
# no reward for zero command
command = env.command_manager.get_command(command_name)
stand_still_envs_flag = torch.norm(command, dim=1) <= 0.1
ang_z_only_mask = (torch.norm(command[:, :2], dim=1) <= 0.1) & (
torch.abs(command[:, 2]) > 0.1
)
# Stand still: 0.0, yaw-only: 10.0, other: 1.0
reward_weights = torch.where(
stand_still_envs_flag, 0.0, 1.0
) + torch.where(ang_z_only_mask, 5.0, 0.0)
return reward * reward_weights
def feet_air_time_v3(
env: ManagerBasedRLEnv,
command_name: str,
sensor_cfg: SceneEntityCfg,
threshold: float,
) -> torch.Tensor:
"""Reward long steps taken by the feet using L2-kernel.
This function rewards the agent for taking steps that are longer than a threshold. This helps ensure
that the robot lifts its feet off the ground and takes steps. The reward is computed as the sum of
the time for which the feet are in the air.
If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero.
"""
# extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
# compute the reward
first_contact = contact_sensor.compute_first_contact(env.step_dt)[
:, sensor_cfg.body_ids
]
last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids]
reward = torch.sum((last_air_time - threshold) * first_contact, dim=1)
# no reward for stand still commands, larger reward for yaw-only commands
commands = env.command_manager.get_command(command_name)
stand_still_envs = torch.norm(commands, dim=1) <= 0.1
yaw_only_envs = (torch.norm(commands[:, :2], dim=1) <= 0.1) & (
torch.abs(commands[:, 2]) > 0.1
)
reward_weights = torch.where(stand_still_envs, 0.0, 1.0) + torch.where(
yaw_only_envs, 4.0, 0.0
)
return reward * reward_weights
def feet_air_time_v4(
env: ManagerBasedRLEnv,
command_name: str,
sensor_cfg: SceneEntityCfg,
threshold: float,
) -> torch.Tensor:
"""Reward long steps taken by the feet using L2-kernel.
This function rewards the agent for taking steps that are longer than a threshold. This helps ensure
that the robot lifts its feet off the ground and takes steps. The reward is computed as the sum of
the time for which the feet are in the air.
If the commands are small (i.e. the agent is not supposed to take a step), then the reward is zero.
"""
# extract the used quantities (to enable type-hinting)
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
# compute the reward
first_contact = contact_sensor.compute_first_contact(env.step_dt)[
:, sensor_cfg.body_ids
]
last_air_time = contact_sensor.data.last_air_time[:, sensor_cfg.body_ids]
reward = torch.sum((last_air_time - threshold) * first_contact, dim=1)
# no reward for stand still commands, larger reward for yaw-only commands
commands = env.command_manager.get_command(command_name)
stand_still_envs = torch.norm(commands, dim=1) <= 0.1
reward_weights = torch.where(stand_still_envs, 0.0, 1.0)
return reward * reward_weights
def yaw_rate_only_movement_l2_penalty(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
command_name: str = "base_velocity",
) -> torch.Tensor:
"""Penalize world-frame root XY translation during yaw-only commands.
When vx, vy are small commands, penalize
the squared magnitude of root linear velocity (vx, vy) in world frame.
Returns: [B]
"""
# Velocity command: [B, 3] (vx, vy, yaw_rate). Some command terms may
# expose extra channels (e.g., heading) via generated_commands().
command = env.command_manager.get_command(command_name)
# Gate only yaw-rate-only envs: vx=vy=0 and v_yaw > 0.0.
yaw_only_mask: torch.Tensor = (
torch.norm(command[:, :2], dim=1) <= 0.1
) # [B]
# Penalize global (world-frame) root linear velocity in x/y.
asset: Articulation = env.scene[asset_cfg.name]
root_lin_vel_w: torch.Tensor = asset.data.root_lin_vel_w # [B, 3]
penalty: torch.Tensor = torch.sum(
torch.square(root_lin_vel_w[:, :2]),
dim=1,
) # [B]
return penalty * yaw_only_mask.to(dtype=penalty.dtype)
def fly(
env: ManagerBasedRLEnv,
threshold: float,
sensor_cfg: SceneEntityCfg,
) -> torch.Tensor:
contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name]
net_contact_forces = contact_sensor.data.net_forces_w_history
is_contact = (
torch.max(
torch.norm(net_contact_forces[:, :, sensor_cfg.body_ids], dim=-1),
dim=1,
)[0]
> threshold
)
return torch.sum(is_contact, dim=-1) < 0.5
def stand_still_torso_lin_vel_l2_penalty(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
command_name: str = "base_velocity",
) -> torch.Tensor:
robot_ptr = env.scene[asset_cfg.name]
torso_idx = robot_ptr.body_names.index("torso_link")
torso_lin_vel_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :]
penalty = torch.sum(torch.square(torso_lin_vel_w), dim=-1)
command = env.command_manager.get_command(command_name)
stand_still_flag = torch.norm(command, dim=1) <= 0.1
return penalty * stand_still_flag.to(dtype=penalty.dtype)
def stand_still_torso_ang_vel_l2_penalty(
env: ManagerBasedRLEnv,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
command_name: str = "base_velocity",
) -> torch.Tensor:
robot_ptr = env.scene[asset_cfg.name]
torso_idx = robot_ptr.body_names.index("torso_link")
torso_ang_vel_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :]
penalty = torch.sum(torch.square(torso_ang_vel_w), dim=-1)
command = env.command_manager.get_command(command_name)
stand_still_flag = torch.norm(command, dim=1) <= 0.1
return penalty * stand_still_flag.to(dtype=penalty.dtype)
def stand_still_torso_lin_vel_exp(
env: ManagerBasedRLEnv,
std: float,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
command_name: str = "base_velocity",
) -> torch.Tensor:
"""Reward staying still (zero torso linear velocity) when commanded to stand.
Uses exponential kernel: exp(-||v||^2 / std^2)
Args:
env: Environment instance
std: Standard deviation for exponential kernel
asset_cfg: Robot asset configuration
command_name: Name of velocity command
Returns:
Reward tensor of shape [B], active only when stand still commanded
"""
robot_ptr = env.scene[asset_cfg.name]
torso_idx = robot_ptr.body_names.index("torso_link")
torso_lin_vel_w = robot_ptr.data.body_lin_vel_w[:, torso_idx, :]
error = torch.sum(torch.square(torso_lin_vel_w), dim=-1)
reward = torch.exp(-error / std**2)
command = env.command_manager.get_command(command_name)
stand_still_flag = torch.norm(command, dim=1) <= 0.1
return reward * stand_still_flag.to(dtype=reward.dtype)
def stand_still_torso_ang_vel_exp(
env: ManagerBasedRLEnv,
std: float,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
command_name: str = "base_velocity",
) -> torch.Tensor:
"""Reward staying still (zero torso angular velocity) when commanded to stand.
Uses exponential kernel: exp(-||omega||^2 / std^2)
Args:
env: Environment instance
std: Standard deviation for exponential kernel
asset_cfg: Robot asset configuration
command_name: Name of velocity command
Returns:
Reward tensor of shape [B], active only when stand still commanded
"""
robot_ptr = env.scene[asset_cfg.name]
torso_idx = robot_ptr.body_names.index("torso_link")
torso_ang_vel_w = robot_ptr.data.body_ang_vel_w[:, torso_idx, :]
error = torch.sum(torch.square(torso_ang_vel_w), dim=-1)
reward = torch.exp(-error / std**2)
command = env.command_manager.get_command(command_name)
stand_still_flag = torch.norm(command, dim=1) <= 0.1
return reward * stand_still_flag.to(dtype=reward.dtype)
def yaw_rate_only_movement_exp(
env: ManagerBasedRLEnv,
std: float,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
command_name: str = "base_velocity",
) -> torch.Tensor:
"""Reward minimal XY translation during yaw-only commands.
When vx, vy commands are small, reward staying in place using exponential kernel.
Uses exponential kernel: exp(-||v_xy||^2 / std^2)
Args:
env: Environment instance
std: Standard deviation for exponential kernel
asset_cfg: Robot asset configuration
command_name: Name of velocity command
Returns:
Reward tensor of shape [B], active only during yaw-only commands
"""
command = env.command_manager.get_command(command_name)
yaw_only_mask: torch.Tensor = torch.norm(command[:, :2], dim=1) <= 0.1
asset: Articulation = env.scene[asset_cfg.name]
root_lin_vel_w: torch.Tensor = asset.data.root_lin_vel_w
error: torch.Tensor = torch.sum(torch.square(root_lin_vel_w[:, :2]), dim=1)
reward: torch.Tensor = torch.exp(-error / std**2)
return reward * yaw_only_mask.to(dtype=reward.dtype)
def yaw_rate_only_hip_yaw_usage_exp(
env: ManagerBasedRLEnv,
std: float,
command_name: str = "base_velocity",
hip_yaw_dofs: list[str] | None = None,
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
lin_threshold: float = 0.1,
yaw_threshold: float = 0.1,
command_tanh_mult: float = 1.0,
) -> torch.Tensor:
"""Encourage using hip_yaw joint(s) during yaw-rate-only commands.
Active only when commanded to rotate in place (vx, vy small and |yaw_rate| large).
Rewards hip_yaw joint velocity magnitude using a saturating exponential kernel:
r = (1 - exp(-mean(qd_hip_yaw^2) / std^2)) * tanh(command_tanh_mult * |cmd_yaw|)
Shapes:
- command: [B, 3] (vx, vy, yaw_rate)
- asset.data.joint_vel: [B, num_dofs]
- return: [B]
"""
command: torch.Tensor = env.command_manager.get_command(command_name)
yaw_only_mask: torch.Tensor = (
torch.norm(command[:, :2], dim=1) <= lin_threshold
) & (torch.abs(command[:, 2]) > yaw_threshold) # [B]
asset: Articulation = env.scene[asset_cfg.name]
if hip_yaw_dofs is None:
raise ValueError(
"yaw_rate_only_hip_yaw_usage_exp requires hip_yaw_dofs (joint names in "
f"robot.joint_names). Got hip_yaw_dofs=None. robot.joint_names={asset.joint_names}"
)
hip_yaw_joint_ids: list[int] = _get_dof_indices(asset, hip_yaw_dofs)
hip_yaw_vel: torch.Tensor = asset.data.joint_vel[
:, hip_yaw_joint_ids
] # [B, N]
activity_sq: torch.Tensor = torch.mean(
torch.square(hip_yaw_vel), dim=-1
) # [B]
usage_reward: torch.Tensor = 1.0 - torch.exp(-activity_sq / std**2) # [B]
cmd_yaw_abs: torch.Tensor = torch.abs(command[:, 2]) # [B]
cmd_weight: torch.Tensor = torch.tanh(
command_tanh_mult * cmd_yaw_abs
) # [B]
reward: torch.Tensor = usage_reward * cmd_weight
return reward * yaw_only_mask.to(dtype=reward.dtype)
@configclass
class RewardsCfg:
pass
class TaskGatedReward:
"""Callable wrapper to gate reward terms by task_id."""
def __init__(self, func, task_name: str):
self.func = func
self.task_name = task_name
self.__name__ = f"TaskGatedReward[{task_name}]"
def __call__(self, env: ManagerBasedRLEnv, *args, **kwargs):
task_ids = getattr(env, "holo_task_ids", None)
mapping = getattr(env, "holo_task_name_to_id", None)
if task_ids is None or mapping is None:
return torch.zeros(env.num_envs, device=env.device)
target = mapping.get(self.task_name, None)
if target is None:
return torch.zeros(env.num_envs, device=env.device)
mask = task_ids == target
if not torch.any(mask):
return torch.zeros(env.num_envs, device=env.device)
inner_args = kwargs.pop("args", None)
inner_kwargs = kwargs.pop("kwargs", None)
call_args = args if inner_args is None else (*args, *inner_args)
call_kwargs = (
kwargs if inner_kwargs is None else {**kwargs, **inner_kwargs}
)
reward = self.func(env, *call_args, **call_kwargs)
mask = mask.to(device=reward.device, dtype=reward.dtype)
return reward * mask
def build_rewards_config(reward_config_dict: dict):
if isinstance(reward_config_dict, (DictConfig, ListConfig)):
reward_config_dict = OmegaConf.to_container(
reward_config_dict, resolve=True
)
rewards_cfg = RewardsCfg()
# Detect grouped (multi-task) vs flat (legacy) layout
def _is_grouped(cfg: dict) -> bool:
for k, v in cfg.items():
if k == "_config":
continue
if isinstance(v, dict) and "weight" in v:
return False
return True
return False
is_grouped = _is_grouped(reward_config_dict)
if not is_grouped:
for reward_name, reward_cfg in reward_config_dict.items():
if reward_name == "_config":
continue
reward_cfg = resolve_holo_config(reward_cfg)
base_params = resolve_holo_config(reward_cfg["params"])
method_name = f"{reward_name}"
func = globals().get(method_name, None)
if func is None:
func = getattr(isaaclab_mdp, reward_name, None)
if func is None:
raise ValueError(f"Unknown reward function: {reward_name}")
params = dict(base_params)
setattr(
rewards_cfg,
reward_name,
RewardTermCfg(
func=func,
weight=reward_cfg["weight"],
params=params,
),
)
return rewards_cfg
# Grouped: rewards: {task_name: {term: ...}}
for task_name, task_group in reward_config_dict.items():
if task_name.startswith("_"):
continue
if not isinstance(task_group, dict):
raise ValueError(f"Expected dict for task group {task_name}")
for reward_name, reward_cfg in task_group.items():
reward_cfg = resolve_holo_config(reward_cfg)
base_params = resolve_holo_config(reward_cfg["params"])
method_name = f"{reward_name}"
func = globals().get(method_name, None)
if func is None:
func = getattr(isaaclab_mdp, reward_name, None)
if func is None:
raise ValueError(f"Unknown reward function: {reward_name}")
if task_name != "common":
func = TaskGatedReward(func, task_name)
params = {"args": [], "kwargs": base_params}
else:
params = base_params
flat_name = f"{task_name}.{reward_name}"
setattr(
rewards_cfg,
flat_name,
RewardTermCfg(
func=func,
weight=reward_cfg["weight"],
params=params,
),
)
return rewards_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_scene.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import copy
import os
import time
from dataclasses import MISSING
import isaaclab.sim as sim_utils
from isaaclab.actuators import ImplicitActuatorCfg
from isaaclab.assets import ArticulationCfg, AssetBaseCfg
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.sensors import ContactSensorCfg, RayCasterCfg, patterns
from isaaclab.terrains import TerrainImporterCfg
from isaaclab.utils import configclass
from loguru import logger
from holomotion.src.env.isaaclab_components.isaaclab_terrain import (
build_terrain_config,
)
from holomotion.src.env.isaaclab_components.unitree_actuators import (
UnitreeActuator,
UnitreeActuatorCfg,
UnitreeErfiActuator,
UnitreeErfiActuatorCfg,
)
class SceneFunctions:
"""Collection of scene component builders."""
@staticmethod
def build_robot_config(
config: dict,
domain_rand_config: dict | None = None,
main_process: bool = True,
process_id: int = 0,
num_processes: int = 1,
) -> ArticulationCfg:
"""Build robot articulation configuration.
Args:
config: Robot configuration dictionary
main_process: Whether this is the main process (from compiled config)
process_id: Process ID/rank (from compiled config)
num_processes: Total number of processes (from compiled config)
"""
urdf_path = config.asset.urdf_file
init_pos = config.init_state.pos
default_joint_positions = config.init_state.default_joint_angles
root_link_name = config.get("root_name", "pelvis")
prim_path = "{ENV_REGEX_NS}/Robot"
actuator_type = config.actuators.get("actuator_type", "implicit")
if actuator_type in {"unitree", "unitree_erfi"}:
actuators = _build_unitree_actuator_cfg(
config.actuators, domain_rand_config or {}
)
else:
actuators = {
"all_joints": ImplicitActuatorCfg(
**config.actuators.all_joints
)
}
logger.info(f"Using {actuator_type} actuators")
logger.info(f"Actuators: {actuators}")
if not os.path.exists(urdf_path):
raise FileNotFoundError(f"URDF file not found: {urdf_path}")
# Configure USD output directory. Optionally isolate per rank to avoid races.
usd_base_dir = os.path.join(os.path.dirname(urdf_path), "usd")
unique_usd_per_rank = True
if num_processes > 1 and unique_usd_per_rank:
usd_dir = os.path.join(usd_base_dir, f"rank_{process_id}")
else:
usd_dir = usd_base_dir
os.makedirs(usd_dir, exist_ok=True)
logger.info(f"Using URDF path: {urdf_path}")
logger.info(f"Using USD directory: {usd_dir}")
force_usd_conversion = config.asset.get("force_usd_conversion", True)
if num_processes > 1 and unique_usd_per_rank:
# Ensure each rank generates its own USD into its isolated directory
force_usd_conversion = True
# Handle DDP
if num_processes > 1:
logger.info(
f"[Process {process_id}/{num_processes}] Distributed training detected"
)
if unique_usd_per_rank:
logger.info(
f"[Process {process_id}] Using per-rank USD dir; forcing USD conversion: {force_usd_conversion}"
)
else:
# Only main process should convert USD to avoid file conflicts
if main_process:
logger.info(
f"[Process {process_id}] Main process - Force USD conversion: {force_usd_conversion}"
)
else:
logger.info(
f"[Process {process_id}] Non-main process - Skipping USD conversion, waiting for main process"
)
force_usd_conversion = False
# Wait for USD files to be created by main process
urdf_basename = os.path.splitext(
os.path.basename(urdf_path)
)[0]
expected_usd_file = os.path.join(
usd_dir, f"{urdf_basename}.usd"
)
logger.info(
f"[Process {process_id}] Waiting for main process to create USD files at {expected_usd_file}..."
)
max_wait = 60
wait_interval = 1
waited = 0
while (
not os.path.exists(expected_usd_file)
and waited < max_wait
):
time.sleep(wait_interval)
waited += wait_interval
if os.path.exists(expected_usd_file):
logger.info(
f"[Process {process_id}] USD file found, proceeding with loading"
)
else:
logger.warning(
f"[Process {process_id}] USD file not found after {max_wait}s, proceeding anyway"
)
else:
logger.info(
f"Single process training. Force USD conversion: {force_usd_conversion}"
)
articulation_cfg = ArticulationCfg(
prim_path=prim_path,
spawn=sim_utils.UrdfFileCfg(
asset_path=os.path.abspath(urdf_path),
usd_dir=os.path.abspath(usd_dir),
force_usd_conversion=force_usd_conversion,
fix_base=False,
merge_fixed_joints=True,
root_link_name=root_link_name,
replace_cylinders_with_capsules=True,
activate_contact_sensors=True,
rigid_props=sim_utils.RigidBodyPropertiesCfg(
disable_gravity=False,
retain_accelerations=False,
linear_damping=0.0,
angular_damping=0.0,
max_linear_velocity=1000.0,
max_angular_velocity=1000.0,
max_depenetration_velocity=1.0,
),
articulation_props=sim_utils.ArticulationRootPropertiesCfg(
enabled_self_collisions=True,
solver_position_iteration_count=8,
solver_velocity_iteration_count=4,
),
joint_drive=sim_utils.UrdfConverterCfg.JointDriveCfg(
gains=sim_utils.UrdfConverterCfg.JointDriveCfg.PDGainsCfg(
stiffness=0,
damping=0,
)
),
),
init_state=ArticulationCfg.InitialStateCfg(
pos=init_pos,
joint_pos=default_joint_positions,
joint_vel={".*": 0.0},
),
soft_joint_pos_limit_factor=0.9,
actuators=actuators,
)
return articulation_cfg
@staticmethod
def build_lighting_config(
config: dict,
) -> tuple[AssetBaseCfg, AssetBaseCfg]:
"""Build lighting configuration."""
distant_light_intensity = config.get("distant_light_intensity", 3000.0)
dome_light_intensity = config.get("dome_light_intensity", 1000.0)
distant_light_color = config.get(
"distant_light_color", (0.75, 0.75, 0.75)
)
dome_light_color = config.get("dome_light_color", (0.13, 0.13, 0.13))
light = AssetBaseCfg(
prim_path="/World/light",
spawn=sim_utils.DistantLightCfg(
color=distant_light_color, intensity=distant_light_intensity
),
)
sky_light = AssetBaseCfg(
prim_path="/World/skyLight",
spawn=sim_utils.DomeLightCfg(
color=dome_light_color, intensity=dome_light_intensity
),
)
return light, sky_light
@staticmethod
def build_contact_sensor_config(config: dict) -> ContactSensorCfg:
"""Build contact sensor configuration."""
prim_path = config.get("prim_path", "{ENV_REGEX_NS}/Robot/.*")
history_length = config.get("history_length", 3)
force_threshold = config.get("force_threshold", 10.0)
track_air_time = config.get("track_air_time", True)
debug_vis = config.get("debug_vis", False)
return ContactSensorCfg(
prim_path=prim_path,
history_length=history_length,
track_air_time=track_air_time,
force_threshold=force_threshold,
debug_vis=debug_vis,
)
@configclass
class MotionTrackingSceneCfg(InteractiveSceneCfg):
"""Scene configuration for motion tracking environment."""
pass
def build_scene_config(
scene_config_dict: dict,
main_process: bool = True,
process_id: int = 0,
num_processes: int = 1,
) -> MotionTrackingSceneCfg:
"""Build IsaacLab-compatible scene configuration from config dictionary.
Args:
scene_config_dict: Scene configuration dictionary
main_process: Whether this is the main process (from compiled config)
process_id: Process ID/rank (from compiled config)
num_processes: Total number of processes (from compiled config)
"""
scene_cfg = MotionTrackingSceneCfg()
# Basic scene properties
scene_cfg.num_envs = scene_config_dict.get("num_envs", MISSING)
scene_cfg.env_spacing = scene_config_dict.get("env_spacing", 2.5)
scene_cfg.replicate_physics = scene_config_dict.get(
"replicate_physics", True
)
# Build robot configuration with process info
if "robot" in scene_config_dict:
robot_config = scene_config_dict["robot"]
scene_cfg.robot = SceneFunctions.build_robot_config(
robot_config,
domain_rand_config=scene_config_dict.get("domain_rand", {}),
main_process=main_process,
process_id=process_id,
num_processes=num_processes,
)
# Build terrain configuration
if "terrain" in scene_config_dict:
terrain_config = scene_config_dict["terrain"]
scene_cfg.terrain = build_terrain_config(
terrain_config, scene_env_spacing=scene_cfg.env_spacing
)
if "robot" in scene_config_dict:
scene_cfg.height_scanner = RayCasterCfg(
prim_path="{ENV_REGEX_NS}/Robot",
offset=RayCasterCfg.OffsetCfg(pos=(0.0, 0.0, 1.0)),
ray_alignment="world",
pattern_cfg=patterns.GridPatternCfg(
resolution=1.0, size=(1.0e-3, 1.0e-3)
),
debug_vis=False,
mesh_prim_paths=[str(scene_cfg.terrain.prim_path)],
max_distance=1.0e6,
)
# Build lighting configuration
if "lighting" in scene_config_dict:
lighting_config = scene_config_dict["lighting"]
light, sky_light = SceneFunctions.build_lighting_config(
lighting_config
)
scene_cfg.light = light
scene_cfg.sky_light = sky_light
# Build contact sensor configuration
if "contact_sensor" in scene_config_dict:
contact_config = scene_config_dict["contact_sensor"]
scene_cfg.contact_forces = SceneFunctions.build_contact_sensor_config(
contact_config
)
return scene_cfg
def _cfg_to_kwargs(cfg: object) -> dict:
return {
key: copy.deepcopy(value)
for key, value in vars(cfg).items()
if not key.startswith("_")
}
def _build_unitree_actuator_cfg(
config: dict, domain_rand_config: dict
) -> dict[str, object]:
base_cfg = unitree_actuator_config_hardcoded["all_joints"]
base_kwargs = _cfg_to_kwargs(base_cfg)
action_delay_cfg = copy.deepcopy(
domain_rand_config.get("action_delay", {})
)
if action_delay_cfg.get("enabled", False):
delay_kwargs = {
"min_delay": int(action_delay_cfg.get("min_delay", 0)),
"max_delay": int(action_delay_cfg.get("max_delay", 0)),
}
else:
delay_kwargs = {"min_delay": 0, "max_delay": 0}
if config.get("actuator_type", "unitree") == "unitree_erfi":
erfi_cfg = copy.deepcopy(domain_rand_config.get("erfi", {}))
actuator_filter_kwargs = {
"ema_filter_enabled": bool(
config.get("ema_filter_enabled", False)
),
"ema_filter_alpha": config.get("ema_filter_alpha", 1.0),
}
erfi_kwargs = {
"erfi_enabled": bool(erfi_cfg.get("enabled", False)),
"rfi_probability": erfi_cfg.get("rfi_probability", 0.5),
"rfi_lim": erfi_cfg.get("rfi_lim", 0.1),
"randomize_rfi_lim": erfi_cfg.get("randomize_rfi_lim", True),
"rfi_lim_range": erfi_cfg.get("rfi_lim_range", (0.5, 1.5)),
"rao_lim": erfi_cfg.get("rao_lim", 0.1),
}
actuator_kwargs = {
**base_kwargs,
**delay_kwargs,
**actuator_filter_kwargs,
**erfi_kwargs,
}
actuator_cfg = UnitreeErfiActuatorCfg(**actuator_kwargs)
actuator_cfg.class_type = UnitreeErfiActuator
else:
actuator_kwargs = {**base_kwargs, **delay_kwargs}
actuator_cfg = UnitreeActuatorCfg(**actuator_kwargs)
actuator_cfg.class_type = UnitreeActuator
return {"all_joints": actuator_cfg}
unitree_actuator_config_hardcoded = {
"all_joints": UnitreeActuatorCfg(
joint_names_expr=[
".*_hip_yaw_joint",
".*_hip_roll_joint",
".*_hip_pitch_joint",
".*_knee_joint",
".*_ankle_pitch_joint",
".*_ankle_roll_joint",
"waist_roll_joint",
"waist_pitch_joint",
"waist_yaw_joint",
".*_shoulder_pitch_joint",
".*_shoulder_roll_joint",
".*_shoulder_yaw_joint",
".*_elbow_joint",
".*_wrist_roll_joint",
".*_wrist_pitch_joint",
".*_wrist_yaw_joint",
],
min_delay=0,
max_delay=0,
effort_limit={
".*_hip_yaw_joint": 88,
".*_hip_roll_joint": 139,
".*_hip_pitch_joint": 88,
".*_knee_joint": 139,
".*_ankle_pitch_joint": 50,
".*_ankle_roll_joint": 50,
"waist_roll_joint": 50,
"waist_pitch_joint": 50,
"waist_yaw_joint": 88,
".*_shoulder_pitch_joint": 25,
".*_shoulder_roll_joint": 25,
".*_shoulder_yaw_joint": 25,
".*_elbow_joint": 25,
".*_wrist_roll_joint": 25,
".*_wrist_pitch_joint": 5,
".*_wrist_yaw_joint": 5,
},
velocity_limit={
".*_hip_yaw_joint": 32,
".*_hip_roll_joint": 20,
".*_hip_pitch_joint": 32,
".*_knee_joint": 20,
".*_ankle_pitch_joint": 37,
".*_ankle_roll_joint": 37,
"waist_roll_joint": 37,
"waist_pitch_joint": 37,
"waist_yaw_joint": 32,
".*_shoulder_pitch_joint": 37,
".*_shoulder_roll_joint": 37,
".*_shoulder_yaw_joint": 37,
".*_elbow_joint": 37,
".*_wrist_roll_joint": 37,
".*_wrist_pitch_joint": 22,
".*_wrist_yaw_joint": 22,
},
stiffness={
".*_hip_yaw_joint": 40.1792384737,
".*_hip_roll_joint": 99.0984277823,
".*_hip_pitch_joint": 40.1792384737,
".*_knee_joint": 99.0984277823,
".*_ankle_pitch_joint": 28.5012461974,
".*_ankle_roll_joint": 28.5012461974,
"waist_roll_joint": 28.5012461974,
"waist_pitch_joint": 28.5012461974,
"waist_yaw_joint": 40.1792384737,
".*_shoulder_pitch_joint": 14.2506230987,
".*_shoulder_roll_joint": 14.2506230987,
".*_shoulder_yaw_joint": 14.2506230987,
".*_elbow_joint": 14.2506230987,
".*_wrist_roll_joint": 14.2506230987,
".*_wrist_pitch_joint": 16.7783274819,
".*_wrist_yaw_joint": 16.7783274819,
},
damping={
".*_hip_yaw_joint": 2.5578897651,
".*_hip_roll_joint": 6.30880185368,
".*_hip_pitch_joint": 2.5578897651,
".*_knee_joint": 6.30880185368,
".*_ankle_pitch_joint": 1.81444568664,
".*_ankle_roll_joint": 1.81444568664,
"waist_roll_joint": 1.81444568664,
"waist_pitch_joint": 1.81444568664,
"waist_yaw_joint": 2.5578897651,
".*_shoulder_pitch_joint": 0.907222843318,
".*_shoulder_roll_joint": 0.907222843318,
".*_shoulder_yaw_joint": 0.907222843318,
".*_elbow_joint": 0.907222843318,
".*_wrist_roll_joint": 0.907222843318,
".*_wrist_pitch_joint": 1.06814150222,
".*_wrist_yaw_joint": 1.06814150222,
},
armature={
".*_hip_yaw_joint": 0.01017752,
".*_hip_roll_joint": 0.025101925,
".*_hip_pitch_joint": 0.01017752,
".*_knee_joint": 0.025101925,
".*_ankle_pitch_joint": 0.00721945,
".*_ankle_roll_joint": 0.00721945,
"waist_roll_joint": 0.00721945,
"waist_pitch_joint": 0.00721945,
"waist_yaw_joint": 0.01017752,
".*_shoulder_pitch_joint": 0.003609725,
".*_shoulder_roll_joint": 0.003609725,
".*_shoulder_yaw_joint": 0.003609725,
".*_elbow_joint": 0.003609725,
".*_wrist_roll_joint": 0.003609725,
".*_wrist_pitch_joint": 0.00425,
".*_wrist_yaw_joint": 0.00425,
},
friction=0,
dynamic_friction=0,
viscous_friction=0,
X1={
".*_hip_yaw_joint": 22.63,
".*_hip_roll_joint": 14.5,
".*_hip_pitch_joint": 22.63,
".*_knee_joint": 14.5,
".*_ankle_pitch_joint": 30.86,
".*_ankle_roll_joint": 30.86,
"waist_roll_joint": 30.86,
"waist_pitch_joint": 30.86,
"waist_yaw_joint": 22.63,
".*_shoulder_pitch_joint": 30.86,
".*_shoulder_roll_joint": 30.86,
".*_shoulder_yaw_joint": 30.86,
".*_elbow_joint": 30.86,
".*_wrist_roll_joint": 30.86,
".*_wrist_pitch_joint": 15.3,
".*_wrist_yaw_joint": 15.3,
},
X2={
".*_hip_yaw_joint": 35.52,
".*_hip_roll_joint": 22.7,
".*_hip_pitch_joint": 35.52,
".*_knee_joint": 22.7,
".*_ankle_pitch_joint": 40.13,
".*_ankle_roll_joint": 40.13,
"waist_roll_joint": 40.13,
"waist_pitch_joint": 40.13,
"waist_yaw_joint": 35.52,
".*_shoulder_pitch_joint": 40.13,
".*_shoulder_roll_joint": 40.13,
".*_shoulder_yaw_joint": 40.13,
".*_elbow_joint": 40.13,
".*_wrist_roll_joint": 40.13,
".*_wrist_pitch_joint": 24.76,
".*_wrist_yaw_joint": 24.76,
},
Y1={
".*_hip_yaw_joint": 71,
".*_hip_roll_joint": 111,
".*_hip_pitch_joint": 71,
".*_knee_joint": 111,
".*_ankle_pitch_joint": 24.8,
".*_ankle_roll_joint": 24.8,
"waist_roll_joint": 24.8,
"waist_pitch_joint": 24.8,
"waist_yaw_joint": 71,
".*_shoulder_pitch_joint": 24.8,
".*_shoulder_roll_joint": 24.8,
".*_shoulder_yaw_joint": 24.8,
".*_elbow_joint": 24.8,
".*_wrist_roll_joint": 24.8,
".*_wrist_pitch_joint": 4.8,
".*_wrist_yaw_joint": 4.8,
},
Y2={
".*_hip_yaw_joint": 83.3,
".*_hip_roll_joint": 131,
".*_hip_pitch_joint": 83.3,
".*_knee_joint": 131,
".*_ankle_pitch_joint": 31.9,
".*_ankle_roll_joint": 31.9,
"waist_roll_joint": 31.9,
"waist_pitch_joint": 31.9,
"waist_yaw_joint": 83.3,
".*_shoulder_pitch_joint": 31.9,
".*_shoulder_roll_joint": 31.9,
".*_shoulder_yaw_joint": 31.9,
".*_elbow_joint": 31.9,
".*_wrist_roll_joint": 31.9,
".*_wrist_pitch_joint": 8.6,
".*_wrist_yaw_joint": 8.6,
},
Fs={
".*_hip_yaw_joint": 1.6,
".*_hip_roll_joint": 2.4,
".*_hip_pitch_joint": 1.6,
".*_knee_joint": 2.4,
".*_ankle_pitch_joint": 0.6,
".*_ankle_roll_joint": 0.6,
"waist_roll_joint": 0.6,
"waist_pitch_joint": 0.6,
"waist_yaw_joint": 1.6,
".*_shoulder_pitch_joint": 0.6,
".*_shoulder_roll_joint": 0.6,
".*_shoulder_yaw_joint": 0.6,
".*_elbow_joint": 0.6,
".*_wrist_roll_joint": 0.6,
".*_wrist_pitch_joint": 0.6,
".*_wrist_yaw_joint": 0.6,
},
Fd={
".*_hip_yaw_joint": 0.16,
".*_hip_roll_joint": 0.24,
".*_hip_pitch_joint": 0.16,
".*_knee_joint": 0.24,
".*_ankle_pitch_joint": 0.06,
".*_ankle_roll_joint": 0.06,
"waist_roll_joint": 0.06,
"waist_pitch_joint": 0.06,
"waist_yaw_joint": 0.16,
".*_shoulder_pitch_joint": 0.06,
".*_shoulder_roll_joint": 0.06,
".*_shoulder_yaw_joint": 0.06,
".*_elbow_joint": 0.06,
".*_wrist_roll_joint": 0.06,
".*_wrist_pitch_joint": 0.06,
".*_wrist_yaw_joint": 0.06,
},
Va={
".*_hip_yaw_joint": 0.01,
".*_hip_roll_joint": 0.01,
".*_hip_pitch_joint": 0.01,
".*_knee_joint": 0.01,
".*_ankle_pitch_joint": 0.01,
".*_ankle_roll_joint": 0.01,
"waist_roll_joint": 0.01,
"waist_pitch_joint": 0.01,
"waist_yaw_joint": 0.01,
".*_shoulder_pitch_joint": 0.01,
".*_shoulder_roll_joint": 0.01,
".*_shoulder_yaw_joint": 0.01,
".*_elbow_joint": 0.01,
".*_wrist_roll_joint": 0.01,
".*_wrist_pitch_joint": 0.01,
".*_wrist_yaw_joint": 0.01,
},
)
}
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_simulator.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from isaaclab.sim import SimulationCfg, PhysxCfg
def build_simulator_config(sim_config_dict: dict) -> SimulationCfg:
"""Build simulation configuration from config dictionary."""
policy_freq = sim_config_dict.get("policy_freq", 50)
sim_freq = sim_config_dict.get("sim_freq", 200)
decimation = int(sim_freq / policy_freq)
dt = 1.0 / sim_freq
device = sim_config_dict.get("device", "cuda")
# PhysX configuration
physx_config = sim_config_dict.get("physx", {})
physx = PhysxCfg(
bounce_threshold_velocity=physx_config.get(
"bounce_threshold_velocity", 0.2
),
gpu_max_rigid_patch_count=physx_config.get(
"gpu_max_rigid_patch_count", int(10 * 2**15)
),
)
return SimulationCfg(
dt=dt,
render_interval=decimation,
physx=physx,
device=device,
)
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_termination.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import inspect
import isaaclab.envs.mdp as isaaclab_mdp
import isaaclab.utils.math as isaaclab_math
import torch
from isaaclab.envs import ManagerBasedRLEnv
from isaaclab.managers import TerminationTermCfg
from isaaclab.utils import configclass
from holomotion.src.env.isaaclab_components import (
isaaclab_motion_tracking_command as motion_tracking_command,
isaaclab_utils,
)
def _list_supported_terminations() -> list[str]:
custom_terminations = {
name
for name, obj in globals().items()
if (
inspect.isfunction(obj)
and obj.__module__ == __name__
and not name.startswith("_")
)
}
native_terminations = {
name
for name in dir(isaaclab_mdp.terminations)
if (
not name.startswith("_")
and callable(getattr(isaaclab_mdp.terminations, name))
)
}
return sorted(custom_terminations | native_terminations)
def _resolve_termination_func(name: str):
func = globals().get(name)
if inspect.isfunction(func) and func.__module__ == __name__:
return func
func = getattr(isaaclab_mdp.terminations, name, None)
if callable(func):
return func
supported = _list_supported_terminations()
raise ValueError(
f"Unknown termination function: {name}. Supported: {supported}"
)
def global_bodylink_pos_far(
env: ManagerBasedRLEnv,
threshold: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Any body link position deviates more than threshold (world frame)."""
command: motion_tracking_command.RefMotionCommand = (
env.command_manager.get_term(command_name)
)
ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next(
prefix=ref_prefix
) # [B, Nb, 3]
robot_pos_w = command.robot.data.body_pos_w # [B, Nb, 3]
keybody_idxs = isaaclab_utils._get_body_indices(
command.robot, keybody_names
)
if keybody_idxs is not None and len(keybody_idxs) > 0:
idxs = torch.as_tensor(
keybody_idxs,
device=ref_pos_w.device,
dtype=torch.long,
)
ref_pos_w = ref_pos_w[:, idxs]
robot_pos_w = robot_pos_w[:, idxs]
error = torch.norm(ref_pos_w - robot_pos_w, dim=-1) # [B, Nb]
return torch.any(error > threshold, dim=-1) # [B]
def anchor_ref_z_far(
env: ManagerBasedRLEnv,
threshold: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Anchor link z difference exceeds threshold (world frame)."""
command: motion_tracking_command.RefMotionCommand = (
env.command_manager.get_term(command_name)
)
ref_z = command.get_ref_motion_anchor_bodylink_global_pos_immediate_next(
prefix=ref_prefix
)[:, -1]
robot_z = command.global_robot_anchor_pos_cur[:, -1]
return (ref_z - robot_z).abs() > threshold
def ref_gravity_projection_far(
env: ManagerBasedRLEnv,
threshold: float,
asset_name: str = "robot",
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Difference in projected gravity z-component exceeds threshold.
Project world gravity into the anchor body frames using inverse
quaternion rotation and compare z-components.
"""
command: motion_tracking_command.RefMotionCommand = (
env.command_manager.get_term(command_name)
)
g_w = env.scene[asset_name].data.GRAVITY_VEC_W # [B, 3]
# Reference anchor orientation (xyzw) from motion cache
ref_anchor_quat_xyzw = (
command.get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next(
prefix=ref_prefix
)
) # [B, 4]
motion_projected_gravity_b = isaaclab_math.quat_apply_inverse(
ref_anchor_quat_xyzw, g_w
) # [B, 3]
# motion_projected_gravity_b = isaaclab_math.quat_rotate_inverse(
# ref_anchor_quat_xyzw, g_w
# ) # [B, 3]
# Robot anchor orientation (xyzw) from sim
robot_anchor_quat_wxyz = command.robot.data.body_quat_w[
:, command.anchor_bodylink_idx
] # [B, 4]
robot_projected_gravity_b = isaaclab_math.quat_apply_inverse(
robot_anchor_quat_wxyz, g_w
) # [B, 3]
# robot_projected_gravity_b = isaaclab_math.quat_rotate_inverse(
# robot_anchor_quat_wxyz, g_w
# ) # [B, 3]
return (
motion_projected_gravity_b[:, 2] - robot_projected_gravity_b[:, 2]
).abs() > threshold
def keybody_ref_pos_far(
env: ManagerBasedRLEnv,
threshold: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Any key body link z difference exceeds threshold (world frame)."""
command: motion_tracking_command.RefMotionCommand = (
env.command_manager.get_term(command_name)
)
ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next(
prefix=ref_prefix
) # [B, Nb, 3]
robot_pos_w = command.robot.data.body_pos_w # [B, Nb, 3]
keybody_idxs = isaaclab_utils._get_body_indices(
command.robot, keybody_names
)
if keybody_idxs is not None and len(keybody_idxs) > 0:
idxs = torch.as_tensor(
keybody_idxs,
device=ref_pos_w.device,
dtype=torch.long,
)
ref_pos_w = ref_pos_w[:, idxs]
robot_pos_w = robot_pos_w[:, idxs]
error = torch.norm(ref_pos_w - robot_pos_w, dim=-1) # [B, Nb]
return torch.any(error > threshold, dim=-1) # [B]
def keybody_ref_z_far(
env: ManagerBasedRLEnv,
threshold: float,
command_name: str = "ref_motion",
keybody_names: list[str] | None = None,
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Any key body link z difference exceeds threshold (world frame)."""
command: motion_tracking_command.RefMotionCommand = (
env.command_manager.get_term(command_name)
)
ref_pos_w = command.get_ref_motion_bodylink_global_pos_immediate_next(
prefix=ref_prefix
) # [B, Nb, 3]
robot_pos_w = command.robot.data.body_pos_w # [B, Nb, 3]
keybody_idxs = isaaclab_utils._get_body_indices(
command.robot, keybody_names
)
if keybody_idxs is not None and len(keybody_idxs) > 0:
idxs = torch.as_tensor(
keybody_idxs,
device=ref_pos_w.device,
dtype=torch.long,
)
ref_pos_w = ref_pos_w[:, idxs]
robot_pos_w = robot_pos_w[:, idxs]
error_z = (ref_pos_w[..., 2] - robot_pos_w[..., 2]).abs() # [B, Nb]
return torch.any(error_z > threshold, dim=-1) # [B]
def wholebody_mpjpe_far(
env: ManagerBasedRLEnv,
threshold: float,
command_name: str = "ref_motion",
ref_prefix: str = "ref_",
) -> torch.Tensor:
"""Mean whole-body DOF position error exceeds threshold."""
command: motion_tracking_command.RefMotionCommand = (
env.command_manager.get_term(command_name)
)
ref_dof_pos = command.get_ref_motion_dof_pos_immediate_next(
prefix=ref_prefix
)
robot_dof_pos = command.robot.data.joint_pos
mean_dof_error = torch.mean(torch.abs(robot_dof_pos - ref_dof_pos), dim=-1)
return mean_dof_error > threshold
def motion_end(
env: ManagerBasedRLEnv,
command_name: str = "ref_motion",
) -> torch.Tensor:
"""Terminate when reference motion frames exceed their end frames.
Returns a boolean mask of shape [num_envs].
"""
command: motion_tracking_command.RefMotionCommand = (
env.command_manager.get_term(command_name)
)
result = command.motion_end_mask.clone().bool()
return result
@configclass
class TerminationsCfg:
pass
def build_terminations_config(
termination_config_dict: dict,
) -> TerminationsCfg:
terminations_cfg = TerminationsCfg()
for termination_name, termination_cfg in termination_config_dict.items():
termination_cfg = isaaclab_utils.resolve_holo_config(termination_cfg)
func = _resolve_termination_func(termination_name)
params = isaaclab_utils.resolve_holo_config(
termination_cfg.get("params", {})
)
term_cfg = TerminationTermCfg(
func=func,
params=params,
time_out=(termination_name == "time_out")
or termination_cfg.get("time_out", False),
)
setattr(terminations_cfg, termination_name, term_cfg)
return terminations_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_terrain.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import isaaclab.sim as sim_utils
import isaaclab.terrains as terrain_gen
import numpy as np
import torch
from isaaclab.terrains import TerrainImporter, TerrainImporterCfg
from isaaclab.terrains.height_field import (
HfDiscreteObstaclesTerrainCfg,
HfPyramidSlopedTerrainCfg,
HfRandomUniformTerrainCfg,
HfTerrainBaseCfg,
)
from isaaclab.terrains.height_field.utils import height_field_to_mesh
from isaaclab.utils import configclass
from loguru import logger
def _convert_range_like_params(params: dict) -> dict:
"""Convert list values for common range/size keys to tuples.
This helps map Hydra YAML list values into IsaacLab config classes that
expect tuples (e.g. ``*_range``).
"""
converted = {}
for key, value in params.items():
if isinstance(value, list) and (
key.endswith("_range") or key in ("size", "difficulty_range")
):
converted[key] = tuple(value)
else:
converted[key] = value
return converted
@height_field_to_mesh
def plane_terrain(difficulty: float, cfg: HfTerrainBaseCfg) -> np.ndarray:
"""Generate a truly flat height-field patch.
This is a lightweight alternative to using ``random_uniform`` with a zero
noise range.
The ``difficulty`` parameter is ignored.
"""
width_pixels = int(cfg.size[0] / cfg.horizontal_scale)
length_pixels = int(cfg.size[1] / cfg.horizontal_scale)
return np.zeros((width_pixels, length_pixels), dtype=np.int16)
@configclass
class HfPlaneTerrainCfg(HfTerrainBaseCfg):
"""Configuration for a flat height-field plane terrain."""
function = plane_terrain
class RandomSpawnTerrainImporter(TerrainImporter):
"""Terrain importer that spawns robots randomly within each sub-terrain."""
_terrain_width: float | None = None
_terrain_length: float | None = None
_spawn_margin: float = 0.0
def _compute_env_origins_curriculum(
self, num_envs: int, origins: torch.Tensor
) -> torch.Tensor:
"""Compute env origins with random (x, y) positions.
This overrides the default curriculum-based distribution to add random
offsets within each sub-terrain's bounds.
Args:
num_envs: Number of environments.
origins: Terrain origins tensor of shape (num_rows, num_cols, 3).
Returns:
Environment origins tensor of shape (num_envs, 3).
"""
num_rows, num_cols = origins.shape[:2]
# Get sub-terrain size from terrain generator config
if self.cfg.terrain_generator is None:
raise ValueError(
"terrain_generator config is required for random spawning"
)
sub_terrain_size = self.cfg.terrain_generator.size
terrain_width, terrain_length = (
sub_terrain_size[0],
sub_terrain_size[1],
)
spawn_margin = float(getattr(self.cfg, "random_spawn_margin", 0.0))
spawn_margin = max(0.0, spawn_margin)
# Clamp margin to avoid invalid sampling ranges.
max_margin = 0.5 * min(float(terrain_width), float(terrain_length))
if spawn_margin >= max_margin:
logger.warning(
f"random_spawn_margin={spawn_margin} is too large "
f"for sub-terrain size={sub_terrain_size}. "
"Clamping to 0.0."
)
spawn_margin = 0.0
# Maximum initial level possible for the terrains
if self.cfg.max_init_terrain_level is None:
max_init_level = num_rows - 1
else:
max_init_level = min(self.cfg.max_init_terrain_level, num_rows - 1)
# Store maximum terrain level possible
self.max_terrain_level = num_rows
# Use default curriculum-based assignment
self.terrain_levels = torch.randint(
0, max_init_level + 1, (num_envs,), device=self.device
)
self.terrain_types = torch.div(
torch.arange(num_envs, device=self.device),
(num_envs / num_cols),
rounding_mode="floor",
).to(torch.long)
# Create environment origins tensor starting from terrain origins
env_origins = torch.zeros(num_envs, 3, device=self.device)
env_origins[:] = origins[self.terrain_levels, self.terrain_types]
# Add random (x, y) offsets within each sub-terrain's bounds
# Offset range: [-size/2 + margin, size/2 - margin] for both x and y
x_min = -terrain_width / 2 + spawn_margin
x_max = terrain_width / 2 - spawn_margin
y_min = -terrain_length / 2 + spawn_margin
y_max = terrain_length / 2 - spawn_margin
x_offsets = torch.empty(num_envs, device=self.device).uniform_(
x_min, x_max
)
y_offsets = torch.empty(num_envs, device=self.device).uniform_(
y_min, y_max
)
env_origins[:, 0] += x_offsets
env_origins[:, 1] += y_offsets
# Store terrain size for use in update_env_origins
self._terrain_width = terrain_width
self._terrain_length = terrain_length
self._spawn_margin = spawn_margin
return env_origins
def update_env_origins(
self,
env_ids: torch.Tensor,
move_up: torch.Tensor,
move_down: torch.Tensor,
):
"""Update env origins when terrain levels change."""
# Check if grid-like spawning
if self.terrain_origins is None:
return
# Update terrain level for the envs
self.terrain_levels[env_ids] += 1 * move_up - 1 * move_down
# Robots that solve the last level are sent to a random one
# The minimum level is zero
self.terrain_levels[env_ids] = torch.where(
self.terrain_levels[env_ids] >= self.max_terrain_level,
torch.randint_like(
self.terrain_levels[env_ids], self.max_terrain_level
),
torch.clip(self.terrain_levels[env_ids], 0),
)
# Update the env origins with terrain origins
self.env_origins[env_ids] = self.terrain_origins[
self.terrain_levels[env_ids], self.terrain_types[env_ids]
]
# Add random (x, y) offsets within each sub-terrain's bounds
if self._terrain_width is None or self._terrain_length is None:
return
num_updated = len(env_ids)
x_min = -self._terrain_width / 2 + self._spawn_margin
x_max = self._terrain_width / 2 - self._spawn_margin
y_min = -self._terrain_length / 2 + self._spawn_margin
y_max = self._terrain_length / 2 - self._spawn_margin
x_offsets = torch.empty(num_updated, device=self.device).uniform_(
x_min, x_max
)
y_offsets = torch.empty(num_updated, device=self.device).uniform_(
y_min, y_max
)
self.env_origins[env_ids, 0] += x_offsets
self.env_origins[env_ids, 1] += y_offsets
def build_terrain_config(
config: dict, scene_env_spacing: float = None
) -> TerrainImporterCfg:
"""Build terrain configuration.
Preferred usage in Holomotion is via the IsaacLab terrain generator API
with height-field sub-terrains fully specified from Hydra configs.
For backward compatibility only, two legacy modes are still supported:
* ``terrain_type=\"plane\"``: simple infinite plane using Isaac Sim's grid.
* ``terrain_type=\"usd\"``: load terrain from a local USD file.
All paths are offline by construction. Visual materials must use local
data:
* ``visual_material.type=\"color\"`` maps to :class:`PreviewSurfaceCfg`
with ``diffuse_color``, ``metallic`` and ``roughness``.
* ``visual_material.type=\"mdl\"`` is accepted only for local MDL files and
never uses NVIDIA Nucleus. When paths are invalid, a neutral color
material is used instead.
Args:
config: Terrain configuration dictionary with fields:
* ``terrain_type``: ``\"generator\"`` (preferred), ``\"plane\"`` or
``\"usd\"`` (legacy).
* ``generator`` (required when ``terrain_type=\"generator\"``):
high-level :class:`TerrainGeneratorCfg` parameters such as
``num_rows``, ``num_cols``, ``size``, ``border_width``,
``horizontal_scale``, ``vertical_scale``, ``slope_threshold``,
``difficulty_range``, ``color_scheme``.
* ``height_field`` (required when ``terrain_type=\"generator\"``):
height-field sub-terrain configuration with:
- ``type``: ``\"plane\"``, ``\"random_uniform\"``,
``\"discrete_obstacles\"`` or ``\"pyramid_sloped\"``.
- Remaining keys are forwarded to the corresponding
:class:`HfRandomUniformTerrainCfg` or
:class:`HfDiscreteObstaclesTerrainCfg`.
* ``random_spawn`` (optional): if True, spawns robots at random
(x, y) positions within each sub-terrain's bounds.
* ``random_spawn_margin`` (optional): if set, keeps random spawn
points at least this many meters away from sub-terrain edges
(helps avoid spawning near the outer border where robots may fall
off).
* ``visual_material`` (optional): offline visual material config.
* ``static_friction``, ``dynamic_friction``, ``restitution``, etc.
scene_env_spacing: Environment spacing from scene config (used only
when ``terrain_type=\"plane\"`` is selected).
Returns:
TerrainImporterCfg configured according to the input parameters
"""
prim_path = config.get("prim_path", "/World/ground")
static_friction = config.get("static_friction", 1.0)
dynamic_friction = config.get("dynamic_friction", 1.0)
restitution = config.get("restitution", 0.0)
friction_combine_mode = config.get("friction_combine_mode", "multiply")
restitution_combine_mode = config.get(
"restitution_combine_mode", "multiply"
)
terrain_type = config.get("terrain_type", "generator")
if terrain_type == "usd":
usd_path = config.get("usd_path")
if usd_path is None:
raise ValueError(
"'usd_path' must be specified for terrain_type 'usd'"
)
terrain_cfg = TerrainImporterCfg(
prim_path=prim_path,
terrain_type="usd",
usd_path=usd_path,
collision_group=-1,
physics_material=sim_utils.RigidBodyMaterialCfg(
friction_combine_mode=friction_combine_mode,
restitution_combine_mode=restitution_combine_mode,
static_friction=static_friction,
dynamic_friction=dynamic_friction,
restitution=restitution,
),
debug_vis=config.get("debug_vis", False),
)
return terrain_cfg
if terrain_type == "plane":
env_spacing = (
scene_env_spacing if scene_env_spacing is not None else 2.5
)
terrain_cfg = TerrainImporterCfg(
prim_path=prim_path,
terrain_type="plane",
collision_group=-1,
env_spacing=env_spacing,
physics_material=sim_utils.RigidBodyMaterialCfg(
friction_combine_mode=friction_combine_mode,
restitution_combine_mode=restitution_combine_mode,
static_friction=static_friction,
dynamic_friction=dynamic_friction,
restitution=restitution,
),
debug_vis=config.get("debug_vis", False),
)
return terrain_cfg
if terrain_type != "generator":
raise ValueError(
f"Unsupported terrain_type '{terrain_type}'. "
"Expected 'generator', 'plane', or 'usd'."
)
generator_cfg_dict = config.get("generator")
if generator_cfg_dict is None:
raise ValueError(
"When 'terrain_type' is 'generator', a 'generator' dict must be "
"provided in terrain config."
)
# Optional new path: multiple sub-terrains defined under
# generator.sub_terrains.
sub_terrains_cfg_dict = generator_cfg_dict.get("sub_terrains")
sub_terrains_cfg = None
if sub_terrains_cfg_dict is not None:
if not isinstance(sub_terrains_cfg_dict, dict):
raise ValueError(
"Expected 'generator.sub_terrains' to be a mapping from names "
"to sub-terrain configs."
)
sub_terrains_cfg = {}
for sub_name, sub_cfg_dict in sub_terrains_cfg_dict.items():
if not isinstance(sub_cfg_dict, dict):
raise ValueError(
f"Sub-terrain '{sub_name}' must be a dictionary with at "
"least a 'type' field."
)
sub_type = sub_cfg_dict.get("type", "random_uniform")
sub_proportion = sub_cfg_dict.get("proportion", 1.0)
sub_params_raw = {
key: value
for key, value in sub_cfg_dict.items()
if key not in ("type", "proportion")
}
sub_params = _convert_range_like_params(sub_params_raw)
if sub_type == "random_uniform":
hf_cfg = HfRandomUniformTerrainCfg(
proportion=sub_proportion, **sub_params
)
elif sub_type == "plane":
hf_cfg = HfPlaneTerrainCfg(
proportion=sub_proportion, **sub_params
)
elif sub_type == "discrete_obstacles":
hf_cfg = HfDiscreteObstaclesTerrainCfg(
proportion=sub_proportion, **sub_params
)
elif sub_type == "pyramid_sloped":
hf_cfg = HfPyramidSlopedTerrainCfg(
proportion=sub_proportion, **sub_params
)
else:
raise ValueError(
f"Unknown sub_terrains['{sub_name}'].type '{sub_type}'. "
"Expected 'plane', 'random_uniform', 'discrete_obstacles',"
" or 'pyramid_sloped'."
)
sub_terrains_cfg[sub_name] = hf_cfg
# Deprecated path: single height_field block at top-level.
if sub_terrains_cfg is None:
height_field_cfg_dict = config.get("height_field")
if height_field_cfg_dict is None:
raise ValueError(
"When 'terrain_type' is 'generator', either "
"'generator.sub_terrains' or a 'height_field' dict must be "
"provided in terrain config."
)
logger.warning(
"Terrain config is using deprecated 'height_field' key. "
"Please migrate to 'generator.sub_terrains' for multi-sub-terrain "
"support."
)
hf_type = height_field_cfg_dict.get("type", "random_uniform")
hf_params_raw = {
key: value
for key, value in height_field_cfg_dict.items()
if key != "type"
}
hf_params = _convert_range_like_params(hf_params_raw)
if hf_type == "random_uniform":
height_field_cfg = HfRandomUniformTerrainCfg(**hf_params)
elif hf_type == "discrete_obstacles":
height_field_cfg = HfDiscreteObstaclesTerrainCfg(**hf_params)
else:
raise ValueError(
f"Unknown height_field.type '{hf_type}'. "
"Expected 'random_uniform' or 'discrete_obstacles'."
)
sub_terrains_cfg = {"main": height_field_cfg}
# Build TerrainGeneratorCfg from Hydra config.
generator_params = _convert_range_like_params(
{
key: value
for key, value in generator_cfg_dict.items()
if key != "sub_terrains"
}
)
terrain_generator = terrain_gen.TerrainGeneratorCfg(
**{
key: value
for key, value in generator_params.items()
if key
in (
"size",
"border_width",
"border_height",
"num_rows",
"num_cols",
"horizontal_scale",
"vertical_scale",
"slope_threshold",
"difficulty_range",
"color_scheme",
"curriculum",
"seed",
"use_cache",
"cache_dir",
)
},
sub_terrains=sub_terrains_cfg,
)
# Configure visual material for offline use
visual_material = None
if "visual_material" in config:
visual_material_dict = config["visual_material"]
material_type = visual_material_dict.get("type", "color")
if material_type == "color":
# Use PreviewSurfaceCfg with diffuse_color (no internet needed)
diffuse_color_raw = visual_material_dict.get(
"diffuse_color", (0.8, 0.8, 0.8)
)
# Convert list to tuple if needed (YAML loads lists).
# Ensure it's a tuple of floats as required by PreviewSurfaceCfg
if isinstance(diffuse_color_raw, list):
diffuse_color = tuple(float(x) for x in diffuse_color_raw)
elif isinstance(diffuse_color_raw, tuple):
diffuse_color = tuple(float(x) for x in diffuse_color_raw)
else:
diffuse_color = diffuse_color_raw
metallic = float(visual_material_dict.get("metallic", 0.0))
roughness = float(visual_material_dict.get("roughness", 0.5))
visual_material = sim_utils.PreviewSurfaceCfg(
diffuse_color=diffuse_color,
metallic=metallic,
roughness=roughness,
)
elif material_type == "none":
# No visual material, rely on vertex colors (e.g. from height map)
visual_material = None
elif material_type == "mdl":
# Use MdlFileCfg with local mdl_path
mdl_path = visual_material_dict.get("mdl_path")
if mdl_path is None:
logger.warning(
"visual_material type is 'mdl' but no mdl_path specified. "
"Falling back to color material to avoid internet "
"requirements."
)
visual_material = sim_utils.PreviewSurfaceCfg(
diffuse_color=(0.5, 0.5, 0.5)
)
else:
# Resolve relative paths
if not os.path.isabs(mdl_path):
if os.path.exists(mdl_path):
resolved_mdl_path = os.path.abspath(mdl_path)
else:
workspace_root = os.path.abspath(
os.path.join(
os.path.dirname(__file__), "../../../.."
)
)
resolved_mdl_path = os.path.join(
workspace_root, mdl_path
)
else:
resolved_mdl_path = mdl_path
if os.path.exists(resolved_mdl_path):
visual_material = sim_utils.MdlFileCfg(
mdl_path=resolved_mdl_path
)
else:
logger.warning(
f"MDL file not found at {resolved_mdl_path}. "
"Falling back to color material to avoid internet "
"requirements."
)
visual_material = sim_utils.PreviewSurfaceCfg(
diffuse_color=(0.5, 0.5, 0.5)
)
else:
logger.warning(
f"Unknown visual_material type: {material_type}. "
"Using default color material."
)
visual_material = sim_utils.PreviewSurfaceCfg(
diffuse_color=(0.5, 0.5, 0.5)
)
# Configure random spawning within sub-terrains if requested
random_spawn = config.get("random_spawn", False)
terrain_importer_class = (
RandomSpawnTerrainImporter if random_spawn else TerrainImporter
)
terrain_cfg = TerrainImporterCfg(
prim_path=prim_path,
terrain_type="generator",
terrain_generator=terrain_generator,
max_init_terrain_level=config.get(
"max_init_terrain_level",
terrain_generator.num_rows - 1,
),
collision_group=-1,
visual_material=visual_material,
physics_material=sim_utils.RigidBodyMaterialCfg(
friction_combine_mode=friction_combine_mode,
restitution_combine_mode=restitution_combine_mode,
static_friction=static_friction,
dynamic_friction=dynamic_friction,
restitution=restitution,
),
debug_vis=config.get("debug_vis", False),
class_type=terrain_importer_class,
)
if random_spawn:
terrain_cfg.random_spawn_margin = float(
config.get("random_spawn_margin", 0.0)
)
return terrain_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_utils.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import torch
from isaaclab.assets import Articulation
from isaaclab.envs import ManagerBasedRLEnv
from isaaclab.managers import RewardTermCfg, SceneEntityCfg
from isaaclab.sensors import ContactSensor
from isaaclab.utils import configclass
import isaaclab.utils.math as isaaclab_math
from holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command import (
RefMotionCommand,
)
import isaaclab.envs.mdp as isaaclab_mdp
from hydra.utils import instantiate as hydra_instantiate
from omegaconf import DictConfig, ListConfig, OmegaConf
from loguru import logger
def _get_dof_indices(
robot: Articulation,
key_dofs: list[str] | None,
) -> list[int] | None:
if key_dofs is None:
return list(range(len(robot.joint_names)))
dof_indices = []
for name in key_dofs:
if name not in robot.joint_names:
raise ValueError(
f"DOF '{name}' not found in robot.joint_names: {robot.joint_names}"
)
dof_indices.append(robot.joint_names.index(name))
return dof_indices
def _get_body_indices(
robot: Articulation,
keybody_names: list[str] | None,
) -> list[int] | None:
"""Convert body names to indices.
Args:
robot: Robot articulation asset
keybody_names: List of body names. If None, returns None.
Returns:
List of body indices corresponding to the given names, or None if keybody_names is None
"""
if keybody_names is None:
return list(range(len(robot.body_names)))
body_indices = []
for name in keybody_names:
if name not in robot.body_names:
raise ValueError(
f"Body '{name}' not found in robot.body_names: {robot.body_names}"
)
body_indices.append(robot.body_names.index(name))
return body_indices
def resolve_holo_config(value):
def _sanitize_config_object(obj):
for attr, attr_value in vars(obj).items():
sanitized_value = resolve_holo_config(attr_value)
setattr(obj, attr, sanitized_value)
return obj
if isinstance(value, (DictConfig, ListConfig)):
value = OmegaConf.to_container(value, resolve=True)
if isinstance(value, dict):
if "_target_" in value:
instantiated = hydra_instantiate(value)
if hasattr(instantiated, "__dict__") and not callable(
instantiated
):
return _sanitize_config_object(instantiated)
return instantiated
return {key: resolve_holo_config(item) for key, item in value.items()}
if isinstance(value, list):
return [resolve_holo_config(item) for item in value]
if hasattr(value, "__dict__") and not callable(value):
return _sanitize_config_object(value)
return value
================================================
FILE: holomotion/src/env/isaaclab_components/isaaclab_velocity_tracking_command.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
from dataclasses import MISSING
from isaaclab.managers import CommandTermCfg
from isaaclab.utils import configclass
import torch
from isaaclab.assets import Articulation
from isaaclab.managers import CommandTerm
from isaaclab.markers import VisualizationMarkers
from isaaclab.envs import ManagerBasedEnv
import isaaclab.utils.math as math_utils
from isaaclab.markers import VisualizationMarkersCfg
from isaaclab.markers.config import (
BLUE_ARROW_X_MARKER_CFG,
GREEN_ARROW_X_MARKER_CFG,
)
from typing import Sequence
class HoloMotionUniformVelocityCommand(CommandTerm):
r"""Command generator that generates a velocity command in SE(2) from uniform distribution.
The command comprises of a linear velocity in x and y direction and an angular velocity around
the z-axis. It is given in the robot's base frame.
If the :attr:`cfg.heading_command` flag is set to True, the angular velocity is computed from the heading
error similar to doing a proportional control on the heading error. The target heading is sampled uniformly
from the provided range. Otherwise, the angular velocity is sampled uniformly from the provided range.
Mathematically, the angular velocity is computed as follows from the heading command:
.. math::
\omega_z = \frac{1}{2} \text{wrap_to_pi}(\theta_{\text{target}} - \theta_{\text{current}})
"""
cfg: HoloMotionUniformVelocityCommandCfg
"""The configuration of the command generator."""
def __init__(
self, cfg: HoloMotionUniformVelocityCommandCfg, env: ManagerBasedEnv
):
"""Initialize the command generator.
Args:
cfg: The configuration of the command generator.
env: The environment.
Raises:
ValueError: If the heading command is active but the heading range is not provided.
"""
# initialize the base class
super().__init__(cfg, env)
# check configuration
if self.cfg.heading_command and self.cfg.ranges.heading is None:
raise ValueError(
"The velocity command has heading commands active (heading_command=True) but the `ranges.heading`"
" parameter is set to None."
)
if self.cfg.rel_yaw_envs > 0.0:
yaw_min, yaw_max = self.cfg.ranges.ang_vel_z
# obtain the robot asset
# -- robot
self.robot: Articulation = env.scene[cfg.asset_name]
# crete buffers to store the command
# -- command: x vel, y vel, yaw vel, heading
self.vel_command_b = torch.zeros(self.num_envs, 3, device=self.device)
self.heading_target = torch.zeros(self.num_envs, device=self.device)
self.is_heading_env = torch.zeros(
self.num_envs, dtype=torch.bool, device=self.device
)
self.is_standing_env = torch.zeros_like(self.is_heading_env)
self.is_yaw_env = torch.zeros_like(self.is_heading_env)
# -- metrics
self.metrics["error_vel_xy"] = torch.zeros(
self.num_envs, device=self.device
)
self.metrics["error_vel_yaw"] = torch.zeros(
self.num_envs, device=self.device
)
def __str__(self) -> str:
"""Return a string representation of the command generator."""
msg = "HoloMotionUniformVelocityCommand:\n"
msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n"
msg += f"\tResampling time range: {self.cfg.resampling_time_range}\n"
msg += f"\tHeading command: {self.cfg.heading_command}\n"
if self.cfg.heading_command:
msg += f"\tHeading probability: {self.cfg.rel_heading_envs}\n"
msg += f"\tStanding probability: {self.cfg.rel_standing_envs}\n"
msg += f"\tYaw-only probability: {self.cfg.rel_yaw_envs}"
return msg
"""
Properties
"""
@property
def command(self) -> torch.Tensor:
"""The desired base velocity command in the base frame. Shape is (num_envs, 3)."""
return self.vel_command_b
"""
Implementation specific functions.
"""
def _update_metrics(self):
# time for which the command was executed
max_command_time = self.cfg.resampling_time_range[1]
max_command_step = max_command_time / self._env.step_dt
# logs data
self.metrics["error_vel_xy"] += (
torch.norm(
self.vel_command_b[:, :2]
- self.robot.data.root_lin_vel_b[:, :2],
dim=-1,
)
/ max_command_step
)
self.metrics["error_vel_yaw"] += (
torch.abs(
self.vel_command_b[:, 2] - self.robot.data.root_ang_vel_b[:, 2]
)
/ max_command_step
)
def _resample_command(self, env_ids: Sequence[int]):
# sample velocity commands
r = torch.empty(len(env_ids), device=self.device)
# -- linear velocity - x direction
self.vel_command_b[env_ids, 0] = r.uniform_(*self.cfg.ranges.lin_vel_x)
# -- linear velocity - y direction
self.vel_command_b[env_ids, 1] = r.uniform_(*self.cfg.ranges.lin_vel_y)
# -- ang vel yaw - rotation around z
self.vel_command_b[env_ids, 2] = r.uniform_(*self.cfg.ranges.ang_vel_z)
# heading target
if self.cfg.heading_command:
self.heading_target[env_ids] = r.uniform_(*self.cfg.ranges.heading)
# update heading envs
self.is_heading_env[env_ids] = (
r.uniform_(0.0, 1.0) <= self.cfg.rel_heading_envs
)
self.is_yaw_env[env_ids] = (
r.uniform_(0.0, 1.0) <= self.cfg.rel_yaw_envs
)
if self.cfg.heading_command:
# yaw-only envs should follow directly sampled yaw commands (not heading control)
self.is_heading_env[env_ids] &= ~self.is_yaw_env[env_ids]
# update standing envs
self.is_standing_env[env_ids] = (
r.uniform_(0.0, 1.0) <= self.cfg.rel_standing_envs
)
def _update_command(self):
"""Post-processes the velocity command.
This function sets velocity command to zero for standing environments and computes angular
velocity from heading direction if the heading_command flag is set.
"""
# Compute angular velocity from heading direction
if self.cfg.heading_command:
# resolve indices of heading envs
env_ids = self.is_heading_env.nonzero(as_tuple=False).flatten()
# compute angular velocity
heading_error = math_utils.wrap_to_pi(
self.heading_target[env_ids]
- self.robot.data.heading_w[env_ids]
)
self.vel_command_b[env_ids, 2] = torch.clip(
self.cfg.heading_control_stiffness * heading_error,
min=self.cfg.ranges.ang_vel_z[0],
max=self.cfg.ranges.ang_vel_z[1],
)
yaw_env_ids = self.is_yaw_env.nonzero(as_tuple=False).flatten()
self.vel_command_b[yaw_env_ids, :2] = 0.0
# Enforce standing (i.e., zero velocity command) for standing envs
# TODO: check if conversion is needed
standing_env_ids = self.is_standing_env.nonzero(
as_tuple=False
).flatten()
self.vel_command_b[standing_env_ids, :] = 0.0
def _set_debug_vis_impl(self, debug_vis: bool):
# set visibility of markers
# note: parent only deals with callbacks. not their visibility
if debug_vis:
# create markers if necessary for the first time
if not hasattr(self, "goal_vel_visualizer"):
# -- goal
self.goal_vel_visualizer = VisualizationMarkers(
self.cfg.goal_vel_visualizer_cfg
)
# -- current
self.current_vel_visualizer = VisualizationMarkers(
self.cfg.current_vel_visualizer_cfg
)
# set their visibility to true
self.goal_vel_visualizer.set_visibility(True)
self.current_vel_visualizer.set_visibility(True)
else:
if hasattr(self, "goal_vel_visualizer"):
self.goal_vel_visualizer.set_visibility(False)
self.current_vel_visualizer.set_visibility(False)
def _debug_vis_callback(self, event):
# check if robot is initialized
# note: this is needed in-case the robot is de-initialized. we can't access the data
if not self.robot.is_initialized:
return
# get marker location
# -- base state
base_pos_w = self.robot.data.root_pos_w.clone()
base_pos_w[:, 2] += 0.5
# -- resolve the scales and quaternions
vel_des_arrow_scale, vel_des_arrow_quat = (
self._resolve_xy_velocity_to_arrow(self.command[:, :2])
)
vel_arrow_scale, vel_arrow_quat = self._resolve_xy_velocity_to_arrow(
self.robot.data.root_lin_vel_b[:, :2]
)
# display markers
self.goal_vel_visualizer.visualize(
base_pos_w, vel_des_arrow_quat, vel_des_arrow_scale
)
self.current_vel_visualizer.visualize(
base_pos_w, vel_arrow_quat, vel_arrow_scale
)
"""
Internal helpers.
"""
def _resolve_xy_velocity_to_arrow(
self, xy_velocity: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Converts the XY base velocity command to arrow direction rotation."""
# obtain default scale of the marker
default_scale = self.goal_vel_visualizer.cfg.markers["arrow"].scale
# arrow-scale
arrow_scale = torch.tensor(default_scale, device=self.device).repeat(
xy_velocity.shape[0], 1
)
arrow_scale[:, 0] *= torch.linalg.norm(xy_velocity, dim=1) * 3.0
# arrow-direction
heading_angle = torch.atan2(xy_velocity[:, 1], xy_velocity[:, 0])
zeros = torch.zeros_like(heading_angle)
arrow_quat = math_utils.quat_from_euler_xyz(
zeros, zeros, heading_angle
)
# convert everything back from base to world frame
base_quat_w = self.robot.data.root_quat_w
arrow_quat = math_utils.quat_mul(base_quat_w, arrow_quat)
return arrow_scale, arrow_quat
@configclass
class HoloMotionUniformVelocityCommandCfg(CommandTermCfg):
"""Configuration for the uniform velocity command generator."""
class_type: type = HoloMotionUniformVelocityCommand
asset_name: str = MISSING
"""Name of the asset in the environment for which the commands are generated."""
heading_command: bool = False
"""Whether to use heading command or angular velocity command. Defaults to False.
If True, the angular velocity command is computed from the heading error, where the
target heading is sampled uniformly from provided range. Otherwise, the angular velocity
command is sampled uniformly from provided range.
"""
heading_control_stiffness: float = 1.0
"""Scale factor to convert the heading error to angular velocity command. Defaults to 1.0."""
rel_standing_envs: float = 0.0
"""The sampled probability of environments that should be standing still. Defaults to 0.0."""
rel_yaw_envs: float = 0.0
"""The sampled probability of environments that should receive yaw-only commands. Defaults to 0.0.
For yaw-only environments, the command is post-processed to:
- enforce vx=vy=0
This is sampled independently from :attr:`rel_standing_envs`. If an environment is both yaw-only
and standing, standing still overrides to zero command.
"""
rel_heading_envs: float = 1.0
"""The sampled probability of environments where the robots follow the heading-based angular velocity command
(the others follow the sampled angular velocity command). Defaults to 1.0.
This parameter is only used if :attr:`heading_command` is True.
"""
@configclass
class Ranges:
"""Uniform distribution ranges for the velocity commands."""
lin_vel_x: tuple[float, float] = MISSING
"""Range for the linear-x velocity command (in m/s)."""
lin_vel_y: tuple[float, float] = MISSING
"""Range for the linear-y velocity command (in m/s)."""
ang_vel_z: tuple[float, float] = MISSING
"""Range for the angular-z velocity command (in rad/s)."""
heading: tuple[float, float] | None = None
"""Range for the heading command (in rad). Defaults to None.
This parameter is only used if :attr:`~HoloMotionUniformVelocityCommandCfg.heading_command` is True.
"""
ranges: Ranges = MISSING
"""Distribution ranges for the velocity commands."""
goal_vel_visualizer_cfg: VisualizationMarkersCfg = (
GREEN_ARROW_X_MARKER_CFG.replace(
prim_path="/Visuals/Command/velocity_goal"
)
)
"""The configuration for the goal velocity visualization marker. Defaults to GREEN_ARROW_X_MARKER_CFG."""
current_vel_visualizer_cfg: VisualizationMarkersCfg = (
BLUE_ARROW_X_MARKER_CFG.replace(
prim_path="/Visuals/Command/velocity_current"
)
)
"""The configuration for the current velocity visualization marker. Defaults to BLUE_ARROW_X_MARKER_CFG."""
# Set the scale of the visualization markers to (0.5, 0.5, 0.5)
goal_vel_visualizer_cfg.markers["arrow"].scale = (0.5, 0.5, 0.5)
current_vel_visualizer_cfg.markers["arrow"].scale = (0.5, 0.5, 0.5)
@configclass
class VelTrack_CommandsCfg:
pass
def _convert_ranges_dict_to_object(
ranges_dict: dict,
) -> HoloMotionUniformVelocityCommandCfg.Ranges:
"""Convert a dict of ranges to a proper Ranges object with tuples."""
ranges_kwargs = {}
for key, value in ranges_dict.items():
if value is None:
ranges_kwargs[key] = None
elif isinstance(value, (list, tuple)):
ranges_kwargs[key] = tuple(value)
else:
ranges_kwargs[key] = value
return HoloMotionUniformVelocityCommandCfg.Ranges(**ranges_kwargs)
def build_velocity_commands_config(
command_config_dict: dict,
) -> VelTrack_CommandsCfg:
"""Build a CommandsCfg that supports velocity commands via IsaacLab isaaclab_mdp.
Expected format:
{
"base_velocity": {
"type": "VelocityCommandCfg" | "HoloMotionUniformVelocityCommandCfg" | "UniformLevelVelocityCommandCfg",
"params": { ... } # args compatible with mdp command cfgs
}
}
For ranges and limit_ranges, pass them as dicts with keys like lin_vel_x, lin_vel_y, ang_vel_z, heading.
"""
commands_cfg = VelTrack_CommandsCfg()
for name, cfg in command_config_dict.items():
command_type = cfg.get("type", "VelocityCommandCfg")
params = cfg.get("params", {}).copy()
if "ranges" in params and isinstance(params["ranges"], dict):
params["ranges"] = _convert_ranges_dict_to_object(params["ranges"])
if "limit_ranges" in params and isinstance(
params["limit_ranges"], dict
):
params["limit_ranges"] = _convert_ranges_dict_to_object(
params["limit_ranges"]
)
if command_type == "HoloMotionUniformVelocityCommandCfg":
term_cfg = HoloMotionUniformVelocityCommandCfg(**params)
else:
raise ValueError(f"Unknown velocity command type: {command_type}")
setattr(commands_cfg, name, term_cfg)
return commands_cfg
================================================
FILE: holomotion/src/env/isaaclab_components/unitree_actuators.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# This file is modified from the unitree_rl_lab repository:
# https://github.com/unitreerobotics/unitree_rl_lab
from __future__ import annotations
import json
import os
from pathlib import Path
import torch
from dataclasses import MISSING
from typing import Sequence
from isaaclab.actuators import DelayedPDActuator, DelayedPDActuatorCfg
from isaaclab.utils import configclass
from isaaclab.utils.types import ArticulationActions
from loguru import logger
class UnitreeActuator(DelayedPDActuator):
"""Unitree actuator class that implements a torque-speed curve for the actuators.
The torque-speed curve is defined as follows:
Torque Limit, N·m
^
Y2──────────|
|──────────────Y1
| │\
| │ \
| │ \
| | \
------------+--------------|------> velocity: rad/s
X1 X2
- Y1: Peak Torque Test (Torque and Speed in the Same Direction)
- Y2: Peak Torque Test (Torque and Speed in the Opposite Direction)
- X1: Maximum Speed at Full Torque (T-N Curve Knee Point)
- X2: No-Load Speed Test
- Fs: Static friction coefficient
- Fd: Dynamic friction coefficient
- Va: Velocity at which the friction is fully activated
"""
cfg: UnitreeActuatorCfg
armature: torch.Tensor
"""The armature of the actuator joints. Shape is (num_envs, num_joints).
armature = J2 + J1 * i2 ^ 2 + Jr * (i1 * i2) ^ 2
"""
def __init__(self, cfg: UnitreeActuatorCfg, *args, **kwargs):
super().__init__(cfg, *args, **kwargs)
self._joint_vel = torch.zeros_like(self.computed_effort)
self._effort_y1 = self._parse_joint_parameter(cfg.Y1, 1e9)
self._effort_y2 = self._parse_joint_parameter(cfg.Y2, cfg.Y1)
self._velocity_x1 = self._parse_joint_parameter(cfg.X1, 1e9)
self._velocity_x2 = self._parse_joint_parameter(cfg.X2, 1e9)
self._friction_static = self._parse_joint_parameter(cfg.Fs, 0.0)
self._friction_dynamic = self._parse_joint_parameter(cfg.Fd, 0.0)
self._activation_vel = self._parse_joint_parameter(cfg.Va, 0.01)
def compute(
self,
control_action: ArticulationActions,
joint_pos: torch.Tensor,
joint_vel: torch.Tensor,
) -> ArticulationActions:
# save current joint vel
self._joint_vel[:] = joint_vel
# calculate the desired joint torques
control_action = super().compute(control_action, joint_pos, joint_vel)
# apply friction model on the torque
self.applied_effort -= (
self._friction_static
* torch.tanh(joint_vel / self._activation_vel)
+ self._friction_dynamic * joint_vel
)
control_action.joint_positions = None
control_action.joint_velocities = None
control_action.joint_efforts = self.applied_effort
return control_action
def _clip_effort(self, effort: torch.Tensor) -> torch.Tensor:
# check if the effort is the same direction as the joint velocity
same_direction = (self._joint_vel * effort) > 0
max_effort = torch.where(
same_direction, self._effort_y1, self._effort_y2
)
# check if the joint velocity is less than the max speed at full torque
max_effort = torch.where(
self._joint_vel.abs() < self._velocity_x1,
max_effort,
self._compute_effort_limit(max_effort),
)
return torch.clip(effort, -max_effort, max_effort)
def _compute_effort_limit(self, max_effort):
k = -max_effort / (self._velocity_x2 - self._velocity_x1)
limit = k * (self._joint_vel.abs() - self._velocity_x1) + max_effort
return limit.clip(min=0.0)
class UnitreeErfiActuator(UnitreeActuator):
"""Unitree actuator with per-env ERFI-50 torque perturbations.
On environment reset, each env is assigned either step-wise random force
injection (RFI) or episode-level random actuation offset (RAO). During
rollout, only the selected mode is applied for that env.
"""
cfg: UnitreeErfiActuatorCfg
def __init__(self, cfg: UnitreeErfiActuatorCfg, *args, **kwargs):
super().__init__(cfg, *args, **kwargs)
self._ema_filter_alpha = float(cfg.ema_filter_alpha)
if not 0.0 <= self._ema_filter_alpha <= 1.0:
raise ValueError(
"ema_filter_alpha must be within [0, 1], "
f"got {self._ema_filter_alpha}."
)
self._ema_filter_debug_dump_path = (
cfg.ema_filter_debug_dump_path
or os.environ.get("HOLOMOTION_EMA_FILTER_DEBUG_DUMP_PATH")
)
self._ema_filter_debug_stop_after_dump = self._parse_bool_env(
"HOLOMOTION_EMA_FILTER_DEBUG_STOP_AFTER_DUMP",
cfg.ema_filter_debug_stop_after_dump,
)
self._ema_filter_debug_dumped = False
self._ema_filter_state = torch.zeros_like(self.computed_effort)
self._ema_filter_initialized = torch.zeros(
self._num_envs, dtype=torch.bool, device=self._device
)
self._mode_is_rfi = torch.zeros(
self._num_envs, dtype=torch.bool, device=self._device
)
self._rfi_lim_scale = torch.ones_like(self.computed_effort)
self._rao_scale = torch.zeros_like(self.computed_effort)
def reset(self, env_ids: Sequence[int] | slice | None):
super().reset(env_ids)
env_ids_tensor = self._env_ids_to_tensor(env_ids)
if env_ids_tensor.numel() == 0:
return
if self.cfg.ema_filter_enabled:
self._ema_filter_state[env_ids_tensor] = 0.0
self._ema_filter_initialized[env_ids_tensor] = False
if not self.cfg.erfi_enabled:
self._mode_is_rfi[env_ids_tensor] = False
self._rfi_lim_scale[env_ids_tensor] = 1.0
self._rao_scale[env_ids_tensor] = 0.0
return
sampled_is_rfi = (
torch.rand(env_ids_tensor.numel(), device=self._device)
< self.cfg.rfi_probability
)
self._mode_is_rfi[env_ids_tensor] = sampled_is_rfi
if self.cfg.randomize_rfi_lim:
self._rfi_lim_scale[env_ids_tensor] = self._sample_uniform(
self.cfg.rfi_lim_range[0],
self.cfg.rfi_lim_range[1],
(env_ids_tensor.numel(), self.num_joints),
)
else:
self._rfi_lim_scale[env_ids_tensor] = 1.0
self._rao_scale[env_ids_tensor] = self._sample_uniform(
-self.cfg.rao_lim,
self.cfg.rao_lim,
(env_ids_tensor.numel(), self.num_joints),
)
rfi_env_ids = env_ids_tensor[sampled_is_rfi]
if rfi_env_ids.numel() > 0:
self._rao_scale[rfi_env_ids] = 0.0
def compute(
self,
control_action: ArticulationActions,
joint_pos: torch.Tensor,
joint_vel: torch.Tensor,
) -> ArticulationActions:
control_action = self._filter_joint_position_action(control_action)
if not self.cfg.erfi_enabled:
return super().compute(control_action, joint_pos, joint_vel)
if control_action.joint_efforts is None:
base_joint_efforts = torch.zeros_like(joint_pos)
else:
base_joint_efforts = control_action.joint_efforts.clone()
effort_limit = self.effort_limit.to(base_joint_efforts)
rfi_noise = self._sample_uniform(-1.0, 1.0, base_joint_efforts.shape)
rfi_term = (
rfi_noise * self.cfg.rfi_lim * self._rfi_lim_scale * effort_limit
)
rao_term = self._rao_scale * effort_limit
mode_is_rfi = self._mode_is_rfi.unsqueeze(-1)
control_action_with_erfi = ArticulationActions(
joint_positions=control_action.joint_positions,
joint_velocities=control_action.joint_velocities,
joint_efforts=base_joint_efforts
+ torch.where(mode_is_rfi, rfi_term, rao_term),
joint_indices=control_action.joint_indices,
)
return super().compute(control_action_with_erfi, joint_pos, joint_vel)
def _filter_joint_position_action(
self, control_action: ArticulationActions
) -> ArticulationActions:
if not self.cfg.ema_filter_enabled:
self._maybe_dump_ema_filter_debug_skip("ema_filter_disabled")
return control_action
if control_action.joint_positions is None:
self._maybe_dump_ema_filter_debug_skip("joint_positions_none")
return control_action
raw_joint_positions = control_action.joint_positions
previous_filtered_joint_positions = self._ema_filter_state.clone()
needs_init = ~self._ema_filter_initialized
filtered_joint_positions = raw_joint_positions.clone()
if torch.any(~needs_init):
filtered_joint_positions = torch.where(
needs_init.unsqueeze(-1),
raw_joint_positions,
self._ema_filter_alpha * raw_joint_positions
+ (1.0 - self._ema_filter_alpha) * self._ema_filter_state,
)
self._maybe_dump_ema_filter_debug_verification(
raw_joint_positions=raw_joint_positions,
filtered_joint_positions=filtered_joint_positions,
previous_filtered_joint_positions=previous_filtered_joint_positions,
needs_init=needs_init,
)
self._ema_filter_state[:] = filtered_joint_positions
self._ema_filter_initialized[:] = True
return ArticulationActions(
joint_positions=filtered_joint_positions,
joint_velocities=control_action.joint_velocities,
joint_efforts=control_action.joint_efforts,
joint_indices=control_action.joint_indices,
)
def _maybe_dump_ema_filter_debug_verification(
self,
raw_joint_positions: torch.Tensor,
filtered_joint_positions: torch.Tensor,
previous_filtered_joint_positions: torch.Tensor,
needs_init: torch.Tensor,
) -> None:
if (
self._ema_filter_debug_dumped
or not self._ema_filter_debug_dump_path
):
return
rank = os.environ.get("RANK")
if rank is not None and rank != "0":
return
initialized_env_ids = torch.nonzero(
~needs_init, as_tuple=False
).flatten()
if initialized_env_ids.numel() == 0:
return
env_idx = int(initialized_env_ids[0].item())
raw = raw_joint_positions[env_idx].detach().cpu()
prev = previous_filtered_joint_positions[env_idx].detach().cpu()
actual = filtered_joint_positions[env_idx].detach().cpu()
expected = (
self._ema_filter_alpha * raw
+ (1.0 - self._ema_filter_alpha) * prev
)
matched = torch.allclose(actual, expected, atol=1.0e-6, rtol=1.0e-6)
dump_path = Path(self._ema_filter_debug_dump_path)
dump_path.parent.mkdir(parents=True, exist_ok=True)
dump_path.write_text(
json.dumps(
{
"alpha": self._ema_filter_alpha,
"env_index": env_idx,
"matched": bool(matched),
"raw_joint_positions": raw.tolist(),
"previous_filtered_joint_positions": prev.tolist(),
"expected_filtered_joint_positions": expected.tolist(),
"actual_filtered_joint_positions": actual.tolist(),
"pid": os.getpid(),
"rank": rank or "0",
},
indent=2,
)
)
self._ema_filter_debug_dumped = True
logger.info("Wrote EMA verification dump to {}", dump_path)
if self._ema_filter_debug_stop_after_dump:
raise RuntimeError(f"EMA verification dump written to {dump_path}")
def _maybe_dump_ema_filter_debug_skip(self, reason: str) -> None:
self._maybe_dump_ema_filter_debug_payload(
{
"applied": False,
"reason": reason,
}
)
def _maybe_dump_ema_filter_debug_payload(self, payload: dict) -> None:
if (
self._ema_filter_debug_dumped
or not self._ema_filter_debug_dump_path
):
return
rank = os.environ.get("RANK")
if rank is not None and rank != "0":
return
dump_path = Path(self._ema_filter_debug_dump_path)
dump_path.parent.mkdir(parents=True, exist_ok=True)
dump_path.write_text(
json.dumps(
{
**payload,
"alpha": self._ema_filter_alpha,
"pid": os.getpid(),
"rank": rank or "0",
},
indent=2,
)
)
self._ema_filter_debug_dumped = True
logger.info("Wrote EMA verification dump to {}", dump_path)
if self._ema_filter_debug_stop_after_dump:
raise RuntimeError(f"EMA verification dump written to {dump_path}")
@staticmethod
def _parse_bool_env(name: str, default: bool) -> bool:
raw_value = os.environ.get(name)
if raw_value is None:
return bool(default)
return raw_value.strip().lower() in {"1", "true", "yes", "on"}
def _env_ids_to_tensor(
self, env_ids: Sequence[int] | slice | None
) -> torch.Tensor:
if env_ids is None or env_ids == slice(None):
return torch.arange(self._num_envs, device=self._device)
if isinstance(env_ids, torch.Tensor):
return env_ids.to(device=self._device, dtype=torch.long).flatten()
return torch.tensor(env_ids, device=self._device, dtype=torch.long)
def _sample_uniform(
self, low: float, high: float, shape: tuple[int, ...]
) -> torch.Tensor:
return torch.empty(shape, device=self._device).uniform_(low, high)
@configclass
class UnitreeActuatorCfg(DelayedPDActuatorCfg):
"""
Configuration for Unitree actuators.
"""
class_type: type = UnitreeActuator
X1: float = 1e9
"""Maximum Speed at Full Torque(T-N Curve Knee Point) Unit: rad/s"""
X2: float = 1e9
"""No-Load Speed Test Unit: rad/s"""
Y1: float = MISSING
"""Peak Torque Test(Torque and Speed in the Same Direction) Unit: N*m"""
Y2: float | None = None
"""Peak Torque Test(Torque and Speed in the Opposite Direction) Unit: N*m"""
Fs: float = 0.0
""" Static friction coefficient """
Fd: float = 0.0
""" Dynamic friction coefficient """
Va: float = 0.01
""" Velocity at which the friction is fully activated """
@configclass
class UnitreeErfiActuatorCfg(UnitreeActuatorCfg):
"""Configuration for Unitree actuators with ERFI-50 perturbations."""
class_type: type = UnitreeErfiActuator
erfi_enabled: bool = False
"""Whether ERFI perturbations are enabled for this actuator."""
ema_filter_enabled: bool = False
"""Whether to apply EMA filtering to incoming joint-position actions."""
ema_filter_alpha: float = 1.0
"""EMA mixing factor using filtered = alpha * raw + (1 - alpha) * prev."""
ema_filter_debug_dump_path: str | None = None
"""Optional JSON path for a one-shot EMA verification dump during runtime."""
ema_filter_debug_stop_after_dump: bool = False
"""Whether to stop execution after writing the EMA verification dump."""
rfi_probability: float = 0.5
"""Probability of assigning RFI to an environment on reset."""
rfi_lim: float = 0.1
"""Base RFI limit, expressed as a ratio of joint effort limits."""
randomize_rfi_lim: bool = True
"""Whether to randomize the per-episode RFI limit scale."""
rfi_lim_range: tuple[float, float] = (0.5, 1.5)
"""Multiplicative range for per-episode RFI scaling."""
rao_lim: float = 0.1
"""RAO limit, expressed as a ratio of joint effort limits."""
@configclass
class UnitreeActuatorCfg_M107_15(UnitreeActuatorCfg):
X1 = 14.0
X2 = 25.6
Y1 = 150.0
Y2 = 182.8
armature = 0.063259741
@configclass
class UnitreeActuatorCfg_M107_24(UnitreeActuatorCfg):
X1 = 8.8
X2 = 16
Y1 = 240
Y2 = 292.5
armature = 0.160478022
@configclass
class UnitreeActuatorCfg_Go2HV(UnitreeActuatorCfg):
X1 = 13.5
X2 = 30
Y1 = 20.2
Y2 = 23.4
@configclass
class UnitreeActuatorCfg_N7520_14p3(UnitreeActuatorCfg):
# Decimal point cannot be used as variable name, use `p` instead
X1 = 22.63
X2 = 35.52
Y1 = 71
Y2 = 83.3
Fs = 1.6
Fd = 0.16
"""
| rotor | 0.489e-4 kg·m²
| gear_1 | 0.098e-4 kg·m² | ratio | 4.5
| gear_2 | 0.533e-4 kg·m² | ratio | 48/22+1
"""
armature = 0.01017752
@configclass
class UnitreeActuatorCfg_N7520_22p5(UnitreeActuatorCfg):
# Decimal point cannot be used as variable name, use `p` instead
X1 = 14.5
X2 = 22.7
Y1 = 111.0
Y2 = 131.0
Fs = 2.4
Fd = 0.24
"""
| rotor | 0.489e-4 kg·m²
| gear_1 | 0.109e-4 kg·m² | ratio | 4.5
| gear_2 | 0.738e-4 kg·m² | ratio | 5.0
"""
armature = 0.025101925
@configclass
class UnitreeActuatorCfg_N5010_16(UnitreeActuatorCfg):
X1 = 27.0
X2 = 41.5
Y1 = 9.5
Y2 = 17.0
"""
| rotor | 0.084e-4 kg·m²
| gear_1 | 0.015e-4 kg·m² | ratio | 4
| gear_2 | 0.068e-4 kg·m² | ratio | 4
"""
armature = 0.0021812
@configclass
class UnitreeActuatorCfg_N5020_16(UnitreeActuatorCfg):
X1 = 30.86
X2 = 40.13
Y1 = 24.8
Y2 = 31.9
Fs = 0.6
Fd = 0.06
"""
| rotor | 0.139e-4 kg·m²
| gear_1 | 0.017e-4 kg·m² | ratio | 46/18+1
| gear_2 | 0.169e-4 kg·m² | ratio | 56/16+1
"""
armature = 0.003609725
@configclass
class UnitreeActuatorCfg_W4010_25(UnitreeActuatorCfg):
X1 = 15.3
X2 = 24.76
Y1 = 4.8
Y2 = 8.6
Fs = 0.6
Fd = 0.06
"""
| rotor | 0.068e-4 kg·m²
| gear_1 | | ratio | 5
| gear_2 | | ratio | 5
"""
armature = 0.00425
================================================
FILE: holomotion/src/env/motion_tracking.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import torch
import time
import os
import yaml
from collections import deque
from functools import wraps
from easydict import EasyDict
import random
import numpy as np
from isaaclab.actuators import ImplicitActuatorCfg
from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg
from isaaclab.sim import PhysxCfg, SimulationCfg
from isaaclab.utils import configclass
from isaaclab.utils.io import dump_yaml
from loguru import logger
from omegaconf import OmegaConf
from holomotion.src.env.isaaclab_components import (
ActionsCfg,
VelTrack_CommandsCfg,
MoTrack_CommandsCfg,
EventsCfg,
MotionTrackingSceneCfg,
ObservationsCfg,
RewardsCfg,
TerminationsCfg,
CurriculumCfg,
build_actions_config,
build_motion_tracking_commands_config,
build_velocity_commands_config,
build_domain_rand_config,
build_curriculum_config,
build_observations_config,
build_rewards_config,
build_scene_config,
build_terminations_config,
)
from holomotion.src.env.isaaclab_components.isaaclab_observation import (
ObservationFunctions,
)
from holomotion.src.env.isaaclab_components.isaaclab_utils import (
resolve_holo_config,
)
# from holomotion.src.modules.agent_modules import ObsSeqSerializer
import isaaclab.envs.mdp as isaaclab_mdp
from isaaclab.envs.mdp.events import _randomize_prop_by_op
from isaaclab.managers import SceneEntityCfg, EventTermCfg
from isaaclab.utils import configclass
from isaaclab.envs import ManagerBasedEnv
from isaaclab.managers import EventTermCfg
from isaaclab.managers import EventTermCfg as EventTerm
import isaaclab.utils.math as math_utils
from isaaclab.assets import Articulation
from isaaclab.envs.mdp.events import _randomize_prop_by_op
from isaaclab.managers import SceneEntityCfg
from typing import TYPE_CHECKING, Literal
def _joint_ids_to_tensor(
joint_ids: slice | list[int] | tuple[int, ...] | torch.Tensor | None,
num_joints: int,
device: torch.device | str,
) -> torch.Tensor:
if joint_ids is None:
return torch.arange(num_joints, device=device, dtype=torch.long)
if isinstance(joint_ids, slice):
if joint_ids == slice(None):
return torch.arange(num_joints, device=device, dtype=torch.long)
return torch.arange(num_joints, device=device, dtype=torch.long)[
joint_ids
]
if isinstance(joint_ids, torch.Tensor):
return joint_ids.to(device=device, dtype=torch.long).flatten()
return torch.tensor(joint_ids, device=device, dtype=torch.long)
def _select_effort_limit_vector(
asset: Articulation,
selected_joint_ids: torch.Tensor,
) -> torch.Tensor:
num_joints = asset.data.applied_torque.shape[1]
device = asset.data.applied_torque.device
dtype = asset.data.applied_torque.dtype
effort_limit_vec = torch.zeros(num_joints, device=device, dtype=dtype)
is_filled = torch.zeros(num_joints, device=device, dtype=torch.bool)
for actuator in asset.actuators.values():
actuator_joint_ids = _joint_ids_to_tensor(
actuator.joint_indices, num_joints=num_joints, device=device
)
actuator_effort_limit = torch.as_tensor(
actuator.effort_limit, device=device, dtype=dtype
)
if actuator_effort_limit.ndim == 0:
actuator_effort_limit = actuator_effort_limit.expand(
actuator_joint_ids.numel()
)
elif actuator_effort_limit.ndim == 2:
if actuator_effort_limit.shape[0] > 1:
reference = actuator_effort_limit[0].unsqueeze(0)
if not torch.allclose(
actuator_effort_limit,
reference.expand_as(actuator_effort_limit),
):
raise ValueError(
"normed_torque_rate requires actuator effort limits to be static across envs."
)
actuator_effort_limit = actuator_effort_limit[0]
elif actuator_effort_limit.ndim != 1:
raise ValueError(
"normed_torque_rate expects actuator effort limits to be scalar, 1-D, or 2-D tensors."
)
if actuator_effort_limit.numel() != actuator_joint_ids.numel():
raise ValueError(
"normed_torque_rate found mismatched actuator joint indices and effort limits."
)
effort_limit_vec[actuator_joint_ids] = actuator_effort_limit
is_filled[actuator_joint_ids] = True
if not torch.all(is_filled[selected_joint_ids]):
missing_joint_ids = selected_joint_ids[~is_filled[selected_joint_ids]]
raise ValueError(
"normed_torque_rate could not resolve actuator effort limits for "
f"joint ids {missing_joint_ids.tolist()}."
)
selected_effort_limits = effort_limit_vec[selected_joint_ids]
if not torch.all(torch.isfinite(selected_effort_limits)):
raise ValueError(
"normed_torque_rate requires finite actuator effort limits for all selected joints."
)
if not torch.all(selected_effort_limits > 0.0):
raise ValueError(
"normed_torque_rate requires strictly positive actuator effort limits for all selected joints."
)
return selected_effort_limits
class MotionTrackingEnv:
"""IsaacLab-based Motion Tracking Environment.
This environment integrates motion tracking capabilities with IsaacLab's
manager-based architecture, supporting curriculum learning, domain randomization,
and various termination conditions.
This is a wrapper class that handles Isaac Sim initialization and delegates
to an internal ManagerBasedRLEnv instance.
"""
def __init__(
self,
config,
device: torch.device = None,
log_dir: str = None,
render_mode: str | None = None,
headless: bool = True,
accelerator=None,
):
"""Initialize the Motion Tracking Environment.
Args:
config: Configuration for the environment
device: Device for tensor operations
log_dir: Logging directory
render_mode: Render mode for the environment
headless: Whether to run in headless mode
accelerator: Accelerator instance for distributed training (optional)
"""
self.config = config
self._device = device
self.accelerator = accelerator
self.log_dir = log_dir
self.headless = headless
self.init_done = False
self.is_evaluating = False
self.render_mode = render_mode
# self._init_motion_tracking_components()
self._init_isaaclab_env()
# self._init_serializers()
self._completion_total_queue = deque(maxlen=1000)
self._completion_success_queue = deque(maxlen=1000)
self.metrics = {}
self._robot_prev_joint_vel = None
self._robot_prev_applied_torque = None
self._robot_torque_rate_inv_effort_limit = None
self._robot_torque_rate_needs_reseed = None
@property
def num_envs(self):
return self._env.num_envs
@property
def device(self):
return self._env.device
def _init_isaaclab_env(self):
_device = self._device
curriculum = CurriculumCfg()
# Determine per-process seed if provided; else create a deterministic per-rank default
seed_val = getattr(self.config, "seed", None)
if seed_val is None:
if self.accelerator is not None:
pid = self.accelerator.process_index
else:
pid = int(self.config.get("process_id", 0))
seed_val = int(time.time()) + pid
_robot_config_dict = EasyDict(
OmegaConf.to_container(self.config.robot, resolve=True)
)
_terrain_config_dict = EasyDict(
OmegaConf.to_container(self.config.terrain, resolve=True)
)
_obs_config_dict = EasyDict(
OmegaConf.to_container(self.config.obs, resolve=True)
)
_rewards_config_dict = EasyDict(
OmegaConf.to_container(self.config.rewards, resolve=True)
)
_domain_rand_config_dict = (
EasyDict(
OmegaConf.to_container(
self.config.domain_rand,
resolve=True,
)
)
if self.config.domain_rand is not None
else {}
)
_terminations_config_dict = (
EasyDict(
OmegaConf.to_container(
self.config.terminations,
resolve=True,
)
)
if self.config.terminations is not None
else {}
)
_scene_config_dict = EasyDict(
OmegaConf.to_container(
self.config.scene,
resolve=True,
)
)
_commands_config_dict = OmegaConf.to_container(
self.config.commands,
resolve=True,
)
_simulation_config_dict = EasyDict(
OmegaConf.to_container(
self.config.simulation,
resolve=True,
)
)
_actions_config_dict = EasyDict(
OmegaConf.to_container(
self.config.actions,
resolve=True,
)
)
if getattr(self.config, "curriculum", None) is not None:
_curriculum_config_dict = EasyDict(
OmegaConf.to_container(self.config.curriculum, resolve=True)
)
else:
_curriculum_config_dict = {}
@configclass
class MotionTrackingEnvCfg(ManagerBasedRLEnvCfg):
seed: int = seed_val
scene_config_dict = {
"num_envs": self.config.num_envs,
"env_spacing": self.config.env_spacing,
"replicate_physics": self.config.replicate_physics,
"robot": _robot_config_dict,
"terrain": _terrain_config_dict,
"domain_rand": _domain_rand_config_dict,
"lighting": _scene_config_dict.lighting,
"contact_sensor": _scene_config_dict.contact_sensor,
}
decimation: int = _simulation_config_dict.control_decimation
episode_length_s: int = _simulation_config_dict.episode_length_s
sim_freq = _simulation_config_dict.sim_freq
dt = 1.0 / sim_freq
physx = PhysxCfg(
bounce_threshold_velocity=_simulation_config_dict.physx.bounce_threshold_velocity,
gpu_max_rigid_patch_count=_simulation_config_dict.physx.gpu_max_rigid_patch_count,
enable_stabilization=True,
)
if self.accelerator is not None:
main_process = self.accelerator.is_main_process
process_id = self.accelerator.process_index
num_processes = self.accelerator.num_processes
else:
main_process = self.config.get("main_process", True)
process_id = self.config.get("process_id", 0)
num_processes = self.config.get("num_processes", 1)
scene: MotionTrackingSceneCfg = build_scene_config(
scene_config_dict,
main_process=main_process,
process_id=process_id,
num_processes=num_processes,
)
sim: SimulationCfg = SimulationCfg(
dt=dt,
render_interval=decimation,
physx=physx,
device=_device,
enable_scene_query_support=True,
)
sim.physics_material = scene.terrain.physics_material
viewer: ViewerCfg = ViewerCfg(origin_type="world")
motion_cmds = {}
vel_cmds = {}
for k, v in _commands_config_dict.items():
if (
isinstance(v, dict)
and v.get("type", "") == "MotionCommandCfg"
):
motion_cmds[k] = v
else:
vel_cmds[k] = v
# Populate RefMotionCommand distributed params when present.
if "ref_motion" in motion_cmds:
if self.accelerator is not None:
cmd_process_id = self.accelerator.process_index
cmd_num_processes = self.accelerator.num_processes
else:
cmd_process_id = getattr(self.config, "process_id", 0)
cmd_num_processes = getattr(
self.config, "num_processes", 1
)
motion_cmds["ref_motion"]["params"].update(
{
"seed": int(seed_val),
"process_id": cmd_process_id,
"num_processes": cmd_num_processes,
"is_evaluating": self.is_evaluating,
}
)
# Build a unified commands cfg that may contain both motion and velocity terms.
if motion_cmds:
commands: MoTrack_CommandsCfg = (
build_motion_tracking_commands_config(motion_cmds)
)
else:
commands: MoTrack_CommandsCfg = MoTrack_CommandsCfg()
if vel_cmds:
vel_commands: VelTrack_CommandsCfg = (
build_velocity_commands_config(vel_cmds)
)
for name in vel_cmds.keys():
setattr(commands, name, getattr(vel_commands, name))
observations: ObservationsCfg = build_observations_config(
_obs_config_dict.obs_groups
)
rewards: RewardsCfg = build_rewards_config(_rewards_config_dict)
if _terminations_config_dict:
terminations: TerminationsCfg = build_terminations_config(
_terminations_config_dict
)
else:
terminations: TerminationsCfg = TerminationsCfg()
if _domain_rand_config_dict:
events: EventsCfg = build_domain_rand_config(
_domain_rand_config_dict
)
else:
events: EventsCfg = EventsCfg()
if "base_velocity" in vel_cmds:
events.reset_base = EventTerm(
func=isaaclab_mdp.reset_root_state_uniform,
mode="reset",
params={
"pose_range": {
"x": (-0.5, 0.5),
"y": (-0.5, 0.5),
"yaw": (-3.14, 3.14),
},
"velocity_range": {
"x": (0.0, 0.0),
"y": (0.0, 0.0),
"z": (0.0, 0.0),
"roll": (0.0, 0.0),
"pitch": (0.0, 0.0),
"yaw": (0.0, 0.0),
},
},
)
events.reset_robot_joints = EventTerm(
func=isaaclab_mdp.reset_joints_by_scale,
mode="reset",
params={
"position_range": (1.0, 1.0),
"velocity_range": (-1.0, 1.0),
},
)
curriculum: CurriculumCfg = build_curriculum_config(
_curriculum_config_dict
)
actions: ActionsCfg = build_actions_config(_actions_config_dict)
sim: SimulationCfg = SimulationCfg(
dt=dt,
render_interval=decimation,
physx=physx,
device=_device,
enable_scene_query_support=True,
)
sim.physx.gpu_max_rigid_patch_count = 10 * 2**15
sim.physx.enable_stabilization = True
sim.physics_material = scene.terrain.physics_material
isaaclab_env_cfg = MotionTrackingEnvCfg()
isaaclab_envconfig_dump_path = os.path.join(
self.log_dir, "isaaclab_env_cfg.yaml"
)
dump_yaml(isaaclab_envconfig_dump_path, isaaclab_env_cfg)
self._env = ManagerBasedRLEnv(isaaclab_env_cfg, self.render_mode)
logger.info("IsaacLab environment initialized !")
return self._env
def _init_motion_tracking_components(self):
self.n_fut_frames = self.config.commands.ref_motion.params.n_fut_frames
self.target_fps = self.config.commands.ref_motion.params.target_fps
self._init_serializers()
def step(self, actor_state: dict):
obs_dict, rewards, terminated, time_outs, infos = self._env.step(
actor_state
)
# IsaacLab separates terminated vs time_outs, combine them for consistency
dones = terminated | time_outs
self._update_completion_rate_stats(terminated, time_outs, infos)
self._update_robot_metrics(infos)
return obs_dict, rewards, dones, time_outs, infos
def _update_robot_metrics(self, infos: dict) -> None:
"""Log robot low-level metrics (scalar means) for TensorBoard/console."""
if ("log" not in infos) or (not isinstance(infos["log"], dict)):
infos["log"] = {}
dt = float(self._env.step_dt)
action = self._env.action_manager.action # [B, A]
prev_action = self._env.action_manager.prev_action # [B, A]
action_rate = torch.norm(action - prev_action, dim=-1) / dt # [B]
robot = self._env.scene["robot"]
dof_vel = robot.data.joint_vel # [B, Nd]
dof_torque = robot.data.applied_torque # [B, Nd]
if self._robot_prev_joint_vel is None or (
self._robot_prev_joint_vel.shape != dof_vel.shape
):
self._robot_prev_joint_vel = dof_vel.clone()
dof_acc = (dof_vel - self._robot_prev_joint_vel) / dt # [B, Nd]
self._robot_prev_joint_vel = dof_vel.clone()
if self._robot_prev_applied_torque is None or (
self._robot_prev_applied_torque.shape != dof_torque.shape
):
joint_ids = torch.arange(
dof_torque.shape[1], device=dof_torque.device, dtype=torch.long
)
effort_limit = _select_effort_limit_vector(robot, joint_ids)
self._robot_torque_rate_inv_effort_limit = (
effort_limit.reciprocal()
)
self._robot_prev_applied_torque = torch.zeros_like(dof_torque)
self._robot_torque_rate_needs_reseed = torch.ones(
dof_torque.shape[0], device=dof_torque.device, dtype=torch.bool
)
normed_torque_rate = torch.zeros(
dof_torque.shape[0],
device=dof_torque.device,
dtype=dof_torque.dtype,
)
reseed_mask = self._robot_torque_rate_needs_reseed.clone()
if hasattr(self._env, "episode_length_buf"):
reseed_mask |= self._env.episode_length_buf == 0
active_mask = ~reseed_mask
if torch.any(active_mask):
delta = (
dof_torque[active_mask]
- self._robot_prev_applied_torque[active_mask]
) * self._robot_torque_rate_inv_effort_limit
normed_torque_rate[active_mask] = torch.sum(delta.square(), dim=1)
self._robot_prev_applied_torque.copy_(dof_torque)
self._robot_torque_rate_needs_reseed[reseed_mask] = False
dof_acc_norm = torch.norm(dof_acc, dim=-1) # [B]
dof_torque_norm = torch.norm(dof_torque, dim=-1) # [B]
energy = torch.sum(
torch.abs(dof_vel) * torch.abs(dof_torque), dim=-1
) # [B]
self.metrics["Robot/Action_Rate"] = action_rate.mean()
self.metrics["Robot/DOF_Acc"] = dof_acc_norm.mean()
self.metrics["Robot/DOF_Torque"] = dof_torque_norm.mean()
self.metrics["Robot/Energy"] = energy.mean()
self.metrics["Robot/Normed_Torque_Rate"] = normed_torque_rate.mean()
infos["log"]["Metrics/Robot/Action_Rate"] = self.metrics[
"Robot/Action_Rate"
]
infos["log"]["Metrics/Robot/DOF_Acc"] = self.metrics["Robot/DOF_Acc"]
infos["log"]["Metrics/Robot/DOF_Torque"] = self.metrics[
"Robot/DOF_Torque"
]
infos["log"]["Metrics/Robot/Energy"] = self.metrics["Robot/Energy"]
infos["log"]["Metrics/Robot/Normed_Torque_Rate"] = self.metrics[
"Robot/Normed_Torque_Rate"
]
def _update_completion_rate_stats(
self,
terminated: torch.Tensor,
time_outs: torch.Tensor,
infos: dict,
) -> None:
"""Log completion rate over recent done batches.
Definition:
- Completed: time_outs==True and terminated==False.
- Failed: terminated==True.
The rolling window stores per-step done counts (only when any done occurs).
"""
done_mask = (terminated | time_outs).reshape(-1).bool()
if torch.any(done_mask):
done_count = int(done_mask.sum().item())
completed_mask = (
time_outs.reshape(-1).bool()
& ~terminated.reshape(-1).bool()
& done_mask
)
completed_count = int(completed_mask.sum().item())
self._completion_total_queue.append(done_count)
self._completion_success_queue.append(completed_count)
denom = sum(self._completion_total_queue)
completion_rate = (
float(sum(self._completion_success_queue)) / float(denom)
if denom > 0
else 0.0
)
if ("log" not in infos) or (not isinstance(infos["log"], dict)):
infos["log"] = {}
infos["log"]["Metrics/ref_motion/Task/Completion_Rate"] = torch.tensor(
completion_rate, device=self.device, dtype=torch.float32
)
self.metrics["Metrics/ref_motion/Task/Completion_Rate"] = (
completion_rate
)
def reset_idx(self, env_ids: torch.Tensor):
return self._env.reset(env_ids=env_ids)
def reset_all(self):
env_ids = torch.arange(self.num_envs, device=self.device)
out = self._env.reset(env_ids=env_ids)
return out
def set_is_evaluating(self):
logger.info("Setting environment to evaluation mode")
self.is_evaluating = True
def seed(self, seed: int):
self._env.seed(seed)
================================================
FILE: holomotion/src/env/velocity_tracking.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import torch
import time
import os
import yaml
from collections import deque
from functools import wraps
from easydict import EasyDict
import random
import numpy as np
from isaaclab.actuators import ImplicitActuatorCfg
from isaaclab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg, ViewerCfg
from isaaclab.sim import PhysxCfg, SimulationCfg
from isaaclab.utils import configclass
from isaaclab.utils.io import dump_yaml
from loguru import logger
from omegaconf import OmegaConf
from holomotion.src.env.isaaclab_components import (
ActionsCfg,
VelTrack_CommandsCfg,
MoTrack_CommandsCfg,
EventsCfg,
MotionTrackingSceneCfg,
ObservationsCfg,
RewardsCfg,
TerminationsCfg,
CurriculumCfg,
build_actions_config,
build_motion_tracking_commands_config,
build_velocity_commands_config,
build_domain_rand_config,
build_curriculum_config,
build_observations_config,
build_rewards_config,
build_scene_config,
build_terminations_config,
)
from holomotion.src.env.isaaclab_components.isaaclab_observation import (
ObservationFunctions,
)
from holomotion.src.env.isaaclab_components.isaaclab_utils import (
resolve_holo_config,
)
import isaaclab.envs.mdp as isaaclab_mdp
from isaaclab.envs.mdp.events import _randomize_prop_by_op
from isaaclab.managers import SceneEntityCfg, EventTermCfg
from isaaclab.utils import configclass
from isaaclab.envs import ManagerBasedEnv
from isaaclab.managers import EventTermCfg
from isaaclab.managers import EventTermCfg as EventTerm
import isaaclab.utils.math as math_utils
from isaaclab.assets import Articulation
from isaaclab.envs.mdp.events import _randomize_prop_by_op
from isaaclab.managers import SceneEntityCfg
from typing import TYPE_CHECKING, Literal
class VelocityTrackingEnv:
"""IsaacLab-based Motion Tracking Environment.
This environment integrates motion tracking capabilities with IsaacLab's
manager-based architecture, supporting curriculum learning, domain randomization,
and various termination conditions.
This is a wrapper class that handles Isaac Sim initialization and delegates
to an internal ManagerBasedRLEnv instance.
"""
def __init__(
self,
config,
device: torch.device = None,
log_dir: str = None,
render_mode: str | None = None,
headless: bool = True,
accelerator=None,
):
"""Initialize the Motion Tracking Environment.
Args:
config: Configuration for the environment
device: Device for tensor operations
log_dir: Logging directory
render_mode: Render mode for the environment
headless: Whether to run in headless mode
accelerator: Accelerator instance for distributed training (optional)
"""
self.config = config
self._device = device
self.accelerator = accelerator
self.log_dir = log_dir
self.headless = headless
self.init_done = False
self.is_evaluating = False
self.render_mode = render_mode
# self._init_motion_tracking_components()
self._init_isaaclab_env()
# self._init_serializers()
self._completion_total_queue = deque(maxlen=1000)
self._completion_success_queue = deque(maxlen=1000)
self.metrics = {}
self._robot_prev_joint_vel = None
@property
def num_envs(self):
return self._env.num_envs
@property
def device(self):
return self._env.device
def _init_isaaclab_env(self):
_device = self._device
# curriculum = CurriculumCfg()
# Determine per-process seed if provided; else create a deterministic per-rank default
seed_val = getattr(self.config, "seed", None)
if seed_val is None:
if self.accelerator is not None:
pid = self.accelerator.process_index
else:
pid = int(self.config.get("process_id", 0))
seed_val = int(time.time()) + pid
_robot_config_dict = EasyDict(
OmegaConf.to_container(self.config.robot, resolve=True)
)
_terrain_config_dict = EasyDict(
OmegaConf.to_container(self.config.terrain, resolve=True)
)
_obs_config_dict = EasyDict(
OmegaConf.to_container(self.config.obs, resolve=True)
)
_rewards_config_dict = EasyDict(
OmegaConf.to_container(self.config.rewards, resolve=True)
)
_domain_rand_config_dict = (
EasyDict(
OmegaConf.to_container(
self.config.domain_rand,
resolve=True,
)
)
if self.config.domain_rand is not None
else {}
)
_terminations_config_dict = (
EasyDict(
OmegaConf.to_container(
self.config.terminations,
resolve=True,
)
)
if self.config.terminations is not None
else {}
)
_scene_config_dict = EasyDict(
OmegaConf.to_container(
self.config.scene,
resolve=True,
)
)
_commands_config_dict = OmegaConf.to_container(
self.config.commands,
resolve=True,
)
# Headless + no rendering: disable base_velocity debug visualization.
# In k8s headless runs, IsaacSim/IsaacLab command debug_vis may wedge
# during/after simulation start (seen on velocity-tracking only).
# Keep an escape hatch for debugging/video.
allow_debug_vis = (not self.headless) or (self.render_mode is not None)
force_debug_vis = bool(
int(os.environ.get("HOLOMOTION_VELCMD_DEBUG_VIS", "0"))
)
if (
(not allow_debug_vis)
and (not force_debug_vis)
and isinstance(_commands_config_dict, dict)
and ("base_velocity" in _commands_config_dict)
):
bv = _commands_config_dict.get("base_velocity", {})
bv_params = bv.get("params", {})
if isinstance(bv_params, dict) and bool(
bv_params.get("debug_vis", False)
):
bv_params["debug_vis"] = False
bv["params"] = bv_params
_commands_config_dict["base_velocity"] = bv
logger.warning(
"Disabled base_velocity debug_vis for headless non-render runs. "
"Set HOLOMOTION_VELCMD_DEBUG_VIS=1 to force-enable."
)
_simulation_config_dict = EasyDict(
OmegaConf.to_container(
self.config.simulation,
resolve=True,
)
)
_actions_config_dict = EasyDict(
OmegaConf.to_container(
self.config.actions,
resolve=True,
)
)
@configclass
class VelocityTrackingEnvCfg(ManagerBasedRLEnvCfg):
seed: int = seed_val
scene_config_dict = {
"num_envs": self.config.num_envs,
"env_spacing": self.config.env_spacing,
"replicate_physics": self.config.replicate_physics,
"robot": _robot_config_dict,
"terrain": _terrain_config_dict,
"domain_rand": _domain_rand_config_dict,
"lighting": _scene_config_dict.lighting,
"contact_sensor": _scene_config_dict.contact_sensor,
}
decimation: int = _simulation_config_dict.control_decimation
episode_length_s: int = _simulation_config_dict.episode_length_s
sim_freq = _simulation_config_dict.sim_freq
dt = 1.0 / sim_freq
physx = PhysxCfg(
bounce_threshold_velocity=_simulation_config_dict.physx.bounce_threshold_velocity,
gpu_max_rigid_patch_count=_simulation_config_dict.physx.gpu_max_rigid_patch_count,
enable_stabilization=True,
)
if self.accelerator is not None:
main_process = self.accelerator.is_main_process
process_id = self.accelerator.process_index
num_processes = self.accelerator.num_processes
else:
main_process = self.config.get("main_process", True)
process_id = self.config.get("process_id", 0)
num_processes = self.config.get("num_processes", 1)
scene: MotionTrackingSceneCfg = build_scene_config(
scene_config_dict,
main_process=main_process,
process_id=process_id,
num_processes=num_processes,
)
sim: SimulationCfg = SimulationCfg(
dt=dt,
render_interval=decimation,
physx=physx,
device=_device,
enable_scene_query_support=True,
)
sim.physics_material = scene.terrain.physics_material
viewer: ViewerCfg = ViewerCfg(origin_type="world")
command_name = list(_commands_config_dict.keys())[0]
commands: VelTrack_CommandsCfg = build_velocity_commands_config(
_commands_config_dict
)
observations: ObservationsCfg = build_observations_config(
_obs_config_dict.obs_groups
)
rewards: RewardsCfg = build_rewards_config(_rewards_config_dict)
if _terminations_config_dict:
terminations: TerminationsCfg = build_terminations_config(
_terminations_config_dict
)
else:
terminations: TerminationsCfg = TerminationsCfg()
if _domain_rand_config_dict:
events: EventsCfg = build_domain_rand_config(
_domain_rand_config_dict
)
else:
events: EventsCfg = EventsCfg()
events.reset_base = EventTerm(
func=isaaclab_mdp.reset_root_state_uniform,
mode="reset",
params={
"pose_range": {
"x": (-0.5, 0.5),
"y": (-0.5, 0.5),
"yaw": (-3.14, 3.14),
},
"velocity_range": {
"x": (0.0, 0.0),
"y": (0.0, 0.0),
"z": (0.0, 0.0),
"roll": (0.0, 0.0),
"pitch": (0.0, 0.0),
"yaw": (0.0, 0.0),
},
},
)
events.reset_robot_joints = EventTerm(
func=isaaclab_mdp.reset_joints_by_scale,
mode="reset",
params={
"position_range": (1.0, 1.0),
"velocity_range": (-1.0, 1.0),
},
)
# curriculum: CurriculumCfg = build_curriculum_config(
# getattr(self.config, "curriculum", {})
# )
actions: ActionsCfg = build_actions_config(_actions_config_dict)
sim: SimulationCfg = SimulationCfg(
dt=dt,
render_interval=decimation,
physx=physx,
device=_device,
enable_scene_query_support=True,
)
sim.physx.gpu_max_rigid_patch_count = 10 * 2**15
sim.physx.enable_stabilization = True
sim.physics_material = scene.terrain.physics_material
isaaclab_env_cfg = VelocityTrackingEnvCfg()
isaaclab_envconfig_dump_path = os.path.join(
self.log_dir, "isaaclab_env_cfg.yaml"
)
dump_yaml(isaaclab_envconfig_dump_path, isaaclab_env_cfg)
logger.info(
"Constructing IsaacLab ManagerBasedRLEnv (velocity_tracking) ..."
)
self._env = ManagerBasedRLEnv(isaaclab_env_cfg, self.render_mode)
logger.info(
"IsaacLab ManagerBasedRLEnv constructed (velocity_tracking)."
)
logger.info("IsaacLab environment initialized !")
return self._env
def _init_motion_tracking_components(self):
self._init_serializers()
def step(self, actor_state: dict):
obs_dict, rewards, terminated, time_outs, infos = self._env.step(
actor_state
)
# IsaacLab separates terminated vs time_outs, combine them for consistency
dones = terminated | time_outs
self._update_completion_rate_stats(terminated, time_outs, infos)
return obs_dict, rewards, dones, time_outs, infos
def _update_completion_rate_stats(
self,
terminated: torch.Tensor,
time_outs: torch.Tensor,
infos: dict,
) -> None:
"""Log completion rate over recent done batches.
Definition:
- Completed: time_outs==True and terminated==False.
- Failed: terminated==True.
The rolling window stores per-step done counts (only when any done occurs).
"""
done_mask = (terminated | time_outs).reshape(-1).bool()
if torch.any(done_mask):
done_count = int(done_mask.sum().item())
completed_mask = (
time_outs.reshape(-1).bool()
& ~terminated.reshape(-1).bool()
& done_mask
)
completed_count = int(completed_mask.sum().item())
self._completion_total_queue.append(done_count)
self._completion_success_queue.append(completed_count)
denom = sum(self._completion_total_queue)
completion_rate = (
float(sum(self._completion_success_queue)) / float(denom)
if denom > 0
else 0.0
)
if ("log" not in infos) or (not isinstance(infos["log"], dict)):
infos["log"] = {}
infos["log"]["Task/Completion_Rate"] = torch.tensor(
completion_rate, device=self.device, dtype=torch.float32
)
def reset_idx(self, env_ids: torch.Tensor):
return self._env.reset(env_ids=env_ids)
def reset_all(self):
env_ids = torch.arange(self.num_envs, device=self.device)
out = self._env.reset(env_ids=env_ids)
return out
def set_is_evaluating(self):
logger.info("Setting environment to evaluation mode")
self.is_evaluating = True
def seed(self, seed: int):
self._env.seed(seed)
================================================
FILE: holomotion/src/evaluation/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/evaluation/eval_motion_tracking.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import argparse
import subprocess
from pathlib import Path
from typing import List, Optional, Tuple
from loguru import logger
def find_checkpoints_to_evaluate(
eval_h5_dataset_path: str,
root_dir: Path,
target_checkpoints: Optional[List[str]],
config_name: str,
) -> List[Tuple[str, str]]:
"""Scan all model subdirectories and collect checkpoints that need evaluation.
Behavior:
- If `target_checkpoints` is provided and non-empty:
only these checkpoint stems are considered (e.g. ['model_17500']).
- If `target_checkpoints` is None or empty:
all checkpoints matching 'model_*.pt' under each model directory will be considered.
Returns:
A list of (checkpoint_path, config_name) tuples to be evaluated.
"""
checkpoints_to_evaluate: List[Tuple[str, str]] = []
dataset_path = Path(eval_h5_dataset_path)
dataset_suffix = (
dataset_path.name if dataset_path.name else "dataset_unknown"
)
if root_dir.is_file():
checkpoint_file = root_dir
model_dir_path = checkpoint_file.parent
checkpoint_stem = checkpoint_file.stem
eval_out_dir = (
model_dir_path
/ f"isaaclab_eval_output_{checkpoint_stem}_{dataset_suffix}"
)
cfg_name = f"evaluation/{config_name}"
return [(str(checkpoint_file), cfg_name)]
if not root_dir.is_dir():
logger.error(
f"Checkpoint root directory '{root_dir}' does not exist or is not a directory."
)
return []
if target_checkpoints:
logger.info(
f"Searching for explicit target checkpoints: {target_checkpoints}"
)
# Iterate over each model directory directly under root_dir
for model_dir_path in root_dir.iterdir():
if not model_dir_path.is_dir():
continue
if target_checkpoints:
# Use only the requested checkpoint stems
candidate_files = [
model_dir_path / f"{stem}.pt" for stem in target_checkpoints
]
else:
candidate_files = sorted(model_dir_path.glob("model_*.pt"))
if not candidate_files:
continue
for checkpoint_file in candidate_files:
if not checkpoint_file.is_file():
logger.debug(f"Target checkpoint not found: {checkpoint_file}")
continue
checkpoint_stem = checkpoint_file.stem
eval_out_dir = (
model_dir_path
/ f"isaaclab_eval_output_{checkpoint_stem}_{dataset_suffix}"
)
if eval_out_dir.is_dir():
logger.debug(
f"Skipping {checkpoint_file.name}, output exists."
)
continue
# Construct Hydra config name from the folder name
cfg_name = f"evaluation/{config_name}"
checkpoints_to_evaluate.append((str(checkpoint_file), cfg_name))
checkpoints_to_evaluate.sort(key=lambda x: x[0])
return checkpoints_to_evaluate
def main(
checkpoint_dir: str,
target_checkpoints: Optional[List[str]],
eval_h5_dataset_path: str,
config_name: str,
num_envs: str,
) -> None:
"""
Entry point for batch evaluation.
Args:
checkpoint_root_dir: Root directory containing subdirectories for models.
target_checkpoints: Optional list of checkpoint stems to evaluate
single_eval_script: Path to the shell script to run a single evaluation.
"""
root_path = Path(checkpoint_dir)
checkpoints_to_evaluate = find_checkpoints_to_evaluate(
eval_h5_dataset_path=eval_h5_dataset_path,
root_dir=root_path,
target_checkpoints=target_checkpoints,
config_name=config_name,
)
if not checkpoints_to_evaluate:
logger.warning(
f"No pending evaluations found under '{checkpoint_dir}'."
)
return
logger.info(
f"Found {len(checkpoints_to_evaluate)} checkpoints to evaluate."
)
for i, (ckpt_path, cfg_name) in enumerate(
checkpoints_to_evaluate, start=1
):
logger.info(
f"[{i}/{len(checkpoints_to_evaluate)}] Evaluating: {cfg_name}/{ckpt_path}"
)
command = [
"bash",
"holomotion/scripts/evaluation/eval_motion_tracking_single.sh",
ckpt_path,
cfg_name,
eval_h5_dataset_path,
num_envs,
]
subprocess.run(
command,
)
def parse_args() -> argparse.Namespace:
"""Parse CLI arguments for the batch evaluation script."""
parser = argparse.ArgumentParser(description="motion-tracking evaluation.")
parser.add_argument("--checkpoint_dir", type=str, required=True)
parser.add_argument(
"--target_checkpoints", type=str, nargs="*", default=None
)
parser.add_argument("--config_name", type=str, required=True)
parser.add_argument("--eval_h5_dataset_path", type=str, required=True)
parser.add_argument("--num_envs", type=str, required=True)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(
checkpoint_dir=args.checkpoint_dir,
target_checkpoints=args.target_checkpoints,
eval_h5_dataset_path=args.eval_h5_dataset_path,
config_name=args.config_name,
num_envs=args.num_envs,
)
================================================
FILE: holomotion/src/evaluation/eval_motion_tracking_single.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import re
from pathlib import Path
import hydra
from hydra.utils import get_class
from loguru import logger
from omegaconf import ListConfig, OmegaConf
from holomotion.src.evaluation.metrics import run_evaluation
from holomotion.src.utils.config import compile_config
from holomotion.src.utils.onnx_export import export_policy_to_onnx
def load_training_config(
checkpoint_path: str, eval_config: OmegaConf
) -> OmegaConf:
"""Load training config from checkpoint directory.
Args:
checkpoint_path: Path to the checkpoint file.
eval_config: Full evaluation config (including command line overrides).
Returns:
Merged config with training config as base.
"""
checkpoint = Path(checkpoint_path)
config_path = checkpoint.parent / "config.yaml"
if not config_path.exists():
config_path = checkpoint.parent.parent / "config.yaml"
if not config_path.exists():
logger.warning(
f"Training config not found at {config_path}, using evaluation config"
)
return eval_config
logger.info(f"Loading training config from {config_path}")
with open(config_path) as file:
train_config = OmegaConf.load(file)
# Apply eval_overrides from training config if they exist
if train_config.get("eval_overrides") is not None:
train_config = OmegaConf.merge(
train_config, train_config.eval_overrides
)
# Set checkpoint path
train_config.checkpoint = checkpoint_path
train_config.algo.config.checkpoint = checkpoint_path
# For evaluation, merge eval_config into train_config
config = OmegaConf.merge(train_config, eval_config)
# force set the terminations and domain rand with eval_config's
config.env.config.terminations = eval_config.env.config.terminations
config.env.config.domain_rand = eval_config.env.config.domain_rand
obs_groups = config.env.config.obs.obs_groups
if "policy" in obs_groups:
obs_groups.policy.enable_corruption = False
if "critic" in obs_groups:
obs_groups.critic.enable_corruption = False
if "unified" in obs_groups:
obs_groups.unified.enable_corruption = False
return config
def _infer_dataset_suffix(output_dir: str, checkpoint_path: str) -> str:
output_name = Path(output_dir).name
model_name = Path(checkpoint_path).stem
expected_prefix = f"isaaclab_eval_output_{model_name}_"
if output_name.startswith(expected_prefix):
return output_name[len(expected_prefix) :]
return output_name
def _checkpoint_sort_key(checkpoint_path: Path):
match = re.search(r"model_(\d+)\.pt$", checkpoint_path.name)
if match is not None:
return (0, int(match.group(1)), checkpoint_path.name)
return (1, checkpoint_path.name)
def _normalize_ckpt_pt_names(ckpt_pt_names) -> list[str]:
if ckpt_pt_names is None:
return []
if isinstance(ckpt_pt_names, ListConfig):
raw_names = list(ckpt_pt_names)
elif isinstance(ckpt_pt_names, (list, tuple)):
raw_names = list(ckpt_pt_names)
else:
raise TypeError(
f"ckpt_pt_names must be a list/tuple, got {type(ckpt_pt_names)}"
)
normalized_names = []
for name in raw_names:
name_str = str(name).strip()
if name_str == "":
continue
if not name_str.endswith(".pt"):
name_str = f"{name_str}.pt"
normalized_names.append(name_str)
return normalized_names
def _resolve_export_ckpt_paths(config: OmegaConf) -> list[Path]:
log_dir_value = config.get("log_dir", None)
checkpoint_value = config.get("checkpoint", None)
if log_dir_value is None or str(log_dir_value).strip() == "":
if checkpoint_value is None or str(checkpoint_value).strip() == "":
raise ValueError(
"When export_only=true, set log_dir or checkpoint."
)
log_dir = Path(str(checkpoint_value)).parent
else:
log_dir = Path(str(log_dir_value))
if not log_dir.is_dir():
raise NotADirectoryError(
f"log_dir does not exist or is not a directory: {log_dir}"
)
ckpt_pt_names = _normalize_ckpt_pt_names(config.get("ckpt_pt_names", None))
if len(ckpt_pt_names) > 0:
selected_paths = []
missing_names = []
for name in ckpt_pt_names:
ckpt_path = log_dir / name
if ckpt_path.is_file():
selected_paths.append(ckpt_path)
else:
missing_names.append(name)
if len(missing_names) > 0:
raise FileNotFoundError(
f"Missing checkpoints in log_dir={log_dir}: {missing_names}"
)
return selected_paths
discovered_paths = sorted(log_dir.glob("*.pt"), key=_checkpoint_sort_key)
if len(discovered_paths) == 0:
raise FileNotFoundError(
f"No .pt checkpoints found in log_dir={log_dir}"
)
return discovered_paths
@hydra.main(
config_path="../../config",
config_name="evaluation/eval_isaaclab",
version_base=None,
)
def main(config: OmegaConf):
"""Evaluate the motion tracking model.
Args:
config: OmegaConf object containing the evaluation configuration.
"""
export_only = bool(config.get("export_only", False))
if export_only:
checkpoint_paths = _resolve_export_ckpt_paths(config)
config = load_training_config(str(checkpoint_paths[0]), config)
else:
if config.checkpoint is None:
raise ValueError("Checkpoint path must be provided for evaluation")
checkpoint_paths = [Path(str(config.checkpoint))]
config = load_training_config(config.checkpoint, config)
# Compile config without accelerator (PPO will create it)
config = compile_config(config, accelerator=None)
# Use checkpoint directory as log_dir for offline evaluation/export.
log_dir = str(checkpoint_paths[0].parent)
headless = config.headless
# PPO creates Accelerator, AppLauncher, and environment internally
algo_class = get_class(config.algo._target_)
algo = algo_class(
env_config=config.env,
config=config.algo.config,
log_dir=log_dir,
headless=headless,
is_offline_eval=True,
)
if (
algo.accelerator.is_main_process
and os.environ.get("TORCH_COMPILE_DISABLE", "0") != "1"
):
logger.info(
"Tip: If you encounter Triton/compilation errors during evaluation,"
)
logger.info(
" set environment variable: export TORCH_COMPILE_DISABLE=1"
)
if algo.accelerator.is_main_process:
with open(os.path.join(log_dir, "eval_config.yaml"), "w") as f:
OmegaConf.save(config, f)
if export_only:
if algo.accelerator.is_main_process:
logger.info(
"Running export-only mode for "
f"{len(checkpoint_paths)} checkpoints in {log_dir}"
)
onnx_name_suffix = config.get("onnx_name_suffix", None)
use_kv_cache = config.get("use_kv_cache", True)
for i, checkpoint_path in enumerate(checkpoint_paths, start=1):
ckpt_path = str(checkpoint_path)
if algo.accelerator.is_main_process:
logger.info(
f"[{i}/{len(checkpoint_paths)}] Loading checkpoint: "
f"{ckpt_path}"
)
algo.load(ckpt_path)
if algo.accelerator.is_main_process:
onnx_path = export_policy_to_onnx(
algo,
ckpt_path,
onnx_name_suffix=onnx_name_suffix,
use_kv_cache=use_kv_cache,
)
logger.info(f"Successfully exported policy to: {onnx_path}")
algo.accelerator.wait_for_everyone()
if algo.accelerator.is_main_process:
logger.info("Export-only mode completed successfully!")
return
if algo.accelerator.is_main_process:
logger.info(f"Loading checkpoint for evaluation: {config.checkpoint}")
algo.load(config.checkpoint)
command_name = list(config.env.config.commands.keys())[0]
if command_name == "ref_motion":
motion_cmd = algo.env._env.command_manager.get_term("ref_motion")
algo.env._env.reset()
motion_cmd._update_ref_motion_state()
# Export ONNX if requested
if config.get("export_policy", True):
if algo.accelerator.is_main_process:
onnx_name_suffix = config.get("onnx_name_suffix", None)
onnx_path = export_policy_to_onnx(
algo,
config.checkpoint,
onnx_name_suffix=onnx_name_suffix,
use_kv_cache=config.get("use_kv_cache", True),
)
logger.info(f"Successfully exported policy to: {onnx_path}")
algo.accelerator.wait_for_everyone()
calc_per_clip_metrics = bool(config.get("calc_per_clip_metrics", False))
generate_report = bool(config.get("generate_report", False))
dump_npzs = bool(config.get("dump_npzs", False)) or calc_per_clip_metrics
dof_mode = config.get("dof_mode", "29")
if (
calc_per_clip_metrics
and not bool(config.get("dump_npzs", False))
and algo.accelerator.is_main_process
):
logger.info(
"calc_per_clip_metrics=true requires dumped NPZs; "
"enabling dump_npzs automatically."
)
result = algo.offline_evaluate_policy(dump_npzs)
algo.accelerator.wait_for_everyone()
if algo.accelerator.is_main_process:
logger.info("Evaluation completed successfully!")
output_dir = (
result.get("output_dir") if isinstance(result, dict) else None
)
if output_dir is not None:
logger.info(f"NPZs saved to: {output_dir}")
if calc_per_clip_metrics:
if output_dir is None:
logger.warning(
"Skipping per-clip metric calculation because "
"output_dir is unavailable."
)
else:
dataset_suffix = _infer_dataset_suffix(
output_dir, config.checkpoint
)
run_evaluation(
npz_dir=output_dir,
dataset_suffix=dataset_suffix,
failure_pos_err_thresh_m=0.25,
dof_mode=dof_mode,
)
logger.info(
f"Finished per-clip metric calculation for: {output_dir}"
)
if generate_report:
if output_dir is None:
logger.warning(
"Skipping report generation because output_dir is unavailable."
)
else:
from holomotion.scripts.evaluation import (
mean_process_5metrics,
)
report_path = mean_process_5metrics.generate_macro_mean_report_from_json_dir(
output_dir
)
logger.info(f"Generated metrics report at: {report_path}")
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/evaluation/eval_mujoco_sim2sim.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import csv
import shutil
import sys
import threading
import time
from collections import deque
from pathlib import Path
from threading import Thread
import cv2
import hydra
import mujoco
import mujoco.viewer
import numpy as np
import onnx
import onnxruntime
import torch
from loguru import logger
from omegaconf import ListConfig, OmegaConf, open_dict
from tqdm import tqdm
import glob
import re
import ray
from holomotion.src.evaluation.metrics import run_evaluation
try:
from horizon_tc_ui.hb_runtime import HBRuntime
except ImportError:
HB_ONNXRuntime = None
logger.warning("HB_ONNXRuntime not available!")
ONNX_IO_DUMP_DIRNAME = "onnx_io_npy"
try:
import pynput.keyboard as pynput_kb
PYNPUT_AVAILABLE = True
except ImportError:
PYNPUT_AVAILABLE = False
if "headless" in sys.argv and "false" in sys.argv:
logger.warning("pynput not available, keyboard control disabled")
from holomotion.src.evaluation.obs import PolicyObsBuilder
from holomotion.src.utils.torch_utils import (
quat_apply,
quat_inv,
subtract_frame_transforms,
quat_normalize_wxyz,
matrix_from_quat,
xyzw_to_wxyz,
quat_mul,
quat_from_euler_xyz,
)
from holomotion.src.motion_retargeting.utils.rotation_conversions import (
standardize_quaternion,
)
DEFAULT_FEET_GEOM_NAMES = {
"left": ["left_foot"],
"right": ["right_foot"],
}
DEFAULT_FEET_BODY_NAMES = {
"left": ["left_ankle_roll_link"],
"right": ["right_ankle_roll_link"],
}
def _coerce_config_bool(value, default: bool = False) -> bool:
"""Interpret config booleans without treating non-empty strings as truthy."""
if value is None:
return default
if isinstance(value, (bool, np.bool_)):
return bool(value)
if isinstance(value, str):
value = value.strip().lower()
if value in {"1", "true", "yes", "y", "on"}:
return True
if value in {"0", "false", "no", "n", "off", ""}:
return False
return bool(value)
class OffscreenRenderer:
"""Minimal offscreen renderer for MuJoCo frames."""
def __init__(
self,
model,
height: int,
width: int,
distance: float | None = None,
azimuth: float | None = None,
elevation: float | None = None,
):
self.model = model
self.height = height
self.width = width
self._overlay_callback = None
self._gl_ctx = mujoco.GLContext(width, height)
self._gl_ctx.make_current()
self._scene = mujoco.MjvScene(model, maxgeom=1000)
self._cam = mujoco.MjvCamera()
self._opt = mujoco.MjvOption()
mujoco.mjv_defaultFreeCamera(model, self._cam)
self.set_align_view(
distance=distance,
azimuth=azimuth,
elevation=elevation,
)
self._con = mujoco.MjrContext(
model,
mujoco.mjtFontScale.mjFONTSCALE_100,
)
self._rgb = np.zeros((height, width, 3), dtype=np.uint8)
self._viewport = mujoco.MjrRect(0, 0, width, height)
def set_overlay_callback(self, callback) -> None:
"""Register a callback to draw custom geoms into the scene each frame."""
self._overlay_callback = callback
def render(self, data) -> np.ndarray:
mujoco.mjv_updateScene(
self.model,
data,
self._opt,
None,
self._cam,
mujoco.mjtCatBit.mjCAT_ALL.value,
self._scene,
)
if self._overlay_callback is not None:
self._overlay_callback(self._scene)
mujoco.mjr_render(self._viewport, self._scene, self._con)
mujoco.mjr_readPixels(self._rgb, None, self._viewport, self._con)
return np.flipud(self._rgb)
def set_align_view(
self,
lookat: np.ndarray | None = None,
distance: float | None = None,
azimuth: float | None = None,
elevation: float | None = None,
):
"""Set camera to 'align' preset view (default azimuth=60, elevation=-20).
Args:
lookat: Optional lookat point [x, y, z]. If None, uses current lookat.
distance: Optional camera distance from lookat point. If None, uses current distance.
"""
self._cam.type = mujoco.mjtCamera.mjCAMERA_FREE
if azimuth is None:
self._cam.azimuth = 60.0 # Side view (looking along Y-axis)
else:
self._cam.azimuth = float(azimuth)
if elevation is None:
self._cam.elevation = -20.0 # Slight downward angle
else:
self._cam.elevation = float(elevation)
if lookat is not None:
self._cam.lookat = np.asarray(lookat, dtype=np.float32)
if distance is not None:
self._cam.distance = float(distance)
def close(self):
self._gl_ctx.free()
class VelocityKeyboardHandler:
"""Keyboard handler for interactive velocity commands using WASD and JL keys."""
def __init__(
self,
vx_increment: float = 0.1,
vy_increment: float = 0.1,
vyaw_increment: float = 0.05,
vx_limits: tuple = (-0.5, 1.0),
vy_limits: tuple = (-0.3, 0.3),
vyaw_limits: tuple = (-0.5, 0.5),
):
self.vx_increment = vx_increment
self.vy_increment = vy_increment
self.vyaw_increment = vyaw_increment
# Velocity limits from training config
self.vx_min, self.vx_max = vx_limits
self.vy_min, self.vy_max = vy_limits
self.vyaw_min, self.vyaw_max = vyaw_limits
self.vx = 0.0
self.vy = 0.0
self.vyaw = 0.0
self._listener = None
self._lock = threading.Lock()
def start_listener(self):
"""Start keyboard listener thread (requires pynput)."""
if not PYNPUT_AVAILABLE:
logger.warning("pynput not available, keyboard control disabled")
return
def on_press(key):
try:
if hasattr(key, "char") and key.char:
self._handle_key(key.char)
except AttributeError:
pass
self._listener = pynput_kb.Listener(on_press=on_press)
self._listener.start()
logger.info(
f"Keyboard listener started. Velocity limits: "
f"vx=[{self.vx_min:.1f},{self.vx_max:.1f}], "
f"vy=[{self.vy_min:.1f},{self.vy_max:.1f}], "
f"vyaw=[{self.vyaw_min:.1f},{self.vyaw_max:.1f}]"
)
def stop_listener(self):
"""Stop keyboard listener thread."""
if self._listener is not None:
self._listener.stop()
self._listener = None
def get_velocity_command(self) -> np.ndarray:
"""Get velocity command [vx, vy, vyaw].
Returns:
Velocity command [vx, vy, vyaw]
"""
with self._lock:
return np.array([self.vx, self.vy, self.vyaw], dtype=np.float32)
def _handle_key(self, char: str):
"""Handle keyboard press events."""
with self._lock:
# W/S for vx (forward/backward)
if char in ["W", "w"]:
self.vx = np.clip(
self.vx + self.vx_increment, self.vx_min, self.vx_max
)
logger.info(
f"[W] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}"
)
elif char in ["S", "s"]:
self.vx = np.clip(
self.vx - self.vx_increment, self.vx_min, self.vx_max
)
logger.info(
f"[S] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}"
)
# A/D for vy (left/right)
elif char in ["A", "a"]:
self.vy = np.clip(
self.vy + self.vy_increment, self.vy_min, self.vy_max
)
logger.info(
f"[A] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}"
)
elif char in ["D", "d"]:
self.vy = np.clip(
self.vy - self.vy_increment, self.vy_min, self.vy_max
)
logger.info(
f"[D] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}"
)
# J/L for vyaw (turn left/right)
elif char in ["J", "j"]:
self.vyaw = np.clip(
self.vyaw + self.vyaw_increment,
self.vyaw_min,
self.vyaw_max,
)
logger.info(
f"[J] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}"
)
elif char in ["L", "l"]:
self.vyaw = np.clip(
self.vyaw - self.vyaw_increment,
self.vyaw_min,
self.vyaw_max,
)
logger.info(
f"[L] vx={self.vx:.2f}, vy={self.vy:.2f}, vyaw={self.vyaw:.2f}"
)
# Space to reset all
elif char == " ":
self.vx = 0.0
self.vy = 0.0
self.vyaw = 0.0
logger.info("[Space] Command reset to zero")
# X to stop (emergency brake)
elif char in ["X", "x"]:
self.vx = 0.0
self.vy = 0.0
self.vyaw = 0.0
logger.info("[X] Emergency stop - all velocities set to zero")
class MujocoEvaluator:
"""Class to handle MuJoCo simulation for policy evaluation."""
def __init__(self, config):
"""Initialize the MuJoCo evaluator.
Args:
config: Configuration object with simulation parameters.
"""
self.config = config
# Initialize variables
self.policy_session = None
self.motion_encoding = None
self.m = None # MuJoCo model
self.d = None # MuJoCo data
# Determine command mode from config
self.command_mode = self._detect_command_mode()
if "motion_npz_dir" not in config:
logger.info(f"Command mode: {self.command_mode}")
# Motion data
self.ref_dof_pos = None
self.ref_dof_vel = None
self.filter_cutoff_hz = None
self.n_motion_frames = 0
self.motion_frame_idx = 0
# Velocity command (for velocity tracking mode)
self.velocity_command = np.zeros(3, dtype=np.float32) # [vx, vy, vyaw]
self.target_heading = 0.0 # Target heading for velocity tracking
self.keyboard_handler = (
None # Will be initialized if velocity_tracking
)
# Extract configuration parameters
self.simulation_dt = 1 / 200
self.policy_dt = 1 / 50
self.control_decimation = 4
self.dof_names_ref_motion = list(config.robot.dof_names)
self.num_actions = len(self.dof_names_ref_motion)
self.action_scale_onnx = np.ones(self.num_actions, dtype=np.float32)
self.kps_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.kds_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.default_angles_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.target_dof_pos_onnx = self.default_angles_onnx.copy()
self.actions_onnx = np.zeros(self.num_actions, dtype=np.float32)
self.n_fut_frames = int(config.obs.n_fut_frames)
self.actor_place_holder_ndim = self._find_actor_place_holder_ndim()
self.use_kv_cache = False
self.policy_kv_cache = None
self.policy_kv_input_name = None
self.policy_kv_output_name = None
self.policy_kv_shape = None
self.policy_model_context_len = 0
algo_cfg = self.config.get("algo", None)
if algo_cfg is None:
raise ValueError("Missing config.algo for MuJoCo evaluation.")
algo_config = algo_cfg.get("config", None)
if algo_config is None:
raise ValueError(
"Missing config.algo.config for MuJoCo evaluation."
)
max_context_len_cfg = algo_config.get("num_steps_per_env", None)
if max_context_len_cfg is None:
raise ValueError(
"Missing config.algo.config.num_steps_per_env for MuJoCo evaluation."
)
self.max_context_len = int(max_context_len_cfg)
if self.max_context_len <= 0:
raise ValueError(
"config.algo.config.num_steps_per_env must be > 0, "
f"got {self.max_context_len}"
)
self.policy_effective_context_len = 0
self.counter = 0
self.tau_hist = []
# Latest Unitree lowstate message (populated when using Unitree bridge)
# self._lowstate_msg = None
# Desired target positions keyed by DOF name (updated after each policy step)
self.target_dof_pos_by_name = {}
# Video/recording related
self._video_writer = None
self._offscreen = None
self._frame_interval = None
self._last_frame_time = 0.0
# Reference(global)->Simulation(global) rigid transform (computed at init)
self._ref_to_sim_ready = False
self._ref_to_sim_q_wxyz = np.array(
[1.0, 0.0, 0.0, 0.0], dtype=np.float32
)
self._ref_to_sim_t = np.zeros(3, dtype=np.float32)
# Optional offset between reference globals and dataset body names (e.g., world body at index 0)
# Robot state recording buffers for offline NPZ dumping
self._robot_dof_pos_seq: list[np.ndarray] = []
self._robot_dof_vel_seq: list[np.ndarray] = []
self._robot_dof_acc_seq: list[np.ndarray] = []
self._robot_dof_torque_seq: list[np.ndarray] = []
self._robot_low_level_dof_torque_seq: list[np.ndarray] = []
self._robot_low_level_foot_contact_seq: list[np.ndarray] = []
self._robot_low_level_foot_normal_force_seq: list[np.ndarray] = []
self._robot_low_level_foot_tangent_speed_seq: list[np.ndarray] = []
self._robot_actions_seq: list[np.ndarray] = []
self._robot_action_rate_seq: list[np.float32] = []
self._robot_global_translation_seq: list[np.ndarray] = []
self._robot_global_rotation_quat_seq: list[np.ndarray] = []
self._robot_global_velocity_seq: list[np.ndarray] = []
self._robot_global_angular_velocity_seq: list[np.ndarray] = []
self._robot_moe_expert_indices_seq: list[np.ndarray] = []
self._robot_moe_expert_logits_seq: list[np.ndarray] = []
self._prev_recorded_dof_vel_ref: np.ndarray | None = None
self._prev_actions_onnx: np.ndarray | None = None
(
self.action_ema_filter_enabled,
self.action_ema_filter_alpha,
) = self._get_action_ema_filter_cfg()
self._filtered_actions_onnx: np.ndarray | None = None
(
self.policy_action_delay_step,
self.action_delay_type,
) = self._get_action_delay_cfg()
self._policy_action_delay_buffer: deque[np.ndarray] = deque(
maxlen=max(1, self.policy_action_delay_step + 1)
)
self._current_policy_action_delay_step = 0
self._reset_action_delay_randomization()
# Camera config (viewer + offscreen)
self._camera_tracking_enabled = bool(
self.config.get("camera_tracking", True)
)
self._camera_height_offset = float(
self.config.get("camera_height_offset", 0.3)
)
self._camera_distance = float(self.config.get("camera_distance", 4.0))
self._camera_azimuth = float(self.config.get("camera_azimuth", 60.0))
self._camera_elevation = float(
self.config.get("camera_elevation", -20.0)
)
self._root_body_id = -1
self._foot_contact_logging_enabled = False
self._foot_geom_id_groups: list[list[int]] = [[], []]
self._foot_geom_id_to_side: dict[int, int] = {}
self._prev_low_level_foot_geom_centers: np.ndarray | None = None
self.dump_onnx_io_npy = bool(
self.config.get("dump_onnx_io_npy", False)
)
self.policy_moe_layer_output_names: list[tuple[int, str, str]] = []
self._reset_onnx_io_dump_buffers()
def _reset_onnx_io_dump_buffers(self):
self._onnx_io_input_names: list[str] = []
self._onnx_io_output_names: list[str] = []
self._onnx_io_inputs: dict[str, list[np.ndarray]] = {}
self._onnx_io_outputs: dict[str, list[np.ndarray]] = {}
def _get_action_ema_filter_cfg(self) -> tuple[bool, float]:
actuator_cfg = self.config.get("robot", {}).get("actuators", {})
actuator_type = actuator_cfg.get("actuator_type", "unitree")
if actuator_type != "unitree_erfi":
return False, 1.0
enabled = _coerce_config_bool(
actuator_cfg.get("ema_filter_enabled", False), default=False
)
alpha = float(actuator_cfg.get("ema_filter_alpha", 1.0))
if not 0.0 <= alpha <= 1.0:
raise ValueError(
"robot.actuators.ema_filter_alpha must be within [0, 1], "
f"got {alpha}."
)
return enabled, alpha
def _reset_action_ema_filter(self) -> None:
self._filtered_actions_onnx = None
def _apply_action_ema_filter(self, raw_actions: np.ndarray) -> np.ndarray:
raw_actions = np.asarray(raw_actions, dtype=np.float32)
if not self.action_ema_filter_enabled:
return raw_actions.copy()
if self._filtered_actions_onnx is None:
self._filtered_actions_onnx = raw_actions.copy()
return self._filtered_actions_onnx.copy()
# self.action_ema_filter_alpha = 0.7
filtered_actions = (
self.action_ema_filter_alpha * raw_actions
+ (1.0 - self.action_ema_filter_alpha)
* self._filtered_actions_onnx
).astype(np.float32, copy=False)
self._filtered_actions_onnx = filtered_actions.copy()
return self._filtered_actions_onnx.copy()
def _get_action_delay_cfg(self) -> tuple[int, str]:
max_delay_step = int(self.config.get("policy_action_delay_step", 0))
if max_delay_step < 0:
raise ValueError(
"policy_action_delay_step must be non-negative, "
f"got {max_delay_step}."
)
delay_type = (
str(self.config.get("action_delay_type", "episode"))
.strip()
.lower()
)
if delay_type not in {"step", "episode"}:
raise ValueError(
"action_delay_type must be one of {'step', 'episode'}, "
f"got {delay_type!r}."
)
return max_delay_step, delay_type
def _sample_policy_action_delay_step(self) -> int:
if self.policy_action_delay_step <= 0:
return 0
return int(np.random.randint(0, self.policy_action_delay_step + 1))
def _reset_action_delay_randomization(self) -> None:
self._policy_action_delay_buffer = deque(
maxlen=max(1, self.policy_action_delay_step + 1)
)
if self.policy_action_delay_step <= 0:
self._current_policy_action_delay_step = 0
return
if self.action_delay_type == "episode":
self._current_policy_action_delay_step = (
self._sample_policy_action_delay_step()
)
else:
self._current_policy_action_delay_step = 0
def _apply_action_delay(self, raw_actions: np.ndarray) -> np.ndarray:
raw_actions = np.asarray(raw_actions, dtype=np.float32)
if self.policy_action_delay_step <= 0:
return raw_actions.copy()
expected_buffer_len = max(1, self.policy_action_delay_step + 1)
if (
not hasattr(self, "_policy_action_delay_buffer")
or self._policy_action_delay_buffer.maxlen != expected_buffer_len
):
self._reset_action_delay_randomization()
if self.action_delay_type == "step":
self._current_policy_action_delay_step = (
self._sample_policy_action_delay_step()
)
self._policy_action_delay_buffer.append(raw_actions.copy())
if self._current_policy_action_delay_step >= len(
self._policy_action_delay_buffer
):
return self._policy_action_delay_buffer[-1].copy()
return self._policy_action_delay_buffer[
-1 - self._current_policy_action_delay_step
].copy()
@staticmethod
def _normalize_foot_geom_name_groups(raw_spec) -> list[list[str]]:
if raw_spec is None:
return [[], []]
if OmegaConf.is_config(raw_spec):
raw_spec = OmegaConf.to_container(raw_spec, resolve=True)
def coerce_names(value) -> list[str]:
if value is None:
return []
if isinstance(value, str):
return [value]
if isinstance(value, (list, tuple)):
return [str(name) for name in value if str(name)]
return []
if isinstance(raw_spec, dict):
return [
coerce_names(raw_spec.get("left", raw_spec.get("left_foot"))),
coerce_names(
raw_spec.get("right", raw_spec.get("right_foot"))
),
]
if isinstance(raw_spec, (list, tuple)) and len(raw_spec) == 2:
return [coerce_names(raw_spec[0]), coerce_names(raw_spec[1])]
logger.warning(
"Unsupported robot.feet_geom_names format. Ignoring configured "
"foot geom names."
)
return [[], []]
@staticmethod
def _normalize_foot_body_name_groups(raw_spec) -> list[list[str]]:
if raw_spec is None:
return [
list(DEFAULT_FEET_BODY_NAMES["left"]),
list(DEFAULT_FEET_BODY_NAMES["right"]),
]
if OmegaConf.is_config(raw_spec):
raw_spec = OmegaConf.to_container(raw_spec, resolve=True)
def coerce_names(value) -> list[str]:
if value is None:
return []
if isinstance(value, str):
return [value]
if isinstance(value, (list, tuple)):
return [str(name) for name in value if str(name)]
return []
if isinstance(raw_spec, dict):
return [
coerce_names(raw_spec.get("left", raw_spec.get("left_foot"))),
coerce_names(
raw_spec.get("right", raw_spec.get("right_foot"))
),
]
if isinstance(raw_spec, (list, tuple)) and len(raw_spec) == 2:
return [coerce_names(raw_spec[0]), coerce_names(raw_spec[1])]
logger.warning(
"Unsupported robot.feet_body_names format. Falling back to "
f"default foot bodies: {DEFAULT_FEET_BODY_NAMES}"
)
return [
list(DEFAULT_FEET_BODY_NAMES["left"]),
list(DEFAULT_FEET_BODY_NAMES["right"]),
]
def _resolve_foot_geom_ids_from_geom_names(
self, foot_geom_name_groups: list[list[str]]
) -> list[list[int]]:
foot_geom_id_groups: list[list[int]] = [[], []]
for side_idx, geom_names in enumerate(foot_geom_name_groups):
for geom_name in geom_names:
geom_id = mujoco.mj_name2id(
self.m, mujoco.mjtObj.mjOBJ_GEOM, geom_name
)
if geom_id == -1:
logger.warning(
f"Foot geom '{geom_name}' was not found in the MuJoCo model."
)
continue
foot_geom_id_groups[side_idx].append(int(geom_id))
return foot_geom_id_groups
def _resolve_foot_geom_ids_from_body_names(
self, foot_body_name_groups: list[list[str]]
) -> list[list[int]]:
foot_geom_id_groups: list[list[int]] = [[], []]
geom_bodyid = np.asarray(self.m.geom_bodyid, dtype=np.int32)
geom_contype = np.asarray(self.m.geom_contype, dtype=np.int32)
geom_conaffinity = np.asarray(self.m.geom_conaffinity, dtype=np.int32)
collidable_mask = (geom_contype != 0) | (geom_conaffinity != 0)
for side_idx, body_names in enumerate(foot_body_name_groups):
resolved_geom_ids: list[int] = []
for body_name in body_names:
body_id = mujoco.mj_name2id(
self.m, mujoco.mjtObj.mjOBJ_BODY, body_name
)
if body_id == -1:
logger.warning(
f"Foot body '{body_name}' was not found in the MuJoCo model."
)
continue
body_geom_ids = np.flatnonzero(geom_bodyid == int(body_id))
if body_geom_ids.size == 0:
logger.warning(
f"Foot body '{body_name}' has no attached geoms."
)
continue
contact_geom_ids = body_geom_ids[
collidable_mask[body_geom_ids]
]
if contact_geom_ids.size == 0:
contact_geom_ids = body_geom_ids
resolved_geom_ids.extend(contact_geom_ids.astype(int).tolist())
# Preserve order while removing duplicates.
deduped = list(dict.fromkeys(resolved_geom_ids))
foot_geom_id_groups[side_idx] = deduped
return foot_geom_id_groups
def _init_low_level_foot_contact_logging(self) -> None:
self._foot_geom_id_groups = [[], []]
self._foot_geom_id_to_side = {}
self._foot_contact_logging_enabled = False
self._prev_low_level_foot_geom_centers = None
foot_geom_name_groups = self._normalize_foot_geom_name_groups(
getattr(self.config.robot, "feet_geom_names", None)
)
foot_body_name_groups = self._normalize_foot_body_name_groups(
getattr(self.config.robot, "feet_body_names", None)
)
geom_name_groups = self._resolve_foot_geom_ids_from_geom_names(
foot_geom_name_groups
)
body_name_groups = self._resolve_foot_geom_ids_from_body_names(
foot_body_name_groups
)
for side_idx in range(2):
resolved_ids = (
geom_name_groups[side_idx]
if len(geom_name_groups[side_idx]) > 0
else body_name_groups[side_idx]
)
self._foot_geom_id_groups[side_idx] = list(resolved_ids)
for geom_id in resolved_ids:
self._foot_geom_id_to_side[int(geom_id)] = side_idx
if any(len(group) == 0 for group in self._foot_geom_id_groups):
logger.warning(
"Low-level foot contact logging is unavailable because one or "
"both foot geom groups could not be resolved. Contact metrics "
"will be written as NaN."
)
return
self._foot_contact_logging_enabled = True
def _record_low_level_foot_contact_sample(self) -> None:
foot_contact = np.full((2,), np.nan, dtype=np.float32)
foot_normal_force = np.full((2,), np.nan, dtype=np.float32)
foot_tangent_speed = np.full((2,), np.nan, dtype=np.float32)
if not self._foot_contact_logging_enabled:
self._robot_low_level_foot_contact_seq.append(foot_contact)
self._robot_low_level_foot_normal_force_seq.append(
foot_normal_force
)
self._robot_low_level_foot_tangent_speed_seq.append(
foot_tangent_speed
)
return
current_centers = np.zeros((2, 3), dtype=np.float32)
for side_idx, geom_ids in enumerate(self._foot_geom_id_groups):
current_centers[side_idx] = np.mean(
self.d.geom_xpos[np.asarray(geom_ids, dtype=np.int32)],
axis=0,
).astype(np.float32)
if self._prev_low_level_foot_geom_centers is None:
tangential_speed = np.zeros((2,), dtype=np.float32)
else:
foot_velocity = (
current_centers - self._prev_low_level_foot_geom_centers
) / np.float32(self.simulation_dt)
tangential_speed = np.linalg.norm(
foot_velocity[:, :2], axis=1
).astype(np.float32)
self._prev_low_level_foot_geom_centers = current_centers.copy()
foot_contact.fill(0.0)
foot_normal_force.fill(0.0)
foot_tangent_speed = tangential_speed
contact_force = np.zeros(6, dtype=np.float64)
for contact_idx in range(int(self.d.ncon)):
contact = self.d.contact[contact_idx]
contact_sides = set()
geom1 = int(contact.geom1)
geom2 = int(contact.geom2)
if geom1 in self._foot_geom_id_to_side:
contact_sides.add(self._foot_geom_id_to_side[geom1])
if geom2 in self._foot_geom_id_to_side:
contact_sides.add(self._foot_geom_id_to_side[geom2])
if len(contact_sides) != 1:
continue
side_idx = next(iter(contact_sides))
foot_contact[side_idx] = 1.0
mujoco.mj_contactForce(self.m, self.d, contact_idx, contact_force)
foot_normal_force[side_idx] += np.float32(abs(contact_force[0]))
self._robot_low_level_foot_contact_seq.append(foot_contact)
self._robot_low_level_foot_normal_force_seq.append(foot_normal_force)
self._robot_low_level_foot_tangent_speed_seq.append(foot_tangent_speed)
@staticmethod
def _flatten_single_step_output(values, *, dtype=None) -> np.ndarray:
arr = np.asarray(values, dtype=dtype)
if arr.ndim == 0:
raise ValueError(
"Expected at least 1D output for single-step ONNX routing dump."
)
return arr.reshape(-1, arr.shape[-1])[0]
def _discover_policy_moe_outputs(self) -> None:
self.policy_moe_layer_output_names: list[tuple[int, str, str]] = []
routing_outputs: dict[int, dict[str, str]] = {}
pattern = re.compile(r"^moe_layer_(\d+)_expert_(indices|logits)$")
for node in self.policy_session.get_outputs():
match = pattern.fullmatch(node.name)
if match is None:
continue
layer_idx = int(match.group(1))
kind = str(match.group(2))
routing_outputs.setdefault(layer_idx, {})[kind] = node.name
for layer_idx in sorted(routing_outputs):
layer_outputs = routing_outputs[layer_idx]
if "indices" not in layer_outputs or "logits" not in layer_outputs:
logger.warning(
"Skipping incomplete MoE routing outputs for layer "
f"{layer_idx}: {sorted(layer_outputs)}"
)
continue
self.policy_moe_layer_output_names.append(
(
layer_idx,
layer_outputs["indices"],
layer_outputs["logits"],
)
)
if self.policy_moe_layer_output_names:
logger.info(
"Detected MoE routing outputs for layers: "
f"{[layer_idx for layer_idx, _, _ in self.policy_moe_layer_output_names]}"
)
def _get_stacked_moe_routing_tensors(
self,
) -> tuple[np.ndarray | None, np.ndarray | None]:
indices_seq = getattr(self, "_robot_moe_expert_indices_seq", [])
logits_seq = getattr(self, "_robot_moe_expert_logits_seq", [])
if len(indices_seq) == 0 or len(logits_seq) == 0:
return None, None
return (
np.stack(indices_seq, axis=0).astype(np.int64),
np.stack(logits_seq, axis=0).astype(np.float32),
)
def _get_stacked_low_level_foot_contact_tensors(
self,
) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
contact_seq = getattr(self, "_robot_low_level_foot_contact_seq", [])
normal_force_seq = getattr(
self, "_robot_low_level_foot_normal_force_seq", []
)
tangent_speed_seq = getattr(
self, "_robot_low_level_foot_tangent_speed_seq", []
)
if contact_seq and normal_force_seq and tangent_speed_seq:
return (
np.stack(contact_seq, axis=0).astype(np.float32),
np.stack(normal_force_seq, axis=0).astype(np.float32),
np.stack(tangent_speed_seq, axis=0).astype(np.float32),
)
num_low_level_samples = len(
getattr(self, "_robot_low_level_dof_torque_seq", [])
)
if num_low_level_samples <= 0:
return None, None, None
nan_array = np.full((num_low_level_samples, 2), np.nan, np.float32)
return nan_array.copy(), nan_array.copy(), nan_array.copy()
def _record_onnx_io_frame(self, input_feed, output_names, onnx_output):
if not self._onnx_io_input_names:
self._onnx_io_input_names = list(input_feed.keys())
self._onnx_io_inputs = {
name: [] for name in self._onnx_io_input_names
}
if not self._onnx_io_output_names:
self._onnx_io_output_names = list(output_names)
self._onnx_io_outputs = {
name: [] for name in self._onnx_io_output_names
}
for name in self._onnx_io_input_names:
if name not in input_feed:
raise KeyError(f"Missing ONNX input tensor: {name}")
self._onnx_io_inputs[name].append(
np.array(input_feed[name], copy=True)
)
for name, value in zip(self._onnx_io_output_names, onnx_output):
self._onnx_io_outputs[name].append(np.array(value, copy=True))
@staticmethod
def _stack_onnx_io_frames(
frame_dict: dict[str, list[np.ndarray]],
) -> dict[str, np.ndarray]:
stacked: dict[str, np.ndarray] = {}
for name, frames in frame_dict.items():
if frames:
stacked[name] = np.stack(frames, axis=0)
else:
stacked[name] = np.empty((0,), dtype=np.float32)
return stacked
def save_onnx_io_dump(self, output_path, meta_info):
payload = {
"input_names": list(self._onnx_io_input_names),
"output_names": list(self._onnx_io_output_names),
"inputs": self._stack_onnx_io_frames(self._onnx_io_inputs),
"outputs": self._stack_onnx_io_frames(self._onnx_io_outputs),
"source_npz": meta_info.get(
"source_npz", meta_info.get("source_file", "")
),
"onnx_model": meta_info.get(
"onnx_model", meta_info.get("model", "")
),
}
np.save(output_path, payload, allow_pickle=True)
def _find_actor_place_holder_ndim(self):
n_dim = 0
for obs_dict in self._get_policy_atomic_obs_list():
name = str(list(obs_dict.keys())[0])
if name == "place_holder":
params = obs_dict["place_holder"].get("params", {})
n_dim = int(params.get("n_dim", 0))
if name == "actor_place_holder":
params = obs_dict["actor_place_holder"].get("params", {})
n_dim = int(params.get("n_dim", 0))
return n_dim
def _get_actor_obs_term_params(self, term_name: str) -> dict:
for obs_dict in self._get_policy_atomic_obs_list():
configured_name = str(list(obs_dict.keys())[0])
if configured_name != term_name:
continue
term_cfg = obs_dict[configured_name]
if not isinstance(term_cfg, dict):
return {}
params = term_cfg.get("params", {})
return dict(params) if isinstance(params, dict) else {}
return {}
def _get_ref_keybody_indices(self, term_name: str) -> np.ndarray:
params = self._get_actor_obs_term_params(term_name)
keybody_names = params.get("keybody_names", None)
body_names = [str(name) for name in self.config.robot.body_names]
if keybody_names is None:
return np.arange(len(body_names), dtype=np.int64)
keybody_names = [str(name) for name in keybody_names]
body_name_to_idx = {
body_name: idx for idx, body_name in enumerate(body_names)
}
missing_names = [
name for name in keybody_names if name not in body_name_to_idx
]
if len(missing_names) > 0:
raise ValueError(
f"Unknown keybody_names in '{term_name}': {missing_names}. "
f"Available body names: {body_names}"
)
return np.asarray(
[body_name_to_idx[name] for name in keybody_names],
dtype=np.int64,
)
@staticmethod
def _to_plain_obs_cfg(cfg):
if OmegaConf.is_config(cfg):
plain_cfg = OmegaConf.to_container(cfg, resolve=True)
else:
plain_cfg = dict(cfg)
if not isinstance(plain_cfg, dict):
raise ValueError(
f"Observation term config must be a mapping, got {type(plain_cfg)}"
)
return plain_cfg
def _get_actor_obs_schema_terms(self) -> list[str]:
modules_cfg = self.config.get("modules", None)
if modules_cfg is None:
return []
actor_cfg = modules_cfg.get("actor", None)
if actor_cfg is None:
return []
obs_schema = actor_cfg.get("obs_schema", None)
if obs_schema is None:
return []
ordered_terms: list[str] = []
for _, seq_cfg in obs_schema.items():
seq_terms = seq_cfg.get("terms", [])
ordered_terms.extend(str(term) for term in seq_terms)
return ordered_terms
def _get_actor_atomic_obs_entries(self) -> list[tuple[str, str, dict]]:
obs_cfg = self.config.get("obs", None)
if obs_cfg is None:
raise ValueError("Missing config.obs for MuJoCo sim2sim")
obs_groups = obs_cfg.get("obs_groups", None)
if obs_groups is None:
raise ValueError(
"Missing config.obs.obs_groups for MuJoCo sim2sim"
)
if obs_groups.get("policy", None) is not None:
entries: list[tuple[str, str, dict]] = []
for term_dict in obs_groups.policy.atomic_obs_list:
term_name = str(list(term_dict.keys())[0])
entries.append(
(
"policy",
term_name,
self._to_plain_obs_cfg(term_dict[term_name]),
)
)
return entries
if obs_groups.get("unified", None) is not None:
entries = []
for term_dict in obs_groups.unified.atomic_obs_list:
term_name = str(list(term_dict.keys())[0])
if term_name.startswith("critic_"):
continue
entries.append(
(
"unified",
term_name,
self._to_plain_obs_cfg(term_dict[term_name]),
)
)
if not entries:
raise ValueError(
"obs_groups.unified found but contains no non-critic terms."
)
return entries
raise ValueError(
"Unsupported obs config for MuJoCo sim2sim: expected obs_groups.policy or obs_groups.unified."
)
def _get_policy_atomic_obs_list(self):
"""Resolve the atomic obs list used to build the ONNX policy input.
Supports both legacy configs (obs_groups.policy) and PULSE-stage2 configs
that use a unified group (obs_groups.unified) with actor_/critic_ prefixes.
"""
actor_atomic_entries = self._get_actor_atomic_obs_entries()
schema_terms = self._get_actor_obs_schema_terms()
if len(schema_terms) == 0:
logger.warning(
"modules.actor.obs_schema is unavailable; using obs_groups actor term order for MuJoCo policy input."
)
return [
{term_name: cfg} for _, term_name, cfg in actor_atomic_entries
]
by_full_key: dict[str, tuple[str, dict]] = {}
by_leaf_key: dict[str, tuple[str, dict]] = {}
ambiguous_leaf_keys: set[str] = set()
for group_name, term_name, term_cfg in actor_atomic_entries:
full_key = f"{group_name}/{term_name}"
by_full_key[full_key] = (term_name, term_cfg)
if term_name in by_leaf_key:
ambiguous_leaf_keys.add(term_name)
else:
by_leaf_key[term_name] = (term_name, term_cfg)
ordered_atomic_list = []
for schema_term in schema_terms:
schema_term_key = str(schema_term)
if schema_term_key in by_full_key:
term_name, term_cfg = by_full_key[schema_term_key]
ordered_atomic_list.append({term_name: term_cfg})
continue
leaf_key = schema_term_key.split("/")[-1]
if leaf_key in ambiguous_leaf_keys:
raise ValueError(
"Actor obs_schema term "
f"'{schema_term}' is ambiguous by leaf key '{leaf_key}'. "
"Use explicit group/term hierarchy in obs_schema terms."
)
if leaf_key not in by_leaf_key:
raise ValueError(
"Actor obs_schema term "
f"'{schema_term}' is not present in obs_groups actor atomic obs list."
)
term_name, term_cfg = by_leaf_key[leaf_key]
ordered_atomic_list.append({term_name: term_cfg})
return ordered_atomic_list
# ----------------- Kinematics / velocities -----------------
# ----------------- Kinematics / velocities -----------------
def _body_origin_world_velocity(
self, body_id: int
) -> tuple[np.ndarray, np.ndarray]:
"""Compute world-frame spatial velocity (v, w) of a body's frame origin.
Returns:
tuple: (lin_vel_w[3], ang_vel_w[3]) in world coordinates.
"""
# World-frame Jacobians for body origin
jacp = np.zeros((3, self.m.nv), dtype=np.float64)
jacr = np.zeros((3, self.m.nv), dtype=np.float64)
mujoco.mj_jacBody(self.m, self.d, jacp, jacr, int(body_id))
# qvel is float64 in MuJoCo; keep computation in float64 then cast
lin_vel_w = jacp @ self.d.qvel
ang_vel_w = jacr @ self.d.qvel
return lin_vel_w.astype(np.float32), ang_vel_w.astype(np.float32)
# ----------------- Body name/id resolution -----------------
def _get_anchor_body_name(self) -> str:
if not hasattr(self, "anchor_body_name"):
self.anchor_body_name = str(
getattr(self.config.robot, "anchor_body", "pelvis")
)
logger.info(f"Anchor body name: {self.anchor_body_name}")
return self.anchor_body_name
def _get_torso_body_name(self) -> str:
if not hasattr(self, "torso_body_name"):
self.torso_body_name = str(
getattr(self.config.robot, "torso_name", "torso_link")
)
return self.torso_body_name
@property
def ref_motion_frame_idx(self):
return self.motion_frame_idx
@property
def anchor_body_idx(self) -> int:
return self.config.robot.body_names.index(
self.config.robot.anchor_body
)
@property
def root_body_idx(self) -> int:
return 0
@property
def torso_body_idx(self) -> int:
return self.config.robot.body_names.index(self.config.robot.torso_name)
@property
def robot_global_bodylink_pos(self):
"""World-frame positions of all robot bodies at their MuJoCo body frame origins.
MuJoCo stores body state for a special world body at index 0, which does not
correspond to any physical link and is always static. We slice it out and
return `xpos[1:]` so that row 0 corresponds to the root body (e.g. pelvis)
and the body dimension matches the HoloMotion NPZ `*_global_translation`
arrays.
Returns:
np.ndarray: Array of shape [n_bodies, 3] in MuJoCo body order with the
world body excluded.
"""
return self.d.xpos[1:]
@property
def robot_global_bodylink_rot(self):
"""World-frame orientations of all robot bodies as WXYZ quaternions.
As with positions, the MuJoCo world body at index 0 is excluded so that the
returned array is aligned with the body dimension used in HoloMotion NPZ
`*_global_rotation_quat` arrays (root at index 0, no world entry).
Returns:
np.ndarray: Array of shape [n_bodies, 4] in MuJoCo body order with the
world body excluded.
"""
xquat = self.d.xquat[1:]
xquat_t = torch.as_tensor(xquat, dtype=torch.float32, device="cpu")
xquat_t = standardize_quaternion(xquat_t)
return xquat_t.detach().cpu().numpy()
@property
def robot_global_bodylink_lin_vel(self):
"""World-frame linear velocities of all robot body frame origins.
Uses `mujoco.mj_objectVelocity` with `mjOBJ_BODY` and `flg_centered=0` to
query the 6D spatial velocity at each body's frame origin, then slices the
translational component. The world body (ID 0) is excluded so that the body
dimension matches the NPZ `*_global_velocity` arrays.
Returns:
np.ndarray: Array of shape [n_bodies, 3] giving linear velocities in the
MuJoCo world frame, ordered by body ID starting from the root body.
"""
nbody = int(self.m.nbody)
vel_6d = np.zeros((nbody, 6), dtype=np.float64)
for bid in range(1, nbody):
mujoco.mj_objectVelocity(
self.m,
self.d,
mujoco.mjtObj.mjOBJ_BODY,
bid,
vel_6d[bid],
0,
)
return vel_6d[1:, 3:6]
@property
def robot_global_bodylink_ang_vel(self):
"""World-frame angular velocities of all robot body frame origins.
Uses the same `mujoco.mj_objectVelocity` call as
`robot_global_bodylink_lin_vel` and slices the rotational component. The
world body (ID 0) is dropped so that the body dimension is identical to the
NPZ `*_global_angular_velocity` arrays and the translation/rotation/velocity
tensors all share the same body ordering.
Returns:
np.ndarray: Array of shape [n_bodies, 3] giving angular velocities in
the MuJoCo world frame, ordered by body ID starting from the root body.
"""
nbody = int(self.m.nbody)
vel_6d = np.zeros((nbody, 6), dtype=np.float64)
for bid in range(1, nbody):
mujoco.mj_objectVelocity(
self.m,
self.d,
mujoco.mjtObj.mjOBJ_BODY,
bid,
vel_6d[bid],
0,
)
return vel_6d[1:, 0:3]
@property
def robot_dof_pos(self):
if hasattr(self, "actuator_qpos_indices"):
return self.d.qpos[self.actuator_qpos_indices]
return self.d.qpos[7:]
@property
def robot_dof_vel(self):
if hasattr(self, "actuator_qvel_indices"):
return self.d.qvel[self.actuator_qvel_indices]
return self.d.qvel[6:]
# ----------------- Reference->Simulation alignment -----------------
def _ensure_ref_to_sim_transform_rigid(self):
"""Compute rigid transform (yaw + translation) from reference globals to sim globals.
The transform is defined such that the reference **anchor body** pose at frame 0 is mapped
onto the robot's current global anchor pose in XY translation and yaw:
- `yaw(q_ref_to_sim * q_ref_anchor_0) = yaw(q_robot_anchor_0)`
- `t_ref_to_sim + R(q_ref_to_sim) @ t_ref_anchor_0 = t_robot_anchor_0`
This uses the robot's initial global pose so that arbitrary initialization offsets in
XY position and yaw between the robot and the reference motion are absorbed into the
reference->simulation mapping, and all subsequent reference globals are expressed in the
same world frame as the robot.
"""
if self._ref_to_sim_ready:
return
# If we don't have reference globals, fall back to identity transform.
if getattr(self, "ref_global_translation", None) is None:
self._ref_to_sim_q_wxyz = np.array(
[1.0, 0.0, 0.0, 0.0], dtype=np.float32
)
self._ref_to_sim_t = np.zeros(3, dtype=np.float32)
self._ref_to_sim_ready = True
logger.info(
"No reference global translations available; using identity Ref->Sim transform."
)
return
# If rotations are missing, keep the previous translation-only semantics.
if getattr(self, "ref_global_rotation_quat_xyzw", None) is None:
t_robot = torch.as_tensor(
self.robot_global_bodylink_pos[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
t_ref = torch.as_tensor(
self.ref_global_translation[0, self.anchor_body_idx].astype(
np.float32
),
dtype=torch.float32,
device="cpu",
)
t_ref_to_sim = t_robot - t_ref
self._ref_to_sim_q_wxyz = np.array(
[1.0, 0.0, 0.0, 0.0], dtype=np.float32
)
self._ref_to_sim_t = t_ref_to_sim.detach().cpu().numpy()
self._ref_to_sim_ready = True
logger.info(
"Reference rotations missing; initialized Ref->Sim as translation-only "
f"transform. t={self._ref_to_sim_t}"
)
return
# Anchor body index shared between robot globals and reference globals
anchor_idx = self.anchor_body_idx
# Robot anchor pose in simulation world frame (after initial state has been set)
t_robot = torch.as_tensor(
self.robot_global_bodylink_pos[anchor_idx],
dtype=torch.float32,
device="cpu",
)
q_robot_wxyz = torch.as_tensor(
self.robot_global_bodylink_rot[anchor_idx],
dtype=torch.float32,
device="cpu",
)
# Reference anchor pose at frame 0 in NPZ global frame
t_ref0 = torch.as_tensor(
self.ref_global_translation[0, anchor_idx].astype(np.float32),
dtype=torch.float32,
device="cpu",
)
q_ref0_xyzw = torch.as_tensor(
self.ref_global_rotation_quat_xyzw[0, anchor_idx].astype(
np.float32
),
dtype=torch.float32,
device="cpu",
)
q_ref0_wxyz = xyzw_to_wxyz(q_ref0_xyzw)
# Yaw-only rotation mapping: align reference yaw to robot yaw (keep roll/pitch from reference).
R_robot = matrix_from_quat(q_robot_wxyz)
R_ref0 = matrix_from_quat(q_ref0_wxyz)
yaw_robot = torch.atan2(R_robot[1, 0], R_robot[0, 0])
yaw_ref0 = torch.atan2(R_ref0[1, 0], R_ref0[0, 0])
yaw_delta = yaw_robot - yaw_ref0
yaw_quat_xyzw = quat_from_euler_xyz(
torch.tensor(0.0, dtype=torch.float32, device="cpu"),
torch.tensor(0.0, dtype=torch.float32, device="cpu"),
yaw_delta,
)
q_ref_to_sim = xyzw_to_wxyz(yaw_quat_xyzw)
q_ref_to_sim = quat_normalize_wxyz(q_ref_to_sim)
# Translation mapping: t_ref_to_sim + R(q_ref_to_sim) @ t_ref0 = t_robot
t_ref0_in_sim = quat_apply(q_ref_to_sim, t_ref0)
t_ref_to_sim = t_robot - t_ref0_in_sim
self._ref_to_sim_q_wxyz = (
q_ref_to_sim.detach().cpu().numpy().astype(np.float32)
)
self._ref_to_sim_t = (
t_ref_to_sim.detach().cpu().numpy().astype(np.float32)
)
self._ref_to_sim_ready = True
logger.info(
"Initialized Ref->Sim rigid transform. "
f"q={self._ref_to_sim_q_wxyz}, t={self._ref_to_sim_t}"
)
def _detect_command_mode(self) -> str:
m_dir = self.config.get("motion_npz_dir") or self.config.get(
"eval", {}
).get("motion_npz_dir")
m_path = self.config.get("motion_npz_path") or self.config.get(
"eval", {}
).get("motion_npz_path")
if m_path is not None and not os.path.exists(m_path):
raise FileNotFoundError(f"Motion file not found: {m_path}")
if (m_dir and str(m_dir) != "") or (m_path and str(m_path) != ""):
return "motion_tracking"
return "velocity_tracking"
def _init_obs_buffers(self):
atomic_list = self._get_policy_atomic_obs_list()
obs_policy_cfg = {"atomic_obs_list": atomic_list}
self.obs_builder = PolicyObsBuilder(
dof_names_onnx=self.dof_names_onnx,
default_angles_onnx=self.default_angles_onnx,
evaluator=self,
obs_policy_cfg=obs_policy_cfg,
)
def load_policy(self):
"""Load the policy model using ONNX Runtime."""
onnx_model_path = Path(self.config.ckpt_onnx_path)
logger.info(f"Loading ONNX policy from {onnx_model_path}")
providers = ["CPUExecutionProvider"]
use_gpu = _coerce_config_bool(
self.config.get("use_gpu", False), default=False
)
gpu_id = int(self.config.get("gpu_id", 0))
available_providers = onnxruntime.get_available_providers()
if use_gpu:
if "CUDAExecutionProvider" in available_providers:
cuda_options = {"device_id": gpu_id}
if torch.cuda.is_available():
torch.cuda.set_device(gpu_id)
cuda_options["user_compute_stream"] = str(
torch.cuda.current_stream().cuda_stream
)
providers = [
("CUDAExecutionProvider", cuda_options),
"CPUExecutionProvider",
]
logger.info(
f"Using CUDAExecutionProvider with gpu_id={gpu_id}"
)
else:
logger.warning(
"use_gpu=true but CUDAExecutionProvider is unavailable; "
"falling back to CPUExecutionProvider."
)
sess_options = onnxruntime.SessionOptions()
sess_options.intra_op_num_threads = 1
sess_options.inter_op_num_threads = 1
sess_options.log_severity_level = 3
self.policy_session = onnxruntime.InferenceSession(
str(onnx_model_path),
sess_options=sess_options,
providers=providers,
)
logger.info(
f"ONNX Runtime session created successfully using: {self.policy_session.get_providers()}"
)
self.policy_input_name = self.policy_session.get_inputs()[0].name
self.policy_output_name = self.policy_session.get_outputs()[0].name
logger.info(
f"Policy ONNX Input: {self.policy_input_name}, Output: {self.policy_output_name}"
)
logger.info("Initializing KV-Cache for Policy...")
self.policy_input_name = "obs"
self.policy_kv_input_name = None
self.policy_step_input_name = None
self.policy_kv_shape = None
for node in self.policy_session.get_inputs():
name = node.name
shape = node.shape
logger.info(f"Model Input: Name={name}, Shape={shape}")
if "obs" in name:
self.policy_input_name = name
elif "past_key_values" in name:
self.policy_kv_input_name = name
self.policy_kv_shape = shape
elif "step_idx" in name or "step" in name or "pos" in name:
self.policy_step_input_name = name
self.policy_output_name = self.policy_session.get_outputs()[0].name
self.policy_kv_output_name = None
for node in self.policy_session.get_outputs():
if "present_key_values" in node.name:
self.policy_kv_output_name = node.name
self._discover_policy_moe_outputs()
if self.policy_kv_input_name and self.policy_kv_shape:
shape = [
d if isinstance(d, int) else 1 for d in self.policy_kv_shape
]
self.policy_kv_cache = np.zeros(shape, dtype=np.float32)
self.policy_model_context_len = (
int(shape[3]) if len(shape) > 3 else 0
)
if self.max_context_len > 0 and self.policy_model_context_len > 0:
self.policy_effective_context_len = min(
self.max_context_len, self.policy_model_context_len
)
logger.info(
"Using context window from "
f"algo.config.num_steps_per_env={self.max_context_len} "
f"(model cache len={self.policy_model_context_len}, "
f"effective={self.policy_effective_context_len})"
)
else:
self.policy_effective_context_len = (
self.policy_model_context_len
)
self.use_kv_cache = True
logger.info(f"KV-Cache ENABLED. Shape: {shape}")
else:
self.use_kv_cache = False
self.policy_kv_cache = None
self.policy_model_context_len = 0
self.policy_effective_context_len = 0
logger.warning("KV-Cache NOT found in model inputs!")
if self.max_context_len > 0:
logger.warning(
"algo.config.num_steps_per_env is set but KV-Cache is unavailable; "
"ignoring context window limit."
)
logger.info("ONNX Policy loaded successfully")
def _read_onnx_metadata(self) -> dict:
"""Read model metadata from ONNX file and parse into Python types."""
onnx_model_path = Path(self.config.ckpt_onnx_path)
model = onnx.load(str(onnx_model_path))
meta = {p.key: p.value for p in model.metadata_props}
def _parse_floats(csv_str: str):
return np.array(
[float(x) for x in csv_str.split(",") if x != ""],
dtype=np.float32,
)
result = {}
result["action_scale"] = _parse_floats(meta["action_scale"])
result["kps"] = _parse_floats(meta["joint_stiffness"])
result["kds"] = _parse_floats(meta["joint_damping"])
result["default_joint_pos"] = _parse_floats(meta["default_joint_pos"])
result["joint_names"] = [
x for x in meta["joint_names"].split(",") if x != ""
]
# 打印解析后的元数据
logger.info("=== Loaded ONNX Metadata ===")
for key, value in result.items():
# 如果关节名称列表很长,进行格式化处理以保持整洁
if key == "joint_names":
logger.info(f"{key}: {', '.join(value)}")
else:
logger.info(f"{key}:\n{value}")
logger.info("============================")
return result
def _apply_onnx_metadata(self):
"""Apply PD/scale/defaults from ONNX metadata as authoritative values."""
meta = self._read_onnx_metadata()
self.dof_names_onnx = meta["joint_names"]
self.action_scale_onnx = meta["action_scale"].astype(np.float32)
self.kps_onnx = meta["kps"].astype(np.float32)
self.kds_onnx = meta["kds"].astype(np.float32)
self.default_angles_onnx = meta["default_joint_pos"].astype(np.float32)
def _build_dof_mappings(self):
# Map ONNX <-> MJCF for control
self.onnx_to_mu = [
self.dof_names_onnx.index(name) for name in self.mjcf_dof_names
]
self.mu_to_onnx = [
self.mjcf_dof_names.index(name) for name in self.dof_names_onnx
]
self.ref_to_onnx = [
self.dof_names_ref_motion.index(name)
for name in self.dof_names_onnx
]
# Map MuJoCo actuator DOF order -> reference DOF order used in motion npz
self.mu_to_ref = []
for mu_idx in range(len(self.mjcf_dof_names)):
onnx_idx = self.onnx_to_mu[mu_idx]
ref_idx = self.ref_to_onnx[onnx_idx]
self.mu_to_ref.append(ref_idx)
self.kps_mu = self.kps_onnx[self.onnx_to_mu].astype(np.float32)
self.kds_mu = self.kds_onnx[self.onnx_to_mu].astype(np.float32)
self.default_angles_mu = self.default_angles_onnx[
self.onnx_to_mu
].astype(np.float32)
self.action_scale_mu = self.action_scale_onnx[self.onnx_to_mu].astype(
np.float32
)
@staticmethod
def _normalize_filter_cutoff_hz(raw_values, num_frames: int) -> np.ndarray:
num_frames = max(int(num_frames), 0)
if num_frames == 0:
return np.zeros((0, 1), dtype=np.float32)
if raw_values is None:
return np.zeros((num_frames, 1), dtype=np.float32)
cutoff = np.asarray(raw_values, dtype=np.float32)
if cutoff.ndim == 0:
cutoff = np.full((num_frames, 1), float(cutoff), dtype=np.float32)
return cutoff
if cutoff.ndim == 1:
cutoff = cutoff[:, None]
else:
cutoff = cutoff.reshape(cutoff.shape[0], -1)[:, :1]
if cutoff.shape[0] == 0:
return np.zeros((num_frames, 1), dtype=np.float32)
if cutoff.shape[0] == 1 and num_frames > 1:
cutoff = np.repeat(cutoff, num_frames, axis=0)
elif cutoff.shape[0] < num_frames:
pad = np.repeat(
cutoff[-1:, :], num_frames - cutoff.shape[0], axis=0
)
cutoff = np.concatenate([cutoff, pad], axis=0)
elif cutoff.shape[0] > num_frames:
cutoff = cutoff[:num_frames]
return cutoff.astype(np.float32, copy=False)
def load_motion_data(self):
"""Load motion data from npz file."""
motion_npz_path = self.config.get("motion_npz_path", None)
if motion_npz_path is None:
logger.warning(
"No motion_npz_path specified in config, using zero reference motion"
)
return
logger.info(f"Loading motion data from {motion_npz_path}")
# Load npz file
with np.load(motion_npz_path, allow_pickle=True) as npz:
keys = list(npz.keys())
raw_filter_cutoff_hz = (
np.array(npz["filter_cutoff_hz"]).astype(np.float32)
if "filter_cutoff_hz" in npz
else None
)
# Try direct arrays first (dof_pos, dof_vel or variants)
naming_pairs = [
("ref_dof_pos", "ref_dof_vel"),
("dof_pos", "dof_vels"), # backward compat
# ("ft_ref_pos", "ft_ref_dof_vel"),
]
pos_key = None
vel_key = None
for pos_k, vel_k in naming_pairs:
if pos_k in npz and vel_k in npz:
pos_key = pos_k
vel_key = vel_k
break
if pos_key is not None and vel_key is not None:
# Direct arrays found
self.ref_dof_pos = np.array(npz[pos_key]).astype(np.float32)
self.ref_dof_vel = np.array(npz[vel_key]).astype(np.float32)
elif len(keys) == 1:
# Single key - might contain nested dict
arr = npz[keys[0]]
if getattr(arr, "dtype", None) == object:
obj = arr.item() if arr.size == 1 else arr
if isinstance(obj, dict):
if (
raw_filter_cutoff_hz is None
and "filter_cutoff_hz" in obj
):
raw_filter_cutoff_hz = np.array(
obj["filter_cutoff_hz"]
).astype(np.float32)
# Try to find dof_pos/dof_vel in nested dict
for pos_k, vel_k in naming_pairs:
if pos_k in obj and vel_k in obj:
self.ref_dof_pos = np.array(obj[pos_k]).astype(
np.float32
)
self.ref_dof_vel = np.array(obj[vel_k]).astype(
np.float32
)
break
else:
raise ValueError(
f"Could not find dof_pos/dof_vel in nested dict. "
f"Available keys: {list(obj.keys())}"
)
else:
raise ValueError(
f"Single key '{keys[0]}' does not contain a dict. "
f"Type: {type(obj)}"
)
else:
raise ValueError(
f"Single key '{keys[0]}' is not an object array. "
f"Available keys: {keys}"
)
else:
raise ValueError(
f"Could not find dof_pos/dof_vel arrays. Available keys: {keys}"
)
# Ensure consistent frame count
if self.ref_dof_pos.shape[0] != self.ref_dof_vel.shape[0]:
min_frames = min(
self.ref_dof_pos.shape[0], self.ref_dof_vel.shape[0]
)
self.ref_dof_pos = self.ref_dof_pos[:min_frames]
self.ref_dof_vel = self.ref_dof_vel[:min_frames]
logger.warning(
f"Frame count mismatch, truncated to {min_frames} frames"
)
self.n_motion_frames = self.ref_dof_pos.shape[0]
# Optional: load reference global body frames as per motion spec
ref_pos_keys = ["ref_global_translation", "global_translation"]
ref_rot_keys = ["ref_global_rotation_quat", "global_rotation_quat"]
ref_vel_keys = ["ref_global_velocity", "global_velocity"]
ref_ang_vel_keys = [
"ref_global_angular_velocity",
"global_angular_velocity",
]
self.ref_global_translation = None
self.ref_global_rotation_quat_xyzw = None
self.ref_global_velocity = None
self.ref_global_angular_velocity = None
for k in ref_pos_keys:
if k in npz:
self.ref_global_translation = np.array(npz[k]).astype(
np.float32
)
break
for k in ref_rot_keys:
if k in npz:
self.ref_global_rotation_quat_xyzw = np.array(
npz[k]
).astype(np.float32)
break
for k in ref_vel_keys:
if k in npz:
self.ref_global_velocity = np.array(npz[k]).astype(
np.float32
)
break
for k in ref_ang_vel_keys:
if k in npz:
self.ref_global_angular_velocity = np.array(npz[k]).astype(
np.float32
)
break
if self.ref_global_translation is not None:
# Truncate to motion frames if needed
t_tr = min(
self.n_motion_frames, self.ref_global_translation.shape[0]
)
if t_tr < self.n_motion_frames:
logger.warning(
f"Global translation shorter than motion frames ({t_tr} < {self.n_motion_frames}), truncating motion."
)
self.n_motion_frames = t_tr
self.ref_dof_pos = self.ref_dof_pos[:t_tr]
self.ref_dof_vel = self.ref_dof_vel[:t_tr]
self.ref_global_translation = self.ref_global_translation[
:t_tr
]
if self.ref_global_rotation_quat_xyzw is not None:
t_rr = min(
self.n_motion_frames,
self.ref_global_rotation_quat_xyzw.shape[0],
)
if t_rr < self.n_motion_frames:
logger.warning(
f"Global rotation shorter than motion frames ({t_rr} < {self.n_motion_frames}), truncating motion."
)
self.n_motion_frames = t_rr
self.ref_dof_pos = self.ref_dof_pos[:t_rr]
self.ref_dof_vel = self.ref_dof_vel[:t_rr]
# Also truncate previously processed globals if necessary
if self.ref_global_translation is not None:
self.ref_global_translation = (
self.ref_global_translation[:t_rr]
)
self.ref_global_rotation_quat_xyzw = (
self.ref_global_rotation_quat_xyzw[:t_rr]
)
if self.ref_global_velocity is not None:
t_rv = min(
self.n_motion_frames,
self.ref_global_velocity.shape[0],
)
if t_rv < self.n_motion_frames:
self.n_motion_frames = t_rv
self.ref_dof_pos = self.ref_dof_pos[:t_rv]
self.ref_dof_vel = self.ref_dof_vel[:t_rv]
if self.ref_global_translation is not None:
self.ref_global_translation = (
self.ref_global_translation[:t_rv]
)
if self.ref_global_rotation_quat_xyzw is not None:
self.ref_global_rotation_quat_xyzw = (
self.ref_global_rotation_quat_xyzw[:t_rv]
)
self.ref_global_velocity = self.ref_global_velocity[:t_rv]
if self.ref_global_angular_velocity is not None:
t_ra = min(
self.n_motion_frames,
self.ref_global_angular_velocity.shape[0],
)
if t_ra < self.n_motion_frames:
self.n_motion_frames = t_ra
self.ref_dof_pos = self.ref_dof_pos[:t_ra]
self.ref_dof_vel = self.ref_dof_vel[:t_ra]
if self.ref_global_translation is not None:
self.ref_global_translation = (
self.ref_global_translation[:t_ra]
)
if self.ref_global_rotation_quat_xyzw is not None:
self.ref_global_rotation_quat_xyzw = (
self.ref_global_rotation_quat_xyzw[:t_ra]
)
if self.ref_global_velocity is not None:
self.ref_global_velocity = self.ref_global_velocity[
:t_ra
]
self.ref_global_angular_velocity = (
self.ref_global_angular_velocity[:t_ra]
)
self.filter_cutoff_hz = self._normalize_filter_cutoff_hz(
raw_filter_cutoff_hz, self.n_motion_frames
)
logger.info(
f"Loaded motion data with {self.n_motion_frames} frames and {self.ref_dof_pos.shape[1]} DOFs"
)
def load_mujoco_model(self):
"""Load the MuJoCo model."""
xml_path = self.config.get("robot_xml_path", None)
if xml_path is None:
raise ValueError(
"robot_xml_path should be specified in config !!!"
)
logger.info(f"Loading MuJoCo model from {xml_path}")
self.m = mujoco.MjModel.from_xml_path(xml_path)
self.d = mujoco.MjData(self.m)
self.m.opt.timestep = self.simulation_dt
logger.info(
f"MuJoCo model loaded with {self.m.nq} position DOFs and {self.m.nu} control DOFs"
)
def _init_camera_config(self):
"""Initialize shared camera configuration for viewer and offscreen renderers."""
self._root_body_id = -1
if not self._camera_tracking_enabled:
logger.info("Camera tracking disabled")
return
# Prefer anchor body from robot config, then fall back to common root names
candidates: list[str] = []
anchor_name = self._get_anchor_body_name()
candidates.append(anchor_name)
for name in ["pelvis", "torso", "base_link", "trunk", "root"]:
if name not in candidates:
candidates.append(name)
for body_name in candidates:
bid = int(
mujoco.mj_name2id(self.m, mujoco.mjtObj.mjOBJ_BODY, body_name)
)
if bid != -1:
self._root_body_id = bid
break
if self._root_body_id != -1:
logger.info(
f"Camera tracking enabled for body '{body_name}' (ID={self._root_body_id}), "
f"lookat height offset: {self._camera_height_offset:.2f}m"
)
else:
logger.warning(
"Could not find robot root body for camera tracking; "
"viewer and offscreen cameras will not track the robot."
)
def _configure_viewer_camera(self, viewer):
"""Apply shared align-view parameters to the interactive viewer camera."""
mujoco.mjv_defaultFreeCamera(self.m, viewer.cam)
viewer.cam.azimuth = self._camera_azimuth
viewer.cam.elevation = self._camera_elevation
viewer.cam.distance = self._camera_distance
def _init_video_tools(self, tag: str):
"""Initialize video writer and offscreen renderer when recording is enabled."""
if not bool(self.config.get("record_video", False)):
return
width = int(self.config.get("video_width", 1280))
height = int(self.config.get("video_height", 720))
fps = float(self.config.get("video_fps", 30.0))
onnx_stem = os.path.splitext(
os.path.basename(self.config.ckpt_onnx_path)
)[0]
output_dir = os.path.join(
os.path.dirname(self.config.ckpt_onnx_path),
f"mujoco_output_{onnx_stem}",
)
os.makedirs(output_dir, exist_ok=True)
motion_npz_path = self.config.get("motion_npz_path", None)
if motion_npz_path is not None:
motion_stem = os.path.splitext(os.path.basename(motion_npz_path))[
0
]
out_path = os.path.join(output_dir, f"{motion_stem}.mp4")
else:
out_path = os.path.join(output_dir, "velocity_tracking.mp4")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
self._video_writer = cv2.VideoWriter(
out_path, fourcc, fps, (width, height)
)
self._offscreen = OffscreenRenderer(
self.m,
height,
width,
distance=self._camera_distance,
azimuth=self._camera_azimuth,
elevation=self._camera_elevation,
)
self._frame_interval = 1.0 / max(fps, 1.0)
self._last_frame_time = 0.0
if getattr(self, "ref_global_translation", None) is not None:
self._offscreen.set_overlay_callback(
lambda scene: self._draw_ref_body_spheres_to_scene(
scene, reset_ngeom=False
)
)
logger.info(f"Recording enabled. Writing to: {out_path}")
def _dump_robot_augmented_npz(self) -> None:
"""Copy original motion npz and append robot_* states, saved next to video output.
The output follows the holomotion offline-eval spec used in PPO:
- robot_dof_pos, robot_dof_vel: [T, num_dofs]
- robot_global_translation: [T, num_bodies, 3]
- robot_global_rotation_quat: [T, num_bodies, 4] (XYZW)
- robot_global_velocity: [T, num_bodies, 3]
- robot_global_angular_velocity: [T, num_bodies, 3]
"""
motion_npz_path = self.config.get("motion_npz_path", None)
if motion_npz_path is None:
return
if len(self._robot_dof_pos_seq) == 0:
return
# Stack recorded sequences
robot_dof_pos = np.stack(self._robot_dof_pos_seq, axis=0).astype(
np.float32
)
robot_dof_vel = np.stack(self._robot_dof_vel_seq, axis=0).astype(
np.float32
)
robot_dof_acc = np.stack(self._robot_dof_acc_seq, axis=0).astype(
np.float32
)
robot_dof_torque = np.stack(self._robot_dof_torque_seq, axis=0).astype(
np.float32
)
robot_low_level_dof_torque = None
if len(self._robot_low_level_dof_torque_seq) > 0:
robot_low_level_dof_torque = np.stack(
self._robot_low_level_dof_torque_seq, axis=0
).astype(np.float32)
(
robot_low_level_foot_contact,
robot_low_level_foot_normal_force,
robot_low_level_foot_tangent_speed,
) = self._get_stacked_low_level_foot_contact_tensors()
robot_actions = None
if len(getattr(self, "_robot_actions_seq", [])) > 0:
robot_actions = np.stack(self._robot_actions_seq, axis=0).astype(
np.float32
)
robot_action_rate = np.asarray(
self._robot_action_rate_seq, dtype=np.float32
)
robot_global_translation = np.stack(
self._robot_global_translation_seq, axis=0
).astype(np.float32)
robot_global_rotation_quat = np.stack(
self._robot_global_rotation_quat_seq, axis=0
).astype(np.float32)
robot_global_velocity = np.stack(
self._robot_global_velocity_seq, axis=0
).astype(np.float32)
robot_global_angular_velocity = np.stack(
self._robot_global_angular_velocity_seq, axis=0
).astype(np.float32)
robot_moe_expert_indices, robot_moe_expert_logits = (
self._get_stacked_moe_routing_tensors()
)
# Load original motion npz
with np.load(motion_npz_path, allow_pickle=True) as npz:
data_dict = {k: npz[k] for k in npz.files}
# Augment with robot_* arrays (override if already present)
data_dict["robot_dof_pos"] = robot_dof_pos
data_dict["robot_dof_vel"] = robot_dof_vel
data_dict["robot_dof_acc"] = robot_dof_acc
data_dict["robot_dof_torque"] = robot_dof_torque
if robot_low_level_dof_torque is not None:
data_dict["robot_low_level_dof_torque"] = (
robot_low_level_dof_torque
)
if robot_low_level_foot_contact is not None:
data_dict["robot_low_level_foot_contact"] = (
robot_low_level_foot_contact
)
if robot_low_level_foot_normal_force is not None:
data_dict["robot_low_level_foot_normal_force"] = (
robot_low_level_foot_normal_force
)
if robot_low_level_foot_tangent_speed is not None:
data_dict["robot_low_level_foot_tangent_speed"] = (
robot_low_level_foot_tangent_speed
)
if robot_actions is not None:
data_dict["robot_actions"] = robot_actions
data_dict["robot_low_level_torque_dt"] = np.array(
self.simulation_dt, dtype=np.float32
)
data_dict["robot_low_level_contact_dt"] = np.array(
self.simulation_dt, dtype=np.float32
)
data_dict["robot_action_rate"] = robot_action_rate
data_dict["robot_global_translation"] = robot_global_translation
data_dict["robot_global_rotation_quat"] = robot_global_rotation_quat
data_dict["robot_global_velocity"] = robot_global_velocity
data_dict["robot_global_angular_velocity"] = (
robot_global_angular_velocity
)
if robot_moe_expert_indices is not None:
data_dict["robot_moe_expert_indices"] = robot_moe_expert_indices
if robot_moe_expert_logits is not None:
data_dict["robot_moe_expert_logits"] = robot_moe_expert_logits
# Derive output directory consistent with video writer
onnx_stem = os.path.splitext(
os.path.basename(self.config.ckpt_onnx_path)
)[0]
output_dir = os.path.join(
os.path.dirname(self.config.ckpt_onnx_path),
f"mujoco_output_{onnx_stem}",
)
os.makedirs(output_dir, exist_ok=True)
motion_stem = os.path.splitext(os.path.basename(motion_npz_path))[0]
out_npz_path = os.path.join(output_dir, f"{motion_stem}_robot.npz")
np.savez_compressed(out_npz_path, **data_dict)
logger.info(
f"Robot-augmented motion npz saved to: {out_npz_path} "
f"(T={robot_dof_pos.shape[0]}, num_dofs={robot_dof_pos.shape[1]}, "
f"num_bodies={robot_global_translation.shape[1]})"
)
def _close_video_tools(self):
if self._video_writer is not None:
self._video_writer.release()
self._video_writer = None
if self._offscreen is not None:
self._offscreen.close()
self._offscreen = None
self._frame_interval = None
self._last_frame_time = 0.0
def _update_camera_lookat(self, cam):
"""Update camera lookat to track the robot root when tracking is enabled."""
if not self._camera_tracking_enabled:
return
if self._root_body_id == -1:
return
cam.lookat[:2] = self.d.xpos[self._root_body_id][:2]
cam.lookat[2] = (
self.d.xpos[self._root_body_id][2] + self._camera_height_offset
)
def _maybe_record_frame(self):
if self._video_writer is None or self._offscreen is None:
return
now = time.time()
if (
self._last_frame_time == 0.0
or (now - self._last_frame_time) >= self._frame_interval
):
# Update offscreen camera lookat to track robot (if enabled)
self._update_camera_lookat(self._offscreen._cam)
frame_rgb = self._offscreen.render(self.d)
# Convert RGB (MuJoCo) -> BGR (OpenCV) before writing
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
self._video_writer.write(frame_bgr)
self._last_frame_time = now
def _apply_control(self, sleep: bool):
"""Apply PD targets via Unitree lowcmd, step MuJoCo, optionally sleep."""
for _ in range(self.control_decimation):
record_low_level_torque = (
self.command_mode == "motion_tracking"
and self.ref_dof_pos is not None
)
if record_low_level_torque:
torque_ref = np.zeros(
len(self.dof_names_ref_motion), dtype=np.float32
)
current_dof_pos = self.robot_dof_pos
current_dof_vel = self.robot_dof_vel
for name, act_idx in self.actuator_name_to_index.items():
mu_idx = self.actuator_name_to_mu_idx[name]
joint_name = self.mjcf_dof_names[mu_idx]
target_q = self.target_dof_pos_by_name.get(
joint_name,
float(self.default_angles_mu[mu_idx]),
)
target_dq = 0.0
feedforward_tau = 0.0
kp = self.kps_mu[mu_idx]
kd = self.kds_mu[mu_idx]
current_q = current_dof_pos[mu_idx]
current_dq = current_dof_vel[mu_idx]
tau = (
feedforward_tau
+ kp * (target_q - current_q)
+ kd * (target_dq - current_dq)
)
if (
act_idx in self.actuator_force_range
and self.actuator_force_range[act_idx] is not None
):
min_force, max_force = self.actuator_force_range[act_idx]
tau = np.clip(tau, min_force, max_force)
self.d.ctrl[mu_idx] = tau
if record_low_level_torque:
torque_ref[self.mu_to_ref[mu_idx]] = np.float32(tau)
mujoco.mj_step(self.m, self.d)
if record_low_level_torque:
self._robot_low_level_dof_torque_seq.append(torque_ref)
self._record_low_level_foot_contact_sample()
if sleep:
time.sleep(self.simulation_dt)
def _compute_pd_torque_command_ref(self) -> np.ndarray:
current_dof_pos = self.robot_dof_pos
current_dof_vel = self.robot_dof_vel
num_mu_dofs = len(self.mjcf_dof_names)
torque_mu = np.zeros(num_mu_dofs, dtype=np.float32)
for name, act_idx in self.actuator_name_to_index.items():
mu_idx = self.actuator_name_to_mu_idx[name]
joint_name = self.mjcf_dof_names[mu_idx]
target_q = self.target_dof_pos_by_name.get(
joint_name,
float(self.default_angles_mu[mu_idx]),
)
target_dq = 0.0
feedforward_tau = 0.0
kp = self.kps_mu[mu_idx]
kd = self.kds_mu[mu_idx]
current_q = current_dof_pos[mu_idx]
current_dq = current_dof_vel[mu_idx]
tau = (
feedforward_tau
+ kp * (target_q - current_q)
+ kd * (target_dq - current_dq)
)
if (
act_idx in self.actuator_force_range
and self.actuator_force_range[act_idx] is not None
):
min_force, max_force = self.actuator_force_range[act_idx]
tau = np.clip(tau, min_force, max_force)
torque_mu[mu_idx] = np.float32(tau)
num_ref_dofs = len(self.dof_names_ref_motion)
torque_ref = np.zeros(num_ref_dofs, dtype=np.float32)
for mu_idx, ref_idx in enumerate(self.mu_to_ref):
torque_ref[ref_idx] = torque_mu[mu_idx]
return torque_ref
def _get_obs_ref_motion_states(self):
# [2 * num_actions] in ONNX order: [ref_pos, ref_vel]
if self.ref_dof_pos is None or self.ref_dof_vel is None:
return np.zeros(2 * self.num_actions, dtype=np.float32)
frame_idx = self.motion_frame_idx
ref_pos_mu = self.ref_dof_pos[frame_idx]
ref_vel_mu = self.ref_dof_vel[frame_idx]
# Map URDF/Mu order -> ONNX order using precomputed indices
ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32)
ref_vel_onnx = ref_vel_mu[self.ref_to_onnx].astype(np.float32)
return np.concatenate([ref_pos_onnx, ref_vel_onnx], axis=0).astype(
np.float32
)
def _get_obs_ref_motion_states_fut(self):
# [T, 2 * num_actions] flattened, ONNX order
T = int(self.n_fut_frames)
if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None:
return np.zeros(0, dtype=np.float32)
N = int(self.num_actions)
frame_idx = self.motion_frame_idx
last_valid_frame_idx = self.n_motion_frames - 1
# Build future arrays in Mu order [N, T]
pos_fut = np.zeros(
(len(self.dof_names_ref_motion), T), dtype=np.float32
)
vel_fut = np.zeros(
(len(self.dof_names_ref_motion), T), dtype=np.float32
)
for i in range(T):
idx = frame_idx + i + 1
if idx < self.n_motion_frames:
pos_fut[:, i] = self.ref_dof_pos[idx]
vel_fut[:, i] = self.ref_dof_vel[idx]
else:
pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx]
vel_fut[:, i] = self.ref_dof_vel[last_valid_frame_idx]
# Reorder to ONNX and flatten per training layout
pos_fut_onnx = pos_fut[self.ref_to_onnx, :] # [N, T]
vel_fut_onnx = vel_fut[self.ref_to_onnx, :] # [N, T]
fut_concat = np.concatenate(
[pos_fut_onnx.T, vel_fut_onnx.T], axis=1
) # [T, 2N]
return fut_concat.reshape(-1).astype(np.float32)
def _get_obs_ref_dof_pos_fut(self):
# [T, 2 * num_actions] flattened, ONNX order
T = int(self.n_fut_frames)
if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None:
return np.zeros(0, dtype=np.float32)
frame_idx = self.motion_frame_idx
last_valid_frame_idx = self.n_motion_frames - 1
# Build future arrays in Mu order [N, T]
pos_fut = np.zeros(
(len(self.dof_names_ref_motion), T), dtype=np.float32
)
for i in range(T):
idx = frame_idx + i + 1
if idx < self.n_motion_frames:
pos_fut[:, i] = self.ref_dof_pos[idx]
else:
pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx]
# Reorder to ONNX and flatten per training layout
pos_fut_onnx = pos_fut[self.ref_to_onnx, :].transpose(1, 0) # [N, T]
return pos_fut_onnx.reshape(-1).astype(np.float32)
def _get_obs_ref_root_height_fut(self):
T = int(self.n_fut_frames)
if (
T <= 0
or self.ref_dof_pos is None
or self.ref_dof_vel is None
or getattr(self, "ref_global_translation", None) is None
):
return np.zeros(0, dtype=np.float32)
frame_idx = self.motion_frame_idx
last_valid_frame_idx = self.n_motion_frames - 1
# Build future arrays in Mu order [N, T]
h_fut = np.zeros((1, T), dtype=np.float32)
for i in range(T):
idx = frame_idx + i + 1
if idx < self.n_motion_frames:
h_fut[:, i] = self.ref_global_translation[
idx, self.root_body_idx, 2
]
else:
h_fut[:, i] = self.ref_global_translation[
last_valid_frame_idx, self.root_body_idx, 2
]
return h_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_dof_pos_cur(self):
# [2 * num_actions] in ONNX order: [ref_pos, ref_vel]
if self.ref_dof_pos is None or self.ref_dof_vel is None:
return np.zeros(2 * self.num_actions, dtype=np.float32)
ref_pos_mu = self.ref_dof_pos[self.motion_frame_idx]
# Map URDF/Mu order -> ONNX order using precomputed indices
ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32)
return ref_pos_onnx
def _get_obs_ref_dof_vel_cur(self):
# [2 * num_actions] in ONNX order: [ref_pos, ref_vel]
if self.ref_dof_pos is None or self.ref_dof_vel is None:
return np.zeros(2 * self.num_actions, dtype=np.float32)
ref_vel_mu = self.ref_dof_vel[self.motion_frame_idx]
# Map URDF/Mu order -> ONNX order using precomputed indices
ref_vel_onnx = ref_vel_mu[self.ref_to_onnx].astype(np.float32)
return ref_vel_onnx
def _get_obs_ref_motion_filter_cutoff_hz(self):
# cutoff = getattr(self, "filter_cutoff_hz", None)
cutoff = 1.0
if cutoff is None:
return np.float32(0.0)
cutoff_flat = np.asarray(cutoff, dtype=np.float32).reshape(-1)
if cutoff_flat.size == 0:
return np.float32(0.0)
frame_idx = min(
max(int(getattr(self, "motion_frame_idx", 0)), 0),
cutoff_flat.size - 1,
)
return np.float32(cutoff_flat[frame_idx])
def _get_obs_ref_root_height_cur(self):
if getattr(self, "ref_global_translation", None) is None:
return 0.0
return self.ref_global_translation[
self.motion_frame_idx, self.root_body_idx, 2
]
def _get_obs_ref_gravity_projection_cur(self):
if (
getattr(self, "ref_global_rotation_quat_xyzw", None) is None
or self.n_motion_frames <= 0
):
return np.zeros(3, dtype=np.float32)
q_root_xyzw = self.ref_global_rotation_quat_xyzw[
self.motion_frame_idx, self.root_body_idx
].astype(np.float32)
q_root_wxyz = xyzw_to_wxyz(
torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu")
)
q_root_wxyz = standardize_quaternion(q_root_wxyz)
g_w = torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device="cpu")
g_root = quat_apply(quat_inv(q_root_wxyz), g_w)
return g_root.detach().cpu().numpy().astype(np.float32)
def _get_obs_ref_gravity_projection_fut(self):
T = int(self.n_fut_frames)
if (
T <= 0
or getattr(self, "ref_global_rotation_quat_xyzw", None) is None
or self.n_motion_frames <= 0
):
return np.zeros(0, dtype=np.float32)
frame_idx = self.motion_frame_idx
last_valid_frame_idx = self.n_motion_frames - 1
g_w = torch.tensor([0.0, 0.0, -1.0], dtype=torch.float32, device="cpu")
gravity_fut = np.zeros((T, 3), dtype=np.float32)
for i in range(T):
idx = frame_idx + i + 1
if idx >= self.n_motion_frames:
idx = last_valid_frame_idx
q_root_xyzw = self.ref_global_rotation_quat_xyzw[
idx, self.root_body_idx
].astype(np.float32)
q_root_wxyz = xyzw_to_wxyz(
torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu")
)
q_root_wxyz = standardize_quaternion(q_root_wxyz)
gravity_fut[i] = (
quat_apply(quat_inv(q_root_wxyz), g_w)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)
return gravity_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_base_linvel_cur(self):
if (
getattr(self, "ref_global_rotation_quat_xyzw", None) is None
or getattr(self, "ref_global_velocity", None) is None
or self.n_motion_frames <= 0
):
return np.zeros(3, dtype=np.float32)
q_root_xyzw = self.ref_global_rotation_quat_xyzw[
self.motion_frame_idx, self.root_body_idx
].astype(np.float32)
q_root_wxyz = xyzw_to_wxyz(
torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu")
)
q_root_wxyz = standardize_quaternion(q_root_wxyz)
v_root_w = torch.as_tensor(
self.ref_global_velocity[
self.motion_frame_idx, self.root_body_idx
],
dtype=torch.float32,
device="cpu",
)
v_root = quat_apply(quat_inv(q_root_wxyz), v_root_w)
return v_root.detach().cpu().numpy().astype(np.float32)
def _get_obs_ref_base_linvel_fut(self):
T = int(self.n_fut_frames)
if (
T <= 0
or getattr(self, "ref_global_rotation_quat_xyzw", None) is None
or getattr(self, "ref_global_velocity", None) is None
or self.n_motion_frames <= 0
):
return np.zeros(0, dtype=np.float32)
frame_idx = self.motion_frame_idx
last_valid_frame_idx = self.n_motion_frames - 1
base_linvel_fut = np.zeros((T, 3), dtype=np.float32)
for i in range(T):
idx = frame_idx + i + 1
if idx >= self.n_motion_frames:
idx = last_valid_frame_idx
q_root_xyzw = self.ref_global_rotation_quat_xyzw[
idx, self.root_body_idx
].astype(np.float32)
q_root_wxyz = xyzw_to_wxyz(
torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu")
)
q_root_wxyz = standardize_quaternion(q_root_wxyz)
v_root_w = torch.as_tensor(
self.ref_global_velocity[idx, self.root_body_idx],
dtype=torch.float32,
device="cpu",
)
base_linvel_fut[i] = (
quat_apply(quat_inv(q_root_wxyz), v_root_w)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)
return base_linvel_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_base_angvel_cur(self):
if (
getattr(self, "ref_global_rotation_quat_xyzw", None) is None
or getattr(self, "ref_global_angular_velocity", None) is None
or self.n_motion_frames <= 0
):
return np.zeros(3, dtype=np.float32)
q_root_xyzw = self.ref_global_rotation_quat_xyzw[
self.motion_frame_idx, self.root_body_idx
].astype(np.float32)
q_root_wxyz = xyzw_to_wxyz(
torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu")
)
q_root_wxyz = standardize_quaternion(q_root_wxyz)
w_root_w = torch.as_tensor(
self.ref_global_angular_velocity[
self.motion_frame_idx, self.root_body_idx
],
dtype=torch.float32,
device="cpu",
)
w_root = quat_apply(quat_inv(q_root_wxyz), w_root_w)
return w_root.detach().cpu().numpy().astype(np.float32)
def _get_obs_ref_base_angvel_fut(self):
T = int(self.n_fut_frames)
if (
T <= 0
or getattr(self, "ref_global_rotation_quat_xyzw", None) is None
or getattr(self, "ref_global_angular_velocity", None) is None
or self.n_motion_frames <= 0
):
return np.zeros(0, dtype=np.float32)
frame_idx = self.motion_frame_idx
last_valid_frame_idx = self.n_motion_frames - 1
base_angvel_fut = np.zeros((T, 3), dtype=np.float32)
for i in range(T):
idx = frame_idx + i + 1
if idx >= self.n_motion_frames:
idx = last_valid_frame_idx
q_root_xyzw = self.ref_global_rotation_quat_xyzw[
idx, self.root_body_idx
].astype(np.float32)
q_root_wxyz = xyzw_to_wxyz(
torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu")
)
q_root_wxyz = standardize_quaternion(q_root_wxyz)
w_root_w = torch.as_tensor(
self.ref_global_angular_velocity[idx, self.root_body_idx],
dtype=torch.float32,
device="cpu",
)
base_angvel_fut[i] = (
quat_apply(quat_inv(q_root_wxyz), w_root_w)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)
return base_angvel_fut.reshape(-1).astype(np.float32)
def _get_obs_ref_keybody_rel_pos_cur(self):
keybody_idxs = self._get_ref_keybody_indices(
"actor_ref_keybody_rel_pos_cur"
)
n_keybodies = int(keybody_idxs.shape[0])
if n_keybodies == 0:
return np.zeros(0, dtype=np.float32)
if (
getattr(self, "ref_global_translation", None) is None
or getattr(self, "ref_global_rotation_quat_xyzw", None) is None
or self.n_motion_frames <= 0
):
return np.zeros(n_keybodies * 3, dtype=np.float32)
frame_idx = self.motion_frame_idx
ref_body_global_pos = self.ref_global_translation[frame_idx].astype(
np.float32
) # [B, 3]
ref_root_global_pos = ref_body_global_pos[
self.root_body_idx
] # [3], world
q_root_xyzw = self.ref_global_rotation_quat_xyzw[
frame_idx, self.root_body_idx
].astype(np.float32)
q_root_wxyz = xyzw_to_wxyz(
torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu")
)
q_root_wxyz = standardize_quaternion(q_root_wxyz)
rel_pos_w = (
ref_body_global_pos[keybody_idxs] - ref_root_global_pos[None, :]
) # [K, 3]
rel_pos_w_t = torch.as_tensor(
rel_pos_w, dtype=torch.float32, device="cpu"
)
q_root_expand = q_root_wxyz.unsqueeze(0).expand(n_keybodies, 4)
rel_pos_root_t = quat_apply(quat_inv(q_root_expand), rel_pos_w_t)
return (
rel_pos_root_t.detach()
.cpu()
.numpy()
.reshape(-1)
.astype(np.float32)
)
def _get_obs_ref_keybody_rel_pos_fut(self):
T = int(self.n_fut_frames)
if T <= 0:
return np.zeros(0, dtype=np.float32)
keybody_idxs = self._get_ref_keybody_indices(
"actor_ref_keybody_rel_pos_fut"
)
n_keybodies = int(keybody_idxs.shape[0])
if n_keybodies == 0:
return np.zeros((T, 0), dtype=np.float32)
if (
getattr(self, "ref_global_translation", None) is None
or getattr(self, "ref_global_rotation_quat_xyzw", None) is None
or self.n_motion_frames <= 0
):
return np.zeros((T, n_keybodies * 3), dtype=np.float32)
frame_idx = self.motion_frame_idx
last_valid_frame_idx = self.n_motion_frames - 1
rel_pos_fut = np.zeros((T, n_keybodies, 3), dtype=np.float32)
for i in range(T):
idx = frame_idx + i + 1
if idx >= self.n_motion_frames:
idx = last_valid_frame_idx
ref_body_global_pos = self.ref_global_translation[idx].astype(
np.float32
) # [B, 3]
ref_root_global_pos = ref_body_global_pos[
self.root_body_idx
] # [3], world
q_root_xyzw = self.ref_global_rotation_quat_xyzw[
idx, self.root_body_idx
].astype(np.float32)
q_root_wxyz = xyzw_to_wxyz(
torch.as_tensor(q_root_xyzw, dtype=torch.float32, device="cpu")
)
q_root_wxyz = standardize_quaternion(q_root_wxyz)
rel_pos_w = (
ref_body_global_pos[keybody_idxs]
- ref_root_global_pos[None, :]
) # [K, 3]
rel_pos_w_t = torch.as_tensor(
rel_pos_w, dtype=torch.float32, device="cpu"
)
q_root_expand = q_root_wxyz.unsqueeze(0).expand(n_keybodies, 4)
rel_pos_fut[i] = (
quat_apply(quat_inv(q_root_expand), rel_pos_w_t)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)
return rel_pos_fut.reshape(T, -1).astype(np.float32)
def _get_obs_place_holder(self):
return np.zeros(self.actor_place_holder_ndim, dtype=np.float32)
def _get_obs_vr_ref_motion_states(self):
# [2 * num_actions] in ONNX order: [ref_pos, ref_vel]
if self.ref_dof_pos is None or self.ref_dof_vel is None:
return np.zeros(2 * self.num_actions, dtype=np.float32)
frame_idx = self.motion_frame_idx
ref_pos_mu = self.ref_dof_pos[frame_idx]
# Map URDF/Mu order -> ONNX order using precomputed indices
ref_pos_onnx = ref_pos_mu[self.ref_to_onnx].astype(np.float32)
return np.concatenate(
[ref_pos_onnx, np.zeros_like(ref_pos_onnx)],
axis=0,
).astype(np.float32)
def _get_obs_vr_ref_motion_states_fut(self):
# [T, 2 * num_actions] flattened, ONNX order
T = int(self.n_fut_frames)
if T <= 0 or self.ref_dof_pos is None or self.ref_dof_vel is None:
return np.zeros(0, dtype=np.float32)
N = int(self.num_actions)
frame_idx = self.motion_frame_idx
last_valid_frame_idx = self.n_motion_frames - 1
# Build future arrays in Mu order [N, T]
pos_fut = np.zeros(
(len(self.dof_names_ref_motion), T), dtype=np.float32
)
for i in range(T):
idx = frame_idx + i + 1
if idx < self.n_motion_frames:
pos_fut[:, i] = self.ref_dof_pos[idx]
else:
pos_fut[:, i] = self.ref_dof_pos[last_valid_frame_idx]
# Reorder to ONNX and flatten per training layout
pos_fut_onnx = pos_fut[self.ref_to_onnx, :] # [N, T]
fut_concat = np.concatenate(
[pos_fut_onnx.T, np.zeros_like(pos_fut_onnx.T)], axis=1
) # [T, 2N]
return fut_concat.reshape(-1).astype(np.float32)
def _get_obs_rel_robot_root_ang_vel(self):
q_root_wxyz = torch.as_tensor(
self.robot_global_bodylink_rot[self.root_body_idx],
dtype=torch.float32,
device="cpu",
)
w_root_w = torch.as_tensor(
self.robot_global_bodylink_ang_vel[self.root_body_idx],
dtype=torch.float32,
device="cpu",
)
w_root_b = quat_apply(quat_inv(q_root_wxyz), w_root_w)
return w_root_b.detach().cpu().numpy().astype(np.float32)
def _get_obs_last_action(self):
return np.array(self.actions_onnx, dtype=np.float32).reshape(-1)
def _get_obs_velocity_command(self):
# Extended velocity command: [move_mask, vx, vy, vyaw]
if (
self.command_mode == "velocity_tracking"
and getattr(self, "keyboard_handler", None) is not None
):
cmd = np.asarray(
self.keyboard_handler.get_velocity_command(), dtype=np.float32
).reshape(3)
else:
cmd = np.zeros(3, dtype=np.float32)
out = np.zeros(4, dtype=np.float32)
out[1:] = cmd
out[0] = float(np.linalg.norm(cmd) > 0.1)
return out
def _get_obs_actor_ref_headling_aligned_vel_cmd(self):
return self._get_obs_velocity_command()
# ----------------- Actor term aliases (PULSE stage2 unified obs) -----------------
def _get_obs_actor_velocity_command(self):
return self._get_obs_velocity_command()
def _get_obs_actor_projected_gravity(self):
return self._get_obs_projected_gravity()
def _get_obs_actor_rel_robot_root_ang_vel(self):
return self._get_obs_rel_robot_root_ang_vel()
def _get_obs_actor_dof_pos(self):
return self._get_obs_dof_pos()
def _get_obs_actor_dof_vel(self):
return self._get_obs_dof_vel()
def _get_obs_actor_last_action(self):
return self._get_obs_last_action()
def _get_obs_actor_place_holder(self):
return self._get_obs_place_holder()
def _get_obs_actor_ref_dof_pos_fut(self):
return self._get_obs_ref_dof_pos_fut()
def _get_obs_actor_ref_dof_pos_cur(self):
return self._get_obs_ref_dof_pos_cur()
def _get_obs_actor_ref_motion_filter_cutoff_hz(self):
return self._get_obs_ref_motion_filter_cutoff_hz()
def _get_obs_actor_ref_root_height_fut(self):
return self._get_obs_ref_root_height_fut()
def _get_obs_actor_ref_root_height_cur(self):
return self._get_obs_ref_root_height_cur()
def _get_obs_actor_ref_gravity_projection_cur(self):
return self._get_obs_ref_gravity_projection_cur()
def _get_obs_actor_ref_gravity_projection_fut(self):
return self._get_obs_ref_gravity_projection_fut()
def _get_obs_actor_ref_base_linvel_cur(self):
return self._get_obs_ref_base_linvel_cur()
def _get_obs_actor_ref_base_linvel_fut(self):
return self._get_obs_ref_base_linvel_fut()
def _get_obs_actor_ref_base_angvel_cur(self):
return self._get_obs_ref_base_angvel_cur()
def _get_obs_actor_ref_base_angvel_fut(self):
return self._get_obs_ref_base_angvel_fut()
def _get_obs_actor_ref_keybody_rel_pos_cur(self):
return self._get_obs_ref_keybody_rel_pos_cur()
def _get_obs_actor_ref_keybody_rel_pos_fut(self):
return self._get_obs_ref_keybody_rel_pos_fut()
def _get_obs_global_anchor_diff(self):
self._ensure_ref_to_sim_transform_rigid()
ref_pos_sim = self.ref_global_bodylink_pos
ref_rot_sim = self.ref_global_bodylink_rot
if ref_pos_sim is None or ref_rot_sim is None:
return np.zeros(9, dtype=np.float32)
t_robot = torch.as_tensor(
self.robot_global_bodylink_pos[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
q_robot_wxyz = torch.as_tensor(
self.robot_global_bodylink_rot[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
t_ref_sim = torch.as_tensor(
ref_pos_sim[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
q_ref_sim = torch.as_tensor(
ref_rot_sim[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
# Use isaaclab semantics: pose of ref (frame 2) w.r.t. robot (frame 1)
p_diff_t, q_diff_wxyz_t = subtract_frame_transforms(
t01=t_robot,
q01=q_robot_wxyz,
t02=t_ref_sim,
q02=q_ref_sim,
)
q_diff_wxyz_t = quat_normalize_wxyz(q_diff_wxyz_t)
rot_diff_mat = matrix_from_quat(q_diff_wxyz_t)
out = torch.cat(
[p_diff_t.reshape(-1), rot_diff_mat[..., :2].reshape(-1)], dim=-1
)
return out.detach().cpu().numpy().astype(np.float32)
def _get_obs_global_anchor_pos_diff(self):
self._ensure_ref_to_sim_transform_rigid()
ref_pos_sim = self.ref_global_bodylink_pos
ref_rot_sim = self.ref_global_bodylink_rot
if ref_pos_sim is None or ref_rot_sim is None:
return np.zeros(3, dtype=np.float32)
t_robot = torch.as_tensor(
self.robot_global_bodylink_pos[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
) # [3], world
q_robot_wxyz = torch.as_tensor(
self.robot_global_bodylink_rot[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
) # [4], wxyz
# Transform reference anchor pose into simulation global frame
t_ref_sim = torch.as_tensor(
ref_pos_sim[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
q_ref_sim = torch.as_tensor(
ref_rot_sim[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
pos_diff_anchor_t, _ = subtract_frame_transforms(
t01=t_robot,
q01=q_robot_wxyz,
t02=t_ref_sim,
q02=q_ref_sim,
)
return pos_diff_anchor_t.detach().cpu().numpy().astype(np.float32)
def _get_obs_global_anchor_rot_diff(self):
self._ensure_ref_to_sim_transform_rigid()
ref_pos_sim = self.ref_global_bodylink_pos
ref_rot_sim = self.ref_global_bodylink_rot
if ref_pos_sim is None or ref_rot_sim is None:
return np.zeros(6, dtype=np.float32)
t_robot = torch.as_tensor(
self.robot_global_bodylink_pos[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
q_robot_wxyz = torch.as_tensor(
self.robot_global_bodylink_rot[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
q_robot_wxyz = standardize_quaternion(q_robot_wxyz)
t_ref_sim = torch.as_tensor(
ref_pos_sim[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
q_ref_sim = torch.as_tensor(
ref_rot_sim[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
q_ref_sim = standardize_quaternion(q_ref_sim)
_, q_diff_wxyz_t = subtract_frame_transforms(
t01=t_robot,
q01=q_robot_wxyz,
t02=t_ref_sim,
q02=q_ref_sim,
)
q_diff_wxyz_t = standardize_quaternion(q_diff_wxyz_t)
rot_diff_mat = matrix_from_quat(q_diff_wxyz_t)
return (
rot_diff_mat[..., :2]
.reshape(-1)
.detach()
.cpu()
.numpy()
.astype(np.float32)
)
def _get_obs_global_bodylink_translation(self) -> np.ndarray:
"""Global body translations in simulator/URDF order, flattened as [num_bodies * 3].
The body dimension excludes the MuJoCo world body and is assumed to match
the NPZ `*_global_translation` arrays (root at index 0).
"""
pos = self.robot_global_bodylink_pos.astype(np.float32) # [B, 3]
return pos.reshape(-1)
def _get_obs_global_bodylink_rotation_quat(self) -> np.ndarray:
"""Global body rotations as XYZW quaternions in simulator/URDF order, flattened [num_bodies * 4]."""
q_wxyz = self.robot_global_bodylink_rot # [B, 4] in w, x, y, z
q_xyzw = np.empty_like(q_wxyz, dtype=np.float32)
q_xyzw[..., 0] = q_wxyz[..., 1]
q_xyzw[..., 1] = q_wxyz[..., 2]
q_xyzw[..., 2] = q_wxyz[..., 3]
q_xyzw[..., 3] = q_wxyz[..., 0]
return q_xyzw.reshape(-1)
def _get_obs_global_bodylink_velocity(self) -> np.ndarray:
"""Global body linear velocities in world frame, flattened [num_bodies * 3]."""
lin_vel = self.robot_global_bodylink_lin_vel.astype(
np.float32
) # [B, 3]
return lin_vel.reshape(-1)
def _get_obs_global_bodylink_angular_velocity(self) -> np.ndarray:
"""Global body angular velocities in world frame, flattened [num_bodies * 3]."""
ang_vel = self.robot_global_bodylink_ang_vel.astype(
np.float32
) # [B, 3]
return ang_vel.reshape(-1)
@property
def ref_global_bodylink_pos(self) -> np.ndarray | None:
"""Reference body positions transformed into the simulator global frame.
Uses the yaw+translation Ref->Sim rigid transform computed from the initial robot
global pose so that the reference motion is expressed in the same world frame as
the robot (matching XY translation and yaw at frame 0).
Returns:
Array of shape [num_bodies, 3] giving reference positions in simulator world frame,
or None if reference globals are not available.
"""
if getattr(self, "ref_global_translation", None) is None:
return None
if self.n_motion_frames <= 0:
return None
self._ensure_ref_to_sim_transform_rigid()
frame_idx = self.ref_motion_frame_idx
ref_pos_world = self.ref_global_translation[frame_idx].astype(
np.float32
) # [B, 3]
pos_world_t = torch.as_tensor(
ref_pos_world, dtype=torch.float32, device="cpu"
)
q_ref_to_sim = torch.as_tensor(
self._ref_to_sim_q_wxyz, dtype=torch.float32, device="cpu"
)
q_ref_to_sim = q_ref_to_sim.unsqueeze(0).expand(
pos_world_t.shape[0], 4
)
t_ref_to_sim = torch.as_tensor(
self._ref_to_sim_t, dtype=torch.float32, device="cpu"
)
# Apply yaw rotation + translation based on initial robot state
pos_sim_t = (
quat_apply(q_ref_to_sim, pos_world_t) + t_ref_to_sim[None, :]
)
return pos_sim_t.detach().cpu().numpy().astype(np.float32)
@property
def ref_global_bodylink_rot(self) -> np.ndarray | None:
"""Reference body rotations transformed into the simulator global frame.
Uses the yaw component of the Ref->Sim transform so that the reference motion's
global yaw is aligned with the robot's initial yaw, while preserving roll/pitch
from the motion data.
Returns:
Array of shape [num_bodies, 4] giving reference orientations in WXYZ format,
or None if reference globals are not available.
"""
if getattr(self, "ref_global_rotation_quat_xyzw", None) is None:
return None
if self.n_motion_frames <= 0:
return None
frame_idx = self.ref_motion_frame_idx
ref_rot_xyzw = self.ref_global_rotation_quat_xyzw[frame_idx].astype(
np.float32
) # [B, 4] in XYZW
q_ref_xyzw_t = torch.as_tensor(
ref_rot_xyzw, dtype=torch.float32, device="cpu"
)
q_ref_wxyz_t = xyzw_to_wxyz(q_ref_xyzw_t)
q_ref_wxyz_t = standardize_quaternion(q_ref_wxyz_t)
q_ref_to_sim = torch.as_tensor(
self._ref_to_sim_q_wxyz, dtype=torch.float32, device="cpu"
)
q_ref_to_sim = q_ref_to_sim.unsqueeze(0).expand_as(q_ref_wxyz_t)
q_ref_sim_wxyz_t = quat_mul(q_ref_to_sim, q_ref_wxyz_t)
q_ref_sim_wxyz_t = standardize_quaternion(q_ref_sim_wxyz_t)
return q_ref_sim_wxyz_t.detach().cpu().numpy().astype(np.float32)
def _draw_ref_body_spheres_to_scene(
self, scene, reset_ngeom: bool
) -> None:
"""Draw blue spheres at reference body positions into a MuJoCo scene."""
ref_positions_sim = self.ref_global_bodylink_pos
if ref_positions_sim is None:
if reset_ngeom:
scene.ngeom = 0
return
if reset_ngeom:
scene.ngeom = 0
radius = float(self.config.get("ref_marker_radius", 0.03))
rgba = np.array([0.8, 0.0, 0.0, 1.0], dtype=np.float32)
size = np.array([radius, 0.0, 0.0], dtype=np.float32)
mat = np.eye(3, dtype=np.float32).reshape(-1)
start = int(scene.ngeom)
idx = 0
for pos in ref_positions_sim:
geom_id = start + idx
if geom_id >= scene.maxgeom:
break
mujoco.mjv_initGeom(
scene.geoms[geom_id],
mujoco.mjtGeom.mjGEOM_SPHERE,
size,
pos.astype(np.float32),
mat,
rgba,
)
idx += 1
scene.ngeom = start + idx
def _get_obs_rel_anchor_lin_vel(self):
# Anchor linear velocity expressed in the anchor frame (IsaacLab semantics)
q_anchor_wxyz = torch.as_tensor(
self.robot_global_bodylink_rot[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
)
v_local_t = quat_apply(
quat_inv(q_anchor_wxyz),
torch.as_tensor(
self.robot_global_bodylink_lin_vel[self.anchor_body_idx],
dtype=torch.float32,
device="cpu",
),
)
return v_local_t.detach().cpu().numpy().astype(np.float32)
def _get_obs_projected_gravity(self):
q = torch.as_tensor(
self.robot_global_bodylink_rot[self.root_body_idx],
dtype=torch.float32,
)
qw, qx, qy, qz = q[0], q[1], q[2], q[3]
gravity_orientation = torch.zeros(3, dtype=torch.float32, device="cpu")
gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2.0 * (qz * qy + qw * qx)
gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz)
return gravity_orientation.detach().cpu().numpy().astype(np.float32)
def _get_obs_dof_pos(self):
pos_mu = self.robot_dof_pos
pos_onnx = pos_mu[self.mu_to_onnx]
return (pos_onnx - self.default_angles_onnx.astype(np.float32)).astype(
np.float32
)
def _get_obs_dof_vel(self):
vel_mu = self.robot_dof_vel
vel_onnx = vel_mu[self.mu_to_onnx]
return vel_onnx.astype(np.float32)
def _record_robot_states(self) -> None:
"""Record current robot DOF and global body states for offline NPZ dumping.
- DOF states are stored in reference DOF order (config.robot.dof_names).
- Body states are stored in dataset/URDF order (config.robot.body_names).
"""
if self.command_mode != "motion_tracking":
return
if self.ref_dof_pos is None or self.n_motion_frames <= 0:
return
if len(self._robot_dof_pos_seq) >= self.n_motion_frames:
return
# Joint positions/velocities from Unitree lowstate in actuator (MuJoCo) order
pos_mu = self.robot_dof_pos
vel_mu = self.robot_dof_vel
# Map MuJoCo actuator order -> reference DOF order
num_dofs = len(self.dof_names_ref_motion)
pos_ref = np.zeros(num_dofs, dtype=np.float32)
vel_ref = np.zeros(num_dofs, dtype=np.float32)
for mu_idx, ref_idx in enumerate(self.mu_to_ref):
pos_ref[ref_idx] = pos_mu[mu_idx]
vel_ref[ref_idx] = vel_mu[mu_idx]
self._robot_dof_pos_seq.append(pos_ref)
self._robot_dof_vel_seq.append(vel_ref)
if self._prev_recorded_dof_vel_ref is None:
acc_ref = np.zeros_like(vel_ref, dtype=np.float32)
else:
acc_ref = (vel_ref - self._prev_recorded_dof_vel_ref) / np.float32(
self.policy_dt
)
self._prev_recorded_dof_vel_ref = vel_ref.copy()
self._robot_dof_acc_seq.append(acc_ref.astype(np.float32))
# Global bodylink states in dataset/URDF order
body_count = int(self.robot_global_bodylink_pos.shape[0])
trans = self._get_obs_global_bodylink_translation().reshape(
body_count, 3
)
rot = self._get_obs_global_bodylink_rotation_quat().reshape(
body_count, 4
)
vel = self._get_obs_global_bodylink_velocity().reshape(body_count, 3)
ang_vel = self._get_obs_global_bodylink_angular_velocity().reshape(
body_count, 3
)
self._robot_global_translation_seq.append(trans)
self._robot_global_rotation_quat_seq.append(rot)
self._robot_global_velocity_seq.append(vel)
self._robot_global_angular_velocity_seq.append(ang_vel)
def load_specific_motion(self, npz_path):
with np.load(npz_path, allow_pickle=True) as npz:
self.ref_global_translation = npz["ref_global_translation"]
self.ref_global_rotation_quat_xyzw = npz[
"ref_global_rotation_quat"
]
self.ref_global_velocity = npz["ref_global_velocity"]
self.ref_global_angular_velocity = npz[
"ref_global_angular_velocity"
]
self.ref_dof_pos = npz["ref_dof_pos"]
self.ref_dof_vel = npz["ref_dof_vel"]
raw_filter_cutoff_hz = (
np.array(npz["filter_cutoff_hz"]).astype(np.float32)
if "filter_cutoff_hz" in npz
else None
)
self.n_motion_frames = self.ref_global_translation.shape[0]
# self.filter_cutoff_hz = self._normalize_filter_cutoff_hz(
# raw_filter_cutoff_hz, self.n_motion_frames
# )
self.filter_cutoff_hz = 1.0
self._ref_to_sim_q_wxyz = np.array(
[1.0, 0.0, 0.0, 0.0], dtype=np.float32
)
self._ref_to_sim_t = np.zeros(3, dtype=np.float32)
self._ref_to_sim_ready = True
def reset_state_teleport(self):
self.counter = 0
self.motion_frame_idx = 0
mujoco.mj_resetDataKeyframe(self.m, self.d, 0)
has_ref_motion = (
self.ref_dof_pos is not None
and self.ref_dof_vel is not None
and self.ref_global_translation is not None
and self.ref_global_rotation_quat_xyzw is not None
and self.ref_global_velocity is not None
and self.ref_global_angular_velocity is not None
)
if has_ref_motion:
root_pos = self.ref_global_translation[0, 0] # (x, y, z)
root_rot = self.ref_global_rotation_quat_xyzw[0, 0] # XYZW
root_vel = self.ref_global_velocity[0, 0]
root_ang = self.ref_global_angular_velocity[0, 0]
dof_pos = getattr(
self, "stored_full_ref_dof_pos", self.ref_dof_pos
)[0]
dof_vel = getattr(
self, "stored_full_ref_dof_vel", self.ref_dof_vel
)[0]
self.d.qpos[0:3] = root_pos
self.d.qpos[3:7] = [
root_rot[3],
root_rot[0],
root_rot[1],
root_rot[2],
] # XYZW -> WXYZ
self.d.qpos[self.actuator_qpos_indices] = dof_pos[self.mu_to_ref]
self.d.qvel[0:3] = root_vel
self.d.qvel[3:6] = root_ang
self.d.qvel[self.actuator_qvel_indices] = dof_vel[self.mu_to_ref]
self.target_dof_pos_mu = dof_pos[self.mu_to_ref].astype(np.float32)
logger.info(
"Teleport reset initialized from reference frame 0 "
"(root + dof pos/vel)"
)
else:
self.d.qpos[self.actuator_qpos_indices] = self.default_angles_mu
self.d.qvel[self.actuator_qvel_indices] = 0.0
self.target_dof_pos_mu = self.default_angles_mu.astype(np.float32)
logger.info(
"Teleport reset initialized from ONNX default joint positions"
)
self.target_dof_pos_by_name = {
self.mjcf_dof_names[i]: float(self.target_dof_pos_mu[i])
for i in range(self.m.nu)
}
mujoco.mj_forward(self.m, self.d)
if self.use_kv_cache and self.policy_kv_shape:
shape = [
d if isinstance(d, int) else 1 for d in self.policy_kv_shape
]
self.policy_kv_cache = np.zeros(shape, dtype=np.float32)
self._robot_dof_pos_seq = []
self._robot_dof_vel_seq = []
self._robot_dof_acc_seq = []
self._robot_dof_torque_seq = []
self._robot_low_level_dof_torque_seq = []
self._robot_low_level_foot_contact_seq = []
self._robot_low_level_foot_normal_force_seq = []
self._robot_low_level_foot_tangent_speed_seq = []
self._robot_actions_seq = []
self._robot_action_rate_seq = []
self._prev_recorded_dof_vel_ref = None
self._prev_actions_onnx = None
self._reset_action_ema_filter()
self._reset_action_delay_randomization()
self._prev_low_level_foot_geom_centers = None
self._robot_global_translation_seq = []
self._robot_global_rotation_quat_seq = []
self._robot_global_velocity_seq = []
self._robot_global_angular_velocity_seq = []
self._robot_moe_expert_indices_seq = []
self._robot_moe_expert_logits_seq = []
self._reset_onnx_io_dump_buffers()
def save_batch_result(self, output_path, meta_info):
import json
metadata = dict(meta_info)
metadata.setdefault(
"robot_low_level_torque_dt",
float(getattr(self, "simulation_dt", 1.0 / 200.0)),
)
metadata.setdefault(
"robot_low_level_contact_dt",
float(getattr(self, "simulation_dt", 1.0 / 200.0)),
)
robot_moe_expert_indices, robot_moe_expert_logits = (
self._get_stacked_moe_routing_tensors()
)
(
robot_low_level_foot_contact,
robot_low_level_foot_normal_force,
robot_low_level_foot_tangent_speed,
) = self._get_stacked_low_level_foot_contact_tensors()
res = {
"robot_dof_pos": np.stack(self._robot_dof_pos_seq),
"robot_dof_vel": np.stack(self._robot_dof_vel_seq),
"robot_dof_acc": np.stack(self._robot_dof_acc_seq),
"robot_dof_torque": np.stack(self._robot_dof_torque_seq),
"robot_low_level_dof_torque": np.stack(
self._robot_low_level_dof_torque_seq
),
"robot_low_level_foot_contact": robot_low_level_foot_contact,
"robot_low_level_foot_normal_force": (
robot_low_level_foot_normal_force
),
"robot_low_level_foot_tangent_speed": (
robot_low_level_foot_tangent_speed
),
"robot_low_level_torque_dt": np.array(
getattr(self, "simulation_dt", 1.0 / 200.0), dtype=np.float32
),
"robot_low_level_contact_dt": np.array(
getattr(self, "simulation_dt", 1.0 / 200.0), dtype=np.float32
),
"robot_action_rate": np.asarray(
self._robot_action_rate_seq, dtype=np.float32
),
"robot_global_translation": np.stack(
self._robot_global_translation_seq
),
"robot_global_rotation_quat": np.stack(
self._robot_global_rotation_quat_seq
),
"robot_global_velocity": np.stack(self._robot_global_velocity_seq),
"robot_global_angular_velocity": np.stack(
self._robot_global_angular_velocity_seq
),
"ref_dof_pos": self.ref_dof_pos,
"ref_dof_vel": self.ref_dof_vel,
"ref_global_translation": self.ref_global_translation,
"ref_global_rotation_quat": self.ref_global_rotation_quat_xyzw,
"ref_global_velocity": self.ref_global_velocity,
"ref_global_angular_velocity": self.ref_global_angular_velocity,
"metadata": json.dumps(metadata),
}
if len(getattr(self, "_robot_actions_seq", [])) > 0:
res["robot_actions"] = np.stack(
self._robot_actions_seq, axis=0
).astype(np.float32)
if robot_moe_expert_indices is not None:
res["robot_moe_expert_indices"] = robot_moe_expert_indices
if robot_moe_expert_logits is not None:
res["robot_moe_expert_logits"] = robot_moe_expert_logits
np.savez_compressed(output_path, **res)
def setup(self):
"""Set up the evaluator by loading all required components."""
self.load_mujoco_model()
self._init_low_level_foot_contact_logging()
self._build_mjcf_dof_names()
self.load_policy()
self._apply_onnx_metadata()
self._build_actuator_qpos_indices()
self._build_dof_mappings()
self._build_actuator_name_map()
self._build_actuator_force_range_map()
self._init_camera_config()
self._init_obs_buffers()
# Initialize keyboard handler for velocity tracking
if self.command_mode == "velocity_tracking":
self.keyboard_handler = VelocityKeyboardHandler(
vx_increment=0.1,
vy_increment=0.05,
vyaw_increment=0.05,
vx_limits=(-0.5, 1.0),
vy_limits=(-0.3, 0.3),
vyaw_limits=(-0.5, 0.5),
)
logger.info(
"Velocity tracking mode enabled. Keyboard controls:\n"
" W/S: forward/backward velocity\n"
" A/D: left/right velocity\n"
" J/L: turn left/right\n"
" Space/X: reset all\n"
" Keep terminal window focused for keyboard input"
)
elif self.command_mode == "motion_tracking":
m_path = self.config.get("motion_npz_path", "")
if m_path and os.path.isfile(m_path):
self.load_motion_data()
def _create_eval_progress_bar(self, desc: str, max_steps: int):
if self.ref_dof_pos is not None:
return tqdm(total=self.n_motion_frames, desc=desc, unit="frame")
if max_steps > 0:
return tqdm(total=max_steps, desc=desc, unit="step")
return None
def _advance_eval_frame(self, max_steps: int) -> bool:
if self.ref_dof_pos is not None:
if self.motion_frame_idx >= (self.n_motion_frames - 1):
return False
self.motion_frame_idx += 1
return True
if max_steps > 0 and self.counter >= max_steps:
return False
return True
def _run_eval_step(self, max_steps: int) -> bool:
self._update_policy()
self.counter += 1
self._apply_control(sleep=True)
if self._video_writer is not None:
self._maybe_record_frame()
return self._advance_eval_frame(max_steps)
def _build_mjcf_dof_names(self):
"""Build MJCF joint name lists used for control/state indexing.
- mjcf_dof_names: joint names corresponding to each actuator (actuator order)
"""
names = []
for i in range(self.m.nu):
j_id = int(self.m.actuator_trnid[i][0])
j_name = mujoco.mj_id2name(
self.m, mujoco._enums.mjtObj.mjOBJ_JOINT, j_id
)
names.append(j_name)
self.mjcf_dof_names = names
def _build_actuator_qpos_indices(self):
"""Build mapping from actuator index to qpos/qvel indices."""
self.actuator_qpos_indices = np.zeros(self.m.nu, dtype=np.int32)
self.actuator_qvel_indices = np.zeros(self.m.nu, dtype=np.int32)
for i in range(self.m.nu):
j_id = int(self.m.actuator_trnid[i, 0])
self.actuator_qpos_indices[i] = self.m.jnt_qposadr[j_id]
self.actuator_qvel_indices[i] = self.m.jnt_dofadr[j_id]
def _build_actuator_name_map(self):
"""Build mappings from actuator name to indices and MJCF DOF indices."""
name_to_index = {}
actuator_name_to_mu_idx = {}
for i in range(self.m.nu):
act_name = mujoco.mj_id2name(
self.m, mujoco._enums.mjtObj.mjOBJ_ACTUATOR, i
)
name_to_index[act_name] = i
j_id = int(self.m.actuator_trnid[i][0])
j_name = mujoco.mj_id2name(
self.m, mujoco._enums.mjtObj.mjOBJ_JOINT, j_id
)
mu_idx = self.mjcf_dof_names.index(j_name)
actuator_name_to_mu_idx[act_name] = mu_idx
self.actuator_name_to_index = name_to_index
self.actuator_name_to_mu_idx = actuator_name_to_mu_idx
def _build_actuator_force_range_map(self):
"""Build mapping from actuator index to joint actuator force range from XML."""
self.actuator_force_range = {}
for i in range(self.m.nu):
j_id = int(self.m.actuator_trnid[i][0])
has_limit = False
min_force = 0.0
max_force = 0.0
if j_id >= 0 and j_id < self.m.njnt:
if self.m.jnt_actfrclimited[j_id]:
min_force = float(self.m.jnt_actfrcrange[j_id][0])
max_force = float(self.m.jnt_actfrcrange[j_id][1])
if min_force != 0.0 or max_force != 0.0:
has_limit = True
if not has_limit:
if self.m.actuator_forcelimited[i]:
min_force = float(self.m.actuator_forcerange[i][0])
max_force = float(self.m.actuator_forcerange[i][1])
if min_force != 0.0 or max_force != 0.0:
has_limit = True
if has_limit:
self.actuator_force_range[i] = (min_force, max_force)
else:
self.actuator_force_range[i] = None
def run_simulation_unitree(self):
"""Run simulation using Unitree's official threading/viewer pattern."""
# Defer heavy deps to runtime to keep default path light
# Ensure thirdparty simulate_python is on sys.path for imports
self.counter = 0
self.motion_frame_idx = 0
self.reset_state_teleport()
max_steps = int(self.config.get("max_policy_steps", 0))
viewer_dt = float(self.config.get("unitree_viewer_dt", 1.0 / 60.0))
viewer = mujoco.viewer.launch_passive(self.m, self.d)
# Configure viewer camera to use shared align / tracking settings
self._configure_viewer_camera(viewer)
# Start keyboard listener for velocity tracking
if (
self.command_mode == "velocity_tracking"
and self.keyboard_handler is not None
):
self.keyboard_handler.start_listener()
# Optional recording in viewer mode
if bool(self.config.get("record_video", False)):
self._init_video_tools(tag="viewer")
pbar = self._create_eval_progress_bar("GUI eval", max_steps)
locker = threading.Lock()
stop_event = threading.Event()
def simulation_thread():
while viewer.is_running() and not stop_event.is_set():
with locker:
keep_running = self._run_eval_step(max_steps)
if pbar is not None:
pbar.update(1)
if not keep_running:
stop_event.set()
viewer.close()
def physics_viewer_thread():
while viewer.is_running() and not stop_event.is_set():
with locker:
# Update camera lookat to track robot root (with small offset for framing)
self._update_camera_lookat(viewer.cam)
# Draw reference global bodylink positions as blue spheres when available
self._draw_ref_body_spheres_to_scene(
viewer.user_scn, reset_ngeom=True
)
viewer.sync()
time.sleep(viewer_dt)
viewer_thread = Thread(target=physics_viewer_thread)
sim_thread = Thread(target=simulation_thread)
viewer_thread.start()
sim_thread.start()
# Block until viewer closes
viewer_thread.join()
sim_thread.join()
# Close progress bar
if pbar is not None:
pbar.close()
# Stop keyboard listener
if (
self.command_mode == "velocity_tracking"
and self.keyboard_handler is not None
):
self.keyboard_handler.stop_listener()
# Teardown recording
self._close_video_tools()
# Dump robot-augmented motion npz if motion tracking is enabled
self._dump_robot_augmented_npz()
def run_simulation_unitree_headless(self):
"""Run simulation headless (no GUI) with optional video recording."""
# Defer heavy deps to runtime to keep default path light
# Initialize
self.counter = 0
self.motion_frame_idx = 0
self.reset_state_teleport()
max_steps = int(self.config.get("max_policy_steps", 0))
# Start keyboard listener for velocity tracking (even in headless mode)
if (
self.command_mode == "velocity_tracking"
and self.keyboard_handler is not None
):
self.keyboard_handler.start_listener()
# Optional recording in headless mode
if bool(self.config.get("record_video", False)):
self._init_video_tools(tag="headless")
pbar = self._create_eval_progress_bar("Headless eval", max_steps)
running = True
while running:
running = self._run_eval_step(max_steps)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.close()
# Stop keyboard listener
if (
self.command_mode == "velocity_tracking"
and self.keyboard_handler is not None
):
self.keyboard_handler.stop_listener()
self._close_video_tools()
# Dump robot-augmented motion npz if motion tracking is enabled
self._dump_robot_augmented_npz()
def run_simulation(self):
if bool(self.config.get("headless", False)):
logger.info("Running MuJoCo sim2sim headless")
self.run_simulation_unitree_headless()
else:
self.run_simulation_unitree()
def _update_policy(self):
# Record robot states once per policy step for offline NPZ dumping
self._record_robot_states()
latest_obs = self.obs_builder.build_policy_obs()
policy_obs_np = latest_obs[None, :]
input_feed = {}
input_feed[self.policy_input_name] = policy_obs_np
if self.use_kv_cache:
if self.policy_kv_cache is None:
shape = [
d if isinstance(d, int) else 1
for d in self.policy_kv_shape
]
self.policy_kv_cache = np.zeros(shape, dtype=np.float32)
# if (
# self.policy_effective_context_len > 0
# and self.counter > 0
# and self.counter % self.policy_effective_context_len == 0
# ):
# self.policy_kv_cache.fill(0.0)
input_feed[self.policy_kv_input_name] = self.policy_kv_cache
if self.policy_step_input_name is not None:
step_idx = self.counter
if self.use_kv_cache and self.policy_effective_context_len > 0:
step_idx = self.counter % self.policy_effective_context_len
step_tensor = np.array([step_idx], dtype=np.int64)
input_feed[self.policy_step_input_name] = step_tensor
if torch.cuda.is_available():
torch.cuda.synchronize()
output_names = [self.policy_output_name]
if self.use_kv_cache and self.policy_kv_output_name:
output_names.append(self.policy_kv_output_name)
for _, indices_name, logits_name in self.policy_moe_layer_output_names:
output_names.extend([indices_name, logits_name])
onnx_output = self.policy_session.run(output_names, input_feed)
if self.dump_onnx_io_npy:
self._record_onnx_io_frame(input_feed, output_names, onnx_output)
if torch.cuda.is_available():
torch.cuda.synchronize()
raw_actions_onnx = onnx_output[0].reshape(-1)
filtered_actions_onnx = self._apply_action_ema_filter(raw_actions_onnx)
self.actions_onnx = self._apply_action_delay(filtered_actions_onnx)
if self.use_kv_cache and len(onnx_output) > 1:
new_cache = onnx_output[1]
self.policy_kv_cache = new_cache
output_offset = 1 + int(
bool(self.use_kv_cache and self.policy_kv_output_name)
)
if self.policy_moe_layer_output_names:
step_indices = []
step_logits = []
for (
_layer_idx,
_indices_name,
_logits_name,
) in self.policy_moe_layer_output_names:
step_indices.append(
self._flatten_single_step_output(
onnx_output[output_offset],
dtype=np.int64,
)
)
output_offset += 1
step_logits.append(
self._flatten_single_step_output(
onnx_output[output_offset],
dtype=np.float32,
)
)
output_offset += 1
self._robot_moe_expert_indices_seq.append(
np.stack(step_indices, axis=0)
)
self._robot_moe_expert_logits_seq.append(
np.stack(step_logits, axis=0)
)
self.target_dof_pos_onnx = (
self.actions_onnx * self.action_scale_onnx
+ self.default_angles_onnx
)
self.target_dof_pos_mu = self.target_dof_pos_onnx[self.onnx_to_mu]
for i, dof_name in enumerate(self.mjcf_dof_names):
self.target_dof_pos_by_name[dof_name] = float(
self.target_dof_pos_mu[i]
)
if (
self.command_mode == "motion_tracking"
and self.ref_dof_pos is not None
and len(self._robot_action_rate_seq) < len(self._robot_dof_pos_seq)
):
self._robot_actions_seq.append(
self.actions_onnx.astype(np.float32).copy()
)
if self._prev_actions_onnx is None:
action_rate = np.float32(0.0)
else:
action_rate = np.float32(
np.linalg.norm(self.actions_onnx - self._prev_actions_onnx)
/ self.policy_dt
)
self._prev_actions_onnx = self.actions_onnx.copy()
self._robot_action_rate_seq.append(action_rate)
self._robot_dof_torque_seq.append(
self._compute_pd_torque_command_ref()
)
def _get_config_value(config_obj, key: str):
value = config_obj.get(key, None)
if value is None and config_obj.get("eval", None) is not None:
value = config_obj.eval.get(key, None)
return value
def _normalize_ckpt_name_list(ckpt_onnx_names):
if ckpt_onnx_names is None:
return []
if isinstance(ckpt_onnx_names, ListConfig):
raw_names = list(ckpt_onnx_names)
elif isinstance(ckpt_onnx_names, (list, tuple)):
raw_names = list(ckpt_onnx_names)
else:
raise TypeError(
"ckpt_onnx_names must be a list/tuple, "
f"got {type(ckpt_onnx_names)}"
)
normalized_names = []
for name in raw_names:
name_str = str(name).strip()
if name_str != "":
normalized_names.append(name_str)
return normalized_names
def _resolve_multi_ckpt_paths(ckpt_onnx_root_dir, ckpt_onnx_names):
root_dir_str = str(ckpt_onnx_root_dir).strip()
if root_dir_str == "":
raise ValueError("ckpt_onnx_root_dir cannot be empty")
root_dir = Path(root_dir_str)
if not root_dir.is_dir():
raise NotADirectoryError(
f"ckpt_onnx_root_dir does not exist or is not a directory: {root_dir}"
)
requested_names = _normalize_ckpt_name_list(ckpt_onnx_names)
if len(requested_names) == 0:
raise ValueError(
"ckpt_onnx_names is empty. Please provide checkpoint names "
'like ["model_1000.onnx", "model_2000.onnx"].'
)
discovered_paths = sorted(root_dir.rglob("*.onnx"))
if len(discovered_paths) == 0:
raise FileNotFoundError(
f"No .onnx files found under ckpt_onnx_root_dir={root_dir}"
)
paths_by_name = {}
for path in discovered_paths:
if path.name not in paths_by_name:
paths_by_name[path.name] = []
paths_by_name[path.name].append(path)
selected_paths = []
missing_names = []
for name in requested_names:
candidates = paths_by_name.get(name, [])
if len(candidates) == 0:
missing_names.append(name)
continue
if len(candidates) > 1:
logger.warning(
f"Found {len(candidates)} ONNX files named '{name}' under "
f"{root_dir}; selecting the first one: {candidates[0]}"
)
selected_paths.append(candidates[0])
if len(missing_names) > 0:
logger.warning(
"Some requested checkpoints were not found under "
f"{root_dir}: {missing_names}"
)
if len(selected_paths) == 0:
raise FileNotFoundError(
"None of the requested checkpoints were found under "
f"{root_dir}. Requested names: {requested_names}"
)
return selected_paths
def _resolve_eval_ckpt_paths(config_obj):
ckpt_onnx_root_dir = _get_config_value(config_obj, "ckpt_onnx_root_dir")
if (
ckpt_onnx_root_dir is not None
and str(ckpt_onnx_root_dir).strip() != ""
):
ckpt_onnx_names = _get_config_value(config_obj, "ckpt_onnx_names")
return _resolve_multi_ckpt_paths(ckpt_onnx_root_dir, ckpt_onnx_names)
ckpt_onnx_path = _get_config_value(config_obj, "ckpt_onnx_path")
if ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == "":
raise ValueError(
"No ONNX checkpoint is provided. Set ckpt_onnx_path, or set "
"ckpt_onnx_root_dir + ckpt_onnx_names."
)
ckpt_path = Path(str(ckpt_onnx_path))
if not ckpt_path.is_file():
raise FileNotFoundError(f"ONNX checkpoint not found: {ckpt_path}")
return [ckpt_path]
def _checkpoint_tag_from_path(ckpt_path: Path) -> str:
match = re.search(r"model_(\d+)", ckpt_path.name)
if match:
return f"model_{match.group(1)}"
return ckpt_path.stem
def _build_eval_output_dir(ckpt_path: Path, dataset_name: str) -> Path:
ckpt_tag = _checkpoint_tag_from_path(ckpt_path)
dir_name = f"mujoco_eval_output_{ckpt_tag}_{dataset_name}"
return ckpt_path.parent.parent / dir_name
def _build_onnx_io_dump_dir(output_dir: str | Path) -> Path:
return Path(output_dir) / ONNX_IO_DUMP_DIRNAME
def _build_onnx_io_dump_path(
output_dir: str | Path, source_file: str | Path
) -> Path:
source_stem = Path(source_file).stem
return _build_onnx_io_dump_dir(output_dir) / f"{source_stem}_onnx_io.npy"
def _build_onnx_io_dump_readme_text() -> str:
return """# ONNX I/O 导出说明
本目录用于保存 MuJoCo sim2sim 评测过程中导出的 ONNX 输入输出数据。
## 文件组织
- 每个动作片段会生成一个 `.npy` 文件,文件名形如 `_onnx_io.npy`
- 每个 `.npy` 文件对应一个原始的动作片段 `.npz`
- 当前只支持默认的 `holomotion` / `MujocoEvaluator` 批量目录评测模式(`motion_npz_dir`)
## 读取方式
`.npy` 文件内部保存的是一个 Python `dict`,读取时需要开启 `allow_pickle=True`:
```python
import numpy as np
npy_path = "onnx_io_npy/example_clip_onnx_io.npy"
payload = np.load(npy_path, allow_pickle=True).item()
print(payload.keys())
print(payload["input_names"])
print(payload["output_names"])
print(payload["inputs"]["obs"].shape)
print(payload["outputs"]["action"].shape)
```
## 数据字段
- `input_names`: ONNX 实际输入张量名称列表
- `output_names`: ONNX 实际输出张量名称列表
- `inputs`: 按输入张量名称组织的字典,数组第 0 维是帧索引
- `outputs`: 按输出张量名称组织的字典,数组第 0 维是帧索引
- `source_npz`: 原始动作片段文件名
- `onnx_model`: 导出这些张量时使用的 ONNX 模型路径
## 说明
单个 `.npy` 文件只能保存一个顶层对象,因此这里使用 pickled dict 来同时保存输入名称、输出名称以及逐帧堆叠后的 numpy 数组。
如果某次导出未产生有效 ONNX I/O 数据,`inputs` 和 `outputs` 可能为空字典,读取时请先检查键是否存在。
"""
def write_onnx_io_dump_readme(output_dir: str | Path) -> Path:
output_dir_path = Path(output_dir)
output_dir_path.mkdir(parents=True, exist_ok=True)
readme_path = output_dir_path / "README.md"
readme_path.write_text(_build_onnx_io_dump_readme_text(), encoding="utf-8")
return readme_path
def _allocate_actor_counts(num_checkpoints: int, total_actors: int):
if num_checkpoints <= 0:
raise ValueError("num_checkpoints must be > 0")
if total_actors <= 0:
raise ValueError("total_actors must be > 0")
base = total_actors // num_checkpoints
rem = total_actors % num_checkpoints
return [base + (1 if i < rem else 0) for i in range(num_checkpoints)]
def _infer_step_from_ckpt_name(ckpt_name: str):
match = re.search(r"model_(\d+)", ckpt_name)
if match:
return int(match.group(1))
fallback = re.search(r"(\d+)", ckpt_name)
if fallback:
return int(fallback.group(1))
return None
def _read_total_macro_row(tsv_path: Path):
if not tsv_path.is_file():
return None
with open(tsv_path, "r", encoding="utf-8", newline="") as tsv_file:
reader = csv.DictReader(tsv_file, delimiter="\t")
for row in reader:
dataset_value = str(row.get("Dataset", "")).strip().lower()
if "total" in dataset_value and "macro" in dataset_value:
return row
return None
def _write_total_macro_summary_table(
eval_targets, job_log_dir: Path | None = None
):
rows_by_parent = {}
for target in eval_targets:
output_dir_path = Path(target["output_dir"])
ckpt_path = target["ckpt_path"]
tsv_path = output_dir_path / "sub_dataset_macro_mean_metrics.tsv"
total_row = _read_total_macro_row(tsv_path)
if total_row is None:
logger.warning(
"Skipping aggregated total metrics entry because "
f"Total (Macro) row is unavailable: {tsv_path}"
)
continue
parent_dir = output_dir_path.parent
if parent_dir not in rows_by_parent:
rows_by_parent[parent_dir] = []
rows_by_parent[parent_dir].append(
{
"step": _infer_step_from_ckpt_name(ckpt_path.stem),
"total_row": total_row,
"ckpt_name": ckpt_path.stem,
}
)
for parent_dir, entries in rows_by_parent.items():
if len(entries) == 0:
continue
entries.sort(
key=lambda item: (
item["step"] is None,
item["step"] if item["step"] is not None else 0,
item["ckpt_name"],
)
)
metric_columns = list(entries[0]["total_row"].keys())
available_steps = [
entry["step"] for entry in entries if entry["step"] is not None
]
if len(available_steps) > 0:
step_range = f"{min(available_steps)}-{max(available_steps)}"
else:
step_range = "na-na"
output_name = f"mujoco_model-{step_range}_total_metrics.tsv"
output_path = parent_dir / output_name
generated_artifacts = [output_path]
with open(output_path, "w", encoding="utf-8", newline="") as out_file:
writer = csv.writer(out_file, delimiter="\t", lineterminator="\n")
writer.writerow(["step"] + metric_columns)
for entry in entries:
step_value = (
str(entry["step"]) if entry["step"] is not None else ""
)
writer.writerow(
[step_value]
+ [
entry["total_row"].get(col, "")
for col in metric_columns
]
)
logger.info(f"Saved aggregated total metrics table at: {output_path}")
plot_metric_columns = [
col for col in metric_columns if col != "Dataset"
]
if len(plot_metric_columns) > 0:
import matplotlib.pyplot as plt
ncols = 4
nrows = (len(plot_metric_columns) + ncols - 1) // ncols
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(4.0 * ncols, 2.8 * nrows),
squeeze=False,
)
for idx, metric_name in enumerate(plot_metric_columns):
ax = axes[idx // ncols][idx % ncols]
trend_pairs = []
for entry in entries:
step_value = entry["step"]
if step_value is None:
continue
raw_metric = entry["total_row"].get(metric_name, "")
if str(raw_metric).strip() == "":
continue
trend_pairs.append((step_value, float(raw_metric)))
if len(trend_pairs) == 0:
ax.text(
0.5,
0.5,
"No valid data",
ha="center",
va="center",
transform=ax.transAxes,
)
ax.set_title(metric_name)
ax.set_xticks([])
ax.set_yticks([])
ax.grid(False)
continue
trend_pairs.sort(key=lambda pair: pair[0])
plot_steps = [pair[0] for pair in trend_pairs]
plot_values = [pair[1] for pair in trend_pairs]
ax.plot(plot_steps, plot_values, marker="o", linewidth=1.2)
ax.set_title(metric_name)
ax.set_xlabel("step")
ax.grid(True, alpha=0.3)
total_axes = nrows * ncols
for idx in range(len(plot_metric_columns), total_axes):
axes[idx // ncols][idx % ncols].axis("off")
fig.tight_layout()
plot_path = output_path.with_name(
f"{output_path.stem}_all_metric_trends.pdf"
)
fig.savefig(plot_path, format="pdf")
plt.close(fig)
generated_artifacts.append(plot_path)
logger.info(f"Saved combined metric trend plot at: {plot_path}")
if job_log_dir is not None:
for artifact_path in generated_artifacts:
job_log_path = job_log_dir / artifact_path.name
shutil.copy2(artifact_path, job_log_path)
logger.info(f"Exported artifact to /job_log: {job_log_path}")
def process_config(override_config):
"""Process the configuration, merging with training config if available."""
ckpt_onnx_path = _get_config_value(override_config, "ckpt_onnx_path")
ckpt_onnx_root_dir = _get_config_value(
override_config, "ckpt_onnx_root_dir"
)
if (
(ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == "")
and ckpt_onnx_root_dir is not None
and str(ckpt_onnx_root_dir).strip() != ""
):
ckpt_onnx_names = _get_config_value(override_config, "ckpt_onnx_names")
resolved_paths = _resolve_multi_ckpt_paths(
ckpt_onnx_root_dir, ckpt_onnx_names
)
ckpt_onnx_path = str(resolved_paths[0])
logger.info(
"Using the first resolved checkpoint as config anchor: "
f"{ckpt_onnx_path}"
)
model_type = override_config.get("model_type") or "holomotion"
if model_type == "gmt":
config_path = Path(
"holomotion/config/evaluation/gmt_eval_mujoco_sim2sim.yaml"
)
elif model_type == "any2track":
config_path = Path(
"holomotion/config/evaluation/any2track_eval_mujoco_sim2sim.json"
)
elif model_type == "sonic":
config_path = Path(
"holomotion/config/evaluation/sonic_eval_mujoco_sim2sim.yaml"
)
else:
if ckpt_onnx_path is None or str(ckpt_onnx_path).strip() == "":
raise ValueError(
"Cannot locate training config.yaml for model_type='holomotion' "
"without an ONNX checkpoint path. Set ckpt_onnx_path, or set "
"ckpt_onnx_root_dir + ckpt_onnx_names."
)
onnx_path = Path(str(ckpt_onnx_path))
# Load training config.yaml from one level above the ONNX path (../onnx_path)
config_path = onnx_path.parent.parent / "config.yaml"
logger.info(f"Loading training config file from {config_path}")
# Ensure ${eval:'...'} expressions are supported during resolution
if not OmegaConf.has_resolver("eval"):
OmegaConf.register_new_resolver("eval", lambda expr: eval(expr))
with open(config_path) as file:
train_config = OmegaConf.load(file)
# Merge training config with any overrides
config = OmegaConf.merge(train_config, override_config)
with open_dict(config):
config.model_type = model_type
# Resolve config values in-place
OmegaConf.resolve(config)
if (
(
config.get("ckpt_onnx_path", None) is None
or str(config.get("ckpt_onnx_path")).strip() == ""
)
and ckpt_onnx_path is not None
and str(ckpt_onnx_path).strip() != ""
):
with open_dict(config):
config.ckpt_onnx_path = str(ckpt_onnx_path)
return config
def _create_ray_evaluator(config_dict, model_type):
"""Create evaluator from serializable config dict (used inside Ray actor)."""
from omegaconf import OmegaConf, open_dict
config = OmegaConf.create(config_dict)
if model_type == "gmt":
from holomotion.src.evaluation.gmt_sim2sim import GMTEvaluator
return GMTEvaluator(config)
if model_type == "any2track":
from holomotion.src.evaluation.any2track_sim2sim import (
Any2TrackEvaluator,
)
return Any2TrackEvaluator(config)
if model_type == "sonic":
from holomotion.src.evaluation.sonic_mujoco_sim2sim import (
SonicEvaluator,
)
return SonicEvaluator(config)
return MujocoEvaluator(config)
def run_mujoco_sim2sim_eval(override_config: OmegaConf):
os.chdir(hydra.utils.get_original_cwd())
config = process_config(override_config)
is_eval_mode = False
dataset_dir = config.get("motion_npz_dir", None)
specific_file = config.get("motion_npz_path", None)
calc_per_clip_metrics = bool(config.get("calc_per_clip_metrics", False))
generate_report = bool(config.get("generate_report", False))
dump_npzs_cfg = bool(config.get("dump_npzs", False))
dump_onnx_io_npy = bool(config.get("dump_onnx_io_npy", False))
dump_npzs = dump_npzs_cfg or calc_per_clip_metrics
if calc_per_clip_metrics and not dump_npzs_cfg:
logger.info(
"calc_per_clip_metrics=true requires dumped NPZs; "
"enabling dump_npzs automatically."
)
if (
dataset_dir
and os.path.isdir(str(dataset_dir))
and (not specific_file or str(specific_file) == "")
):
is_eval_mode = True
if is_eval_mode:
logger.info(f"Mode: EVALUATION on directory: {dataset_dir}")
logger.remove()
logger.add(sys.stderr, level="INFO")
dataset_name = Path(dataset_dir).name
ckpt_paths = _resolve_eval_ckpt_paths(config)
logger.info(
f"Resolved {len(ckpt_paths)} checkpoint(s) for evaluation."
)
for idx, ckpt_path in enumerate(ckpt_paths):
logger.info(f" [{idx}] {ckpt_path}")
eval_targets = []
for ckpt_path in ckpt_paths:
output_dir = _build_eval_output_dir(ckpt_path, dataset_name)
eval_targets.append(
{
"ckpt_path": ckpt_path,
"output_dir": str(output_dir),
}
)
if dump_npzs:
for target in eval_targets:
os.makedirs(target["output_dir"], exist_ok=True)
if dump_onnx_io_npy:
write_onnx_io_dump_readme(
_build_onnx_io_dump_dir(target["output_dir"])
)
files = sorted(
[
os.path.join(root, name)
for root, _, filenames in os.walk(
dataset_dir, followlinks=True
)
for name in filenames
if name.endswith(".npz")
]
)
logger.info(
f"Found {len(files)} files for dataset_dir={dataset_dir}. "
f"Will evaluate {len(eval_targets)} checkpoint(s)."
)
if len(files) == 0:
logger.warning(
f"No NPZ files found under dataset_dir={dataset_dir}"
)
requested_use_gpu = _coerce_config_bool(
config.get("use_gpu", True), default=True
)
num_available_gpus = 0
if requested_use_gpu and torch.cuda.is_available():
num_available_gpus = int(torch.cuda.device_count())
if requested_use_gpu and num_available_gpus == 0:
logger.warning(
"use_gpu=true but no CUDA device is detected; "
"Ray actors will run on CPU."
)
if num_available_gpus > 0:
logger.info(
f"Detected {num_available_gpus} CUDA device(s). "
"Using Ray for batch evaluation."
)
ray_actors_per_gpu = int(config.get("ray_actors_per_gpu", 4))
if ray_actors_per_gpu <= 0:
raise ValueError("ray_actors_per_gpu must be > 0")
ray_multi_ckpt_mode = str(
config.get("ray_multi_ckpt_mode", "split")
)
if ray_multi_ckpt_mode not in ("split", "per_checkpoint"):
raise ValueError(
"ray_multi_ckpt_mode must be one of: "
"'split', 'per_checkpoint'"
)
success_count = 0
total_jobs = len(files) * len(eval_targets)
if total_jobs > 0:
base_config_dict = OmegaConf.to_container(config, resolve=True)
base_config_dict.setdefault(
"ray_evaluator_module",
"holomotion.src.evaluation.eval_mujoco_sim2sim",
)
if not ray.is_initialized():
ray.init()
from holomotion.src.evaluation.ray_evaluator_actor import (
RayEvaluatorActor,
)
if num_available_gpus > 0:
base_actor_count = num_available_gpus * ray_actors_per_gpu
gpus_per_actor = 1.0 / ray_actors_per_gpu
remote_actor = ray.remote(num_gpus=gpus_per_actor)(
RayEvaluatorActor
)
else:
base_actor_count = max(1, ray_actors_per_gpu)
gpus_per_actor = 0.0
remote_actor = ray.remote(num_gpus=0)(RayEvaluatorActor)
if ray_multi_ckpt_mode == "per_checkpoint":
actor_counts = [
base_actor_count for _ in range(len(eval_targets))
]
else:
actor_counts = _allocate_actor_counts(
len(eval_targets), base_actor_count
)
if min(actor_counts) <= 0:
raise ValueError(
"Not enough actor budget to assign at least one actor "
"per checkpoint in split mode. Reduce checkpoint count, "
"increase ray_actors_per_gpu, or switch to "
"ray_multi_ckpt_mode=per_checkpoint."
)
total_actor_count = sum(actor_counts)
logger.info(
f"Ray: {total_actor_count} persistent actors "
f"({ray_actors_per_gpu} per GPU, {gpus_per_actor} GPU each)"
)
logger.info(
"Checkpoint actor allocation: "
+ ", ".join(
[
f"{target['ckpt_path'].name}={actor_counts[idx]}"
for idx, target in enumerate(eval_targets)
]
)
)
refs = []
for target_idx, target in enumerate(eval_targets):
target_config_dict = dict(base_config_dict)
target_config_dict["ckpt_onnx_path"] = str(
target["ckpt_path"]
)
num_actors = actor_counts[target_idx]
target_actors = [
remote_actor.remote(
target_config_dict, target["output_dir"]
)
for _ in range(num_actors)
]
for file_idx, file_path in enumerate(files):
actor = target_actors[file_idx % len(target_actors)]
refs.append(actor.run_clip.remote(file_path))
pbar = tqdm(
total=total_jobs,
desc="Batch Processing (all checkpoints)",
unit="job",
dynamic_ncols=True,
)
while refs:
done, refs = ray.wait(refs, num_returns=1)
for ref in done:
status = ray.get(ref)
if status == "success":
success_count += 1
pbar.update(1)
pbar.close()
logger.info(
f"Batch processing done. Success: {success_count}/{total_jobs}"
)
else:
logger.info("Skipping NPZ dumping because dump_npzs=false.")
job_log_dir = Path("/job_log")
job_log_enabled = job_log_dir.is_dir() and os.access(
str(job_log_dir), os.W_OK
)
if job_log_enabled:
logger.info(
f"Detected writable /job_log. Will copy summary TSVs to {job_log_dir}."
)
else:
logger.info(
"/job_log is unavailable or not writable. "
"Skipping summary TSV export."
)
postprocess_targets = []
for target in eval_targets:
output_dir = target["output_dir"]
output_dir_path = Path(output_dir)
if not output_dir_path.is_dir():
logger.warning(
f"Output directory does not exist, skipping post-processing: {output_dir}"
)
continue
postprocess_targets.append(target)
failure_pos_err_thresh_m = float(
config.get("failure_pos_err_thresh_m", 0.25)
)
metric_calculation = str(config.get("metric_calculation", "per_clip"))
dof_mode = str(config.get("dof_mode", "29"))
ray_parallel_metrics = bool(
config.get(
"ray_parallel_metrics_postprocess",
config.get("ray_parallel_metrics", True),
)
)
metrics_threadpool_max_workers = config.get(
"metrics_threadpool_max_workers", None
)
should_parallelize_metrics = (
ray_parallel_metrics
and len(postprocess_targets) > 1
and (calc_per_clip_metrics or generate_report or job_log_enabled)
)
logger.info(
"Metrics config: "
f"ray_parallel_metrics_postprocess={ray_parallel_metrics}, "
f"metrics_threadpool_max_workers={metrics_threadpool_max_workers}"
)
if should_parallelize_metrics:
if not ray.is_initialized():
ray.init()
from holomotion.src.evaluation.ray_metrics_postprocess import (
run_metrics_postprocess_job,
)
ray_metrics_num_cpus_cfg = config.get(
"ray_metrics_postprocess_num_cpus",
config.get("ray_metrics_num_cpus", None),
)
if ray_metrics_num_cpus_cfg is None:
ray_metrics_num_cpus = 0.0
else:
ray_metrics_num_cpus = float(ray_metrics_num_cpus_cfg)
if ray_metrics_num_cpus < 0.0:
raise ValueError("ray_metrics_num_cpus must be >= 0")
metric_refs = []
for target in postprocess_targets:
ckpt_path = target["ckpt_path"]
metric_refs.append(
run_metrics_postprocess_job.options(
num_cpus=ray_metrics_num_cpus
).remote(
output_dir=target["output_dir"],
dataset_name=dataset_name,
calc_per_clip_metrics=calc_per_clip_metrics,
failure_pos_err_thresh_m=failure_pos_err_thresh_m,
metric_calculation=metric_calculation,
dof_mode=dof_mode,
metrics_threadpool_max_workers=metrics_threadpool_max_workers,
generate_report=generate_report,
job_log_dir=str(job_log_dir)
if job_log_enabled
else None,
ckpt_stem=ckpt_path.stem,
)
)
pbar = tqdm(
total=len(metric_refs),
desc="Metrics post-processing (all checkpoints)",
unit="ckpt",
dynamic_ncols=True,
)
while metric_refs:
done, metric_refs = ray.wait(metric_refs, num_returns=1)
for ref in done:
result = ray.get(ref)
ckpt_stem = str(result.get("ckpt_stem", "")).strip()
if ckpt_stem == "":
ckpt_stem = "unknown"
if calc_per_clip_metrics:
logger.info(
f"Metric calculation finished for {ckpt_stem}."
)
report_path = str(result.get("report_path", "")).strip()
if report_path != "":
logger.info(
f"Generated metrics report for {ckpt_stem} at: {report_path}"
)
exported_tsv = str(
result.get("exported_summary_tsv", "")
).strip()
if exported_tsv != "":
logger.info(f"Exported summary TSV to: {exported_tsv}")
pbar.update(1)
pbar.close()
else:
mean_process_5metrics = None
if generate_report:
from holomotion.scripts.evaluation import mean_process_5metrics
for target in postprocess_targets:
output_dir = target["output_dir"]
output_dir_path = Path(output_dir)
ckpt_path = target["ckpt_path"]
if calc_per_clip_metrics:
logger.info(
"Starting metric calculation for "
f"{ckpt_path.name}: {output_dir}"
)
run_evaluation(
npz_dir=output_dir,
dataset_suffix=dataset_name,
failure_pos_err_thresh_m=failure_pos_err_thresh_m,
metric_calculation=metric_calculation,
dof_mode=dof_mode,
threadpool_max_workers=metrics_threadpool_max_workers,
)
logger.info(
f"Metric calculation finished for {ckpt_path.name}."
)
if generate_report:
report_path = mean_process_5metrics.generate_macro_mean_report_from_json_dir(
output_dir
)
logger.info(
f"Generated metrics report for {ckpt_path.name} at: {report_path}"
)
if job_log_enabled:
sub_dataset_tsv = (
output_dir_path / "sub_dataset_macro_mean_metrics.tsv"
)
if sub_dataset_tsv.is_file():
export_name = f"{ckpt_path.stem}_sub_dataset_macro_mean_metrics.tsv"
export_path = job_log_dir / export_name
shutil.copy2(sub_dataset_tsv, export_path)
logger.info(f"Exported summary TSV to: {export_path}")
else:
logger.warning(
"Summary TSV not found (skip export): "
f"{sub_dataset_tsv}"
)
_write_total_macro_summary_table(
eval_targets,
job_log_dir=job_log_dir if job_log_enabled else None,
)
else:
if config.get("model_type", "holomotion") == "sonic":
from holomotion.src.evaluation.sonic_mujoco_sim2sim import (
SonicEvaluator,
)
evaluator = SonicEvaluator(config)
else:
evaluator = MujocoEvaluator(config)
evaluator.setup()
evaluator.run_simulation()
@hydra.main(
config_path="../../config",
config_name="evaluation/eval_mujoco_sim2sim",
version_base=None,
)
def main(override_config: OmegaConf):
run_mujoco_sim2sim_eval(override_config)
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/evaluation/eval_velocity_tracking.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
from pathlib import Path
import hydra
from hydra.utils import get_class
from loguru import logger
from omegaconf import OmegaConf
from holomotion.src.utils.config import compile_config
from holomotion.src.utils.onnx_export import export_policy_to_onnx
def load_training_config(
checkpoint_path: str, eval_config: OmegaConf
) -> OmegaConf:
"""Load training config from checkpoint directory.
Args:
checkpoint_path: Path to the checkpoint file.
eval_config: Full evaluation config (including command line overrides).
Returns:
Merged config with training config as base.
"""
checkpoint = Path(checkpoint_path)
config_path = checkpoint.parent / "config.yaml"
if not config_path.exists():
config_path = checkpoint.parent.parent / "config.yaml"
if not config_path.exists():
logger.warning(
"Training config not found at "
f"{config_path}, using evaluation config"
)
return eval_config
logger.info(f"Loading training config from {config_path}")
with open(config_path) as file:
train_config = OmegaConf.load(file)
# Apply eval_overrides from training config if they exist
if train_config.get("eval_overrides") is not None:
train_config = OmegaConf.merge(
train_config, train_config.eval_overrides
)
# Set checkpoint path
train_config.checkpoint = checkpoint_path
train_config.algo.config.checkpoint = checkpoint_path
# For evaluation, merge eval_config into train_config
config = OmegaConf.merge(train_config, eval_config)
# For velocity tracking, always keep the robot configuration from training
if hasattr(train_config, "robot"):
config.robot = train_config.robot
# foce set the terminations and domain rand with eval_config's
config.env.config.terminations = eval_config.env.config.terminations
config.env.config.domain_rand = eval_config.env.config.domain_rand
config.env.config.domain_rand = eval_config.env.config.domain_rand
return config
@hydra.main(
config_path="../../config",
config_name="evaluation/eval_isaaclab",
version_base=None,
)
def main(config: OmegaConf):
"""Evaluate the velocity tracking model.
Args:
config: OmegaConf object containing the evaluation configuration.
"""
# Load training config first
if config.checkpoint is None:
raise ValueError("Checkpoint path must be provided for evaluation")
config = load_training_config(config.checkpoint, config)
# Compile config without accelerator (PPO will create it)
config = compile_config(config, accelerator=None)
# Use checkpoint directory as log_dir for offline evaluation
log_dir = os.path.dirname(config.checkpoint)
headless = config.headless
# PPO creates Accelerator, AppLauncher, and environment internally
algo_class = get_class(config.algo._target_)
algo = algo_class(
env_config=config.env,
config=config.algo.config,
log_dir=log_dir,
headless=headless,
is_offline_eval=True,
)
if (
algo.accelerator.is_main_process
and os.environ.get("TORCH_COMPILE_DISABLE", "0") != "1"
):
logger.info(
"Tip: set TORCH_COMPILE_DISABLE=1 if Triton/compile errors occur"
)
if algo.accelerator.is_main_process:
eval_log_dir = os.path.dirname(config.checkpoint)
with open(os.path.join(eval_log_dir, "eval_config.yaml"), "w") as f:
OmegaConf.save(config, f)
if hasattr(config, "checkpoint") and config.checkpoint is not None:
if algo.accelerator.is_main_process:
logger.info(
f"Loading checkpoint for evaluation: {config.checkpoint}"
)
algo.load(config.checkpoint)
else:
if algo.accelerator.is_main_process:
logger.warning("No checkpoint provided for evaluation!")
# Export ONNX if requested
if config.get("export_policy", True):
if algo.accelerator.is_main_process:
onnx_name_suffix = config.get("onnx_name_suffix", None)
onnx_path = export_policy_to_onnx(
algo,
config.checkpoint,
onnx_name_suffix=onnx_name_suffix,
use_kv_cache=config.get("use_kv_cache", True),
)
logger.info(f"Successfully exported policy to: {onnx_path}")
algo.accelerator.wait_for_everyone()
# Run indefinite velocity tracking rollout for visualization
algo.offline_evaluate_velocity_tracking()
if algo.accelerator.is_main_process:
logger.info("Velocity tracking evaluation completed!")
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/evaluation/find_worst_clips.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import math
from pathlib import Path
from typing import Dict, Any, List
JSON_INPUT_FILE = "logs/Holomotion/metrics_output_dataset/model_17500.json"
input_path = Path(JSON_INPUT_FILE).expanduser().resolve()
OUTPUT_JSON_FILE = str(input_path.parent / "bad_clips.json")
WORST_PERCENTAGE = 0.2
METRICS_INFO: Dict[str, Dict[str, str]] = {
"whole_body_joints_dist": {
"name": "Joint Angle Error (Whole Body Average)",
"unit": "rad",
"direction": "higher_is_worse",
},
}
def find_and_save_bad_clips(
data: Dict[str, Any],
metrics_info: Dict[str, Dict[str, str]],
percentage: float,
output_file: str,
) -> None:
per_clip_data: List[Dict[str, Any]] = data.get("per_clip", [])
if not per_clip_data:
print("Error: 'per_clip' not found in JSON data.")
return
total_clips = len(per_clip_data)
num_to_select = math.ceil(total_clips * percentage)
if num_to_select == 0 and total_clips > 0:
num_to_select = 1
bad_clips_report: Dict[str, List[Dict[str, Any]]] = {}
for key, info in metrics_info.items():
direction = info.get("direction")
if not direction:
continue
sort_descending = direction == "higher_is_worse"
clips_with_metric_value = [
{"motion_key": clip["motion_key"], "value": clip[key]}
for clip in per_clip_data
if key in clip and "motion_key" in clip
]
if not clips_with_metric_value:
print(f"Warning: no values found for metric '{key}' in data.")
continue
sorted_clips = sorted(
clips_with_metric_value,
key=lambda x: x["value"],
reverse=sort_descending,
)
worst_clips = sorted_clips[:num_to_select]
bad_clips_report[key] = worst_clips
with open(output_file, "w", encoding="utf-8") as f:
json.dump(bad_clips_report, f, indent=4, ensure_ascii=False)
print(f"Saved bad-clips report to: {output_file}")
def main() -> None:
if not Path(JSON_INPUT_FILE).is_file():
print(f"Error: JSON input file '{JSON_INPUT_FILE}' not found.")
return
with open(JSON_INPUT_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
find_and_save_bad_clips(
data, METRICS_INFO, WORST_PERCENTAGE, OUTPUT_JSON_FILE
)
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/evaluation/metrics.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from pathlib import Path
from typing import Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import argparse
import csv
import json
import os
import re
from glob import glob
from zipfile import BadZipFile
import numpy as np
import pandas as pd
from loguru import logger
from scipy.signal import welch
from scipy.spatial.transform import Rotation as sRot
from tabulate import tabulate
from tqdm import tqdm
DEFAULT_ROBOT_CONTROL_DT = 1.0 / 50.0
TORQUE_JUMP_RATIO_EPS = 1e-6
MIN_WELCH_SAMPLES = 8
STABILITY_BURST_WINDOW_SECONDS = 0.5
TOUCHDOWN_WINDOW_SECONDS = 0.05
ROOT_BODY_INDEX = 0
PROBABILITY_EPS = 1e-12
def quat_inv(q):
return np.concatenate([-q[..., :3], q[..., 3:4]], axis=-1)
def quat_apply(q, v):
q = np.asarray(q, dtype=np.float64)
v = np.asarray(v, dtype=np.float64)
xyz = q[:, None, :3]
w = q[:, None, 3:4]
t = 2.0 * np.cross(xyz, v, axis=-1)
return v + w * t + np.cross(xyz, t, axis=-1)
def p_mpjpe(predicted: np.ndarray, target: np.ndarray) -> np.ndarray:
"""Compute Procrustes-aligned MPJPE between predicted and ground truth.
Reference:
This function is inspired by and partially adapted from the SMPLSim:
https://github.com/ZhengyiLuo/SMPLSim/blob/0d672790a7672f28361d59dadd98ae2fc1b9685e/smpl_sim/smpllib/smpl_eval.py.
"""
assert predicted.shape == target.shape
mu_x = np.mean(target, axis=1, keepdims=True)
mu_y = np.mean(predicted, axis=1, keepdims=True)
x0 = target - mu_x
y0 = predicted - mu_y
norm_x = np.sqrt(np.sum(x0**2, axis=(1, 2), keepdims=True))
norm_y = np.sqrt(np.sum(y0**2, axis=(1, 2), keepdims=True))
x0 /= norm_x
y0 /= norm_y
h = np.matmul(x0.transpose(0, 2, 1), y0)
# Per-frame SVD with graceful handling for non-convergence: mark those frames as NaN
batch_size = int(h.shape[0])
jdim = int(h.shape[1])
u = np.empty((batch_size, jdim, jdim), dtype=h.dtype)
s = np.empty((batch_size, jdim), dtype=h.dtype)
vt = np.empty((batch_size, jdim, jdim), dtype=h.dtype)
for i in range(batch_size):
try:
ui, si, vti = np.linalg.svd(h[i])
u[i] = ui
s[i] = si
vt[i] = vti
except np.linalg.LinAlgError:
u[i].fill(np.nan)
s[i].fill(np.nan)
vt[i].fill(np.nan)
v = vt.transpose(0, 2, 1)
r = np.matmul(v, u.transpose(0, 2, 1))
# Avoid improper rotations (reflections), i.e. rotations with det(R) = -1
sign_det_r = np.sign(np.expand_dims(np.linalg.det(r), axis=1))
v[:, :, -1] *= sign_det_r
s[:, -1] *= sign_det_r.flatten()
r = np.matmul(v, u.transpose(0, 2, 1)) # Corrected rotation
tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2)
a = tr * norm_x / norm_y # Scale
t = mu_x - a * np.matmul(mu_y, r) # Translation
predicted_aligned = a * np.matmul(predicted, r) + t
return np.linalg.norm(
predicted_aligned - target, axis=len(target.shape) - 1
)
def _parse_clip_len_from_name(filename: str) -> Optional[int]:
"""Extract clip length from filename suffix '__start_XXX_len_N'."""
m = re.search(r"__start_\d+_len_(\d+)", os.path.basename(filename))
return int(m.group(1)) if m else None
def _parse_metadata_entry(raw_metadata) -> Dict[str, object]:
if raw_metadata is None:
return {}
parsed = raw_metadata
if isinstance(parsed, np.ndarray):
if parsed.shape != ():
return {}
parsed = parsed.item()
if isinstance(parsed, dict):
return parsed
if isinstance(parsed, bytes):
parsed = parsed.decode("utf-8")
if isinstance(parsed, str):
try:
obj = json.loads(parsed)
except json.JSONDecodeError:
return {}
return obj if isinstance(obj, dict) else {}
return {}
def _extract_robot_control_dt(
metadata: Dict[str, object], raw_data: Dict[str, np.ndarray]
) -> float:
if "robot_low_level_torque_dt" in raw_data:
raw_dt = np.asarray(raw_data["robot_low_level_torque_dt"]).item()
else:
raw_dt = metadata.get(
"robot_low_level_torque_dt",
metadata.get("robot_control_dt", DEFAULT_ROBOT_CONTROL_DT),
)
try:
robot_control_dt = float(raw_dt)
except (TypeError, ValueError):
return DEFAULT_ROBOT_CONTROL_DT
if not np.isfinite(robot_control_dt) or robot_control_dt <= 0.0:
return DEFAULT_ROBOT_CONTROL_DT
return robot_control_dt
def _extract_low_level_contact_dt(
metadata: Dict[str, object],
raw_data: Dict[str, np.ndarray],
robot_control_dt: float,
) -> float:
if "robot_low_level_contact_dt" in raw_data:
raw_dt = np.asarray(raw_data["robot_low_level_contact_dt"]).item()
else:
raw_dt = metadata.get(
"robot_low_level_contact_dt",
metadata.get(
"robot_low_level_torque_dt",
metadata.get("robot_control_dt", robot_control_dt),
),
)
try:
contact_dt = float(raw_dt)
except (TypeError, ValueError):
return robot_control_dt
if not np.isfinite(contact_dt) or contact_dt <= 0.0:
return robot_control_dt
return contact_dt
def _aggregate_sample_metric_to_frames(
sample_metric: np.ndarray, num_frames: int
) -> np.ndarray:
if int(sample_metric.shape[0]) == num_frames:
return sample_metric.astype(float, copy=False)
if num_frames <= 0:
return np.empty((0,), dtype=float)
aggregated = np.full((num_frames,), np.nan, dtype=float)
for frame_idx, chunk in enumerate(
np.array_split(sample_metric, num_frames)
):
if chunk.size == 0:
continue
if np.all(np.isnan(chunk)):
continue
aggregated[frame_idx] = float(np.nanmean(chunk))
return aggregated
def _compute_torque_jump_series(
torque_samples: np.ndarray, torque_dt: float
) -> tuple[np.ndarray, np.ndarray]:
num_samples = int(torque_samples.shape[0])
torque_jump_norm = np.full((num_samples,), np.nan, dtype=float)
torque_jump_ratio = np.full((num_samples,), np.nan, dtype=float)
if num_samples <= 1:
return torque_jump_norm, torque_jump_ratio
torque_mag = np.linalg.norm(torque_samples, axis=1)
torque_delta_norm = np.linalg.norm(
torque_samples[1:] - torque_samples[:-1], axis=1
)
torque_jump_norm[1:] = torque_delta_norm / torque_dt
torque_scale = np.maximum(
np.maximum(torque_mag[1:], torque_mag[:-1]), TORQUE_JUMP_RATIO_EPS
)
torque_jump_ratio[1:] = torque_delta_norm / torque_scale
return torque_jump_norm, torque_jump_ratio
def _safe_nanpercentile(values: np.ndarray, q: float) -> float:
arr = np.asarray(values, dtype=float).reshape(-1)
arr = arr[np.isfinite(arr)]
if arr.size == 0:
return float("nan")
return float(np.nanpercentile(arr, q))
def _safe_nanmean(values: np.ndarray) -> float:
arr = np.asarray(values, dtype=float).reshape(-1)
arr = arr[np.isfinite(arr)]
if arr.size == 0:
return float("nan")
return float(np.mean(arr))
def _safe_nanmedian(values: np.ndarray) -> float:
arr = np.asarray(values, dtype=float).reshape(-1)
arr = arr[np.isfinite(arr)]
if arr.size == 0:
return float("nan")
return float(np.median(arr))
def _compute_rolling_nanmean_max(
values: np.ndarray, window_size: int
) -> float:
arr = np.asarray(values, dtype=float).reshape(-1)
if arr.size == 0:
return float("nan")
if window_size <= 1:
return float(np.nanmax(arr))
best = float("nan")
max_start = int(arr.size) - int(window_size) + 1
if max_start <= 0:
if np.all(np.isnan(arr)):
return float("nan")
return float(np.nanmean(arr))
for start in range(max_start):
window = arr[start : start + window_size]
if np.all(np.isnan(window)):
continue
mean_value = float(np.nanmean(window))
if np.isnan(best) or mean_value > best:
best = mean_value
return best
def _integrate_psd_band(
frequencies: np.ndarray,
power_density: np.ndarray,
low_hz: float,
high_hz: float,
) -> float:
if (
not np.isfinite(low_hz)
or not np.isfinite(high_hz)
or high_hz <= low_hz
):
return float("nan")
band_mask = (frequencies >= low_hz) & (frequencies <= high_hz)
if not np.any(band_mask):
return float("nan")
band_freq = frequencies[band_mask]
band_power = power_density[band_mask]
if band_freq.size == 1:
return float(band_power[0])
return float(np.trapz(band_power, band_freq))
def _compute_psd_high_frequency_ratio(
signal_values: np.ndarray,
sample_dt: float,
*,
high_band_low_hz: float,
band_high_hz: float,
band_low_hz: float = 0.5,
) -> float:
samples = np.asarray(signal_values, dtype=float).reshape(-1)
samples = samples[np.isfinite(samples)]
if samples.size < MIN_WELCH_SAMPLES:
return float("nan")
sample_rate_hz = 1.0 / float(sample_dt)
max_band_hz = min(float(band_high_hz), 0.45 * sample_rate_hz)
if max_band_hz <= max(float(band_low_hz), float(high_band_low_hz)):
return float("nan")
nperseg = min(int(samples.size), 256)
frequencies, power_density = welch(
samples,
fs=sample_rate_hz,
nperseg=nperseg,
detrend="constant",
average="mean",
)
total_power = _integrate_psd_band(
frequencies, power_density, float(band_low_hz), max_band_hz
)
high_power = _integrate_psd_band(
frequencies, power_density, float(high_band_low_hz), max_band_hz
)
if (
not np.isfinite(total_power)
or total_power <= 0.0
or not np.isfinite(high_power)
):
return float("nan")
return float(high_power / total_power)
def _compute_torque_chatter_hf_ratio(
low_level_torque: np.ndarray, low_level_dt: float
) -> float:
torque_samples = np.asarray(low_level_torque, dtype=float)
if torque_samples.ndim != 2 or torque_samples.shape[0] < MIN_WELCH_SAMPLES:
return float("nan")
ratios = []
for joint_idx in range(int(torque_samples.shape[1])):
ratio = _compute_psd_high_frequency_ratio(
torque_samples[:, joint_idx],
low_level_dt,
high_band_low_hz=10.0,
band_high_hz=40.0,
)
if np.isfinite(ratio):
ratios.append(ratio)
if len(ratios) == 0:
return float("nan")
return float(np.mean(ratios))
def _compute_torso_roll_pitch_stability_metrics(
robot_global_angular_velocity: np.ndarray,
robot_control_dt: float,
) -> Dict[str, float]:
angular_velocity = np.asarray(robot_global_angular_velocity, dtype=float)
if angular_velocity.ndim != 3 or angular_velocity.shape[0] == 0:
return {
"torso_rp_hf_ratio": float("nan"),
"torso_rp_angacc_p95": float("nan"),
}
torso_roll_pitch_vel = angular_velocity[:, ROOT_BODY_INDEX, :2]
torso_roll_pitch_speed = np.linalg.norm(torso_roll_pitch_vel, axis=1)
hf_ratio = _compute_psd_high_frequency_ratio(
torso_roll_pitch_speed,
robot_control_dt,
high_band_low_hz=5.0,
band_high_hz=20.0,
)
if torso_roll_pitch_vel.shape[0] <= 1:
angacc_p95 = float("nan")
else:
roll_pitch_angacc = np.diff(torso_roll_pitch_vel, axis=0) / float(
robot_control_dt
)
roll_pitch_angacc_mag = np.linalg.norm(roll_pitch_angacc, axis=1)
angacc_p95 = _safe_nanpercentile(roll_pitch_angacc_mag, 95.0)
return {
"torso_rp_hf_ratio": hf_ratio,
"torso_rp_angacc_p95": angacc_p95,
}
def _compute_expert_switching_js_div(
robot_moe_expert_logits: np.ndarray | None,
) -> float:
if robot_moe_expert_logits is None:
return float("nan")
logits = np.asarray(robot_moe_expert_logits, dtype=float)
if logits.ndim != 3 or logits.shape[0] <= 1 or logits.shape[-1] <= 1:
return float("nan")
if not np.all(np.isfinite(logits)):
return float("nan")
shifted_logits = logits - np.max(logits, axis=-1, keepdims=True)
probs = np.exp(shifted_logits)
probs /= np.sum(probs, axis=-1, keepdims=True)
prev_probs = np.clip(probs[:-1], PROBABILITY_EPS, 1.0)
next_probs = np.clip(probs[1:], PROBABILITY_EPS, 1.0)
mixture = 0.5 * (prev_probs + next_probs)
kl_prev = np.sum(
prev_probs * (np.log(prev_probs) - np.log(mixture)), axis=-1
)
kl_next = np.sum(
next_probs * (np.log(next_probs) - np.log(mixture)), axis=-1
)
js_divergence = 0.5 * (kl_prev + kl_next) / np.log(2.0)
return _safe_nanmean(js_divergence)
def _compute_contact_stability_metrics(
foot_contact_samples: np.ndarray | None,
foot_normal_force_samples: np.ndarray | None,
foot_tangent_speed_samples: np.ndarray | None,
contact_dt: float,
) -> Dict[str, float]:
metrics = {
"foot_contact_toggle_rate": float("nan"),
"foot_impact_force_p95": float("nan"),
"stance_slip_speed_p95": float("nan"),
}
if (
foot_contact_samples is None
or foot_normal_force_samples is None
or foot_tangent_speed_samples is None
):
return metrics
contact = np.asarray(foot_contact_samples, dtype=float)
normal_force = np.asarray(foot_normal_force_samples, dtype=float)
tangent_speed = np.asarray(foot_tangent_speed_samples, dtype=float)
if (
contact.shape != normal_force.shape
or contact.shape != tangent_speed.shape
or contact.ndim != 2
or contact.shape[1] != 2
):
return metrics
finite_contact = np.isfinite(contact)
if not np.any(finite_contact):
return metrics
contact_binary = np.where(contact >= 0.5, 1.0, 0.0)
valid_pair_mask = finite_contact[1:] & finite_contact[:-1]
toggle_count = int(
np.sum(
np.abs(contact_binary[1:] - contact_binary[:-1]) * valid_pair_mask
)
)
clip_duration_seconds = float(contact.shape[0]) * float(contact_dt)
if clip_duration_seconds > 0.0:
metrics["foot_contact_toggle_rate"] = (
float(toggle_count) / clip_duration_seconds
)
touchdown_window = max(
1, int(round(TOUCHDOWN_WINDOW_SECONDS / float(contact_dt)))
)
touchdown_peaks = []
for foot_idx in range(2):
foot_contact = contact_binary[:, foot_idx]
foot_force = normal_force[:, foot_idx]
onset_mask = np.zeros_like(foot_contact, dtype=bool)
onset_mask[0] = foot_contact[0] >= 0.5
onset_mask[1:] = (foot_contact[1:] >= 0.5) & (foot_contact[:-1] < 0.5)
for onset_idx in np.flatnonzero(onset_mask):
window = foot_force[onset_idx : onset_idx + touchdown_window]
if window.size == 0 or np.all(~np.isfinite(window)):
continue
touchdown_peaks.append(float(np.nanmax(window)))
metrics["foot_impact_force_p95"] = _safe_nanpercentile(
np.asarray(touchdown_peaks, dtype=float), 95.0
)
stance_slip_mask = (contact_binary >= 0.5) & np.isfinite(tangent_speed)
if np.any(stance_slip_mask):
metrics["stance_slip_speed_p95"] = _safe_nanpercentile(
tangent_speed[stance_slip_mask], 95.0
)
return metrics
def _compute_clip_stability_summary(
data: Dict[str, np.ndarray],
robot_control_dt: float,
low_level_contact_dt: float,
) -> Dict[str, float]:
robot_low_level_dof_torque = (
np.asarray(data["robot_low_level_dof_torque"])
if "robot_low_level_dof_torque" in data
else None
)
if robot_low_level_dof_torque is None and "robot_dof_torque" in data:
robot_low_level_dof_torque = np.asarray(data["robot_dof_torque"])
if robot_low_level_dof_torque is None:
torque_chatter_hf_ratio = float("nan")
torque_jump_burst_max = float("nan")
else:
torque_chatter_hf_ratio = _compute_torque_chatter_hf_ratio(
robot_low_level_dof_torque, low_level_contact_dt
)
_, torque_jump_ratio = _compute_torque_jump_series(
robot_low_level_dof_torque, low_level_contact_dt
)
torque_jump_window = max(
1,
int(
round(
STABILITY_BURST_WINDOW_SECONDS
/ float(low_level_contact_dt)
)
),
)
torque_jump_burst_max = _compute_rolling_nanmean_max(
torque_jump_ratio[1:], torque_jump_window
)
torso_metrics = _compute_torso_roll_pitch_stability_metrics(
np.asarray(data["robot_global_angular_velocity"]),
robot_control_dt,
)
contact_metrics = _compute_contact_stability_metrics(
np.asarray(data["robot_low_level_foot_contact"])
if "robot_low_level_foot_contact" in data
else None,
np.asarray(data["robot_low_level_foot_normal_force"])
if "robot_low_level_foot_normal_force" in data
else None,
np.asarray(data["robot_low_level_foot_tangent_speed"])
if "robot_low_level_foot_tangent_speed" in data
else None,
low_level_contact_dt,
)
expert_switching_js_div = _compute_expert_switching_js_div(
np.asarray(data["robot_moe_expert_logits"])
if "robot_moe_expert_logits" in data
else None
)
return {
"torque_chatter_hf_ratio": torque_chatter_hf_ratio,
"torque_jump_burst_max": torque_jump_burst_max,
"expert_switching_js_div": expert_switching_js_div,
**torso_metrics,
**contact_metrics,
}
def _compute_clip_torque_jump_summary(
data: Dict[str, np.ndarray],
dof_mode: str,
torque_dt: float,
) -> Dict[str, float]:
robot_dof_torque = (
np.asarray(data["robot_dof_torque"])
if "robot_dof_torque" in data
else None
)
robot_low_level_dof_torque = (
np.asarray(data["robot_low_level_dof_torque"])
if "robot_low_level_dof_torque" in data
else None
)
if dof_mode == "23" and robot_dof_torque is not None:
total_dofs_in_file = int(robot_dof_torque.shape[1])
if total_dofs_in_file == 29:
idx_23_in_29_dof = list(range(19)) + list(range(22, 26))
robot_dof_torque = robot_dof_torque[:, idx_23_in_29_dof]
if (
robot_low_level_dof_torque is not None
and int(robot_low_level_dof_torque.shape[1])
== total_dofs_in_file
):
robot_low_level_dof_torque = robot_low_level_dof_torque[
:, idx_23_in_29_dof
]
chatter_torque = robot_low_level_dof_torque
if chatter_torque is None:
chatter_torque = robot_dof_torque
if chatter_torque is None or int(chatter_torque.shape[0]) <= 1:
return {
"mean_torque_jump_norm": float("nan"),
"p95_torque_jump_norm": float("nan"),
"mean_torque_jump_ratio": float("nan"),
"p95_torque_jump_ratio": float("nan"),
}
torque_jump_norm, torque_jump_ratio = _compute_torque_jump_series(
chatter_torque, torque_dt
)
return {
"mean_torque_jump_norm": float(np.nanmean(torque_jump_norm)),
"p95_torque_jump_norm": float(
np.nanpercentile(torque_jump_norm[1:], 95)
),
"mean_torque_jump_ratio": float(np.nanmean(torque_jump_ratio)),
"p95_torque_jump_ratio": float(
np.nanpercentile(torque_jump_ratio[1:], 95)
),
}
def _per_frame_metrics_from_npz(
motion_key: str,
data: Dict[str, np.ndarray],
dof_mode: str = "29",
robot_control_dt: float = DEFAULT_ROBOT_CONTROL_DT,
) -> pd.DataFrame:
"""Compute per-frame metrics for a single motion clip from loaded npz arrays.
Expects the following keys in `data` (URDF order):
- dof_pos, robot_dof_pos
- global_translation, robot_global_translation
- global_rotation_quat, robot_global_rotation_quat (xyzw)
"""
# Required arrays
jpos_gt = np.asarray(data["ref_global_translation"]) # (T, J, 3)
jpos_pred = np.asarray(data["robot_global_translation"]) # (T, J, 3)
rot_gt = np.asarray(data["ref_global_rotation_quat"]) # (T, J, 4) xyzw
rot_pred = np.asarray(data["robot_global_rotation_quat"]) # (T, J, 4)
dof_gt = np.asarray(data["ref_dof_pos"]) # (T, D)
dof_pred = np.asarray(data["robot_dof_pos"]) # (T, D)
robot_dof_vel = (
np.asarray(data["robot_dof_vel"]) if "robot_dof_vel" in data else None
)
robot_dof_acc = (
np.asarray(data["robot_dof_acc"]) if "robot_dof_acc" in data else None
)
robot_dof_torque = (
np.asarray(data["robot_dof_torque"])
if "robot_dof_torque" in data
else None
)
robot_low_level_dof_torque = (
np.asarray(data["robot_low_level_dof_torque"])
if "robot_low_level_dof_torque" in data
else None
)
robot_action_rate = (
np.asarray(data["robot_action_rate"])
if "robot_action_rate" in data
else None
)
total_dofs_in_file = int(dof_gt.shape[1])
IDX_23_IN_29_DOF = list(range(19)) + list(range(22, 26))
IDX_23_IN_29_BODY = [0] + [i + 1 for i in IDX_23_IN_29_DOF]
if dof_mode == "23":
if total_dofs_in_file == 29:
dof_gt = dof_gt[:, IDX_23_IN_29_DOF]
dof_pred = dof_pred[:, IDX_23_IN_29_DOF]
if (
robot_dof_vel is not None
and int(robot_dof_vel.shape[1]) == total_dofs_in_file
):
robot_dof_vel = robot_dof_vel[:, IDX_23_IN_29_DOF]
if (
robot_dof_acc is not None
and int(robot_dof_acc.shape[1]) == total_dofs_in_file
):
robot_dof_acc = robot_dof_acc[:, IDX_23_IN_29_DOF]
if (
robot_dof_torque is not None
and int(robot_dof_torque.shape[1]) == total_dofs_in_file
):
robot_dof_torque = robot_dof_torque[:, IDX_23_IN_29_DOF]
if (
robot_low_level_dof_torque is not None
and int(robot_low_level_dof_torque.shape[1])
== total_dofs_in_file
):
robot_low_level_dof_torque = robot_low_level_dof_torque[
:, IDX_23_IN_29_DOF
]
jpos_gt = jpos_gt[:, IDX_23_IN_29_BODY, :]
jpos_pred = jpos_pred[:, IDX_23_IN_29_BODY, :]
rot_gt = rot_gt[:, IDX_23_IN_29_BODY, :]
rot_pred = rot_pred[:, IDX_23_IN_29_BODY, :]
assert jpos_gt.shape == jpos_pred.shape
assert rot_gt.shape == rot_pred.shape
assert dof_gt.shape == dof_pred.shape
num_frames = int(jpos_gt.shape[0])
# Global MPJPE [mm]
mpjpe_g = (
np.mean(np.linalg.norm(jpos_gt - jpos_pred, axis=2), axis=1) * 1000.0
)
# Per-frame maximum body-link position error [m] (used for failure criterion)
# per_joint_err = np.linalg.norm(jpos_pred - jpos_gt, axis=2)
# frame_max_body_pos_err = np.max(per_joint_err, axis=1)
frame_max_body_pos_err = np.abs(jpos_pred[:, 0, 2] - jpos_gt[:, 0, 2])
# Localize by root (index 0)
jpos_gt_local = jpos_gt - jpos_gt[:, [0]]
jpos_pred_local = jpos_pred - jpos_pred[:, [0]]
ref_body_pos_root_rel = quat_apply(
quat_inv(rot_gt[:, 0, :]),
jpos_gt - jpos_gt[:, [0]],
)
robot_body_pos_root_rel = quat_apply(
quat_inv(rot_pred[:, 0, :]),
jpos_pred - jpos_pred[:, [0]],
)
mpjpe_l = (
np.mean(
np.linalg.norm(
robot_body_pos_root_rel - ref_body_pos_root_rel, axis=2
),
axis=1,
)
* 1000.0
)
# Procrustes-aligned MPJPE [mm]
pa_per_joint = p_mpjpe(jpos_pred_local, jpos_gt_local)
mpjpe_pa = np.mean(pa_per_joint, axis=1) * 1000.0
# Velocity/acceleration errors from positions (discrete frame diffs) [mm/frame],[mm/frame^2]
vel_gt = jpos_gt[1:] - jpos_gt[:-1]
vel_pred = jpos_pred[1:] - jpos_pred[:-1]
vel_dist = (
np.mean(np.linalg.norm(vel_pred - vel_gt, axis=2), axis=1) * 1000.0
)
acc_gt = jpos_gt[:-2] - 2 * jpos_gt[1:-1] + jpos_gt[2:]
acc_pred = jpos_pred[:-2] - 2 * jpos_pred[1:-1] + jpos_pred[2:]
accel_dist = (
np.mean(np.linalg.norm(acc_pred - acc_gt, axis=2), axis=1) * 1000.0
)
# DOF angle errors [radians] — whole body average
dof_err = np.abs(dof_pred - dof_gt)
whole_body_joints_dist = np.mean(dof_err, axis=1)
# Root orientation errors [radians] — handle zero-norm/invalid quaternions by NaN
q_gt_root = rot_gt[:, 0, :]
q_pred_root = rot_pred[:, 0, :]
norms_gt = np.linalg.norm(q_gt_root, axis=1)
norms_pred = np.linalg.norm(q_pred_root, axis=1)
valid_mask = (
(norms_gt > 0.0)
& (norms_pred > 0.0)
& np.isfinite(norms_gt)
& np.isfinite(norms_pred)
)
root_r_error = np.full((num_frames,), np.nan, dtype=float)
root_p_error = np.full((num_frames,), np.nan, dtype=float)
root_y_error = np.full((num_frames,), np.nan, dtype=float)
if np.any(valid_mask):
q_gt_valid = q_gt_root[valid_mask] / norms_gt[valid_mask, None]
q_pred_valid = q_pred_root[valid_mask] / norms_pred[valid_mask, None]
rel_valid = sRot.from_quat(q_gt_valid).inv() * sRot.from_quat(
q_pred_valid
)
euler_xyz = rel_valid.as_euler("xyz", degrees=False)
root_r_error[valid_mask] = np.abs(euler_xyz[:, 0])
root_p_error[valid_mask] = np.abs(euler_xyz[:, 1])
root_y_error[valid_mask] = np.abs(euler_xyz[:, 2])
# Root velocity error [m/frame]
root_pos_gt = jpos_gt[:, 0, :]
root_pos_pred = jpos_pred[:, 0, :]
root_vel_err = np.linalg.norm(
(root_pos_pred[1:] - root_pos_pred[:-1])
- (root_pos_gt[1:] - root_pos_gt[:-1]),
axis=1,
)
# Root height error [m]
root_height_error = np.abs(root_pos_pred[:, 2] - root_pos_gt[:, 2])
# Robot low-level magnitudes (optional)
mean_dof_vel = np.full((num_frames,), np.nan, dtype=float)
if robot_dof_vel is not None:
if int(robot_dof_vel.shape[0]) != num_frames:
raise ValueError(
"robot_dof_vel frame length mismatch: "
f"{robot_dof_vel.shape[0]} vs {num_frames}"
)
mean_dof_vel = np.linalg.norm(robot_dof_vel, axis=1)
mean_dof_acc = np.full((num_frames,), np.nan, dtype=float)
if robot_dof_acc is not None:
if int(robot_dof_acc.shape[0]) != num_frames:
raise ValueError(
"robot_dof_acc frame length mismatch: "
f"{robot_dof_acc.shape[0]} vs {num_frames}"
)
mean_dof_acc = np.linalg.norm(robot_dof_acc, axis=1)
mean_dof_torque = np.full((num_frames,), np.nan, dtype=float)
mean_torque_jump_norm = np.full((num_frames,), np.nan, dtype=float)
mean_torque_jump_ratio = np.full((num_frames,), np.nan, dtype=float)
if robot_dof_torque is not None:
if int(robot_dof_torque.shape[0]) != num_frames:
raise ValueError(
"robot_dof_torque frame length mismatch: "
f"{robot_dof_torque.shape[0]} vs {num_frames}"
)
mean_dof_torque = np.linalg.norm(robot_dof_torque, axis=1)
chatter_torque = robot_low_level_dof_torque
if chatter_torque is None:
chatter_torque = robot_dof_torque
if chatter_torque is not None and int(chatter_torque.shape[0]) > 1:
torque_jump_norm, torque_jump_ratio = _compute_torque_jump_series(
chatter_torque, robot_control_dt
)
mean_torque_jump_norm = _aggregate_sample_metric_to_frames(
torque_jump_norm, num_frames
)
mean_torque_jump_ratio = _aggregate_sample_metric_to_frames(
torque_jump_ratio, num_frames
)
mean_action_rate = np.full((num_frames,), np.nan, dtype=float)
if robot_action_rate is not None:
flat_action_rate = robot_action_rate.reshape(-1)
if int(flat_action_rate.shape[0]) != num_frames:
raise ValueError(
"robot_action_rate frame length mismatch: "
f"{flat_action_rate.shape[0]} vs {num_frames}"
)
mean_action_rate = flat_action_rate
# Frame DataFrame (align lengths by padding NaN at the start where needed)
def pad_front(x: np.ndarray, pad: int) -> np.ndarray:
if pad <= 0:
return x
return np.concatenate(
[np.full((pad,), np.nan, dtype=float), x], axis=0
)
df = pd.DataFrame(
{
"motion_key": [motion_key] * num_frames,
"frame_idx": np.arange(num_frames, dtype=int),
"mpjpe_g": mpjpe_g,
"mpjpe_l": mpjpe_l,
"mpjpe_pa": mpjpe_pa,
"vel_dist": pad_front(vel_dist, 1),
"accel_dist": pad_front(accel_dist, 2),
"frame_max_body_pos_err": frame_max_body_pos_err,
"whole_body_joints_dist": whole_body_joints_dist,
"root_r_error": root_r_error,
"root_p_error": root_p_error,
"root_y_error": root_y_error,
"root_vel_error": pad_front(root_vel_err, 1),
"root_height_error": root_height_error,
"mean_dof_vel": mean_dof_vel,
"mean_dof_acc": mean_dof_acc,
"mean_dof_torque": mean_dof_torque,
"mean_torque_jump_norm": mean_torque_jump_norm,
"mean_torque_jump_ratio": mean_torque_jump_ratio,
"mean_action_rate": mean_action_rate,
}
)
return df
def offline_evaluate_dumped_npzs(
npz_dir: str,
output_json_path: str,
failure_pos_err_thresh_m: float = 0.25,
metric_calculation: str = "per_clip",
dof_mode: str = "29",
threadpool_max_workers: Optional[int] = None,
) -> Dict[str, dict]:
"""Evaluate dumped NPZs in `npz_dir` and write a JSON summary to `output_dir`.
The function produces dataset-wide averages and per-clip averages across frames.
"""
npz_dir_abs = Path(npz_dir).resolve()
os.makedirs(npz_dir_abs, exist_ok=True)
# Add file handler for logging to metric.log
metric_log_path = npz_dir_abs / "metric.log"
logger.add(
str(metric_log_path),
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}",
level="INFO",
)
logger.info(f"Input NPZ directory (absolute): {npz_dir_abs}")
# Gather NPZ files
files = sorted(glob(os.path.join(npz_dir_abs, "*.npz")))
if len(files) == 0:
raise FileNotFoundError(f"No NPZ files found in: {npz_dir_abs}")
# Accumulate per-frame metrics
frame_tables: List[pd.DataFrame] = []
clip_meta: Dict[str, dict] = {}
skipped_files_count = 0
required_keys = [
"ref_dof_pos",
"ref_dof_vel",
"ref_global_translation",
"ref_global_rotation_quat",
"ref_global_velocity",
"ref_global_angular_velocity",
"robot_dof_pos",
"robot_dof_vel",
"robot_global_translation",
"robot_global_rotation_quat",
"robot_global_velocity",
"robot_global_angular_velocity",
]
optional_keys = [
"robot_dof_acc",
"robot_dof_torque",
"robot_low_level_dof_torque",
"robot_low_level_torque_dt",
"robot_low_level_foot_contact",
"robot_low_level_foot_normal_force",
"robot_low_level_foot_tangent_speed",
"robot_low_level_contact_dt",
"robot_action_rate",
"robot_moe_expert_indices",
"robot_moe_expert_logits",
]
def _compute_metrics_from_file(fpath: str):
try:
with np.load(fpath, allow_pickle=True) as npz_data:
# Extract arrays and metadata
data = {k: npz_data[k] for k in required_keys}
for k in optional_keys:
if k in npz_data.files:
data[k] = npz_data[k]
metadata = _parse_metadata_entry(npz_data.get("metadata"))
robot_control_dt = _extract_robot_control_dt(metadata, data)
low_level_contact_dt = _extract_low_level_contact_dt(
metadata, data, robot_control_dt
)
motion_key = os.path.splitext(os.path.basename(fpath))[0]
clip_len_from_name = _parse_clip_len_from_name(fpath)
df_frames = _per_frame_metrics_from_npz(
motion_key=motion_key,
data=data,
dof_mode=dof_mode,
robot_control_dt=robot_control_dt,
)
chatter_summary = _compute_clip_torque_jump_summary(
data=data, dof_mode=dof_mode, torque_dt=robot_control_dt
)
stability_summary = _compute_clip_stability_summary(
data=data,
robot_control_dt=robot_control_dt,
low_level_contact_dt=low_level_contact_dt,
)
# Clip-level info and failure criterion (max body-link pos error > threshold)
num_frames_clip = int(df_frames.shape[0])
clip_length = int(
metadata.get(
"clip_length", clip_len_from_name or num_frames_clip
)
)
max_body_err = float(
np.nanmax(df_frames["frame_max_body_pos_err"].to_numpy())
)
success = 1.0 if max_body_err <= failure_pos_err_thresh_m else 0.0
clip_meta_entry = {
"motion_key": motion_key,
"num_frames": num_frames_clip,
"clip_length": clip_length,
"success": success,
"max_body_pos_err": max_body_err,
"failure_threshold_m": float(failure_pos_err_thresh_m),
**chatter_summary,
**stability_summary,
}
return fpath, df_frames, motion_key, clip_meta_entry, None
except (ValueError, KeyError, BadZipFile, EOFError, OSError) as e:
return fpath, None, None, None, e
if threadpool_max_workers is None:
max_workers = max(1, min(len(files), 24))
requested_workers = None
else:
requested_workers = int(threadpool_max_workers)
if requested_workers <= 0:
raise ValueError("threadpool_max_workers must be > 0")
max_workers = min(requested_workers, len(files))
if max_workers <= 0:
max_workers = 1
logger.info(
f"Metric ThreadPoolExecutor max_workers={max_workers} "
f"(requested={requested_workers}, num_npz_files={len(files)})"
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(_compute_metrics_from_file, fpath): file_idx
for file_idx, fpath in enumerate(files)
}
processed_results = [None] * len(files)
for future in tqdm(
as_completed(futures.keys()),
total=len(files),
desc="Compute metrics from NPZs",
):
processed_results[futures[future]] = future.result()
for result in processed_results:
(
fpath,
df_frames,
motion_key,
clip_meta_entry,
file_error,
) = result
if file_error is not None:
logger.warning(f"\nCaught an error while processing file: {fpath}")
logger.warning(f"Error type: {type(file_error).__name__}")
logger.warning(f"Error message: {file_error}")
logger.warning("This file will be SKIPPED.")
skipped_files_count += 1
continue
frame_tables.append(df_frames)
clip_meta[motion_key] = clip_meta_entry
if skipped_files_count > 0:
logger.info(
f"\nFinished processing. Skipped a total of {skipped_files_count} files due to errors."
)
# If all files were skipped, there's nothing to process further.
if not frame_tables:
logger.error(
"No valid NPZ files could be processed. Aborting evaluation."
)
return {}
# Concatenate per-frame metrics
all_frames = pd.concat(frame_tables, ignore_index=True)
# Per-clip averages
frame_metric_cols = [
"mpjpe_g",
"mpjpe_l",
"whole_body_joints_dist",
"root_vel_error",
"root_r_error",
"root_p_error",
"root_y_error",
"root_height_error",
"mean_dof_vel",
"mean_dof_acc",
"mean_dof_torque",
"mean_torque_jump_norm",
"mean_torque_jump_ratio",
"mean_action_rate",
]
percentile_metric_cols = [
"mean_torque_jump_norm",
"mean_torque_jump_ratio",
]
percentile_rename_map = {
"mean_torque_jump_norm": "p95_torque_jump_norm",
"mean_torque_jump_ratio": "p95_torque_jump_ratio",
}
metric_cols = frame_metric_cols + list(percentile_rename_map.values())
clip_only_metric_cols = [
"torque_chatter_hf_ratio",
"torque_jump_burst_max",
"expert_switching_js_div",
"torso_rp_hf_ratio",
"torso_rp_angacc_p95",
"foot_contact_toggle_rate",
"foot_impact_force_p95",
"stance_slip_speed_p95",
]
metric_cols += clip_only_metric_cols
# Metric display configuration: metric_key -> (display_name, unit)
metric_display_map = {
"mpjpe_g": ("Global Bodylink Mean Position Error", "mm"),
"mpjpe_l": ("Local Bodylink Mean Position Error", "mm"),
"whole_body_joints_dist": ("DOF Position Error", "rad"),
"root_vel_error": ("Root Velocity Error", "m/s"),
"root_r_error": ("Root Roll Error", "rad"),
"root_p_error": ("Root Pitch Error", "rad"),
"root_y_error": ("Root Yaw Error", "rad"),
"root_height_error": ("Root Height Error", "mm"),
"mean_dof_vel": ("Mean DOF Velocity", "rad/s"),
"mean_dof_acc": ("Mean DOF Acceleration", "rad/s^2"),
"mean_dof_torque": ("Mean DOF Torque", "N*m"),
"mean_torque_jump_norm": ("Mean Torque Jump Norm", "N*m/s"),
"p95_torque_jump_norm": ("P95 Torque Jump Norm", "N*m/s"),
"mean_torque_jump_ratio": ("Mean Torque Jump Ratio", "ratio"),
"p95_torque_jump_ratio": ("P95 Torque Jump Ratio", "ratio"),
"mean_action_rate": ("Mean Action Rate", "1/s"),
"torque_chatter_hf_ratio": ("Torque Chatter HF Ratio", "ratio"),
"torque_jump_burst_max": ("Torque Jump Burst Max", "ratio"),
"expert_switching_js_div": ("Expert Switching JS Div", "bits"),
"torso_rp_hf_ratio": ("Torso RP HF Ratio", "ratio"),
"torso_rp_angacc_p95": ("Torso RP Angular Accel P95", "rad/s^2"),
"foot_contact_toggle_rate": ("Foot Contact Toggle Rate", "1/s"),
"foot_impact_force_p95": ("Foot Impact Force P95", "N"),
"stance_slip_speed_p95": ("Stance Slip Speed P95", "m/s"),
}
per_clip_mean = (
all_frames.groupby("motion_key")[frame_metric_cols]
.mean(numeric_only=True)
.reset_index()
)
per_clip_p95 = (
all_frames.groupby("motion_key")[percentile_metric_cols]
.quantile(0.95)
.reset_index()
.rename(columns=percentile_rename_map)
)
per_clip_summary = per_clip_mean.merge(
per_clip_p95, on="motion_key", how="left"
)
for metric_key in (
"mean_torque_jump_norm",
"p95_torque_jump_norm",
"mean_torque_jump_ratio",
"p95_torque_jump_ratio",
):
per_clip_summary[metric_key] = per_clip_summary["motion_key"].map(
{mk: clip_meta[mk].get(metric_key, np.nan) for mk in clip_meta}
)
for metric_key in clip_only_metric_cols:
per_clip_summary[metric_key] = per_clip_summary["motion_key"].map(
{mk: clip_meta[mk].get(metric_key, np.nan) for mk in clip_meta}
)
# Merge with success flags
per_clip_records = []
for _, row in per_clip_summary.iterrows():
mk = row["motion_key"]
rec = {**row.to_dict(), **clip_meta.get(mk, {})}
per_clip_records.append(rec)
# Persist per-clip metrics as a tabular CSV for easier downstream analysis.
per_clip_df = pd.DataFrame(per_clip_records)
output_csv_path = str(npz_dir_abs / "per_clip_metrics.csv")
per_clip_df.to_csv(output_csv_path, index=False)
logger.info(f"Saved per-clip metrics CSV to: {output_csv_path}")
dataset_means = {}
dataset_medians = {}
if metric_calculation == "per_frame":
agg_source = all_frames
agg_desc = "PER-FRAME"
else:
agg_source = per_clip_summary
agg_desc = "PER-CLIP"
for k in metric_cols:
if k in agg_source.columns:
arr = agg_source[k].to_numpy()
else:
arr = per_clip_summary[k].to_numpy()
dataset_means[k] = _safe_nanmean(arr)
dataset_medians[k] = _safe_nanmedian(arr)
success_rate = float(
np.mean([clip_meta[mk]["success"] for mk in clip_meta])
if len(clip_meta) > 0
else 0.0
)
dataset_means["success_rate"] = success_rate
# Compose result and write
result = {
"dataset": {
"calculation_mode": metric_calculation,
"mean": dataset_means,
"median": dataset_medians,
"success_rate": success_rate,
},
"num_clips": int(len(clip_meta)),
"per_clip": per_clip_records,
}
with open(output_json_path, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2)
# Conversion factors for unit conversion (assuming 50Hz)
frame_rate_hz = 50.0
unit_conversions = {
"root_height_error": 1000.0, # m to mm
"root_vel_error": frame_rate_hz, # m/frame to m/s
}
table_data = []
# Iterate through metric_display_map to preserve order
for key in metric_display_map.keys():
if key not in dataset_means:
continue
val_mean = dataset_means[key]
val_median = dataset_medians[key]
display_name, unit = metric_display_map[key]
# Apply unit conversion if needed
if key in unit_conversions:
factor = unit_conversions[key]
val_mean = val_mean * factor
val_median = val_median * factor
def fmt(v):
return f"{v:.4f}" if isinstance(v, float) else str(v)
table_data.append([display_name, fmt(val_mean), fmt(val_median), unit])
table_headers = ["Metric", "Mean", "Median", "Unit"]
output_tsv_path = str(npz_dir_abs / "whole_dataset_metrics.tsv")
with open(output_tsv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f, delimiter="\t", lineterminator="\n")
writer.writerow(table_headers)
writer.writerows(table_data)
logger.info(f"Saved whole-dataset metrics TSV to: {output_tsv_path}")
table_str = tabulate(
table_data,
headers=table_headers,
tablefmt="simple_outline",
colalign=("left", "left", "left", "left"),
)
logger.info(
"\n"
+ "=" * 80
+ f"\nDATASET-WISE METRICS ({agg_desc})\n"
+ "=" * 80
+ f"\n\n{table_str}\n"
+ "=" * 80
+ "\n"
)
return result
def parse_ckpt_and_dataset_from_eval_dirname(
eval_dir_name: str, dataset_suffix: str
):
VALID_PREFIXES = ["isaaclab_eval_output_", "mujoco_eval_output_"]
matched_prefix = None
for prefix in VALID_PREFIXES:
if eval_dir_name.startswith(prefix):
matched_prefix = prefix
break
if matched_prefix is None:
return None, None
rest = eval_dir_name[len(matched_prefix) :]
if not rest.endswith(dataset_suffix):
return None, None
model_part = rest[: -len(dataset_suffix)]
if model_part.endswith("_"):
model_part = model_part[:-1]
m = re.search(r"model_(\d+)$", model_part)
if not m:
return None, dataset_suffix
return m.group(1), dataset_suffix
def run_evaluation(
npz_dir: str,
dataset_suffix: str,
failure_pos_err_thresh_m: float = 0.25,
metric_calculation: str = "per_clip",
dof_mode: str = "29",
threadpool_max_workers: Optional[int] = None,
):
"""
Main function to run evaluation. It scans a root directory, runs evaluation
for each found subdirectory, and generates a final summary report.
Args:
npz_dir (str): Top-level directory containing all model evaluation results (e.g., 'logs/test').
output_dir (str): Directory to store all generated JSON files and logs.
failure_pos_err_thresh_m (float): The position error threshold in meters to determine a failure.
"""
root_path = Path(npz_dir)
logger.info(f"Starting batch evaluation. Root directory: '{root_path}'")
logger.info(
f"Searching for directories matching pattern: '{dataset_suffix}'"
)
def has_npz_files(path: Path) -> bool:
return path.is_dir() and any(path.glob("*.npz"))
is_single_eval_dir = (
root_path.is_dir()
and (
root_path.name.startswith("isaaclab_eval_output_")
or root_path.name.startswith("mujoco_eval_output_")
)
and has_npz_files(root_path)
)
if is_single_eval_dir:
output_path = root_path
else:
output_path = root_path / f"metrics_output_{dataset_suffix}"
output_path.mkdir(parents=True, exist_ok=True)
if is_single_eval_dir:
logger.info(
f"Detected '{root_path}' as a single evaluation directory. "
"Running offline evaluation only for this directory."
)
model_name = root_path.parent.name
ckpt_str, ds = parse_ckpt_and_dataset_from_eval_dirname(
root_path.name, dataset_suffix
)
if ckpt_str is None:
logger.warning(
f"Could not parse checkpoint/dataset from directory name '{root_path.name}'. "
"Using 'checkpoint_unknown' in output filename."
)
ckpt_str = "checkpoint_unknown"
ds = dataset_suffix
output_json_name = f"{model_name}_{ckpt_str}_{dof_mode}dof.json"
output_json_path = output_path / output_json_name
offline_evaluate_dumped_npzs(
npz_dir=str(root_path),
output_json_path=str(output_json_path),
failure_pos_err_thresh_m=failure_pos_err_thresh_m,
metric_calculation=metric_calculation,
dof_mode=dof_mode,
threadpool_max_workers=threadpool_max_workers,
)
logger.success(
f"Finished single-directory evaluation: model='{model_name}', checkpoint={ckpt_str}"
)
return
logger.info(
f"Treating '{root_path}' as root directory for batch evaluation."
)
# Find all directories matching the evaluation output pattern.
eval_dirs = sorted(
p
for p in root_path.glob(f"**/*eval_output_*_{dataset_suffix}")
if p.is_dir()
)
if not eval_dirs:
logger.error(
f"No directories matching the pattern '{dataset_suffix}' found under '{root_path}'. "
"Please check the path and pattern."
)
return
all_results = []
# Process each found evaluation directory.
for eval_dir in tqdm(eval_dirs, desc="Overall Progress"):
# Extract model name from the parent directory.
model_name = eval_dir.parent.name
# Parse the checkpoint number from the directory name.
ckpt_str, ds = parse_ckpt_and_dataset_from_eval_dirname(
eval_dir.name, dataset_suffix
)
if ckpt_str is None:
logger.warning(
f"Could not parse ckpt/dataset from '{eval_dir.name}'. Skipping."
)
continue
checkpoint = int(ckpt_str)
logger.info(
f"\n--- Processing: model='{model_name}', dataset='{ds}', checkpoint={checkpoint} ---"
)
# Construct a unique output JSON filename.
output_json_name = f"{model_name}_{checkpoint}.json"
output_json_path = output_path / output_json_name
# Call the evaluation function for the current directory.
result = offline_evaluate_dumped_npzs(
npz_dir=str(eval_dir),
output_json_path=str(output_json_path),
failure_pos_err_thresh_m=failure_pos_err_thresh_m,
metric_calculation=metric_calculation,
dof_mode=dof_mode,
threadpool_max_workers=threadpool_max_workers,
)
if result and "dataset" in result:
# Collect dataset-level average metrics for the final summary.
flat_result = {
"model": model_name,
"checkpoint": checkpoint,
**result["dataset"],
}
all_results.append(flat_result)
logger.success(
f"--- Finished processing: model='{model_name}', checkpoint={checkpoint} ---"
)
else:
logger.error(
f"--- Failed to process: model='{model_name}', checkpoint={checkpoint} ---"
)
if not all_results:
logger.error(
"No evaluations succeeded. Cannot generate a summary report."
)
return
logger.info("\n" + "=" * 80)
logger.info("Batch evaluation finished successfully.")
logger.info(f"Total successful evaluations: {len(all_results)}")
logger.info("=" * 80)
if __name__ == "__main__":
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--npz_dir", type=str, required=True)
argument_parser.add_argument(
"--dataset_suffix",
type=str,
required=True,
)
argument_parser.add_argument(
"--failure_pos_err_thresh_m", type=float, default=0.25
)
argument_parser.add_argument(
"--metric_calculation",
type=str,
choices=["per_clip", "per_frame"],
default="per_clip",
help="Calculation mode for dataset metrics. 'per_clip' averages clip means (Macro). 'per_frame' averages all frames (Micro).",
)
argument_parser.add_argument(
"--dof_mode",
type=str,
choices=["29", "23"],
default="29",
help="Compute metrics for full 29 DoF or reduced 23 DoF (excluding hands).",
)
argument_parser.add_argument(
"--threadpool_max_workers",
type=int,
default=None,
help="Max workers for per-NPZ ThreadPoolExecutor. "
"Default: None (auto = min(num_files, 24)).",
)
args = argument_parser.parse_args()
run_evaluation(
npz_dir=args.npz_dir,
dataset_suffix=args.dataset_suffix,
failure_pos_err_thresh_m=args.failure_pos_err_thresh_m,
metric_calculation=args.metric_calculation,
dof_mode=args.dof_mode,
threadpool_max_workers=args.threadpool_max_workers,
)
================================================
FILE: holomotion/src/evaluation/multi_model_metrics_report.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import argparse
import itertools
import json
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from loguru import logger
from scipy.stats import mannwhitneyu
import textwrap
DEFAULT_METRICS_TO_ANALYZE = [
"mpjpe_g",
"mpjpe_l",
"whole_body_joints_dist",
"root_vel_error",
"root_r_error",
"root_p_error",
"root_y_error",
"root_height_error",
]
RADAR_METRICS = [
"mpjpe_g",
"mpjpe_l",
"whole_body_joints_dist",
"root_vel_error",
]
DEFAULT_RADAR_MAPPING = {m: m for m in RADAR_METRICS}
DEFAULT_ALPHA = 0.05
class AnalysisReportGenerator:
"""Load per-clip JSON metrics, run analysis, generate plots + markdown report."""
def __init__(
self,
json_dir: str,
plots_dir: str,
dataset_name: str,
metrics_to_analyze: List[str],
radar_metric_mapping: Dict[str, str],
metric_types_for_radar: Dict[str, str],
alpha: float = DEFAULT_ALPHA,
plot_quantile_cutoff: float = 0.99,
kde_linewidth: float = 2.5,
min_normalized_value: float = 0.2,
radar_chart_filename: str = "radar_chart_comparison.png",
) -> None:
self.json_dir = Path(json_dir)
self.plots_dir = Path(plots_dir)
self.dataset_name = dataset_name
self.metrics_to_analyze = metrics_to_analyze
self.radar_metric_mapping = radar_metric_mapping.copy()
self.metric_types_for_radar = metric_types_for_radar.copy()
self.alpha = alpha
self.plot_quantile_cutoff = plot_quantile_cutoff
self.kde_linewidth = kde_linewidth
self.min_normalized_value = min_normalized_value
self.radar_chart_filename = radar_chart_filename
self.df: Optional[pd.DataFrame] = None
self.models: List[str] = []
def run(self) -> None:
self.plots_dir.mkdir(exist_ok=True, parents=True)
self.df = self._load_and_prepare_data()
if self.df is None or self.df.empty:
logger.warning("No valid data loaded; aborting analysis.")
return
self.models = sorted(self.df["model"].unique().tolist())
if len(self.models) < 1:
logger.warning("No models found in data; aborting analysis.")
return
self._create_matplotlib_radar_chart()
markdown_content = self._generate_markdown_report()
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_md = (
self.plots_dir / f"analysis_report_{self.dataset_name}_{ts}.md"
)
out_md.write_text(markdown_content, encoding="utf-8")
logger.info(f"Markdown report written to: {out_md}")
def _load_and_prepare_data(self) -> Optional[pd.DataFrame]:
if not self.json_dir.is_dir():
logger.error(f"json_dir '{self.json_dir}' is not a directory.")
return None
json_files = list(self.json_dir.glob("*.json"))
if not json_files:
logger.error(f"No .json files found in '{self.json_dir}'.")
return None
all_clips: List[Dict[str, Any]] = []
for jf in json_files:
model_name = jf.stem
data = json.loads(jf.read_text(encoding="utf-8"))
if not isinstance(data, dict) or "per_clip" not in data:
logger.warning(
f"Skipping non-eval JSON file '{jf.name}' "
f"(top-level type={type(data)}, has_per_clip={'per_clip' in data if isinstance(data, dict) else False})."
)
continue
per_clip = data.get("per_clip")
if not per_clip:
logger.warning(
f"File '{jf.name}' has empty 'per_clip'; skipping."
)
continue
for clip in per_clip:
clip["model"] = model_name
all_clips.append(clip)
if not all_clips:
logger.error("No per_clip data found in any JSON files.")
return None
df = pd.DataFrame(all_clips)
logger.info(
f"Loaded {len(df)} clip records from {len(json_files)} JSON files."
)
return df
def _create_kde_plot(self, metric: str, save_path: Path) -> None:
if self.df is None or metric not in self.df.columns:
return
if self.df[metric].isnull().all():
return
q_high = self.df[metric].quantile(self.plot_quantile_cutoff)
df_filtered = self.df[self.df[metric] <= q_high]
plt.style.use("seaborn-v0_8-whitegrid")
fig, ax = plt.subplots(figsize=(12, 7))
sns.kdeplot(
data=df_filtered,
x=metric,
hue="model",
hue_order=self.models,
ax=ax,
fill=False,
common_norm=False,
palette="tab10",
linewidth=self.kde_linewidth,
)
ax.set_title(
f'Error Distribution for "{metric}" on {self.dataset_name}',
fontsize=16,
weight="bold",
)
ax.set_xlabel(f"Error Value ({metric})", fontsize=12)
ax.set_ylabel("Density", fontsize=12)
ax.set_xlim(left=0)
legend = ax.get_legend()
if legend:
legend.set_title("Models", prop={"size": 14, "weight": "bold"})
for text in legend.get_texts():
text.set_fontsize(14)
fig.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close(fig)
def _create_matplotlib_radar_chart(self) -> None:
if self.df is None:
return
original_metrics = list(self.radar_metric_mapping.keys())
raw_labels = [self.radar_metric_mapping[m] for m in original_metrics]
display_labels = [
textwrap.fill(label, width=20, break_long_words=False)
for label in raw_labels
]
num_metrics = len(original_metrics)
median_df = self.df.groupby("model")[original_metrics].median()
rounded_median_df = median_df.round(2)
normalized_df = pd.DataFrame(
index=self.models, columns=original_metrics, dtype=float
)
scale = 1.0 - self.min_normalized_value
for metric in original_metrics:
medians = rounded_median_df[metric].dropna()
if medians.empty:
normalized_df[metric] = self.min_normalized_value
continue
min_val, max_val = medians.min(), medians.max()
rng = max_val - min_val if max_val > min_val else 1.0
for model in self.models:
val = rounded_median_df.loc[model, metric]
if pd.isna(val):
normalized_df.loc[model, metric] = (
self.min_normalized_value
)
continue
lower_better = (
self.metric_types_for_radar.get(metric, "lower") == "lower"
)
if lower_better:
base = (max_val - val) / rng
else:
base = (val - min_val) / rng
norm = self.min_normalized_value + base * scale
normalized_df.loc[model, metric] = norm
angles = np.linspace(
0, 2 * np.pi, num_metrics, endpoint=False
).tolist()
angles += angles[:1]
fig, ax = plt.subplots(
figsize=(10, 10), subplot_kw=dict(projection="polar")
)
cmap = plt.get_cmap("tab10")
colors = {m: cmap(i % 10) for i, m in enumerate(self.models)}
for model in self.models:
vals = normalized_df.loc[model].tolist()
vals += vals[:1]
ax.fill(angles, vals, color=colors[model], alpha=0.25)
ax.plot(
angles,
vals,
color=colors[model],
linewidth=2.5,
label=model,
marker="o",
markersize=7,
markeredgecolor="white",
markeredgewidth=1,
)
for j, metric in enumerate(original_metrics):
angle = angles[j]
groups: Dict[str, List[float]] = defaultdict(list)
for model in self.models:
orig_val = rounded_median_df.loc[model, metric]
norm_val = normalized_df.loc[model, metric]
groups[f"{orig_val:.2f}"].append(norm_val)
for label_text, norm_vals in groups.items():
avg_norm = float(np.mean(norm_vals))
offset = 0.05
ax.text(
angle,
avg_norm + offset,
label_text,
ha="center",
va="center",
color="black",
weight="bold",
fontsize=9,
bbox=dict(
boxstyle="square,pad=0.3",
fc="white",
ec="none",
alpha=0.8,
),
)
ax.set_thetagrids(np.degrees(angles[:-1]), display_labels, fontsize=16)
ax.tick_params(axis="x", pad=30)
ax.set_rgrids([0.4, 0.6, 0.8, 1.0], labels=[])
ax.set_ylim(0, 1.25)
ax.spines["polar"].set_visible(False)
ax.grid(color="grey", linestyle="--", linewidth=0.5)
handles, labels = ax.get_legend_handles_labels()
if handles:
legend_map = dict(zip(labels, handles))
ordered_labels = [m for m in self.models if m in legend_map]
ordered_handles = [legend_map[m] for m in ordered_labels]
ax.legend(
handles=ordered_handles,
labels=ordered_labels,
loc="upper center",
bbox_to_anchor=(0.5, 1.15),
ncol=len(ordered_handles),
fontsize=14,
frameon=False,
)
fig.suptitle(
f"Model Comparison on {self.dataset_name} Dataset",
fontsize=20,
weight="bold",
y=1.05,
)
save_path = self.plots_dir / self.radar_chart_filename
fig.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.info(f"Radar chart saved to: {save_path}")
def _generate_markdown_report(self) -> str:
if self.df is None or len(self.models) < 2:
return ""
parts: List[str] = [
f"**Dataset**: {self.dataset_name}",
f"**Models**: {', '.join(self.models)}",
f"**Significance level (alpha)**: {self.alpha}",
"### Pairwise metric comparisons and distributions",
]
two_models = len(self.models) == 2
if two_models:
model1, model2 = self.models[0], self.models[1]
for metric in self.metrics_to_analyze:
if metric not in self.df.columns:
continue
p_value_str = ""
if two_models:
d1 = self.df.loc[self.df["model"] == model1, metric].dropna()
d2 = self.df.loc[self.df["model"] == model2, metric].dropna()
if not d1.empty and not d2.empty:
_, p_val = mannwhitneyu(d1, d2, alternative="two-sided")
p_value_str = f" (p = {p_val:.3g})"
parts.append(f"#### Metric: `{metric}`{p_value_str}")
metric_stats: List[Dict[str, Any]] = []
for name in self.models:
data = self.df.loc[self.df["model"] == name, metric].dropna()
if data.empty:
continue
metric_stats.append(
{
"Model": name,
"Median": data.median(),
"Q1 (25%)": data.quantile(0.25),
"Q3 (75%)": data.quantile(0.75),
}
)
if metric_stats:
stats_df = (
pd.DataFrame(metric_stats)
.sort_values(by="Median")
.reset_index(drop=True)
)
parts.append(stats_df.to_markdown(index=False, floatfmt=".4f"))
findings: List[str] = []
lower_better = (
self.metric_types_for_radar.get(metric, "lower") == "lower"
)
for m1, m2 in itertools.combinations(self.models, 2):
d1 = self.df.loc[self.df["model"] == m1, metric].dropna()
d2 = self.df.loc[self.df["model"] == m2, metric].dropna()
if d1.empty or d2.empty:
continue
_, p_val = mannwhitneyu(d1, d2, alternative="two-sided")
if p_val >= self.alpha:
continue
m1_med, m2_med = d1.median(), d2.median()
better, worse = (m1, m2) if m1_med < m2_med else (m2, m1)
if not lower_better:
better, worse = worse, better
findings.append(
f"- **{better}** is significantly better than **{worse}** "
f"(p < {self.alpha})."
)
if findings:
parts.append("\n".join(findings))
else:
parts.append(
"No statistically significant differences between models."
)
safe_metric = metric.replace(" ", "_")
plot_filename = f"{safe_metric}.png"
plot_path = self.plots_dir / plot_filename
self._create_kde_plot(metric, plot_path)
parts.append(
f"##### Distribution plot\n"
f""
)
return "\n\n".join(parts)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Analyze per-clip JSON metrics, generate plots and markdown report."
)
parser.add_argument("--json_dir", type=str, required=True)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
json_dir = Path(args.json_dir).resolve()
name = json_dir.name
for prefix in ("metrics_output_",):
if name.startswith(prefix) and len(name) > len(prefix):
name = name[len(prefix) :]
break
dataset_name = name # e.g. "AMASS"
plots_dir = json_dir / f"analysis_plots_{dataset_name}"
metric_types_for_radar = {m: "lower" for m in DEFAULT_METRICS_TO_ANALYZE}
analyzer = AnalysisReportGenerator(
json_dir=args.json_dir,
plots_dir=str(plots_dir),
dataset_name=dataset_name,
metrics_to_analyze=DEFAULT_METRICS_TO_ANALYZE,
radar_metric_mapping=DEFAULT_RADAR_MAPPING,
metric_types_for_radar=metric_types_for_radar,
alpha=DEFAULT_ALPHA,
)
analyzer.run()
================================================
FILE: holomotion/src/evaluation/obs/__init__.py
================================================
from .obs_builder import PolicyObsBuilder, get_gravity_orientation
__all__ = [
"PolicyObsBuilder",
"get_gravity_orientation",
]
================================================
FILE: holomotion/src/evaluation/obs/obs_builder.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import numpy as np
import torch
from typing import Dict, List, Sequence, Any, Optional
def get_gravity_orientation(quaternion: np.ndarray) -> np.ndarray:
"""Calculate gravity orientation from quaternion.
Args:
quaternion: Array-like [w, x, y, z]
Returns:
np.ndarray of shape (3,) representing gravity projection.
"""
qw = float(quaternion[0])
qx = float(quaternion[1])
qy = float(quaternion[2])
qz = float(quaternion[3])
gravity_orientation = np.zeros(3, dtype=np.float32)
gravity_orientation[0] = 2.0 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2.0 * (qz * qy + qw * qx)
gravity_orientation[2] = 1.0 - 2.0 * (qw * qw + qz * qz)
return gravity_orientation
class _CircularBuffer:
"""History buffer for batched tensor data (batch==1 in our eval/deploy).
Stores history in oldest->newest order when accessed via .buffer.
"""
def __init__(self, max_len: int, feat_dim: int, device: str):
if max_len < 1:
raise ValueError(f"max_len must be >= 1, got {max_len}")
self._max_len = int(max_len)
self._feat_dim = int(feat_dim)
self._device = device
self._pointer = -1
self._num_pushes = 0
self._buffer: torch.Tensor = torch.zeros(
(self._max_len, 1, self._feat_dim),
dtype=torch.float32,
device="cpu",
)
@property
def buffer(self) -> torch.Tensor:
"""Tensor of shape [1, max_len, feat_dim], oldest->newest along dim=1."""
if self._num_pushes == 0:
raise RuntimeError(
"Attempting to read from an empty history buffer."
)
# roll such that oldest is at index=0 along the history axis
rolled = torch.roll(
self._buffer, shifts=self._max_len - self._pointer - 1, dims=0
)
return torch.transpose(rolled, 0, 1) # [1, max_len, feat]
def append(self, data: torch.Tensor) -> None:
"""Append one step: data shape [1, feat_dim] on the configured device."""
if (
data.ndim != 2
or data.shape[0] != 1
or data.shape[1] != self._feat_dim
):
raise ValueError(
f"Expected data with shape [1, {self._feat_dim}], got {tuple(data.shape)}"
)
self._pointer = (self._pointer + 1) % self._max_len
self._buffer[self._pointer] = data
if self._num_pushes == 0:
# duplicate first push across entire history for warm start
self._buffer[:] = data
self._num_pushes += 1
class PolicyObsBuilder:
"""Builds policy observations from Unitree lowstate with temporal history.
Designed to be shared between MuJoCo sim2sim evaluation and ROS2 deployment.
History management is internal and produces a flattened vector of size
sum_i(context_length * feat_i) across the configured observation items.
Supports two command modes:
- "motion_tracking": uses reference motion states
- "velocity_tracking": uses velocity commands [vx, vy, vyaw]
"""
def __init__(
self,
dof_names_onnx: Sequence[str],
default_angles_onnx: np.ndarray,
evaluator: Optional[Any] = None,
obs_policy_cfg: Optional[Dict[str, Any]] = None,
) -> None:
self.dof_names_onnx: List[str] = list(dof_names_onnx)
self.num_actions: int = len(self.dof_names_onnx)
self.evaluator = evaluator
self.obs_policy_cfg = obs_policy_cfg
if default_angles_onnx.shape[0] != self.num_actions:
raise ValueError(
"default_angles_onnx length must match num actions"
)
self.default_angles_onnx = default_angles_onnx.astype(np.float32)
self.default_angles_dict: Dict[str, float] = {
name: float(self.default_angles_onnx[idx])
for idx, name in enumerate(self.dof_names_onnx)
}
# Build observation schema from config if provided
self.term_specs: List[Dict[str, Any]] = []
for term_dict in self.obs_policy_cfg["atomic_obs_list"]:
for name, cfg in term_dict.items():
term_dict = {**cfg}
term_dict["name"] = name
self.term_specs.append(term_dict)
# Buffers are created lazily after first dimension inference
self._buffers: Dict[str, _CircularBuffer] = {}
def reset(self) -> None:
for buf in self._buffers.values():
buf._pointer = -1
buf._num_pushes = 0
buf._buffer.zero_()
def _compute_term(
self,
name: str,
) -> np.ndarray:
# Prefer evaluator-provided methods; no legacy fallbacks
if self.evaluator is not None:
meth = getattr(self.evaluator, f"_get_obs_{name}", None)
if callable(meth):
out = meth()
return np.asarray(out, dtype=np.float32).reshape(-1)
raise ValueError(
f"Unknown observation term '{name}' or evaluator method missing."
)
def build_policy_obs(self) -> np.ndarray:
"""Append one step using evaluator-provided observation terms and return flattened obs."""
# Compute per-term outputs
values: Dict[str, np.ndarray] = {}
for spec in self.term_specs:
name = spec["name"]
scale = spec.get("scale", 1.0)
values[name] = self._compute_term(name) * scale
# Lazily initialize buffers with inferred feature dims
if len(self._buffers) == 0:
for spec in self.term_specs:
name = spec["name"]
hist_len = int(spec.get("history_length", 0))
if hist_len <= 0:
continue
feat_dim = int(values[name].reshape(-1).shape[0])
self._buffers[name] = _CircularBuffer(
hist_len, feat_dim, "cpu"
)
# Append current step to buffers (skip terms without history)
for spec in self.term_specs:
name = spec["name"]
if name in self._buffers:
item = torch.as_tensor(
values[name].reshape(1, -1),
dtype=torch.float32,
device="cpu",
)
self._buffers[name].append(item)
# Assemble flat list according to term ordering and history flatten rules
flat_list: List[np.ndarray] = []
for spec in self.term_specs:
name = spec["name"]
flatten = bool(spec.get("flatten", True))
if name in self._buffers:
buf = self._buffers[name].buffer[0] # [hist, feat]
arr = (
buf.reshape(-1).detach().cpu().numpy()
if flatten
else buf[-1].detach().cpu().numpy()
)
flat_list.append(arr.astype(np.float32))
else:
# no history -> use computed value directly
flat_list.append(values[name].reshape(-1).astype(np.float32))
if len(flat_list) == 0:
return np.zeros(0, dtype=np.float32)
return np.concatenate(flat_list, axis=0).astype(np.float32)
================================================
FILE: holomotion/src/evaluation/ray_evaluator_actor.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Minimal Ray actor for batch eval. Lives in its own module so the class
can be pickled without pulling in torch/jit from eval_mujoco_sim2sim.
"""
import importlib
import os
import sys
import ray
from loguru import logger
class RayEvaluatorActor:
"""Persistent Ray actor: one evaluator (one ONNX session) per actor.
Schedule with num_gpus=1/actors_per_gpu so that multiple actors share one GPU.
Ray sets CUDA_VISIBLE_DEVICES so this actor sees a single GPU as device 0.
"""
def __init__(self, config_dict, output_dir):
logger.remove()
logger.add(sys.stderr, level="WARNING")
self.output_dir = output_dir
self.config_dict = config_dict
model_type = config_dict.get("model_type") or "holomotion"
self.evaluator = _load_ray_evaluator(config_dict, model_type)
self.evaluator.setup()
if model_type == "gmt":
self.evaluator.gmt_proprio_buf.clear()
def run_clip(self, file_path):
from holomotion.src.evaluation.eval_mujoco_sim2sim import (
_build_onnx_io_dump_dir,
_build_onnx_io_dump_path,
)
fname = os.path.basename(file_path)
save_name = fname.replace(".npz", "_eval.npz")
save_path = os.path.join(self.output_dir, save_name)
self.evaluator.load_specific_motion(file_path)
self.evaluator.reset_state_teleport()
for i in range(self.evaluator.n_motion_frames):
self.evaluator.motion_frame_idx = i
self.evaluator._update_policy()
self.evaluator._apply_control(sleep=False)
self.evaluator.counter += 1
meta = {
"source_file": fname,
"model": str(self.config_dict.get("ckpt_onnx_path", "")),
"source_npz": fname,
"onnx_model": str(self.config_dict.get("ckpt_onnx_path", "")),
}
self.evaluator.save_batch_result(save_path, meta)
model_type = self.config_dict.get("model_type") or "holomotion"
if bool(self.config_dict.get("dump_onnx_io_npy", False)) and (
model_type == "holomotion"
):
onnx_io_dir = _build_onnx_io_dump_dir(self.output_dir)
os.makedirs(onnx_io_dir, exist_ok=True)
self.evaluator.save_onnx_io_dump(
_build_onnx_io_dump_path(self.output_dir, fname), meta
)
return "success"
def _load_ray_evaluator(config_dict, model_type):
module_name = config_dict.get(
"ray_evaluator_module",
"holomotion.src.evaluation.eval_mujoco_sim2sim",
)
factory_module = importlib.import_module(module_name)
return factory_module._create_ray_evaluator(config_dict, model_type)
================================================
FILE: holomotion/src/evaluation/ray_metrics_postprocess.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import shutil
import sys
from pathlib import Path
from typing import Any
import ray
from loguru import logger
@ray.remote
def run_metrics_postprocess_job(
output_dir: str,
dataset_name: str,
calc_per_clip_metrics: bool,
failure_pos_err_thresh_m: float,
metric_calculation: str,
dof_mode: str,
metrics_threadpool_max_workers: int | None,
generate_report: bool,
job_log_dir: str | None,
ckpt_stem: str,
) -> dict[str, Any]:
logger.remove()
logger.add(sys.stderr, level="WARNING")
if calc_per_clip_metrics:
from holomotion.src.evaluation.metrics import run_evaluation
run_evaluation(
npz_dir=output_dir,
dataset_suffix=dataset_name,
failure_pos_err_thresh_m=failure_pos_err_thresh_m,
metric_calculation=metric_calculation,
dof_mode=dof_mode,
threadpool_max_workers=metrics_threadpool_max_workers,
)
report_path = None
if generate_report:
from holomotion.scripts.evaluation import mean_process_5metrics
report_path = (
mean_process_5metrics.generate_macro_mean_report_from_json_dir(
output_dir
)
)
exported_summary_tsv = None
if job_log_dir is not None:
job_log_dir_path = Path(job_log_dir)
sub_dataset_tsv = (
Path(output_dir) / "sub_dataset_macro_mean_metrics.tsv"
)
if sub_dataset_tsv.is_file():
export_name = f"{ckpt_stem}_sub_dataset_macro_mean_metrics.tsv"
export_path = job_log_dir_path / export_name
shutil.copy2(sub_dataset_tsv, export_path)
exported_summary_tsv = export_path
return {
"ckpt_stem": ckpt_stem,
"output_dir": output_dir,
"report_path": str(report_path) if report_path is not None else "",
"exported_summary_tsv": str(exported_summary_tsv)
if exported_summary_tsv is not None
else "",
}
================================================
FILE: holomotion/src/modules/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/modules/agent_modules.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
import io
import copy
import math
from pathlib import Path
import holomotion.src.modules.network_modules as NM
import torch
import torch.nn as nn
import torch.nn.functional as F
from holomotion.src.modules.network_modules import EmpiricalNormalization
from loguru import logger
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
from torch.distributions import Normal
def _module_device(module: nn.Module) -> torch.device:
for tensor in module.parameters():
return tensor.device
for tensor in module.buffers():
return tensor.device
return torch.device("cpu")
def _clone_module_for_cpu_export(module: nn.Module) -> nn.Module:
"""Clone a module for CPU-side export without mutating the live module."""
buffer = io.BytesIO()
# Keep the training module on-device; rank-local device hops during export
# can desynchronize DDP state and hang later collectives.
torch.save(module, buffer)
buffer.seek(0)
clone = torch.load(buffer, map_location="cpu", weights_only=False)
clone = clone.to("cpu")
clone.eval()
return clone
class TensorDictAssembler(torch.nn.Module):
def __init__(self, schema_config: dict, *, output_mode: str = "flat"):
super().__init__()
self.schema_config = schema_config
self.output_mode = str(output_mode).lower()
if self.output_mode not in ("flat", "seq"):
raise ValueError(
f"output_mode must be one of {{'flat','seq'}}, got {output_mode}"
)
self.seq_len_dict: dict[str, int] = {
str(k): int(v.get("seq_len", 1)) for k, v in schema_config.items()
}
_uniq_lens = sorted(set(self.seq_len_dict.values()))
self.seq_len: int | None = (
int(_uniq_lens[0]) if len(_uniq_lens) == 1 else None
)
if self.output_mode == "seq" and self.seq_len is None:
raise ValueError(
"TensorDictAssembler(output_mode='seq') requires a single unique seq_len "
f"across schema groups, got seq_len_dict={self.seq_len_dict}"
)
self.output_dim: int | None = None
@staticmethod
def _get_from_data(data: TensorDict, key: str):
# Support hierarchical keys like "latent/z"
if key in data.keys():
return data.get(key)
if "/" in key:
current = data
for p in key.split("/"):
if isinstance(current, TensorDict) and p in current.keys():
current = current.get(p)
else:
return None
return current
return None
def _validate_to_seq(
self,
tensor: torch.Tensor,
seq_len: int,
term: str,
) -> torch.Tensor:
"""Return [B, seq_len, d] tensor."""
if tensor.ndim == 2:
# [B, d] treat as seq_len=1
if seq_len != 1:
raise ValueError(
f"Term '{term}' expected seq_len={seq_len} but tensor is 2D {tensor.shape}"
)
return tensor[:, None, :]
if tensor.ndim == 3:
if tensor.shape[1] != seq_len:
raise ValueError(
f"Term '{term}' seq_len mismatch: expected {seq_len}, got {tensor.shape[1]}"
)
return tensor
raise ValueError(
f"Term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}"
)
def _validate_and_flatten(
self,
tensor: torch.Tensor,
seq_len: int,
term: str,
) -> torch.Tensor:
if tensor.ndim == 2:
# [B, D] treat as seq_len=1
if seq_len != 1:
raise ValueError(
f"Term '{term}' expected seq_len={seq_len} but tensor is 2D {tensor.shape}"
)
return tensor
if tensor.ndim == 3:
if tensor.shape[1] != seq_len:
raise ValueError(
f"Term '{term}' seq_len mismatch: expected {seq_len}, got {tensor.shape[1]}"
)
b, t, d = tensor.shape
return tensor.reshape(b, t * d)
raise ValueError(
f"Term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}"
)
def forward(self, data: TensorDict) -> torch.Tensor:
if not isinstance(data, TensorDict):
raise TypeError("TensorDictAssembler expects TensorDict input.")
if self.output_mode == "flat":
assembled = []
output_dim = 0
batch_size = None
for _, seq_cfg in self.schema_config.items():
seq_len = int(seq_cfg.get("seq_len", 1))
terms = seq_cfg.get("terms", [])
for term in terms:
tensor = self._get_from_data(data, term)
if tensor is None:
raise KeyError(
f"Missing term '{term}' in TensorDict input for assembler. "
"Use explicit hierarchical terms (e.g. 'group/term') "
"for nested TensorDict keys."
)
flat = self._validate_and_flatten(tensor, seq_len, term)
if batch_size is None:
batch_size = flat.shape[0]
elif flat.shape[0] != batch_size:
raise ValueError(
f"Batch size mismatch for term '{term}': {flat.shape[0]} vs {batch_size}"
)
assembled.append(flat)
output_dim += flat.shape[-1]
if not assembled:
raise ValueError(
"Assembler received an empty schema or no tensors found"
)
out = torch.cat(assembled, dim=-1)
# Cache output_dim on first successful forward
if self.output_dim is None:
self.output_dim = output_dim
return out
# output_mode == "seq"
assembled_seq = []
batch_size = None
seq_len_ref = None
for _, seq_cfg in self.schema_config.items():
seq_len = int(seq_cfg.get("seq_len", 1))
if seq_len_ref is None:
seq_len_ref = seq_len
elif seq_len != seq_len_ref:
raise ValueError(
"TensorDictAssembler(output_mode='seq') requires consistent seq_len "
f"across schema groups, got {seq_len_ref} vs {seq_len}"
)
terms = seq_cfg.get("terms", [])
for term in terms:
tensor = self._get_from_data(data, term)
if tensor is None:
raise KeyError(
f"Missing term '{term}' in TensorDict input for assembler. "
"Use explicit hierarchical terms (e.g. 'group/term') "
"for nested TensorDict keys."
)
seq_tensor = self._validate_to_seq(tensor, seq_len, term)
if batch_size is None:
batch_size = seq_tensor.shape[0]
elif seq_tensor.shape[0] != batch_size:
raise ValueError(
f"Batch size mismatch for term '{term}': {seq_tensor.shape[0]} vs {batch_size}"
)
assembled_seq.append(seq_tensor)
if not assembled_seq:
raise ValueError(
"Assembler received an empty schema or no tensors found"
)
out = torch.cat(assembled_seq, dim=-1)
# Expose seq_len and output_dim for sequence assembly
if self.seq_len is None:
self.seq_len = int(out.shape[1])
if self.output_dim is None:
self.output_dim = int(out.shape[-1])
return out
@torch.inference_mode()
def infer_output_dim(self, sample: TensorDict) -> int:
"""Run a dry forward pass to populate output_dim without grads."""
if self.output_dim is not None:
return int(self.output_dim)
_ = self.forward(sample)
return self.output_dim
class PPOActorOnnxModule(nn.Module):
def __init__(
self,
actor_module: nn.Module,
obs_normalizer: nn.Module,
obs_norm_enabled: bool,
obs_norm_clip: float,
):
super().__init__()
self.actor_module = actor_module
self.obs_normalizer = obs_normalizer
self.obs_norm_enabled = bool(obs_norm_enabled)
self.obs_norm_clip = float(obs_norm_clip)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
actor_obs = obs
if self.obs_norm_enabled:
actor_obs = self.obs_normalizer.normalize_only(actor_obs)
if self.obs_norm_clip > 0.0:
actor_obs = torch.clamp(
actor_obs, -self.obs_norm_clip, self.obs_norm_clip
)
return self.actor_module(actor_obs)
class PPOTFActorOnnxModule(nn.Module):
def __init__(
self,
actor_module: nn.Module,
obs_normalizer: nn.Module,
obs_norm_enabled: bool,
obs_norm_clip: float,
):
super().__init__()
self.actor_module = actor_module
self.obs_normalizer = obs_normalizer
self.obs_norm_enabled = bool(obs_norm_enabled)
self.obs_norm_clip = float(obs_norm_clip)
def forward(
self,
obs: torch.Tensor,
past_key_values: torch.Tensor,
step_idx: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
actor_obs = obs
if self.obs_norm_enabled:
actor_obs = self.obs_normalizer.normalize_only(actor_obs)
if self.obs_norm_clip > 0.0:
actor_obs = torch.clamp(
actor_obs, -self.obs_norm_clip, self.obs_norm_clip
)
return self.actor_module(
actor_obs,
past_key_values=past_key_values,
current_pos=step_idx,
)
class PPOTFWoKVCacheActorOnnxModule(nn.Module):
def __init__(
self,
actor_module: nn.Module,
obs_normalizer: nn.Module,
obs_norm_enabled: bool,
obs_norm_clip: float,
):
super().__init__()
self.actor_module = actor_module
self.obs_normalizer = obs_normalizer
self.obs_norm_enabled = bool(obs_norm_enabled)
self.obs_norm_clip = float(obs_norm_clip)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
if obs.ndim != 3:
raise ValueError(
f"Expected obs [B, 32, D] for no-kv ONNX path, got {obs.shape}"
)
if obs.shape[1] != 32:
raise ValueError(
f"Expected fixed token length 32, got {int(obs.shape[1])}"
)
actor_obs = obs
if self.obs_norm_enabled:
actor_obs = self.obs_normalizer.normalize_only(actor_obs)
if self.obs_norm_clip > 0.0:
actor_obs = torch.clamp(
actor_obs, -self.obs_norm_clip, self.obs_norm_clip
)
action_seq = self.actor_module.sequence_mu(actor_obs, attn_mask=None)
return action_seq[:, -1, :]
class PPOCondTFActorOnnxModule(nn.Module):
def __init__(
self,
actor_module: nn.Module,
state_obs_normalizer: nn.Module,
obs_norm_enabled: bool,
obs_norm_clip: float,
state_dim: int,
future_seq_len: int,
future_token_dim: int,
future_term_dims: list[int],
):
super().__init__()
self.actor_module = actor_module
self.state_obs_normalizer = state_obs_normalizer
self.obs_norm_enabled = bool(obs_norm_enabled)
self.obs_norm_clip = float(obs_norm_clip)
self.state_dim = int(state_dim)
self.future_seq_len = int(future_seq_len)
self.future_token_dim = int(future_token_dim)
self.future_term_dims = [int(x) for x in future_term_dims]
if any(d <= 0 for d in self.future_term_dims):
raise ValueError(
f"future_term_dims must be all positive, got {self.future_term_dims}"
)
if sum(self.future_term_dims) != self.future_token_dim:
raise ValueError(
"future_term_dims sum mismatch: expected "
f"{self.future_token_dim}, got {sum(self.future_term_dims)}"
)
def forward(
self,
obs: torch.Tensor,
past_key_values: torch.Tensor,
step_idx: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
if obs.ndim != 2:
raise ValueError(f"Expected obs [B, D], got {obs.shape}")
state_obs = obs[:, : self.state_dim]
future_flat = obs[:, self.state_dim :]
expected_future_dim = self.future_seq_len * self.future_token_dim
if future_flat.shape[-1] != expected_future_dim:
raise ValueError(
"Future obs dim mismatch for ONNX path: expected "
f"{expected_future_dim}, got {future_flat.shape[-1]}"
)
if self.obs_norm_enabled:
state_obs = self.state_obs_normalizer.normalize_only(state_obs)
if self.obs_norm_clip > 0.0:
state_obs = torch.clamp(
state_obs, -self.obs_norm_clip, self.obs_norm_clip
)
# Reconstruct [B, N_fut, D_fut] from term-major flattened layout:
# [term1 (N_fut*d1), term2 (N_fut*d2), ...] -> per-step concat along last dim.
b = int(obs.shape[0])
offset = 0
future_parts = []
for d_term in self.future_term_dims:
span = int(self.future_seq_len * d_term)
chunk = future_flat[:, offset : offset + span]
future_parts.append(chunk.reshape(b, self.future_seq_len, d_term))
offset += span
if offset != int(future_flat.shape[-1]):
raise ValueError(
"Future flat slicing mismatch in ONNX path: "
f"consumed={offset}, total={int(future_flat.shape[-1])}"
)
future_obs = torch.cat(future_parts, dim=-1)
return self.actor_module._forward_inference_onnx_cond(
state_obs,
future_obs,
past_key_values,
step_idx,
)
class PPOActor(TensorDictModuleBase):
def __init__(
self,
obs_schema: dict | None,
module_config_dict: dict,
num_actions: int,
init_noise_std: float,
*,
obs_example: dict | None = None,
):
super(PPOActor, self).__init__()
self.use_logvar = module_config_dict.get("use_logvar", False)
obs_norm_cfg = module_config_dict.get("obs_norm", {})
self.obs_norm_enabled = bool(obs_norm_cfg.get("enabled", False))
if self.obs_norm_enabled:
self.obs_norm_clip = float(obs_norm_cfg.get("clip_range", 0.0))
self.obs_norm_eps = float(obs_norm_cfg.get("epsilon", 1.0e-8))
self.obs_norm_update_method = str(
obs_norm_cfg.get(
"update_method", obs_norm_cfg.get("method", "cumulative")
)
).lower()
self.obs_norm_ema_momentum = float(
obs_norm_cfg.get("ema_momentum")
)
module_config_dict = self._process_module_config(
module_config_dict,
num_actions,
)
self.actor_net_type = module_config_dict.get("type", "MLP")
logger.info(f"actor_net_type: {self.actor_net_type}")
actor_net_class = getattr(NM, self.actor_net_type, None)
if actor_net_class is NM.MLP and obs_schema is None:
raise ValueError(
"PPOActor(Mlp) requires obs_schema so the agent module can assemble"
"TensorDict observations into a flat tensor."
)
if obs_schema is not None:
output_mode = "seq" if actor_net_class is NM.ConvMLP else "flat"
self.assembler = TensorDictAssembler(
obs_schema, output_mode=output_mode
)
if obs_example is not None:
self.assembler.infer_output_dim(obs_example)
if self.assembler.output_dim is None:
raise ValueError(
"TensorDictAssembler could not infer output_dim"
)
input_dim_for_net = int(self.assembler.output_dim)
else:
raise ValueError("obs_schema can't be None!")
actor_in_keys: list[str] = []
for _, seq_cfg in obs_schema.items():
if not isinstance(seq_cfg, dict):
continue
for term in seq_cfg.get("terms", []):
actor_in_keys.append(str(term))
self.in_keys = actor_in_keys
self.out_keys = [
"actions",
"actions_log_prob",
"mu",
"sigma",
"entropy",
]
if self.obs_norm_enabled and self.assembler is not None:
self.obs_normalizer = EmpiricalNormalization(
shape=self.assembler.output_dim,
eps=self.obs_norm_eps,
update_method=self.obs_norm_update_method,
ema_momentum=self.obs_norm_ema_momentum,
)
else:
self.obs_normalizer = nn.Identity()
# Always pass obs_example if available
if obs_example is not None:
self.actor_module = actor_net_class(
input_dim=input_dim_for_net,
output_dim=int(module_config_dict["output_dim"]),
module_config_dict=module_config_dict,
)
else:
raise ValueError("Obs example can't be None!")
if "output_head_init_scale" in module_config_dict:
output_head_init_scale = float(
module_config_dict["output_head_init_scale"]
)
if output_head_init_scale <= 0.0:
raise ValueError("output_head_init_scale must be > 0.")
output_head = self.actor_module.output_head
if not isinstance(output_head, nn.Linear):
raise ValueError(
"output_head_init_scale requires actor_module.output_head to be nn.Linear."
)
with torch.no_grad():
output_head.weight.mul_(output_head_init_scale)
if output_head.bias is not None:
output_head.bias.mul_(output_head_init_scale)
self._actor_schema_module = bool(
getattr(self.actor_module, "proprio_assembler", None)
)
self.fix_sigma = module_config_dict.get("fix_sigma", False)
self.max_sigma = module_config_dict.get("max_sigma", 1.0)
self.min_sigma = module_config_dict.get("min_sigma", 0.1)
if "noise_std_type" in module_config_dict:
self.noise_std_type = str(
module_config_dict["noise_std_type"]
).lower()
elif self.use_logvar:
self.noise_std_type = "log"
else:
self.noise_std_type = "scalar"
# Action noise parameters (kept outside nets so optimizer updates them)
if self.noise_std_type == "log":
logger.info("Using log-std parameterization for action noise")
self.log_std = nn.Parameter(
torch.log(torch.ones(num_actions) * init_noise_std)
)
if self.fix_sigma:
self.log_std.requires_grad = False
else: # scalar (default)
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
if self.fix_sigma:
self.std.requires_grad = False
self.distribution = None
# disable args validation for speedup
Normal.set_default_validate_args = False
self.actor_obs_transforms: list[callable] = []
if self.obs_norm_enabled:
self.actor_obs_transforms.append(self._normalize_actor_obs)
def _process_module_config(self, module_config_dict, num_actions):
if module_config_dict.get("output_schema", None) is not None:
raise ValueError(
"PPOActor no longer supports module_config_dict.output_schema. "
"Use scalar module_config_dict.output_dim instead."
)
# Resolve output_dim placeholders when present.
if "output_dim" in module_config_dict:
output_dim = module_config_dict["output_dim"]
if isinstance(output_dim, list):
raise ValueError(
"PPOActor expects module_config_dict.output_dim to be a scalar. "
"List-valued output_dim is not supported."
)
if output_dim == "robot_action_dim":
module_config_dict["output_dim"] = num_actions
return module_config_dict
def _sigma_from_params(self) -> torch.Tensor:
if self.noise_std_type == "log":
return torch.exp(self.log_std)
return self.std
def _normalize_actor_obs(
self, obs: torch.Tensor, update: bool
) -> torch.Tensor:
if not self.obs_norm_enabled:
return obs
clip = float(self.obs_norm_clip)
if obs.ndim == 3:
b, seq_len, d = obs.shape
flat_obs = obs.reshape(b * seq_len, d)
if update:
self.obs_normalizer.update(flat_obs)
flat_obs = self.obs_normalizer.normalize_only(flat_obs)
obs = flat_obs.reshape(b, seq_len, d)
else:
if update:
self.obs_normalizer.update(obs)
obs = self.obs_normalizer.normalize_only(obs)
if clip > 0.0:
obs = torch.clamp(obs, -clip, clip)
return obs
def _sigma_like(self, like: torch.Tensor) -> torch.Tensor:
sigma_vec = self._sigma_from_params()
sigma_vec = torch.clamp(
sigma_vec,
min=float(self.min_sigma),
max=float(self.max_sigma),
)
if sigma_vec.ndim == 1 and like.ndim >= 2:
view_shape = [1 for _ in range(like.ndim - 1)] + [
sigma_vec.shape[0]
]
return sigma_vec.view(*view_shape).expand_as(like)
if sigma_vec.shape != like.shape:
return sigma_vec.expand_as(like)
return sigma_vec
@property
def actor(self):
return self.actor_module
@property
def flat_obs_dim(self) -> int:
if self.assembler is None:
raise ValueError(
"PPOActor has no assembler; flat obs dim unavailable."
)
if self.assembler.output_dim is None:
raise ValueError(
"PPOActor assembler output_dim is not initialized."
)
return int(self.assembler.output_dim)
def export_onnx(
self,
onnx_path: str | Path,
*,
opset_version: int = 17,
) -> str:
if self._actor_schema_module:
raise ValueError(
"PPOActor export expects flat-obs actor modules, not schema-native modules."
)
export_path = Path(onnx_path)
export_path.parent.mkdir(parents=True, exist_ok=True)
if hasattr(self.actor_module, "clear_router_distribution_cache"):
self.actor_module.clear_router_distribution_cache()
actor_module = _clone_module_for_cpu_export(self.actor_module)
if self.obs_norm_enabled:
obs_normalizer = _clone_module_for_cpu_export(self.obs_normalizer)
else:
obs_normalizer = nn.Identity()
exporter = PPOActorOnnxModule(
actor_module=actor_module,
obs_normalizer=obs_normalizer,
obs_norm_enabled=self.obs_norm_enabled,
obs_norm_clip=self.obs_norm_clip if self.obs_norm_enabled else 0.0,
).to("cpu")
exporter.eval()
obs = torch.zeros(
1, self.flat_obs_dim, device="cpu", dtype=torch.float32
)
torch.onnx.export(
exporter,
(obs,),
str(export_path),
export_params=True,
opset_version=opset_version,
verbose=False,
dynamo=False,
input_names=["obs"],
output_names=["actions"],
)
return str(export_path)
def forward(
self,
obs_td: TensorDict,
actions: torch.Tensor | None = None,
mode: str = "sampling",
*,
update_obs_norm: bool = True,
) -> TensorDict:
"""TensorDict-first forward for PPOActor.
Returns a TensorDict with keys:
- actions: [B, A]
- actions_log_prob: [B] (sampling/logp only)
- mu: [B, A]
- sigma: [B, A]
- entropy: [B] (sampling/logp only)
"""
if mode not in ("sampling", "logp", "inference"):
raise ValueError(f"Unsupported mode: {mode}")
if not isinstance(obs_td, TensorDict):
raise ValueError("PPOActor.forward expects TensorDict input.")
td = obs_td.clone(
recurse=False
) # this only clones the tree sturcture, not the data
if self._actor_schema_module:
mu = self.actor_module(obs_td)
else:
if self.assembler is None:
raise ValueError(
"Flat-tensor actor module requires obs_schema in PPOActor init."
)
actor_obs = self.assembler(obs_td)
update = bool(update_obs_norm)
for fn in self.actor_obs_transforms:
actor_obs = fn(actor_obs, update)
mu = self.actor_module(actor_obs)
sigma = self._sigma_like(mu)
td.set("mu", mu)
td.set("sigma", sigma)
if mode == "inference":
actions_out = mu
td.set("actions", actions_out)
return td
self.distribution = Normal(mu, sigma)
if mode == "sampling":
actions_out = self.distribution.sample()
else:
if actions is None:
raise ValueError("actions must be provided when mode='logp'")
actions_out = actions
td.set("actions", actions_out)
td.set(
"actions_log_prob",
self.distribution.log_prob(actions_out).sum(dim=-1),
)
td.set("entropy", self.distribution.entropy().sum(dim=-1))
return td
def update_distribution(self, actor_obs):
mean = self.actor(actor_obs)
# Resolve std according to parameterization
std_val = self._sigma_from_params()
std_val = torch.clamp(std_val, min=self.min_sigma, max=self.max_sigma)
self.distribution = Normal(mean, std_val)
def override_sigma(self, sigma_override: float | torch.Tensor) -> None:
"""Override actor sigma parameters (std) explicitly.
Args:
sigma_override: scalar or [A] tensor for sigma_theta (std).
"""
if self.noise_std_type not in ("scalar", "log"):
raise ValueError(
f"Unsupported noise_std_type for override: {self.noise_std_type}"
)
param = self.log_std if self.noise_std_type == "log" else self.std
sigma_tensor = torch.as_tensor(
sigma_override, device=param.device, dtype=param.dtype
)
if sigma_tensor.numel() == 1:
sigma_tensor = sigma_tensor.expand_as(param)
elif sigma_tensor.shape != param.shape:
raise ValueError(
f"sigma_override shape {tuple(sigma_tensor.shape)} does not match "
f"actor sigma shape {tuple(param.shape)}."
)
if torch.any(sigma_tensor <= 0):
raise ValueError("sigma_override must be > 0 for all dims.")
if self.noise_std_type == "log":
sigma_tensor = torch.log(sigma_tensor)
with torch.no_grad():
param.copy_(sigma_tensor)
class PPOCritic(TensorDictModuleBase):
def __init__(
self,
obs_schema: dict | None,
module_config_dict,
*,
obs_example: dict | None = None,
):
super(PPOCritic, self).__init__()
self.critic_net_type = module_config_dict.get("type", "MLP")
obs_norm_cfg = module_config_dict.get("obs_norm", {})
self.obs_norm_enabled = bool(obs_norm_cfg.get("enabled", False))
if self.obs_norm_enabled:
self.obs_norm_clip = float(obs_norm_cfg.get("clip_range", 0.0))
self.obs_norm_eps = float(obs_norm_cfg.get("epsilon", 1.0e-8))
self.obs_norm_update_method = str(
obs_norm_cfg.get(
"update_method", obs_norm_cfg.get("method", "cumulative")
)
).lower()
self.obs_norm_ema_momentum = float(
obs_norm_cfg.get("ema_momentum")
)
critic_net_class = getattr(NM, self.critic_net_type, None)
if critic_net_class is None:
critic_net_class = globals().get(self.critic_net_type, None)
if critic_net_class is None or not isinstance(critic_net_class, type):
available_classes = [
name
for name in dir(NM)
if isinstance(getattr(NM, name, None), type)
] + [
name
for name, obj in globals().items()
if isinstance(obj, type)
]
raise NotImplementedError(
f"Unknown critic_net_type: {self.critic_net_type}. "
f"Available classes: {available_classes}"
)
if critic_net_class is NM.MLP and obs_schema is None:
raise ValueError(
"PPOCritic(MLP) requires obs_schema so the agent module can assemble "
"TensorDict observations into a flat tensor."
)
# Build assembler for flat-tensor networks only
# Schema-based networks (e.g., MultiTaskCritic) don't need it
if obs_schema is not None:
output_mode = "seq" if critic_net_class is NM.ConvMLP else "flat"
self.assembler = TensorDictAssembler(
obs_schema, output_mode=output_mode
)
if obs_example is not None:
self.assembler.infer_output_dim(obs_example)
if self.assembler.output_dim is None:
raise ValueError(
"TensorDictAssembler could not infer output_dim; provide obs_example."
)
input_dim_for_net = int(self.assembler.output_dim)
else:
# Schema-based modules don't use wrapper's assembler
self.assembler = None
input_dim_for_net = 0
critic_in_keys: list[str] = []
if obs_schema is not None:
for _, seq_cfg in obs_schema.items():
if not isinstance(seq_cfg, dict):
continue
for term in seq_cfg.get("terms", []):
critic_in_keys.append(str(term))
self.in_keys = critic_in_keys
self.out_keys = ["values"]
if self.obs_norm_enabled and self.assembler is not None:
self.obs_normalizer = EmpiricalNormalization(
shape=self.assembler.output_dim,
eps=self.obs_norm_eps,
update_method=self.obs_norm_update_method,
ema_momentum=self.obs_norm_ema_momentum,
)
else:
self.obs_normalizer = nn.Identity()
# Always pass obs_example if available
if obs_example is not None:
self.critic_module = critic_net_class(
input_dim=input_dim_for_net,
output_dim=int(module_config_dict["output_dim"]),
module_config_dict=module_config_dict,
)
else:
raise ValueError("obs_schema can't be None!")
self._critic_schema_module = bool(
getattr(self.critic_module, "proprio_assembler", None)
)
self.critic_obs_transforms: list[callable] = []
if self.obs_norm_enabled:
self.critic_obs_transforms.append(self._normalize_critic_obs)
def _normalize_critic_obs(
self, obs: torch.Tensor, update: bool
) -> torch.Tensor:
if not self.obs_norm_enabled:
return obs
clip = float(self.obs_norm_clip)
if obs.ndim == 3:
b, seq_len, d = obs.shape
flat_obs = obs.reshape(b * seq_len, d)
if update:
self.obs_normalizer.update(flat_obs)
flat_obs = self.obs_normalizer.normalize_only(flat_obs)
obs = flat_obs.reshape(b, seq_len, d)
else:
if update:
self.obs_normalizer.update(obs)
obs = self.obs_normalizer.normalize_only(obs)
if clip > 0.0:
obs = torch.clamp(obs, -clip, clip)
return obs
def forward(
self,
obs_td: TensorDict,
update_obs_norm: bool = True,
**kwargs,
) -> TensorDict:
"""TensorDict-first forward for PPOCritic.
Args:
obs_td: TensorDict observations keyed by obs terms.
update_obs_norm: If False, skip updating running stats.
Returns:
TensorDict with key:
- "values": [B, 1]
"""
if not isinstance(obs_td, TensorDict):
raise ValueError("PPOCritic.forward expects TensorDict input.")
td = obs_td.clone(recurse=False)
if self._critic_schema_module:
values = self.critic_module(obs_td)
if values.ndim == 1:
values = values[..., None]
td.set("values", values)
return td
if self.assembler is None:
raise ValueError(
"Flat-tensor critic module requires obs_schema in PPOCritic init."
)
critic_obs = self.assembler(obs_td)
update = bool(update_obs_norm)
for fn in self.critic_obs_transforms:
critic_obs = fn(critic_obs, update)
values = self.critic_module(critic_obs)
if values.ndim == 1:
values = values[..., None]
td.set("values", values)
return td
class PPOTFActor(PPOActor):
"""Transformer-based PPO actor wrapper compatible with PPOActor interface.
- Uses NM.TransformerDecoderPolicy as actor_module
- Provides KV-cache controls
- Uses model-predicted diagonal std for distribution
"""
def __init__(
self,
obs_schema: dict | None,
module_config_dict: dict,
num_actions: int,
init_noise_std: float,
*,
obs_example: dict | None = None,
):
super().__init__(
obs_schema=obs_schema,
module_config_dict=module_config_dict,
num_actions=num_actions,
init_noise_std=init_noise_std,
obs_example=obs_example,
)
# Ensure initial std is strictly inside [min_sigma, max_sigma] to avoid boundary saturation
init_std_val = float(init_noise_std)
if not (self.min_sigma < init_std_val < self.max_sigma):
# Expand bounds conservatively if needed
if init_std_val >= self.max_sigma:
self.max_sigma = max(self.max_sigma, init_std_val * 2.0)
if init_std_val <= self.min_sigma:
self.min_sigma = min(self.min_sigma, init_std_val * 0.1)
aux_cfg = module_config_dict.get("aux_state_pred", {})
self.aux_state_pred_enabled = bool(aux_cfg.get("enabled", False))
aux_cmd_cfg = module_config_dict.get("aux_router_command_recon", {})
self.aux_router_command_recon_enabled = bool(
aux_cmd_cfg.get("enabled", False)
)
aux_switch_cfg = module_config_dict.get(
"aux_router_switch_penalty", {}
)
self.aux_router_switch_penalty_enabled = bool(
aux_switch_cfg.get("enabled", False)
)
aux_router_future_cfg = module_config_dict.get(
"aux_router_future_recon", {}
)
self.aux_router_future_recon_enabled = bool(
aux_router_future_cfg.get("enabled", False)
)
self.aux_router_future_recon_assembler: TensorDictAssembler | None = (
None
)
def _sigma_from_params(self) -> torch.Tensor:
# Prefer log-std if present; otherwise use softplus(linear) for positivity
if hasattr(self, "log_std"):
return torch.exp(self.log_std)
return F.softplus(self.std)
def reset_kv_cache(self, num_envs: int, device):
if hasattr(self.actor_module, "reset_kv_cache"):
self.actor_module.reset_kv_cache(num_envs, device)
def clear_env_cache(self, env_ids: torch.Tensor):
if hasattr(self.actor_module, "clear_env_cache"):
self.actor_module.clear_env_cache(env_ids)
def onnx_past_key_values_shape(
self, *, batch_size: int = 1
) -> tuple[int, int, int, int, int, int]:
num_kv_layers = int(
getattr(
self.actor_module, "onnx_kv_layers", self.actor_module.n_layers
)
)
return (
num_kv_layers,
2,
int(batch_size),
int(self.actor_module.max_ctx_len),
int(self.actor_module.n_kv_heads),
int(self.actor_module.head_dim),
)
def onnx_moe_layer_indices(self) -> list[int]:
layers = getattr(self.actor_module, "layers", None)
if layers is None:
return []
return [
layer_idx
for layer_idx, layer in enumerate(layers)
if isinstance(layer, NM.GroupedMoEBlock)
]
def onnx_routing_output_names(self) -> list[str]:
output_names: list[str] = []
for layer_idx in self.onnx_moe_layer_indices():
output_names.extend(
[
f"moe_layer_{layer_idx}_expert_indices",
f"moe_layer_{layer_idx}_expert_logits",
]
)
return output_names
def _maybe_update_aux_router_future_recon_norm(
self,
obs_td: TensorDict,
*,
update: bool,
) -> None:
if (
not update
or not self.aux_router_future_recon_enabled
or self.aux_router_future_recon_assembler is None
):
return
future_target = self.aux_router_future_recon_assembler(obs_td)
self.actor_module.update_aux_router_future_recon_normalizer(
future_target
)
def export_onnx(
self,
onnx_path: str | Path,
*,
opset_version: int = 17,
use_kv_cache: bool = True,
) -> str:
export_path = Path(onnx_path)
export_path.parent.mkdir(parents=True, exist_ok=True)
if hasattr(self.actor_module, "clear_router_distribution_cache"):
self.actor_module.clear_router_distribution_cache()
actor_module = _clone_module_for_cpu_export(self.actor_module)
if self.obs_norm_enabled:
obs_normalizer = _clone_module_for_cpu_export(self.obs_normalizer)
else:
obs_normalizer = nn.Identity()
obs = torch.zeros(
1, self.flat_obs_dim, device="cpu", dtype=torch.float32
)
if use_kv_cache:
exporter = PPOTFActorOnnxModule(
actor_module=actor_module,
obs_normalizer=obs_normalizer,
obs_norm_enabled=self.obs_norm_enabled,
obs_norm_clip=self.obs_norm_clip
if self.obs_norm_enabled
else 0.0,
).to("cpu")
exporter.eval()
cache_shape = self.onnx_past_key_values_shape(batch_size=1)
past_key_values = torch.zeros(
*cache_shape, device="cpu", dtype=torch.float32
)
step_idx = torch.tensor([0], dtype=torch.long, device="cpu")
output_names = [
"actions",
"present_key_values",
*self.onnx_routing_output_names(),
]
torch.onnx.export(
exporter,
(obs, past_key_values, step_idx),
str(export_path),
export_params=True,
opset_version=opset_version,
verbose=False,
dynamo=False,
input_names=["obs", "past_key_values", "step_idx"],
output_names=output_names,
)
else:
exporter = PPOTFWoKVCacheActorOnnxModule(
actor_module=actor_module,
obs_normalizer=obs_normalizer,
obs_norm_enabled=self.obs_norm_enabled,
obs_norm_clip=self.obs_norm_clip
if self.obs_norm_enabled
else 0.0,
).to("cpu")
exporter.eval()
obs = torch.zeros(
1, 32, self.flat_obs_dim, device="cpu", dtype=torch.float32
)
torch.onnx.export(
exporter,
(obs,),
str(export_path),
export_params=True,
opset_version=opset_version,
verbose=False,
dynamo=False,
input_names=["obs"],
output_names=["actions"],
)
return str(export_path)
def update_distribution(self, actor_obs):
"""Distribution using TransformerDecoderPolicy single-step mu + learnable log-std.
Args:
actor_obs: [B, D] normalized obs
"""
mu = self.actor_module.single_step_mu(actor_obs)
std = self._sigma_from_params()
std = torch.clamp(std, min=self.min_sigma, max=self.max_sigma)
self.distribution = Normal(mu, std)
def forward(
self,
obs_td: TensorDict | torch.Tensor,
actions: torch.Tensor | None = None,
mode: str = "sampling",
attn_mask: torch.Tensor | None = None,
*,
update_obs_norm: bool = True,
past_key_values: torch.Tensor | None = None,
current_pos: torch.Tensor | None = None,
) -> TensorDict | tuple[torch.Tensor, torch.Tensor]:
"""TensorDict-first forward for PPOTFActor.
Modes:
- "sampling" / "logp" / "inference": single-step policy with KV-cache-aware
mean prediction via `actor_module.single_step_mu`.
- "sequence_logp": sequence log-prob evaluation with attention mask support.
"""
if past_key_values is not None:
if isinstance(obs_td, TensorDict):
if self.assembler is None:
raise ValueError(
"PPOTFActor requires obs_schema/assembler for ONNX cache path."
)
actor_obs = self.assembler(obs_td)
else:
actor_obs = obs_td
return self.actor_module(
actor_obs,
past_key_values=past_key_values,
current_pos=current_pos,
)
if mode == "sequence_logp":
if not isinstance(obs_td, TensorDict):
raise ValueError(
"PPOTFActor.forward(mode='sequence_logp') expects TensorDict input."
)
if obs_td.batch_dims != 2:
raise ValueError(
"PPOTFActor.forward(mode='sequence_logp') expects TensorDict with "
f"batch_dims=2 [B, T], got batch_size={tuple(obs_td.batch_size)}"
)
if self.assembler is None:
raise ValueError(
"PPOTFActor requires obs_schema to assemble sequence observations."
)
if actions is None:
raise ValueError(
"actions must be provided when mode='sequence_logp'"
)
b, t = int(obs_td.batch_size[0]), int(obs_td.batch_size[1])
flat_td = obs_td.flatten(0, 1)
actor_obs_flat = self.assembler(flat_td)
update = bool(update_obs_norm)
for fn in self.actor_obs_transforms:
actor_obs_flat = fn(actor_obs_flat, update)
self._maybe_update_aux_router_future_recon_norm(
flat_td, update=update
)
actor_obs_seq = actor_obs_flat.reshape(b, t, -1)
if actor_obs_seq.ndim != 3:
raise ValueError(
"PPOTFActor forward(mode='sequence_logp') expects actor_obs "
f"with shape [B, T, D], got {actor_obs_seq.shape}"
)
mu, sigma, logp, entropy, aux_preds = self.sequence_forward_logp(
actor_obs_seq, actions, attn_mask
)
td = obs_td.clone(recurse=False)
td.set("mu", mu)
td.set("sigma", sigma)
td.set("actions", actions)
td.set("actions_log_prob", logp)
td.set("entropy", entropy)
if aux_preds is not None:
if "base_lin_vel_loc" in aux_preds:
td.set(
"aux_base_lin_vel_loc", aux_preds["base_lin_vel_loc"]
)
td.set(
"aux_base_lin_vel_log_std",
aux_preds["base_lin_vel_log_std"],
)
td.set("aux_root_height_loc", aux_preds["root_height_loc"])
td.set(
"aux_root_height_log_std",
aux_preds["root_height_log_std"],
)
td.set(
"aux_keybody_contact_logits",
aux_preds["keybody_contact_logits"],
)
td.set(
"aux_ref_keybody_rel_pos",
aux_preds["ref_keybody_rel_pos"],
)
td.set(
"aux_robot_keybody_rel_pos",
aux_preds["robot_keybody_rel_pos"],
)
if "denoise_ref_root_lin_vel_residual" in aux_preds:
td.set(
"aux_denoise_ref_root_lin_vel_residual",
aux_preds["denoise_ref_root_lin_vel_residual"],
)
if "denoise_ref_root_ang_vel_residual" in aux_preds:
td.set(
"aux_denoise_ref_root_ang_vel_residual",
aux_preds["denoise_ref_root_ang_vel_residual"],
)
if "denoise_ref_dof_pos_residual" in aux_preds:
td.set(
"aux_denoise_ref_dof_pos_residual",
aux_preds["denoise_ref_dof_pos_residual"],
)
if "router_command_recon" in aux_preds:
td.set(
"aux_router_command_recon",
aux_preds["router_command_recon"],
)
if "router_future_recon" in aux_preds:
td.set(
"aux_router_future_recon",
aux_preds["router_future_recon"],
)
if "router_features" in aux_preds:
td.set("router_features", aux_preds["router_features"])
if "router_temporal_features" in aux_preds:
td.set(
"router_temporal_features",
aux_preds["router_temporal_features"],
)
return td
if mode not in ("sampling", "logp", "inference"):
raise ValueError(f"Unsupported mode: {mode}")
if not isinstance(obs_td, TensorDict):
raise ValueError("PPOTFActor.forward expects TensorDict input.")
if self.assembler is None:
raise ValueError(
"Flat-tensor actor module requires obs_schema in PPOTFActor init."
)
td = obs_td.clone(recurse=False)
actor_obs = self.assembler(obs_td)
update = bool(update_obs_norm)
for fn in self.actor_obs_transforms:
actor_obs = fn(actor_obs, update)
self._maybe_update_aux_router_future_recon_norm(obs_td, update=update)
if hasattr(self.actor_module, "single_step_mu"):
mu = self.actor_module.single_step_mu(actor_obs)
else:
mu = self.actor_module(actor_obs)
sigma = self._sigma_like(mu)
td.set("mu", mu)
td.set("sigma", sigma)
if mode == "inference":
td.set("actions", mu)
return td
self.distribution = Normal(mu, sigma)
if mode == "sampling":
actions_out = self.distribution.sample()
else:
if actions is None:
raise ValueError("actions must be provided when mode='logp'")
actions_out = actions
td.set("actions", actions_out)
td.set(
"actions_log_prob",
self.distribution.log_prob(actions_out).sum(dim=-1),
)
td.set("entropy", self.distribution.entropy().sum(dim=-1))
return td
def sequence_forward_logp(
self,
obs_seq: torch.Tensor,
actions: torch.Tensor,
attn_mask: torch.Tensor | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
dict[str, torch.Tensor] | None,
]:
"""Sequence log-prob path with learnable per-action log-std.
Args:
obs_seq: [B, T, D]
actions: [B, T, A]
attn_mask: [B, T, T] boolean (True if attend allowed)
Returns:
mu: [B, T, A], sigma: [B, T, A], logp: [B, T, 1], entropy: [B, T, 1]
"""
aux_preds = None
aux_router_future_recon_enabled = bool(
getattr(self, "aux_router_future_recon_enabled", False)
)
need_pre_moe_aux = self.aux_state_pred_enabled
need_router_features = (
self.aux_router_command_recon_enabled
or self.aux_router_switch_penalty_enabled
)
need_router_aux = (
need_router_features or aux_router_future_recon_enabled
)
need_ref_aux_hidden = bool(
(need_pre_moe_aux or aux_router_future_recon_enabled)
and getattr(
self.actor_module, "supports_explicit_ref_aux_hidden", False
)
)
if need_pre_moe_aux and need_router_aux:
sequence_mu_kwargs = {
"attn_mask": attn_mask,
"return_pre_moe_hidden": True,
"return_router_features": need_router_features,
"return_router_temporal_features": self.aux_router_switch_penalty_enabled,
}
if need_ref_aux_hidden:
sequence_mu_kwargs["return_ref_aux_hidden"] = True
actor_outputs = self.actor_module.sequence_mu(
obs_seq,
**sequence_mu_kwargs,
)
output_parts = list(actor_outputs)
mu = output_parts.pop(0)
pre_moe_hidden = output_parts.pop(0)
ref_aux_hidden = (
output_parts.pop(0) if need_ref_aux_hidden else None
)
router_features = (
output_parts.pop(0) if need_router_features else None
)
router_temporal_features = (
output_parts.pop(0)
if self.aux_router_switch_penalty_enabled
else None
)
aux_preds = self.actor_module.predict_aux_from_pre_moe(
pre_moe_hidden,
ref_aux_hidden=ref_aux_hidden if need_ref_aux_hidden else None,
)
if router_features is not None:
aux_preds["router_features"] = router_features
if router_temporal_features is not None:
aux_preds["router_temporal_features"] = (
router_temporal_features
)
if self.aux_router_command_recon_enabled:
aux_preds["router_command_recon"] = (
self.actor_module.predict_aux_router_command_from_router_features(
router_features
)
)
if aux_router_future_recon_enabled:
aux_preds["router_future_recon"] = (
self.actor_module.predict_aux_router_future_recon_from_router_hidden(
ref_aux_hidden
)
)
elif need_pre_moe_aux:
sequence_mu_kwargs = {
"attn_mask": attn_mask,
"return_pre_moe_hidden": True,
}
if need_ref_aux_hidden:
sequence_mu_kwargs["return_ref_aux_hidden"] = True
actor_outputs = self.actor_module.sequence_mu(
obs_seq,
**sequence_mu_kwargs,
)
if need_ref_aux_hidden:
mu, pre_moe_hidden, ref_aux_hidden = actor_outputs
else:
mu, pre_moe_hidden = actor_outputs
aux_preds = self.actor_module.predict_aux_from_pre_moe(
pre_moe_hidden,
ref_aux_hidden=ref_aux_hidden if need_ref_aux_hidden else None,
)
elif need_router_aux:
sequence_mu_kwargs = {
"attn_mask": attn_mask,
"return_router_features": need_router_features,
"return_router_temporal_features": self.aux_router_switch_penalty_enabled,
}
if need_ref_aux_hidden:
sequence_mu_kwargs["return_ref_aux_hidden"] = True
actor_outputs = self.actor_module.sequence_mu(
obs_seq,
**sequence_mu_kwargs,
)
output_parts = list(actor_outputs)
mu = output_parts.pop(0)
ref_aux_hidden = (
output_parts.pop(0) if need_ref_aux_hidden else None
)
router_features = (
output_parts.pop(0) if need_router_features else None
)
router_temporal_features = (
output_parts.pop(0)
if self.aux_router_switch_penalty_enabled
else None
)
aux_preds = {}
if router_features is not None:
aux_preds["router_features"] = router_features
if router_temporal_features is not None:
aux_preds["router_temporal_features"] = (
router_temporal_features
)
if self.aux_router_command_recon_enabled:
aux_preds["router_command_recon"] = (
self.actor_module.predict_aux_router_command_from_router_features(
router_features
)
)
if aux_router_future_recon_enabled:
aux_preds["router_future_recon"] = (
self.actor_module.predict_aux_router_future_recon_from_router_hidden(
ref_aux_hidden
)
)
else:
mu = self.actor_module.sequence_mu(obs_seq, attn_mask=attn_mask)
# Match sampling-time clamping for stability and consistent KL/log-prob
sigma_vec = self._sigma_from_params().clamp(
self.min_sigma, self.max_sigma
)
sigma = sigma_vec[None, None, :].expand_as(mu)
var = sigma * sigma
logp = -0.5 * (
((actions - mu) ** 2) / (var + 1.0e-8)
+ 2.0 * torch.log(sigma + 1.0e-8)
+ math.log(2.0 * math.pi)
).sum(dim=-1, keepdim=True)
entropy = (
0.5 + 0.5 * math.log(2.0 * math.pi) + torch.log(sigma + 1.0e-8)
).sum(dim=-1, keepdim=True)
return mu, sigma, logp, entropy, aux_preds
class PPOTFRefRouterActor(PPOTFActor):
@staticmethod
def _leaf_obs_name(term: str) -> str:
return str(term).rsplit("/", maxsplit=1)[-1]
@classmethod
def _infer_flat_term_dim(
cls,
*,
obs_example: TensorDict,
term: str,
seq_len: int,
) -> int:
tensor = TensorDictAssembler._get_from_data(obs_example, str(term))
if tensor is None:
raise KeyError(
f"Missing obs term '{term}' in obs_example while inferring "
"reference-router feature indices."
)
if not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Obs term '{term}' must be a torch.Tensor, got {type(tensor)}."
)
if tensor.ndim == 2:
if seq_len != 1:
raise ValueError(
f"Obs term '{term}' expected seq_len={seq_len} but tensor "
f"is 2D with shape {tuple(tensor.shape)}."
)
return int(tensor.shape[-1])
if tensor.ndim == 3:
if int(tensor.shape[1]) != seq_len:
raise ValueError(
f"Obs term '{term}' seq_len mismatch: expected {seq_len}, "
f"got {int(tensor.shape[1])}."
)
return int(tensor.shape[1] * tensor.shape[-1])
raise ValueError(
f"Obs term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}."
)
@classmethod
def infer_router_feature_indices(
cls,
obs_schema: dict,
obs_example: TensorDict,
) -> list[int]:
if not isinstance(obs_example, TensorDict):
raise ValueError(
"PPOTFRefRouterActor requires TensorDict obs_example."
)
router_feature_indices: list[int] = []
offset = 0
for _, seq_cfg in obs_schema.items():
if not isinstance(seq_cfg, dict):
continue
seq_len = int(seq_cfg.get("seq_len", 1))
for term in seq_cfg.get("terms", []):
term_str = str(term)
flat_dim = cls._infer_flat_term_dim(
obs_example=obs_example,
term=term_str,
seq_len=seq_len,
)
leaf_name = cls._leaf_obs_name(term_str)
if leaf_name.startswith("actor_ref_"):
router_feature_indices.extend(
range(offset, offset + flat_dim)
)
offset += flat_dim
if len(router_feature_indices) == 0:
raise ValueError(
"PPOTFRefRouterActor could not infer any actor_ref_* features "
"from obs_schema."
)
return router_feature_indices
def __init__(
self,
obs_schema: dict | None,
module_config_dict: dict,
num_actions: int,
init_noise_std: float,
*,
obs_example: dict | None = None,
):
if obs_schema is None:
raise ValueError(
"PPOTFRefRouterActor requires non-empty obs_schema."
)
if obs_example is None:
raise ValueError("PPOTFRefRouterActor requires obs_example.")
if bool(module_config_dict.get("use_future_cross_attn", False)):
raise ValueError(
"PPOTFRefRouterActor does not support use_future_cross_attn=True."
)
actor_module_cfg = copy.deepcopy(module_config_dict)
aux_future_cfg = actor_module_cfg.get("aux_router_future_recon", {})
if bool(aux_future_cfg.get("enabled", False)):
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicy does not support "
"aux_router_future_recon."
)
router_feature_indices = self.infer_router_feature_indices(
obs_schema, obs_example
)
actor_module_cfg["router_input_dim"] = int(len(router_feature_indices))
actor_module_cfg["router_feature_indices"] = list(
router_feature_indices
)
if "router_embed_mlp_hidden" not in actor_module_cfg:
actor_module_cfg["router_embed_mlp_hidden"] = int(
actor_module_cfg.get("obs_embed_mlp_hidden", 1024)
)
super().__init__(
obs_schema=obs_schema,
module_config_dict=actor_module_cfg,
num_actions=num_actions,
init_noise_std=init_noise_std,
obs_example=obs_example,
)
self.router_feature_indices = list(router_feature_indices)
class PPOTFRefRouterSeqActor(PPOTFActor):
REQUIRED_CURRENT_REF_TERMS = (
"actor_ref_gravity_projection_cur",
"actor_ref_base_linvel_cur",
"actor_ref_base_angvel_cur",
"actor_ref_dof_pos_cur",
"actor_ref_root_height_cur",
)
REQUIRED_FUTURE_REF_TERMS = (
"actor_ref_gravity_projection_fut",
"actor_ref_base_linvel_fut",
"actor_ref_base_angvel_fut",
"actor_ref_dof_pos_fut",
"actor_ref_root_height_fut",
)
SUPPORTED_AUX_WEIGHT_NAMES = {
"w_base_lin_vel",
"w_keybody_contact",
"w_ref_keybody_rel_pos",
"w_robot_keybody_rel_pos",
}
@staticmethod
def _leaf_obs_name(term: str) -> str:
return str(term).rsplit("/", maxsplit=1)[-1]
@classmethod
def _infer_flat_term_dim(
cls,
*,
obs_example: TensorDict,
term: str,
seq_len: int,
) -> int:
tensor = TensorDictAssembler._get_from_data(obs_example, str(term))
if tensor is None:
raise KeyError(
f"Missing obs term '{term}' in obs_example while inferring shared ref partitions."
)
if not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Obs term '{term}' must be a torch.Tensor, got {type(tensor)}."
)
if tensor.ndim == 2:
if seq_len != 1:
raise ValueError(
f"Obs term '{term}' expected seq_len={seq_len} but tensor "
f"is 2D with shape {tuple(tensor.shape)}."
)
return int(tensor.shape[-1])
if tensor.ndim == 3:
if int(tensor.shape[1]) != seq_len:
raise ValueError(
f"Obs term '{term}' seq_len mismatch: expected {seq_len}, "
f"got {int(tensor.shape[1])}."
)
return int(tensor.shape[-1])
raise ValueError(
f"Obs term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}."
)
@classmethod
def _validate_v2_aux_config(cls, module_config_dict: dict) -> None:
aux_cmd_cfg = module_config_dict.get("aux_router_command_recon", {})
if bool(aux_cmd_cfg.get("enabled", False)):
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 does not support "
"aux_router_command_recon."
)
aux_cfg = module_config_dict.get("aux_state_pred", {})
if not bool(aux_cfg.get("enabled", False)):
return
for key, value in aux_cfg.items():
if not str(key).startswith("w_"):
continue
if float(value) <= 0.0:
continue
if str(key) not in cls.SUPPORTED_AUX_WEIGHT_NAMES:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 only supports "
"aux_state_pred weights for "
"base_lin_vel, keybody_contact, ref_keybody_rel_pos, and "
"robot_keybody_rel_pos. Unsupported weight: "
f"{key}."
)
@classmethod
def _build_aux_router_future_recon_schema(
cls, obs_schema: dict
) -> dict[str, dict]:
required_terms = set(cls.REQUIRED_FUTURE_REF_TERMS)
matched_terms: set[str] = set()
future_schema: dict[str, dict] = {}
for group_name, seq_cfg in obs_schema.items():
if not isinstance(seq_cfg, dict):
continue
terms = [
str(term)
for term in seq_cfg.get("terms", [])
if cls._leaf_obs_name(str(term)) in required_terms
]
if len(terms) == 0:
continue
next_seq_cfg = dict(seq_cfg)
next_seq_cfg["terms"] = terms
future_schema[str(group_name)] = next_seq_cfg
matched_terms.update(cls._leaf_obs_name(term) for term in terms)
missing_terms = sorted(required_terms.difference(matched_terms))
if missing_terms:
raise ValueError(
"PPOTFRefRouterSeqActor could not infer all future ref terms "
"for aux_router_future_recon. Missing: "
+ ", ".join(missing_terms)
)
return future_schema
@classmethod
def _prepare_aux_router_future_recon(
cls,
*,
actor_module_cfg: dict,
obs_schema: dict,
obs_example: TensorDict,
) -> TensorDictAssembler | None:
aux_future_cfg = copy.deepcopy(
actor_module_cfg.get("aux_router_future_recon", {})
)
if not bool(aux_future_cfg.get("enabled", False)):
actor_module_cfg["aux_router_future_recon"] = aux_future_cfg
return None
future_schema = cls._build_aux_router_future_recon_schema(obs_schema)
future_assembler = TensorDictAssembler(
future_schema, output_mode="flat"
)
aux_future_cfg["output_dim"] = int(
future_assembler.infer_output_dim(obs_example)
)
actor_module_cfg["aux_router_future_recon"] = aux_future_cfg
return future_assembler
@classmethod
def _infer_shared_ref_layout(
cls,
obs_schema: dict,
obs_example: TensorDict,
) -> dict[str, int | list[int] | list[tuple[int, int, int]]]:
if not isinstance(obs_example, TensorDict):
raise ValueError(
"PPOTFRefRouterSeqActor requires TensorDict obs_example."
)
required_cur = set(cls.REQUIRED_CURRENT_REF_TERMS)
required_fut = set(cls.REQUIRED_FUTURE_REF_TERMS)
found_cur: dict[str, tuple[int, int]] = {}
found_fut: dict[str, tuple[int, int, int]] = {}
state_indices: list[int] = []
ref_cur_indices: list[int] = []
offset = 0
ref_fut_seq_len: int | None = None
for _, seq_cfg in obs_schema.items():
if not isinstance(seq_cfg, dict):
continue
seq_len = int(seq_cfg.get("seq_len", 1))
for term in seq_cfg.get("terms", []):
term_str = str(term)
leaf_name = cls._leaf_obs_name(term_str)
flat_term_dim = cls._infer_flat_term_dim(
obs_example=obs_example,
term=term_str,
seq_len=seq_len,
)
flat_span = int(seq_len * flat_term_dim)
term_range = list(range(offset, offset + flat_span))
if leaf_name in required_cur:
if seq_len != 1:
raise ValueError(
"current ref term "
f"'{leaf_name}' must have seq_len=1, got {seq_len}."
)
if leaf_name in found_cur:
raise ValueError(
f"duplicate current ref term '{leaf_name}' in obs_schema."
)
found_cur[leaf_name] = (offset, flat_term_dim)
ref_cur_indices.extend(term_range)
elif leaf_name in required_fut:
if leaf_name in found_fut:
raise ValueError(
f"duplicate future ref term '{leaf_name}' in obs_schema."
)
if ref_fut_seq_len is None:
ref_fut_seq_len = seq_len
elif ref_fut_seq_len != seq_len:
raise ValueError(
"future ref terms must share one seq_len, got "
f"{ref_fut_seq_len} and {seq_len}."
)
found_fut[leaf_name] = (
offset,
offset + flat_span,
flat_term_dim,
)
else:
state_indices.extend(term_range)
offset += flat_span
missing_cur = sorted(required_cur.difference(found_cur.keys()))
if missing_cur:
raise ValueError(
"missing required current ref term(s): "
+ ", ".join(missing_cur)
)
missing_fut = sorted(required_fut.difference(found_fut.keys()))
if missing_fut:
raise ValueError(
"missing required future ref term(s): "
+ ", ".join(missing_fut)
)
if ref_fut_seq_len is None or ref_fut_seq_len <= 0:
raise ValueError(
"missing required future ref terms in obs_schema."
)
if len(state_indices) == 0:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 requires at least "
"one non-reference actor state feature."
)
ordered_fut_slices = [
found_fut[leaf_name] for leaf_name in cls.REQUIRED_FUTURE_REF_TERMS
]
return {
"full_obs_input_dim": int(offset),
"state_obs_input_dim": int(len(state_indices)),
"ref_cur_token_dim": int(len(ref_cur_indices)),
"ref_fut_token_dim": int(
sum(end - start for start, end, _ in ordered_fut_slices)
// ref_fut_seq_len
),
"ref_fut_seq_len": int(ref_fut_seq_len),
"state_feature_indices": state_indices,
"ref_cur_feature_indices": ref_cur_indices,
"ref_fut_slices": ordered_fut_slices,
}
def __init__(
self,
obs_schema: dict | None,
module_config_dict: dict,
num_actions: int,
init_noise_std: float,
*,
obs_example: dict | None = None,
):
if obs_schema is None:
raise ValueError(
"PPOTFRefRouterSeqActor requires non-empty obs_schema."
)
if obs_example is None:
raise ValueError("PPOTFRefRouterSeqActor requires obs_example.")
if bool(module_config_dict.get("use_future_cross_attn", False)):
raise ValueError(
"PPOTFRefRouterSeqActor does not support use_future_cross_attn=True."
)
self._validate_v2_aux_config(module_config_dict)
inferred_layout = self._infer_shared_ref_layout(
obs_schema, obs_example
)
actor_module_cfg = copy.deepcopy(module_config_dict)
actor_module_cfg["input_dim_override"] = int(
inferred_layout["state_obs_input_dim"]
)
actor_module_cfg["state_obs_input_dim"] = int(
inferred_layout["state_obs_input_dim"]
)
actor_module_cfg["ref_cur_token_dim"] = int(
inferred_layout["ref_cur_token_dim"]
)
actor_module_cfg["ref_fut_token_dim"] = int(
inferred_layout["ref_fut_token_dim"]
)
actor_module_cfg["ref_fut_seq_len"] = int(
inferred_layout["ref_fut_seq_len"]
)
actor_module_cfg["state_feature_indices"] = list(
inferred_layout["state_feature_indices"]
)
actor_module_cfg["ref_cur_feature_indices"] = list(
inferred_layout["ref_cur_feature_indices"]
)
actor_module_cfg["ref_fut_slices"] = [
list(item) for item in inferred_layout["ref_fut_slices"]
]
actor_module_cfg.pop("router_hist_obs_schema", None)
actor_module_cfg.pop("router_fut_obs_schema", None)
super().__init__(
obs_schema=obs_schema,
module_config_dict=actor_module_cfg,
num_actions=num_actions,
init_noise_std=init_noise_std,
obs_example=obs_example,
)
self.full_obs_input_dim = int(inferred_layout["full_obs_input_dim"])
self.state_obs_input_dim = int(inferred_layout["state_obs_input_dim"])
self.ref_cur_token_dim = int(inferred_layout["ref_cur_token_dim"])
self.ref_fut_token_dim = int(inferred_layout["ref_fut_token_dim"])
self.ref_fut_seq_len = int(inferred_layout["ref_fut_seq_len"])
self.state_feature_indices = list(
inferred_layout["state_feature_indices"]
)
self.ref_cur_feature_indices = list(
inferred_layout["ref_cur_feature_indices"]
)
self.ref_fut_slices = [
tuple(int(v) for v in item)
for item in inferred_layout["ref_fut_slices"]
]
class PPOTFRefRouterV3Actor(PPOTFRefRouterSeqActor):
def __init__(
self,
obs_schema: dict | None,
module_config_dict: dict,
num_actions: int,
init_noise_std: float,
*,
obs_example: dict | None = None,
):
if obs_schema is None:
raise ValueError(
"PPOTFRefRouterV3Actor requires non-empty obs_schema."
)
if obs_example is None:
raise ValueError("PPOTFRefRouterV3Actor requires obs_example.")
if bool(module_config_dict.get("use_future_cross_attn", False)):
raise ValueError(
"PPOTFRefRouterV3Actor does not support use_future_cross_attn=True."
)
self._validate_v2_aux_config(module_config_dict)
inferred_layout = self._infer_shared_ref_layout(
obs_schema, obs_example
)
actor_module_cfg = copy.deepcopy(module_config_dict)
actor_module_cfg["state_obs_input_dim"] = int(
inferred_layout["state_obs_input_dim"]
)
actor_module_cfg["ref_cur_token_dim"] = int(
inferred_layout["ref_cur_token_dim"]
)
actor_module_cfg["ref_fut_token_dim"] = int(
inferred_layout["ref_fut_token_dim"]
)
actor_module_cfg["ref_fut_seq_len"] = int(
inferred_layout["ref_fut_seq_len"]
)
actor_module_cfg["state_feature_indices"] = list(
inferred_layout["state_feature_indices"]
)
actor_module_cfg["ref_cur_feature_indices"] = list(
inferred_layout["ref_cur_feature_indices"]
)
actor_module_cfg["ref_fut_slices"] = [
list(item) for item in inferred_layout["ref_fut_slices"]
]
actor_module_cfg.pop("router_hist_obs_schema", None)
actor_module_cfg.pop("router_fut_obs_schema", None)
future_recon_assembler = self._prepare_aux_router_future_recon(
actor_module_cfg=actor_module_cfg,
obs_schema=obs_schema,
obs_example=obs_example,
)
PPOTFActor.__init__(
self,
obs_schema=obs_schema,
module_config_dict=actor_module_cfg,
num_actions=num_actions,
init_noise_std=init_noise_std,
obs_example=obs_example,
)
self.full_obs_input_dim = int(inferred_layout["full_obs_input_dim"])
self.state_obs_input_dim = int(inferred_layout["state_obs_input_dim"])
self.ref_cur_token_dim = int(inferred_layout["ref_cur_token_dim"])
self.ref_fut_token_dim = int(inferred_layout["ref_fut_token_dim"])
self.ref_fut_seq_len = int(inferred_layout["ref_fut_seq_len"])
self.state_feature_indices = list(
inferred_layout["state_feature_indices"]
)
self.ref_cur_feature_indices = list(
inferred_layout["ref_cur_feature_indices"]
)
self.ref_fut_slices = [
tuple(int(v) for v in item)
for item in inferred_layout["ref_fut_slices"]
]
self.aux_router_future_recon_assembler = future_recon_assembler
class PPOCondTFActor(PPOTFActor):
"""Transformer actor with flat state obs and seq future-token conditioning."""
def __init__(
self,
obs_schema: dict | None,
module_config_dict: dict,
num_actions: int,
init_noise_std: float,
*,
obs_example: dict | None = None,
):
super().__init__(
obs_schema=obs_schema,
module_config_dict=module_config_dict,
num_actions=num_actions,
init_noise_std=init_noise_std,
obs_example=obs_example,
)
if obs_schema is None:
raise ValueError("PPOCondTFActor requires non-empty obs_schema.")
if "flattened_obs" not in obs_schema:
raise ValueError("obs_schema must contain 'flattened_obs'.")
if "flattened_obs_fut" not in obs_schema:
raise ValueError("obs_schema must contain 'flattened_obs_fut'.")
if obs_example is None:
raise ValueError("PPOCondTFActor requires obs_example.")
self.state_schema = {"flattened_obs": obs_schema["flattened_obs"]}
self.future_schema = {
"flattened_obs_fut": obs_schema["flattened_obs_fut"]
}
self.state_assembler = TensorDictAssembler(
self.state_schema, output_mode="flat"
)
self.future_assembler = TensorDictAssembler(
self.future_schema, output_mode="seq"
)
self.state_dim = int(
self.state_assembler.infer_output_dim(obs_example)
)
self.future_token_dim = int(
self.future_assembler.infer_output_dim(obs_example)
)
self.future_seq_len = int(self.future_assembler.seq_len)
self.future_term_dims = self._infer_future_term_dims(obs_example)
self.full_obs_dim = int(self.flat_obs_dim)
expected_full = self.state_dim + (
self.future_seq_len * self.future_token_dim
)
if self.full_obs_dim != expected_full:
raise ValueError(
"Assembled obs dim mismatch in PPOCondTFActor: "
f"full={self.full_obs_dim}, expected={expected_full}"
)
if self.obs_norm_enabled:
self.state_obs_normalizer = EmpiricalNormalization(
shape=self.state_dim,
eps=self.obs_norm_eps,
update_method=self.obs_norm_update_method,
ema_momentum=self.obs_norm_ema_momentum,
)
else:
self.state_obs_normalizer = nn.Identity()
def _infer_future_term_dims(self, obs_example: TensorDict) -> list[int]:
if not isinstance(obs_example, TensorDict):
raise ValueError("PPOCondTFActor requires TensorDict obs_example.")
fut_cfg = self.future_schema.get("flattened_obs_fut", None)
if fut_cfg is None:
raise ValueError(
"Missing future schema group 'flattened_obs_fut'."
)
terms = fut_cfg.get("terms", [])
if not isinstance(terms, list) or len(terms) == 0:
raise ValueError("Future schema terms must be a non-empty list.")
dims: list[int] = []
for term in terms:
tensor = TensorDictAssembler._get_from_data(obs_example, str(term))
if tensor is None:
raise KeyError(
f"Missing future term '{term}' in obs_example TensorDict."
)
if not isinstance(tensor, torch.Tensor):
raise TypeError(
f"Future term '{term}' must be a torch.Tensor, got {type(tensor)}"
)
if tensor.ndim == 2:
dims.append(int(tensor.shape[-1]))
elif tensor.ndim == 3:
dims.append(int(tensor.shape[-1]))
else:
raise ValueError(
f"Future term '{term}' tensor ndim must be 2 or 3, got {tensor.ndim}"
)
if sum(dims) != int(self.future_token_dim):
raise ValueError(
"Inferred future_term_dims sum mismatch: expected "
f"{int(self.future_token_dim)}, got {sum(dims)} (dims={dims})"
)
return dims
@property
def flat_obs_dim(self) -> int:
if self.assembler is None:
raise ValueError(
"PPOCondTFActor requires the base flat assembler for ONNX."
)
if self.assembler.output_dim is None:
raise ValueError("Base assembler output_dim is not initialized.")
return int(self.assembler.output_dim)
def _normalize_state_obs(
self, state_obs: torch.Tensor, update: bool
) -> torch.Tensor:
if not self.obs_norm_enabled:
return state_obs
if state_obs.ndim != 2:
raise ValueError(
f"state_obs must be [B, D_state], got {tuple(state_obs.shape)}"
)
if update:
self.state_obs_normalizer.update(state_obs)
state_obs = self.state_obs_normalizer.normalize_only(state_obs)
if self.obs_norm_clip > 0.0:
state_obs = torch.clamp(
state_obs, -self.obs_norm_clip, self.obs_norm_clip
)
return state_obs
def _assemble_state_future(
self, obs_td: TensorDict
) -> tuple[torch.Tensor, torch.Tensor]:
if not isinstance(obs_td, TensorDict):
raise ValueError(
"PPOCondTFActor._assemble_state_future expects TensorDict input."
)
state_obs = self.state_assembler(obs_td)
future_obs = self.future_assembler(obs_td)
return state_obs, future_obs
def _split_flat_obs(
self, obs: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
if obs.ndim != 2:
raise ValueError(f"Expected [B, D], got {obs.shape}")
state_obs = obs[:, : self.state_dim]
future_flat = obs[:, self.state_dim :]
expected_dim = self.future_seq_len * self.future_token_dim
if future_flat.shape[-1] != expected_dim:
raise ValueError(
"Future flat obs dim mismatch: expected "
f"{expected_dim}, got {future_flat.shape[-1]}"
)
b = int(obs.shape[0])
offset = 0
future_parts = []
for d_term in self.future_term_dims:
span = int(self.future_seq_len * d_term)
chunk = future_flat[:, offset : offset + span]
future_parts.append(chunk.reshape(b, self.future_seq_len, d_term))
offset += span
if offset != int(future_flat.shape[-1]):
raise ValueError(
"Future flat slicing mismatch: "
f"consumed={offset}, total={int(future_flat.shape[-1])}"
)
future_obs = torch.cat(future_parts, dim=-1)
return state_obs, future_obs
def export_onnx(
self,
onnx_path: str | Path,
*,
opset_version: int = 17,
) -> str:
export_path = Path(onnx_path)
export_path.parent.mkdir(parents=True, exist_ok=True)
if hasattr(self.actor_module, "clear_router_distribution_cache"):
self.actor_module.clear_router_distribution_cache()
actor_module = _clone_module_for_cpu_export(self.actor_module)
if self.obs_norm_enabled:
state_obs_normalizer = _clone_module_for_cpu_export(
self.state_obs_normalizer
)
else:
state_obs_normalizer = nn.Identity()
exporter = PPOCondTFActorOnnxModule(
actor_module=actor_module,
state_obs_normalizer=state_obs_normalizer,
obs_norm_enabled=self.obs_norm_enabled,
obs_norm_clip=self.obs_norm_clip if self.obs_norm_enabled else 0.0,
state_dim=self.state_dim,
future_seq_len=self.future_seq_len,
future_token_dim=self.future_token_dim,
future_term_dims=self.future_term_dims,
).to("cpu")
exporter.eval()
cache_shape = self.onnx_past_key_values_shape(batch_size=1)
obs = torch.zeros(
1, self.flat_obs_dim, device="cpu", dtype=torch.float32
)
past_key_values = torch.zeros(
*cache_shape, device="cpu", dtype=torch.float32
)
step_idx = torch.tensor([0], dtype=torch.long, device="cpu")
output_names = [
"actions",
"present_key_values",
*self.onnx_routing_output_names(),
]
torch.onnx.export(
exporter,
(obs, past_key_values, step_idx),
str(export_path),
export_params=True,
opset_version=opset_version,
verbose=False,
dynamo=False,
input_names=["obs", "past_key_values", "step_idx"],
output_names=output_names,
)
return str(export_path)
def update_distribution(self, actor_obs):
if not isinstance(actor_obs, tuple) or len(actor_obs) != 2:
raise ValueError(
"PPOCondTFActor.update_distribution expects tuple(state_obs, future_obs)."
)
state_obs, future_obs = actor_obs
mu = self.actor_module.single_step_mu_cond(
state_obs,
future_obs,
future_mask=None,
)
std = self._sigma_from_params()
std = torch.clamp(std, min=self.min_sigma, max=self.max_sigma)
self.distribution = Normal(mu, std)
def forward(
self,
obs_td: TensorDict | torch.Tensor,
actions: torch.Tensor | None = None,
mode: str = "sampling",
attn_mask: torch.Tensor | None = None,
*,
update_obs_norm: bool = True,
past_key_values: torch.Tensor | None = None,
current_pos: torch.Tensor | None = None,
) -> TensorDict | tuple[torch.Tensor, torch.Tensor]:
if past_key_values is not None:
if isinstance(obs_td, TensorDict):
state_obs, future_obs = self._assemble_state_future(obs_td)
else:
state_obs, future_obs = self._split_flat_obs(obs_td)
state_obs = self._normalize_state_obs(state_obs, update=False)
return self.actor_module._forward_inference_onnx_cond(
state_obs,
future_obs,
past_key_values,
current_pos,
)
if mode == "sequence_logp":
if not isinstance(obs_td, TensorDict):
raise ValueError(
"PPOCondTFActor.forward(mode='sequence_logp') expects TensorDict input."
)
if obs_td.batch_dims != 2:
raise ValueError(
"PPOCondTFActor.forward(mode='sequence_logp') expects batch_dims=2 [B, T], "
f"got batch_size={tuple(obs_td.batch_size)}"
)
if actions is None:
raise ValueError(
"actions must be provided when mode='sequence_logp'"
)
b, t = int(obs_td.batch_size[0]), int(obs_td.batch_size[1])
future_mask = None
if "future_mask" in obs_td.keys():
future_mask = obs_td.get("future_mask")
if future_mask.shape != (b, t, self.future_seq_len):
raise ValueError(
"future_mask shape mismatch in sequence_logp: expected "
f"{(b, t, self.future_seq_len)}, got {tuple(future_mask.shape)}"
)
future_mask = future_mask.to(torch.bool)
flat_td = obs_td.flatten(0, 1)
state_flat, future_flat = self._assemble_state_future(flat_td)
update = bool(update_obs_norm)
state_flat = self._normalize_state_obs(state_flat, update=update)
state_seq = state_flat.reshape(b, t, -1)
future_seq = future_flat.reshape(
b, t, self.future_seq_len, self.future_token_dim
)
(
mu,
sigma,
logp,
entropy,
aux_preds,
) = self.sequence_forward_logp_cond(
state_seq,
future_seq,
actions,
attn_mask,
future_mask,
)
td = obs_td.clone(recurse=False)
td.set("mu", mu)
td.set("sigma", sigma)
td.set("actions", actions)
td.set("actions_log_prob", logp)
td.set("entropy", entropy)
if aux_preds is not None:
if "base_lin_vel_loc" in aux_preds:
td.set(
"aux_base_lin_vel_loc", aux_preds["base_lin_vel_loc"]
)
td.set(
"aux_base_lin_vel_log_std",
aux_preds["base_lin_vel_log_std"],
)
td.set("aux_root_height_loc", aux_preds["root_height_loc"])
td.set(
"aux_root_height_log_std",
aux_preds["root_height_log_std"],
)
td.set(
"aux_keybody_contact_logits",
aux_preds["keybody_contact_logits"],
)
td.set(
"aux_ref_keybody_rel_pos",
aux_preds["ref_keybody_rel_pos"],
)
td.set(
"aux_robot_keybody_rel_pos",
aux_preds["robot_keybody_rel_pos"],
)
if "denoise_ref_root_lin_vel_residual" in aux_preds:
td.set(
"aux_denoise_ref_root_lin_vel_residual",
aux_preds["denoise_ref_root_lin_vel_residual"],
)
if "denoise_ref_root_ang_vel_residual" in aux_preds:
td.set(
"aux_denoise_ref_root_ang_vel_residual",
aux_preds["denoise_ref_root_ang_vel_residual"],
)
if "denoise_ref_dof_pos_residual" in aux_preds:
td.set(
"aux_denoise_ref_dof_pos_residual",
aux_preds["denoise_ref_dof_pos_residual"],
)
if "router_command_recon" in aux_preds:
td.set(
"aux_router_command_recon",
aux_preds["router_command_recon"],
)
if "router_features" in aux_preds:
td.set("router_features", aux_preds["router_features"])
if "router_temporal_features" in aux_preds:
td.set(
"router_temporal_features",
aux_preds["router_temporal_features"],
)
return td
if mode not in ("sampling", "logp", "inference"):
raise ValueError(f"Unsupported mode: {mode}")
if not isinstance(obs_td, TensorDict):
raise ValueError(
"PPOCondTFActor.forward expects TensorDict input."
)
td = obs_td.clone(recurse=False)
state_obs, future_obs = self._assemble_state_future(obs_td)
update = bool(update_obs_norm)
state_obs = self._normalize_state_obs(state_obs, update=update)
future_mask = None
if "future_mask" in td.keys():
future_mask = td.get("future_mask")
if future_mask.shape != (state_obs.shape[0], self.future_seq_len):
raise ValueError(
"future_mask shape mismatch in single-step forward: expected "
f"{(state_obs.shape[0], self.future_seq_len)}, got {tuple(future_mask.shape)}"
)
future_mask = future_mask.to(torch.bool)
mu = self.actor_module.single_step_mu_cond(
state_obs, future_obs, future_mask=future_mask
)
sigma = self._sigma_like(mu)
td.set("mu", mu)
td.set("sigma", sigma)
if mode == "inference":
td.set("actions", mu)
return td
self.distribution = Normal(mu, sigma)
if mode == "sampling":
actions_out = self.distribution.sample()
else:
if actions is None:
raise ValueError("actions must be provided when mode='logp'")
actions_out = actions
td.set("actions", actions_out)
td.set(
"actions_log_prob",
self.distribution.log_prob(actions_out).sum(dim=-1),
)
td.set("entropy", self.distribution.entropy().sum(dim=-1))
return td
def sequence_forward_logp_cond(
self,
state_seq: torch.Tensor,
future_seq: torch.Tensor,
actions: torch.Tensor,
attn_mask: torch.Tensor | None,
future_mask: torch.Tensor | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
dict[str, torch.Tensor] | None,
]:
aux_preds = None
need_pre_moe_aux = self.aux_state_pred_enabled
need_router_aux = (
self.aux_router_command_recon_enabled
or self.aux_router_switch_penalty_enabled
)
if need_pre_moe_aux and need_router_aux:
actor_outputs = self.actor_module.sequence_mu_cond(
state_seq,
future_seq,
attn_mask=attn_mask,
future_mask=future_mask,
return_pre_moe_hidden=True,
return_router_features=True,
return_router_temporal_features=self.aux_router_switch_penalty_enabled,
)
if self.aux_router_switch_penalty_enabled:
(
mu,
pre_moe_hidden,
router_features,
router_temporal_features,
) = actor_outputs
else:
mu, pre_moe_hidden, router_features = actor_outputs
aux_preds = self.actor_module.predict_aux_from_pre_moe(
pre_moe_hidden
)
aux_preds["router_features"] = router_features
if self.aux_router_switch_penalty_enabled:
aux_preds["router_temporal_features"] = (
router_temporal_features
)
if self.aux_router_command_recon_enabled:
aux_preds["router_command_recon"] = (
self.actor_module.predict_aux_router_command_from_router_features(
router_features
)
)
elif need_pre_moe_aux:
mu, pre_moe_hidden = self.actor_module.sequence_mu_cond(
state_seq,
future_seq,
attn_mask=attn_mask,
future_mask=future_mask,
return_pre_moe_hidden=True,
)
aux_preds = self.actor_module.predict_aux_from_pre_moe(
pre_moe_hidden
)
elif need_router_aux:
actor_outputs = self.actor_module.sequence_mu_cond(
state_seq,
future_seq,
attn_mask=attn_mask,
future_mask=future_mask,
return_router_features=True,
return_router_temporal_features=self.aux_router_switch_penalty_enabled,
)
if self.aux_router_switch_penalty_enabled:
(
mu,
router_features,
router_temporal_features,
) = actor_outputs
else:
mu, router_features = actor_outputs
aux_preds = {"router_features": router_features}
if self.aux_router_switch_penalty_enabled:
aux_preds["router_temporal_features"] = (
router_temporal_features
)
if self.aux_router_command_recon_enabled:
aux_preds["router_command_recon"] = (
self.actor_module.predict_aux_router_command_from_router_features(
router_features
)
)
else:
mu = self.actor_module.sequence_mu_cond(
state_seq,
future_seq,
attn_mask=attn_mask,
future_mask=future_mask,
)
sigma_vec = self._sigma_from_params().clamp(
self.min_sigma, self.max_sigma
)
sigma = sigma_vec[None, None, :].expand_as(mu)
var = sigma * sigma
logp = -0.5 * (
((actions - mu) ** 2) / (var + 1.0e-8)
+ 2.0 * torch.log(sigma + 1.0e-8)
+ math.log(2.0 * math.pi)
).sum(dim=-1, keepdim=True)
entropy = (
0.5 + 0.5 * math.log(2.0 * math.pi) + torch.log(sigma + 1.0e-8)
).sum(dim=-1, keepdim=True)
return mu, sigma, logp, entropy, aux_preds
================================================
FILE: holomotion/src/modules/network_modules.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import math
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values."""
def __init__(
self,
shape,
eps: float = 1e-2,
until: int | None = None,
*,
update_method: str = "cumulative",
ema_momentum: float | None = None,
):
"""Initialize EmpiricalNormalization module.
Args:
shape (int or tuple of int): Shape of input values except
batch axis.
eps (float): Small value for stability.
until (int or None): If this arg is specified, the link learns
input values until the sum of batch sizes
exceeds it.
update_method:
One of {"cumulative", "ema"}.
- "cumulative": count-based updates (legacy behavior).
- "ema": EMA updates of mean and second moment.
ema_momentum:
EMA momentum in (0, 1]. Required when update_method == "ema".
"""
super().__init__()
self.eps = eps
self.until = until
self.update_method = str(update_method).lower()
self.ema_momentum = (
float(ema_momentum) if ema_momentum is not None else None
)
if self.update_method in ("count", "cumulative"):
self.update_method = "cumulative"
elif self.update_method in ("ema", "exp", "exponential"):
self.update_method = "ema"
else:
raise ValueError(
f"update_method must be one of {{'cumulative','ema'}}, got {update_method}"
)
if self.update_method == "ema":
if self.ema_momentum is None:
raise ValueError(
"ema_momentum must be provided when update_method == 'ema'"
)
if not (0.0 < self.ema_momentum <= 1.0):
raise ValueError(
f"ema_momentum must be in (0, 1], got {self.ema_momentum}"
)
self.register_buffer("_mean", torch.zeros(shape)[None, ...])
self.register_buffer("_var", torch.ones(shape)[None, ...])
self.register_buffer("_std", torch.ones(shape)[None, ...])
self.register_buffer("_ex2", torch.ones(shape)[None, ...])
self.register_buffer("count", torch.tensor(0, dtype=torch.long))
self.register_buffer("_last_sync_mean", torch.zeros(shape)[None, ...])
self.register_buffer("_last_sync_var", torch.ones(shape)[None, ...])
self.register_buffer(
"_last_sync_count", torch.tensor(0, dtype=torch.long)
)
@property
def mean(self):
return self._mean.squeeze(0).clone()
@property
def std(self):
return self._std.squeeze(0).clone()
def forward(self, x):
"""Normalize mean and variance of values based on empirical values.
Args:
x (ndarray or Variable): Input values
Returns:
ndarray or Variable: Normalized output values
"""
if self.training:
self.update(x)
return (x - self._mean) / (self._std + self.eps)
def normalize_only(self, x):
return (x - self._mean) / (self._std + self.eps)
@torch.compiler.disable
@torch.jit.unused
def update(self, x):
"""Learn input values without computing the output values of them."""
if self.until is not None and self.count >= self.until:
return
count_x = x.shape[0]
self.count += count_x
if self.update_method == "ema":
m = float(self.ema_momentum)
mean_x = torch.mean(x, dim=0, keepdim=True)
ex2_x = torch.mean(x * x, dim=0, keepdim=True)
self._mean.mul_(1.0 - m).add_(mean_x, alpha=m)
self._ex2.mul_(1.0 - m).add_(ex2_x, alpha=m)
var = torch.clamp(self._ex2 - self._mean * self._mean, min=0.0)
self._var.copy_(var)
self._std.copy_(torch.sqrt(self._var))
return
rate = count_x / self.count
var_x = torch.var(x, dim=0, unbiased=False, keepdim=True)
mean_x = torch.mean(x, dim=0, keepdim=True)
delta_mean = mean_x - self._mean
self._mean += rate * delta_mean
self._var += rate * (
var_x - self._var + delta_mean * (mean_x - self._mean)
)
self._std = torch.sqrt(self._var)
@torch.jit.unused
def inverse(self, y):
return y * (self._std + self.eps) + self._mean
def sync_stats_across_processes(self, accelerator):
"""Synchronize normalization statistics across distributed processes."""
if accelerator.num_processes <= 1:
return
if self.update_method == "ema":
# EMA stats are already running estimates.
# Sync by averaging across ranks.
mean_g = accelerator.reduce(
self._mean.to(dtype=torch.float32), reduction="mean"
)
ex2_g = accelerator.reduce(
self._ex2.to(dtype=torch.float32), reduction="mean"
)
var_g = torch.clamp(ex2_g - mean_g * mean_g, min=0.0)
self._mean.copy_(mean_g.to(self._mean.dtype))
self._ex2.copy_(ex2_g.to(self._ex2.dtype))
self._var.copy_(var_g.to(self._var.dtype))
self._std.copy_(torch.sqrt(self._var))
return
# Weighted synchronization with correction to avoid double counting
device = self._mean.device
count_local = self.count.to(device=device, dtype=torch.float32)
mean_local = self._mean.to(device=device, dtype=torch.float32)
var_local = self._var.to(device=device, dtype=torch.float32)
# Local weighted sums
sum_count = accelerator.reduce(count_local, reduction="sum")
sum_mean_count = accelerator.reduce(
mean_local * count_local, reduction="sum"
)
sum_ex2_count = accelerator.reduce(
(var_local + mean_local * mean_local) * count_local,
reduction="sum",
)
# Correct for replication of previously-synced global stats
# across ranks.
last_c = self._last_sync_count.to(device=device, dtype=torch.float32)
if last_c.item() > 0:
w_minus_1 = float(accelerator.num_processes - 1)
last_mean = self._last_sync_mean.to(
device=device, dtype=torch.float32
)
last_var = self._last_sync_var.to(
device=device, dtype=torch.float32
)
sum_count = sum_count - w_minus_1 * last_c
sum_mean_count = sum_mean_count - w_minus_1 * (last_mean * last_c)
sum_ex2_count = sum_ex2_count - w_minus_1 * (
(last_var + last_mean * last_mean) * last_c
)
if sum_count.item() <= 0:
return
global_mean = sum_mean_count / sum_count
global_ex2 = sum_ex2_count / sum_count
global_var = torch.clamp(
global_ex2 - global_mean * global_mean, min=0.0
)
global_std = torch.sqrt(global_var)
# Copy back (keep original buffer shapes)
self._mean.copy_(global_mean.to(self._mean.dtype))
self._var.copy_(global_var.to(self._var.dtype))
self._std.copy_(global_std.to(self._std.dtype))
# Set global sample count and remember snapshot for next correction
self.count.copy_(sum_count.to(self.count.dtype))
self._last_sync_mean.copy_(global_mean.to(self._last_sync_mean.dtype))
self._last_sync_var.copy_(global_var.to(self._last_sync_var.dtype))
self._last_sync_count.copy_(self.count)
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
module_config_dict: dict,
):
super().__init__()
self.module_config_dict = module_config_dict
self.input_dim = int(input_dim)
self.output_dim = int(output_dim)
if self.input_dim <= 0:
raise ValueError(
f"MLP input_dim must be positive, got {self.input_dim}"
)
if self.output_dim <= 0:
raise ValueError(
f"MLP output_dim must be positive, got {self.output_dim}"
)
def _make_norm(
norm_type: str,
dim: int,
*,
eps: float,
) -> nn.Module:
t = str(norm_type).lower()
if t in ("none", "identity", "null"):
return nn.Identity()
if t in ("layernorm", "ln"):
return nn.LayerNorm(dim, eps=eps)
if t in ("rmsnorm", "rms"):
return RMSNorm(dim, eps=eps)
raise ValueError(
f"Unknown norm '{t}'. Expected one of {'none', 'layernorm', 'rmsnorm'}."
)
self.hidden_norm_type = module_config_dict.get("hidden_norm", "none")
self.hidden_norm_eps = float(
module_config_dict.get("hidden_norm_eps", 1.0e-6)
)
layer_config = self.module_config_dict["layer_config"]
hidden_dims: list[int] = list(layer_config.get("hidden_dims", []))
activation = getattr(nn, str(layer_config["activation"]))()
layers: list[nn.Module] = []
prev = self.input_dim
for h in hidden_dims:
h_i = int(h)
layers.append(nn.Linear(prev, h_i))
layers.append(
_make_norm(
self.hidden_norm_type,
h_i,
eps=self.hidden_norm_eps,
)
)
layers.append(activation)
prev = h_i
self.trunk = nn.Sequential(*layers) if layers else nn.Identity()
self.output_head = nn.Linear(prev, self.output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward.
Args:
x: [..., input_dim] assembled tensor observations.
Returns:
y: [..., output_dim]
"""
if not isinstance(x, torch.Tensor):
raise TypeError(f"MLP expects torch.Tensor input, got {type(x)}")
h = self.trunk(x)
return self.output_head(h)
class ConvMLP(nn.Module):
"""Conv1d + pooling history encoder with an MLP head."""
def __init__(
self,
input_dim: int,
output_dim: int,
module_config_dict: dict,
):
super().__init__()
self.module_config_dict = module_config_dict
self.input_dim = int(input_dim)
self.output_dim = int(output_dim)
layer_cfg = dict(module_config_dict.get("layer_config", {}))
activation = str(layer_cfg.get("activation", "SiLU"))
self.conv_channels = int(module_config_dict.get("conv_channels", 128))
self.conv_layers = int(module_config_dict.get("conv_layers", 2))
self.conv_kernel_size = int(
module_config_dict.get("conv_kernel_size", 3)
)
self.pool_type = str(
module_config_dict.get("pool_type", "avg")
).lower()
conv_modules: list[nn.Module] = []
padding = self.conv_kernel_size // 2
in_ch = int(self.input_dim)
for _ in range(self.conv_layers):
conv_modules.append(
nn.Conv1d(
in_channels=in_ch,
out_channels=self.conv_channels,
kernel_size=self.conv_kernel_size,
padding=padding,
bias=True,
)
)
conv_modules.append(getattr(nn, activation)())
in_ch = self.conv_channels
conv_modules.append(nn.AdaptiveAvgPool1d(1))
self.hist_encoder = nn.Sequential(*conv_modules)
fused_dim = int(self.conv_channels + self.input_dim)
self.mlp_head = MLP(
input_dim=fused_dim,
output_dim=int(self.output_dim),
module_config_dict=module_config_dict,
)
def forward(self, hist_seq: torch.Tensor) -> torch.Tensor:
ctx = self.hist_encoder(hist_seq.transpose(1, 2)).squeeze(-1)
latest = hist_seq[:, -1, :]
fused = torch.cat([ctx, latest], dim=-1)
return self.mlp_head(fused)
class ReferenceMotionConvRouterEncoder(nn.Module):
"""Conv1d encoder for reference-motion router sequences."""
def __init__(
self,
input_dim: int,
output_dim: int,
*,
conv_channels: int = 128,
conv_layers: int = 2,
conv_kernel_size: int = 3,
pool_type: str = "avg",
):
super().__init__()
self.input_dim = int(input_dim)
self.output_dim = int(output_dim)
self.conv_channels = int(conv_channels)
self.conv_layers = int(conv_layers)
self.conv_kernel_size = int(conv_kernel_size)
self.pool_type = str(pool_type).lower()
if self.input_dim <= 0:
raise ValueError(
f"input_dim must be positive, got {self.input_dim}"
)
if self.output_dim <= 0:
raise ValueError(
f"output_dim must be positive, got {self.output_dim}"
)
if self.conv_channels <= 0:
raise ValueError(
f"conv_channels must be positive, got {self.conv_channels}"
)
if self.conv_layers <= 0:
raise ValueError(
f"conv_layers must be positive, got {self.conv_layers}"
)
if self.conv_kernel_size <= 0:
raise ValueError(
f"conv_kernel_size must be positive, got {self.conv_kernel_size}"
)
if self.pool_type not in {"avg", "max"}:
raise ValueError(
f"pool_type must be one of {{'avg','max'}}, got {self.pool_type}"
)
padding = self.conv_kernel_size // 2
conv_modules: list[nn.Module] = []
in_ch = self.input_dim
for _ in range(self.conv_layers):
conv_modules.append(
nn.Conv1d(
in_channels=in_ch,
out_channels=self.conv_channels,
kernel_size=self.conv_kernel_size,
padding=padding,
bias=True,
)
)
conv_modules.append(nn.SiLU())
in_ch = self.conv_channels
self.temporal_trunk = nn.Sequential(*conv_modules)
if self.pool_type == "avg":
self.pool = nn.AdaptiveAvgPool1d(1)
else:
self.pool = nn.AdaptiveMaxPool1d(1)
self.out_proj = nn.Sequential(
nn.Linear(self.conv_channels, self.output_dim),
nn.SiLU(),
nn.Linear(self.output_dim, self.output_dim),
)
def forward(self, seq: torch.Tensor) -> torch.Tensor:
if seq.ndim != 3:
raise ValueError(
f"Expected router seq with shape [B, T, D], got {tuple(seq.shape)}."
)
if int(seq.shape[-1]) != self.input_dim:
raise ValueError(
"Router seq dim mismatch: expected "
f"{self.input_dim}, got {int(seq.shape[-1])}."
)
x = seq.transpose(1, 2)
x = self.temporal_trunk(x)
x = self.pool(x).squeeze(-1)
return self.out_proj(x)
class SingleQueryAttentionPool(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.d_model = int(d_model)
self.scale = float(self.d_model) ** -0.5
self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)
self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)
self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)
def forward(
self,
query: torch.Tensor,
tokens: torch.Tensor,
) -> torch.Tensor:
if query.ndim == 2:
if tokens.ndim != 3:
raise ValueError(
"SingleQueryAttentionPool expected [B, N, D] tokens for "
f"2D query, got {tuple(tokens.shape)}."
)
q = self.q_proj(query).unsqueeze(-2)
k = self.k_proj(tokens)
v = self.v_proj(tokens)
attn = torch.softmax(
(q * k).sum(dim=-1, keepdim=True) * self.scale,
dim=-2,
)
return self.out_proj((attn * v).sum(dim=-2))
if query.ndim == 3:
if tokens.ndim != 4:
raise ValueError(
"SingleQueryAttentionPool expected [B, T, N, D] tokens for "
f"3D query, got {tuple(tokens.shape)}."
)
q = self.q_proj(query).unsqueeze(-2)
k = self.k_proj(tokens)
v = self.v_proj(tokens)
attn = torch.softmax(
(q * k).sum(dim=-1, keepdim=True) * self.scale,
dim=-2,
)
return self.out_proj((attn * v).sum(dim=-2))
raise ValueError(
f"SingleQueryAttentionPool query must be 2D or 3D, got {query.ndim}."
)
class GroupedMoETransformerPolicy(nn.Module):
"""Hybrid Modern Transformer decoder policy with SOTA improvements.
Structure:
- Layer 0: Dense MLP (ModernTransformerBlock)
- Optional final layer: Dense MLP when dense_layer_at_last=True
- Intermediate layers: MoE MLP (GroupedMoEBlock)
Features:
- RealRoPE.
- RMSNorm: Root Mean Square Normalization.
- GQA: Grouped Query Attention (configurable n_kv_heads).
- QK-Norm: RMSNorm on Queries and Keys.
- Gated Attention: Qwen-style element-wise sigmoid gating.
- SwiGLU MLP: DeepseekV3MLP for feed-forward.
- Flash Attention: via F.scaled_dot_product_attention.
- Gradient Checkpointing: optional for memory efficiency.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
module_config_dict: dict,
):
super().__init__()
self.input_dim = int(input_dim)
self.output_dim = int(output_dim)
self.module_config_dict = module_config_dict
self.num_fine_experts = module_config_dict["num_fine_experts"]
self.num_shared_experts = module_config_dict["num_shared_experts"]
self.top_k = module_config_dict["top_k"]
self.use_dynamic_bias = module_config_dict.get(
"use_dynamic_bias", False
)
self.bias_update_rate = module_config_dict.get(
"bias_update_rate", 0.001
)
self.routing_score_fn = str(
module_config_dict.get("routing_score_fn", "softmax")
).lower()
self.freeze_router = bool(
module_config_dict.get("freeze_router", False)
)
self.routing_scale = float(
module_config_dict.get("routing_scale", 1.0)
)
self.expert_bias_clip = float(
module_config_dict.get("expert_bias_clip", 0.0)
)
dead_margin_cfg = module_config_dict.get(
"dead_expert_margin_to_topk", {}
)
selected_margin_cfg = module_config_dict.get(
"selected_expert_margin_to_unselected", {}
)
self.dead_expert_margin_to_topk_enabled = bool(
dead_margin_cfg.get("enabled", False)
)
self.selected_expert_margin_to_unselected_enabled = bool(
selected_margin_cfg.get("enabled", False)
)
self.selected_expert_margin_to_unselected_target = float(
selected_margin_cfg.get("target", 0.0)
)
if self.routing_score_fn not in ("softmax", "sigmoid"):
raise ValueError(
f"routing_score_fn must be one of {{'softmax','sigmoid'}}, got {self.routing_score_fn}"
)
if self.routing_scale <= 0.0:
raise ValueError(
f"routing_scale must be > 0, got {self.routing_scale}"
)
if self.expert_bias_clip < 0.0:
raise ValueError(
f"expert_bias_clip must be >= 0, got {self.expert_bias_clip}"
)
if self.selected_expert_margin_to_unselected_target < 0.0:
raise ValueError(
"selected_expert_margin_to_unselected.target must be >= 0, "
f"got {self.selected_expert_margin_to_unselected_target}"
)
_ov = module_config_dict.get("input_dim_override", None)
self.obs_input_dim = (
int(_ov) if isinstance(_ov, (int, float)) else None
)
self.obs_embed_mlp_hidden = int(
module_config_dict.get("obs_embed_mlp_hidden", 1024)
)
self.d_model = int(module_config_dict.get("d_model", 256))
self.n_layers = int(module_config_dict.get("n_layers", 4))
self.dense_layer_at_last = bool(
module_config_dict.get("dense_layer_at_last", False)
)
self.n_heads = int(module_config_dict.get("n_heads", 4))
self.n_kv_heads = int(
module_config_dict.get("n_kv_heads", self.n_heads // 2)
)
self.ff_mult = float(module_config_dict.get("ff_mult", 4))
self.ff_mult_dense = int(
module_config_dict.get("ff_mult_dense", self.ff_mult * 3)
)
self.attn_dropout = float(module_config_dict.get("attn_dropout", 0.0))
self.mlp_dropout = float(module_config_dict.get("mlp_dropout", 0.0))
self.max_ctx_len = int(module_config_dict.get("max_ctx_len", 64))
self.use_qk_norm = module_config_dict.get("use_qk_norm", True)
self.use_gated_attn = module_config_dict.get("use_gated_attn", True)
self.gated_attn_type = module_config_dict.get(
"gated_attn_type", "headwise"
)
self.use_checkpointing = module_config_dict.get(
"use_checkpointing", False
)
self.use_future_cross_attn = bool(
module_config_dict.get("use_future_cross_attn", False)
)
self.state_obs_dim = int(
module_config_dict.get(
"state_obs_dim", self.obs_input_dim or self.input_dim
)
)
self.future_seq_len = int(module_config_dict.get("future_seq_len", 0))
self.future_token_dim = int(
module_config_dict.get("future_token_dim", 0)
)
self.head_dim = self.d_model // self.n_heads
if self.d_model % self.n_heads != 0:
raise ValueError(
f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
)
if self.head_dim % 2 != 0:
raise ValueError(
f"RoPE requires even head_dim, got head_dim={self.head_dim}"
)
# RoPE configuration (used in both sequence and KV-cached single-step inference)
self.rope_theta = float(module_config_dict.get("rope_theta", 10000.0))
self.inv_freq = 1.0 / (
self.rope_theta
** (
torch.arange(0, self.head_dim, 2, dtype=torch.float32)
/ self.head_dim
)
) # [head_dim//2]
self.register_buffer("_rope_inv_freq", self.inv_freq, persistent=False)
self._set_cos_sin_cache(seq_len=8192)
obs_in = self.obs_input_dim or self.input_dim
if self.use_future_cross_attn:
if self.future_seq_len <= 0:
raise ValueError(
"future_seq_len must be positive when use_future_cross_attn=True"
)
if self.future_token_dim <= 0:
raise ValueError(
"future_token_dim must be positive when use_future_cross_attn=True"
)
self.state_obs_embed = nn.Sequential(
nn.Linear(self.state_obs_dim, self.obs_embed_mlp_hidden),
nn.SiLU(),
nn.Linear(self.obs_embed_mlp_hidden, self.d_model),
)
# Keep a single state embedding module so DDP doesn't see unused
# parameters from an extra unused `obs_embed` in conditional mode.
self.obs_embed = self.state_obs_embed
self.future_obs_embed = nn.Sequential(
nn.Linear(self.future_token_dim, self.obs_embed_mlp_hidden),
nn.SiLU(),
nn.Linear(self.obs_embed_mlp_hidden, self.d_model),
)
self.future_pos_embed = nn.Embedding(
self.future_seq_len, self.d_model
)
else:
self.obs_embed = nn.Sequential(
nn.Linear(obs_in, self.obs_embed_mlp_hidden),
nn.SiLU(),
nn.Linear(self.obs_embed_mlp_hidden, self.d_model),
)
self.state_obs_embed = None
self.future_obs_embed = None
self.future_pos_embed = None
# Stack of TransformerBlocks: first layer is always dense; the last
# layer is also dense when dense_layer_at_last=True.
self.layers = nn.ModuleList()
for i in range(self.n_layers):
use_dense_layer = i == 0 or (
self.dense_layer_at_last and i == self.n_layers - 1
)
if use_dense_layer:
layer = ModernTransformerBlock(
d_model=self.d_model,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
ff_mult=self.ff_mult_dense,
use_qk_norm=self.use_qk_norm,
use_gated_attn=self.use_gated_attn,
gated_attn_type=self.gated_attn_type,
attn_dropout=self.attn_dropout,
mlp_dropout=self.mlp_dropout,
use_cross_attn=self.use_future_cross_attn,
)
else:
layer = GroupedMoEBlock(
d_model=self.d_model,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
ff_mult=self.ff_mult,
use_qk_norm=self.use_qk_norm,
use_gated_attn=self.use_gated_attn,
gated_attn_type=self.gated_attn_type,
attn_dropout=self.attn_dropout,
mlp_dropout=self.mlp_dropout,
num_fine_experts=self.num_fine_experts,
num_shared_experts=self.num_shared_experts,
top_k=self.top_k,
use_dynamic_bias=self.use_dynamic_bias,
bias_update_rate=self.bias_update_rate,
routing_score_fn=self.routing_score_fn,
freeze_router=self.freeze_router,
routing_scale=self.routing_scale,
expert_bias_clip=self.expert_bias_clip,
dead_expert_margin_to_topk_enabled=(
self.dead_expert_margin_to_topk_enabled
),
selected_expert_margin_to_unselected_enabled=(
self.selected_expert_margin_to_unselected_enabled
),
selected_expert_margin_to_unselected_target=(
self.selected_expert_margin_to_unselected_target
),
use_cross_attn=self.use_future_cross_attn,
)
self.layers.append(layer)
self._last_moe_layer_idx = None
for layer_idx, layer in enumerate(self.layers):
if isinstance(layer, GroupedMoEBlock):
self._last_moe_layer_idx = layer_idx
self.norm_f = RMSNorm(self.d_model)
self.action_mu_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.SiLU(),
nn.Linear(self.d_model, self.output_dim),
)
aux_cfg = module_config_dict.get("aux_state_pred", {})
self.aux_state_pred_enabled = bool(aux_cfg.get("enabled", False))
self.aux_contact_dim = int(
len(aux_cfg.get("keybody_contact_names", []))
)
self.aux_keybody_pos_dim = int(
len(aux_cfg.get("keybody_rel_pos_names", []))
)
self.use_aux_denoise_ref_root_lin_vel = bool(
float(aux_cfg.get("w_denoise_ref_root_lin_vel", 0.0)) > 0.0
)
self.use_aux_denoise_ref_root_ang_vel = bool(
float(aux_cfg.get("w_denoise_ref_root_ang_vel", 0.0)) > 0.0
)
self.use_aux_denoise_ref_dof_pos = bool(
float(aux_cfg.get("w_denoise_ref_dof_pos", 0.0)) > 0.0
)
if self.aux_state_pred_enabled:
self.aux_vel_head = nn.Linear(self.d_model, 6)
self.aux_height_head = nn.Linear(self.d_model, 2)
self.aux_denoise_ref_root_lin_vel_head = (
nn.Linear(self.d_model, 3)
if self.use_aux_denoise_ref_root_lin_vel
else None
)
self.aux_denoise_ref_root_ang_vel_head = (
nn.Linear(self.d_model, 3)
if self.use_aux_denoise_ref_root_ang_vel
else None
)
self.aux_contact_head = (
nn.Linear(self.d_model, self.aux_contact_dim)
if self.aux_contact_dim > 0
else None
)
self.aux_ref_keybody_pos_head = (
nn.Linear(self.d_model, self.aux_keybody_pos_dim * 3)
if self.aux_keybody_pos_dim > 0
else None
)
self.aux_robot_keybody_pos_head = (
nn.Linear(self.d_model, self.aux_keybody_pos_dim * 3)
if self.aux_keybody_pos_dim > 0
else None
)
self.aux_denoise_ref_dof_pos_head = (
nn.Linear(self.d_model, self.output_dim)
if self.use_aux_denoise_ref_dof_pos
else None
)
else:
self.aux_vel_head = None
self.aux_height_head = None
self.aux_denoise_ref_root_lin_vel_head = None
self.aux_denoise_ref_root_ang_vel_head = None
self.aux_contact_head = None
self.aux_ref_keybody_pos_head = None
self.aux_robot_keybody_pos_head = None
self.aux_denoise_ref_dof_pos_head = None
# True per-layer KV cache for single-step inference.
# K/V shapes: [B, n_layers, max_ctx_len, n_kv_heads, head_dim]
self._k_cache: torch.Tensor | None = None
self._v_cache: torch.Tensor | None = None
# Cache state per environment
self._kv_cache_len: torch.Tensor | None = None # [B]
self._kv_cache_write_idx: torch.Tensor | None = None # [B]
self._kv_cache_abs_pos: torch.Tensor | None = None # [B]
self._prev_last_moe_router_p: torch.Tensor | None = None
self._prev_last_moe_router_valid: torch.Tensor | None = None
self._last_moe_router_js_sum: torch.Tensor | None = None
self._last_moe_router_js_count: torch.Tensor | None = None
self._last_moe_router_top1_switch_sum: torch.Tensor | None = None
self._last_moe_router_top1_switch_count: torch.Tensor | None = None
aux_cmd_cfg = module_config_dict.get("aux_router_command_recon", {})
self.aux_router_command_recon_enabled = bool(
aux_cmd_cfg.get("enabled", False)
)
self.aux_router_command_recon_output_dim = int(
aux_cmd_cfg.get("output_dim", 0)
)
self.aux_router_command_recon_hidden_dim = int(
aux_cmd_cfg.get("hidden_dim", self.d_model)
)
self._num_moe_layers = sum(
1 for layer in self.layers if isinstance(layer, GroupedMoEBlock)
)
if self.aux_router_command_recon_enabled:
if self._num_moe_layers <= 0:
raise ValueError(
"aux_router_command_recon requires at least one GroupedMoEBlock."
)
if self.aux_router_command_recon_output_dim <= 0:
raise ValueError(
"aux_router_command_recon.output_dim must be positive when enabled."
)
router_feature_dim = self._num_moe_layers * self.num_fine_experts
self.aux_router_command_recon_head = nn.Sequential(
nn.Linear(
router_feature_dim,
self.aux_router_command_recon_hidden_dim,
),
nn.SiLU(),
nn.Linear(
self.aux_router_command_recon_hidden_dim,
self.aux_router_command_recon_output_dim,
),
)
else:
self.aux_router_command_recon_head = None
aux_router_future_cfg = module_config_dict.get(
"aux_router_future_recon", {}
)
self.aux_router_future_recon_enabled = bool(
aux_router_future_cfg.get("enabled", False)
)
self.aux_router_future_recon_output_dim = int(
aux_router_future_cfg.get("output_dim", 0)
)
self.aux_router_future_recon_hidden_dim = int(
aux_router_future_cfg.get("hidden_dim", self.d_model)
)
aux_router_future_norm_cfg = aux_router_future_cfg.get(
"target_norm", {}
)
self.aux_router_future_recon_norm_eps = float(
aux_router_future_norm_cfg.get("epsilon", 1.0e-2)
)
self.aux_router_future_recon_norm_update_method = str(
aux_router_future_norm_cfg.get("update_method", "cumulative")
).lower()
aux_router_future_norm_ema = aux_router_future_norm_cfg.get(
"ema_momentum", None
)
self.aux_router_future_recon_norm_ema_momentum = (
float(aux_router_future_norm_ema)
if aux_router_future_norm_ema is not None
else None
)
if self.aux_router_future_recon_enabled:
if self.aux_router_future_recon_output_dim <= 0:
raise ValueError(
"aux_router_future_recon.output_dim must be positive when enabled."
)
self.aux_router_future_recon_head = nn.Sequential(
nn.Linear(
self.d_model,
self.aux_router_future_recon_hidden_dim,
),
nn.SiLU(),
nn.Linear(
self.aux_router_future_recon_hidden_dim,
self.aux_router_future_recon_output_dim,
),
)
self.aux_router_future_recon_normalizer = EmpiricalNormalization(
shape=self.aux_router_future_recon_output_dim,
eps=self.aux_router_future_recon_norm_eps,
update_method=self.aux_router_future_recon_norm_update_method,
ema_momentum=self.aux_router_future_recon_norm_ema_momentum,
)
else:
self.aux_router_future_recon_head = None
self.aux_router_future_recon_normalizer = None
self._apply_base_freeze_router_state()
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
if self.use_future_cross_attn:
# In conditional mode, `obs_embed` is tied to `state_obs_embed`.
# Older checkpoints may contain separate weights for both; ensure we
# always load the trained state embedding weights.
obs_prefix = prefix + "obs_embed."
state_prefix = prefix + "state_obs_embed."
for suffix in ("0.weight", "0.bias", "2.weight", "2.bias"):
s_key = state_prefix + suffix
o_key = obs_prefix + suffix
if s_key in state_dict:
state_dict[o_key] = state_dict[s_key]
legacy_aux_prefix = prefix + "aux_command_recon_head."
current_aux_prefix = prefix + "aux_router_command_recon_head."
legacy_aux_keys = [
key
for key in list(state_dict.keys())
if key.startswith(legacy_aux_prefix)
]
if legacy_aux_keys:
if self.aux_router_command_recon_head is not None:
for legacy_key in legacy_aux_keys:
suffix = legacy_key.removeprefix(legacy_aux_prefix)
current_key = current_aux_prefix + suffix
state_dict.setdefault(current_key, state_dict[legacy_key])
for legacy_key in legacy_aux_keys:
state_dict.pop(legacy_key, None)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
self._apply_freeze_router_state()
def _router_no_grad_context(self):
if self.freeze_router:
return torch.no_grad()
return nullcontext()
def _apply_base_freeze_router_state(self) -> None:
for layer in self.layers:
if isinstance(layer, GroupedMoEBlock):
layer._apply_freeze_router_state()
def _apply_freeze_router_state(self) -> None:
self._apply_base_freeze_router_state()
if self.aux_router_future_recon_head is not None:
self.aux_router_future_recon_head.requires_grad_(
not self.freeze_router
)
def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
t = torch.arange(
seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
# outer product: [seq_len, head_dim/2]
freqs = torch.outer(t, self.inv_freq)
# Concatenate to match rotate_half: [seq_len, head_dim]
# Different from complex, here we just concat freqs to match the real-valued rotation logic
emb = torch.cat((freqs, freqs), dim=-1)
# [seq_len, head_dim]
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def get_cos_sin(self, x, position_ids):
"""根据 position_ids 获取 cos/sin
x: [B, T, D]
position_ids: [B, T]
Returns: cos, sin -> [B, T, D] (broadcastable)
"""
# cos_cached: [MaxLen, D]
# F.embedding(pos, cache) -> [B, T, D]
cos = F.embedding(position_ids, self.cos_cached)
sin = F.embedding(position_ids, self.sin_cached)
return cos.to(x.dtype), sin.to(x.dtype)
def _init_last_moe_router_shift_state(self, num_envs: int, device) -> None:
if self._last_moe_layer_idx is None:
self._prev_last_moe_router_p = None
self._prev_last_moe_router_valid = None
self._last_moe_router_js_sum = None
self._last_moe_router_js_count = None
self._last_moe_router_top1_switch_sum = None
self._last_moe_router_top1_switch_count = None
return
self._prev_last_moe_router_p = torch.zeros(
num_envs,
self.num_fine_experts,
device=device,
dtype=torch.float32,
)
self._prev_last_moe_router_valid = torch.zeros(
num_envs, device=device, dtype=torch.bool
)
self._last_moe_router_js_sum = torch.zeros(
(), device=device, dtype=torch.float32
)
self._last_moe_router_js_count = torch.zeros(
(), device=device, dtype=torch.float32
)
self._last_moe_router_top1_switch_sum = torch.zeros(
(), device=device, dtype=torch.float32
)
self._last_moe_router_top1_switch_count = torch.zeros(
(), device=device, dtype=torch.float32
)
def _accumulate_last_moe_router_shift(
self, router_distribution: torch.Tensor
) -> None:
if (
self._prev_last_moe_router_p is None
or self._prev_last_moe_router_valid is None
or self._last_moe_router_js_sum is None
or self._last_moe_router_js_count is None
or self._last_moe_router_top1_switch_sum is None
or self._last_moe_router_top1_switch_count is None
):
return
if (
router_distribution.ndim != 3
or int(router_distribution.shape[1]) != 1
):
return
curr_p = router_distribution[:, 0, :].to(torch.float32)
if int(curr_p.shape[0]) != int(self._prev_last_moe_router_p.shape[0]):
return
prev_valid = self._prev_last_moe_router_valid
if torch.any(prev_valid):
prev_p = self._prev_last_moe_router_p[prev_valid]
curr_p_valid = curr_p[prev_valid]
mix_p = 0.5 * (curr_p_valid + prev_p)
eps = 1.0e-20
curr_safe = curr_p_valid.clamp_min(eps)
prev_safe = prev_p.clamp_min(eps)
mix_safe = mix_p.clamp_min(eps)
kl_curr = (
curr_p_valid * (torch.log(curr_safe) - torch.log(mix_safe))
).sum(dim=-1)
kl_prev = (
prev_p * (torch.log(prev_safe) - torch.log(mix_safe))
).sum(dim=-1)
js = 0.5 * (kl_curr + kl_prev)
self._last_moe_router_js_sum.add_(js.sum())
self._last_moe_router_js_count.add_(float(js.numel()))
curr_top1 = curr_p_valid.argmax(dim=-1)
prev_top1 = prev_p.argmax(dim=-1)
switch = (curr_top1 != prev_top1).to(torch.float32)
self._last_moe_router_top1_switch_sum.add_(switch.sum())
self._last_moe_router_top1_switch_count.add_(float(switch.numel()))
self._prev_last_moe_router_p.copy_(curr_p)
self._prev_last_moe_router_valid.fill_(True)
def get_last_moe_router_shift_stats(
self,
) -> dict[str, torch.Tensor | None]:
return {
"js_sum": self._last_moe_router_js_sum,
"js_count": self._last_moe_router_js_count,
"top1_switch_sum": self._last_moe_router_top1_switch_sum,
"top1_switch_count": self._last_moe_router_top1_switch_count,
}
def reset_kv_cache(self, num_envs: int, device):
"""Initialize per-environment KV cache for single-step inference."""
cache_dtype = (
torch.float16
if torch.device(device).type == "cuda"
else torch.float32
)
self._k_cache = torch.zeros(
num_envs,
self.n_layers,
self.max_ctx_len,
self.n_kv_heads,
self.head_dim,
device=device,
dtype=cache_dtype,
)
self._v_cache = torch.zeros_like(self._k_cache)
self._kv_cache_len = torch.zeros(
num_envs, dtype=torch.long, device=device
)
self._kv_cache_write_idx = torch.zeros(
num_envs, dtype=torch.long, device=device
)
self._kv_cache_abs_pos = torch.zeros(
num_envs, dtype=torch.long, device=device
)
self._init_last_moe_router_shift_state(num_envs, device)
def clear_env_cache(self, env_ids: torch.Tensor | None):
"""Reset KV cache state for specific environments."""
if self._k_cache is None:
return
if env_ids is None:
self._k_cache.zero_()
self._v_cache.zero_()
self._kv_cache_len.zero_()
self._kv_cache_write_idx.zero_()
self._kv_cache_abs_pos.zero_()
if self._prev_last_moe_router_p is not None:
self._prev_last_moe_router_p.zero_()
if self._prev_last_moe_router_valid is not None:
self._prev_last_moe_router_valid.zero_()
if self._last_moe_router_js_sum is not None:
self._last_moe_router_js_sum.zero_()
if self._last_moe_router_js_count is not None:
self._last_moe_router_js_count.zero_()
if self._last_moe_router_top1_switch_sum is not None:
self._last_moe_router_top1_switch_sum.zero_()
if self._last_moe_router_top1_switch_count is not None:
self._last_moe_router_top1_switch_count.zero_()
else:
self._k_cache[env_ids] = 0.0
self._v_cache[env_ids] = 0.0
self._kv_cache_len[env_ids] = 0
self._kv_cache_write_idx[env_ids] = 0
self._kv_cache_abs_pos[env_ids] = 0
if self._prev_last_moe_router_valid is not None:
self._prev_last_moe_router_valid[env_ids] = False
if self._prev_last_moe_router_p is not None:
self._prev_last_moe_router_p[env_ids] = 0.0
def set_collect_routing_stats(self, collect: bool) -> None:
collect_flag = bool(collect)
for layer_idx, layer in enumerate(self.layers):
if isinstance(layer, GroupedMoEBlock):
layer.collect_routing_stats = collect_flag
layer.collect_router_distribution = (
collect_flag and layer_idx == self._last_moe_layer_idx
)
def reset_routing_stats(self) -> None:
for layer in self.layers:
if isinstance(layer, GroupedMoEBlock):
layer.reset_routing_stats()
def clear_router_distribution_cache(self) -> None:
for layer in self.layers:
if isinstance(layer, GroupedMoEBlock):
layer.last_router_distribution = None
layer.last_router_logits = None
layer.capture_router_distribution = False
layer.capture_router_logits = False
def _set_capture_router_distributions(self, capture: bool) -> None:
self._set_capture_router_features(
capture_distributions=capture,
capture_logits=False,
)
def _set_capture_router_features(
self,
*,
capture_distributions: bool,
capture_logits: bool,
) -> None:
capture_distribution_flag = bool(capture_distributions)
capture_logits_flag = bool(capture_logits)
for layer in self.layers:
if isinstance(layer, GroupedMoEBlock):
layer.capture_router_distribution = capture_distribution_flag
layer.capture_router_logits = capture_logits_flag
def apply_dynamic_bias_update_from_stats(self) -> None:
for layer in self.layers:
if isinstance(layer, GroupedMoEBlock):
layer.apply_bias_update_from_counts()
def _make_causal_mask(self, T: int, device) -> torch.Tensor:
"""Generate causal attention mask: shape [T, T], True where attend allowed."""
return torch.tril(torch.ones(T, T, device=device, dtype=torch.bool))
def _forward_layers_range(
self,
h: torch.Tensor,
cos: torch.Tensor | None,
sin: torch.Tensor | None,
mask: torch.Tensor | None,
memory: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
router_h: torch.Tensor | None = None,
router_h_per_layer: list[torch.Tensor | None] | None = None,
*,
start_layer: int,
end_layer: int,
return_pre_moe_hidden: bool = False,
return_router_features: bool = False,
return_router_temporal_features: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward through a contiguous layer range with optional checkpointing."""
if (
start_layer < 0
or end_layer < start_layer
or end_layer > len(self.layers)
):
raise ValueError(
"Invalid layer range for _forward_layers_range: "
f"start_layer={start_layer}, end_layer={end_layer}, "
f"num_layers={len(self.layers)}."
)
pre_moe_hidden = None
router_features = []
router_temporal_features = []
self._set_capture_router_features(
capture_distributions=return_router_features,
capture_logits=return_router_temporal_features,
)
try:
for layer_idx in range(start_layer, end_layer):
layer = self.layers[layer_idx]
layer_router_h = router_h
if router_h_per_layer is not None:
layer_router_h = router_h_per_layer[layer_idx]
if self.use_checkpointing and self.training:
if isinstance(layer, GroupedMoEBlock):
h = checkpoint.checkpoint(
layer,
h,
cos,
sin,
mask,
memory,
memory_mask,
layer_router_h,
use_reentrant=False,
)
else:
h = checkpoint.checkpoint(
layer,
h,
cos,
sin,
mask,
memory,
memory_mask,
use_reentrant=False,
)
else:
if isinstance(layer, GroupedMoEBlock):
h = layer(
h,
cos,
sin,
mask,
memory,
memory_mask,
router_x=layer_router_h,
)
else:
h = layer(h, cos, sin, mask, memory, memory_mask)
if return_pre_moe_hidden and layer_idx == 0:
pre_moe_hidden = h
if return_router_features and isinstance(
layer, GroupedMoEBlock
):
if layer.last_router_distribution is None:
raise ValueError(
f"Missing router distribution for MoE layer {layer_idx}."
)
router_features.append(layer.last_router_distribution)
if return_router_temporal_features and isinstance(
layer, GroupedMoEBlock
):
if layer.last_router_logits is None:
raise ValueError(
f"Missing router logits for MoE layer {layer_idx}."
)
router_temporal_features.append(layer.last_router_logits)
finally:
self._set_capture_router_features(
capture_distributions=False,
capture_logits=False,
)
outputs: list[torch.Tensor] = [h]
if return_pre_moe_hidden:
if pre_moe_hidden is None:
raise ValueError(
"Missing pre-MoE hidden state from the leading dense layer."
)
outputs.append(pre_moe_hidden)
if return_router_features:
if len(router_features) == 0:
raise ValueError(
"Missing router features while return_router_features=True."
)
outputs.append(torch.cat(router_features, dim=-1))
if return_router_temporal_features:
if len(router_temporal_features) == 0:
raise ValueError(
"Missing router temporal features while "
"return_router_temporal_features=True."
)
outputs.append(torch.cat(router_temporal_features, dim=-1))
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
def _forward_layers(
self,
h: torch.Tensor,
cos: torch.Tensor | None,
sin: torch.Tensor | None,
mask: torch.Tensor | None,
memory: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
router_h: torch.Tensor | None = None,
router_h_per_layer: list[torch.Tensor | None] | None = None,
return_pre_moe_hidden: bool = False,
return_router_features: bool = False,
return_router_temporal_features: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
return self._forward_layers_range(
h,
cos,
sin,
mask,
memory,
memory_mask,
router_h,
router_h_per_layer,
start_layer=0,
end_layer=len(self.layers),
return_pre_moe_hidden=return_pre_moe_hidden,
return_router_features=return_router_features,
return_router_temporal_features=return_router_temporal_features,
)
def _compute_router_hidden(self, x: torch.Tensor) -> torch.Tensor | None:
return None
def sequence_mu(
self,
x: torch.Tensor,
*,
attn_mask: torch.Tensor | None = None,
return_hidden: bool = False,
return_pre_moe_hidden: bool = False,
return_router_features: bool = False,
return_router_temporal_features: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Compute per-token action mean for sequences.
Args:
x: [B, T, D] flat obs per token.
attn_mask: [B, T, T] boolean mask (True if attend allowed), or None for causal.
return_hidden: If True, also return the hidden states.
Returns:
mu: [B, T, A]
h: [B, T, d_model] (only if return_hidden=True)
"""
B, T, _ = x.shape
h = self.obs_embed(x) # [B, T, d_model]
router_h = self._compute_router_hidden(x)
# SDPA bool attention mask uses True = allowed (can attend).
if attn_mask is not None:
tgt_mask = attn_mask.unsqueeze(1) # [B, 1, T, T]
# Episode-aware positions: first attendable token is episode start.
start_idx = attn_mask.to(torch.int64).argmax(dim=-1) # [B, T]
t_idx = torch.arange(T, device=x.device, dtype=torch.long)[
None, :
].expand(B, T)
pos = t_idx - start_idx # [B, T]
else:
tgt_mask = None
pos = torch.arange(T, device=x.device, dtype=torch.long)[
None, :
].expand(B, T)
cos, sin = self.get_cos_sin(h, pos) # [B, T, head_dim//2]
if return_hidden and return_pre_moe_hidden:
raise ValueError(
"return_hidden and return_pre_moe_hidden cannot both be True."
)
forward_out = self._forward_layers(
h,
cos=cos,
sin=sin,
mask=tgt_mask,
router_h=router_h,
return_pre_moe_hidden=return_pre_moe_hidden,
return_router_features=return_router_features,
return_router_temporal_features=return_router_temporal_features,
)
extras: list[torch.Tensor] = []
if isinstance(forward_out, tuple):
h = forward_out[0]
extras = list(forward_out[1:])
else:
h = forward_out
h = self.norm_f(h)
mu = self.action_mu_head(h)
outputs: list[torch.Tensor] = [mu]
if return_pre_moe_hidden:
outputs.append(extras.pop(0))
if return_router_features:
outputs.append(extras.pop(0))
if return_router_temporal_features:
outputs.append(extras.pop(0))
if len(outputs) > 1:
return tuple(outputs)
if return_hidden:
return mu, h
return mu
def sequence_hidden(
self,
x: torch.Tensor,
*,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute per-token latent features for sequences.
Args:
x: [B, T, D] flat obs per token.
attn_mask: [B, T, T] boolean mask (True if attend allowed).
Returns:
h_f: [B, T, d_model]
"""
B, T, _ = x.shape
h = self.obs_embed(x) # [B, T, d_model]
router_h = self._compute_router_hidden(x)
if attn_mask is not None:
tgt_mask = attn_mask.unsqueeze(1) # [B, 1, T, T]
start_idx = attn_mask.to(torch.int64).argmax(dim=-1) # [B, T]
t_idx = torch.arange(T, device=x.device, dtype=torch.long)[
None, :
].expand(B, T)
pos = t_idx - start_idx # [B, T]
else:
tgt_mask = None
pos = torch.arange(T, device=x.device, dtype=torch.long)[
None, :
].expand(B, T)
cos, sin = self.get_cos_sin(h, pos)
h = self._forward_layers(
h,
cos=cos,
sin=sin,
mask=tgt_mask,
router_h=router_h,
)
h = self.norm_f(h)
return h
def _embed_future_tokens(
self, future_tokens: torch.Tensor
) -> torch.Tensor:
if not self.use_future_cross_attn:
raise ValueError(
"_embed_future_tokens requires use_future_cross_attn=True"
)
if future_tokens.ndim == 3:
b, n, d = future_tokens.shape
if n != self.future_seq_len:
raise ValueError(
f"future token length mismatch: expected {self.future_seq_len}, got {n}"
)
if d != self.future_token_dim:
raise ValueError(
f"future token dim mismatch: expected {self.future_token_dim}, got {d}"
)
pos = torch.arange(
n, device=future_tokens.device, dtype=torch.long
)
pos_emb = self.future_pos_embed(pos)[None, :, :]
return self.future_obs_embed(future_tokens) + pos_emb
if future_tokens.ndim == 4:
b, t, n, d = future_tokens.shape
if n != self.future_seq_len:
raise ValueError(
f"future token length mismatch: expected {self.future_seq_len}, got {n}"
)
if d != self.future_token_dim:
raise ValueError(
f"future token dim mismatch: expected {self.future_token_dim}, got {d}"
)
pos = torch.arange(
n, device=future_tokens.device, dtype=torch.long
)
pos_emb = self.future_pos_embed(pos)[None, None, :, :]
return self.future_obs_embed(future_tokens) + pos_emb
raise ValueError(
f"future_tokens must be 3D or 4D, got shape {tuple(future_tokens.shape)}"
)
def sequence_mu_cond(
self,
state_seq: torch.Tensor,
future_seq: torch.Tensor,
*,
attn_mask: torch.Tensor | None = None,
future_mask: torch.Tensor | None = None,
return_pre_moe_hidden: bool = False,
return_router_features: bool = False,
return_router_temporal_features: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
if not self.use_future_cross_attn:
raise ValueError(
"sequence_mu_cond requires use_future_cross_attn=True"
)
if state_seq.ndim != 3:
raise ValueError(
f"state_seq must have shape [B, T, D], got {tuple(state_seq.shape)}"
)
if future_seq.ndim != 4:
raise ValueError(
"future_seq must have shape [B, T, N_fut, D_fut], "
f"got {tuple(future_seq.shape)}"
)
b, t, d_state = state_seq.shape
bf, tf, n_fut, d_fut = future_seq.shape
if bf != b or tf != t:
raise ValueError(
"state_seq and future_seq batch/time mismatch: "
f"state={tuple(state_seq.shape)}, future={tuple(future_seq.shape)}"
)
if d_state != self.state_obs_dim:
raise ValueError(
f"state_seq dim mismatch: expected {self.state_obs_dim}, got {d_state}"
)
if n_fut != self.future_seq_len:
raise ValueError(
f"future_seq len mismatch: expected {self.future_seq_len}, got {n_fut}"
)
if d_fut != self.future_token_dim:
raise ValueError(
f"future_seq dim mismatch: expected {self.future_token_dim}, got {d_fut}"
)
h = self.state_obs_embed(state_seq)
memory = self._embed_future_tokens(future_seq)
if future_mask is None:
future_mask = torch.ones(
b,
t,
n_fut,
dtype=torch.bool,
device=state_seq.device,
)
if future_mask.shape != (b, t, n_fut):
raise ValueError(
"future_mask shape mismatch: expected "
f"{(b, t, n_fut)}, got {tuple(future_mask.shape)}"
)
if attn_mask is not None:
tgt_mask = attn_mask.unsqueeze(1)
start_idx = attn_mask.to(torch.int64).argmax(dim=-1)
t_idx = torch.arange(t, device=state_seq.device, dtype=torch.long)[
None, :
].expand(b, t)
pos = t_idx - start_idx
else:
tgt_mask = None
pos = torch.arange(t, device=state_seq.device, dtype=torch.long)[
None, :
].expand(b, t)
cos, sin = self.get_cos_sin(h, pos)
forward_out = self._forward_layers(
h,
cos=cos,
sin=sin,
mask=tgt_mask,
memory=memory,
memory_mask=future_mask,
return_pre_moe_hidden=return_pre_moe_hidden,
return_router_features=return_router_features,
return_router_temporal_features=return_router_temporal_features,
)
extras: list[torch.Tensor] = []
if isinstance(forward_out, tuple):
h = forward_out[0]
extras = list(forward_out[1:])
else:
h = forward_out
h = self.norm_f(h)
mu = self.action_mu_head(h)
outputs: list[torch.Tensor] = [mu]
if return_pre_moe_hidden:
outputs.append(extras.pop(0))
if return_router_features:
outputs.append(extras.pop(0))
if return_router_temporal_features:
outputs.append(extras.pop(0))
if len(outputs) > 1:
return tuple(outputs)
return mu
def predict_aux_from_pre_moe(
self,
pre_moe_hidden: torch.Tensor,
*,
ref_aux_hidden: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
if not self.aux_state_pred_enabled:
raise ValueError(
"predict_aux_from_pre_moe requires aux_state_pred.enabled=True."
)
if pre_moe_hidden.ndim != 3:
raise ValueError(
f"Expected pre_moe_hidden with shape [B, T, D], got {tuple(pre_moe_hidden.shape)}"
)
vel_params = self.aux_vel_head(pre_moe_hidden)
height_params = self.aux_height_head(pre_moe_hidden)
vel_loc, vel_log_std = vel_params.chunk(2, dim=-1)
height_loc, height_log_std = height_params.chunk(2, dim=-1)
aux_outputs = {
"base_lin_vel_loc": vel_loc,
"base_lin_vel_log_std": vel_log_std,
"root_height_loc": height_loc,
"root_height_log_std": height_log_std,
}
if self.aux_contact_head is not None:
aux_outputs["keybody_contact_logits"] = self.aux_contact_head(
pre_moe_hidden
)
else:
aux_outputs["keybody_contact_logits"] = pre_moe_hidden.new_zeros(
pre_moe_hidden.shape[0],
pre_moe_hidden.shape[1],
0,
)
if self.aux_denoise_ref_root_lin_vel_head is not None:
aux_outputs["denoise_ref_root_lin_vel_residual"] = (
self.aux_denoise_ref_root_lin_vel_head(pre_moe_hidden)
)
if self.aux_denoise_ref_root_ang_vel_head is not None:
aux_outputs["denoise_ref_root_ang_vel_residual"] = (
self.aux_denoise_ref_root_ang_vel_head(pre_moe_hidden)
)
if self.aux_ref_keybody_pos_head is not None:
aux_outputs["ref_keybody_rel_pos"] = self.aux_ref_keybody_pos_head(
pre_moe_hidden
).reshape(
pre_moe_hidden.shape[0],
pre_moe_hidden.shape[1],
self.aux_keybody_pos_dim,
3,
)
aux_outputs["robot_keybody_rel_pos"] = (
self.aux_robot_keybody_pos_head(pre_moe_hidden).reshape(
pre_moe_hidden.shape[0],
pre_moe_hidden.shape[1],
self.aux_keybody_pos_dim,
3,
)
)
else:
aux_outputs["ref_keybody_rel_pos"] = pre_moe_hidden.new_zeros(
pre_moe_hidden.shape[0],
pre_moe_hidden.shape[1],
0,
3,
)
aux_outputs["robot_keybody_rel_pos"] = pre_moe_hidden.new_zeros(
pre_moe_hidden.shape[0],
pre_moe_hidden.shape[1],
0,
3,
)
if self.aux_denoise_ref_dof_pos_head is not None:
aux_outputs["denoise_ref_dof_pos_residual"] = (
self.aux_denoise_ref_dof_pos_head(pre_moe_hidden)
)
return aux_outputs
def predict_aux_router_command_from_router_features(
self, router_features: torch.Tensor
) -> torch.Tensor:
if not self.aux_router_command_recon_enabled:
raise ValueError(
"predict_aux_router_command_from_router_features requires "
"aux_router_command_recon.enabled=True."
)
if router_features.ndim != 3:
raise ValueError(
"Expected router_features with shape [B, T, D], got "
f"{tuple(router_features.shape)}."
)
if self.aux_router_command_recon_head is None:
raise ValueError(
"aux_router_command_recon_head is not initialized."
)
return self.aux_router_command_recon_head(router_features)
def update_aux_router_future_recon_normalizer(
self, future_target: torch.Tensor
) -> None:
if not self.aux_router_future_recon_enabled:
raise ValueError(
"update_aux_router_future_recon_normalizer requires "
"aux_router_future_recon.enabled=True."
)
if self.aux_router_future_recon_normalizer is None:
raise ValueError(
"aux_router_future_recon_normalizer is not initialized."
)
if future_target.ndim < 2:
raise ValueError(
"Expected future_target with shape [B, D] or [B, T, D], got "
f"{tuple(future_target.shape)}."
)
flat_target = future_target.reshape(
-1, future_target.shape[-1]
).detach()
self.aux_router_future_recon_normalizer.update(flat_target)
def normalize_aux_router_future_recon_target(
self, future_target: torch.Tensor
) -> torch.Tensor:
if not self.aux_router_future_recon_enabled:
raise ValueError(
"normalize_aux_router_future_recon_target requires "
"aux_router_future_recon.enabled=True."
)
if self.aux_router_future_recon_normalizer is None:
raise ValueError(
"aux_router_future_recon_normalizer is not initialized."
)
if future_target.ndim < 2:
raise ValueError(
"Expected future_target with shape [B, D] or [B, T, D], got "
f"{tuple(future_target.shape)}."
)
flat_target = future_target.reshape(-1, future_target.shape[-1])
norm_target = self.aux_router_future_recon_normalizer.normalize_only(
flat_target
)
return norm_target.reshape_as(future_target)
def predict_aux_router_future_recon_from_router_hidden(
self, router_hidden: torch.Tensor
) -> torch.Tensor:
if not self.aux_router_future_recon_enabled:
raise ValueError(
"predict_aux_router_future_recon_from_router_hidden requires "
"aux_router_future_recon.enabled=True."
)
if router_hidden.ndim != 3:
raise ValueError(
"Expected router_hidden with shape [B, T, D], got "
f"{tuple(router_hidden.shape)}."
)
if self.aux_router_future_recon_head is None:
raise ValueError(
"aux_router_future_recon_head is not initialized."
)
return self.aux_router_future_recon_head(router_hidden)
def single_step_mu_cond(
self,
state_x: torch.Tensor,
future_tokens: torch.Tensor,
*,
future_mask: torch.Tensor | None = None,
) -> torch.Tensor:
if not self.use_future_cross_attn:
raise ValueError(
"single_step_mu_cond requires use_future_cross_attn=True"
)
if state_x.ndim != 2:
raise ValueError(f"Expected state_x [B, D], got {state_x.shape}")
if future_tokens.ndim != 3:
raise ValueError(
"Expected future_tokens [B, N_fut, D_fut], "
f"got {future_tokens.shape}"
)
b, d_state = state_x.shape
bf, n_fut, d_fut = future_tokens.shape
if bf != b:
raise ValueError(
f"Batch mismatch between state_x and future_tokens: {b} vs {bf}"
)
if d_state != self.state_obs_dim:
raise ValueError(
f"state_x dim mismatch: expected {self.state_obs_dim}, got {d_state}"
)
if n_fut != self.future_seq_len:
raise ValueError(
f"future len mismatch: expected {self.future_seq_len}, got {n_fut}"
)
if d_fut != self.future_token_dim:
raise ValueError(
f"future dim mismatch: expected {self.future_token_dim}, got {d_fut}"
)
if self._k_cache is None:
state_seq = state_x[:, None, :]
future_seq = future_tokens[:, None, :, :]
if future_mask is not None:
future_mask = future_mask[:, None, :]
mu_seq = self.sequence_mu_cond(
state_seq,
future_seq,
attn_mask=None,
future_mask=future_mask,
)
return mu_seq[:, 0, :]
if self._k_cache.device != state_x.device:
self._k_cache = self._k_cache.to(state_x.device)
self._v_cache = self._v_cache.to(state_x.device)
self._kv_cache_len = self._kv_cache_len.to(state_x.device)
self._kv_cache_write_idx = self._kv_cache_write_idx.to(
state_x.device
)
self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(state_x.device)
h = self.state_obs_embed(state_x)[:, None, :]
memory = self._embed_future_tokens(future_tokens)
if self._k_cache.dtype != h.dtype:
self._k_cache = self._k_cache.to(h.dtype)
self._v_cache = self._v_cache.to(h.dtype)
cache_len = self._kv_cache_len
insert_pos = self._kv_cache_write_idx
max_len = int(self.max_ctx_len)
new_len = torch.clamp(cache_len + 1, max=max_len)
self._kv_cache_len = new_len
self._kv_cache_write_idx = (insert_pos + 1) % max_len
pos = self._kv_cache_abs_pos
self._kv_cache_abs_pos = pos + 1
pos_ids = pos.unsqueeze(1)
cos, sin = self.get_cos_sin(h, pos_ids)
memory_mask = None
if future_mask is not None:
if future_mask.shape != (b, n_fut):
raise ValueError(
"future_mask shape mismatch for single-step path: expected "
f"{(b, n_fut)}, got {tuple(future_mask.shape)}"
)
memory_mask = future_mask[:, None, None, :]
for layer_idx, layer in enumerate(self.layers):
x_norm = layer.norm1(h)
k_cache_l = self._k_cache[:, layer_idx]
v_cache_l = self._v_cache[:, layer_idx]
attn_out, _, _ = layer.attn.forward_single_token(
x_norm,
cos,
sin,
k_cache_l,
v_cache_l,
new_len,
insert_pos,
)
h = h + attn_out
if layer.use_cross_attn:
h = h + layer.cross_attn(
layer.norm_cross(h), memory, memory_mask
)
h2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
ffn = layer.compute_moe_ffn(h2)
if (
layer_idx == self._last_moe_layer_idx
and layer.collect_routing_stats
and layer.last_router_distribution is not None
):
self._accumulate_last_moe_router_shift(
layer.last_router_distribution
)
else:
ffn = layer.mlp_dropout(layer.mlp(h2))
h = h + ffn
h = self.norm_f(h)
return self.action_mu_head(h[:, 0, :])
def forward(
self,
input: torch.Tensor,
past_key_values: torch.Tensor | None = None,
current_pos: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for single-step inference (no history)."""
if past_key_values is not None:
return self._forward_inference_onnx(
input, past_key_values, current_pos
)
if input.ndim != 2:
raise ValueError(f"Expected [B, D], got {input.shape}")
mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None)
return mu_seq[:, 0, :]
def single_step_mu(self, x: torch.Tensor) -> torch.Tensor:
"""Compute action mean for a single step using per-layer KV cache.
Uses a ring-buffer KV cache with per-env absolute positions for RoPE.
"""
if x.ndim != 2:
raise ValueError(f"Expected [B, D], got {x.shape}")
B, _ = x.shape
if self._k_cache is None:
mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None)
return mu_seq[:, 0, :]
# Ensure cache device matches
if self._k_cache.device != x.device:
self._k_cache = self._k_cache.to(x.device)
self._v_cache = self._v_cache.to(x.device)
self._kv_cache_len = self._kv_cache_len.to(x.device)
self._kv_cache_write_idx = self._kv_cache_write_idx.to(x.device)
self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(x.device)
h = self.obs_embed(x)[:, None, :] # [B, 1, d_model]
router_h = self._compute_router_hidden(x)
if router_h is not None:
router_h = router_h[:, None, :]
# Ensure cache dtype matches compute dtype (convert once if needed)
if self._k_cache.dtype != h.dtype:
self._k_cache = self._k_cache.to(h.dtype)
self._v_cache = self._v_cache.to(h.dtype)
cache_len = self._kv_cache_len # [B]
insert_pos = self._kv_cache_write_idx # [B]
max_len = int(self.max_ctx_len)
new_len = torch.clamp(cache_len + 1, max=max_len) # [B]
self._kv_cache_len = new_len
self._kv_cache_write_idx = (insert_pos + 1) % max_len
# RoPE frequencies for current absolute position
pos = self._kv_cache_abs_pos # [B]
self._kv_cache_abs_pos = pos + 1
pos_ids = pos.unsqueeze(1) # [B, 1]
cos, sin = self.get_cos_sin(h, pos_ids)
for layer_idx, layer in enumerate(self.layers):
x_norm = layer.norm1(h)
k_cache_l = self._k_cache[
:, layer_idx
] # [B, L, n_kv_heads, head_dim]
v_cache_l = self._v_cache[:, layer_idx]
attn_out, _, _ = layer.attn.forward_single_token(
x_norm,
cos,
sin,
k_cache_l,
v_cache_l,
new_len,
insert_pos,
)
h = h + attn_out
# FFN/MoE path for single token (h: [B,1,D])
h2 = layer.norm2(h) # [B,1,D]
if isinstance(layer, GroupedMoEBlock):
ffn = layer.compute_moe_ffn(h2, router_x=router_h)
if (
layer_idx == self._last_moe_layer_idx
and layer.collect_routing_stats
and layer.last_router_distribution is not None
):
self._accumulate_last_moe_router_shift(
layer.last_router_distribution
)
else:
ffn = layer.mlp_dropout(layer.mlp(h2)) # 保持 [B, 1, D]
h = h + ffn
h = self.norm_f(h)
return self.action_mu_head(h[:, 0, :])
def _forward_inference_onnx(
self,
x: torch.Tensor,
past_key_values: torch.Tensor,
current_pos: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""Single-step inference compatible with ONNX export.
Aligns strictly with `single_step_mu` logic using Real-valued RoPE.
Args:
x: [B, D] (Batch=1 for ONNX usually)
past_key_values: [n_layers, 2, B, max_len, n_kv_heads, head_dim]
current_pos: [B] or scalar, the absolute step index (0, 1, 2...)
Returns:
action: [B, A]
present_key_values: Updated KV cache tensor
"""
# Embedding [B, D] -> [B, 1, D]
h = self.obs_embed(x)[:, None, :] # [1, 1, 512]
router_h = self._compute_router_hidden(x)
if router_h is not None:
router_h = router_h[:, None, :]
B = h.shape[0] # 1
# Calculate Cache Indices (Ring Buffer Logic)
# past_key_values shape: [L, 2, B, T, H, D] -> T is index 3
max_len = past_key_values.shape[3] # 32
if current_pos.ndim == 0:
current_pos = current_pos.view(1).expand(B)
# insert_pos: [B]
insert_pos = current_pos % max_len
# new_len: [B]
new_len = torch.clamp(current_pos + 1, max=max_len)
# position_ids: [B, 1]
position_ids = current_pos.unsqueeze(1)
# cos, sin shape: [B, 1, head_dim]
cos, sin = self.get_cos_sin(h, position_ids)
present_key_values_list = []
routing_debug_outputs: list[torch.Tensor] = []
export_routing_debug = torch.onnx.is_in_onnx_export()
for i, layer in enumerate(self.layers):
# Unpack Cache: [2, B, T, H, D]
layer_past = past_key_values[i]
k_cache = layer_past[0]
v_cache = layer_past[1]
h_norm = layer.norm1(h)
# Attention
attn_out, new_k_cache, new_v_cache = (
layer.attn.forward_single_token(
x=h_norm,
cos=cos,
sin=sin,
k_cache=k_cache,
v_cache=v_cache,
new_len=new_len,
insert_pos=insert_pos,
)
)
h = h + attn_out
# FFN / MoE
h_norm2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
if export_routing_debug:
ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(
h_norm2,
router_x=router_h,
return_routing_debug=True,
)
routing_debug_outputs.extend([topk_idx, router_logits])
else:
ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h)
else:
# Dense MLP
ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))
h = h + ffn_out
current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)
present_key_values_list.append(current_layer_kv)
h = self.norm_f(h)
action = self.action_mu_head(h[:, 0, :])
present_key_values = torch.stack(present_key_values_list, dim=0)
if export_routing_debug and routing_debug_outputs:
return (action, present_key_values, *routing_debug_outputs)
return action, present_key_values
class ReferenceRoutedGroupedMoETransformerPolicy(GroupedMoETransformerPolicy):
def __init__(
self,
input_dim: int,
output_dim: int,
module_config_dict: dict,
):
module_config = dict(module_config_dict)
if bool(module_config.get("use_future_cross_attn", False)):
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicy does not support "
"use_future_cross_attn=True."
)
router_input_dim = module_config.get("router_input_dim", None)
router_feature_indices = module_config.get(
"router_feature_indices", None
)
if router_input_dim is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicy requires router_input_dim."
)
if router_feature_indices is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicy requires "
"router_feature_indices."
)
self.router_input_dim = int(router_input_dim)
self.router_feature_indices = tuple(
int(idx) for idx in router_feature_indices
)
if self.router_input_dim <= 0:
raise ValueError(
f"router_input_dim must be positive, got {self.router_input_dim}."
)
if len(self.router_feature_indices) != self.router_input_dim:
raise ValueError(
"router_input_dim must match len(router_feature_indices): "
f"{self.router_input_dim} vs {len(self.router_feature_indices)}."
)
if any(idx < 0 for idx in self.router_feature_indices):
raise ValueError(
f"router_feature_indices must be non-negative, got {self.router_feature_indices}."
)
super().__init__(
input_dim=input_dim,
output_dim=output_dim,
module_config_dict=module_config,
)
obs_in = int(self.obs_input_dim or self.input_dim)
if any(idx >= obs_in for idx in self.router_feature_indices):
raise ValueError(
"router_feature_indices exceed the flat actor obs dim "
f"{obs_in}: {self.router_feature_indices}"
)
self.router_embed_mlp_hidden = int(
module_config.get(
"router_embed_mlp_hidden", self.obs_embed_mlp_hidden
)
)
self.register_buffer(
"_router_feature_indices",
torch.tensor(self.router_feature_indices, dtype=torch.long),
persistent=False,
)
self.router_obs_embed = nn.Sequential(
nn.Linear(self.router_input_dim, self.router_embed_mlp_hidden),
nn.SiLU(),
nn.Linear(self.router_embed_mlp_hidden, self.d_model),
)
self._apply_freeze_router_state()
def _apply_freeze_router_state(self) -> None:
super()._apply_freeze_router_state()
self.router_obs_embed.requires_grad_(not self.freeze_router)
def _compute_router_hidden(self, x: torch.Tensor) -> torch.Tensor | None:
if x.shape[-1] != int(self.obs_input_dim or self.input_dim):
raise ValueError(
"Reference-routed policy expected flat obs dim "
f"{int(self.obs_input_dim or self.input_dim)}, got {x.shape[-1]}."
)
router_idx = self._router_feature_indices
if router_idx.device != x.device:
router_idx = router_idx.to(x.device)
router_obs = torch.index_select(x, dim=x.ndim - 1, index=router_idx)
return self.router_obs_embed(router_obs)
def _forward_inference_onnx_cond(
self,
state_x: torch.Tensor,
future_tokens: torch.Tensor,
past_key_values: torch.Tensor,
current_pos: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
if not self.use_future_cross_attn:
raise ValueError(
"_forward_inference_onnx_cond requires use_future_cross_attn=True"
)
if state_x.ndim != 2:
raise ValueError(
f"state_x must have shape [B, D_state], got {tuple(state_x.shape)}"
)
if future_tokens.ndim != 3:
raise ValueError(
"future_tokens must have shape [B, N_fut, D_fut], "
f"got {tuple(future_tokens.shape)}"
)
h = self.state_obs_embed(state_x)[:, None, :]
memory = self._embed_future_tokens(future_tokens)
b = h.shape[0]
max_len = past_key_values.shape[3]
if current_pos.ndim == 0:
current_pos = current_pos.view(1).expand(b)
insert_pos = current_pos % max_len
new_len = torch.clamp(current_pos + 1, max=max_len)
position_ids = current_pos.unsqueeze(1)
cos, sin = self.get_cos_sin(h, position_ids)
present_key_values_list = []
routing_debug_outputs: list[torch.Tensor] = []
export_routing_debug = torch.onnx.is_in_onnx_export()
for i, layer in enumerate(self.layers):
layer_past = past_key_values[i]
k_cache = layer_past[0]
v_cache = layer_past[1]
h_norm = layer.norm1(h)
attn_out, new_k_cache, new_v_cache = (
layer.attn.forward_single_token(
x=h_norm,
cos=cos,
sin=sin,
k_cache=k_cache,
v_cache=v_cache,
new_len=new_len,
insert_pos=insert_pos,
)
)
h = h + attn_out
if layer.use_cross_attn:
h = h + layer.cross_attn(layer.norm_cross(h), memory, None)
h_norm2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
if export_routing_debug:
ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(
h_norm2,
return_routing_debug=True,
)
routing_debug_outputs.extend([topk_idx, router_logits])
else:
ffn_out = layer.compute_moe_ffn(h_norm2)
else:
ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))
h = h + ffn_out
current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)
present_key_values_list.append(current_layer_kv)
h = self.norm_f(h)
action = self.action_mu_head(h[:, 0, :])
present_key_values = torch.stack(present_key_values_list, dim=0)
if export_routing_debug and routing_debug_outputs:
return (action, present_key_values, *routing_debug_outputs)
return action, present_key_values
class ReferenceRoutedGroupedMoETransformerPolicyV2(
GroupedMoETransformerPolicy
):
supports_explicit_ref_aux_hidden = True
def __init__(
self,
input_dim: int,
output_dim: int,
module_config_dict: dict,
):
module_config = dict(module_config_dict)
if bool(module_config.get("use_future_cross_attn", False)):
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 does not "
"support use_future_cross_attn=True."
)
state_obs_input_dim = module_config.get("state_obs_input_dim", None)
ref_cur_token_dim = module_config.get("ref_cur_token_dim", None)
ref_fut_token_dim = module_config.get("ref_fut_token_dim", None)
ref_fut_seq_len = module_config.get("ref_fut_seq_len", None)
state_feature_indices = module_config.get(
"state_feature_indices", None
)
ref_cur_feature_indices = module_config.get(
"ref_cur_feature_indices", None
)
ref_fut_slices = module_config.get("ref_fut_slices", None)
if state_obs_input_dim is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 requires "
"state_obs_input_dim."
)
if ref_cur_token_dim is None or ref_fut_token_dim is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 requires "
"ref_cur_token_dim and ref_fut_token_dim."
)
if ref_fut_seq_len is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 requires "
"ref_fut_seq_len."
)
if state_feature_indices is None or ref_cur_feature_indices is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 requires "
"state_feature_indices and ref_cur_feature_indices."
)
if ref_fut_slices is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 requires "
"ref_fut_slices."
)
self.full_obs_input_dim = int(input_dim)
self.state_obs_input_dim = int(state_obs_input_dim)
self.ref_cur_token_dim = int(ref_cur_token_dim)
self.ref_fut_token_dim = int(ref_fut_token_dim)
self.ref_fut_seq_len = int(ref_fut_seq_len)
self.state_feature_indices = tuple(
int(idx) for idx in state_feature_indices
)
self.ref_cur_feature_indices = tuple(
int(idx) for idx in ref_cur_feature_indices
)
self.ref_fut_slices = tuple(
(int(start), int(end), int(dim))
for start, end, dim in ref_fut_slices
)
if self.state_obs_input_dim <= 0:
raise ValueError(
"state_obs_input_dim must be positive, got "
f"{self.state_obs_input_dim}."
)
if self.ref_cur_token_dim <= 0 or self.ref_fut_token_dim <= 0:
raise ValueError(
"ref token dims must be positive, got "
f"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}."
)
if self.ref_cur_token_dim != self.ref_fut_token_dim:
raise ValueError(
"current/future ref token dims must match, got "
f"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}."
)
if self.ref_fut_seq_len <= 0:
raise ValueError(
f"ref_fut_seq_len must be positive, got {self.ref_fut_seq_len}."
)
if len(self.state_feature_indices) != self.state_obs_input_dim:
raise ValueError(
"state_obs_input_dim must match len(state_feature_indices): "
f"{self.state_obs_input_dim} vs {len(self.state_feature_indices)}."
)
if len(self.ref_cur_feature_indices) != self.ref_cur_token_dim:
raise ValueError(
"ref_cur_token_dim must match len(ref_cur_feature_indices): "
f"{self.ref_cur_token_dim} vs {len(self.ref_cur_feature_indices)}."
)
fut_flat_dim = 0
for start, end, dim in self.ref_fut_slices:
if end <= start or dim <= 0:
raise ValueError(
f"Invalid ref_fut_slices entry {(start, end, dim)}."
)
if (end - start) != self.ref_fut_seq_len * dim:
raise ValueError(
"Future ref slice span must equal ref_fut_seq_len * dim, got "
f"{(start, end, dim)} with ref_fut_seq_len={self.ref_fut_seq_len}."
)
fut_flat_dim += end - start
expected_full_input_dim = (
self.state_obs_input_dim + self.ref_cur_token_dim + fut_flat_dim
)
if self.full_obs_input_dim != expected_full_input_dim:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 expected full "
f"input dim {expected_full_input_dim}, got {self.full_obs_input_dim}."
)
self.ref_hist_n_layers = int(module_config.get("ref_hist_n_layers", 1))
if self.ref_hist_n_layers != 1:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV2 currently supports "
"exactly one ref history attention layer."
)
self.ref_future_conv_channels = int(
module_config.get(
"ref_future_conv_channels", self.ref_cur_token_dim
)
)
self.ref_future_conv_layers = int(
module_config.get("ref_future_conv_layers", 2)
)
self.ref_future_conv_kernel_size = int(
module_config.get("ref_future_conv_kernel_size", 3)
)
self.ref_future_conv_stride = int(
module_config.get("ref_future_conv_stride", 2)
)
if self.ref_future_conv_layers <= 0:
raise ValueError(
"ref_future_conv_layers must be positive, got "
f"{self.ref_future_conv_layers}."
)
if self.ref_future_conv_kernel_size <= 0:
raise ValueError(
"ref_future_conv_kernel_size must be positive, got "
f"{self.ref_future_conv_kernel_size}."
)
if self.ref_future_conv_stride <= 0:
raise ValueError(
"ref_future_conv_stride must be positive, got "
f"{self.ref_future_conv_stride}."
)
module_config["input_dim_override"] = self.state_obs_input_dim
super().__init__(
input_dim=input_dim,
output_dim=output_dim,
module_config_dict=module_config,
)
self.onnx_kv_layers = int(self.ref_hist_n_layers + self.n_layers)
self.register_buffer(
"_state_feature_indices",
torch.tensor(self.state_feature_indices, dtype=torch.long),
persistent=False,
)
self.register_buffer(
"_ref_cur_feature_indices",
torch.tensor(self.ref_cur_feature_indices, dtype=torch.long),
persistent=False,
)
self.ref_frame_embed = nn.Sequential(
nn.Linear(self.ref_cur_token_dim, self.obs_embed_mlp_hidden),
nn.SiLU(),
nn.Linear(self.obs_embed_mlp_hidden, self.d_model),
)
self.ref_hist_norm = RMSNorm(self.d_model)
self.ref_hist_attn = ModernAttention(
d_model=self.d_model,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
use_qk_norm=self.use_qk_norm,
use_gated_attn=self.use_gated_attn,
gated_attn_type=self.gated_attn_type,
attn_dropout=self.attn_dropout,
)
self.ref_hist_out_norm = RMSNorm(self.d_model)
padding = self.ref_future_conv_kernel_size // 2
conv_modules: list[nn.Module] = []
in_ch = self.d_model
for layer_idx in range(self.ref_future_conv_layers):
out_ch = (
self.d_model
if layer_idx == self.ref_future_conv_layers - 1
else self.ref_future_conv_channels
)
conv_modules.append(
nn.Conv1d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=self.ref_future_conv_kernel_size,
stride=self.ref_future_conv_stride,
padding=padding,
bias=True,
)
)
conv_modules.append(nn.SiLU())
in_ch = out_ch
self.ref_future_conv = nn.Sequential(*conv_modules)
self.actor_ref_pool = SingleQueryAttentionPool(self.d_model)
self.router_ref_pool = SingleQueryAttentionPool(self.d_model)
self.router_query = nn.Parameter(torch.zeros(self.d_model))
self.actor_ref_ctx_norm = RMSNorm(self.d_model)
self.actor_film_hidden_norm = RMSNorm(self.d_model)
self.actor_ref_film = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.SiLU(),
nn.Linear(self.d_model, 2 * self.d_model),
)
nn.init.zeros_(self.actor_ref_film[-1].weight)
nn.init.zeros_(self.actor_ref_film[-1].bias)
self.actor_film_gain_max = float(
module_config.get("actor_film_gain_max", 1.0)
)
self.actor_film_gain_init = float(
module_config.get("actor_film_gain_init", 0.05)
)
if self.actor_film_gain_max <= 0.0:
raise ValueError(
"actor_film_gain_max must be positive, got "
f"{self.actor_film_gain_max}."
)
if not (0.0 < self.actor_film_gain_init < self.actor_film_gain_max):
raise ValueError(
"actor_film_gain_init must be in (0, actor_film_gain_max), "
f"got {self.actor_film_gain_init} with max "
f"{self.actor_film_gain_max}."
)
gain_init_ratio = self.actor_film_gain_init / self.actor_film_gain_max
self.actor_film_gain_raw = nn.Parameter(
torch.full(
(self.d_model,),
math.log(gain_init_ratio / (1.0 - gain_init_ratio)),
)
)
self.actor_film_scale_max = 0.5
self.actor_film_shift_max = 0.5
self.actor_film_delta_rms_eps = float(
module_config.get("actor_film_delta_rms_eps", 1.0e-6)
)
if self.actor_film_delta_rms_eps <= 0.0:
raise ValueError(
"actor_film_delta_rms_eps must be positive, got "
f"{self.actor_film_delta_rms_eps}."
)
self._ref_hist_k_cache: torch.Tensor | None = None
self._ref_hist_v_cache: torch.Tensor | None = None
self._apply_freeze_router_state()
def _apply_freeze_router_state(self) -> None:
super()._apply_freeze_router_state()
requires_grad = not self.freeze_router
self.ref_frame_embed.requires_grad_(requires_grad)
self.ref_hist_norm.requires_grad_(requires_grad)
self.ref_hist_attn.requires_grad_(requires_grad)
self.ref_hist_out_norm.requires_grad_(requires_grad)
self.ref_future_conv.requires_grad_(requires_grad)
self.router_ref_pool.requires_grad_(requires_grad)
self.router_query.requires_grad_(requires_grad)
def _build_shared_ref_tokens(
self,
ref_cur_x: torch.Tensor,
ref_fut_x: torch.Tensor,
pos: torch.Tensor,
tgt_mask: torch.Tensor | None,
) -> torch.Tensor:
with self._router_no_grad_context():
ref_cur_h = self.ref_frame_embed(ref_cur_x)
ref_hist_attn = self.ref_hist_attn(
self.ref_hist_norm(ref_cur_h),
*self.get_cos_sin(ref_cur_h, pos),
mask=tgt_mask,
)
ref_hist_h = self.ref_hist_out_norm(ref_cur_h + ref_hist_attn)
ref_fut_tokens = self._encode_future_tokens(ref_fut_x)
return torch.cat([ref_hist_h.unsqueeze(2), ref_fut_tokens], dim=2)
def _build_shared_ref_tokens_single_step(
self,
ref_cur_x: torch.Tensor,
ref_fut_x: torch.Tensor,
pos_ids: torch.Tensor,
*,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
new_len: torch.Tensor,
insert_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with self._router_no_grad_context():
ref_cur_h = self.ref_frame_embed(ref_cur_x)[:, None, :]
ref_cos, ref_sin = self.get_cos_sin(ref_cur_h, pos_ids)
ref_hist_attn, ref_k_cache, ref_v_cache = (
self.ref_hist_attn.forward_single_token(
self.ref_hist_norm(ref_cur_h),
ref_cos,
ref_sin,
k_cache,
v_cache,
new_len,
insert_pos,
)
)
ref_hist_h = self.ref_hist_out_norm(ref_cur_h + ref_hist_attn)
ref_fut_tokens = self._encode_future_tokens(ref_fut_x)
shared_ref_tokens = torch.cat([ref_hist_h, ref_fut_tokens], dim=1)
return shared_ref_tokens, ref_k_cache, ref_v_cache
def _split_actor_ref_inputs(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if x.ndim not in (2, 3):
raise ValueError(
f"Expected full obs tensor with ndim 2 or 3, got {x.ndim}."
)
if int(x.shape[-1]) != self.full_obs_input_dim:
raise ValueError(
"Full obs dim mismatch for reference router V2: expected "
f"{self.full_obs_input_dim}, got {int(x.shape[-1])}."
)
state_idx = self._state_feature_indices.to(x.device)
ref_cur_idx = self._ref_cur_feature_indices.to(x.device)
state_x = torch.index_select(x, dim=x.ndim - 1, index=state_idx)
ref_cur_x = torch.index_select(x, dim=x.ndim - 1, index=ref_cur_idx)
fut_parts: list[torch.Tensor] = []
for start, end, dim in self.ref_fut_slices:
chunk = x[..., start:end]
if x.ndim == 2:
fut_parts.append(
chunk.reshape(int(x.shape[0]), self.ref_fut_seq_len, dim)
)
else:
fut_parts.append(
chunk.reshape(
int(x.shape[0]),
int(x.shape[1]),
self.ref_fut_seq_len,
dim,
)
)
ref_fut_x = torch.cat(fut_parts, dim=-1)
return state_x, ref_cur_x, ref_fut_x
def _encode_future_tokens(self, ref_fut_x: torch.Tensor) -> torch.Tensor:
if ref_fut_x.ndim == 3:
fut = self.ref_frame_embed(ref_fut_x)
return self.ref_future_conv(fut.transpose(1, 2)).transpose(1, 2)
if ref_fut_x.ndim == 4:
batch, time, seq_len, dim = ref_fut_x.shape
fut = self.ref_frame_embed(
ref_fut_x.reshape(batch * time, seq_len, dim)
)
fut = self.ref_future_conv(fut.transpose(1, 2)).transpose(1, 2)
return fut.reshape(batch, time, fut.shape[1], self.d_model)
raise ValueError(
f"Expected ref_fut_x with ndim 3 or 4, got {ref_fut_x.ndim}."
)
def _pool_router_context(
self, shared_ref_tokens: torch.Tensor
) -> torch.Tensor:
with self._router_no_grad_context():
if shared_ref_tokens.ndim == 3:
query = self.router_query.to(
device=shared_ref_tokens.device,
dtype=shared_ref_tokens.dtype,
)[None, :].expand(int(shared_ref_tokens.shape[0]), -1)
elif shared_ref_tokens.ndim == 4:
query = self.router_query.to(
device=shared_ref_tokens.device,
dtype=shared_ref_tokens.dtype,
)[None, None, :].expand(
int(shared_ref_tokens.shape[0]),
int(shared_ref_tokens.shape[1]),
-1,
)
else:
raise ValueError(
"shared_ref_tokens must have ndim 3 or 4, got "
f"{shared_ref_tokens.ndim}."
)
return self.router_ref_pool(query, shared_ref_tokens)
def _apply_actor_ref_film(
self, state_hidden: torch.Tensor, actor_ref_ctx: torch.Tensor
) -> torch.Tensor:
ctx = self.actor_ref_ctx_norm(actor_ref_ctx)
scale_raw, shift_raw = self.actor_ref_film(ctx).chunk(2, dim=-1)
scale = self.actor_film_scale_max * torch.tanh(scale_raw)
shift = self.actor_film_shift_max * torch.tanh(shift_raw)
hidden_norm = self.actor_film_hidden_norm(state_hidden)
delta = scale * hidden_norm + shift
delta = self._normalize_actor_film_delta(delta)
gain = self._actor_film_gain().to(
device=state_hidden.device, dtype=state_hidden.dtype
)
expand_shape = [1] * (delta.ndim - 1) + [self.d_model]
return state_hidden + delta * gain.view(*expand_shape)
def _actor_film_gain(self) -> torch.Tensor:
return self.actor_film_gain_max * torch.sigmoid(
self.actor_film_gain_raw
)
def _normalize_actor_film_delta(self, delta: torch.Tensor) -> torch.Tensor:
rms = delta.pow(2).mean(dim=-1, keepdim=True)
return delta * torch.rsqrt(rms + self.actor_film_delta_rms_eps)
def _ensure_internal_cache_device(
self,
device,
*,
dtype: torch.dtype | None = None,
) -> None:
if self._k_cache is not None and self._k_cache.device != device:
self._k_cache = self._k_cache.to(device)
self._v_cache = self._v_cache.to(device)
self._ref_hist_k_cache = self._ref_hist_k_cache.to(device)
self._ref_hist_v_cache = self._ref_hist_v_cache.to(device)
self._kv_cache_len = self._kv_cache_len.to(device)
self._kv_cache_write_idx = self._kv_cache_write_idx.to(device)
self._kv_cache_abs_pos = self._kv_cache_abs_pos.to(device)
if (
dtype is not None
and self._k_cache is not None
and self._k_cache.dtype != dtype
):
self._k_cache = self._k_cache.to(dtype)
self._v_cache = self._v_cache.to(dtype)
self._ref_hist_k_cache = self._ref_hist_k_cache.to(dtype)
self._ref_hist_v_cache = self._ref_hist_v_cache.to(dtype)
def reset_kv_cache(self, num_envs: int, device):
cache_dtype = (
torch.float16
if torch.device(device).type == "cuda"
else torch.float32
)
self._k_cache = torch.zeros(
num_envs,
self.n_layers,
self.max_ctx_len,
self.n_kv_heads,
self.head_dim,
device=device,
dtype=cache_dtype,
)
self._v_cache = torch.zeros_like(self._k_cache)
self._ref_hist_k_cache = torch.zeros(
num_envs,
self.ref_hist_n_layers,
self.max_ctx_len,
self.n_kv_heads,
self.head_dim,
device=device,
dtype=cache_dtype,
)
self._ref_hist_v_cache = torch.zeros_like(self._ref_hist_k_cache)
self._kv_cache_len = torch.zeros(
num_envs, dtype=torch.long, device=device
)
self._kv_cache_write_idx = torch.zeros(
num_envs, dtype=torch.long, device=device
)
self._kv_cache_abs_pos = torch.zeros(
num_envs, dtype=torch.long, device=device
)
self._init_last_moe_router_shift_state(num_envs, device)
def clear_env_cache(self, env_ids: torch.Tensor | None):
if self._k_cache is None:
return
if env_ids is None:
self._k_cache.zero_()
self._v_cache.zero_()
self._ref_hist_k_cache.zero_()
self._ref_hist_v_cache.zero_()
self._kv_cache_len.zero_()
self._kv_cache_write_idx.zero_()
self._kv_cache_abs_pos.zero_()
if self._prev_last_moe_router_p is not None:
self._prev_last_moe_router_p.zero_()
if self._prev_last_moe_router_valid is not None:
self._prev_last_moe_router_valid.zero_()
if self._last_moe_router_js_sum is not None:
self._last_moe_router_js_sum.zero_()
if self._last_moe_router_js_count is not None:
self._last_moe_router_js_count.zero_()
if self._last_moe_router_top1_switch_sum is not None:
self._last_moe_router_top1_switch_sum.zero_()
if self._last_moe_router_top1_switch_count is not None:
self._last_moe_router_top1_switch_count.zero_()
return
self._k_cache[env_ids] = 0.0
self._v_cache[env_ids] = 0.0
self._ref_hist_k_cache[env_ids] = 0.0
self._ref_hist_v_cache[env_ids] = 0.0
self._kv_cache_len[env_ids] = 0
self._kv_cache_write_idx[env_ids] = 0
self._kv_cache_abs_pos[env_ids] = 0
if self._prev_last_moe_router_valid is not None:
self._prev_last_moe_router_valid[env_ids] = False
if self._prev_last_moe_router_p is not None:
self._prev_last_moe_router_p[env_ids] = 0.0
def predict_aux_from_pre_moe(
self,
pre_moe_hidden: torch.Tensor,
*,
ref_aux_hidden: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
aux_outputs = super().predict_aux_from_pre_moe(
pre_moe_hidden, ref_aux_hidden=ref_aux_hidden
)
if self.aux_ref_keybody_pos_head is not None:
if ref_aux_hidden is None:
raise ValueError(
"Missing shared-ref auxiliary hidden state for "
"ref_keybody_rel_pos prediction."
)
aux_outputs["ref_keybody_rel_pos"] = self.aux_ref_keybody_pos_head(
ref_aux_hidden
).reshape(
ref_aux_hidden.shape[0],
ref_aux_hidden.shape[1],
self.aux_keybody_pos_dim,
3,
)
return aux_outputs
def sequence_mu(
self,
x: torch.Tensor,
*,
attn_mask: torch.Tensor | None = None,
return_hidden: bool = False,
return_pre_moe_hidden: bool = False,
return_ref_aux_hidden: bool = False,
return_router_features: bool = False,
return_router_temporal_features: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
_, time, _ = x.shape
state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)
state_h = self.obs_embed(state_x)
if attn_mask is not None:
tgt_mask = attn_mask.unsqueeze(1)
start_idx = attn_mask.to(torch.int64).argmax(dim=-1)
t_idx = torch.arange(time, device=x.device, dtype=torch.long)[
None, :
].expand(int(x.shape[0]), time)
pos = t_idx - start_idx
else:
tgt_mask = None
pos = torch.arange(time, device=x.device, dtype=torch.long)[
None, :
].expand(int(x.shape[0]), time)
shared_ref_tokens = self._build_shared_ref_tokens(
ref_cur_x=ref_cur_x,
ref_fut_x=ref_fut_x,
pos=pos,
tgt_mask=tgt_mask,
)
actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)
router_h = self._pool_router_context(shared_ref_tokens)
cos, sin = self.get_cos_sin(state_h, pos)
if return_hidden and return_pre_moe_hidden:
raise ValueError(
"return_hidden and return_pre_moe_hidden cannot both be True."
)
block0_h = self._forward_layers_range(
state_h,
cos=cos,
sin=sin,
mask=tgt_mask,
router_h=router_h,
start_layer=0,
end_layer=1,
)
h = self._apply_actor_ref_film(block0_h, actor_ref_ctx)
forward_out = self._forward_layers_range(
h,
cos=cos,
sin=sin,
mask=tgt_mask,
router_h=router_h,
start_layer=1,
end_layer=len(self.layers),
return_router_features=return_router_features,
return_router_temporal_features=return_router_temporal_features,
)
extras: list[torch.Tensor] = []
if isinstance(forward_out, tuple):
h = forward_out[0]
extras = list(forward_out[1:])
else:
h = forward_out
h = self.norm_f(h)
mu = self.action_mu_head(h)
outputs: list[torch.Tensor] = [mu]
if return_pre_moe_hidden:
outputs.append(block0_h)
if return_ref_aux_hidden:
outputs.append(router_h)
if return_router_features:
outputs.append(extras.pop(0))
if return_router_temporal_features:
outputs.append(extras.pop(0))
if len(outputs) > 1:
return tuple(outputs)
if return_hidden:
return mu, h
return mu
def sequence_hidden(
self,
x: torch.Tensor,
*,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
_, time, _ = x.shape
state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)
state_h = self.obs_embed(state_x)
if attn_mask is not None:
tgt_mask = attn_mask.unsqueeze(1)
start_idx = attn_mask.to(torch.int64).argmax(dim=-1)
t_idx = torch.arange(time, device=x.device, dtype=torch.long)[
None, :
].expand(int(x.shape[0]), time)
pos = t_idx - start_idx
else:
tgt_mask = None
pos = torch.arange(time, device=x.device, dtype=torch.long)[
None, :
].expand(int(x.shape[0]), time)
shared_ref_tokens = self._build_shared_ref_tokens(
ref_cur_x=ref_cur_x,
ref_fut_x=ref_fut_x,
pos=pos,
tgt_mask=tgt_mask,
)
actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)
router_h = self._pool_router_context(shared_ref_tokens)
cos, sin = self.get_cos_sin(state_h, pos)
h = self._forward_layers_range(
state_h,
cos=cos,
sin=sin,
mask=tgt_mask,
router_h=router_h,
start_layer=0,
end_layer=1,
)
h = self._apply_actor_ref_film(h, actor_ref_ctx)
h = self._forward_layers_range(
h,
cos=cos,
sin=sin,
mask=tgt_mask,
router_h=router_h,
start_layer=1,
end_layer=len(self.layers),
)
h = self.norm_f(h)
return h
def forward(
self,
input: torch.Tensor,
past_key_values: torch.Tensor | None = None,
current_pos: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if past_key_values is not None:
return self._forward_inference_onnx(
input, past_key_values, current_pos
)
if input.ndim != 2:
raise ValueError(f"Expected [B, D], got {input.shape}")
mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None)
return mu_seq[:, 0, :]
def single_step_mu(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim != 2:
raise ValueError(f"Expected [B, D], got {x.shape}")
state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)
batch = int(state_x.shape[0])
if self._k_cache is None:
mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None)
return mu_seq[:, 0, :]
state_h = self.obs_embed(state_x)
self._ensure_internal_cache_device(x.device, dtype=state_h.dtype)
cache_len = self._kv_cache_len
insert_pos = self._kv_cache_write_idx
max_len = int(self.max_ctx_len)
new_len = torch.clamp(cache_len + 1, max=max_len)
self._kv_cache_len = new_len
self._kv_cache_write_idx = (insert_pos + 1) % max_len
pos = self._kv_cache_abs_pos
self._kv_cache_abs_pos = pos + 1
pos_ids = pos.unsqueeze(1)
shared_ref_tokens, _, _ = self._build_shared_ref_tokens_single_step(
ref_cur_x=ref_cur_x,
ref_fut_x=ref_fut_x,
pos_ids=pos_ids,
k_cache=self._ref_hist_k_cache[:, 0],
v_cache=self._ref_hist_v_cache[:, 0],
new_len=new_len,
insert_pos=insert_pos,
)
actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)[
:, None, :
]
router_h = self._pool_router_context(shared_ref_tokens)[:, None, :]
cos, sin = self.get_cos_sin(state_h[:, None, :], pos_ids)
h = state_h[:, None, :]
for layer_idx, layer in enumerate(self.layers[:1]):
x_norm = layer.norm1(h)
k_cache_l = self._k_cache[:, layer_idx]
v_cache_l = self._v_cache[:, layer_idx]
attn_out, _, _ = layer.attn.forward_single_token(
x_norm,
cos,
sin,
k_cache_l,
v_cache_l,
new_len,
insert_pos,
)
h = h + attn_out
h2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
ffn = layer.compute_moe_ffn(h2, router_x=router_h)
if (
layer_idx == self._last_moe_layer_idx
and layer.collect_routing_stats
and layer.last_router_distribution is not None
):
self._accumulate_last_moe_router_shift(
layer.last_router_distribution
)
else:
ffn = layer.mlp_dropout(layer.mlp(h2))
h = h + ffn
h = self._apply_actor_ref_film(h, actor_ref_ctx)
for layer_idx, layer in enumerate(self.layers[1:], start=1):
x_norm = layer.norm1(h)
k_cache_l = self._k_cache[:, layer_idx]
v_cache_l = self._v_cache[:, layer_idx]
attn_out, _, _ = layer.attn.forward_single_token(
x_norm,
cos,
sin,
k_cache_l,
v_cache_l,
new_len,
insert_pos,
)
h = h + attn_out
h2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
ffn = layer.compute_moe_ffn(h2, router_x=router_h)
if (
layer_idx == self._last_moe_layer_idx
and layer.collect_routing_stats
and layer.last_router_distribution is not None
):
self._accumulate_last_moe_router_shift(
layer.last_router_distribution
)
else:
ffn = layer.mlp_dropout(layer.mlp(h2))
h = h + ffn
h = self.norm_f(h)
return self.action_mu_head(h[:, 0, :]).reshape(batch, -1)
def _forward_inference_onnx(
self,
x: torch.Tensor,
past_key_values: torch.Tensor,
current_pos: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
state_x, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)
state_h = self.obs_embed(state_x)
batch = state_h.shape[0]
max_len = past_key_values.shape[3]
if current_pos.ndim == 0:
current_pos = current_pos.view(1).expand(batch)
insert_pos = current_pos % max_len
new_len = torch.clamp(current_pos + 1, max=max_len)
position_ids = current_pos.unsqueeze(1)
ref_layer_past = past_key_values[0]
shared_ref_tokens, ref_k_cache, ref_v_cache = (
self._build_shared_ref_tokens_single_step(
ref_cur_x=ref_cur_x,
ref_fut_x=ref_fut_x,
pos_ids=position_ids,
k_cache=ref_layer_past[0],
v_cache=ref_layer_past[1],
new_len=new_len,
insert_pos=insert_pos,
)
)
actor_ref_ctx = self.actor_ref_pool(state_h, shared_ref_tokens)[
:, None, :
]
router_h = self._pool_router_context(shared_ref_tokens)[:, None, :]
cos, sin = self.get_cos_sin(state_h[:, None, :], position_ids)
present_key_values_list = [
torch.stack([ref_k_cache, ref_v_cache], dim=0)
]
routing_debug_outputs: list[torch.Tensor] = []
export_routing_debug = torch.onnx.is_in_onnx_export()
h = state_h[:, None, :]
for i, layer in enumerate(self.layers[:1]):
layer_past = past_key_values[self.ref_hist_n_layers + i]
k_cache = layer_past[0]
v_cache = layer_past[1]
h_norm = layer.norm1(h)
attn_out, new_k_cache, new_v_cache = (
layer.attn.forward_single_token(
x=h_norm,
cos=cos,
sin=sin,
k_cache=k_cache,
v_cache=v_cache,
new_len=new_len,
insert_pos=insert_pos,
)
)
h = h + attn_out
h_norm2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
if export_routing_debug:
ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(
h_norm2,
router_x=router_h,
return_routing_debug=True,
)
routing_debug_outputs.extend([topk_idx, router_logits])
else:
ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h)
else:
ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))
h = h + ffn_out
current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)
present_key_values_list.append(current_layer_kv)
h = self._apply_actor_ref_film(h, actor_ref_ctx)
for i, layer in enumerate(self.layers[1:], start=1):
layer_past = past_key_values[self.ref_hist_n_layers + i]
k_cache = layer_past[0]
v_cache = layer_past[1]
h_norm = layer.norm1(h)
attn_out, new_k_cache, new_v_cache = (
layer.attn.forward_single_token(
x=h_norm,
cos=cos,
sin=sin,
k_cache=k_cache,
v_cache=v_cache,
new_len=new_len,
insert_pos=insert_pos,
)
)
h = h + attn_out
h_norm2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
if export_routing_debug:
ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(
h_norm2,
router_x=router_h,
return_routing_debug=True,
)
routing_debug_outputs.extend([topk_idx, router_logits])
else:
ffn_out = layer.compute_moe_ffn(h_norm2, router_x=router_h)
else:
ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))
h = h + ffn_out
current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)
present_key_values_list.append(current_layer_kv)
h = self.norm_f(h)
action = self.action_mu_head(h[:, 0, :])
present_key_values = torch.stack(present_key_values_list, dim=0)
if export_routing_debug and routing_debug_outputs:
return (action, present_key_values, *routing_debug_outputs)
return action, present_key_values
class ReferenceRoutedGroupedMoETransformerPolicyV3(
ReferenceRoutedGroupedMoETransformerPolicyV2
):
supports_explicit_ref_aux_hidden = True
def __init__(
self,
input_dim: int,
output_dim: int,
module_config_dict: dict,
):
module_config = dict(module_config_dict)
if bool(module_config.get("use_future_cross_attn", False)):
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 does not "
"support use_future_cross_attn=True."
)
state_obs_input_dim = module_config.get("state_obs_input_dim", None)
ref_cur_token_dim = module_config.get("ref_cur_token_dim", None)
ref_fut_token_dim = module_config.get("ref_fut_token_dim", None)
ref_fut_seq_len = module_config.get("ref_fut_seq_len", None)
state_feature_indices = module_config.get(
"state_feature_indices", None
)
ref_cur_feature_indices = module_config.get(
"ref_cur_feature_indices", None
)
ref_fut_slices = module_config.get("ref_fut_slices", None)
if state_obs_input_dim is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 requires "
"state_obs_input_dim."
)
if ref_cur_token_dim is None or ref_fut_token_dim is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 requires "
"ref_cur_token_dim and ref_fut_token_dim."
)
if ref_fut_seq_len is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 requires "
"ref_fut_seq_len."
)
if state_feature_indices is None or ref_cur_feature_indices is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 requires "
"state_feature_indices and ref_cur_feature_indices."
)
if ref_fut_slices is None:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 requires "
"ref_fut_slices."
)
self.full_obs_input_dim = int(input_dim)
self.state_obs_input_dim = int(state_obs_input_dim)
self.ref_cur_token_dim = int(ref_cur_token_dim)
self.ref_fut_token_dim = int(ref_fut_token_dim)
self.ref_fut_seq_len = int(ref_fut_seq_len)
self.state_feature_indices = tuple(
int(idx) for idx in state_feature_indices
)
self.ref_cur_feature_indices = tuple(
int(idx) for idx in ref_cur_feature_indices
)
self.ref_fut_slices = tuple(
(int(start), int(end), int(dim))
for start, end, dim in ref_fut_slices
)
if self.state_obs_input_dim <= 0:
raise ValueError(
"state_obs_input_dim must be positive, got "
f"{self.state_obs_input_dim}."
)
if self.ref_cur_token_dim <= 0 or self.ref_fut_token_dim <= 0:
raise ValueError(
"ref token dims must be positive, got "
f"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}."
)
if self.ref_cur_token_dim != self.ref_fut_token_dim:
raise ValueError(
"current/future ref token dims must match, got "
f"{self.ref_cur_token_dim} and {self.ref_fut_token_dim}."
)
if self.ref_fut_seq_len <= 0:
raise ValueError(
f"ref_fut_seq_len must be positive, got {self.ref_fut_seq_len}."
)
if len(self.state_feature_indices) != self.state_obs_input_dim:
raise ValueError(
"state_obs_input_dim must match len(state_feature_indices): "
f"{self.state_obs_input_dim} vs {len(self.state_feature_indices)}."
)
if len(self.ref_cur_feature_indices) != self.ref_cur_token_dim:
raise ValueError(
"ref_cur_token_dim must match len(ref_cur_feature_indices): "
f"{self.ref_cur_token_dim} vs {len(self.ref_cur_feature_indices)}."
)
fut_flat_dim = 0
for start, end, dim in self.ref_fut_slices:
if end <= start or dim <= 0:
raise ValueError(
f"Invalid ref_fut_slices entry {(start, end, dim)}."
)
if (end - start) != self.ref_fut_seq_len * dim:
raise ValueError(
"Future ref slice span must equal ref_fut_seq_len * dim, got "
f"{(start, end, dim)} with ref_fut_seq_len={self.ref_fut_seq_len}."
)
fut_flat_dim += end - start
expected_full_input_dim = (
self.state_obs_input_dim + self.ref_cur_token_dim + fut_flat_dim
)
if self.full_obs_input_dim != expected_full_input_dim:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 expected full "
f"input dim {expected_full_input_dim}, got {self.full_obs_input_dim}."
)
self.ref_hist_n_layers = int(module_config.get("ref_hist_n_layers", 1))
if self.ref_hist_n_layers != 1:
raise ValueError(
"ReferenceRoutedGroupedMoETransformerPolicyV3 currently supports "
"exactly one ref history attention layer."
)
layer_proj_hidden_default = int(
module_config.get(
"router_layer_proj_hidden_dim",
module_config.get("d_model", 256),
)
)
self.ref_motion_input_dim = int(
self.ref_cur_token_dim
+ self.ref_fut_seq_len * self.ref_fut_token_dim
)
self.router_layer_proj_hidden_dim = int(layer_proj_hidden_default)
if self.router_layer_proj_hidden_dim <= 0:
raise ValueError(
"router_layer_proj_hidden_dim must be positive, got "
f"{self.router_layer_proj_hidden_dim}."
)
GroupedMoETransformerPolicy.__init__(
self,
input_dim=input_dim,
output_dim=output_dim,
module_config_dict=module_config,
)
self.onnx_kv_layers = int(self.ref_hist_n_layers + self.n_layers)
self.register_buffer(
"_state_feature_indices",
torch.tensor(self.state_feature_indices, dtype=torch.long),
persistent=False,
)
self.register_buffer(
"_ref_cur_feature_indices",
torch.tensor(self.ref_cur_feature_indices, dtype=torch.long),
persistent=False,
)
self.ref_frame_embed = nn.Sequential(
nn.Linear(self.ref_motion_input_dim, self.obs_embed_mlp_hidden),
nn.SiLU(),
nn.Linear(self.obs_embed_mlp_hidden, self.d_model),
)
self.ref_hist_norm = RMSNorm(self.d_model)
self.ref_hist_attn = ModernAttention(
d_model=self.d_model,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
use_qk_norm=self.use_qk_norm,
use_gated_attn=self.use_gated_attn,
gated_attn_type=self.gated_attn_type,
attn_dropout=self.attn_dropout,
)
self.ref_hist_out_norm = RMSNorm(self.d_model)
self._moe_layer_indices = tuple(
i
for i, layer in enumerate(self.layers)
if isinstance(layer, GroupedMoEBlock)
)
self.router_layer_projections = nn.ModuleList(
[
nn.Sequential(
RMSNorm(self.d_model),
nn.Linear(self.d_model, self.router_layer_proj_hidden_dim),
nn.SiLU(),
nn.Linear(self.router_layer_proj_hidden_dim, self.d_model),
)
for _ in self._moe_layer_indices
]
)
self._ref_hist_k_cache: torch.Tensor | None = None
self._ref_hist_v_cache: torch.Tensor | None = None
self._apply_freeze_router_state()
def _apply_freeze_router_state(self) -> None:
GroupedMoETransformerPolicy._apply_freeze_router_state(self)
requires_grad = not self.freeze_router
self.ref_frame_embed.requires_grad_(requires_grad)
self.ref_hist_norm.requires_grad_(requires_grad)
self.ref_hist_attn.requires_grad_(requires_grad)
self.ref_hist_out_norm.requires_grad_(requires_grad)
self.router_layer_projections.requires_grad_(requires_grad)
def _build_router_ref_motion(
self,
ref_cur_x: torch.Tensor,
ref_fut_x: torch.Tensor,
) -> torch.Tensor:
if ref_cur_x.ndim not in (2, 3):
raise ValueError(
f"Expected ref_cur_x with ndim 2 or 3, got {ref_cur_x.ndim}."
)
if ref_fut_x.ndim != ref_cur_x.ndim + 1:
raise ValueError(
"Expected ref_fut_x to add one future-seq axis relative to "
f"ref_cur_x, got cur={tuple(ref_cur_x.shape)}, "
f"fut={tuple(ref_fut_x.shape)}."
)
ref_fut_flat = torch.flatten(ref_fut_x, start_dim=-2)
return torch.cat([ref_cur_x, ref_fut_flat], dim=-1)
def _build_shared_router_summary(
self,
ref_hist_h: torch.Tensor,
) -> torch.Tensor:
with self._router_no_grad_context():
return ref_hist_h
def _build_router_h_per_layer(
self,
shared_router_summary: torch.Tensor,
) -> list[torch.Tensor | None]:
with self._router_no_grad_context():
router_h_per_layer: list[torch.Tensor | None] = [
None for _ in self.layers
]
for proj, layer_idx in zip(
self.router_layer_projections, self._moe_layer_indices
):
router_h_per_layer[layer_idx] = proj(shared_router_summary)
return router_h_per_layer
def _build_ref_hist_hidden(
self,
ref_motion_x: torch.Tensor,
pos: torch.Tensor,
tgt_mask: torch.Tensor | None,
) -> torch.Tensor:
with self._router_no_grad_context():
ref_motion_h = self.ref_frame_embed(ref_motion_x)
ref_hist_attn = self.ref_hist_attn(
self.ref_hist_norm(ref_motion_h),
*self.get_cos_sin(ref_motion_h, pos),
mask=tgt_mask,
)
return self.ref_hist_out_norm(ref_motion_h + ref_hist_attn)
def _build_ref_hist_hidden_single_step(
self,
ref_motion_x: torch.Tensor,
pos_ids: torch.Tensor,
*,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
new_len: torch.Tensor,
insert_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with self._router_no_grad_context():
ref_motion_h = self.ref_frame_embed(ref_motion_x)[:, None, :]
ref_cos, ref_sin = self.get_cos_sin(ref_motion_h, pos_ids)
ref_hist_attn, ref_k_cache, ref_v_cache = (
self.ref_hist_attn.forward_single_token(
self.ref_hist_norm(ref_motion_h),
ref_cos,
ref_sin,
k_cache,
v_cache,
new_len,
insert_pos,
)
)
ref_hist_h = self.ref_hist_out_norm(ref_motion_h + ref_hist_attn)
return ref_hist_h, ref_k_cache, ref_v_cache
def predict_aux_from_pre_moe(
self,
pre_moe_hidden: torch.Tensor,
*,
ref_aux_hidden: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
return GroupedMoETransformerPolicy.predict_aux_from_pre_moe(
self,
pre_moe_hidden,
ref_aux_hidden=ref_aux_hidden,
)
def sequence_mu(
self,
x: torch.Tensor,
*,
attn_mask: torch.Tensor | None = None,
return_hidden: bool = False,
return_pre_moe_hidden: bool = False,
return_ref_aux_hidden: bool = False,
return_router_features: bool = False,
return_router_temporal_features: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
batch, time, _ = x.shape
h = self.obs_embed(x)
_, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)
ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x)
if attn_mask is not None:
tgt_mask = attn_mask.unsqueeze(1)
start_idx = attn_mask.to(torch.int64).argmax(dim=-1)
t_idx = torch.arange(time, device=x.device, dtype=torch.long)[
None, :
].expand(batch, time)
pos = t_idx - start_idx
else:
tgt_mask = None
pos = torch.arange(time, device=x.device, dtype=torch.long)[
None, :
].expand(batch, time)
ref_hist_h = self._build_ref_hist_hidden(
ref_motion_x=ref_motion_x,
pos=pos,
tgt_mask=tgt_mask,
)
shared_router_summary = self._build_shared_router_summary(ref_hist_h)
router_h_per_layer = self._build_router_h_per_layer(
shared_router_summary
)
cos, sin = self.get_cos_sin(h, pos)
if return_hidden and return_pre_moe_hidden:
raise ValueError(
"return_hidden and return_pre_moe_hidden cannot both be True."
)
forward_out = self._forward_layers(
h,
cos=cos,
sin=sin,
mask=tgt_mask,
router_h_per_layer=router_h_per_layer,
return_pre_moe_hidden=return_pre_moe_hidden,
return_router_features=return_router_features,
return_router_temporal_features=return_router_temporal_features,
)
extras: list[torch.Tensor] = []
if isinstance(forward_out, tuple):
h = forward_out[0]
extras = list(forward_out[1:])
else:
h = forward_out
h = self.norm_f(h)
mu = self.action_mu_head(h)
outputs: list[torch.Tensor] = [mu]
if return_pre_moe_hidden:
outputs.append(extras.pop(0))
if return_ref_aux_hidden:
outputs.append(shared_router_summary)
if return_router_features:
outputs.append(extras.pop(0))
if return_router_temporal_features:
outputs.append(extras.pop(0))
if len(outputs) > 1:
return tuple(outputs)
if return_hidden:
return mu, h
return mu
def sequence_hidden(
self,
x: torch.Tensor,
*,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
batch, time, _ = x.shape
h = self.obs_embed(x)
_, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)
ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x)
if attn_mask is not None:
tgt_mask = attn_mask.unsqueeze(1)
start_idx = attn_mask.to(torch.int64).argmax(dim=-1)
t_idx = torch.arange(time, device=x.device, dtype=torch.long)[
None, :
].expand(batch, time)
pos = t_idx - start_idx
else:
tgt_mask = None
pos = torch.arange(time, device=x.device, dtype=torch.long)[
None, :
].expand(batch, time)
ref_hist_h = self._build_ref_hist_hidden(
ref_motion_x=ref_motion_x,
pos=pos,
tgt_mask=tgt_mask,
)
shared_router_summary = self._build_shared_router_summary(ref_hist_h)
router_h_per_layer = self._build_router_h_per_layer(
shared_router_summary
)
cos, sin = self.get_cos_sin(h, pos)
h = self._forward_layers(
h,
cos=cos,
sin=sin,
mask=tgt_mask,
router_h_per_layer=router_h_per_layer,
)
h = self.norm_f(h)
return h
def forward(
self,
input: torch.Tensor,
past_key_values: torch.Tensor | None = None,
current_pos: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if past_key_values is not None:
return self._forward_inference_onnx(
input, past_key_values, current_pos
)
if input.ndim != 2:
raise ValueError(f"Expected [B, D], got {input.shape}")
mu_seq = self.sequence_mu(input[:, None, :], attn_mask=None)
return mu_seq[:, 0, :]
def single_step_mu(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim != 2:
raise ValueError(f"Expected [B, D], got {x.shape}")
_, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)
ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x)
batch = int(x.shape[0])
if self._k_cache is None:
mu_seq = self.sequence_mu(x[:, None, :], attn_mask=None)
return mu_seq[:, 0, :]
h = self.obs_embed(x)[:, None, :]
self._ensure_internal_cache_device(x.device, dtype=h.dtype)
cache_len = self._kv_cache_len
insert_pos = self._kv_cache_write_idx
max_len = int(self.max_ctx_len)
new_len = torch.clamp(cache_len + 1, max=max_len)
self._kv_cache_len = new_len
self._kv_cache_write_idx = (insert_pos + 1) % max_len
pos = self._kv_cache_abs_pos
self._kv_cache_abs_pos = pos + 1
pos_ids = pos.unsqueeze(1)
ref_hist_h, _, _ = self._build_ref_hist_hidden_single_step(
ref_motion_x=ref_motion_x,
pos_ids=pos_ids,
k_cache=self._ref_hist_k_cache[:, 0],
v_cache=self._ref_hist_v_cache[:, 0],
new_len=new_len,
insert_pos=insert_pos,
)
shared_router_summary = self._build_shared_router_summary(ref_hist_h)
router_h_per_layer = self._build_router_h_per_layer(
shared_router_summary
)
cos, sin = self.get_cos_sin(h, pos_ids)
for layer_idx, layer in enumerate(self.layers):
x_norm = layer.norm1(h)
k_cache_l = self._k_cache[:, layer_idx]
v_cache_l = self._v_cache[:, layer_idx]
attn_out, _, _ = layer.attn.forward_single_token(
x_norm,
cos,
sin,
k_cache_l,
v_cache_l,
new_len,
insert_pos,
)
h = h + attn_out
h2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
ffn = layer.compute_moe_ffn(
h2, router_x=router_h_per_layer[layer_idx]
)
if (
layer_idx == self._last_moe_layer_idx
and layer.collect_routing_stats
and layer.last_router_distribution is not None
):
self._accumulate_last_moe_router_shift(
layer.last_router_distribution
)
else:
ffn = layer.mlp_dropout(layer.mlp(h2))
h = h + ffn
h = self.norm_f(h)
return self.action_mu_head(h[:, 0, :]).reshape(batch, -1)
def _forward_inference_onnx(
self,
x: torch.Tensor,
past_key_values: torch.Tensor,
current_pos: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
_, ref_cur_x, ref_fut_x = self._split_actor_ref_inputs(x)
ref_motion_x = self._build_router_ref_motion(ref_cur_x, ref_fut_x)
h = self.obs_embed(x)[:, None, :]
batch = h.shape[0]
max_len = past_key_values.shape[3]
if current_pos.ndim == 0:
current_pos = current_pos.view(1).expand(batch)
insert_pos = current_pos % max_len
new_len = torch.clamp(current_pos + 1, max=max_len)
position_ids = current_pos.unsqueeze(1)
ref_layer_past = past_key_values[0]
ref_hist_h, ref_k_cache, ref_v_cache = (
self._build_ref_hist_hidden_single_step(
ref_motion_x=ref_motion_x,
pos_ids=position_ids,
k_cache=ref_layer_past[0],
v_cache=ref_layer_past[1],
new_len=new_len,
insert_pos=insert_pos,
)
)
shared_router_summary = self._build_shared_router_summary(ref_hist_h)
router_h_per_layer = self._build_router_h_per_layer(
shared_router_summary
)
cos, sin = self.get_cos_sin(h, position_ids)
present_key_values_list = [
torch.stack([ref_k_cache, ref_v_cache], dim=0)
]
routing_debug_outputs: list[torch.Tensor] = []
export_routing_debug = torch.onnx.is_in_onnx_export()
for i, layer in enumerate(self.layers):
layer_past = past_key_values[self.ref_hist_n_layers + i]
k_cache = layer_past[0]
v_cache = layer_past[1]
h_norm = layer.norm1(h)
attn_out, new_k_cache, new_v_cache = (
layer.attn.forward_single_token(
x=h_norm,
cos=cos,
sin=sin,
k_cache=k_cache,
v_cache=v_cache,
new_len=new_len,
insert_pos=insert_pos,
)
)
h = h + attn_out
h_norm2 = layer.norm2(h)
if isinstance(layer, GroupedMoEBlock):
if export_routing_debug:
ffn_out, topk_idx, router_logits = layer.compute_moe_ffn(
h_norm2,
router_x=router_h_per_layer[i],
return_routing_debug=True,
)
routing_debug_outputs.extend([topk_idx, router_logits])
else:
ffn_out = layer.compute_moe_ffn(
h_norm2, router_x=router_h_per_layer[i]
)
else:
ffn_out = layer.mlp_dropout(layer.mlp(h_norm2))
h = h + ffn_out
current_layer_kv = torch.stack([new_k_cache, new_v_cache], dim=0)
present_key_values_list.append(current_layer_kv)
h = self.norm_f(h)
action = self.action_mu_head(h[:, 0, :])
present_key_values = torch.stack(present_key_values_list, dim=0)
if export_routing_debug and routing_debug_outputs:
return (action, present_key_values, *routing_debug_outputs)
return action, present_key_values
class GroupedMoEBlock(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
num_fine_experts: int,
num_shared_experts: int,
top_k: int,
n_kv_heads: int | None = None,
ff_mult: float = 2,
use_qk_norm: bool = True,
use_gated_attn: bool = True,
gated_attn_type: str = "headwise",
attn_dropout: float = 0.0,
mlp_dropout: float = 0.0,
use_dynamic_bias: bool = False,
bias_update_rate: float = 0.001,
routing_score_fn: str = "softmax",
freeze_router: bool = False,
routing_scale: float = 1.0,
expert_bias_clip: float = 0.0,
dead_expert_margin_to_topk_enabled: bool = False,
selected_expert_margin_to_unselected_enabled: bool = False,
selected_expert_margin_to_unselected_target: float = 0.0,
use_cross_attn: bool = False,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.num_fine_experts = num_fine_experts
self.num_shared_experts = num_shared_experts
self.top_k = top_k
self.use_dynamic_bias = use_dynamic_bias
self.bias_update_rate = bias_update_rate
self.routing_score_fn = str(routing_score_fn).lower()
self.freeze_router = bool(freeze_router)
self.routing_scale = float(routing_scale)
self.expert_bias_clip = float(expert_bias_clip)
self.dead_expert_margin_to_topk_enabled = bool(
dead_expert_margin_to_topk_enabled
)
self.selected_expert_margin_to_unselected_enabled = bool(
selected_expert_margin_to_unselected_enabled
)
self.selected_expert_margin_to_unselected_target = float(
selected_expert_margin_to_unselected_target
)
if self.routing_score_fn not in ("softmax", "sigmoid"):
raise ValueError(
f"routing_score_fn must be one of {{'softmax','sigmoid'}}, got {self.routing_score_fn}"
)
if self.routing_scale <= 0.0:
raise ValueError(
f"routing_scale must be > 0, got {self.routing_scale}"
)
if self.expert_bias_clip < 0.0:
raise ValueError(
f"expert_bias_clip must be >= 0, got {self.expert_bias_clip}"
)
if self.selected_expert_margin_to_unselected_target < 0.0:
raise ValueError(
"selected_expert_margin_to_unselected_target must be >= 0, "
f"got {self.selected_expert_margin_to_unselected_target}"
)
self.register_buffer("expert_bias", torch.zeros(num_fine_experts))
self.register_buffer(
"routing_counts_accum",
torch.zeros(num_fine_experts, dtype=torch.long),
persistent=False,
)
self.register_buffer(
"last_routed_expert_usage",
torch.zeros(num_fine_experts, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"last_routed_active_expert_count",
torch.tensor(0.0),
persistent=False,
)
self.register_buffer(
"last_routed_max_expert_frac",
torch.tensor(0.0),
persistent=False,
)
self.register_buffer(
"last_active_expert_ratio", torch.tensor(0.0), persistent=False
)
self.register_buffer(
"last_max_expert_frac", torch.tensor(0.0), persistent=False
)
self.register_buffer(
"last_expert_count_cv", torch.tensor(0.0), persistent=False
)
self.register_buffer(
"last_min_expert_frac", torch.tensor(0.0), persistent=False
)
self.register_buffer(
"last_dead_expert_ratio", torch.tensor(0.0), persistent=False
)
self.register_buffer(
"last_dense_expert_usage",
torch.zeros(num_fine_experts, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"last_dead_expert_margin_to_topk_loss_value",
torch.tensor(0.0),
persistent=False,
)
self.register_buffer(
"last_dead_expert_margin_to_topk_target",
torch.tensor(0.0),
persistent=False,
)
self.register_buffer(
"last_selected_expert_margin_to_unselected",
torch.tensor(0.0),
persistent=False,
)
self.register_buffer(
"last_selected_expert_margin_to_unselected_loss_value",
torch.tensor(0.0),
persistent=False,
)
self.collect_routing_stats = False
self.collect_router_distribution = False
self.capture_router_distribution = False
self.capture_router_logits = False
self.last_router_distribution: torch.Tensor | None = None
self.last_router_logits: torch.Tensor | None = None
self.last_dead_expert_margin_to_topk_loss: torch.Tensor | None = None
self.last_selected_expert_margin_to_unselected_loss: (
torch.Tensor | None
) = None
self.use_cross_attn = bool(use_cross_attn)
self.norm1 = RMSNorm(d_model)
self.attn = ModernAttention(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
use_gated_attn=use_gated_attn,
gated_attn_type=gated_attn_type,
attn_dropout=attn_dropout,
)
if self.use_cross_attn:
self.norm_cross = RMSNorm(d_model)
self.cross_attn = ModernCrossAttention(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
use_gated_attn=use_gated_attn,
gated_attn_type=gated_attn_type,
attn_dropout=attn_dropout,
)
else:
self.norm_cross = None
self.cross_attn = None
self.norm2 = RMSNorm(d_model)
self.intermediate_dim = int(d_model * ff_mult)
self.router = nn.Linear(d_model, num_fine_experts, bias=False)
self._apply_freeze_router_state()
# Gate + Up (Combined)
self.gate_up_proj = nn.Parameter(
torch.empty(
num_fine_experts, self.d_model, 2 * self.intermediate_dim
)
)
# Down
self.down_proj = nn.Parameter(
torch.empty(num_fine_experts, self.intermediate_dim, self.d_model)
)
self.shared_experts = DeepseekV3MLP(
hidden_size=d_model,
intermediate_size=int(d_model * ff_mult * num_shared_experts),
)
self.mlp_dropout = (
nn.Dropout(mlp_dropout) if mlp_dropout > 0.0 else nn.Identity()
)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.gate_up_proj)
nn.init.xavier_uniform_(self.down_proj)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
gate_up_key = prefix + "gate_up_proj"
current_gate_up_shape = tuple(self.gate_up_proj.shape)
legacy_gate_up_shape = (
self.num_fine_experts,
2 * self.intermediate_dim,
self.d_model,
)
down_key = prefix + "down_proj"
current_down_shape = tuple(self.down_proj.shape)
legacy_down_shape = (
self.num_fine_experts,
self.d_model,
self.intermediate_dim,
)
is_legacy_layout = None
if gate_up_key in state_dict:
gate_up_shape = tuple(state_dict[gate_up_key].shape)
gate_up_is_current = gate_up_shape == current_gate_up_shape
gate_up_is_legacy = gate_up_shape == legacy_gate_up_shape
if gate_up_is_current and not gate_up_is_legacy:
is_legacy_layout = False
elif gate_up_is_legacy and not gate_up_is_current:
is_legacy_layout = True
if is_legacy_layout is None and down_key in state_dict:
down_shape = tuple(state_dict[down_key].shape)
down_is_current = down_shape == current_down_shape
down_is_legacy = down_shape == legacy_down_shape
if down_is_current and not down_is_legacy:
is_legacy_layout = False
elif down_is_legacy and not down_is_current:
is_legacy_layout = True
if gate_up_key in state_dict:
gate_up_w = state_dict[gate_up_key]
gate_up_shape = tuple(gate_up_w.shape)
gate_up_is_legacy_only = (
gate_up_shape == legacy_gate_up_shape
and gate_up_shape != current_gate_up_shape
)
gate_up_is_ambiguous = (
gate_up_shape == legacy_gate_up_shape
and gate_up_shape == current_gate_up_shape
)
if gate_up_is_legacy_only or (
gate_up_is_ambiguous and is_legacy_layout
):
state_dict[gate_up_key] = gate_up_w.transpose(
-2, -1
).contiguous()
if down_key in state_dict:
down_w = state_dict[down_key]
down_shape = tuple(down_w.shape)
down_is_legacy_only = (
down_shape == legacy_down_shape
and down_shape != current_down_shape
)
down_is_ambiguous = (
down_shape == legacy_down_shape
and down_shape == current_down_shape
)
if down_is_legacy_only or (down_is_ambiguous and is_legacy_layout):
state_dict[down_key] = down_w.transpose(-2, -1).contiguous()
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
self._apply_freeze_router_state()
def _apply_freeze_router_state(self) -> None:
self.router.requires_grad_(not self.freeze_router)
def reset_routing_stats(self) -> None:
self.routing_counts_accum.zero_()
self.last_router_distribution = None
self.last_router_logits = None
self.last_routed_expert_usage.zero_()
self.last_routed_active_expert_count.zero_()
self.last_routed_max_expert_frac.zero_()
self.last_dense_expert_usage.zero_()
self.last_dead_expert_margin_to_topk_loss_value.zero_()
self.last_dead_expert_margin_to_topk_target.zero_()
self.last_dead_expert_margin_to_topk_loss = None
self.last_selected_expert_margin_to_unselected.zero_()
self.last_selected_expert_margin_to_unselected_loss_value.zero_()
self.last_selected_expert_margin_to_unselected_loss = None
def accumulate_routing_stats(self, topk_idx: torch.Tensor) -> None:
with torch.no_grad():
counts = torch.bincount(
topk_idx.reshape(-1), minlength=self.num_fine_experts
)
self.routing_counts_accum.add_(counts)
def _apply_bias_update_from_counts(self, counts: torch.Tensor) -> None:
with torch.no_grad():
if dist.is_available() and dist.is_initialized():
dist.all_reduce(counts, op=dist.ReduceOp.SUM)
total = counts.sum()
if int(total.item()) == 0:
self.last_active_expert_ratio.zero_()
self.last_max_expert_frac.zero_()
self.last_expert_count_cv.zero_()
self.last_min_expert_frac.zero_()
self.last_dead_expert_ratio.zero_()
return
if self.use_dynamic_bias:
avg = counts.float().mean()
error = avg - counts.float()
self.expert_bias.add_(
self.bias_update_rate * torch.sign(error)
)
total = total.clamp_min(1)
active_ratio = (counts > 0).to(torch.float32).mean()
max_expert_frac = counts.max().to(torch.float32) / total.to(
torch.float32
)
min_expert_frac = counts.min().to(torch.float32) / total.to(
torch.float32
)
dead_expert_ratio = (counts == 0).to(torch.float32).mean()
counts_f = counts.to(torch.float32)
counts_mean = counts_f.mean().clamp_min(1.0e-6)
counts_std = counts_f.std(unbiased=False)
expert_count_cv = counts_std / counts_mean
self.last_active_expert_ratio.copy_(active_ratio)
self.last_max_expert_frac.copy_(max_expert_frac)
self.last_expert_count_cv.copy_(expert_count_cv)
self.last_min_expert_frac.copy_(min_expert_frac)
self.last_dead_expert_ratio.copy_(dead_expert_ratio)
if self.use_dynamic_bias and self.expert_bias_clip > 0.0:
self.expert_bias.clamp_(
min=-self.expert_bias_clip, max=self.expert_bias_clip
)
def apply_bias_update_from_counts(self) -> None:
with torch.no_grad():
counts = self.routing_counts_accum.clone()
self.routing_counts_accum.zero_()
self._apply_bias_update_from_counts(counts)
def _update_routed_expert_stats_and_floor_loss(
self,
topk_idx: torch.Tensor,
dense_distribution: torch.Tensor,
choice_scores: torch.Tensor,
) -> torch.Tensor:
counts = torch.bincount(
topk_idx.reshape(-1), minlength=self.num_fine_experts
).to(torch.float32)
total_assignments = max(int(topk_idx.numel()), 1)
hard_usage = counts / float(total_assignments)
active_count = (counts > 0).to(torch.float32).sum()
max_frac = hard_usage.max() if hard_usage.numel() > 0 else counts.sum()
with torch.no_grad():
self.last_routed_expert_usage.copy_(
hard_usage.to(self.last_routed_expert_usage.dtype)
)
self.last_routed_active_expert_count.copy_(
active_count.to(self.last_routed_active_expert_count.dtype)
)
self.last_routed_max_expert_frac.copy_(
max_frac.to(self.last_routed_max_expert_frac.dtype)
)
dense_usage = dense_distribution.to(torch.float32).mean(dim=(0, 1))
with torch.no_grad():
self.last_dense_expert_usage.copy_(
dense_usage.detach().to(self.last_dense_expert_usage.dtype)
)
kth_choice_score = choice_scores.gather(-1, topk_idx)[..., -1:]
if self.top_k < self.num_fine_experts:
selected_mask = F.one_hot(
topk_idx, num_classes=self.num_fine_experts
).any(dim=-2)
best_unselected_score = (
choice_scores.masked_fill(
selected_mask, torch.finfo(choice_scores.dtype).min
)
.max(dim=-1, keepdim=True)
.values
)
selected_margin_gap = kth_choice_score - best_unselected_score
selected_margin = selected_margin_gap.mean()
else:
selected_margin_gap = choice_scores.new_zeros(
choice_scores.shape[:2] + (1,)
)
selected_margin = choice_scores.new_zeros(())
if self.selected_expert_margin_to_unselected_enabled:
selected_margin_loss = torch.relu(
self.selected_expert_margin_to_unselected_target
- selected_margin_gap
).mean()
else:
selected_margin_loss = choice_scores.new_zeros(())
with torch.no_grad():
self.last_selected_expert_margin_to_unselected.copy_(
selected_margin.detach().to(
self.last_selected_expert_margin_to_unselected.dtype
)
)
self.last_selected_expert_margin_to_unselected_loss_value.copy_(
selected_margin_loss.detach().to(
self.last_selected_expert_margin_to_unselected_loss_value.dtype
)
)
self.last_selected_expert_margin_to_unselected_loss = (
selected_margin_loss
)
if not self.dead_expert_margin_to_topk_enabled:
margin_loss = dense_distribution.new_zeros(())
with torch.no_grad():
self.last_dead_expert_margin_to_topk_loss_value.zero_()
self.last_dead_expert_margin_to_topk_target.zero_()
self.last_dead_expert_margin_to_topk_loss = margin_loss
return margin_loss
dead_mask = (counts == 0).to(choice_scores.dtype)
margin_gap = torch.relu(kth_choice_score - choice_scores)
dead_margin_sum = (
margin_gap * dead_mask.view(1, 1, self.num_fine_experts)
).sum()
dead_count = dead_mask.sum()
num_tokens = choice_scores.new_ones(
choice_scores.shape[:2], dtype=choice_scores.dtype
).sum()
normalizer = dead_count.clamp_min(1.0) * num_tokens
margin_loss = dead_margin_sum / normalizer
with torch.no_grad():
self.last_dead_expert_margin_to_topk_loss_value.copy_(
margin_loss.detach().to(
self.last_dead_expert_margin_to_topk_loss_value.dtype
)
)
self.last_dead_expert_margin_to_topk_target.copy_(
kth_choice_score.mean()
.detach()
.to(self.last_dead_expert_margin_to_topk_target.dtype)
)
self.last_dead_expert_margin_to_topk_loss = margin_loss
return margin_loss
@torch.compiler.disable
def _compute_sparse_experts(
self,
x: torch.Tensor,
topk_idx: torch.Tensor,
topk_scores: torch.Tensor,
) -> torch.Tensor:
B, T, D = x.size()
num_top_k = self.top_k
is_exporting = torch.onnx.is_in_onnx_export()
if is_exporting:
# ONNX/runtime path: compute only selected experts (top_k),
# avoiding O(num_experts) per-step overhead at bs=1.
return self._compute_with_topk_selection(x, topk_idx, topk_scores)
x_flat = x.view(-1, D)
expert_ids = topk_idx.view(-1)
scores = topk_scores.view(-1)
raw_token_indices = (
torch.arange(B * T, device=x.device)
.unsqueeze(1)
.expand(-1, num_top_k)
.reshape(-1)
)
sorted_expert_ids, perm = torch.sort(expert_ids)
sorted_token_indices = raw_token_indices[perm]
x_sorted = x_flat[sorted_token_indices]
scores_sorted = scores[perm]
# Path B: High-Performance Grouped GEMM
output_sorted = self._compute_with_grouped_mm(
x_sorted, sorted_expert_ids
)
output_sorted = output_sorted * scores_sorted.unsqueeze(-1)
inv_perm = torch.argsort(perm)
output_flat = output_sorted[inv_perm]
output_final = output_flat.view(B * T, num_top_k, D).sum(dim=1)
return output_final.view(B, T, D)
def _compute_with_grouped_mm(
self, x_sorted: torch.Tensor, sorted_expert_ids: torch.Tensor
) -> torch.Tensor:
"""Based on official implementation logic:
- offsets must be Cumsum (End-Indices).
- offsets length must be exactly Num_Experts (NOT N+1).
- dtype must be int32.
"""
tokens_per_expert = torch.bincount(
sorted_expert_ids.long(), minlength=self.num_fine_experts
)
counts = tokens_per_expert[: self.num_fine_experts]
offsets = torch.cumsum(counts, dim=0, dtype=torch.int32)
gate_up_out = _grouped_linear(
x_sorted, self.gate_up_proj, offs=offsets
)
x1, x2 = gate_up_out.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
out = _grouped_linear(hidden, self.down_proj, offs=offsets)
return out
def _compute_with_topk_selection(
self,
x: torch.Tensor,
topk_idx: torch.Tensor,
topk_scores: torch.Tensor,
) -> torch.Tensor:
"""ONNX-friendly sparse expert compute that scales with top_k, not
num_fine_experts.
"""
B, T, D = x.shape
N = B * T
K = self.top_k
orig_dtype = x.dtype
x_tokens = x.reshape(N, D)
idx = topk_idx.reshape(N, K)
scores = topk_scores.reshape(N, K)
x_rep = x_tokens[:, None, :].expand(N, K, D).reshape(N * K, D)
idx_flat = idx.reshape(N * K)
compute_dtype = self.gate_up_proj.dtype
if x_rep.dtype != compute_dtype:
x_rep = x_rep.to(compute_dtype)
gate_up_w = self.gate_up_proj.index_select(0, idx_flat)
gate_up_out = torch.bmm(x_rep.unsqueeze(1), gate_up_w).squeeze(1)
x1, x2 = gate_up_out.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
down_w = self.down_proj.index_select(0, idx_flat)
sparse_flat = torch.bmm(hidden.unsqueeze(1), down_w).squeeze(1)
if sparse_flat.dtype != orig_dtype:
sparse_flat = sparse_flat.to(orig_dtype)
sparse = sparse_flat.view(N, K, D)
weighted = sparse * scores.to(sparse.dtype).unsqueeze(-1)
out = weighted.sum(dim=1)
return out.view(B, T, D)
def _compute_with_loop_fallback(
self, x_sorted: torch.Tensor, sorted_expert_ids: torch.Tensor
) -> torch.Tensor:
"""Path A: Loop Fallback (Compatible with F.linear and 3D Weights)"""
results = []
for i in range(self.num_fine_experts):
mask = sorted_expert_ids == i
inp_i = x_sorted[mask]
# Gate + Up
w_gate_up = self.gate_up_proj[i].transpose(0, 1)
gate_up_out = F.linear(inp_i, w_gate_up)
x1, x2 = gate_up_out.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
# Down
w_down = self.down_proj[i].transpose(0, 1)
out_i = F.linear(hidden, w_down)
results.append(out_i)
return torch.cat(results, dim=0)
def compute_moe_ffn(
self,
x: torch.Tensor,
router_x: torch.Tensor | None = None,
*,
return_routing_debug: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, D = x.shape
should_cache_router_distribution = (
self.collect_routing_stats and self.collect_router_distribution
) or self.capture_router_distribution
should_cache_router_logits = self.capture_router_logits
# 1. Shared Experts (Dense Path)
shared_out = self.shared_experts(x)
# 2. Router (Gating)
router_input = x if router_x is None else router_x
if router_input.shape != x.shape:
raise ValueError(
"router_x shape must match x shape in compute_moe_ffn: "
f"x={tuple(x.shape)}, router_x={tuple(router_input.shape)}"
)
if self.freeze_router:
with torch.no_grad():
logits = self.router(router_input)
else:
logits = self.router(router_input)
logits_fp32 = logits.to(torch.float32)
bias_fp32 = None
if self.use_dynamic_bias:
bias_fp32 = self.expert_bias.to(
device=logits.device, dtype=torch.float32
)
if self.routing_score_fn == "softmax":
choice_logits = logits_fp32
if bias_fp32 is not None:
# Keep dynamic bias as a selection correction, not a mixture-weight shaper.
choice_logits = choice_logits + bias_fp32
choice_scores = choice_logits
_, topk_idx = torch.topk(choice_scores, self.top_k, dim=-1)
dense_distribution = torch.softmax(logits_fp32, dim=-1)
if torch.onnx.is_in_onnx_export():
selected_probs = dense_distribution.gather(-1, topk_idx)
else:
selected_logits = logits_fp32.gather(-1, topk_idx)
log_z = torch.logsumexp(logits_fp32, dim=-1, keepdim=True)
selected_probs = torch.exp(selected_logits - log_z)
topk_scores = selected_probs / selected_probs.sum(
dim=-1, keepdim=True
).clamp_min(1.0e-20)
router_distribution = None
if should_cache_router_distribution:
router_distribution = dense_distribution
else: # sigmoid
scores = torch.sigmoid(logits_fp32)
dense_distribution = scores / scores.sum(
dim=-1, keepdim=True
).clamp_min(1.0e-20)
scores_for_choice = scores
if bias_fp32 is not None:
# DeepSeek-style correction bias for expert choice.
scores_for_choice = scores_for_choice + bias_fp32
choice_scores = scores_for_choice
_, topk_idx = torch.topk(choice_scores, self.top_k, dim=-1)
selected_scores = scores.gather(-1, topk_idx)
# Match DeepSeek-style routing: bias affects only expert choice,
# while the expert mixing weights come from the original sigmoid
# affinities normalized over the selected experts.
topk_scores = selected_scores / selected_scores.sum(
dim=-1, keepdim=True
).clamp_min(1.0e-20)
router_distribution = None
if should_cache_router_distribution:
router_distribution = dense_distribution
if self.collect_routing_stats:
self.accumulate_routing_stats(topk_idx)
if (
should_cache_router_distribution
and router_distribution is not None
):
self.last_router_distribution = router_distribution
else:
self.last_router_distribution = None
if should_cache_router_logits:
self.last_router_logits = logits_fp32
else:
self.last_router_logits = None
self._update_routed_expert_stats_and_floor_loss(
topk_idx=topk_idx,
dense_distribution=dense_distribution,
choice_scores=choice_scores,
)
if self.routing_scale != 1.0:
topk_scores = topk_scores * self.routing_scale
topk_scores = topk_scores.to(logits.dtype)
# 3. Sparse Experts Computation (Grouped MM / ONNX Loop)
sparse_out = self._compute_sparse_experts(x, topk_idx, topk_scores)
# 4. Combine
output = shared_out + sparse_out
output = self.mlp_dropout(output)
if return_routing_debug:
return output, topk_idx, logits_fp32
return output
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor = None,
sin: torch.Tensor = None,
mask: torch.Tensor | None = None,
memory: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
router_x: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass compatible with ONNX and Attention/Norm."""
norm_x = self.norm1(x)
attn_out = self.attn(norm_x, cos, sin, mask)
x = x + attn_out
if self.use_cross_attn and memory is not None:
x_cross = self.norm_cross(x)
if memory.ndim == 4:
b, t, d_model = x_cross.shape
_, _, n_fut, _ = memory.shape
q = x_cross.reshape(b * t, 1, d_model)
mem = memory.reshape(b * t, n_fut, d_model)
mem_mask = None
if memory_mask is not None:
if memory_mask.ndim != 3:
raise ValueError(
"memory_mask for 4D memory must have shape [B, T, N_fut]"
)
mem_mask = memory_mask.reshape(b * t, 1, 1, n_fut)
cross = self.cross_attn(q, mem, mem_mask).reshape(
b, t, d_model
)
else:
cross = self.cross_attn(x_cross, memory, memory_mask)
x = x + cross
h = self.norm2(x)
ffn_out = self.compute_moe_ffn(h, router_x=router_x)
x = x + ffn_out
return x
class ModernTransformerBlock(nn.Module):
"""Modern Transformer block with pre-norm, SwiGLU MLP, and modern attention.
Features:
- Pre-normalization with RMSNorm.
- ModernAttention (GQA, QK-Norm, RealRoPE, Gated Attention).
- DeepseekV3MLP (SwiGLU) for feed-forward.
"""
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int | None = None,
ff_mult: int = 4,
use_qk_norm: bool = True,
use_gated_attn: bool = True,
gated_attn_type: str = "headwise",
attn_dropout: float = 0.0,
mlp_dropout: float = 0.0,
use_cross_attn: bool = False,
):
super().__init__()
self.use_cross_attn = bool(use_cross_attn)
self.norm1 = RMSNorm(d_model)
self.attn = ModernAttention(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
use_gated_attn=use_gated_attn,
gated_attn_type=gated_attn_type,
attn_dropout=attn_dropout,
)
if self.use_cross_attn:
self.norm_cross = RMSNorm(d_model)
self.cross_attn = ModernCrossAttention(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
use_gated_attn=use_gated_attn,
gated_attn_type=gated_attn_type,
attn_dropout=attn_dropout,
)
else:
self.norm_cross = None
self.cross_attn = None
self.norm2 = RMSNorm(d_model)
self.mlp = DeepseekV3MLP(
hidden_size=d_model, intermediate_size=d_model * ff_mult
)
self.mlp_dropout = (
nn.Dropout(mlp_dropout) if mlp_dropout > 0.0 else nn.Identity()
)
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: torch.Tensor | None = None,
memory: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with pre-norm residual connections.
Args:
x: Input tensor [B, T, d_model].
freqs_cis: RoPE frequencies [T, head_dim // 2].
mask: Attention mask [T, T] or [B, T, T], True = allowed (can attend).
Returns:
out: Output tensor [B, T, d_model].
"""
x = x + self.attn(self.norm1(x), cos, sin, mask)
if self.use_cross_attn and memory is not None:
x_cross = self.norm_cross(x)
if memory.ndim == 4:
b, t, d_model = x_cross.shape
_, _, n_fut, _ = memory.shape
q = x_cross.reshape(b * t, 1, d_model)
mem = memory.reshape(b * t, n_fut, d_model)
mem_mask = None
if memory_mask is not None:
if memory_mask.ndim != 3:
raise ValueError(
"memory_mask for 4D memory must have shape [B, T, N_fut]"
)
mem_mask = memory_mask.reshape(b * t, 1, 1, n_fut)
cross = self.cross_attn(q, mem, mem_mask).reshape(
b, t, d_model
)
else:
cross = self.cross_attn(x_cross, memory, memory_mask)
x = x + cross
x = x + self.mlp_dropout(self.mlp(self.norm2(x)))
return x
class DeepseekV3MLP(nn.Module):
"""SwiGLU MLP with fused gate+up projection for efficiency."""
def __init__(self, hidden_size=None, intermediate_size=None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# Fused gate and up projection: outputs [gate, up] concatenated
self.gate_up_proj = nn.Linear(
self.hidden_size,
2 * self.intermediate_size,
bias=True,
)
self.down_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=True,
)
self.act_fn = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., hidden_size]
gate_up = self.gate_up_proj(x) # [..., 2 * intermediate_size]
gate, up = gate_up.chunk(2, dim=-1) # each [..., intermediate_size]
return self.down_proj(self.act_fn(gate) * up)
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (used in Llama, DeepSeek, Qwen)."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
var = x.pow(2).mean(-1, keepdim=True)
x_normed = x * torch.rsqrt(var + self.eps)
return x_normed * self.weight
class ModernAttention(nn.Module):
"""Modern attention with GQA, QK-Norm, RealRoPE, Flash Attention, and Gated Attention.
Features:
- GQA: Grouped Query Attention (n_kv_heads < n_heads).
- QK-Norm: RMSNorm on queries and keys for stability.
- RealRoPE: Real-valued Rotary Positional Embeddings.
- Flash Attention: via F.scaled_dot_product_attention.
- Gated Attention: Headwise or element-wise sigmoid gating (Qwen3-style).
- Fused Projections: Q separate, KV fused for efficiency.
Reference: https://github.com/qiuzh20/gated_attention
"""
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int | None = None,
use_qk_norm: bool = True,
use_gated_attn: bool = True,
gated_attn_type: str = "headwise",
attn_dropout: float = 0.0,
):
"""Initialize ModernAttention.
Args:
d_model: Model dimension.
n_heads: Number of query heads.
n_kv_heads: Number of key/value heads (for GQA). Defaults to n_heads.
use_qk_norm: Apply RMSNorm to Q and K.
use_gated_attn: Enable gated attention.
gated_attn_type: "headwise" (Qwen3-style, one gate per head) or
"elementwise" (one gate per element).
attn_dropout: Dropout probability for attention.
"""
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
self.head_dim = d_model // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.use_qk_norm = use_qk_norm
self.use_gated_attn = use_gated_attn
self.gated_attn_type = gated_attn_type
self.attn_dropout = attn_dropout
# Fused projections: Q separate, KV fused (for GQA efficiency)
self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
self.kv_proj = nn.Linear(
d_model, 2 * self.n_kv_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
if self.use_gated_attn:
if self.gated_attn_type == "headwise":
# Qwen3-style: one gate scalar per head [B, T, n_heads]
self.gate_proj = nn.Linear(d_model, n_heads, bias=False)
else:
# Element-wise: gate each element [B, T, d_model]
self.gate_proj = nn.Linear(d_model, d_model, bias=False)
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor, # [B, T, D]
sin: torch.Tensor, # [B, T, D]
mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass.
Args:
x: Input tensor [B, T, d_model].
cos: RoPE cosine frequencies [B, T, head_dim].
sin: RoPE sine frequencies [B, T, head_dim].
mask: Attention mask [T, T] or [B, T, T], True = allowed (can attend).
Returns:
out: Output tensor [B, T, d_model].
"""
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim)
kv = self.kv_proj(x).view(B, T, 2, self.n_kv_heads, self.head_dim)
k, v = kv[:, :, 0], kv[:, :, 1] # each [B, T, n_kv_heads, head_dim]
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose for SDPA: [B, n_heads, T, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Flash Attention via SDPA (handles GQA internally)
dropout_p = self.attn_dropout if self.training else 0.0
is_exporting = torch.onnx.is_in_onnx_export()
if is_exporting:
k = repeat_kv(k, self.n_rep)
v = repeat_kv(v, self.n_rep)
enable_gqa = False
else:
enable_gqa = True
attn_out = export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=dropout_p,
is_causal=(mask is None),
enable_gqa=enable_gqa,
)
# attn_out: [B, n_heads, T, head_dim]
# Gated Attention (Qwen3-style)
if self.use_gated_attn:
if self.gated_attn_type == "headwise":
# Headwise gating: [B, T, n_heads] -> [B, n_heads, T, 1]
g = torch.sigmoid(self.gate_proj(x)) # [B, T, n_heads]
g = g.transpose(1, 2)[..., None] # [B, n_heads, T, 1]
attn_out = attn_out * g
else:
# Element-wise gating: apply after reshaping
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1)
g = torch.sigmoid(self.gate_proj(x)) # [B, T, d_model]
attn_out = attn_out * g
return self.o_proj(attn_out)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(attn_out)
def forward_single_token(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
new_len: torch.Tensor,
insert_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward for single token with per-environment KV cache update.
Args:
x: Input tensor [B, 1, d_model].
cos: RoPE cosine frequencies [B, 1, head_dim].
sin: RoPE sine frequencies [B, 1, head_dim].
k_cache: K cache for this layer [B, max_ctx_len, n_kv_heads, head_dim].
v_cache: V cache for this layer [B, max_ctx_len, n_kv_heads, head_dim].
new_len: New valid cache length per env AFTER inserting this token [B].
insert_pos: Insert position per env [B].
Returns:
attn_out: [B, 1, d_model]
k_cache: Updated K cache
v_cache: Updated V cache
"""
B = x.shape[0]
q = self.q_proj(x).view(B, 1, self.n_heads, self.head_dim)
kv = self.kv_proj(x).view(B, 1, 2, self.n_kv_heads, self.head_dim)
k_new, v_new = kv[:, :, 0], kv[:, :, 1] # [B, 1, n_kv_heads, head_dim]
if self.use_qk_norm:
q = self.q_norm(q)
k_new = self.k_norm(k_new)
# Apply RoPE with per-environment position
q = q.transpose(1, 2)
k_new = k_new.transpose(1, 2)
q, k_new = apply_rotary_pos_emb(q, k_new, cos, sin)
q = q.transpose(1, 2)
k_new = k_new.transpose(1, 2)
# Scatter K, V into cache at per-env insert positions
# insert_pos: [B] -> [B, 1, 1, 1] for scatter
idx = (
insert_pos.view(B, 1, 1, 1)
.expand(B, 1, self.n_kv_heads, self.head_dim)
.to(torch.int64)
)
if torch.onnx.is_in_onnx_export():
# === ONNX 模式: Out-of-place (生成新 Tensor) ===
k_cache = k_cache.scatter(1, idx, k_new.to(k_cache.dtype))
v_cache = v_cache.scatter(1, idx, v_new.to(v_cache.dtype))
else:
# === Rollout 模式: In-place (原地修改) ===
k_cache.scatter_(1, idx, k_new.to(k_cache.dtype))
v_cache.scatter_(1, idx, v_new.to(v_cache.dtype))
# Compute attention over cached keys/values
# Mask out positions >= new_len (after insert)
max_len = k_cache.shape[1]
new_len = new_len.clamp(max=max_len) # [B]
# Build per-env mask: [B, max_len] where True = valid (can attend)
pos_idx = torch.arange(max_len, device=x.device, dtype=torch.int64)
valid_mask = pos_idx[None, :] < new_len[:, None] # [B, max_len]
# For SDPA bool mask: True = allowed (can attend)
attn_mask = valid_mask[:, None, None, :] # [B, 1, 1, max_len]
# GQA: Use native SDPA broadcasting (no repeat_interleave)
k_attn = k_cache.to(q.dtype)
v_attn = v_cache.to(q.dtype)
# Transpose for SDPA: [B, n_heads, T, head_dim]
q_t = q.transpose(1, 2) # [B, n_heads, 1, head_dim]
k_t = k_attn.transpose(1, 2) # [B, n_kv_heads, max_len, head_dim]
v_t = v_attn.transpose(1, 2)
dropout_p = self.attn_dropout if self.training else 0.0
is_exporting = torch.onnx.is_in_onnx_export()
if is_exporting:
k_t = repeat_kv(k_t, self.n_rep)
v_t = repeat_kv(v_t, self.n_rep)
enable_gqa = False
else:
enable_gqa = True
attn_out = export_safe_scaled_dot_product_attention(
q_t,
k_t,
v_t,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=False,
enable_gqa=enable_gqa,
)
# attn_out: [B, n_heads, 1, head_dim]
# Gated Attention
if self.use_gated_attn:
if self.gated_attn_type == "headwise":
g = torch.sigmoid(self.gate_proj(x)) # [B, 1, n_heads]
g = g.transpose(1, 2)[..., None] # [B, n_heads, 1, 1]
attn_out = attn_out * g
else:
attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, -1)
g = torch.sigmoid(self.gate_proj(x))
attn_out = attn_out * g
return self.o_proj(attn_out), k_cache, v_cache
attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, -1)
return self.o_proj(attn_out), k_cache, v_cache
class ModernCrossAttention(nn.Module):
"""Cross-attention with GQA/QK-norm and optional gated attention."""
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int | None = None,
use_qk_norm: bool = True,
use_gated_attn: bool = True,
gated_attn_type: str = "headwise",
attn_dropout: float = 0.0,
):
super().__init__()
self.d_model = int(d_model)
self.n_heads = int(n_heads)
self.n_kv_heads = (
int(n_kv_heads) if n_kv_heads is not None else int(n_heads)
)
self.head_dim = self.d_model // self.n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.use_qk_norm = bool(use_qk_norm)
self.use_gated_attn = bool(use_gated_attn)
self.gated_attn_type = str(gated_attn_type)
self.attn_dropout = float(attn_dropout)
self.q_proj = nn.Linear(
self.d_model, self.n_heads * self.head_dim, bias=False
)
self.kv_proj = nn.Linear(
self.d_model, 2 * self.n_kv_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.n_heads * self.head_dim, self.d_model, bias=False
)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
if self.use_gated_attn:
if self.gated_attn_type == "headwise":
self.gate_proj = nn.Linear(
self.d_model, self.n_heads, bias=False
)
else:
self.gate_proj = nn.Linear(
self.d_model, self.d_model, bias=False
)
def forward(
self,
x: torch.Tensor,
memory: torch.Tensor,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
if x.ndim != 3:
raise ValueError(f"x must be [B, T, D], got {tuple(x.shape)}")
if memory.ndim != 3:
raise ValueError(
f"memory must be [B, N, D], got {tuple(memory.shape)}"
)
b, t, _ = x.shape
bm, n, _ = memory.shape
if bm != b:
raise ValueError(
f"batch mismatch between x and memory: {b} vs {bm}"
)
q = self.q_proj(x).view(b, t, self.n_heads, self.head_dim)
kv = self.kv_proj(memory).view(b, n, 2, self.n_kv_heads, self.head_dim)
k, v = kv[:, :, 0], kv[:, :, 1]
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
dropout_p = self.attn_dropout if self.training else 0.0
is_exporting = torch.onnx.is_in_onnx_export()
if is_exporting:
k = repeat_kv(k, self.n_rep)
v = repeat_kv(v, self.n_rep)
enable_gqa = False
else:
enable_gqa = True
attn_out = export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=dropout_p,
is_causal=False,
enable_gqa=enable_gqa,
)
if self.use_gated_attn:
if self.gated_attn_type == "headwise":
g = torch.sigmoid(self.gate_proj(x))
g = g.transpose(1, 2)[..., None]
attn_out = attn_out * g
else:
attn_out = attn_out.transpose(1, 2).contiguous().view(b, t, -1)
g = torch.sigmoid(self.gate_proj(x))
attn_out = attn_out * g
return self.o_proj(attn_out)
attn_out = attn_out.transpose(1, 2).contiguous().view(b, t, -1)
return self.o_proj(attn_out)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Standard LLaMA GQA replication logic, optimized for ONNX export.
Equivalent to torch.repeat_interleave(x, dim=1, repeats=n_rep).
Input shape: (batch, num_key_value_heads, seqlen, head_dim)
Output shape: (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
# 1. Unsqueeze: [batch, n_kv, 1, seq, dim]
# 2. Expand: [batch, n_kv, n_rep, seq, dim]
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
# 3. Reshape: [batch, n_kv * n_rep, seq, dim] -> [batch, n_head, seq, dim]
return hidden_states.reshape(
batch, num_key_value_heads * n_rep, slen, head_dim
)
def export_safe_scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
attn_mask: torch.Tensor | None,
dropout_p: float,
is_causal: bool,
enable_gqa: bool = False,
) -> torch.Tensor:
if (
not torch.onnx.is_in_onnx_export()
or attn_mask is None
or attn_mask.dtype != torch.bool
):
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
enable_gqa=enable_gqa,
)
# Use additive float bias during ONNX export so the legacy exporter
# does not emit the bool-mask SDPA cleanup path with IsNaN.
mask_bias = torch.zeros_like(attn_mask, dtype=q.dtype)
mask_bias = mask_bias.masked_fill(~attn_mask, torch.finfo(q.dtype).min)
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask_bias,
dropout_p=dropout_p,
is_causal=is_causal,
enable_gqa=enable_gqa,
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
orig_dtype = q.dtype
# 强制转为 fp32 进行计算
q_fp32 = q.to(torch.float32)
k_fp32 = k.to(torch.float32)
cos_fp32 = cos.to(torch.float32)
sin_fp32 = sin.to(torch.float32)
q_embed = (q_fp32 * cos_fp32) + (rotate_half(q_fp32) * sin_fp32)
k_embed = (k_fp32 * cos_fp32) + (rotate_half(k_fp32) * sin_fp32)
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
def _grouped_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
offs: torch.Tensor | None = None,
) -> torch.Tensor:
"""input: [Total_Tokens, In_Dim]
weight: [Num_Experts, In_Dim, Out_Dim]
"""
orig_dtype = input.dtype
if input.dtype != weight.dtype:
input = input.to(weight.dtype)
out = torch._grouped_mm(input, weight, offs=offs)
if out.dtype != orig_dtype:
out = out.to(orig_dtype)
if bias is not None:
out = out + bias
return out
================================================
FILE: holomotion/src/motion_retargeting/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/motion_retargeting/gmr_to_holomotion.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import json, os, sys
from pathlib import Path
from typing import Dict, Tuple, Optional, List
import joblib
import numpy as np
import torch
import hydra
from omegaconf import OmegaConf, DictConfig, ListConfig
from tqdm import tqdm
from loguru import logger
from holomotion.src.utils import torch_utils
from holomotion.src.motion_retargeting.utils.torch_humanoid_batch import (
HumanoidBatch,
)
from holomotion.src.motion_retargeting.utils import (
rotation_conversions as rot_conv,
)
from holomotion.src.motion_retargeting.holomotion_preprocess import (
HoloMotionPreprocessor,
ProcessedClip,
)
import ray
import logging
def quaternion_to_axis_angle(q: torch.Tensor) -> torch.Tensor:
q = q / torch.norm(q, dim=-1, keepdim=True)
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
angle = 2 * torch.acos(torch.clamp(w, -1.0, 1.0))
s = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
s = torch.clamp(s, min=1e-6)
ax = x / s
ay = y / s
az = z / s
axis_angles = torch.stack([ax * angle, ay * angle, az * angle], dim=-1)
return axis_angles
def dof_to_pose_aa(
dof_pos: np.ndarray,
robot_config_path: Optional[str],
root_rot: Optional[np.ndarray],
) -> np.ndarray:
"""Compute pose_aa via FK; if no config is provided, return zero placeholders."""
if not robot_config_path:
T = dof_pos.shape[0]
return np.zeros((T, 27, 3), dtype=np.float32)
robot_cfg = OmegaConf.load(robot_config_path)
logger.info(f"Loaded robot config for FK from: {robot_config_path}")
fk = HumanoidBatch(robot_cfg.robot)
num_aug = len(robot_cfg.robot.extend_config)
dof_t = torch.as_tensor(dof_pos, dtype=torch.float32)
if dof_t.dim() == 3 and dof_t.shape[-1] == 1:
dof_t = dof_t.squeeze(-1)
T = dof_t.shape[0]
if root_rot is None:
root_aa = torch.zeros((T, 3), dtype=torch.float32)
else:
rr = torch.as_tensor(root_rot, dtype=torch.float32)
root_aa = quaternion_to_axis_angle(rr) if rr.shape[-1] == 4 else rr
joint_aa = fk.dof_axis * dof_t.unsqueeze(-1)
pose_aa = torch.cat(
[root_aa.unsqueeze(1), joint_aa, torch.zeros((T, num_aug, 3))], dim=1
)
return pose_aa.numpy().astype(np.float32, copy=False)
def load_any_pkl(p: Path):
with open(p, "rb") as f:
return joblib.load(f)
def unwrap_source(obj) -> Dict[str, np.ndarray]:
"""Accept {top_key: inner} or flat dict (early GMR)."""
if isinstance(obj, dict) and len(obj) == 1:
inner = next(iter(obj.values()))
if isinstance(inner, dict):
return inner
if isinstance(obj, dict):
return obj
raise ValueError("Unsupported PKL structure")
def make_motion_key(p: Path, src_dir: Path) -> str:
rel = p.relative_to(src_dir).with_suffix("")
return "/".join(rel.parts)
def key_to_filename(key: str) -> str:
return key.replace("/", "_") + ".npz"
def get_ref_schema(
ref_dir: Path,
) -> Tuple[Dict[str, Tuple[Tuple[int, ...], np.dtype]], str]:
"""
Read schema only from ref_dir/_schema.json.
Expected JSON structure:
{
"schema": {
"root_trans_offset": {"shape": [T, 3], "dtype": "float64"},
"pose_aa": {"shape": [T, 27, 3], "dtype": "float32"},
...
},
"sample_top_key": "xxx"
}
"""
ref_dir = Path(ref_dir)
cache_path = ref_dir / "_schema.json"
if not cache_path.exists():
raise FileNotFoundError(f"Schema JSON not found: {cache_path}")
try:
with open(cache_path, "r", encoding="utf-8") as f:
obj = json.load(f)
except Exception as e:
raise ValueError(f"Failed to parse _schema.json: {e}")
schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]] = {}
raw = obj.get("schema", {})
if not isinstance(raw, dict) or not raw:
raise ValueError("Schema JSON missing 'schema' object or it's empty.")
for k, v in raw.items():
if not isinstance(v, dict) or "shape" not in v or "dtype" not in v:
raise ValueError(f"Bad schema entry for key '{k}': {v}")
shape = tuple(int(x) for x in v["shape"])
dtype = np.dtype(v["dtype"])
schema[k] = (shape, dtype)
sample_top_key = obj.get("sample_top_key", "")
return schema, sample_top_key
def infer_T(src_inner: Dict[str, np.ndarray]) -> Optional[int]:
for key in [
"root_trans_offset",
"root_pos",
"pose_aa",
"dof",
"dof_pos",
"root_rot",
"smpl_joints",
]:
v = src_inner.get(key)
if isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] > 0:
return int(v.shape[0])
T = 0
for v in src_inner.values():
if isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] > T:
T = int(v.shape[0])
return T or None
def build_inner_from_source(
src_inner: Dict[str, np.ndarray],
schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]],
T_default: int,
) -> Dict[str, object]:
alt_map = {
"root_trans_offset": [
"root_trans_offset",
"root_pos",
"trans",
"root_trans",
],
"pose_aa": ["pose_aa"],
"dof": ["dof", "dof_pos"],
"root_rot": ["root_rot", "root_orient", "root_quat"],
"smpl_joints": ["smpl_joints", "joints", "smpljoints"],
"fps": ["fps", "mocap_framerate", "mocap_frame_rate"],
}
out: Dict[str, object] = {}
T = infer_T(src_inner) or T_default
for key, (shape, dtype) in schema.items():
if key == "fps":
fps = None
for cand in alt_map["fps"]:
v = src_inner.get(cand)
if isinstance(v, (int, np.integer)):
fps = int(v)
break
out["fps"] = int(fps) if fps is not None else 30
continue
src_arr = None
for cand in alt_map.get(key, []):
v = src_inner.get(cand)
if isinstance(v, np.ndarray) and v.ndim >= 1:
src_arr = v
break
# Target shape: override leading T; keep source column count for DOF
if key == "dof" and isinstance(src_arr, np.ndarray):
target_shape = (T, src_arr.shape[1] if src_arr.ndim >= 2 else 1)
else:
ts = list(shape)
if ts:
ts[0] = T
target_shape = tuple(ts)
if src_arr is None:
out[key] = np.zeros(target_shape, dtype=dtype)
continue
arr = src_arr.astype(dtype, copy=False)
if key == "dof" and arr.ndim == 1:
arr = arr.reshape(-1, 1)
if arr.shape[0] > T:
arr = arr[:T]
elif arr.shape[0] < T:
pad = np.repeat(arr[-1:], T - arr.shape[0], axis=0)
arr = np.concatenate([arr, pad], axis=0)
if (
key != "dof"
and len(target_shape) == 2
and arr.shape[1] != target_shape[1]
):
if arr.shape[1] > target_shape[1]:
arr = arr[:, : target_shape[1]]
else:
pad = np.zeros(
(T, target_shape[1] - arr.shape[1]), dtype=arr.dtype
)
arr = np.concatenate([arr, pad], axis=1)
if len(target_shape) == 3:
d1 = min(arr.shape[1], target_shape[1])
d2 = min(arr.shape[2], target_shape[2])
arr = arr[:, :d1, :d2]
if (arr.shape[1], arr.shape[2]) != (
target_shape[1],
target_shape[2],
):
pad = np.zeros(
(
T,
target_shape[1] - arr.shape[1],
target_shape[2] - arr.shape[2],
),
dtype=arr.dtype,
)
arr = np.concatenate([arr, pad], axis=1)
if arr.shape[2] != target_shape[2]:
pad2 = np.zeros(
(T, target_shape[1], target_shape[2] - arr.shape[2]),
dtype=arr.dtype,
)
arr = np.concatenate([arr, pad2], axis=2)
out[key] = arr.astype(dtype, copy=False)
return out
def to_torch(tensor):
if torch.is_tensor(tensor):
return tensor
else:
return torch.from_numpy(tensor.copy())
def batch_interpolate_tensor(
tensor, orig_times, target_times, use_slerp=False
):
"""Optimized tensor interpolation with batch processing"""
target_num_frames = len(target_times)
shape = list(tensor.shape)
shape[0] = target_num_frames
# Create empty output tensor
result = torch.zeros(shape, device=tensor.device, dtype=tensor.dtype)
if len(tensor.shape) == 2:
# For 2D tensors - process all frames at once
# Create masks for the three cases
before_mask = target_times <= orig_times[0]
after_mask = target_times >= orig_times[-1]
valid_mask = ~(before_mask | after_mask)
# Handle edge cases
if before_mask.any():
result[before_mask] = tensor[0]
if after_mask.any():
result[after_mask] = tensor[-1]
# Process interpolation for valid times
if valid_mask.any():
valid_times = target_times[valid_mask]
# Get indices for lower frames
indices = torch.searchsorted(orig_times, valid_times) - 1
# Ensure indices are valid
indices = torch.clamp(indices, 0, len(orig_times) - 2)
next_indices = indices + 1
# Calculate weights
alphas = (valid_times - orig_times[indices]) / (
orig_times[next_indices] - orig_times[indices]
)
alphas = alphas.unsqueeze(-1) # Add dimension for broadcasting
if use_slerp and tensor.shape[1] == 4: # Quaternion data
# Process in smaller batches to avoid memory issues
batch_size = 1000 # Adjust based on available memory
num_valid = valid_mask.sum()
for i in range(0, num_valid, batch_size):
end_idx = min(i + batch_size, num_valid)
batch_indices = torch.where(valid_mask)[0][i:end_idx]
batch_alphas = alphas[i:end_idx]
batch_lower_indices = indices[i:end_idx]
batch_upper_indices = next_indices[i:end_idx]
# Get frame data for this batch
frames_low = tensor[batch_lower_indices]
frames_high = tensor[batch_upper_indices]
# Apply SLERP to this batch
result[batch_indices] = torch_utils.slerp(
frames_low, frames_high, batch_alphas
)
else:
# Standard linear interpolation - can be done in one batch
frames_low = tensor[indices]
frames_high = tensor[next_indices]
result[valid_mask] = (
frames_low * (1 - alphas) + frames_high * alphas
)
elif len(tensor.shape) == 3:
# For 3D tensors - process each joint sequence
for j in range(tensor.shape[1]):
result[:, j] = batch_interpolate_tensor(
tensor[:, j], orig_times, target_times, use_slerp
)
return result
def fast_interpolate_motion(motion_dict, source_fps, target_fps):
"""Optimized motion interpolation that preserves correctness"""
# Early return if no interpolation needed
if source_fps == target_fps:
return motion_dict
# Calculate timestamps
orig_dt = 1.0 / source_fps
target_dt = 1.0 / target_fps
# Find the first tensor to determine number of frames
for v in motion_dict.values():
if torch.is_tensor(v):
num_frames = v.shape[0]
device = v.device
break
else:
return motion_dict # No tensor data to interpolate
orig_times = torch.arange(0, num_frames, device=device) * orig_dt
wallclock_len = orig_dt * (num_frames - 1)
target_num_frames = int(wallclock_len * target_fps) + 1
target_times = (
torch.arange(0, target_num_frames, device=device) * target_dt
)
# Create interpolated motion dictionary
interp_motion = {}
for k, v in motion_dict.items():
if not torch.is_tensor(v):
interp_motion[k] = v
continue
is_quat = "quat" in k
interp_motion[k] = batch_interpolate_tensor(
v, orig_times, target_times, is_quat
)
return interp_motion
def process_single_motion(
robot_cfg: dict,
all_samples, # Can be dict or LazyMotionLoader
curr_key: str,
target_fps: int = 50,
fast_interpolate: bool = True,
debug_mode: bool = False,
):
logger.debug(f"Starting process_single_motion for key: {curr_key}")
humanoid_fk = HumanoidBatch(robot_cfg)
motion_sample_dict = all_samples[curr_key]
if len(motion_sample_dict) == 1:
motion_sample_dict = motion_sample_dict[
list(motion_sample_dict.keys())[0]
]
logger.debug("Step 3: Extracting sequence length")
if debug_mode:
# In debug mode, let exceptions bubble up naturally
if "root_trans_offset" not in motion_sample_dict:
available_keys = list(motion_sample_dict.keys())
raise KeyError(
f"'root_trans_offset' not found in motion data. Available keys: {available_keys}"
)
seq_len = motion_sample_dict["root_trans_offset"].shape[0]
start, end = 0, seq_len
logger.debug(f"Step 3 completed - seq_len: {seq_len}")
else:
try:
if "root_trans_offset" not in motion_sample_dict:
available_keys = list(motion_sample_dict.keys())
raise KeyError(
f"'root_trans_offset' not found in motion data. Available keys: {available_keys}"
)
seq_len = motion_sample_dict["root_trans_offset"].shape[0]
start, end = 0, seq_len
logger.debug(f"Step 3 completed - seq_len: {seq_len}")
except Exception as e:
logger.error(
f"Step 3 failed - Extracting sequence length: {e}",
exc_info=True,
)
raise RuntimeError(
f"Failed to extract sequence length: {e}"
) from e
logger.debug("Step 4: Processing root translation")
if debug_mode:
# In debug mode, let exceptions bubble up naturally
trans = to_torch(motion_sample_dict["root_trans_offset"]).clone()[
start:end
]
logger.debug(f"Step 4 completed - trans shape: {trans.shape}")
else:
try:
trans = to_torch(motion_sample_dict["root_trans_offset"]).clone()[
start:end
]
logger.debug(f"Step 4 completed - trans shape: {trans.shape}")
except Exception as e:
logger.error(
f"Step 4 failed - Processing root translation: {e}",
exc_info=True,
)
raise RuntimeError(
f"Failed to process root translation: {e}"
) from e
logger.debug("Step 5: Processing pose_aa")
if debug_mode:
# In debug mode, let exceptions bubble up naturally
if "pose_aa" not in motion_sample_dict:
available_keys = list(motion_sample_dict.keys())
raise KeyError(
f"'pose_aa' not found in motion data. Available keys: {available_keys}"
)
pose_aa = to_torch(motion_sample_dict["pose_aa"][start:end]).clone()
# If available, enforce root rotation from input quaternions (XYZW)
if "root_rot" in motion_sample_dict:
root_quat_xyzw = to_torch(
motion_sample_dict["root_rot"][start:end]
).clone()
root_quat_wxyz = rot_conv.xyzw_to_wxyz(root_quat_xyzw)
root_axis_angle = rot_conv.quaternion_to_axis_angle(root_quat_wxyz)
pose_aa[:, 0, :] = root_axis_angle
logger.debug(f"Step 5 completed - pose_aa shape: {pose_aa.shape}")
else:
try:
if "pose_aa" not in motion_sample_dict:
available_keys = list(motion_sample_dict.keys())
raise KeyError(
f"'pose_aa' not found in motion data. Available keys: {available_keys}"
)
pose_aa = to_torch(
motion_sample_dict["pose_aa"][start:end]
).clone()
# If available, enforce root rotation from input quaternions (XYZW)
if "root_rot" in motion_sample_dict:
root_quat_xyzw = to_torch(
motion_sample_dict["root_rot"][start:end]
).clone()
root_quat_wxyz = rot_conv.xyzw_to_wxyz(root_quat_xyzw)
root_axis_angle = rot_conv.quaternion_to_axis_angle(
root_quat_wxyz
)
pose_aa[:, 0, :] = root_axis_angle
logger.debug(f"Step 5 completed - pose_aa shape: {pose_aa.shape}")
except Exception as e:
logger.error(
f"Step 5 failed - Processing pose_aa: {e}", exc_info=True
)
raise RuntimeError(f"Failed to process pose_aa: {e}") from e
logger.debug("Step 6: Calculating dt")
if debug_mode:
# In debug mode, let exceptions bubble up naturally
if "fps" not in motion_sample_dict:
available_keys = list(motion_sample_dict.keys())
raise KeyError(
f"'fps' not found in motion data. Available keys: {available_keys}"
)
fps = motion_sample_dict["fps"]
if fps <= 0:
raise ValueError(f"Invalid fps value: {fps}")
dt = 1 / fps
logger.debug(f"Step 6 completed - fps: {fps}, dt: {dt}")
else:
try:
if "fps" not in motion_sample_dict:
available_keys = list(motion_sample_dict.keys())
raise KeyError(
f"'fps' not found in motion data. Available keys: {available_keys}"
)
fps = motion_sample_dict["fps"]
if fps <= 0:
raise ValueError(f"Invalid fps value: {fps}")
dt = 1 / fps
logger.debug(f"Step 6 completed - fps: {fps}, dt: {dt}")
except Exception as e:
logger.error(f"Step 6 failed - Calculating dt: {e}", exc_info=True)
raise RuntimeError(f"Failed to calculate dt: {e}") from e
logger.debug("Step 8: Running forward kinematics")
if debug_mode:
# In debug mode, let exceptions bubble up naturally
curr_motion = humanoid_fk.fk_batch(
pose_aa[None,],
trans[None,],
return_full=True,
dt=dt,
)
logger.debug("Step 8 completed")
else:
try:
curr_motion = humanoid_fk.fk_batch(
pose_aa[None,],
trans[None,],
return_full=True,
dt=dt,
)
logger.debug("Step 8 completed")
except Exception as e:
logger.error(
f"Step 8 failed - Forward kinematics: {e}", exc_info=True
)
raise RuntimeError(f"Failed to run forward kinematics: {e}") from e
curr_motion = dict(
{
k: v.squeeze() if torch.is_tensor(v) else v
for k, v in curr_motion.items()
}
)
motion_fps = curr_motion["fps"]
motion_dt = 1.0 / motion_fps
num_frames = curr_motion["global_rotation"].shape[0]
wallclock_len = motion_dt * (num_frames - 1)
num_dofs = len(robot_cfg.motion.dof_names)
num_bodies = len(robot_cfg.motion.body_names)
num_extended_bodies = num_bodies + len(
robot_cfg.motion.get("extend_config", [])
)
# build a frame_flag array to indicate three status:
# start_of_motion: 0, middle_of_motion: 1, end_of_motion: 2
frame_flag = torch.ones(num_frames).int()
frame_flag[0] = 0
frame_flag[-1] = 2
curr_motion["frame_flag"] = frame_flag
# rename and pop some keys
curr_motion["global_rotation_quat"] = curr_motion.pop("global_rotation")
curr_motion["local_rotation_quat"] = curr_motion.pop("local_rotation")
if "global_translation_extend" in curr_motion:
curr_motion["global_rotation_quat_extend"] = curr_motion.pop(
"global_rotation_extend"
)
curr_motion.pop("fps")
curr_motion.pop("global_rotation_mat")
if "global_rotation_mat_extend" in curr_motion:
curr_motion.pop("global_rotation_mat_extend")
# add some keys
curr_motion["global_root_translation"] = curr_motion["global_translation"][
:, 0
]
curr_motion["global_root_rotation_quat"] = curr_motion[
"global_rotation_quat"
][:, 0]
# Interpolate to target_fps if different from original fps
if target_fps != motion_fps:
curr_motion = fast_interpolate_motion(
curr_motion, motion_fps, target_fps
)
motion_fps = target_fps
motion_dt = 1.0 / target_fps
num_frames = (
next(iter(curr_motion.values())).shape[0]
if curr_motion
else num_frames
)
wallclock_len = motion_dt * (num_frames - 1)
sample_dict = {
"motion_name": curr_key,
"motion_fps": motion_fps,
"num_frames": num_frames,
"wallclock_len": wallclock_len,
"num_dofs": num_dofs,
"num_bodies": num_bodies,
"num_extended_bodies": num_extended_bodies,
}
sample_dict.update(
{
k: curr_motion[k].float().cpu().numpy()
for k in sorted(curr_motion.keys())
}
)
if debug_mode:
for k, v in sample_dict.items():
if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray):
logger.debug(f"{k}: {v.shape}")
else:
logger.debug(f"{k}: {v}")
return sample_dict
class InMemoryAlignedLoader:
"""Minimal Loader interface: compatible with process_single_motion sample access."""
def __init__(self, mapping: Dict[str, Dict[str, object]]):
self._map = mapping
def keys(self) -> List[str]:
return list(self._map.keys())
def __len__(self):
return len(self._map)
def __getitem__(self, k: str):
return self._map[k]
def load(self, k: str):
return self._map[k]
def get(self, k: str, default=None):
return self._map.get(k, default)
def arrays_for_npz(
sample: Dict, emit_prefixed: bool = True, emit_legacy: bool = False
) -> Dict[str, np.ndarray]:
"""
Build NPZ arrays:
- Always include frame_flag if present
- If emit_prefixed: write ref_* arrays mapped from base keys
- If emit_legacy: also include legacy, unprefixed keys for compatibility
"""
base_to_ref = {
"dof_pos": "ref_dof_pos",
"dof_vel": "ref_dof_vel",
"dof_vels": "ref_dof_vel",
"global_translation": "ref_global_translation",
"global_rotation_quat": "ref_global_rotation_quat",
"global_velocity": "ref_global_velocity",
"global_angular_velocity": "ref_global_angular_velocity",
}
out: Dict[str, np.ndarray] = {}
if isinstance(sample.get("frame_flag"), np.ndarray):
out["frame_flag"] = sample["frame_flag"]
for base, ref_name in base_to_ref.items():
v = sample.get(base, None)
if isinstance(v, np.ndarray):
if emit_prefixed:
out[ref_name] = v
if emit_legacy:
out[base] = v
return out
@ray.remote
class MotionProcessorActor:
"""
Persistent Ray actor that loads robot config once and processes PKLs asynchronously.
"""
def __init__(
self,
robot_cfg_path: str,
schema: Dict[str, Tuple[Tuple[int, ...], np.dtype]],
):
cfg = OmegaConf.load(robot_cfg_path)
self.robot_cfg = cfg.robot
self.schema = schema
# Cached FK holder for DOF → axis-angle conversion (uses dof_axis)
self._fk_for_dof = HumanoidBatch(self.robot_cfg)
def _dof_to_pose_aa_cached(
self, dof_pos: np.ndarray, root_rot: Optional[np.ndarray]
) -> np.ndarray:
dof_t = torch.as_tensor(dof_pos, dtype=torch.float32)
if dof_t.dim() == 3 and dof_t.shape[-1] == 1:
dof_t = dof_t.squeeze(-1)
T = int(dof_t.shape[0])
if root_rot is None:
root_aa = torch.zeros((T, 3), dtype=torch.float32)
else:
rr = torch.as_tensor(root_rot, dtype=torch.float32)
root_aa = quaternion_to_axis_angle(rr) if rr.shape[-1] == 4 else rr
num_aug = len(self.robot_cfg.extend_config)
joint_aa = self._fk_for_dof.dof_axis * dof_t[:, :, None]
pose_aa = torch.cat(
[root_aa[:, None, :], joint_aa, torch.zeros((T, num_aug, 3))],
dim=1,
)
return pose_aa.numpy().astype(np.float32, copy=False)
def process_pkl(
self,
p_str: str,
src_dir_str: str,
target_fps: int,
fast_interpolate: bool,
debug_mode: bool,
) -> Tuple[bool, Dict[str, object]]:
"""
Returns (success, payload). On success, payload contains:
{ "flat_key": str, "sample": Dict[str, np.ndarray|scalar] }
"""
p = Path(p_str)
src_dir = Path(src_dir_str)
motion_key_rel = make_motion_key(p, src_dir)
flat_key = motion_key_rel.replace("/", "_")
obj = load_any_pkl(p)
inner = unwrap_source(obj)
T_default = infer_T(inner) or 1
aligned = build_inner_from_source(inner, self.schema, T_default)
dof = aligned.get("dof")
if isinstance(dof, np.ndarray) and dof.size > 0:
root_rot = aligned.get("root_rot", None)
aligned["pose_aa"] = self._dof_to_pose_aa_cached(dof, root_rot)
loader = InMemoryAlignedLoader({flat_key: aligned})
sample = process_single_motion(
self.robot_cfg,
loader,
flat_key,
int(target_fps),
bool(fast_interpolate),
bool(debug_mode),
)
payload: Dict[str, object] = {"flat_key": flat_key, "sample": sample}
return True, payload
@hydra.main(
config_path="../../config",
config_name="motion_retargeting/gmr_to_holomotion",
version_base=None,
)
def main(cfg: DictConfig) -> None:
# Setup logging
logger.remove()
log_level = "DEBUG" if bool(cfg.processing.debug_mode) else "INFO"
logger.add(sys.stderr, level=log_level, colorize=True)
src_path = Path(str(cfg.io.src_dir)).expanduser().resolve()
ref_dir = Path(str(cfg.io.ref_dir)).expanduser().resolve()
out_root = Path(str(cfg.io.out_root)).expanduser().resolve()
clips_dir = out_root / "clips"
clips_dir.mkdir(parents=True, exist_ok=True)
# dump resolved config used
(out_root).mkdir(parents=True, exist_ok=True)
with open(out_root / "config_used.yaml", "w") as f:
f.write(OmegaConf.to_yaml(cfg))
# 1) schema from _schema.json
schema, _ = get_ref_schema(ref_dir)
# 2) gather PKLs
if src_path.is_file() and src_path.suffix == ".pkl":
src_pkls = [src_path]
root_for_keys = src_path.parent
else:
src_pkls = []
for dirpath, _, filenames in os.walk(src_path, followlinks=True):
for filename in filenames:
if filename.endswith(".pkl"):
p = Path(dirpath) / filename
if p.is_file():
src_pkls.append(p)
src_pkls = sorted(src_pkls)
root_for_keys = src_path
# 3) quiet third-party DEBUG logs (e.g., filelock/Ray)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.ERROR)
os.environ.setdefault("RAY_BACKEND_LOG_LEVEL", "error")
# 4) initialize Ray
if str(cfg.ray.ray_address):
ray.init(
address=str(cfg.ray.ray_address),
ignore_reinit_error=True,
log_to_driver=False,
include_dashboard=False,
logging_level=logging.ERROR,
)
else:
num_cpus = (
None if int(cfg.ray.num_workers) <= 0 else int(cfg.ray.num_workers)
)
ray.init(
num_cpus=num_cpus,
ignore_reinit_error=True,
log_to_driver=False,
include_dashboard=False,
logging_level=logging.ERROR,
)
# 5) build work list (skip existing if requested)
skip_existing = bool(cfg.processing.skip_existing)
work_list: List[Path] = []
for p in src_pkls:
motion_key = make_motion_key(p, root_for_keys)
out_name = key_to_filename(motion_key)
if skip_existing and (clips_dir / out_name).exists():
continue
work_list.append(p)
if not work_list:
logger.info("No tasks to run (all outputs exist or no PKLs found).")
ray.shutdown()
return
# 6) create persistent actors (each loads robot config once)
if int(cfg.ray.num_workers) > 0:
num_actors = min(len(work_list), int(cfg.ray.num_workers))
else:
available_cpus = int(ray.available_resources().get("CPU", 1))
num_actors = min(len(work_list), max(1, available_cpus))
actors = [
MotionProcessorActor.remote(str(cfg.io.robot_config), schema)
for _ in range(num_actors)
]
# Parse pipeline config
pipeline_cfg = cfg.get("preprocess", None)
pipeline = None
if pipeline_cfg is not None:
pipeline_val = pipeline_cfg.get("pipeline", None)
if pipeline_val is not None:
if isinstance(pipeline_val, (list, tuple, ListConfig)):
pipeline = [str(s) for s in pipeline_val]
elif isinstance(pipeline_val, str):
import ast
pipeline = ast.literal_eval(pipeline_val)
else:
logger.warning(
f"Unexpected pipeline type: {type(pipeline_val)}, value: {pipeline_val}"
)
pipeline = []
else:
pipeline = []
else:
pipeline = []
# Separate per-clip stages from dataset-level stages
per_clip_pipeline = (
[s for s in pipeline if s != "tagging"] if pipeline else []
)
tagging_enabled = pipeline and "tagging" in pipeline
logger.info("=" * 80)
logger.info("Preprocessing Configuration:")
if pipeline:
logger.info(f" Pipeline stages: {pipeline}")
logger.info(f" Number of stages: {len(pipeline)}")
for i, stage in enumerate(pipeline, 1):
logger.info(f" {i}. {stage}")
if tagging_enabled:
logger.info(
" Note: 'tagging' is a dataset-level operation and will run after all clips are processed"
)
else:
logger.info(
" No preprocessing pipeline specified - no processors will be applied"
)
logger.info("=" * 80)
preprocessor = HoloMotionPreprocessor(
slicing_cfg=cfg.slicing,
filtering_cfg=cfg.filtering,
tagging_cfg=cfg.tagging,
padding_cfg=cfg.get("padding", None),
pipeline=per_clip_pipeline if per_clip_pipeline else None,
)
# 7) asynchronously schedule PKLs to actors (round-robin)
pending = {}
next_idx = 0
# prime the queue
for i in range(min(num_actors, len(work_list))):
p = work_list[next_idx]
next_idx += 1
ref = actors[i].process_pkl.remote(
str(p),
str(root_for_keys),
int(cfg.processing.target_fps),
bool(cfg.processing.fast_interpolate),
bool(cfg.processing.debug_mode),
)
pending[ref] = i
# 8) collect results and keep feeding new tasks (post-process in-memory, then write)
total_outputs = 0
with tqdm(total=len(work_list), desc="Ray: PKL→NPZ (Hydra)") as pbar:
while pending:
done, _ = ray.wait(list(pending.keys()), num_returns=1)
ref = done[0]
actor_idx = pending.pop(ref)
ok, payload = ray.get(ref)
if ok:
flat_key: str = payload["flat_key"] # type: ignore[assignment]
sample: Dict = payload["sample"] # type: ignore[assignment]
arrays_ref = arrays_for_npz(
sample,
emit_prefixed=bool(cfg.naming.emit_prefixed),
emit_legacy=bool(cfg.naming.emit_legacy),
)
base_meta = {
"motion_key": flat_key,
"raw_motion_key": flat_key,
"motion_fps": float(sample["motion_fps"]),
"num_frames": int(sample["num_frames"]),
"wallclock_len": float(sample["wallclock_len"]),
"num_dofs": int(sample["num_dofs"]),
"num_bodies": int(sample["num_bodies"]),
"num_extended_bodies": int(sample["num_extended_bodies"]),
"slice_start": 0,
"slice_end": int(sample["num_frames"]),
}
base_clip = ProcessedClip(
motion_key=flat_key,
metadata=base_meta,
arrays=arrays_ref,
)
clips = preprocessor.process_clip(base_clip)
for clip in clips:
out_name = f"{clip.motion_key}.npz"
out_path = clips_dir / out_name
np.savez_compressed(
out_path,
metadata=json.dumps(clip.metadata),
**clip.arrays,
)
total_outputs += 1
else:
logger.warning(f"Processing failed: {payload}")
pbar.update(1)
if next_idx < len(work_list):
p = work_list[next_idx]
next_idx += 1
new_ref = actors[actor_idx].process_pkl.remote(
str(p),
str(root_for_keys),
int(cfg.processing.target_fps),
bool(cfg.processing.fast_interpolate),
bool(cfg.processing.debug_mode),
)
pending[new_ref] = actor_idx
# 9) Optional kinematic tagging (write to out_root level)
if tagging_enabled:
tags_path = (
Path(str(cfg.tagging.output_json_path)).expanduser().resolve()
if str(cfg.tagging.output_json_path)
else (out_root / "kinematic_tags.json")
)
preprocessor.tag_directory(clips_dir, tags_path)
logger.info(
f"Done. NPZ written to: {clips_dir} (total clips: {total_outputs})"
)
ray.shutdown()
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/motion_retargeting/holomotion_fk.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
import os
import xml.etree.ElementTree as ETree
from typing import Dict, List, Tuple
import torch
import pytorch_kinematics as pk
from loguru import logger
from holomotion.src.utils import torch_utils
class MJCFParser:
def __init__(self, robot_file_path: str) -> None:
self._robot_file_path = robot_file_path
@staticmethod
def parse_vec(
text: str | None, size: int, default: List[float]
) -> List[float]:
if text is None:
return list(default)
values = [float(v) for v in text.strip().split()]
if len(values) != size:
raise ValueError(
f"Expected {size} values, got {len(values)} in '{text}'"
)
return values
@staticmethod
def _find_parent(
root: ETree.Element, child: ETree.Element
) -> ETree.Element | None:
for parent in root.iter():
for node in list(parent):
if node is child:
return parent
return None
@staticmethod
def _select_include_children(
parent: ETree.Element, inc_root: ETree.Element
) -> List[ETree.Element]:
if inc_root.tag == "mujoco":
if parent.tag != "mujoco":
sub = inc_root.find(parent.tag)
if sub is not None:
return list(sub)
return list(inc_root)
if inc_root.tag == parent.tag:
return list(inc_root)
return list(inc_root)
@classmethod
def _resolve_includes(cls, root: ETree.Element, base_dir: str) -> None:
includes = root.findall(".//include")
while includes:
for inc in includes:
inc_file = inc.attrib.get("file")
if inc_file is None:
raise ValueError("Include tag missing 'file' attribute")
inc_path = os.path.join(base_dir, inc_file)
inc_root = ETree.parse(inc_path).getroot()
cls._resolve_includes(inc_root, os.path.dirname(inc_path))
parent = cls._find_parent(root, inc)
if parent is None:
raise ValueError("Failed to resolve include parent")
insert_children = cls._select_include_children(
parent, inc_root
)
insert_index = list(parent).index(inc)
for child in list(insert_children):
parent.insert(insert_index, child)
insert_index += 1
parent.remove(inc)
includes = root.findall(".//include")
def load_root(self) -> ETree.Element:
root = ETree.parse(self._robot_file_path).getroot()
self._resolve_includes(root, os.path.dirname(self._robot_file_path))
return root
def parse(
self,
) -> Tuple[
List[str],
torch.Tensor,
torch.Tensor,
torch.Tensor,
List[List[str]],
Dict[str, int],
Dict[str, List[float]],
List[str],
torch.Tensor,
List[List[int]],
]:
root = self.load_root()
xml_world = root.find("worldbody")
if xml_world is None:
raise ValueError("MJCF missing worldbody")
xml_body_root = xml_world.find("body")
if xml_body_root is None:
raise ValueError("MJCF missing root body")
body_names: List[str] = []
parents: List[int] = []
local_translation: List[List[float]] = []
local_rotation: List[List[float]] = []
body_joint_order: List[List[str]] = []
joint_body_index: Dict[str, int] = {}
joint_axis: Dict[str, List[float]] = {}
def _add_body(xml_body: ETree.Element, parent_index: int) -> None:
body_idx = len(body_names)
body_names.append(xml_body.attrib.get("name", ""))
parents.append(parent_index)
local_translation.append(
self.parse_vec(xml_body.attrib.get("pos"), 3, [0.0, 0.0, 0.0])
)
local_rotation.append(
self.parse_vec(
xml_body.attrib.get("quat"), 4, [1.0, 0.0, 0.0, 0.0]
)
)
joints_in_body: List[str] = []
for joint in xml_body.findall("joint"):
joint_name = joint.attrib.get("name")
if joint_name is None:
raise ValueError("Joint missing name")
joint_type = joint.attrib.get("type", "hinge")
if joint_type == "free":
continue
if joint_type != "hinge":
raise ValueError(f"Unsupported joint type: {joint_type}")
axis = self.parse_vec(
joint.attrib.get("axis"), 3, [0.0, 0.0, 1.0]
)
joint_body_index[joint_name] = body_idx
joint_axis[joint_name] = axis
joints_in_body.append(joint_name)
body_joint_order.append(joints_in_body)
for child in xml_body.findall("body"):
_add_body(child, body_idx)
_add_body(xml_body_root, -1)
if local_translation:
local_translation[0] = [0.0, 0.0, 0.0]
local_rotation[0] = [1.0, 0.0, 0.0, 0.0]
dof_names: List[str] = []
for elem in root.iter():
if elem.tag == "actuator":
for child in list(elem):
joint_name = child.attrib.get("joint")
if joint_name is not None:
dof_names.append(joint_name)
if len(dof_names) == 0:
raise ValueError("No actuated joints found in MJCF")
dof_axis: List[List[float]] = []
for joint_name in dof_names:
if joint_name not in joint_body_index:
raise ValueError(f"Actuator joint not found: {joint_name}")
dof_axis.append(joint_axis[joint_name])
dof_name_to_index = {name: idx for idx, name in enumerate(dof_names)}
body_dof_indices: List[List[int]] = []
for joints in body_joint_order:
indices: List[int] = []
for name in joints:
if name in dof_name_to_index:
indices.append(dof_name_to_index[name])
body_dof_indices.append(indices)
return (
body_names,
torch.tensor(parents, dtype=torch.long),
torch.tensor(local_translation, dtype=torch.float32),
torch.tensor(local_rotation, dtype=torch.float32),
body_joint_order,
joint_body_index,
joint_axis,
dof_names,
torch.tensor(dof_axis, dtype=torch.float32),
body_dof_indices,
)
class URDFParser:
def __init__(self, urdf_path: str) -> None:
self._urdf_path = urdf_path
@staticmethod
def _as_tf(
tf: torch.Tensor | None, identity: torch.Tensor
) -> torch.Tensor:
if tf is None:
return identity
if tf.ndim == 3:
return tf[0]
return tf
def _load_chain(self) -> pk.Chain:
with open(self._urdf_path, mode="r", encoding="utf-8") as f:
urdf_text = f.read()
return pk.build_chain_from_urdf(urdf_text)
def parse(
self,
) -> Tuple[
List[str],
torch.Tensor,
torch.Tensor,
torch.Tensor,
List[List[str]],
Dict[str, int],
Dict[str, List[float]],
List[str],
torch.Tensor,
List[List[int]],
]:
pk_chain = self._load_chain()
dof_names = pk_chain.get_joint_parameter_names()
if len(dof_names) == 0:
raise ValueError("No actuated joints found in URDF")
dof_axis = pk_chain.axes.to(dtype=torch.float32)
root_name = pk_chain._root.name
moving_frames = pk_chain.get_frame_names(exclude_fixed=True)
body_names = [root_name] + [
name for name in moving_frames if name != root_name
]
body_name_to_index = {name: idx for idx, name in enumerate(body_names)}
num_frames = len(pk_chain.idx_to_frame)
frame_name_to_index = {
name: idx for idx, name in pk_chain.idx_to_frame.items()
}
full_parent_indices: List[int] = []
for i in range(num_frames):
chain_indices = pk_chain.parents_indices[i]
if chain_indices.numel() <= 1:
full_parent_indices.append(-1)
else:
full_parent_indices.append(int(chain_indices[-2].item()))
identity = torch.eye(4, dtype=torch.float32)
frame_transforms: List[torch.Tensor] = [identity] * num_frames
for i in range(num_frames):
link_offset = self._as_tf(pk_chain.link_offsets[i], identity)
joint_offset = self._as_tf(pk_chain.joint_offsets[i], identity)
if i == 0:
link_offset = identity
joint_offset = identity
parent = full_parent_indices[i]
if parent < 0:
frame_tf = identity
else:
frame_tf = (
frame_transforms[parent] @ link_offset @ joint_offset
)
frame_transforms[i] = frame_tf
parents: List[int] = []
local_translation: List[List[float]] = []
local_rotation_mat: List[torch.Tensor] = []
body_joint_order: List[List[str]] = []
joint_body_index: Dict[str, int] = {}
joint_axis: Dict[str, List[float]] = {}
for body_name in body_names:
frame_idx = frame_name_to_index[body_name]
parent_frame_idx = full_parent_indices[frame_idx]
parent_body_idx = -1
while parent_frame_idx >= 0:
parent_name = pk_chain.idx_to_frame[parent_frame_idx]
if parent_name in body_name_to_index:
parent_body_idx = body_name_to_index[parent_name]
break
parent_frame_idx = full_parent_indices[parent_frame_idx]
parents.append(parent_body_idx)
if parent_body_idx < 0:
local_tf = identity
else:
local_tf = (
torch.linalg.inv(frame_transforms[parent_frame_idx])
@ frame_transforms[frame_idx]
)
local_translation.append(local_tf[:3, 3].tolist())
local_rotation_mat.append(local_tf[:3, :3])
joints_in_body: List[str] = []
joint_index = int(pk_chain.joint_indices[frame_idx].item())
if joint_index >= 0:
joint_type = int(pk_chain.joint_type_indices[frame_idx].item())
if joint_type != 1:
raise ValueError(
f"Unsupported joint type index: {joint_type}"
)
joint_name = dof_names[joint_index]
joints_in_body.append(joint_name)
joint_body_index[joint_name] = body_name_to_index[body_name]
joint_axis[joint_name] = dof_axis[joint_index].tolist()
body_joint_order.append(joints_in_body)
local_rotation = torch_utils.quat_from_matrix(
torch.stack(local_rotation_mat, dim=0)
)
dof_name_to_index = {name: idx for idx, name in enumerate(dof_names)}
body_dof_indices: List[List[int]] = []
for joints in body_joint_order:
indices: List[int] = []
for name in joints:
if name in dof_name_to_index:
indices.append(dof_name_to_index[name])
body_dof_indices.append(indices)
return (
body_names,
torch.tensor(parents, dtype=torch.long),
torch.tensor(local_translation, dtype=torch.float32),
local_rotation.to(dtype=torch.float32),
body_joint_order,
joint_body_index,
joint_axis,
dof_names,
dof_axis,
body_dof_indices,
)
# @torch.compile(dynamic=True)
class HoloMotionFK(torch.nn.Module):
def __init__(
self,
robot_file_path: str,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.robot_file_path = robot_file_path
_, ext = os.path.splitext(robot_file_path)
ext = ext.lower()
if ext == ".urdf":
parser = URDFParser(robot_file_path)
elif ext in [".xml", ".mjcf"]:
parser = MJCFParser(robot_file_path)
else:
raise ValueError(f"Unsupported file extension: {ext}")
logger.info(
f"Parsing robot file for online forward kinematics: {robot_file_path}..."
)
(
body_names,
parents,
local_translation,
local_rotation,
body_joint_order,
joint_body_index,
joint_axis,
dof_names,
dof_axis,
body_dof_indices,
) = parser.parse()
self.body_names = body_names
self.dof_names = dof_names
self.num_bodies = len(body_names)
self.num_dof = len(dof_names)
parents = parents.to(device=device)
local_translation = local_translation.to(device=device, dtype=dtype)
local_rotation = local_rotation.to(device=device, dtype=dtype)
local_rotation_mat = torch_utils.matrix_from_quat(local_rotation)
dof_axis = dof_axis.to(device=device, dtype=dtype)
max_body_dofs = max(
(len(indices) for indices in body_dof_indices), default=0
)
body_dof_index_tensor = torch.full(
(self.num_bodies, max_body_dofs),
-1,
dtype=torch.long,
)
body_dof_mask = torch.zeros(
(self.num_bodies, max_body_dofs), dtype=torch.bool
)
for body_idx, indices in enumerate(body_dof_indices):
if not indices:
continue
body_dof_index_tensor[body_idx, : len(indices)] = torch.tensor(
indices, dtype=torch.long
)
body_dof_mask[body_idx, : len(indices)] = True
self.register_buffer("_parents", parents)
self.register_buffer("_local_translation", local_translation)
self.register_buffer("_local_rotation_mat", local_rotation_mat)
self.register_buffer("_dof_axis", dof_axis)
self.register_buffer("_body_dof_index_tensor", body_dof_index_tensor)
self.register_buffer("_body_dof_mask", body_dof_mask)
self._body_joint_order = body_joint_order
self._joint_body_index = joint_body_index
self._joint_axis = joint_axis
self._body_dof_indices = body_dof_indices
@torch.no_grad()
def forward(
self,
root_pos: torch.Tensor,
root_quat: torch.Tensor,
dof_pos: torch.Tensor,
fps: float,
quat_format: str = "xyzw",
sub_batch_size: int = 64,
vel_smoothing_sigma: float = 2.0,
) -> Dict[str, torch.Tensor]:
"""Forward kinematics and smoothed velocities.
Args:
root_pos: (B, T, 3)
root_quat: (B, T, 4), XYZW by default
dof_pos: (B, T, ndof)
fps: frames per second
sub_batch_size: split batch into chunks to reduce peak memory
vel_smoothing_sigma: Gaussian sigma for smoothing velocity signals
along the time axis (set <= 0 to disable).
Returns:
Dict with global_translation/global_rotation_quat/global_velocity/
global_angular_velocity/dof_pos/dof_vel.
"""
if fps <= 0.0:
raise ValueError(f"Invalid fps: {fps}")
if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3:
raise ValueError("Inputs must be (B, T, ...)")
if (
root_pos.shape[:2] != root_quat.shape[:2]
or root_pos.shape[:2] != dof_pos.shape[:2]
):
raise ValueError("Mismatched batch/time shapes among inputs")
if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4:
raise ValueError(
"root_pos must be (B,T,3) and root_quat must be (B,T,4)"
)
if dof_pos.shape[-1] != self.num_dof:
raise ValueError(
f"dof_pos last dim {dof_pos.shape[-1]} does not match "
f"{self.num_dof}"
)
device = self._local_translation.device
dtype = self._local_translation.dtype
root_pos = root_pos.to(device=device, dtype=dtype)
root_quat = root_quat.to(device=device, dtype=dtype)
dof_pos = dof_pos.to(device=device, dtype=dtype)
batch_size, seq_len = root_pos.shape[:2]
if (
sub_batch_size is None
or sub_batch_size <= 0
or sub_batch_size >= batch_size
):
return self._forward_impl(
root_pos=root_pos,
root_quat=root_quat,
dof_pos=dof_pos,
fps=fps,
quat_format=quat_format,
vel_smoothing_sigma=float(vel_smoothing_sigma),
)
global_translation = torch.empty(
(batch_size, seq_len, self.num_bodies, 3),
device=device,
dtype=dtype,
)
global_rotation_quat = torch.empty(
(batch_size, seq_len, self.num_bodies, 4),
device=device,
dtype=dtype,
)
global_velocity = torch.empty_like(global_translation)
global_angular_velocity = torch.empty_like(global_translation)
dof_pos_out = torch.empty_like(dof_pos)
dof_vel = torch.empty_like(dof_pos)
for start in range(0, batch_size, sub_batch_size):
end = min(start + sub_batch_size, batch_size)
out = self._forward_impl(
root_pos=root_pos[start:end],
root_quat=root_quat[start:end],
dof_pos=dof_pos[start:end],
fps=fps,
quat_format=quat_format,
vel_smoothing_sigma=float(vel_smoothing_sigma),
)
global_translation[start:end] = out["global_translation"]
global_rotation_quat[start:end] = out["global_rotation_quat"]
global_velocity[start:end] = out["global_velocity"]
global_angular_velocity[start:end] = out["global_angular_velocity"]
dof_pos_out[start:end] = out["dof_pos"]
dof_vel[start:end] = out["dof_vel"]
return {
"global_translation": global_translation,
"global_rotation_quat": global_rotation_quat,
"global_velocity": global_velocity,
"global_angular_velocity": global_angular_velocity,
"dof_pos": dof_pos_out,
"dof_vel": dof_vel,
}
def _forward_impl(
self,
root_pos: torch.Tensor,
root_quat: torch.Tensor,
dof_pos: torch.Tensor,
fps: float,
quat_format: str,
vel_smoothing_sigma: float,
) -> Dict[str, torch.Tensor]:
device = self._local_translation.device
dtype = self._local_translation.dtype
if quat_format == "xyzw":
root_quat_wxyz = torch_utils.xyzw_to_wxyz(root_quat)
elif quat_format == "wxyz":
root_quat_wxyz = root_quat
else:
raise ValueError(f"Unsupported quat_format: {quat_format}")
root_rotmat = torch_utils.matrix_from_quat(root_quat_wxyz)
dof_rotmats = torch_utils.axis_angle_to_matrix(dof_pos, self._dof_axis)
positions_world = torch.empty(
(dof_pos.shape[0], dof_pos.shape[1], self.num_bodies, 3),
device=device,
dtype=dtype,
)
rotations_world = torch.empty(
(dof_pos.shape[0], dof_pos.shape[1], self.num_bodies, 3, 3),
device=device,
dtype=dtype,
)
for i in range(self.num_bodies):
parent = int(self._parents[i].item())
if parent < 0:
positions_world[:, :, i] = root_pos
rotations_world[:, :, i] = root_rotmat
continue
parent_pos = positions_world[:, :, parent]
parent_rot = rotations_world[:, :, parent]
offset = self._local_translation[i]
pos = parent_pos + torch.einsum("btij,j->bti", parent_rot, offset)
rot = torch.matmul(parent_rot, self._local_rotation_mat[i])
body_dof_indices = self._body_dof_indices[i]
for dof_idx in body_dof_indices:
rot = torch.matmul(rot, dof_rotmats[:, :, dof_idx])
positions_world[:, :, i] = pos
rotations_world[:, :, i] = rot
global_translation = positions_world
global_rotation_mat = rotations_world
global_quat_wxyz = torch_utils.quat_from_matrix(global_rotation_mat)
global_quat_xyzw = torch_utils.wxyz_to_xyzw(global_quat_wxyz)
dt = 1.0 / fps
if dof_pos.shape[1] < 2:
dof_vel = torch.zeros_like(dof_pos)
else:
diff = (dof_pos[:, 1:] - dof_pos[:, :-1]) / dt
pad = diff[:, -2:-1] if diff.shape[1] >= 2 else diff[:, -1:]
dof_vel = torch.cat([diff, pad], dim=1)
dof_vel = torch_utils.smooth_time_series(
dof_vel, sigma=float(vel_smoothing_sigma), dim=1
)
global_velocity = torch_utils.grad_t(global_translation, dt)
global_velocity = torch_utils.smooth_time_series(
global_velocity, sigma=float(vel_smoothing_sigma), dim=1
)
if global_quat_xyzw.shape[1] < 2:
global_angular_velocity = torch.zeros_like(global_translation)
else:
q1 = torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, 1:])
q0_inv = torch_utils.quat_conjugate(
torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, :-1])
)
q_rel = torch_utils.quat_mul(q1, q0_inv)
q_rel = q_rel / torch.linalg.norm(q_rel, dim=-1, keepdim=True)
q_rel = torch_utils.standardize_quaternion(q_rel)
identity = torch.tensor(
[1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype
)[None, None, None]
q_rel_full = identity.expand(
global_quat_xyzw.shape[0],
global_quat_xyzw.shape[1],
global_quat_xyzw.shape[2],
4,
).clone()
q_rel_full[:, :-1] = q_rel
global_angular_velocity = (
torch_utils.axis_angle_from_quat(q_rel_full, w_last=False) / dt
)
global_angular_velocity = torch_utils.smooth_time_series(
global_angular_velocity,
sigma=float(vel_smoothing_sigma),
dim=1,
)
return {
"global_translation": global_translation,
"global_rotation_quat": global_quat_xyzw,
"global_velocity": global_velocity,
"global_angular_velocity": global_angular_velocity,
"dof_pos": dof_pos,
"dof_vel": dof_vel,
}
# class HoloMotionFK_V2(torch.nn.Module):
# """
# Use pytorch_kinematics to compute FK.
# """
# def __init__(
# self,
# robot_file_path: str,
# device: torch.device | str = "cpu",
# dtype: torch.dtype = torch.float32,
# ) -> None:
# super().__init__()
# self.robot_file_path = robot_file_path
# urdf_path = os.path.splitext(robot_file_path)[0] + ".urdf"
# if not os.path.isfile(urdf_path):
# raise FileNotFoundError(f"URDF not found: {urdf_path}")
# with open(urdf_path, mode="r", encoding="utf-8") as f:
# urdf_text = f.read()
# pk_chain = pk.build_chain_from_urdf(urdf_text)
# pk_chain = pk_chain.to(dtype=dtype, device=device)
# self.dof_names = pk_chain.get_joint_parameter_names()
# self.num_dof = len(self.dof_names)
# root_name = pk_chain._root.name
# moving_frames = pk_chain.get_frame_names(exclude_fixed=True)
# self.body_names = [root_name] + [
# name for name in moving_frames if name != root_name
# ]
# self.num_bodies = len(self.body_names)
# body_frame_indices = pk_chain.get_frame_indices(*self.body_names)
# self.register_buffer("_body_frame_indices", body_frame_indices)
# num_frames = len(pk_chain.idx_to_frame)
# identity = torch.eye(4, device=device, dtype=dtype)
# link_offsets = []
# joint_offsets = []
# for i in range(num_frames):
# link_offset = pk_chain.link_offsets[i]
# joint_offset = pk_chain.joint_offsets[i]
# if link_offset is None:
# link_offset = identity
# if joint_offset is None:
# joint_offset = identity
# if link_offset.ndim == 3:
# link_offset = link_offset[0]
# if joint_offset.ndim == 3:
# joint_offset = joint_offset[0]
# link_offsets.append(link_offset)
# joint_offsets.append(joint_offset)
# if num_frames > 0:
# link_offsets[0] = identity
# joint_offsets[0] = identity
# parent_indices: List[int] = []
# for i in range(num_frames):
# chain_indices = pk_chain.parents_indices[i]
# if chain_indices.numel() <= 1:
# parent_indices.append(-1)
# else:
# parent_indices.append(int(chain_indices[-2].item()))
# self.register_buffer("_pk_axes", pk_chain.axes)
# self.register_buffer(
# "_pk_joint_type_indices", pk_chain.joint_type_indices
# )
# self.register_buffer("_pk_joint_indices", pk_chain.joint_indices)
# self.register_buffer(
# "_pk_link_offsets", torch.stack(link_offsets, dim=0)
# )
# self.register_buffer(
# "_pk_joint_offsets", torch.stack(joint_offsets, dim=0)
# )
# self.register_buffer(
# "_pk_parent_indices",
# torch.tensor(parent_indices, dtype=torch.long, device=device),
# )
# self._num_frames = num_frames
# def forward(
# self,
# root_pos: torch.Tensor,
# root_quat: torch.Tensor,
# dof_pos: torch.Tensor,
# fps: float,
# quat_format: str = "xyzw",
# ) -> Dict[str, torch.Tensor]:
# """
# Args:
# root_pos: (B, T, 3)
# root_quat: (B, T, 4), XYZW by default
# dof_pos: (B, T, ndof)
# fps: frames per second
# Returns:
# Dict with global_translation/global_rotation_quat/global_velocity/
# global_angular_velocity/dof_pos/dof_vel.
# """
# if fps <= 0.0:
# raise ValueError(f"Invalid fps: {fps}")
# if root_pos.ndim != 3 or root_quat.ndim != 3 or dof_pos.ndim != 3:
# raise ValueError("Inputs must be (B, T, ...)")
# if (
# root_pos.shape[:2] != root_quat.shape[:2]
# or root_pos.shape[:2] != dof_pos.shape[:2]
# ):
# raise ValueError("Mismatched batch/time shapes among inputs")
# if root_pos.shape[-1] != 3 or root_quat.shape[-1] != 4:
# raise ValueError(
# "root_pos must be (B,T,3) and root_quat must be (B,T,4)"
# )
# if dof_pos.shape[-1] != self.num_dof:
# raise ValueError(
# f"dof_pos last dim {dof_pos.shape[-1]} does not match {self.num_dof}"
# )
# device = self._pk_axes.device
# dtype = self._pk_axes.dtype
# root_pos = root_pos.to(device=device, dtype=dtype)
# root_quat = root_quat.to(device=device, dtype=dtype)
# dof_pos = dof_pos.to(device=device, dtype=dtype)
# if quat_format == "xyzw":
# root_quat_wxyz = torch_utils.xyzw_to_wxyz(root_quat)
# elif quat_format == "wxyz":
# root_quat_wxyz = root_quat
# else:
# raise ValueError(f"Unsupported quat_format: {quat_format}")
# batch_size, seq_len = root_pos.shape[:2]
# flat_size = batch_size * seq_len
# root_pos_flat = root_pos.reshape(flat_size, 3)
# root_quat_flat = root_quat_wxyz.reshape(flat_size, 4)
# dof_pos_flat = dof_pos.reshape(flat_size, self.num_dof)
# axes_expanded = self._pk_axes[None].expand(flat_size, -1, -1)
# revolute_tf = axis_and_angle_to_matrix_44(axes_expanded, dof_pos_flat)
# prismatic_tf = axis_and_d_to_pris_matrix(axes_expanded, dof_pos_flat)
# frame_transforms = torch.empty(
# (flat_size, self._num_frames, 4, 4), device=device, dtype=dtype
# )
# identity = torch.eye(4, device=device, dtype=dtype).repeat(
# flat_size, 1, 1
# )
# for i in range(self._num_frames):
# parent = int(self._pk_parent_indices[i].item())
# if parent < 0:
# frame_tf = identity
# else:
# frame_tf = frame_transforms[:, parent]
# frame_tf = frame_tf @ self._pk_link_offsets[i]
# frame_tf = frame_tf @ self._pk_joint_offsets[i]
# joint_type = int(self._pk_joint_type_indices[i].item())
# if joint_type == 1:
# joint_index = int(self._pk_joint_indices[i].item())
# frame_tf = frame_tf @ revolute_tf[:, joint_index]
# elif joint_type == 2:
# joint_index = int(self._pk_joint_indices[i].item())
# frame_tf = frame_tf @ prismatic_tf[:, joint_index]
# frame_transforms[:, i] = frame_tf
# chain_tf = torch.index_select(
# frame_transforms, 1, self._body_frame_indices
# )
# root_rotmat = torch_utils.matrix_from_quat(root_quat_flat)
# root_tf = torch.eye(4, device=device, dtype=dtype).repeat(
# flat_size, 1, 1
# )
# root_tf[:, :3, :3] = root_rotmat
# root_tf[:, :3, 3] = root_pos_flat
# world_tf = root_tf[:, None] @ chain_tf
# world_tf = world_tf.reshape(batch_size, seq_len, self.num_bodies, 4, 4)
# global_translation = world_tf[:, :, :, :3, 3]
# global_rotation_mat = world_tf[:, :, :, :3, :3]
# global_quat_wxyz = torch_utils.quat_from_matrix(global_rotation_mat)
# global_quat_xyzw = torch_utils.wxyz_to_xyzw(global_quat_wxyz)
# dt = 1.0 / fps
# if dof_pos.shape[1] < 2:
# dof_vel = torch.zeros_like(dof_pos)
# else:
# diff = (dof_pos[:, 1:] - dof_pos[:, :-1]) / dt
# pad = diff[:, -2:-1] if diff.shape[1] >= 2 else diff[:, -1:]
# dof_vel = torch.cat([diff, pad], dim=1)
# global_velocity = torch_utils.grad_t(global_translation, dt)
# global_velocity = torch_utils.gaussian_filter1d(
# global_velocity, sigma=2.0, dim=1
# )
# if global_quat_xyzw.shape[1] < 2:
# global_angular_velocity = torch.zeros_like(global_translation)
# else:
# q1 = torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, 1:])
# q0_inv = torch_utils.quat_conjugate(
# torch_utils.xyzw_to_wxyz(global_quat_xyzw[:, :-1])
# )
# q_rel = torch_utils.quat_mul(q1, q0_inv)
# q_rel = q_rel / torch.linalg.norm(q_rel, dim=-1, keepdim=True)
# q_rel = torch_utils.standardize_quaternion(q_rel)
# identity = torch.tensor(
# [1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype
# )[None, None, None]
# q_rel_full = identity.expand(
# global_quat_xyzw.shape[0],
# global_quat_xyzw.shape[1],
# global_quat_xyzw.shape[2],
# 4,
# ).clone()
# q_rel_full[:, :-1] = q_rel
# global_angular_velocity = (
# torch_utils.axis_angle_from_quat(q_rel_full, w_last=False) / dt
# )
# global_angular_velocity = torch_utils.gaussian_filter1d(
# global_angular_velocity,
# sigma=2.0,
# dim=1,
# )
# return {
# "global_translation": global_translation,
# "global_rotation_quat": global_quat_xyzw,
# "global_velocity": global_velocity,
# "global_angular_velocity": global_angular_velocity,
# "dof_pos": dof_pos,
# "dof_vel": dof_vel,
# }
================================================
FILE: holomotion/src/motion_retargeting/holomotion_preprocess.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import hydra
import numpy as np
import ray
import torch
from loguru import logger
from omegaconf import DictConfig, ListConfig, OmegaConf
from scipy.spatial.transform import Rotation as sRot
from scipy.spatial.transform import Slerp
from tqdm import tqdm
from holomotion.src.motion_retargeting.utils.torch_humanoid_batch import (
HumanoidBatch,
)
from holomotion.src.motion_retargeting.utils import (
rotation_conversions as rot_conv,
)
from holomotion.src.motion_retargeting.reference_filtering import (
butterworth_filter_ref_arrays as shared_butterworth_filter_ref_arrays,
)
def compute_slices(
sequence_len: int, window_size: int, overlap: int
) -> List[Tuple[int, int]]:
step = window_size - overlap
if step <= 0:
raise ValueError("window_size must be > overlap")
slices: List[Tuple[int, int]] = []
start = 0
length = int(sequence_len)
while start < length:
end = min(start + window_size, length)
slices.append((start, end))
if end == length:
break
start += step
return slices
def _reshape_time_flat(a: np.ndarray) -> Tuple[np.ndarray, Tuple[int, ...]]:
shape = a.shape
t = shape[0]
return a.reshape(t, -1), shape
def _butterworth_lowpass_smooth_time(
a: np.ndarray, fps: float, cutoff_hz: float, order: int
) -> np.ndarray:
from scipy.signal import butter, filtfilt
t = a.shape[0]
if t < 3:
return a.astype(np.float32, copy=True)
if fps <= 0.0 or cutoff_hz <= 0.0:
return a.astype(np.float32, copy=True)
nyquist = 0.5 * float(fps)
wn = float(cutoff_hz) / nyquist
if wn >= 1.0:
wn = 0.999
if wn <= 0.0:
return a.astype(np.float32, copy=True)
flat, shape = _reshape_time_flat(a.astype(np.float64, copy=False))
b, a_coefs = butter(int(order), wn, btype="low", analog=False)
maxlen = max(len(b), len(a_coefs))
padlen_required = max(3 * (maxlen - 1), 3 * maxlen)
if t <= padlen_required:
return a.astype(np.float32, copy=True)
filtered = filtfilt(b, a_coefs, flat, axis=0, method="pad")
return filtered.reshape(shape).astype(np.float32, copy=False)
def _quat_normalize(q: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(q, axis=-1, keepdims=True)
norm = np.where(norm == 0.0, 1.0, norm)
return (q / norm).astype(np.float32, copy=False)
def _quat_hemisphere_align(q: np.ndarray) -> np.ndarray:
if q.shape[0] == 0:
return q
aligned = q.copy()
prev = aligned[0]
for t in range(1, aligned.shape[0]):
dots = np.sum(prev * aligned[t], axis=-1)
mask = dots < 0.0
if np.any(mask):
aligned[t, mask] = -aligned[t, mask]
prev = aligned[t]
return aligned
def _quat_conjugate(q: np.ndarray) -> np.ndarray:
conj = q.copy()
conj[..., :3] = -conj[..., :3]
return conj
def _quat_multiply(a: np.ndarray, b: np.ndarray) -> np.ndarray:
av = a[..., :3]
aw = a[..., 3:4]
bv = b[..., :3]
bw = b[..., 3:4]
cross = np.cross(av, bv)
vec = aw * bv + bw * av + cross
scalar = aw * bw - np.sum(av * bv, axis=-1, keepdims=True)
return np.concatenate([vec, scalar], axis=-1)
def _finite_difference_time(a: np.ndarray, dt: float) -> np.ndarray:
t = a.shape[0]
if t < 2 or dt <= 0.0:
return np.zeros_like(a, dtype=np.float32)
deriv = np.gradient(
a.astype(np.float64, copy=False),
dt,
axis=0,
edge_order=2 if t >= 3 else 1,
)
return deriv.astype(np.float32, copy=False)
def _angular_velocity_from_quat(
q: np.ndarray, q_dot: np.ndarray
) -> np.ndarray:
q_conj = _quat_conjugate(q)
prod = _quat_multiply(q_conj, q_dot)
omega = 2.0 * prod[..., :3]
return omega.astype(np.float32, copy=False)
def butterworth_filter_ref_arrays(
arrays: Dict[str, np.ndarray], fps: float, cutoff_hz: float, order: int
) -> Dict[str, np.ndarray]:
return shared_butterworth_filter_ref_arrays(
arrays=arrays,
fps=fps,
cutoff_hz=cutoff_hz,
order=order,
)
def _summary(arr: np.ndarray) -> Dict[str, float]:
if arr.size == 0:
return {
"mean": 0.0,
"std": 0.0,
"median": 0.0,
"min": 0.0,
"max": 0.0,
"q25": 0.0,
"q75": 0.0,
}
return {
"mean": float(arr.mean()),
"std": float(arr.std()),
"median": float(np.median(arr)),
"min": float(arr.min()),
"max": float(arr.max()),
"q25": float(np.quantile(arr, 0.25)),
"q75": float(np.quantile(arr, 0.75)),
}
def _ds_summary(arr: np.ndarray) -> Dict[str, float]:
if arr.size == 0:
return {
"DS_mean": 0.0,
"DS_std": 0.0,
"DS_median": 0.0,
"DS_min": 0.0,
"DS_max": 0.0,
"DS_q25": 0.0,
"DS_q75": 0.0,
}
return {
"DS_mean": float(arr.mean()),
"DS_std": float(arr.std()),
"DS_median": float(np.median(arr)),
"DS_min": float(arr.min()),
"DS_max": float(arr.max()),
"DS_q25": float(np.quantile(arr, 0.25)),
"DS_q75": float(np.quantile(arr, 0.75)),
}
def _interpolate_linear(
start: np.ndarray, end: np.ndarray, num_frames: int
) -> np.ndarray:
"""Linear interpolation between start and end over num_frames.
Returns array where result[0] == start and result[-1] == end.
"""
start = np.asarray(start, dtype=np.float32)
end = np.asarray(end, dtype=np.float32)
if num_frames <= 1:
return start[None, ...]
t = np.linspace(0.0, 1.0, num_frames, dtype=np.float32)
for _ in range(start.ndim):
t = t[..., None]
result = ((1.0 - t) * start + t * end).astype(np.float32)
result[0] = start
result[-1] = end
return result
def _interpolate_quaternions_slerp(
start_quat: np.ndarray, end_quat: np.ndarray, num_frames: int
) -> np.ndarray:
"""SLERP interpolation between two quaternions (XYZW format) over num_frames.
Args:
start_quat: shape [4] in XYZW format
end_quat: shape [4] in XYZW format
num_frames: number of interpolation frames
Returns:
shape [num_frames, 4] in XYZW format, with result[0] == start_quat
and result[-1] == end_quat.
"""
start_quat = np.asarray(start_quat, dtype=np.float32)
end_quat = np.asarray(end_quat, dtype=np.float32)
if num_frames <= 1:
return start_quat[None, ...]
rotations = sRot.from_quat([start_quat, end_quat])
slerp = Slerp([0.0, 1.0], rotations)
t = np.linspace(0.0, 1.0, num_frames)
result = slerp(t).as_quat().astype(np.float32)
result[0] = start_quat
result[-1] = end_quat
return result
def _extract_yaw_only_quat(quat: np.ndarray) -> np.ndarray:
"""Extract yaw-only quaternion (XYZW format) from a full quaternion.
Args:
quat: shape [4] in XYZW format
Returns:
shape [4] in XYZW format with only yaw rotation
"""
rot = sRot.from_quat(quat)
euler = rot.as_euler("xyz", degrees=False)
yaw_only_euler = np.array([0.0, 0.0, euler[2]])
yaw_only_rot = sRot.from_euler("xyz", yaw_only_euler, degrees=False)
return yaw_only_rot.as_quat().astype(np.float32)
def _dof_to_pose_aa(
dof_pos: np.ndarray,
root_rot_xyzw: np.ndarray,
humanoid_fk: "HumanoidBatch",
num_augment_joints: int,
) -> np.ndarray:
"""Convert DOF positions and root rotation to pose axis-angle.
Args:
dof_pos: shape [T, num_dofs]
root_rot_xyzw: shape [T, 4] in XYZW format
humanoid_fk: HumanoidBatch instance
num_augment_joints: number of augmented joints
Returns:
pose_aa: shape [T, num_bodies + num_augment_joints, 3]
"""
dof_t = torch.as_tensor(dof_pos, dtype=torch.float32)
T = dof_t.shape[0]
root_quat_wxyz = rot_conv.xyzw_to_wxyz(
torch.as_tensor(root_rot_xyzw, dtype=torch.float32)
)
root_aa = rot_conv.quaternion_to_axis_angle(root_quat_wxyz)
joint_aa = humanoid_fk.dof_axis * dof_t[:, :, None]
pose_aa = torch.cat(
[
root_aa[:, None, :],
joint_aa,
torch.zeros((T, num_augment_joints, 3), dtype=torch.float32),
],
dim=1,
)
return pose_aa.numpy().astype(np.float32)
def _compute_fk_motion(
dof_pos: np.ndarray,
root_pos: np.ndarray,
root_rot_xyzw: np.ndarray,
humanoid_fk: "HumanoidBatch",
num_augment_joints: int,
fps: float,
) -> Dict[str, np.ndarray]:
"""Compute all motion arrays from dof_pos, root_pos, and root_rot via FK.
Args:
dof_pos: shape [T, num_dofs]
root_pos: shape [T, 3]
root_rot_xyzw: shape [T, 4] in XYZW format
humanoid_fk: HumanoidBatch instance
num_augment_joints: number of augmented joints
fps: frames per second
Returns:
Dict with ref_dof_pos, ref_dof_vel, ref_global_translation,
ref_global_rotation_quat, ref_global_velocity, ref_global_angular_velocity,
frame_flag
"""
T = dof_pos.shape[0]
dt = 1.0 / fps
pose_aa = _dof_to_pose_aa(
dof_pos, root_rot_xyzw, humanoid_fk, num_augment_joints
)
pose_aa_t = torch.as_tensor(pose_aa, dtype=torch.float32)
root_pos_t = torch.as_tensor(root_pos, dtype=torch.float32)
fk_result = humanoid_fk.fk_batch(
pose_aa_t[None, ...],
root_pos_t[None, ...],
return_full=True,
dt=dt,
)
frame_flag = np.ones(T, dtype=np.int32)
frame_flag[0] = 0
frame_flag[-1] = 2
return {
"ref_dof_pos": fk_result.dof_pos.squeeze(0).numpy().astype(np.float32),
"ref_dof_vel": fk_result.dof_vels.squeeze(0)
.numpy()
.astype(np.float32),
"ref_global_translation": fk_result.global_translation.squeeze(0)
.numpy()
.astype(np.float32),
"ref_global_rotation_quat": fk_result.global_rotation.squeeze(0)
.numpy()
.astype(np.float32),
"ref_global_velocity": fk_result.global_velocity.squeeze(0)
.numpy()
.astype(np.float32),
"ref_global_angular_velocity": fk_result.global_angular_velocity.squeeze(
0
)
.numpy()
.astype(np.float32),
"frame_flag": frame_flag,
}
@dataclass
class ProcessedClip:
motion_key: str
metadata: Dict[str, Any]
arrays: Dict[str, np.ndarray]
class HoloMotionPreprocessor:
"""
Composable preprocessing pipeline operating on standardized HoloMotion NPZ clips.
Supports per-clip stages like slicing and Butterworth filtering,
plus dataset-level kinematic tagging.
"""
def __init__(
self,
slicing_cfg: Optional[DictConfig] = None,
filtering_cfg: Optional[DictConfig] = None,
tagging_cfg: Optional[DictConfig] = None,
padding_cfg: Optional[DictConfig] = None,
pipeline: Optional[List[str]] = None,
) -> None:
self.slicing_cfg = slicing_cfg
self.filtering_cfg = filtering_cfg
self.tagging_cfg = tagging_cfg
self.padding_cfg = padding_cfg
self.pipeline = self._resolve_pipeline(pipeline)
self._humanoid_fk: Optional[HumanoidBatch] = None
self._robot_cfg: Optional[DictConfig] = None
def _resolve_pipeline(self, pipeline: Optional[List[str]]) -> List[str]:
if pipeline is not None:
return list(pipeline)
return []
def process_clip(self, clip: ProcessedClip) -> List[ProcessedClip]:
clips = [clip]
logger.debug(
f"Processing clip '{clip.motion_key}' with pipeline: {self.pipeline}"
)
for stage in self.pipeline:
logger.debug(f"Applying stage: {stage}")
if stage in ("slicing", "slice"):
next_clips: List[ProcessedClip] = []
for c in clips:
next_clips.extend(self._apply_slicing(c))
clips = next_clips
logger.debug(f"After slicing: {len(clips)} clips")
elif stage in (
"apply_butterworth_filter",
"filtering",
"butterworth_filter",
):
clips = [self._apply_filtering(c) for c in clips]
logger.debug(
f"After apply_butterworth_filter: {len(clips)} clips"
)
elif stage == "filename_as_motionkey":
clips = [self._apply_filename_as_motionkey(c) for c in clips]
logger.debug(
f"After filename_as_motionkey: {len(clips)} clips"
)
elif stage == "legacy_to_ref_keys":
clips = [self._apply_legacy_to_ref_keys(c) for c in clips]
logger.debug(f"After legacy_to_ref_keys: {len(clips)} clips")
elif stage == "add_legacy_keys":
clips = [self._apply_add_legacy_keys(c) for c in clips]
logger.debug(f"After add_legacy_keys: {len(clips)} clips")
elif stage == "add_padding":
clips = [self._apply_add_padding(c) for c in clips]
logger.debug(f"After add_padding: {len(clips)} clips")
else:
logger.warning(
f"Unknown preprocessing stage '{stage}' ignored."
)
return clips
def _apply_slicing(self, clip: ProcessedClip) -> List[ProcessedClip]:
cfg = self.slicing_cfg
if cfg is None:
logger.warning(
"Slicing requested but slicing_cfg is None - skipping slicing"
)
return [clip]
window_size = int(getattr(cfg, "window_size", 0))
overlap = int(getattr(cfg, "overlap", 0))
seq_len = int(clip.metadata.get("num_frames", 0))
if seq_len <= 0:
return [clip]
slice_specs = compute_slices(seq_len, window_size, overlap)
if not slice_specs:
return [clip]
fps = float(clip.metadata.get("motion_fps", 0.0))
raw_motion_key = str(
clip.metadata.get(
"raw_motion_key", clip.metadata.get("motion_key", "")
)
)
base_motion_key = str(clip.metadata.get("motion_key", raw_motion_key))
arrays = clip.arrays
out_clips: List[ProcessedClip] = []
for s, e in slice_specs:
arrays_window: Dict[str, np.ndarray] = {}
for k, v in arrays.items():
if (
isinstance(v, np.ndarray)
and v.ndim >= 1
and v.shape[0] == seq_len
):
arrays_window[k] = v[s:e]
else:
arrays_window[k] = v
num_frames = int(e - s)
if num_frames <= 0:
continue
wallclock_len = float(num_frames - 1) / fps if fps > 0.0 else 0.0
if s == 0 and e == seq_len:
motion_key = base_motion_key
else:
motion_key = f"{base_motion_key}_s{s}_e{e}"
meta = dict(clip.metadata)
meta["motion_key"] = motion_key
meta["raw_motion_key"] = raw_motion_key
meta["num_frames"] = num_frames
meta["wallclock_len"] = wallclock_len
meta["slice_start"] = int(s)
meta["slice_end"] = int(e)
out_clips.append(
ProcessedClip(
motion_key=motion_key,
metadata=meta,
arrays=arrays_window,
)
)
return out_clips
def _apply_filtering(self, clip: ProcessedClip) -> ProcessedClip:
cfg = self.filtering_cfg
if cfg is None:
logger.warning(
"Filtering requested but filtering_cfg is None - skipping filtering"
)
return clip
fps = float(clip.metadata.get("motion_fps", 0.0))
cutoff = float(getattr(cfg, "butter_cutoff_hz", 0.0))
order = int(getattr(cfg, "butter_order", 4))
ft = butterworth_filter_ref_arrays(
clip.arrays, fps=fps, cutoff_hz=cutoff, order=order
)
arrays = dict(clip.arrays)
arrays.update(ft)
return ProcessedClip(
motion_key=clip.motion_key,
metadata=clip.metadata,
arrays=arrays,
)
def _apply_filename_as_motionkey(
self, clip: ProcessedClip
) -> ProcessedClip:
filename = clip.metadata.get("source_filename", None)
if filename is None:
logger.warning(
"filename_as_motionkey requested but source_filename not found in metadata - skipping"
)
return clip
new_motion_key = str(filename)
meta = dict(clip.metadata)
meta["motion_key"] = new_motion_key
if "raw_motion_key" not in meta:
meta["raw_motion_key"] = clip.motion_key
return ProcessedClip(
motion_key=new_motion_key,
metadata=meta,
arrays=clip.arrays,
)
def _apply_add_legacy_keys(self, clip: ProcessedClip) -> ProcessedClip:
"""Add deprecated legacy keys for backward compatibility.
Maps ref_* keys to legacy unprefixed keys according to spec:
- ref_dof_pos -> dof_pos
- ref_dof_vel -> dof_vels
- ref_global_translation -> global_translation
- ref_global_rotation_quat -> global_rotation_quat
- ref_global_velocity -> global_velocity
- ref_global_angular_velocity -> global_angular_velocity
"""
ref_to_legacy = {
"ref_dof_pos": "dof_pos",
"ref_dof_vel": "dof_vels",
"ref_global_translation": "global_translation",
"ref_global_rotation_quat": "global_rotation_quat",
"ref_global_velocity": "global_velocity",
"ref_global_angular_velocity": "global_angular_velocity",
}
arrays = dict(clip.arrays)
for ref_key, legacy_key in ref_to_legacy.items():
if ref_key in arrays:
if legacy_key not in arrays:
arrays[legacy_key] = arrays[ref_key].copy()
logger.debug(
f"Added legacy key '{legacy_key}' from '{ref_key}'"
)
else:
logger.debug(
f"Legacy key '{legacy_key}' already exists, skipping"
)
return ProcessedClip(
motion_key=clip.motion_key,
metadata=clip.metadata,
arrays=arrays,
)
def _apply_legacy_to_ref_keys(self, clip: ProcessedClip) -> ProcessedClip:
"""Add new ref_* keys from legacy unprefixed keys.
Maps legacy keys to ref_* keys according to spec while keeping the
original legacy arrays:
- dof_pos -> ref_dof_pos
- dof_vels -> ref_dof_vel
- global_translation -> ref_global_translation
- global_rotation_quat -> ref_global_rotation_quat
- global_velocity -> ref_global_velocity
- global_angular_velocity -> ref_global_angular_velocity
"""
legacy_to_ref = {
"dof_pos": "ref_dof_pos",
"dof_vels": "ref_dof_vel",
"global_translation": "ref_global_translation",
"global_rotation_quat": "ref_global_rotation_quat",
"global_velocity": "ref_global_velocity",
"global_angular_velocity": "ref_global_angular_velocity",
}
arrays = dict(clip.arrays)
for legacy_key, ref_key in legacy_to_ref.items():
if legacy_key in arrays:
if ref_key not in arrays:
arrays[ref_key] = arrays[legacy_key].copy()
logger.debug(
f"Added ref key '{ref_key}' from legacy key '{legacy_key}'"
)
else:
logger.debug(
f"Ref key '{ref_key}' already exists, skipping"
)
return ProcessedClip(
motion_key=clip.motion_key,
metadata=clip.metadata,
arrays=arrays,
)
def _get_humanoid_fk(self) -> HumanoidBatch:
"""Lazy-load and cache HumanoidBatch for FK computation."""
if self._humanoid_fk is not None:
return self._humanoid_fk
cfg = self.padding_cfg
robot_config_path = str(getattr(cfg, "robot_config_path", ""))
self._robot_cfg = OmegaConf.load(robot_config_path)
self._humanoid_fk = HumanoidBatch(self._robot_cfg.robot)
return self._humanoid_fk
def _get_default_dof_pos(self) -> np.ndarray:
"""Get default DOF positions from robot config."""
robot_cfg = self._robot_cfg.robot
dof_names = list(robot_cfg.dof_names)
init_state = robot_cfg.get("init_state", {})
default_angles = init_state.get("default_joint_angles", {})
default_dof = np.zeros(len(dof_names), dtype=np.float32)
for i, name in enumerate(dof_names):
default_dof[i] = float(default_angles.get(name, 0.0))
return default_dof
def _apply_add_padding(self, clip: ProcessedClip) -> ProcessedClip:
"""Add transition and static padding to the motion clip.
Adds stand-still padding at default pose before and after the motion,
with smooth transitions between default pose and the motion's first/last
frames. Recalculates all states from root pos, rot and dof pos via FK.
"""
cfg = self.padding_cfg
if cfg is None:
logger.warning(
"Padding requested but padding_cfg is None - skipping padding"
)
return clip
fps = float(clip.metadata.get("motion_fps", 50.0))
stand_still_time = float(getattr(cfg, "stand_still_time", 1.0))
transition_time = float(getattr(cfg, "transition_time", 1.5))
robot_config_path = str(getattr(cfg, "robot_config_path", ""))
if not robot_config_path:
raise ValueError(
"robot_config_path must be specified in padding_cfg"
)
stand_still_frames = max(1, int(stand_still_time * fps))
transition_frames = max(1, int(transition_time * fps))
humanoid_fk = self._get_humanoid_fk()
default_dof = self._get_default_dof_pos()
extend_config = self._robot_cfg.robot.get("extend_config", [])
num_augment = len(extend_config) if extend_config else 0
# Get root offset from HumanoidBatch (usually from MJCF root body pos)
# self._offsets is [1, num_bodies, 3]
# root_offset = humanoid_fk._offsets[0, 0].cpu().numpy()
arrays = clip.arrays
dof_pos = arrays.get("ref_dof_pos", arrays.get("dof_pos"))
global_trans = arrays.get(
"ref_global_translation", arrays.get("global_translation")
)
global_rot = arrays.get(
"ref_global_rotation_quat", arrays.get("global_rotation_quat")
)
if dof_pos is None or global_trans is None or global_rot is None:
raise ValueError(
"Missing required arrays for padding: ref_dof_pos, "
"ref_global_translation, or ref_global_rotation_quat"
)
T_orig = dof_pos.shape[0]
dof_pos = dof_pos.astype(np.float32, copy=True)
root_pos = global_trans[:, 0, :].astype(np.float32, copy=True)
root_rot = global_rot[:, 0, :].astype(np.float32, copy=True)
first_dof = dof_pos[0].copy()
last_dof = dof_pos[-1].copy()
first_root_pos = root_pos[0].copy()
last_root_pos = root_pos[-1].copy()
first_root_rot = root_rot[0].copy()
last_root_rot = root_rot[-1].copy()
logger.debug(
f"Padding: T_orig={T_orig}, first_root_pos={first_root_pos}, "
f"last_root_pos={last_root_pos}"
)
logger.debug(
f"Padding: first_dof[:3]={first_dof[:3]}, last_dof[:3]={last_dof[:3]}"
)
logger.debug(
f"Padding: original dof_pos[-1][:3]={dof_pos[-1][:3]}, "
f"original root_pos[-1]={root_pos[-1]}, original root_rot[-1]={root_rot[-1]}"
)
first_yaw_quat = _extract_yaw_only_quat(first_root_rot)
last_yaw_quat = _extract_yaw_only_quat(last_root_rot)
start_stand_dof = np.tile(default_dof, (stand_still_frames, 1))
start_trans_dof = _interpolate_linear(
default_dof, first_dof, transition_frames
)
end_trans_dof = _interpolate_linear(
last_dof, default_dof, transition_frames
)
end_stand_dof = np.tile(default_dof, (stand_still_frames, 1))
start_stand_root_pos = np.tile(first_root_pos, (stand_still_frames, 1))
start_trans_root_pos = _interpolate_linear(
first_root_pos, first_root_pos, transition_frames
)
end_trans_root_pos = _interpolate_linear(
last_root_pos, last_root_pos, transition_frames
)
end_stand_root_pos = np.tile(last_root_pos, (stand_still_frames, 1))
start_stand_root_rot = np.tile(first_yaw_quat, (stand_still_frames, 1))
start_trans_root_rot = _interpolate_quaternions_slerp(
first_yaw_quat, first_root_rot, transition_frames
)
end_trans_root_rot = _interpolate_quaternions_slerp(
last_root_rot, last_yaw_quat, transition_frames
)
end_stand_root_rot = np.tile(last_yaw_quat, (stand_still_frames, 1))
# Construct full sequence of inputs
full_dof = np.concatenate(
[
start_stand_dof,
start_trans_dof,
dof_pos,
end_trans_dof,
end_stand_dof,
],
axis=0,
)
full_root_pos = np.concatenate(
[
start_stand_root_pos,
start_trans_root_pos,
root_pos,
end_trans_root_pos,
end_stand_root_pos,
],
axis=0,
)
full_root_rot = np.concatenate(
[
start_stand_root_rot,
start_trans_root_rot,
root_rot,
end_trans_root_rot,
end_stand_root_rot,
],
axis=0,
)
# Compute FK for the entire sequence to ensure continuity
new_arrays = _compute_fk_motion(
full_dof,
full_root_pos,
full_root_rot,
humanoid_fk,
num_augment,
fps,
)
T_new = full_dof.shape[0]
wallclock_len = float(T_new - 1) / fps if fps > 0.0 else 0.0
meta = dict(clip.metadata)
meta["num_frames"] = T_new
meta["wallclock_len"] = wallclock_len
meta["padding_stand_still_frames"] = stand_still_frames
meta["padding_transition_frames"] = transition_frames
meta["original_num_frames"] = T_orig
return ProcessedClip(
motion_key=clip.motion_key,
metadata=meta,
arrays=new_arrays,
)
def process_npz_file(self, npz_path: Path) -> List[ProcessedClip]:
with np.load(npz_path, allow_pickle=False) as data:
if "metadata" not in data:
raise KeyError(f"'metadata' missing in NPZ: {npz_path}")
meta_text = str(data["metadata"])
metadata = json.loads(meta_text)
motion_key = str(metadata["motion_key"])
arrays: Dict[str, np.ndarray] = {}
for k in data.files:
if k == "metadata":
continue
arrays[k] = np.array(data[k], copy=False)
filename_without_ext = npz_path.stem
metadata["source_filename"] = filename_without_ext
base_clip = ProcessedClip(
motion_key=motion_key,
metadata=metadata,
arrays=arrays,
)
return self.process_clip(base_clip)
def run_on_directory(
self,
src_root: Path,
out_root: Path,
use_ray: bool = False,
num_workers: int = 0,
) -> None:
if src_root.is_dir():
if (src_root / "clips").is_dir():
clips_src = src_root / "clips"
else:
clips_src = src_root
else:
raise ValueError(f"Source root is not a directory: {src_root}")
clips_dst = out_root / "clips"
clips_dst.mkdir(parents=True, exist_ok=True)
files = sorted([p for p in clips_src.rglob("*.npz") if p.is_file()])
if not files:
logger.info("No NPZ files found to process.")
return
if use_ray:
if num_workers <= 0:
available_cpus = int(ray.available_resources().get("CPU", 1))
effective_workers = max(1, available_cpus)
else:
effective_workers = num_workers
self._run_on_directory_ray(files, clips_dst, effective_workers)
else:
self._run_on_directory_sequential(files, clips_dst)
def _run_on_directory_sequential(
self, files: List[Path], clips_dst: Path
) -> None:
logger.info(f"Processing {len(files)} NPZ files sequentially")
logger.info(f"Pipeline stages to apply: {self.pipeline}")
total_input_clips = 0
total_output_clips = 0
for p in tqdm(files, desc="HoloMotion preprocess NPZ", unit="file"):
clips = self.process_npz_file(p)
total_input_clips += 1
for clip in clips:
total_output_clips += 1
out_name = f"{clip.motion_key}.npz"
out_path = clips_dst / out_name
metadata_json = json.dumps(clip.metadata)
np.savez_compressed(
out_path, metadata=metadata_json, **clip.arrays
)
logger.info(
f"Processed {total_input_clips} input files into {total_output_clips} output clips"
)
def _run_on_directory_ray(
self, files: List[Path], clips_dst: Path, num_workers: int
) -> None:
if num_workers <= 0:
available_cpus = int(ray.available_resources().get("CPU", 1))
num_actors = min(len(files), max(1, available_cpus))
else:
num_actors = min(len(files), num_workers)
actors = [
PreprocessorActor.remote(
slicing_cfg=self.slicing_cfg,
filtering_cfg=self.filtering_cfg,
tagging_cfg=self.tagging_cfg,
padding_cfg=self.padding_cfg,
pipeline=self.pipeline,
)
for _ in range(num_actors)
]
pending = {}
next_idx = 0
for i in range(min(num_actors, len(files))):
p = files[next_idx]
next_idx += 1
ref = actors[i].process_npz_file.remote(str(p))
pending[ref] = i
total_outputs = 0
with tqdm(
total=len(files), desc="Ray: HoloMotion preprocess NPZ"
) as pbar:
while pending:
done, _ = ray.wait(list(pending.keys()), num_returns=1)
ref = done[0]
actor_idx = pending.pop(ref)
clips = ray.get(ref)
for clip in clips:
out_name = f"{clip.motion_key}.npz"
out_path = clips_dst / out_name
metadata_json = json.dumps(clip.metadata)
np.savez_compressed(
out_path, metadata=metadata_json, **clip.arrays
)
total_outputs += 1
pbar.update(1)
if next_idx < len(files):
p = files[next_idx]
next_idx += 1
new_ref = actors[actor_idx].process_npz_file.remote(str(p))
pending[new_ref] = actor_idx
logger.info(f"Processed {total_outputs} clips total.")
def tag_directory(self, clips_dir: Path, tags_path: Path) -> None:
files = sorted([p for p in clips_dir.rglob("*.npz") if p.is_file()])
clip_info: Dict[str, Dict[str, Dict[str, float]]] = {}
all_speed: List[np.ndarray] = []
all_wnorm: List[np.ndarray] = []
all_zrel: List[np.ndarray] = []
all_jerk: List[np.ndarray] = []
for f in tqdm(files, desc="Tagging kinematics", unit="file"):
with np.load(f, allow_pickle=True) as data:
meta_text = str(data["metadata"])
meta = json.loads(meta_text)
key = str(meta["motion_key"])
fps = float(meta["motion_fps"])
def pick(name: str) -> np.ndarray:
if f"ft_ref_{name}" in data:
return np.array(data[f"ft_ref_{name}"], copy=False)
if f"ref_{name}" in data:
return np.array(data[f"ref_{name}"], copy=False)
return np.array([], dtype=np.float32)
gv = pick("global_velocity")
ga = pick("global_angular_velocity")
gt = pick("global_translation")
if gv.size > 0:
root_vel = gv[:, 0, :]
speed = np.linalg.norm(root_vel, axis=1)
else:
speed = np.array([], dtype=float)
if ga.size > 0:
root_w = ga[:, 0, :]
wnorm = np.linalg.norm(root_w, axis=1)
else:
wnorm = np.array([], dtype=float)
if gt.size > 0:
root_pos_z = gt[:, 0, 2]
z_rel = np.abs(root_pos_z - float(root_pos_z[0]))
else:
z_rel = np.array([], dtype=float)
if gv.shape[0] >= 3:
dt = 1.0 / fps if fps > 0.0 else 0.0
a = (
np.diff(gv, axis=0) / dt
if dt > 0.0
else np.zeros_like(gv)
)
j = (
np.diff(a, axis=0) / dt
if dt > 0.0
else np.zeros_like(a)
)
jn = np.linalg.norm(j, axis=2)
else:
jn = np.array([], dtype=float)
clip_info[key] = {
"root_linear_speed": _summary(speed),
"root_angular_speed": _summary(wnorm),
"root_delta_z": _summary(z_rel),
"jerk": _summary(jn),
}
if speed.size > 0:
all_speed.append(speed.astype(float))
if wnorm.size > 0:
all_wnorm.append(wnorm.astype(float))
if z_rel.size > 0:
all_zrel.append(z_rel.astype(float))
if jn.size > 0:
all_jerk.append(jn.astype(float))
speed_cat = (
np.concatenate([a for a in all_speed if a.size > 0], axis=0)
if len(all_speed) > 0
else np.array([], dtype=float)
)
wnorm_cat = (
np.concatenate([a for a in all_wnorm if a.size > 0], axis=0)
if len(all_wnorm) > 0
else np.array([], dtype=float)
)
zrel_cat = (
np.concatenate([a for a in all_zrel if a.size > 0], axis=0)
if len(all_zrel) > 0
else np.array([], dtype=float)
)
jerk_cat = (
np.concatenate([a for a in all_jerk if a.size > 0], axis=0)
if len(all_jerk) > 0
else np.array([], dtype=float)
)
result = {
"dataset_stats": {
"root_linear_speed": _ds_summary(speed_cat),
"root_angular_speed": _ds_summary(wnorm_cat),
"root_delta_z": _ds_summary(zrel_cat),
"jerk": _ds_summary(jerk_cat),
},
"clip_info": clip_info,
}
with open(tags_path, "w") as f:
json.dump(result, f, indent=2, sort_keys=True)
logger.info(f"Wrote kinematic tags JSON to: {tags_path}")
@ray.remote
class PreprocessorActor:
"""Ray actor that holds a HoloMotionPreprocessor instance for parallel processing."""
def __init__(
self,
slicing_cfg: Optional[DictConfig] = None,
filtering_cfg: Optional[DictConfig] = None,
tagging_cfg: Optional[DictConfig] = None,
padding_cfg: Optional[DictConfig] = None,
pipeline: Optional[List[str]] = None,
) -> None:
self.preprocessor = HoloMotionPreprocessor(
slicing_cfg=slicing_cfg,
filtering_cfg=filtering_cfg,
tagging_cfg=tagging_cfg,
padding_cfg=padding_cfg,
pipeline=pipeline,
)
logger.debug(
f"PreprocessorActor initialized with pipeline: {self.preprocessor.pipeline}"
)
def process_npz_file(self, npz_path_str: str) -> List[ProcessedClip]:
npz_path = Path(npz_path_str)
return self.preprocessor.process_npz_file(npz_path)
@hydra.main(
config_path="../../config",
config_name="motion_retargeting/holomotion_preprocess",
version_base=None,
)
def main(cfg: DictConfig) -> None:
logger.remove()
logger.add(sys.stderr, level="INFO", colorize=True)
src_root = Path(str(cfg.io.src_root)).expanduser().resolve()
out_root = Path(str(cfg.io.out_root)).expanduser().resolve()
out_root.mkdir(parents=True, exist_ok=True)
# Dump resolved config used
with open(out_root / "config_used.yaml", "w") as f:
f.write(OmegaConf.to_yaml(cfg))
# Parse pipeline
pipeline_cfg = cfg.get("preprocess", None)
logger.debug(f"Raw preprocess config: {pipeline_cfg}")
pipeline = None
if pipeline_cfg is not None:
pipeline_val = pipeline_cfg.get("pipeline", None)
logger.debug(
f"Raw pipeline value: {pipeline_val} (type: {type(pipeline_val)})"
)
if pipeline_val is not None:
if isinstance(pipeline_val, (list, tuple, ListConfig)):
pipeline = [str(s) for s in pipeline_val]
elif isinstance(pipeline_val, str):
import ast
pipeline = ast.literal_eval(pipeline_val)
else:
logger.warning(
f"Unexpected pipeline type: {type(pipeline_val)}, value: {pipeline_val}"
)
pipeline = []
else:
logger.debug("pipeline_val is None")
else:
logger.debug("preprocess config is None")
# Separate per-clip stages from dataset-level stages
per_clip_pipeline = (
[s for s in pipeline if s != "tagging"] if pipeline else []
)
tagging_enabled = pipeline and "tagging" in pipeline
logger.info("=" * 80)
logger.info("Preprocessing Configuration:")
logger.info(f" Source directory: {src_root}")
logger.info(f" Output directory: {out_root}")
if pipeline:
logger.info(f" Pipeline stages: {pipeline}")
logger.info(f" Number of stages: {len(pipeline)}")
for i, stage in enumerate(pipeline, 1):
logger.info(f" {i}. {stage}")
if tagging_enabled:
logger.info(
" Note: 'tagging' is a dataset-level operation and will run after all clips are processed"
)
else:
logger.warning(
" No preprocessing pipeline specified - no processors will be applied!"
)
logger.info("=" * 80)
use_ray = bool(cfg.get("ray", {}).get("enabled", False))
num_workers = int(cfg.get("ray", {}).get("num_workers", 0))
ray_address = str(cfg.get("ray", {}).get("ray_address", ""))
if use_ray:
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("ray").setLevel(logging.ERROR)
os.environ.setdefault("RAY_BACKEND_LOG_LEVEL", "error")
if ray_address:
ray.init(
address=ray_address,
ignore_reinit_error=True,
log_to_driver=False,
include_dashboard=False,
logging_level=logging.ERROR,
)
if num_workers <= 0:
num_workers = int(ray.available_resources().get("CPU", 1))
else:
num_cpus = None if num_workers <= 0 else num_workers
ray.init(
num_cpus=num_cpus,
ignore_reinit_error=True,
log_to_driver=False,
include_dashboard=False,
logging_level=logging.ERROR,
)
if num_workers <= 0:
num_workers = int(ray.available_resources().get("CPU", 1))
preprocessor = HoloMotionPreprocessor(
slicing_cfg=cfg.slicing,
filtering_cfg=cfg.filtering,
tagging_cfg=cfg.tagging,
padding_cfg=cfg.get("padding", None),
pipeline=per_clip_pipeline if per_clip_pipeline else None,
)
logger.info(
f"Preprocessor initialized with pipeline: {preprocessor.pipeline}"
)
logger.info(
f" Slicing config present: {preprocessor.slicing_cfg is not None}"
)
logger.info(
f" Filtering config present: {preprocessor.filtering_cfg is not None}"
)
logger.info(
f" Tagging config present: {preprocessor.tagging_cfg is not None}"
)
preprocessor.run_on_directory(
src_root, out_root, use_ray=use_ray, num_workers=num_workers
)
if use_ray:
ray.shutdown()
if tagging_enabled:
if str(cfg.tagging.output_json_path):
tags_path = (
Path(str(cfg.tagging.output_json_path)).expanduser().resolve()
)
else:
tags_path = out_root / "kinematic_tags.json"
clips_dir = out_root / "clips"
preprocessor.tag_directory(clips_dir, tags_path)
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/motion_retargeting/kinematic_filter.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
import hydra
import yaml
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
def _eval_rule(val: float, op: str, thr: float) -> bool:
if op == ">":
return val > thr
if op == ">=":
return val >= thr
if op == "<":
return val < thr
if op == "<=":
return val <= thr
if op == "==":
return val == thr
if op == "!=":
return val != thr
raise ValueError(f"Unsupported op: {op}")
def _deep_get(container: Dict[str, Any], parts: List[str]) -> Optional[float]:
cur: Any = container
for p in parts:
if not isinstance(cur, dict) or p not in cur:
return None
cur = cur[p]
if isinstance(cur, (int, float)):
return float(cur)
return None
def _resolve_value(
tags_root: Dict[str, Any],
clip_group: Dict[str, Any],
path: str,
) -> Optional[float]:
"""Resolve a threshold path to a numeric value.
- dataset_stats.. reads from tags_root
- kinematic_features.. reads from clip_group
- . (no prefix) also reads from clip_group for convenience.
"""
parts = str(path).split(".")
if len(parts) == 0:
return None
if parts[0] == "dataset_stats":
return _deep_get(tags_root, parts)
if parts[0] == "kinematic_features":
return _deep_get(clip_group, parts[1:])
return _deep_get(clip_group, parts)
def filter_with_schema(
tags: Dict[str, Any],
schema: Dict[str, Any],
) -> Tuple[Set[str], Dict[str, int], Dict[str, int]]:
thresholds: Dict[str, Dict[str, Any]] = schema.get("thresholds", {}) or {}
across_mode = str(schema.get("across", "union"))
out: Set[str] = set()
path_counts: Dict[str, int] = {}
group_counts: Dict[str, int] = {}
clips: Dict[str, Dict[str, Any]] = tags.get("clip_info", {}) or {}
for motion_key, groups in tqdm(
clips.items(), desc="Evaluating schema", unit="clip"
):
hits: List[bool] = []
hits_by_path: Dict[str, bool] = {}
group_hit_any: Dict[str, bool] = {}
for path, spec in thresholds.items():
parts = str(path).split(".")
if len(parts) == 0:
continue
val = _resolve_value(tags, groups, path)
if val is None:
continue
op = str(spec.get("op", ">"))
thr = float(spec["value"])
hit = _eval_rule(val, op, thr)
hits.append(hit)
hits_by_path[path] = hit
grp = parts[0]
if hit:
group_hit_any[grp] = True
if len(hits) == 0:
continue
if across_mode == "union":
excluded = any(hits)
elif across_mode == "intersection":
excluded = all(hits)
else:
raise ValueError(f"Invalid across mode: {across_mode}")
if not excluded:
continue
out.add(motion_key)
# accumulate counts for excluded clips
for pth, hit in hits_by_path.items():
if hit:
path_counts[pth] = path_counts.get(pth, 0) + 1
for grp, any_hit in group_hit_any.items():
if any_hit:
group_counts[grp] = group_counts.get(grp, 0) + 1
return out, path_counts, group_counts
def _default_schema_path() -> Path:
# holomotion/src/motion_retargeting/kinematic_filter.py
# -> holomotion/config/motion_retargeting/kinematic_filtering_schema.yaml
this_file = Path(__file__).resolve()
holomotion_dir = this_file.parents[2]
return (
holomotion_dir
/ "config"
/ "motion_retargeting"
/ "kinematic_filtering_schema.yaml"
)
def run(
dataset_root: str,
schema_yaml_path: Optional[str] = None,
output_yaml_path: Optional[str] = None,
schema_obj: Optional[Dict[str, Any]] = None,
) -> Set[str]:
"""Execute kinematic filtering using tags and a schema.
- dataset_root: directory containing 'kinematic_tags.json'
- schema_yaml_path: external YAML with 'across' and 'thresholds' (optional)
- schema_obj: inline dict with 'across' and 'thresholds' (optional)
- output_yaml_path: where to write the excluded list YAML (optional)
"""
root = Path(dataset_root).expanduser().resolve()
tags_path = root / "kinematic_tags.json"
if not tags_path.is_file():
raise FileNotFoundError(f"Missing kinematic tags JSON: {tags_path}")
schema: Dict[str, Any]
if schema_obj is not None:
schema = dict(schema_obj)
else:
schema_path = (
Path(schema_yaml_path).expanduser().resolve()
if schema_yaml_path
else _default_schema_path()
)
if not schema_path.is_file():
raise FileNotFoundError(f"Missing schema YAML: {schema_path}")
schema = yaml.safe_load(open(schema_path, "r", encoding="utf-8"))
out_yaml = (
Path(output_yaml_path).expanduser().resolve()
if output_yaml_path
else (root / "excluded_kinematic_motion_names.yaml")
)
logger.info(f"Dataset root: {root}")
logger.info(f"Reading tags from: {tags_path}")
logger.info(
"Using schema from: inline config"
if schema_obj is not None
else "Using schema from YAML file"
)
# Pretty-print resolved schema to console
try:
logger.info(
"Resolved schema:\n"
+ yaml.safe_dump(schema, sort_keys=True, default_flow_style=False)
)
except Exception:
pass
tags = json.load(open(tags_path, "r", encoding="utf-8"))
# Dump the used filter config into dataset root
try:
used_cfg = {
"dataset_root": str(root),
"output_yaml": str(out_yaml),
"schema": schema,
}
with open(
root / "kinematic_filter_config_used.yaml", "w", encoding="utf-8"
) as f:
yaml.safe_dump(
used_cfg, f, sort_keys=True, default_flow_style=False
)
except Exception:
pass
excluded_keys, path_counts, group_counts = filter_with_schema(tags, schema)
with open(out_yaml, "w", encoding="utf-8") as f:
f.write("# @package _global_\n\n")
f.write("excluded_motion_names:\n")
for k in sorted(excluded_keys):
f.write(f"- {k}\n")
logger.info(f"Excluded by config: {len(excluded_keys)}")
if len(group_counts) > 0:
logger.info("Excluded counts by category:")
for grp, cnt in sorted(
group_counts.items(), key=lambda kv: kv[1], reverse=True
):
logger.info(f"- {grp}: {cnt}")
if len(path_counts) > 0:
logger.info("Excluded counts by threshold path:")
for pth, cnt in sorted(
path_counts.items(), key=lambda kv: kv[1], reverse=True
):
logger.info(f"- {pth}: {cnt}")
logger.info(f"Wrote excluded list to: {out_yaml}")
return excluded_keys
@hydra.main(
config_path="../../config",
config_name="motion_retargeting/kinematic_filter",
version_base=None,
)
def main(cfg: DictConfig) -> None:
logger.remove()
logger.add(sys.stderr, level="INFO", colorize=True)
dataset_root = str(cfg.io.dataset_root)
# Optional fields (external schema YAML override and output path)
schema_val = ""
out_val = ""
schema_obj = None
if "schema" in cfg:
# Inline schema object
schema_obj = OmegaConf.to_object(cfg.schema)
if "filtering" in cfg and hasattr(cfg.filtering, "schema_yaml"):
schema_val = str(cfg.filtering.get("schema_yaml", "") or "")
out_val = str(cfg.filtering.get("output_yaml", "") or "")
elif "filtering" in cfg:
out_val = str(cfg.filtering.get("output_yaml", "") or "")
schema_yaml_path = schema_val if len(schema_val) > 0 else None
output_yaml_path = out_val if len(out_val) > 0 else None
run(
dataset_root=dataset_root,
schema_yaml_path=schema_yaml_path,
schema_obj=schema_obj,
output_yaml_path=output_yaml_path,
)
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/motion_retargeting/pack_hdf5_v2.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import h5py
import hydra
import numpy as np
from loguru import logger
from omegaconf import ListConfig, OmegaConf
from tqdm import tqdm
def _ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
@dataclass
class ArraySpec:
name: str
shape_tail: Tuple[int, ...] # shape excluding time dim
dtype: np.dtype
@dataclass
class ClipEntry:
clip_id: int
name: str
path: str
class Hdf5ShardWriter:
def __init__(
self,
h5_path: str,
array_specs: List[ArraySpec],
chunks_t: int,
compression: str,
) -> None:
self.h5_path = h5_path
self.array_specs = array_specs
self.chunks_t = int(chunks_t)
self.compression = compression
_ensure_dir(os.path.dirname(self.h5_path))
self.h5 = h5py.File(self.h5_path, "w")
self.datasets: Dict[str, h5py.Dataset] = {}
for spec in self.array_specs:
chunks = (self.chunks_t, *spec.shape_tail)
maxshape = (None, *spec.shape_tail)
ds = self.h5.create_dataset(
spec.name,
shape=(0, *spec.shape_tail),
maxshape=maxshape,
chunks=chunks,
compression=(
self.compression if self.compression != "none" else None
),
dtype=spec.dtype,
shuffle=True if self.compression != "none" else False,
)
self.datasets[spec.name] = ds
self._clip_starts: List[int] = []
self._clip_lengths: List[int] = []
self._clip_motion_ids: List[int] = []
self._clip_metadata: List[str] = []
self.t_cursor = 0
def append_motion(
self,
motion_id: int,
np_arrays: Dict[str, np.ndarray],
metadata_json: str,
) -> Tuple[int, int]:
if "ref_dof_pos" not in np_arrays:
raise KeyError("ref_dof_pos missing for HDF5 v2 packing")
t_len = int(np_arrays["ref_dof_pos"].shape[0])
start = self.t_cursor
end = start + t_len
for spec in self.array_specs:
if spec.name not in np_arrays:
raise KeyError(
f"Missing array '{spec.name}' for HDF5 v2 packing"
)
ds = self.datasets[spec.name]
ds.resize((end, *spec.shape_tail))
ds[start:end, ...] = np_arrays[spec.name]
self._clip_starts.append(start)
self._clip_lengths.append(t_len)
self._clip_motion_ids.append(motion_id)
self._clip_metadata.append(metadata_json)
self.t_cursor = end
return start, t_len
def finalize(self) -> Dict[str, Any]:
g = self.h5.create_group("clips")
g.create_dataset(
"start", data=np.asarray(self._clip_starts, dtype=np.int64)
)
g.create_dataset(
"length", data=np.asarray(self._clip_lengths, dtype=np.int64)
)
g.create_dataset(
"motion_key_id",
data=np.asarray(self._clip_motion_ids, dtype=np.int64),
)
vlen_str = h5py.string_dtype(encoding="utf-8")
g.create_dataset(
"metadata_json",
data=np.asarray(self._clip_metadata, dtype=vlen_str),
)
summary = {
"file": self.h5_path,
"num_clips": len(self._clip_starts),
"num_frames": int(self.t_cursor),
}
self.h5.flush()
self.h5.close()
return summary
def _normalize_root_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, (str, os.PathLike)):
return [str(value)]
if isinstance(value, (list, tuple, ListConfig)):
return [str(v) for v in list(value)]
return [str(value)]
def _discover_motion_entries(roots: List[str]) -> List[ClipEntry]:
motion_key_to_path: Dict[str, str] = {}
for root in roots:
root_path = Path(root).expanduser().resolve()
parent_dir_name = root_path.name
clips_dir = root_path / "clips"
base_dir = clips_dir if clips_dir.is_dir() else root_path
if not base_dir.is_dir():
raise FileNotFoundError(f"NPZ directory not found: {base_dir}")
for dirpath, _, filenames in os.walk(str(base_dir)):
for fname in filenames:
if not fname.endswith(".npz"):
continue
base_key = os.path.splitext(fname)[0]
motion_key = f"{parent_dir_name}_{base_key}"
npz_path = os.path.join(dirpath, fname)
if motion_key in motion_key_to_path:
raise ValueError(f"Duplicate motion key: {motion_key}")
motion_key_to_path[motion_key] = npz_path
entries = [
ClipEntry(clip_id=i, name=key, path=motion_key_to_path[key])
for i, key in enumerate(sorted(motion_key_to_path.keys()))
]
if len(entries) == 0:
raise ValueError("No NPZ files found in input directories.")
return entries
def _load_metadata_json(npz_path: Path) -> Tuple[str, Dict[str, Any]]:
with np.load(npz_path, allow_pickle=False) as data:
if "metadata" not in data:
raise KeyError(f"'metadata' missing in NPZ: {npz_path}")
metadata_json = str(data["metadata"])
num_frames_from_dof = data["ref_dof_pos"].shape[0]
metadata = json.loads(metadata_json)
num_frames_from_metadata = metadata["num_frames"]
assert num_frames_from_dof == num_frames_from_metadata, (
f"num_frames_from_dof {num_frames_from_dof} != num_frames_from_metadata {num_frames_from_metadata} in {npz_path}"
)
if not isinstance(metadata, dict):
raise ValueError(f"metadata must be a JSON object in {npz_path}")
return metadata_json, metadata
def _cast_array(array: np.ndarray, name: str, npz_path: Path) -> np.ndarray:
if array.dtype == np.float32:
return array
if array.dtype.kind == "O":
raise ValueError(f"Array '{name}' in {npz_path} has object dtype.")
if np.issubdtype(array.dtype, np.integer):
logger.warning(
"Casting array '{}' in {} from {} to float32.",
name,
npz_path,
array.dtype,
)
return array.astype(np.float32, copy=False)
raise ValueError(
f"Array '{name}' in {npz_path} has dtype {array.dtype}, "
"expected float32 or integer."
)
def _discover_array_specs(first_npz: Path) -> List[ArraySpec]:
with np.load(first_npz, allow_pickle=False) as data:
if "ref_dof_pos" not in data:
raise KeyError(f"'ref_dof_pos' missing in NPZ: {first_npz}")
if "ref_global_translation" not in data:
raise KeyError(
f"'ref_global_translation' missing in NPZ: {first_npz}"
)
if "ref_global_rotation_quat" not in data:
raise KeyError(
f"'ref_global_rotation_quat' missing in NPZ: {first_npz}"
)
dof_pos = data["ref_dof_pos"]
global_trans = data["ref_global_translation"]
global_rot = data["ref_global_rotation_quat"]
if dof_pos.ndim < 2:
raise ValueError(f"'ref_dof_pos' must be (T, ndof) in {first_npz}")
if global_trans.ndim < 2 or global_trans.shape[-1] != 3:
raise ValueError(
f"'ref_global_translation' must end with 3 in {first_npz}"
)
if global_rot.ndim < 2 or global_rot.shape[-1] != 4:
raise ValueError(
f"'ref_global_rotation_quat' must end with 4 in {first_npz}"
)
dof_tail = tuple(dof_pos.shape[1:])
return [
ArraySpec(name="ref_dof_pos", shape_tail=dof_tail, dtype=np.float32),
ArraySpec(name="ref_root_pos", shape_tail=(3,), dtype=np.float32),
ArraySpec(name="ref_root_rot", shape_tail=(4,), dtype=np.float32),
]
def _load_npz_arrays(
npz_path: Path,
num_frames: int,
dof_tail: Tuple[int, ...],
) -> Dict[str, np.ndarray]:
with np.load(npz_path, allow_pickle=False) as data:
dof_pos = _cast_array(data["ref_dof_pos"], "ref_dof_pos", npz_path)
global_trans = _cast_array(
data["ref_global_translation"], "ref_global_translation", npz_path
)
global_rot = _cast_array(
data["ref_global_rotation_quat"],
"ref_global_rotation_quat",
npz_path,
)
if global_trans.ndim == 2:
root_pos = global_trans
elif global_trans.ndim >= 3:
root_pos = global_trans[:, 0, :]
else:
raise ValueError(
f"ref_global_translation must be (T,3) or (T,B,3) in {npz_path}"
)
if global_rot.ndim == 2:
root_rot = global_rot
elif global_rot.ndim >= 3:
root_rot = global_rot[:, 0, :]
else:
raise ValueError(
f"ref_global_rotation_quat must be (T,4) or (T,B,4) in {npz_path}"
)
expected_dof_shape = (num_frames, *dof_tail)
if dof_pos.shape != expected_dof_shape:
raise ValueError(
f"ref_dof_pos shape {dof_pos.shape} does not match {expected_dof_shape} "
f"in {npz_path}"
)
if root_pos.shape != (num_frames, 3):
raise ValueError(
f"ref_root_pos shape {root_pos.shape} does not match {(num_frames, 3)} "
f"in {npz_path}"
)
if root_rot.shape != (num_frames, 4):
raise ValueError(
f"ref_root_rot shape {root_rot.shape} does not match {(num_frames, 4)} "
f"in {npz_path}"
)
return {
"ref_dof_pos": dof_pos,
"ref_root_pos": root_pos,
"ref_root_rot": root_rot,
}
def _relative_npz_path(npz_path: Path, roots: List[str]) -> str:
npz_path = npz_path.expanduser().resolve()
for root in roots:
root_path = Path(root).expanduser().resolve()
try:
rel = npz_path.relative_to(root_path)
except ValueError:
continue
return str(Path(root_path.name) / rel)
return str(npz_path)
def _nan_array_names(arrays: Dict[str, np.ndarray]) -> List[str]:
nan_names: List[str] = []
for name, array in arrays.items():
if not np.issubdtype(array.dtype, np.floating):
continue
if np.isnan(array).any():
nan_names.append(name)
return nan_names
return []
def _estimate_bytes_for_motion(npz_path: Path, mode: str) -> int:
"""Estimate per-clip byte contribution for shard sizing.
Note:
- ``uncompressed_nbytes`` matches the in-memory float32 payload size and does
*not* correspond to on-disk shard size when compression is enabled.
- ``npz_filesize`` uses the compressed input file size as a cheap proxy for
on-disk shard size.
- ``h5_filesize`` does not use this estimator (it measures actual shard size
after writes).
"""
mode_norm = str(mode).lower().strip()
if mode_norm in ("uncompressed_nbytes", "nbytes", "uncompressed"):
with np.load(npz_path, allow_pickle=False) as data:
total = 0
for key in (
"ref_dof_pos",
"ref_global_translation",
"ref_global_rotation_quat",
):
if key in data:
total += int(data[key].nbytes)
return total
if mode_norm in ("npz_filesize", "npz_size", "npz_bytes"):
return int(npz_path.stat().st_size)
raise ValueError(
f"Unsupported shard_target_mode '{mode}'. Expected one of: "
"uncompressed_nbytes | npz_filesize | h5_filesize"
)
@hydra.main(
config_path="../../config",
config_name="motion_retargeting/pack_hdf5_v2",
version_base=None,
)
def main(cfg: OmegaConf) -> None:
roots = _normalize_root_list(cfg.get("holomotion_npz_root", None))
if len(roots) == 0:
roots = _normalize_root_list(
cfg.get("holomotion_retargeted_dirs", None)
)
if len(roots) == 0:
legacy_root = cfg.get("precomputed_npz_root", None)
roots = _normalize_root_list(legacy_root)
if len(roots) == 0:
raise ValueError("holomotion_npz_root must be provided.")
hdf5_root = cfg.get(
"hdf5_root", os.path.join(os.getcwd(), "holomotion_hdf5_v2")
)
chunks_t = int(cfg.get("chunks_t", 1024))
compression = str(cfg.get("compression", "lzf")).lower()
shard_target_gb = float(cfg.get("shard_target_gb", 2.0))
shard_target_bytes = int(
cfg.get("shard_target_bytes", shard_target_gb * (1 << 30))
)
shard_target_mode = str(
cfg.get("shard_target_mode", "uncompressed_nbytes")
)
for root in roots:
if not os.path.isdir(root):
raise FileNotFoundError(f"NPZ clips directory not found: {root}")
entries = _discover_motion_entries(roots)
motion_keys: List[str] = []
motion_key2id: Dict[str, int] = {}
nan_npz_paths: List[str] = []
first_npz = Path(entries[0].path)
array_specs = _discover_array_specs(first_npz)
array_names_created = [s.name for s in array_specs]
dof_tail = next(
spec.shape_tail for spec in array_specs if spec.name == "ref_dof_pos"
)
logger.info(
"HDF5 v2 datasets: {} (dof_tail={})",
array_names_created,
dof_tail,
)
dof_names: List[str] = []
body_names: List[str] = []
extended_body_names: List[str] = []
robot_cfg = cfg.get("robot", None)
if robot_cfg is not None and "motion" in robot_cfg:
motion_cfg = robot_cfg["motion"]
dof_names = list(motion_cfg.get("dof_names", []))
body_names = list(motion_cfg.get("body_names", []))
extended_body_names = list(
list(motion_cfg.get("body_names", []))
+ [
i.get("joint_name")
for i in motion_cfg.get("extend_config", [])
]
)
shard_dir = os.path.join(str(hdf5_root), "shards")
_ensure_dir(shard_dir)
hdf5_shards: List[Dict[str, Any]] = []
clips_manifest: Dict[str, Dict[str, Any]] = {}
curr_shard_idx = 0
curr_shard_bytes = 0
writer: Optional[Hdf5ShardWriter] = None
pbar = tqdm(total=len(entries), desc="Packing HDF5 v2 shards")
for entry in entries:
npz_path = Path(entry.path)
metadata_json, metadata = _load_metadata_json(npz_path)
if "num_frames" not in metadata:
raise KeyError(f"'num_frames' missing in metadata: {npz_path}")
num_frames = int(metadata["num_frames"])
if num_frames <= 0:
raise ValueError(f"Invalid num_frames {num_frames} in {npz_path}")
arrays_np = _load_npz_arrays(
npz_path=npz_path, num_frames=num_frames, dof_tail=dof_tail
)
nan_arrays = _nan_array_names(arrays_np)
if len(nan_arrays) > 0:
rel_npz_path = _relative_npz_path(npz_path, roots)
nan_npz_paths.append(rel_npz_path)
logger.warning(
"NaN detected in NPZ (arrays: {}), skipping: {}",
nan_arrays,
npz_path,
)
pbar.update(1)
continue
shard_mode_norm = shard_target_mode.lower().strip()
if shard_mode_norm in (
"h5_filesize",
"h5_size",
"output_filesize",
"disk",
):
if writer is None:
shard_name = f"holomotion_{curr_shard_idx:03d}.h5"
shard_path = os.path.join(shard_dir, shard_name)
writer = Hdf5ShardWriter(
shard_path,
array_specs,
chunks_t=chunks_t,
compression=compression,
)
else:
motion_bytes = _estimate_bytes_for_motion(
npz_path, shard_target_mode
)
if (
writer is None
or (curr_shard_bytes + motion_bytes) > shard_target_bytes
):
if writer is not None:
shard_summary = writer.finalize()
hdf5_shards.append(
{
"file": os.path.relpath(
shard_summary["file"], str(hdf5_root)
),
"num_clips": shard_summary["num_clips"],
"num_frames": shard_summary["num_frames"],
}
)
curr_shard_idx += 1
curr_shard_bytes = 0
shard_name = f"holomotion_{curr_shard_idx:03d}.h5"
shard_path = os.path.join(shard_dir, shard_name)
writer = Hdf5ShardWriter(
shard_path,
array_specs,
chunks_t=chunks_t,
compression=compression,
)
motion_id = motion_key2id.get(entry.name)
if motion_id is None:
motion_id = len(motion_keys)
motion_key2id[entry.name] = motion_id
motion_keys.append(entry.name)
start, length = writer.append_motion(
motion_id=motion_id,
np_arrays=arrays_np,
metadata_json=metadata_json,
)
clips_manifest[entry.name] = {
"motion_key": entry.name,
"shard": curr_shard_idx,
"clip_idx": len(writer._clip_starts) - 1,
"start": int(start),
"length": int(length),
"available_arrays": list(array_names_created),
"metadata": metadata,
}
if shard_mode_norm in (
"h5_filesize",
"h5_size",
"output_filesize",
"disk",
):
writer.h5.flush()
curr_shard_bytes = int(os.path.getsize(writer.h5_path))
else:
curr_shard_bytes += motion_bytes
pbar.update(1)
if (
shard_mode_norm
in ("h5_filesize", "h5_size", "output_filesize", "disk")
and curr_shard_bytes >= shard_target_bytes
and writer is not None
):
shard_summary = writer.finalize()
hdf5_shards.append(
{
"file": os.path.relpath(
shard_summary["file"], str(hdf5_root)
),
"num_clips": shard_summary["num_clips"],
"num_frames": shard_summary["num_frames"],
}
)
curr_shard_idx += 1
curr_shard_bytes = 0
writer = None
pbar.close()
if writer is not None:
shard_summary = writer.finalize()
hdf5_shards.append(
{
"file": os.path.relpath(shard_summary["file"], str(hdf5_root)),
"num_clips": shard_summary["num_clips"],
"num_frames": shard_summary["num_frames"],
}
)
manifest = {
"version": 1,
"root": str(hdf5_root),
"hdf5_shards": hdf5_shards,
"clips": clips_manifest,
"motion_keys": motion_keys,
"dof_names": dof_names,
"body_names": body_names,
"extended_body_names": extended_body_names,
"array_names": array_names_created,
"chunks_t": int(chunks_t),
"compression": compression,
"shard_target_mode": str(shard_target_mode),
"shard_target_bytes": int(shard_target_bytes),
}
_ensure_dir(str(hdf5_root))
nan_paths_path = os.path.join(str(hdf5_root), "nan_npz_paths.json")
with open(nan_paths_path, "w") as f:
json.dump(nan_npz_paths, f, indent=2)
if len(nan_npz_paths) > 0:
logger.warning(
"Skipped {} NPZ files due to NaNs. List: {}",
len(nan_npz_paths),
nan_paths_path,
)
else:
logger.info("No NaN detected in NPZ inputs.")
with open(os.path.join(str(hdf5_root), "manifest.json"), "w") as f:
json.dump(manifest, f, indent=2)
logger.info(
"HDF5 v2 packing complete. Shards: {}. Root: {}",
len(hdf5_shards),
hdf5_root,
)
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/motion_retargeting/reference_filtering.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Dict, Mapping, Tuple
import numpy as np
# This module keeps the offline preprocess filtering path and the online
# root/DoF-before-FK path aligned while still exposing helpers tailored to
# each tensor family.
def _reshape_time_flat(a: np.ndarray) -> Tuple[np.ndarray, Tuple[int, ...]]:
shape = a.shape
t = shape[0]
return a.reshape(t, -1), shape
def _butterworth_lowpass_smooth_time(
a: np.ndarray, fps: float, cutoff_hz: float, order: int
) -> np.ndarray:
from scipy.signal import butter, filtfilt
t = a.shape[0]
if t < 3:
return a.astype(np.float32, copy=True)
if fps <= 0.0 or cutoff_hz <= 0.0:
return a.astype(np.float32, copy=True)
nyquist = 0.5 * float(fps)
wn = float(cutoff_hz) / nyquist
if wn >= 1.0:
wn = 0.999
if wn <= 0.0:
return a.astype(np.float32, copy=True)
flat, shape = _reshape_time_flat(a.astype(np.float64, copy=False))
b, a_coefs = butter(int(order), wn, btype="low", analog=False)
maxlen = max(len(b), len(a_coefs))
padlen_required = max(3 * (maxlen - 1), 3 * maxlen)
if t <= padlen_required:
return a.astype(np.float32, copy=True)
filtered = filtfilt(b, a_coefs, flat, axis=0, method="pad")
return filtered.reshape(shape).astype(np.float32, copy=False)
def _quat_normalize(q: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(q, axis=-1, keepdims=True)
norm = np.where(norm == 0.0, 1.0, norm)
return (q / norm).astype(np.float32, copy=False)
def _quat_hemisphere_align(q: np.ndarray) -> np.ndarray:
if q.shape[0] == 0:
return q
aligned = q.copy()
prev = aligned[0]
for t in range(1, aligned.shape[0]):
dots = np.sum(prev * aligned[t], axis=-1)
mask = dots < 0.0
if np.any(mask):
aligned[t, mask] = -aligned[t, mask]
prev = aligned[t]
return aligned
def _quat_conjugate(q: np.ndarray) -> np.ndarray:
conj = q.copy()
conj[..., :3] = -conj[..., :3]
return conj
def _quat_multiply(a: np.ndarray, b: np.ndarray) -> np.ndarray:
av = a[..., :3]
aw = a[..., 3:4]
bv = b[..., :3]
bw = b[..., 3:4]
cross = np.cross(av, bv)
vec = aw * bv + bw * av + cross
scalar = aw * bw - np.sum(av * bv, axis=-1, keepdims=True)
return np.concatenate([vec, scalar], axis=-1)
def _finite_difference_time(a: np.ndarray, dt: float) -> np.ndarray:
t = a.shape[0]
if t < 2 or dt <= 0.0:
return np.zeros_like(a, dtype=np.float32)
deriv = np.gradient(
a.astype(np.float64, copy=False),
dt,
axis=0,
edge_order=2 if t >= 3 else 1,
)
return deriv.astype(np.float32, copy=False)
def _angular_velocity_from_quat(
q: np.ndarray, q_dot: np.ndarray
) -> np.ndarray:
q_conj = _quat_conjugate(q)
prod = _quat_multiply(q_conj, q_dot)
omega = 2.0 * prod[..., :3]
return omega.astype(np.float32, copy=False)
def butterworth_filter_ref_arrays(
arrays: Mapping[str, np.ndarray],
fps: float,
cutoff_hz: float,
order: int,
) -> Dict[str, np.ndarray]:
out: Dict[str, np.ndarray] = {}
dt = 1.0 / float(fps) if float(fps) > 0.0 else 0.0
if "ref_dof_pos" in arrays:
dof_pos = arrays["ref_dof_pos"].astype(np.float32, copy=True)
smooth_dof_pos = _butterworth_lowpass_smooth_time(
dof_pos, fps, cutoff_hz, order
)
dof_vel = _finite_difference_time(smooth_dof_pos, dt)
out["ft_ref_dof_pos"] = smooth_dof_pos
out["ft_ref_dof_vel"] = dof_vel
if "ref_global_translation" in arrays:
body_pos = arrays["ref_global_translation"].astype(
np.float32, copy=True
)
smooth_body_pos = _butterworth_lowpass_smooth_time(
body_pos, fps, cutoff_hz, order
)
body_vel = _finite_difference_time(smooth_body_pos, dt)
out["ft_ref_global_translation"] = smooth_body_pos
out["ft_ref_global_velocity"] = body_vel
if "ref_global_rotation_quat" in arrays:
body_rot = arrays["ref_global_rotation_quat"].astype(
np.float32, copy=True
)
body_rot = _quat_normalize(body_rot)
body_rot = _quat_hemisphere_align(body_rot)
smooth_body_rot = _butterworth_lowpass_smooth_time(
body_rot, fps, cutoff_hz, order
)
smooth_body_rot = _quat_normalize(smooth_body_rot)
body_rot_dot = _finite_difference_time(smooth_body_rot, dt)
out["ft_ref_global_rotation_quat"] = _quat_normalize(smooth_body_rot)
out["ft_ref_global_angular_velocity"] = _angular_velocity_from_quat(
smooth_body_rot, body_rot_dot
)
return out
def butterworth_filter_root_dof_arrays(
arrays: Mapping[str, np.ndarray],
fps: float,
cutoff_hz: float,
order: int,
) -> Dict[str, np.ndarray]:
out: Dict[str, np.ndarray] = {}
if "ref_root_pos" in arrays:
root_pos = arrays["ref_root_pos"].astype(np.float32, copy=True)
out["ft_ref_root_pos"] = _butterworth_lowpass_smooth_time(
root_pos, fps, cutoff_hz, order
)
if "ref_root_rot" in arrays:
root_rot = arrays["ref_root_rot"].astype(np.float32, copy=True)
root_rot = _quat_normalize(root_rot)
root_rot = _quat_hemisphere_align(root_rot)
smooth_root_rot = _butterworth_lowpass_smooth_time(
root_rot, fps, cutoff_hz, order
)
out["ft_ref_root_rot"] = _quat_normalize(smooth_root_rot)
if "ref_dof_pos" in arrays:
dof_pos = arrays["ref_dof_pos"].astype(np.float32, copy=True)
out["ft_ref_dof_pos"] = _butterworth_lowpass_smooth_time(
dof_pos, fps, cutoff_hz, order
)
return out
================================================
FILE: holomotion/src/motion_retargeting/utils/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/motion_retargeting/utils/_schema.json
================================================
{
"schema": {
"root_trans_offset": {
"shape": [
682,
3
],
"dtype": "float64"
},
"pose_aa": {
"shape": [
682,
27,
3
],
"dtype": "float32"
},
"dof": {
"shape": [
682,
23
],
"dtype": "float32"
},
"root_rot": {
"shape": [
682,
4
],
"dtype": "float64"
},
"smpl_joints": {
"shape": [
682,
24,
3
],
"dtype": "float32"
},
"fps": {
"shape": [],
"dtype": "int64"
}
},
"sample_top_key": "2024-12-28 16.03.15-视频-2025年主播必学120支热门小舞蹈-帅帅的《电话卡点舞》 #电话卡点舞 #...舞 #抖音热歌 #舞蹈教学 #网红必学+p02_1_btws_pad"
}
================================================
FILE: holomotion/src/motion_retargeting/utils/rotation_conversions.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Union
import torch
import torch.nn.functional as F
def wxyz_to_xyzw(quat):
return quat[..., [1, 2, 3, 0]]
def xyzw_to_wxyz(quat):
return quat[..., [3, 0, 1, 2]]
Device = Union[str, torch.device]
"""
The transformation matrices returned from the functions in this file assume
the points on which the transformation will be applied are column vectors.
i.e. the R matrix is structured as
R = [
[Rxx, Rxy, Rxz],
[Ryx, Ryy, Ryz],
[Rzx, Rzy, Rzz],
] # (3, 3)
This matrix can be applied to column vectors by post multiplication
by the points e.g.
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
transformed_points = R * points
To apply the same matrix to points which are row vectors, the R matrix
can be transposed and pre multiplied by the points:
e.g.
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * R.transpose(1, 0)
"""
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Return a tensor of absolute value.
Return a tensor where each element has the absolute value taken from
the corresponding element of a, with sign taken from the corresponding
element of b. This is like the standard copysign floating-point operation,
but is not careful about negative 0 and NaN.
Args:
a: source tensor.
b: tensor whose signs will be used, of the same shape as a.
Returns:
Tensor of the same shape as a with the signs of b.
"""
signs_differ = (a < 0) != (b < 0)
return torch.where(signs_differ, -a, a)
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""Returns torch.sqrt(torch.max(0, x)).
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""W x y z.
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
torch.stack(
[q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1
),
torch.stack(
[m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1
),
torch.stack(
[m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1
),
torch.stack(
[m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1
),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is
# small, the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same
# (up to a sign), forall i; we pick the best-conditioned one
# (with the largest denominator)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,
:, # pyre-ignore[16]
].reshape(batch_dim + (4,))
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
"""Return the rotation matrices for one of the rotations about an axis.
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
r_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "Y":
r_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "Z":
r_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")
return torch.stack(r_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(
euler_angles: torch.Tensor, convention: str
) -> torch.Tensor:
"""Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = [
_axis_angle_rotation(c, e)
for c, e in zip(
convention, torch.unbind(euler_angles, -1), strict=False
)
]
# return functools.reduce(torch.matmul, matrices)
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
def _angle_from_tan(
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
) -> torch.Tensor:
"""Extract the first or third Euler angle from the two members of.
the matrix which are positive constant times its sine and cosine.
Args:
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
convention.
data: Rotation matrices as tensor of shape (..., 3, 3).
horizontal: Whether we are looking for the angle for the third axis,
which means the relevant entries are in the same row of the
rotation matrix. If not, they are in the same column.
tait_bryan: Whether the first and third axes in the convention differ.
Returns:
Euler Angles in radians for each matrix in data as a tensor
of shape (...).
"""
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
def _index_from_letter(letter: str) -> int:
if letter == "X":
return 0
if letter == "Y":
return 1
if letter == "Z":
return 2
raise ValueError("letter must be either X, Y or Z.")
def matrix_to_euler_angles(
matrix: torch.Tensor, convention: str
) -> torch.Tensor:
"""Convert rotations given as rotation matrices to Euler angles in radians.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
convention: Convention string of three uppercase letters.
Returns:
Euler angles in radians as tensor of shape (..., 3).
"""
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
central_angle = torch.asin(
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
)
else:
central_angle = torch.acos(matrix[..., i0, i0])
o = (
_angle_from_tan(
convention[0], convention[1], matrix[..., i2], False, tait_bryan
),
central_angle,
_angle_from_tan(
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
),
)
return torch.stack(o, -1)
def random_quaternions(
n: int,
dtype: Optional[torch.dtype] = None,
device: Optional[Device] = None,
) -> torch.Tensor:
"""Generate random quaternions representing rotations.
i.e. versors with nonnegative real part.
Args:
n: Number of quaternions in a batch to return.
dtype: Type to return.
device: Desired device of returned tensor. Default:
uses the current device for the default tensor type.
Returns:
Quaternions as tensor of shape (N, 4).
"""
if isinstance(device, str):
device = torch.device(device)
o = torch.randn((n, 4), dtype=dtype, device=device)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
return o
def random_rotations(
n: int,
dtype: Optional[torch.dtype] = None,
device: Optional[Device] = None,
) -> torch.Tensor:
"""Generate random rotations as 3x3 rotation matrices.
Args:
n: Number of rotation matrices in a batch to return.
dtype: Type to return.
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type.
Returns:
Rotation matrices as tensor of shape (n, 3, 3).
"""
quaternions = random_quaternions(n, dtype=dtype, device=device)
return quaternion_to_matrix(quaternions)
def random_rotation(
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
) -> torch.Tensor:
"""Generate a single random 3x3 rotation matrix.
Args:
dtype: Type to return
device: Device of returned tensor. Default: if None,
uses the current device for the default tensor type
Returns:
Rotation matrix as tensor of shape (3, 3).
"""
return random_rotations(1, dtype, device)[0]
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert a unit quaternion to a standard form: one in which the real.
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Multiply two quaternions.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions shape (..., 4).
"""
aw, ax, ay, az = torch.unbind(a, -1)
bw, bx, by, bz = torch.unbind(b, -1)
ow = aw * bw - ax * bx - ay * by - az * bz
ox = aw * bx + ax * bw + ay * bz - az * by
oy = aw * by - ax * bz + ay * bw + az * bx
oz = aw * bz + ax * by - ay * bx + az * bw
return torch.stack((ow, ox, oy, oz), -1)
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Multiply two quaternions representing rotations.
Returning the quaternion representing their composition,
i.e. the versor with nonnegative real part.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions of shape (..., 4).
"""
ab = quaternion_raw_multiply(a, b)
return standardize_quaternion(ab)
def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
"""Get the quaternion representingquaternion representing rotation.
Args:
quaternion: Quaternions as tensor of shape (..., 4), with real part
first, which must be versors (unit quaternions).
Returns:
The inverse, a tensor of quaternions of shape (..., 4).
"""
scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
return quaternion * scaling
def quaternion_apply(
quaternion: torch.Tensor, point: torch.Tensor
) -> torch.Tensor:
"""Apply the rotation given by a quaternion to a 3D point.
Usual torch rules for broadcasting apply.
Args:
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
point: Tensor of 3D points of shape (..., 3).
Returns:
Tensor of rotated points of shape (..., 3).
"""
if point.size(-1) != 3:
raise ValueError(f"Points are not in 3D, {point.shape}.")
real_parts = point.new_zeros(point.shape[:-1] + (1,))
point_as_quaternion = torch.cat((real_parts, point), -1)
out = quaternion_raw_multiply(
quaternion_raw_multiply(quaternion, point_as_quaternion),
quaternion_invert(quaternion),
)
return out[..., 1:]
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
"""Convert rotations given as axis/angle to rotation matrices.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
"""Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
"""Convert rotations given as axis/angle to quaternions.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles],
dim=-1,
)
return quaternions
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert rotations given as quaternions to axis/angle.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
return quaternions[..., 1:] / sin_half_angles_over_angles
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
"""Converts 6D rotation to rotation matrix.
Using Gram--Schmidt orthogonalization per Section B of [1].
Representation by Zhou et al. [1]
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
"""Converts rotation matrices to 6D rotation representation by Zhou et al.
by dropping the last row. Note that 6D representation is not unique.
Args:
matrix: batch of rotation matrices of size (*, 3, 3)
Returns:
6D rotation representation, of size (*, 6)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
batch_dim = matrix.size()[:-2]
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
================================================
FILE: holomotion/src/motion_retargeting/utils/torch_humanoid_batch.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
#
# This file was originally copied from the [PHC] repository:
# https://github.com/ZhengyiLuo/PHC
# Modifications have been made to fit the needs of this project.
#
import os
import os.path as osp
import sys
sys.path.append(os.getcwd())
import copy
import logging
import xml.etree.ElementTree as ETree
from collections import OrderedDict, defaultdict
from io import BytesIO
import numpy as np
import open3d as o3d
import scipy.ndimage.filters as filters
import smpl_sim.poselib.core.rotation3d as poselib_rotation3d
import smpl_sim.utils.rotation_conversions as torch_rotation_conversions
import torch
from easydict import EasyDict
from lxml.etree import XMLParser, parse
from omegaconf import DictConfig
from scipy.spatial.transform import Rotation as sRot
from tqdm import tqdm
# from loguru import logger
# Configure logging
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
)
class HumanoidBatch:
def __init__(self, cfg, device=None):
if device is None:
device = torch.device("cpu")
self.cfg = cfg
self.mjcf_file = cfg.asset.assetFileName
parser = XMLParser(remove_blank_text=True)
tree = parse(
BytesIO(open(self.mjcf_file, "rb").read()),
parser=parser,
)
self.dof_axis = []
joints = sorted(
[
j.attrib["name"]
for j in tree.getroot().find("worldbody").findall(".//joint")
]
)
motors = sorted(
[
m.attrib["name"]
for m in tree.getroot().find("actuator").getchildren()
]
)
assert len(motors) > 0, "No motors found in the mjcf file"
self.num_dof = len(motors)
self.num_extend_dof = self.num_dof
self.mjcf_data = mjcf_data = self.from_mjcf(self.mjcf_file)
self.body_names = copy.deepcopy(mjcf_data["node_names"])
# logger.info(f"Body names from {self.mjcf_file}: {self.body_names}")
self._parents = mjcf_data["parent_indices"]
self.body_names_augment = copy.deepcopy(mjcf_data["node_names"])
self._proper_kinematic_structure = copy.deepcopy(
mjcf_data["node_names"]
)
self._offsets = mjcf_data["local_translation"][None,].to(device)
self._local_rotation = mjcf_data["local_rotation"][None,].to(device)
self.actuated_joints_idx = np.array(
[
self.body_names.index(k)
for k, v in mjcf_data["body_to_joint"].items()
]
)
for m in motors:
if m not in joints:
print(m)
if (
"type"
in tree.getroot().find("worldbody").findall(".//joint")[0].attrib
and tree.getroot()
.find("worldbody")
.findall(".//joint")[0]
.attrib["type"]
== "free"
):
for j in tree.getroot().find("worldbody").findall(".//joint")[1:]:
self.dof_axis.append(
[int(i) for i in j.attrib["axis"].split(" ")]
)
self.has_freejoint = True
elif (
"type"
not in tree.getroot()
.find("worldbody")
.findall(".//joint")[0]
.attrib
):
for j in tree.getroot().find("worldbody").findall(".//joint"):
self.dof_axis.append(
[int(i) for i in j.attrib["axis"].split(" ")]
)
self.has_freejoint = True
else:
for j in tree.getroot().find("worldbody").findall(".//joint")[6:]:
self.dof_axis.append(
[int(i) for i in j.attrib["axis"].split(" ")]
)
self.has_freejoint = False
axis_list = []
for _i, axis in enumerate(self.dof_axis):
if axis == [1, 0, 0]:
axis_list.append("x")
elif axis == [0, 1, 0]:
axis_list.append("y")
elif axis == [0, 0, 1]:
axis_list.append("z")
else:
raise ValueError(f"Invalid axis: {axis}")
# print("Axis list for this humanoid: ", axis_list)
self.dof_axis = torch.tensor(self.dof_axis)
for extend_config in cfg.extend_config:
self.body_names_augment += [extend_config.joint_name]
self._parents = torch.cat(
[
self._parents,
torch.tensor(
[self.body_names.index(extend_config.parent_name)]
).to(device),
],
dim=0,
)
self._offsets = torch.cat(
[
self._offsets,
torch.tensor([[extend_config.pos]]).to(device),
],
dim=1,
)
self._local_rotation = torch.cat(
[
self._local_rotation,
torch.tensor([[extend_config.rot]]).to(device),
],
dim=1,
)
self.num_extend_dof += 1
parent_id = self._proper_kinematic_structure.index(
extend_config.parent_name
)
self._proper_kinematic_structure.insert(
parent_id + 1, extend_config.joint_name
)
self.num_bodies = len(self.body_names)
self.num_bodies_augment = len(self.body_names_augment)
self.joints_range = mjcf_data["joints_range"].to(device)
self._local_rotation_mat = (
torch_rotation_conversions.quaternion_to_matrix(
self._local_rotation
).float()
) # w, x, y ,z
self.load_mesh()
self.extend_to_proper_mapping = []
for _i, name in enumerate(self._proper_kinematic_structure):
self.extend_to_proper_mapping.append(
self.body_names_augment.index(name)
)
self.proper_to_extend_mapping = []
for _i, name in enumerate(self.body_names_augment):
self.proper_to_extend_mapping.append(
self._proper_kinematic_structure.index(name)
)
def from_mjcf(self, path):
# function from Poselib:
tree = ETree.parse(path)
xml_doc_root = tree.getroot()
xml_world_body = xml_doc_root.find("worldbody")
if xml_world_body is None:
raise ValueError("MJCF parsed incorrectly please verify it.")
# assume this is the root
xml_body_root = xml_world_body.find("body")
if xml_body_root is None:
raise ValueError("MJCF parsed incorrectly please verify it.")
# xml_joint_root = xml_body_root.find("joint") # Unused variable
node_names = []
parent_indices = []
local_translation = []
local_rotation = []
joints_range = []
body_to_joint = OrderedDict()
# recursively adding all nodes into the skel_tree
def _add_xml_node(xml_node, parent_index, node_index):
node_name = xml_node.attrib.get("name")
# parse the local translation into float list
pos = np.fromstring(
xml_node.attrib.get("pos", "0 0 0"), dtype=float, sep=" "
)
quat = np.fromstring(
xml_node.attrib.get("quat", "1 0 0 0"), dtype=float, sep=" "
)
node_names.append(node_name)
parent_indices.append(parent_index)
local_translation.append(pos)
local_rotation.append(quat)
curr_index = node_index
node_index += 1
all_joints = xml_node.findall(
"joint"
) # joints need to remove the first 6 joints
if len(all_joints) == 6:
all_joints = all_joints[6:]
for joint in all_joints:
if joint.attrib.get("range") is not None:
joints_range.append(
np.fromstring(
joint.attrib.get("range"), dtype=float, sep=" "
)
)
else:
if not joint.attrib.get("type") == "free":
joints_range.append([-np.pi, np.pi])
for joint_node in xml_node.findall("joint"):
body_to_joint[node_name] = joint_node.attrib.get("name")
for next_node in xml_node.findall("body"):
node_index = _add_xml_node(next_node, curr_index, node_index)
return node_index
_add_xml_node(xml_body_root, -1, 0)
assert len(joints_range) == self.num_dof
return {
"node_names": node_names,
"parent_indices": torch.from_numpy(
np.array(parent_indices, dtype=np.int32)
),
"local_translation": torch.from_numpy(
np.array(local_translation, dtype=np.float32)
),
"local_rotation": torch.from_numpy(
np.array(local_rotation, dtype=np.float32)
),
"joints_range": torch.from_numpy(np.array(joints_range)),
"body_to_joint": body_to_joint,
}
def fk_batch(
self, pose, trans, convert_to_mat=True, return_full=False, dt=1 / 30
):
# device, dtype = pose.device, pose.dtype # Unused variables
# pose_input = pose.clone() # Unused variable
b, seq_len = pose.shape[:2]
pose = pose[
..., : len(self._parents), :
] # H1 fitted joints might have extra joints
if convert_to_mat:
pose_quat = torch_rotation_conversions.axis_angle_to_quaternion(
pose.clone()
)
pose_mat = torch_rotation_conversions.quaternion_to_matrix(
pose_quat
)
else:
pose_mat = pose
if pose_mat.shape != 5:
pose_mat = pose_mat.reshape(b, seq_len, -1, 3, 3)
# j = pose_mat.shape[2] - 1 # Exclude root - unused variable
wbody_pos, wbody_mat = self.forward_kinematics_batch(
pose_mat[:, :, 1:], pose_mat[:, :, 0:1], trans
)
return_dict = EasyDict()
wbody_rot = torch_rotation_conversions.wxyz_to_xyzw(
torch_rotation_conversions.matrix_to_quaternion(wbody_mat)
)
if len(self.cfg.extend_config) > 0:
if return_full:
return_dict.global_velocity_extend = self._compute_velocity(
wbody_pos, dt
)
return_dict.global_angular_velocity_extend = (
self._compute_angular_velocity(wbody_rot, dt)
)
return_dict.global_translation_extend = wbody_pos.clone()
return_dict.global_rotation_mat_extend = wbody_mat.clone()
return_dict.global_rotation_extend = wbody_rot
wbody_pos = wbody_pos[..., : self.num_bodies, :]
wbody_mat = wbody_mat[..., : self.num_bodies, :, :]
wbody_rot = wbody_rot[..., : self.num_bodies, :]
return_dict.global_translation = wbody_pos
return_dict.global_rotation_mat = wbody_mat
return_dict.global_rotation = wbody_rot
if return_full:
rigidbody_linear_velocity = self._compute_velocity(wbody_pos, dt)
# Isaac gym is [x, y, z, w]. All the previous functions are
# [w, x, y, z]
rigidbody_angular_velocity = self._compute_angular_velocity(
wbody_rot, dt
)
return_dict.local_rotation = (
torch_rotation_conversions.wxyz_to_xyzw(pose_quat)
)
return_dict.global_root_velocity = rigidbody_linear_velocity[
..., 0, :
]
return_dict.global_root_angular_velocity = (
rigidbody_angular_velocity[..., 0, :]
)
return_dict.global_angular_velocity = rigidbody_angular_velocity
return_dict.global_velocity = rigidbody_linear_velocity
if len(self.cfg.extend_config) > 0:
return_dict.dof_pos = pose.sum(dim=-1)[
..., 1 : self.num_bodies
]
# you can sum it up since unitree's each joint has 1 dof.
# Last two are for hands. doesn't really matter.
else:
if not len(self.actuated_joints_idx) == len(self.body_names):
return_dict.dof_pos = pose.sum(dim=-1)[
..., self.actuated_joints_idx
]
else:
return_dict.dof_pos = pose.sum(dim=-1)[..., 1:]
dof_vel = (
return_dict.dof_pos[:, 1:] - return_dict.dof_pos[:, :-1]
) / dt
return_dict.dof_vels = torch.cat(
[dof_vel, dof_vel[:, -2:-1]], dim=1
)
return_dict.fps = int(1 / dt)
return return_dict
def convert_to_proper_kinematic(self, return_dict):
if len(self.cfg.extend_config) > 0:
return_dict.global_translation_extend = (
return_dict.global_translation_extend[
..., self.extend_to_proper_mapping, :
]
)
return_dict.global_rotation_mat_extend = (
return_dict.global_rotation_mat_extend[
..., self.extend_to_proper_mapping, :, :
]
)
return_dict.global_rotation_extend = (
return_dict.global_rotation_extend[
..., self.extend_to_proper_mapping, :
]
)
return_dict.global_velocity_extend = (
return_dict.global_velocity_extend[
..., self.extend_to_proper_mapping, :
]
)
return_dict.global_angular_velocity_extend = (
return_dict.global_angular_velocity_extend[
..., self.extend_to_proper_mapping, :
]
)
else:
return_dict.global_translation = return_dict.global_translation[
..., self.extend_to_proper_mapping, :
]
return_dict.global_rotation_mat = return_dict.global_rotation_mat[
..., self.extend_to_proper_mapping, :, :
]
return_dict.global_rotation = return_dict.global_rotation[
..., self.extend_to_proper_mapping, :
]
return_dict.global_velocity = return_dict.global_velocity[
..., self.extend_to_proper_mapping, :
]
return_dict.global_angular_velocity = (
return_dict.global_angular_velocity[
..., self.extend_to_proper_mapping, :
]
)
return return_dict
def forward_kinematics_batch(
self, rotations, root_rotations, root_positions
):
"""Perform forward kinematics using the trajectory and rotations.
Arguments (where B = batch size, J = number of joints):
-- rotations: (B, J, 4) tensor of unit quaternions describing the
local rotations of each joint.
-- root_positions: (B, 3) tensor describing the root joint positions.
Output: joint positions (B, J, 3)
Reference:
https://github.com/ZhengyiLuo/PHC/blob/master/phc/utils/
torch_humanoid_batch.py
"""
device, dtype = root_rotations.device, root_rotations.dtype
b, seq_len = rotations.size()[0:2]
j = self._offsets.shape[1]
positions_world = []
rotations_world = []
expanded_offsets = (
self._offsets[:, None]
.expand(b, seq_len, j, 3)
.to(device)
.type(dtype)
)
# print(expanded_offsets.shape, j)
for i in range(j):
if self._parents[i] == -1:
positions_world.append(root_positions)
rotations_world.append(root_rotations)
else:
jpos = (
torch.matmul(
rotations_world[self._parents[i]][:, :, 0],
expanded_offsets[:, :, i, :, None],
).squeeze(-1)
+ positions_world[self._parents[i]]
)
rot_mat = torch.matmul(
rotations_world[self._parents[i]],
torch.matmul(
self._local_rotation_mat[:, (i) : (i + 1)],
rotations[:, :, (i - 1) : i, :],
),
)
# rot_mat = torch.matmul(rotations_world[self._parents[i]],
# rotations[:, :, (i - 1):i, :])
# print(rotations[:, :, (i - 1):i, :].shape,
# self._local_rotation_mat.shape)
positions_world.append(jpos)
rotations_world.append(rot_mat)
positions_world = torch.stack(positions_world, dim=2)
rotations_world = torch.cat(rotations_world, dim=2)
return positions_world, rotations_world
@staticmethod
def _compute_velocity(p, time_delta, guassian_filter=True):
velocity = np.gradient(p.numpy(), axis=-3) / time_delta
if guassian_filter:
velocity = torch.from_numpy(
filters.gaussian_filter1d(velocity, 2, axis=-3, mode="nearest")
).to(p)
else:
velocity = torch.from_numpy(velocity).to(p)
return velocity
@staticmethod
def _compute_angular_velocity(r, time_delta: float, guassian_filter=True):
# assume the second last dimension is the time axis
diff_quat_data = poselib_rotation3d.quat_identity_like(r).to(r)
diff_quat_data[..., :-1, :, :] = poselib_rotation3d.quat_mul_norm(
r[..., 1:, :, :],
poselib_rotation3d.quat_inverse(r[..., :-1, :, :]),
)
diff_angle, diff_axis = poselib_rotation3d.quat_angle_axis(
diff_quat_data
)
angular_velocity = diff_axis * diff_angle.unsqueeze(-1) / time_delta
if guassian_filter:
angular_velocity = torch.from_numpy(
filters.gaussian_filter1d(
angular_velocity.numpy(), 2, axis=-3, mode="nearest"
),
)
return angular_velocity
def load_mesh(self):
xml_base = os.path.dirname(self.mjcf_file)
# Read the compiler tag from the g1.xml file to find if there is a
# meshdir defined
tree = ETree.parse(self.mjcf_file)
xml_doc_root = tree.getroot()
compiler_tag = xml_doc_root.find("compiler")
if compiler_tag is not None and "meshdir" in compiler_tag.attrib:
mesh_base = os.path.join(xml_base, compiler_tag.attrib["meshdir"])
else:
mesh_base = xml_base
self.tree = tree = ETree.parse(self.mjcf_file)
xml_doc_root = tree.getroot()
xml_world_body = xml_doc_root.find("worldbody")
xml_assets = xml_doc_root.find("asset")
all_mesh = xml_assets.findall(".//mesh")
geoms = xml_world_body.findall(".//geom")
# all_joints = xml_world_body.findall(".//joint") # Unused variable
# all_motors = tree.findall(".//motor") # Unused variable
# all_bodies = xml_world_body.findall(".//body") # Unused variable
def find_parent(root, child):
for parent in root.iter():
for elem in parent:
if elem == child:
return parent
return None
mesh_dict = {}
# mesh_parent_dict = {} # Unused variable
for mesh_file_node in all_mesh:
mesh_name = mesh_file_node.attrib["name"]
mesh_file = mesh_file_node.attrib["file"]
mesh_full_file = osp.join(mesh_base, mesh_file)
mesh_obj = o3d.io.read_triangle_mesh(mesh_full_file)
mesh_dict[mesh_name] = mesh_obj
geom_transform = {}
body_to_mesh = defaultdict(set)
mesh_to_body = {}
for geom_node in geoms:
if "mesh" in geom_node.attrib:
parent = find_parent(xml_doc_root, geom_node)
body_to_mesh[parent.attrib["name"]].add(
geom_node.attrib["mesh"]
)
mesh_to_body[geom_node] = parent
if "pos" in geom_node.attrib or "quat" in geom_node.attrib:
geom_transform[parent.attrib["name"]] = {}
geom_transform[parent.attrib["name"]]["pos"] = np.array(
[0.0, 0.0, 0.0]
)
geom_transform[parent.attrib["name"]]["quat"] = np.array(
[1.0, 0.0, 0.0, 0.0]
)
if "pos" in geom_node.attrib:
geom_transform[parent.attrib["name"]]["pos"] = (
np.array(
[
float(f)
for f in geom_node.attrib["pos"].split(" ")
]
)
)
if "quat" in geom_node.attrib:
geom_transform[parent.attrib["name"]]["quat"] = (
np.array(
[
float(f)
for f in geom_node.attrib["quat"].split(
" "
)
]
)
)
else:
pass
self.geom_transform = geom_transform
self.mesh_dict = mesh_dict
self.body_to_mesh = body_to_mesh
self.mesh_to_body = mesh_to_body
def mesh_fk(self, pose=None, trans=None):
"""Load the mesh from the XML file and merge into the humanoid.
Reference:
https://github.com/ZhengyiLuo/PHC/blob/master/phc/utils/
torch_humanoid_batch.py
"""
if pose is None:
fk_res = self.fk_batch(
torch.zeros(1, 1, len(self.body_names_augment), 3),
torch.zeros(1, 1, 3),
)
else:
fk_res = self.fk_batch(pose, trans)
g_trans = fk_res.global_translation.squeeze()
g_rot = fk_res.global_rotation_mat.squeeze()
geoms = self.tree.find("worldbody").findall(".//geom")
joined_mesh_obj = []
for geom in geoms:
if "mesh" not in geom.attrib:
continue
# parent_name = geom.attrib["mesh"]
k = self.mesh_to_body[geom].attrib["name"]
mesh_names = self.body_to_mesh[k]
body_idx = self.body_names.index(k)
body_trans = g_trans[body_idx].numpy().copy()
body_rot = g_rot[body_idx].numpy().copy()
for mesh_name in mesh_names:
mesh_obj = copy.deepcopy(self.mesh_dict[mesh_name])
if k in self.geom_transform:
pos = self.geom_transform[k]["pos"]
quat = self.geom_transform[k]["quat"]
body_trans = body_trans + body_rot @ pos
global_rot = (
body_rot
@ sRot.from_quat(quat[[1, 2, 3, 0]]).as_matrix()
).T
else:
global_rot = body_rot.T
mesh_obj.rotate(global_rot.T, center=(0, 0, 0))
mesh_obj.translate(body_trans)
joined_mesh_obj.append(mesh_obj)
# Merge all meshes into a single mesh
merged_mesh = joined_mesh_obj[0]
for mesh in joined_mesh_obj[1:]:
merged_mesh += mesh
# Save the merged mesh to a file
# merged_mesh.compute_vertex_normals()
# o3d.io.write_triangle_mesh(f"data/{self.cfg.humanoid_type}/
# combined_{self.cfg.humanoid_type}.stl", merged_mesh)
return merged_mesh
# @hydra.main(version_base=None, config_path="../../phc/data/cfg",
# config_name="config")
def main(cfg: DictConfig):
device = torch.device("cpu")
humanoid_fk = HumanoidBatch(cfg.robot, device)
humanoid_fk.mesh_fk()
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/motion_retargeting/utils/visualize_with_mujoco.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
#
# This file was originally copied from the [PHC] repository:
# https://github.com/ZhengyiLuo/PHC
# Modifications have been made to fit the needs of this project.
import glob
import os
from typing import Any, Dict, List, Tuple
import cv2
import hydra
import mujoco
import numpy as np
import ray
from omegaconf import DictConfig
from tqdm.auto import tqdm
class OffscreenRenderer:
"""Offscreen renderer (no SMPL markers or joint spheres)."""
def __init__(self, model, height, width):
self.model = model
self.height = height
self.width = width
# Create OpenGL context
self.ctx = mujoco.GLContext(width, height)
self.ctx.make_current()
# Scene and camera setup
self.scene = mujoco.MjvScene(model, maxgeom=1000)
self.cam = mujoco.MjvCamera()
self.opt = mujoco.MjvOption()
self.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
self.cam.distance = 4.0
self.cam.azimuth = 60.0
self.cam.elevation = -20
self.cam.lookat = np.array([0.0, 0.0, 1.0])
# Rendering context
self.con = mujoco.MjrContext(
model, mujoco.mjtFontScale.mjFONTSCALE_100
)
# Buffers
self.rgb_buffer = np.zeros((height, width, 3), dtype=np.uint8)
self.viewport = mujoco.MjrRect(0, 0, width, height)
def render(
self,
data,
ref_body_positions: np.ndarray | None = None,
ref_marker_radius: float = 0.03,
ref_marker_rgba: np.ndarray | None = None,
):
mujoco.mjv_updateScene(
self.model,
data,
self.opt,
None,
self.cam,
mujoco.mjtCatBit.mjCAT_ALL.value,
self.scene,
)
_draw_body_spheres_to_scene(
scene=self.scene,
body_positions=ref_body_positions,
radius=ref_marker_radius,
rgba=ref_marker_rgba,
)
mujoco.mjr_render(self.viewport, self.scene, self.con)
mujoco.mjr_readPixels(self.rgb_buffer, None, self.viewport, self.con)
return np.flipud(self.rgb_buffer)
def close(self):
self.ctx.free()
def _get_key_prefix_order(cfg: DictConfig) -> List[str]:
"""
Determine the key prefix order used to extract arrays from NPZ files.
Priority:
1) cfg.key_prefix_order (list or single value)
2) cfg.key_prefix (single value)
3) default ["ref_", "", "robot_"]
"""
configured = cfg.get("key_prefix_order", None)
if configured is not None:
order_list = (
[str(p) for p in configured]
if isinstance(configured, (list, tuple))
else [str(configured)]
)
else:
single = cfg.get("key_prefix", None)
if single is not None:
order_list = [str(single)]
else:
order_list = ["ref_", "", "robot_"]
print(f"Using key_prefix_order: {order_list}")
return order_list
def _get_ref_key_prefix_order(cfg: DictConfig) -> List[str]:
"""Determine the prefix order used to read reference overlay arrays."""
configured = cfg.get("ref_key_prefix_order", None)
if configured is not None:
order_list = (
[str(p) for p in configured]
if isinstance(configured, (list, tuple))
else [str(configured)]
)
else:
single = cfg.get("ref_key_prefix", None)
if single is not None:
order_list = [str(single)]
else:
order_list = ["ref_"]
print(f"Using ref_key_prefix_order: {order_list}")
return order_list
def _pick_with_prefixes(
arrays: Dict[str, np.ndarray],
base_name: str,
prefixes: List[str],
) -> np.ndarray | None:
"""
Return arrays[prefix + base_name] for the first matching prefix in order.
For non-empty prefixes, also attempts "_".
"""
for prefix in prefixes:
if prefix == "":
candidate = base_name
if candidate in arrays:
return arrays[candidate]
else:
cand1 = f"{prefix}{base_name}"
if cand1 in arrays:
return arrays[cand1]
cand2 = f"{prefix.rstrip('_')}_{base_name}"
if cand2 in arrays:
return arrays[cand2]
return None
def _resolve_visualization_arrays(
arrays: Dict[str, np.ndarray],
key_prefix_order: List[str],
draw_ref_body_spheres: bool = False,
ref_key_prefix_order: List[str] | None = None,
) -> Dict[str, np.ndarray | None]:
"""Resolve playback arrays and optional reference overlay arrays."""
dof_pos = _pick_with_prefixes(arrays, "dof_pos", key_prefix_order)
global_translation = _pick_with_prefixes(
arrays, "global_translation", key_prefix_order
)
global_rotation_quat = _pick_with_prefixes(
arrays, "global_rotation_quat", key_prefix_order
)
ref_body_positions = None
if draw_ref_body_spheres:
ref_prefixes = (
ref_key_prefix_order
if ref_key_prefix_order is not None
else ["ref_"]
)
ref_body_positions = _pick_with_prefixes(
arrays, "global_translation", ref_prefixes
)
return {
"dof_pos": dof_pos,
"global_translation": global_translation,
"global_rotation_quat": global_rotation_quat,
"ref_body_positions": ref_body_positions,
}
def _draw_body_spheres_to_scene(
scene,
body_positions: np.ndarray | None,
radius: float,
rgba: np.ndarray | None,
) -> None:
"""Append sphere markers for body positions to the current MuJoCo scene."""
if body_positions is None:
return
sphere_rgba = (
np.array([0.8, 0.0, 0.0, 1.0], dtype=np.float32)
if rgba is None
else np.asarray(rgba, dtype=np.float32)
)
size = np.array([radius, 0.0, 0.0], dtype=np.float32)
mat = np.eye(3, dtype=np.float32).reshape(-1)
start = int(scene.ngeom)
idx = 0
for pos in body_positions:
geom_id = start + idx
if geom_id >= scene.maxgeom:
break
mujoco.mjv_initGeom(
scene.geoms[geom_id],
mujoco.mjtGeom.mjGEOM_SPHERE,
size,
pos.astype(np.float32),
mat,
sphere_rgba,
)
idx += 1
scene.ngeom = start + idx
def _load_npz_as_motion(
npz_path: str,
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], str]:
"""
Load a single .npz file and return (arrays_dict, metadata_dict, motion_name)
- metadata: parsed from JSON
- motion_name: file name without extension
"""
with np.load(npz_path) as z:
arrays = {k: z[k] for k in z.files if k != "metadata"}
meta_raw = z.get("metadata", None)
if meta_raw is None:
metadata = {}
else:
metadata = {}
try:
metadata = dict(np.atleast_1d(meta_raw).tolist())
except Exception:
try:
metadata = {**(dict()), **(eval(str(meta_raw)))}
except Exception:
pass
# Parse metadata as JSON string
try:
import json
metadata = json.loads(str(np.atleast_1d(meta_raw)[0]))
except Exception:
pass
motion_name = os.path.splitext(os.path.basename(npz_path))[0]
return arrays, metadata, motion_name
def _collect_all_npz(
npz_root: str, motion_name: str
) -> List[Tuple[Dict[str, np.ndarray], Dict[str, Any], str]]:
"""Collect all NPZ files to process based on configuration."""
print("Collecting NPZ files...", npz_root, motion_name)
base = (
os.path.join(npz_root, "clips")
if os.path.isdir(os.path.join(npz_root, "clips"))
else npz_root
)
if motion_name == "all":
npz_files = [
p
for p in glob.glob(
os.path.join(base, "**", "*.npz"), recursive=True
)
]
else:
# try both base and base/clips
candidate = os.path.join(base, f"{motion_name}.npz")
npz_files = [candidate]
motions = []
for f in tqdm(npz_files, desc="Loading npz files"):
try:
arrays, metadata, name = _load_npz_as_motion(f)
motions.append((arrays, metadata, name))
except Exception as e:
print(f"Failed to load {f}: {e}")
return motions
def _infer_fps_from_meta(
metadata: Dict[str, Any], default_fps: float = 50.0
) -> float:
"""Infer FPS value from metadata."""
try:
return float(metadata.get("motion_fps", default_fps))
except Exception:
return float(default_fps)
def _time_length(*arrays) -> int:
"""Return the smallest time dimension length among given arrays, ignoring None."""
T = None
for a in arrays:
if isinstance(a, np.ndarray) and a.ndim >= 1:
t = a.shape[0]
T = t if T is None else min(T, t)
return T if T is not None else 0
@ray.remote
def process_single_motion_remote_npz(
arrays: Dict[str, np.ndarray],
metadata: Dict[str, Any],
motion_name: str,
cfg_dict: dict,
) -> str:
try:
cfg = DictConfig(cfg_dict)
# MuJoCo model
mj_model = mujoco.MjModel.from_xml_path(cfg.robot.asset.assetFileName)
mj_data = mujoco.MjData(mj_model)
# Renderer
width, height = 1280, 720
renderer = OffscreenRenderer(mj_model, height, width)
# FPS
src_fps = _infer_fps_from_meta(metadata, default_fps=50.0)
skip_frames = getattr(cfg, "skip_frames", 1)
actual_fps = src_fps / max(1, int(skip_frames))
# Video writer
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out_path = os.path.join(cfg.video_dir, f"{motion_name}.mp4")
os.makedirs(os.path.dirname(out_path), exist_ok=True)
out = cv2.VideoWriter(out_path, fourcc, actual_fps, (width, height))
try:
prefix_order = _get_key_prefix_order(cfg)
draw_ref_body_spheres = bool(
getattr(cfg, "draw_ref_body_spheres", False)
)
ref_prefix_order = _get_ref_key_prefix_order(cfg)
resolved = _resolve_visualization_arrays(
arrays=arrays,
key_prefix_order=prefix_order,
draw_ref_body_spheres=draw_ref_body_spheres,
ref_key_prefix_order=ref_prefix_order,
)
dof_pos = resolved["dof_pos"]
gpos = resolved["global_translation"]
grot = resolved["global_rotation_quat"]
ref_body_positions = resolved["ref_body_positions"]
if (
not isinstance(dof_pos, np.ndarray)
or not isinstance(gpos, np.ndarray)
or not isinstance(grot, np.ndarray)
):
raise ValueError(
"Missing required NPZ keys: dof_pos / global_translation / global_rotation_quat"
)
# Time dimension alignment
T = _time_length(dof_pos, gpos, grot, ref_body_positions)
if T == 0:
raise ValueError("No valid frames found.")
for t in range(0, T, max(1, int(skip_frames))):
# Root position and quaternion: take from body 0
root_pos = gpos[t, 0]
root_quat_xyzw = grot[t, 0]
root_quat_wxyz = root_quat_xyzw[[3, 0, 1, 2]]
mj_data.qpos[:3] = root_pos
mj_data.qpos[3:7] = root_quat_wxyz
mj_data.qpos[7:] = dof_pos[t]
mujoco.mj_forward(mj_model, mj_data)
safe_lookat = np.array(
renderer.cam.lookat
) # 当前相机中心,先取出来
safe_lookat[0] = root_pos[0]
safe_lookat[1] = root_pos[1]
min_height = 1.0
safe_lookat[2] = max(root_pos[2], min_height)
renderer.cam.lookat[:] = safe_lookat
frame_ref_body_positions = (
ref_body_positions[t]
if isinstance(ref_body_positions, np.ndarray)
else None
)
frame = renderer.render(
mj_data,
ref_body_positions=frame_ref_body_positions,
)
# Convert RGB (MuJoCo) -> BGR (OpenCV) before writing
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
finally:
out.release()
renderer.close()
return motion_name
except Exception as e:
return f"ERROR_{motion_name}: {str(e)}"
class MotionRendererNPZ:
def process_single_motion(
self,
arrays: Dict[str, np.ndarray],
metadata: Dict[str, Any],
motion_name: str,
cfg: DictConfig,
):
mj_model = mujoco.MjModel.from_xml_path(cfg.robot.asset.assetFileName)
mj_data = mujoco.MjData(mj_model)
width, height = 1280, 720
renderer = OffscreenRenderer(mj_model, height, width)
src_fps = _infer_fps_from_meta(metadata, default_fps=50.0)
skip_frames = getattr(cfg, "skip_frames", 1)
actual_fps = src_fps / max(1, int(skip_frames))
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out_path = os.path.join(cfg.video_dir, f"{motion_name}.mp4")
os.makedirs(os.path.dirname(out_path), exist_ok=True)
out = cv2.VideoWriter(out_path, fourcc, actual_fps, (width, height))
try:
prefix_order = _get_key_prefix_order(cfg)
draw_ref_body_spheres = bool(
getattr(cfg, "draw_ref_body_spheres", False)
)
ref_prefix_order = _get_ref_key_prefix_order(cfg)
resolved = _resolve_visualization_arrays(
arrays=arrays,
key_prefix_order=prefix_order,
draw_ref_body_spheres=draw_ref_body_spheres,
ref_key_prefix_order=ref_prefix_order,
)
dof_pos = resolved["dof_pos"]
gpos = resolved["global_translation"]
grot = resolved["global_rotation_quat"]
ref_body_positions = resolved["ref_body_positions"]
if (
not isinstance(dof_pos, np.ndarray)
or not isinstance(gpos, np.ndarray)
or not isinstance(grot, np.ndarray)
):
raise ValueError(
"Missing required NPZ keys: dof_pos / global_translation / global_rotation_quat"
)
T = _time_length(dof_pos, gpos, grot, ref_body_positions)
if T == 0:
raise ValueError("No valid frames found.")
for t in tqdm(
range(0, T, max(1, int(skip_frames))),
desc=f"Rendering {motion_name}",
):
root_pos = gpos[t, 0]
root_quat_xyzw = grot[t, 0]
root_quat_wxyz = root_quat_xyzw[[3, 0, 1, 2]]
mj_data.qpos[:3] = root_pos
mj_data.qpos[3:7] = root_quat_wxyz
mj_data.qpos[7:] = dof_pos[t]
mujoco.mj_forward(mj_model, mj_data)
safe_lookat = np.array(
renderer.cam.lookat
) # 当前相机中心,先取出来
safe_lookat[0] = root_pos[0]
safe_lookat[1] = root_pos[1]
min_height = 1.0
safe_lookat[2] = max(root_pos[2], min_height)
renderer.cam.lookat[:] = safe_lookat
frame_ref_body_positions = (
ref_body_positions[t]
if isinstance(ref_body_positions, np.ndarray)
else None
)
frame = renderer.render(
mj_data,
ref_body_positions=frame_ref_body_positions,
)
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
finally:
out.release()
renderer.close()
return motion_name
@hydra.main(
version_base=None,
config_path="../../../config/motion_retargeting",
config_name="unitree_G1_29dof_retargeting",
)
def main(cfg: DictConfig) -> None:
"""
Required config fields:
- cfg.robot.asset.assetFileName : Path to the MuJoCo XML file
- cfg.video_dir : Output video directory
- cfg.motion_npz_root : Directory containing NPZ files
- cfg.motion_name : "all" or a specific clip name (without extension)
- cfg.skip_frames : Frame step size (>=1)
Optional:
- cfg.key_prefix_order : List[str] or str for key prefix matching order
- cfg.key_prefix : Single prefix to use (overridden by key_prefix_order)
"""
try:
# NPZ input
motions = _collect_all_npz(cfg.motion_npz_root, cfg.motion_name)
if not motions:
print("No NPZ motions found.")
return
# Ray parallel or single-thread mode
if cfg.motion_name == "all":
if not ray.is_initialized():
num_cpus = min(os.cpu_count(), cfg.get("max_workers", 8))
ray.init(num_cpus=num_cpus)
print(f"Initialized Ray with {num_cpus} workers")
cfg_dict = dict(cfg)
tasks = [
process_single_motion_remote_npz.remote(
arr, meta, name, cfg_dict
)
for (arr, meta, name) in motions
]
completed, failed = [], []
with tqdm(total=len(tasks), desc="Processing Motions") as pbar:
remaining = list(tasks)
while remaining:
ready, remaining = ray.wait(
remaining, num_returns=1, timeout=1.0
)
for t in ready:
try:
res = ray.get(t)
if isinstance(res, str) and res.startswith(
"ERROR_"
):
failed.append(res)
print(f"Failed: {res}")
else:
completed.append(res)
print(f"Completed: {res}")
except Exception as e:
failed.append(f"Task exception: {e}")
pbar.update(1)
print("\nProcessing complete!")
print(f"Success: {len(completed)}; Failed: {len(failed)}")
if failed:
for f in failed:
print(" -", f)
ray.shutdown()
else:
renderer = MotionRendererNPZ()
for arr, meta, name in motions:
res = renderer.process_single_motion(arr, meta, name, cfg)
print(f"Processed: {res}")
except Exception as e:
print(f"Error during processing: {e}")
if ray.is_initialized():
ray.shutdown()
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/training/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/training/h5_dataloader.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Simplified HDF5 motion cache backed by a PyTorch ``DataLoader``.
This module provides two core utilities:
* ``Hdf5MotionDataset`` – loads contiguous motion windows directly from HDF5
shards using metadata stored in ``manifest.json``.
* ``MotionClipBatchCache`` – maintains a double-buffered cache of motion clips
with deterministic swapping semantics suitable for high-throughput
reinforcement learning.
Compared to the legacy slot-based prefetcher, this implementation keeps the
pipeline intentionally simple:
* A dataset-worker keeps shard handles open locally; no Ray dependency.
* Each cached batch has a fixed shape
``[max_num_clips, max_frame_length, feature_dims]``.
* Swapping a batch is handled via an O(1) pointer flip once the next batch is
staged on the desired device (CPU or GPU).
The cache exposes helper methods that mirror the data access patterns required
by ``RefMotionCommand``:
* ``sample_env_assignments`` for initial clip/frame sampling.
* ``gather_tensor`` to fetch exactly one tensor field for ``1 + n_future``
frames per environment.
All tensors returned by this module are ``torch.float32`` unless stated
otherwise; tensor shapes are noted explicitly in type annotations.
"""
from __future__ import annotations
import json
import math
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import (
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
)
import h5py
import numpy as np
import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, Dataset, DistributedSampler, Sampler
from loguru import logger
from tabulate import tabulate
from tqdm import tqdm
from holomotion.src.motion_retargeting.reference_filtering import (
butterworth_filter_root_dof_arrays,
)
from holomotion.src.utils import torch_utils
from holomotion.src.motion_retargeting.holomotion_fk import HoloMotionFK
Tensor = torch.Tensor
def _cpu_only_dataloader_worker_init_fn(worker_id: int) -> None:
"""Keep cache workers lightweight without mutating CUDA visibility."""
del worker_id
torch.set_num_threads(1)
def _allocate_batch_counts(
raw_counts: List[float], target_total: int
) -> List[int]:
"""Allocate integer counts that sum exactly to target_total."""
total = int(max(0, target_total))
if len(raw_counts) == 0:
return []
base_counts = [max(0, int(c)) for c in raw_counts]
residuals = [float(c) - float(int(c)) for c in raw_counts]
remaining = total - int(sum(base_counts))
if remaining > 0:
order = sorted(
range(len(residuals)),
key=lambda i: residuals[i],
reverse=True,
)
idx_pos = 0
while remaining > 0:
j = order[idx_pos % len(order)]
base_counts[j] += 1
remaining -= 1
idx_pos += 1
elif remaining < 0:
order = sorted(range(len(residuals)), key=lambda i: residuals[i])
idx_pos = 0
while remaining < 0:
j = order[idx_pos % len(order)]
if base_counts[j] > 0:
base_counts[j] -= 1
remaining += 1
idx_pos += 1
if sum(base_counts) != total:
raise RuntimeError(
"Internal error: integer batch-count allocation did not preserve total."
)
return [max(0, int(c)) for c in base_counts]
def _configure_weighted_bins(
keys: List[str],
cfg: Mapping[str, Any],
batch_size_for_log: int,
) -> Tuple[List[List[int]], List[float], List[Dict[str, Any]]]:
"""Common helper to parse config, assign bins, and compute batch fractions."""
if batch_size_for_log <= 0:
batch_size_for_log = 1
cfg_local: Dict[str, Any] = dict(cfg or {})
patterns_cfg = cfg_local.get("bin_regex_patterns")
if patterns_cfg is None:
patterns_cfg = cfg_local.get("bin_regrex_patterns")
if not patterns_cfg:
raise ValueError(
"weighted_bin configuration requires 'bin_regex_patterns' "
"(list of {regex, ratio}) to be configured"
)
compiled_patterns: List[Dict[str, Any]] = []
ratios: List[float] = []
for idx, entry in enumerate(patterns_cfg):
if not isinstance(entry, Mapping):
raise ValueError(
f"Entry {idx} in bin_regex_patterns must be a mapping, "
f"got {type(entry)}"
)
regex_str = entry.get("regex", entry.get("regrex", None))
if not isinstance(regex_str, str) or not regex_str:
raise ValueError(
f"Entry {idx} in bin_regex_patterns is missing a non-empty "
f"'regex' field"
)
ratio_val = entry.get("ratio", None)
if ratio_val is None:
raise ValueError(
f"Entry {idx} in bin_regex_patterns is missing 'ratio'"
)
ratio_f = float(ratio_val)
if ratio_f < 0.0 or ratio_f > 1.0:
raise ValueError(
f"Entry {idx} in bin_regex_patterns has invalid ratio "
f"{ratio_f:.6f}; expected in [0.0, 1.0]"
)
compiled_patterns.append(
{
"name": str(entry.get("name", f"bin_{idx}")),
"regex": regex_str,
"compiled": re.compile(regex_str),
}
)
ratios.append(ratio_f)
sum_explicit = float(sum(ratios))
if sum_explicit > 1.0 + 1.0e-6:
raise ValueError(
f"Sum of weighted-bin ratios is {sum_explicit:.6f} (> 1.0). "
"Please reduce the ratios so that their sum is <= 1.0."
)
if sum_explicit > 1.0:
sum_explicit = 1.0
others_ratio = max(0.0, 1.0 - sum_explicit)
if len(keys) == 0:
raise ValueError(
"weighted_bin configuration received an empty key set"
)
num_items_total = float(len(keys))
num_explicit = len(compiled_patterns)
bin_indices: List[List[int]] = [[] for _ in range(num_explicit + 1)]
for idx, motion_key in enumerate(keys):
assigned = False
for b_idx, pat in enumerate(compiled_patterns):
if pat["compiled"].search(motion_key):
bin_indices[b_idx].append(idx)
assigned = True
break
if not assigned:
bin_indices[-1].append(idx)
# Combine explicit ratios with implicit "others" ratio
all_ratios: List[float] = list(ratios)
all_ratios.append(others_ratio)
# If all motion keys are covered by explicit regex bins, but the specified
# ratios sum to less than 1.0, linearly reweight explicit ratios so that
# they sum to 1.0 and disable the implicit "others" bin.
others_count = len(bin_indices[-1])
if others_count == 0 and others_ratio > 0.0 and sum_explicit > 0.0:
scale = 1.0 / sum_explicit
ratios = [r * scale for r in ratios]
others_ratio = 0.0
all_ratios = list(ratios)
all_ratios.append(others_ratio)
logger.info(
"Weighted-bin: all regex bins cover the dataset; "
"linearly reweighted explicit ratios to sum to 1.0 and disabled "
"the implicit 'others' bin."
)
# Validate non-empty bins for any positive ratio (including others)
for b_idx, r in enumerate(all_ratios):
if r > 0.0 and len(bin_indices[b_idx]) == 0:
if b_idx < num_explicit:
name = compiled_patterns[b_idx]["name"]
regex_s = compiled_patterns[b_idx]["regex"]
raise ValueError(
f"Weighted-bin '{name}' (regex='{regex_s}') has ratio "
f"{r:.6f} but matched no motion keys"
)
raise ValueError(
f"Weighted-bin 'others' has ratio {r:.6f} but matched no motion keys"
)
# Prepare logging summary using the configured cache batch size
raw_counts_log = [ratio * batch_size_for_log for ratio in all_ratios]
base_counts_log = _allocate_batch_counts(
raw_counts=raw_counts_log,
target_total=batch_size_for_log,
)
batch_fractions_log = [
float(c) / float(batch_size_for_log) for c in base_counts_log
]
# Build specs using the final, actually used batch fractions
specs: List[Dict[str, Any]] = []
total_items = float(max(1, num_items_total))
for b_idx in range(num_explicit):
name = compiled_patterns[b_idx]["name"]
regex_s = compiled_patterns[b_idx]["regex"]
n = len(bin_indices[b_idx])
ds_frac = float(n) / total_items
bf = batch_fractions_log[b_idx]
specs.append(
{
"name": name,
"regex": regex_s,
"ratio": bf,
"count": n,
"dataset_fraction": ds_frac,
"batch_fraction": bf,
}
)
# Others bin
others_name = "others"
others_regex = ""
n_o = len(bin_indices[-1])
ds_frac_o = float(n_o) / total_items
bf_o = batch_fractions_log[-1]
specs.append(
{
"name": others_name,
"regex": others_regex,
"ratio": bf_o,
"count": n_o,
"dataset_fraction": ds_frac_o,
"batch_fraction": bf_o,
}
)
return bin_indices, all_ratios, specs
def _collect_manifest_keys(
manifest_path: str | Sequence[str],
) -> Tuple[List[str], Dict[str, str], List[str]]:
if isinstance(manifest_path, (str, os.PathLike)):
manifest_paths: List[str] = [str(manifest_path)]
else:
manifest_paths = [str(p) for p in manifest_path]
if len(manifest_paths) == 0:
raise ValueError("Expected at least one manifest path")
key_source: Dict[str, str] = {}
for mp in manifest_paths:
if not os.path.exists(mp):
raise FileNotFoundError(
f"HDF5 manifest not found at {mp}. "
"Please set robot.motion.hdf5_root/train_hdf5_roots "
"to the correct path."
)
with open(mp, "r", encoding="utf-8") as handle:
manifest = json.load(handle)
clips = manifest.get("clips", {})
if not clips:
raise ValueError(
f"Manifest at {mp} contains no clips; cannot preview sampling."
)
for key in clips.keys():
if key in key_source:
raise ValueError(
f"Duplicate motion clip key '{key}' found in multiple "
"manifests; clip keys must be globally unique."
)
key_source[key] = mp
return list(key_source.keys()), key_source, manifest_paths
def _normalize_online_filter_cfg(
cfg: Optional[Mapping[str, Any]],
*,
default_vel_smoothing_sigma: float = 2.0,
) -> Dict[str, Any]:
cfg_local = dict(cfg or {})
enabled = bool(cfg_local.get("enabled", False))
cutoff_pool_cfg = cfg_local.get("butter_cutoff_hz_pool", [])
cutoff_pool = tuple(float(v) for v in cutoff_pool_cfg)
ref_vel_smoothing_sigma = float(
cfg_local.get("ref_vel_smoothing_sigma", default_vel_smoothing_sigma)
)
ft_ref_vel_smoothing_sigma = float(
cfg_local.get(
"ft_ref_vel_smoothing_sigma", default_vel_smoothing_sigma
)
)
if enabled and len(cutoff_pool) == 0:
raise ValueError(
"online_filter.enabled=True requires butter_cutoff_hz_pool to "
"contain at least one cutoff value"
)
butter_order = int(cfg_local.get("butter_order", 4))
if butter_order <= 0:
raise ValueError("online_filter.butter_order must be positive")
return {
"enabled": enabled,
"butter_order": butter_order,
"butter_cutoff_hz_pool": cutoff_pool,
"ref_vel_smoothing_sigma": ref_vel_smoothing_sigma,
"ft_ref_vel_smoothing_sigma": ft_ref_vel_smoothing_sigma,
}
def preview_weighted_bin_from_manifest(
manifest_path: str | Sequence[str],
batch_size: int,
cfg: Mapping[str, Any],
) -> None:
"""Lightweight preview of weighted-bin sampling using manifest.json only.
This helper is intended to be called at configuration time before any
MotionClipBatchCache/DataLoader is constructed, so that invalid regex or
ratio settings can fail fast without incurring the cost of cache setup.
"""
if batch_size <= 0:
batch_size = 1
keys, _, _ = _collect_manifest_keys(manifest_path=manifest_path)
_, _, specs = _configure_weighted_bins(
keys=keys,
cfg=cfg,
batch_size_for_log=batch_size,
)
table_rows = []
for item in specs:
table_rows.append(
[
item["name"],
item["regex"],
f"{item['ratio']:.4f}",
int(item["count"]),
f"{item['dataset_fraction']:.4f}",
f"{item['batch_fraction']:.4f}",
]
)
headers = [
"bin",
"regex",
"final_ratio",
"num_clips",
"clip_fraction",
"batch_fraction",
]
logger.info(
"Weighted-bin config preview (manifest-level):\n"
+ tabulate(table_rows, headers=headers, tablefmt="simple_outline")
)
def preview_uniform_from_manifest(
manifest_path: str | Sequence[str],
batch_size: int,
*,
max_frame_length: int,
min_window_length: int,
handpicked_motion_names: Optional[Sequence[str]] = None,
excluded_motion_names: Optional[Sequence[str]] = None,
) -> None:
"""Manifest-level preview table for uniform/curriculum sampling."""
if batch_size <= 0:
batch_size = 1
if max_frame_length <= 0:
raise ValueError("max_frame_length must be positive")
if min_window_length <= 0:
raise ValueError("min_window_length must be positive")
_, _, manifest_paths = _collect_manifest_keys(manifest_path=manifest_path)
handpicked_set = (
set(handpicked_motion_names)
if handpicked_motion_names is not None
else None
)
excluded_set = (
set(excluded_motion_names)
if excluded_motion_names is not None
else None
)
def _normalize_key(value: Any) -> Optional[str]:
if value is None:
return None
if isinstance(value, bytes):
value = value.decode("utf-8")
key = value if isinstance(value, str) else str(value)
if not key:
return None
return key
def _build_aliases(motion_key: str, meta: Mapping[str, Any]) -> List[str]:
aliases: List[str] = []
def _add(value: Any) -> None:
key = _normalize_key(value)
if key is None or key in aliases:
return
aliases.append(key)
_add(motion_key)
if isinstance(meta, Mapping):
_add(meta.get("motion_key"))
metadata = meta.get("metadata")
if isinstance(metadata, Mapping):
_add(metadata.get("motion_key"))
_add(metadata.get("raw_motion_key"))
return aliases
def _count_windows(clip_length: int) -> Tuple[int, int]:
remaining = clip_length
offset = 0
num_windows = 0
num_frames = 0
while remaining > 0:
window_length = min(max_frame_length, remaining)
if window_length >= min_window_length:
num_windows += 1
num_frames += int(window_length)
offset += int(window_length)
remaining = max(0, clip_length - offset)
return num_windows, num_frames
stats_by_manifest: Dict[str, Dict[str, float]] = {}
for mp in manifest_paths:
with open(mp, "r", encoding="utf-8") as handle:
manifest = json.load(handle)
clips = manifest.get("clips", {})
if not clips:
raise ValueError(
f"Manifest at {mp} contains no clips; cannot preview sampling."
)
num_windows = 0
num_frames = 0
duration_s = 0.0
for key, meta in clips.items():
if isinstance(meta, Mapping):
aliases = _build_aliases(key, meta)
else:
aliases = [key]
if handpicked_set is not None and not any(
alias in handpicked_set for alias in aliases
):
continue
if excluded_set is not None and any(
alias in excluded_set for alias in aliases
):
continue
length = (
int(meta.get("length", 0)) if isinstance(meta, Mapping) else 0
)
if length <= 0:
continue
metadata = (
meta.get("metadata") if isinstance(meta, Mapping) else None
)
motion_fps_val = None
if isinstance(metadata, Mapping):
motion_fps_val = metadata.get("motion_fps")
if motion_fps_val is None and isinstance(meta, Mapping):
motion_fps_val = meta.get("motion_fps")
if motion_fps_val is None:
raise ValueError(
f"motion_fps missing for clip {key} in manifest {mp}"
)
motion_fps = float(motion_fps_val)
if motion_fps <= 0.0:
raise ValueError(
f"Invalid motion_fps {motion_fps} for clip {key} in {mp}"
)
clip_windows, clip_frames = _count_windows(length)
num_windows += int(clip_windows)
num_frames += int(clip_frames)
duration_s += float(clip_frames) / float(motion_fps)
stats_by_manifest[mp] = {
"num_windows": float(num_windows),
"num_frames": float(num_frames),
"duration_s": float(duration_s),
}
total_windows = int(
sum(stats["num_windows"] for stats in stats_by_manifest.values())
)
if total_windows == 0:
raise ValueError(
"No motion windows satisfy the requested frame length constraints"
)
table_rows = []
denom = float(max(1, total_windows))
for mp in manifest_paths:
stats = stats_by_manifest.get(mp, {})
count = int(stats.get("num_windows", 0))
frames = int(stats.get("num_frames", 0))
duration_h = float(stats.get("duration_s", 0.0)) / 3600.0
frac = float(count) / denom
table_rows.append(
[
os.path.dirname(mp),
count,
f"{frac:.4f}",
frames,
f"{duration_h:.2f}",
f"{frac:.4f}",
]
)
headers = [
"dataset_root",
"num_windows",
"window_fraction",
"num_frames",
"duration_h",
"batch_fraction",
]
logger.info(
"Uniform sampling preview (manifest-level):\n"
+ tabulate(table_rows, headers=headers, tablefmt="simple_outline")
)
def preview_sampling_from_cfg(motion_cfg: Mapping[str, Any]) -> None:
"""Preview manifest-level sampling table for uniform/weighted-bin."""
sampling_strategy_cfg = motion_cfg.get("sampling_strategy", None)
if sampling_strategy_cfg is None:
sampling_strategy = "uniform"
else:
sampling_strategy = str(sampling_strategy_cfg).lower()
if sampling_strategy not in ("uniform", "weighted_bin", "curriculum"):
return
backend = str(motion_cfg.get("backend", "hdf5")).lower()
if backend not in ("hdf5", "hdf5_simple", "hdf5_v2"):
return
train_roots = _normalize_root_list(
motion_cfg.get("train_hdf5_roots", None)
)
if len(train_roots) == 0:
hdf5_root = motion_cfg.get("hdf5_root", None)
if not hdf5_root:
return
train_roots = [str(hdf5_root)]
manifest_paths = [
os.path.join(str(root), "manifest.json") for root in train_roots
]
cache_cfg = motion_cfg.get("cache", {})
batch_size = int(cache_cfg.get("max_num_clips", 1))
if sampling_strategy == "weighted_bin":
weighted_bin_cfg = dict(motion_cfg.get("weighted_bin", {}))
preview_weighted_bin_from_manifest(
manifest_path=manifest_paths
if len(manifest_paths) > 1
else manifest_paths[0],
batch_size=batch_size,
cfg=weighted_bin_cfg,
)
return
max_frame_length = int(motion_cfg.get("max_frame_length", 1))
min_window_length = int(motion_cfg.get("min_frame_length", 1))
handpicked_motion_names = motion_cfg.get("handpicked_motion_names", None)
excluded_motion_names = motion_cfg.get("excluded_motion_names", None)
preview_uniform_from_manifest(
manifest_path=manifest_paths
if len(manifest_paths) > 1
else manifest_paths[0],
batch_size=batch_size,
max_frame_length=max_frame_length,
min_window_length=min_window_length,
handpicked_motion_names=handpicked_motion_names,
excluded_motion_names=excluded_motion_names,
)
MANDATORY_DATASETS = {
"dof_pos": "dof_pos",
"dof_vel": "dof_vel",
"rg_pos": "global_translation",
"rb_rot": "global_rotation_quat",
"body_vel": "global_velocity",
"body_ang_vel": "global_angular_velocity",
}
class _WorldFrameNormalizeTransform:
"""Normalize motion tensors into a canonical z-up world frame in-place."""
@staticmethod
def _apply_prefix(
arrays: Dict[str, Tensor],
prefix: str,
*,
offset_xy: Tensor,
q_flat_wxyz: Tensor,
ref_rg_pos_shape: torch.Size,
ref_rb_rot_shape: torch.Size,
) -> None:
pos_key = f"{prefix}rg_pos"
rot_key = f"{prefix}rb_rot"
vel_key = f"{prefix}body_vel"
ang_key = f"{prefix}body_ang_vel"
if (
pos_key not in arrays
or rot_key not in arrays
or vel_key not in arrays
or ang_key not in arrays
):
return
pos = arrays[pos_key]
rot = arrays[rot_key]
vel = arrays[vel_key]
ang = arrays[ang_key]
if pos.shape != ref_rg_pos_shape or rot.shape != ref_rb_rot_shape:
return
# Center XY using canonical offset.
pos[..., 0] -= offset_xy[0]
pos[..., 1] -= offset_xy[1]
# Rotate vectors using shared quaternion utilities (WXYZ convention).
pos_flat = pos.reshape(-1, 3)
vel_flat = vel.reshape(-1, 3)
ang_flat = ang.reshape(-1, 3)
pos[:] = torch_utils.quat_apply(q_flat_wxyz, pos_flat).reshape_as(pos)
vel[:] = torch_utils.quat_apply(q_flat_wxyz, vel_flat).reshape_as(vel)
ang[:] = torch_utils.quat_apply(q_flat_wxyz, ang_flat).reshape_as(ang)
# Rotate orientations: q' = q_heading_inv * q.
rot_flat_xyzw = rot.reshape(-1, 4)
rot_flat_wxyz = torch_utils.xyzw_to_wxyz(rot_flat_xyzw)
rot_out_wxyz = torch_utils.quat_mul(q_flat_wxyz, rot_flat_wxyz)
rot[:] = torch_utils.wxyz_to_xyzw(rot_out_wxyz).reshape_as(rot)
def __call__(self, arrays: Dict[str, Tensor]) -> None:
if "ref_rg_pos" not in arrays or "ref_rb_rot" not in arrays:
raise ValueError("ref_rg_pos and ref_rb_rot are required")
if "ref_body_vel" not in arrays or "ref_body_ang_vel" not in arrays:
raise ValueError("ref_body_vel and ref_body_ang_vel are required")
rg_pos = arrays["ref_rg_pos"]
rb_rot = arrays["ref_rb_rot"]
# Root pose at frame 0, body 0 (XYZW quaternion, z-up).
p_root0 = rg_pos[0, 0] # [3]
q_root0 = rb_rot[0, 0] # [4]
# Compute XY offset from root at frame 0 (will be applied in _apply_to_set).
offset_xy = p_root0.clone()
offset_xy[2] = 0.0
# Extract yaw from q_root0 (XYZW) using z-up convention.
x = q_root0[0]
y = q_root0[1]
z = q_root0[2]
w = q_root0[3]
siny_cosp = 2.0 * (w * z + x * y)
cosy_cosp = w * w + x * x - y * y - z * z
yaw0 = torch.atan2(siny_cosp, cosy_cosp)
# Quaternion for rotation around +Z by -yaw0 (remove initial heading).
half = -0.5 * yaw0
sin_half = torch.sin(half)
cos_half = torch.cos(half)
q_heading_inv = torch.stack(
[
torch.zeros_like(sin_half),
torch.zeros_like(sin_half),
sin_half,
cos_half,
],
dim=-1,
) # [4], XYZW
t, b, _ = rg_pos.shape
q_flat = q_heading_inv.view(1, 1, 4).expand(t, b, 4).reshape(-1, 4)
q_flat_wxyz = torch_utils.xyzw_to_wxyz(q_flat)
for pfx in ("ref_", "ft_ref_"):
self._apply_prefix(
arrays,
pfx,
offset_xy=offset_xy,
q_flat_wxyz=q_flat_wxyz,
ref_rg_pos_shape=rg_pos.shape,
ref_rb_rot_shape=rb_rot.shape,
)
class _CpuFKTransform:
"""Compute FK on CPU and write ref_* tensors in-place."""
def __init__(self, robot_file_path: str) -> None:
self._fk = HoloMotionFK(
robot_file_path=str(robot_file_path), device=torch.device("cpu")
)
self._fk = self._fk.to(torch.device("cpu"))
def __call__(
self,
arrays: Dict[str, Tensor],
fps: float,
prefix: str = "ref_",
vel_smoothing_sigma: float = 2.0,
) -> None:
root_pos_key = f"{prefix}root_pos"
root_rot_key = f"{prefix}root_rot"
dof_pos_key = f"{prefix}dof_pos"
if (
root_pos_key not in arrays
or root_rot_key not in arrays
or dof_pos_key not in arrays
):
raise KeyError(f"Missing {prefix}root_* or {prefix}dof_pos for FK")
with torch.no_grad():
fk_out = self._fk(
root_pos=arrays[root_pos_key][None, ...],
root_quat=arrays[root_rot_key][None, ...],
dof_pos=arrays[dof_pos_key][None, ...],
fps=float(fps),
vel_smoothing_sigma=float(vel_smoothing_sigma),
quat_format="xyzw",
)
arrays[f"{prefix}rg_pos"] = fk_out["global_translation"][0]
arrays[f"{prefix}rb_rot"] = fk_out["global_rotation_quat"][0]
arrays[f"{prefix}body_vel"] = fk_out["global_velocity"][0]
arrays[f"{prefix}body_ang_vel"] = fk_out["global_angular_velocity"][0]
arrays[f"{prefix}dof_vel"] = fk_out["dof_vel"][0]
@dataclass
class MotionWindow:
"""Metadata describing a contiguous motion window within an HDF5 shard."""
motion_key: str # unique per window
shard_index: int
start: int
length: int
raw_motion_key: str # original clip key
window_index: int
@dataclass
class MotionClipSample:
"""In-memory representation of a motion window.
Attributes:
motion_key: Unique window identifier (includes slice info).
raw_motion_key: Original clip identifier from manifest.
tensors: Mapping from tensor name to data tensor of shape
``[window_length, ...]`` (float32 unless specified otherwise).
length: Number of valid frames contained in the sample (``<=``
``max_frame_length``).
"""
motion_key: str
raw_motion_key: str
window_index: int
tensors: Dict[str, Tensor]
length: int
@dataclass
class ClipBatch:
"""Batch of motion clips ready for consumption by the environment.
Attributes:
tensors: Mapping from tensor name to tensor with shape
``[batch_size, max_frame_length, ...]`` placed on the staging
device.
lengths: Valid frame counts per clip ``[batch_size]``.
motion_keys: List of motion keys corresponding to each clip.
max_frame_length: Fixed length configured for the cache.
"""
tensors: Dict[str, Tensor]
lengths: Tensor
motion_keys: List[str]
raw_motion_keys: List[str]
window_indices: Tensor
max_frame_length: int
@staticmethod
def collate_fn(samples: List[MotionClipSample]) -> "ClipBatch":
if len(samples) == 0:
raise ValueError(
"ClipBatch collate_fn received an empty sample list"
)
max_frame_length = max(
sample.tensors["ref_dof_pos"].shape[0] for sample in samples
)
max_frame_length = int(max_frame_length)
batched_tensors: Dict[str, Tensor] = {}
lengths = torch.zeros(len(samples), dtype=torch.long)
motion_keys = []
raw_motion_keys = []
window_indices = torch.zeros(len(samples), dtype=torch.long)
for batch_idx, sample in enumerate(samples):
lengths[batch_idx] = sample.length
motion_keys.append(sample.motion_key)
raw_motion_keys.append(sample.raw_motion_key)
window_indices[batch_idx] = int(sample.window_index)
for name, tensor in sample.tensors.items():
if name not in batched_tensors:
pad_shape = (
len(samples),
max_frame_length,
) + tensor.shape[1:]
batched_tensors[name] = torch.zeros(
pad_shape,
dtype=tensor.dtype,
device=tensor.device,
)
target = batched_tensors[name]
valid_frames = sample.length
target[batch_idx, :valid_frames] = tensor
if valid_frames < max_frame_length and valid_frames > 0:
target[batch_idx, valid_frames:] = tensor[valid_frames - 1]
return ClipBatch(
tensors=batched_tensors,
lengths=lengths,
motion_keys=motion_keys,
raw_motion_keys=raw_motion_keys,
window_indices=window_indices,
max_frame_length=max_frame_length,
)
class Hdf5RootDofDataset(Dataset[MotionClipSample]):
"""HDF5 dataset reading ref_root_* + ref_dof_pos only."""
def __init__(
self,
manifest_path: str | Sequence[str],
max_frame_length: int,
min_window_length: int = 1,
handpicked_motion_names: Optional[List[str]] = None,
excluded_motion_names: Optional[List[str]] = None,
fk_robot_file_path: Optional[str] = None,
fk_vel_smoothing_sigma: float = 2.0,
fk_world_frame_normalization: bool = True,
online_filter_cfg: Optional[Mapping[str, Any]] = None,
allowed_prefixes: Optional[Sequence[str]] = None,
) -> None:
super().__init__()
if max_frame_length <= 0:
raise ValueError("max_frame_length must be positive")
if min_window_length <= 0:
raise ValueError("min_window_length must be positive")
self.max_frame_length = int(max_frame_length)
self.min_window_length = int(min_window_length)
self.handpicked_motion_names = (
set(handpicked_motion_names)
if handpicked_motion_names is not None
else None
)
self.excluded_motion_names = (
set(excluded_motion_names)
if excluded_motion_names is not None
else None
)
self._fk_robot_file_path = (
str(fk_robot_file_path) if fk_robot_file_path is not None else ""
)
if not self._fk_robot_file_path:
raise ValueError("fk_robot_file_path is required for hdf5_v2 FK")
self._fk_world_frame_normalization = bool(fk_world_frame_normalization)
self._fk_transform = _CpuFKTransform(self._fk_robot_file_path)
self._world_frame_transform = (
_WorldFrameNormalizeTransform()
if self._fk_world_frame_normalization
else None
)
self._fk_vel_smoothing_sigma = float(fk_vel_smoothing_sigma)
self._online_filter_cfg = _normalize_online_filter_cfg(
online_filter_cfg,
default_vel_smoothing_sigma=self._fk_vel_smoothing_sigma,
)
self._online_filter_enabled = bool(self._online_filter_cfg["enabled"])
self._online_filter_butter_order = int(
self._online_filter_cfg["butter_order"]
)
self._online_filter_cutoff_hz_pool = tuple(
float(v) for v in self._online_filter_cfg["butter_cutoff_hz_pool"]
)
self._ref_vel_smoothing_sigma = float(
self._online_filter_cfg["ref_vel_smoothing_sigma"]
)
self._ft_ref_vel_smoothing_sigma = float(
self._online_filter_cfg["ft_ref_vel_smoothing_sigma"]
)
if allowed_prefixes is None:
self._allowed_prefixes = ("ref_", "ft_ref_")
else:
self._allowed_prefixes = tuple(str(v) for v in allowed_prefixes)
if "ref_" not in self._allowed_prefixes:
raise ValueError(
"Hdf5RootDofDataset requires 'ref_' in allowed_prefixes"
)
if isinstance(manifest_path, (str, os.PathLike)):
manifest_paths: List[str] = [str(manifest_path)]
else:
manifest_paths = [str(p) for p in manifest_path]
if len(manifest_paths) == 0:
raise ValueError("At least one manifest_path must be provided")
self.hdf5_root = os.path.dirname(manifest_paths[0])
self._manifest_paths: List[str] = manifest_paths
self._shard_paths: List[str] = []
self.shards: List[Dict[str, Any]] = []
self.clips: Dict[str, Dict[str, Any]] = {}
for mp in manifest_paths:
if not os.path.exists(mp):
raise FileNotFoundError(
f"HDF5 manifest not found at {mp}. "
"Please set robot.motion.hdf5_root/train_hdf5_roots "
"to the correct path."
)
with open(mp, "r", encoding="utf-8") as handle:
manifest = json.load(handle)
root = os.path.dirname(mp)
shards_local = list(manifest.get("hdf5_shards", []))
clips_local = manifest.get("clips", {})
shard_offset = len(self.shards)
for shard_meta in shards_local:
self.shards.append(shard_meta)
rel = shard_meta.get("file", None)
if not isinstance(rel, str) or not rel:
raise ValueError(
f"Shard entry in manifest {mp} is missing a valid 'file' field"
)
self._shard_paths.append(os.path.join(root, rel))
for key, meta in clips_local.items():
if key in self.clips:
raise ValueError(
f"Duplicate motion clip key '{key}' found in multiple "
"manifests; clip keys must be globally unique."
)
meta_global = dict(meta)
meta_global["shard"] = (
int(meta_global.get("shard", 0)) + shard_offset
)
self.clips[key] = meta_global
if len(self.shards) == 0:
raise ValueError(
f"No HDF5 shards listed in manifests: {', '.join(manifest_paths)}"
)
self.windows: List[MotionWindow] = self._enumerate_windows()
if len(self.windows) == 0:
raise ValueError(
"No motion windows satisfy the requested frame length constraints"
)
# Setting up hdf5 file handles management for bounded host-memory usage
self._file_handles: "OrderedDict[int, h5py.File]" = OrderedDict()
max_open_env = os.getenv("HOLOMOTION_HDF5_MAX_OPEN_SHARDS")
if max_open_env is None:
self._h5_max_open_files = 16
else:
self._h5_max_open_files = max(1, int(max_open_env))
self._h5_access_counter = 0
self._h5_cleanup_interval = int(
1.0e6
) # clean h5 handles every 1 million samples
def set_progress_counter(self, counter: Optional[mp.Value]) -> None:
self._progress_counter = counter
@staticmethod
def _normalize_motion_key(value: Any) -> Optional[str]:
if value is None:
return None
if isinstance(value, bytes):
value = value.decode("utf-8")
if isinstance(value, str):
key = value
else:
key = str(value)
if not key:
return None
return key
def _build_motion_key_aliases(
self, motion_key: str, meta: Mapping[str, Any]
) -> Tuple[str, ...]:
aliases: List[str] = []
def _add(value: Any) -> None:
key = self._normalize_motion_key(value)
if key is None:
return
if key in aliases:
return
aliases.append(key)
_add(motion_key)
if isinstance(meta, Mapping):
_add(meta.get("motion_key"))
metadata = meta.get("metadata")
if isinstance(metadata, Mapping):
_add(metadata.get("motion_key"))
_add(metadata.get("raw_motion_key"))
return tuple(aliases)
def _enumerate_windows(self) -> List[MotionWindow]:
windows: List[MotionWindow] = []
for motion_key, meta in self.clips.items():
aliases = self._build_motion_key_aliases(motion_key, meta)
if self.handpicked_motion_names is not None and not any(
alias in self.handpicked_motion_names for alias in aliases
):
continue
if self.excluded_motion_names is not None and any(
alias in self.excluded_motion_names for alias in aliases
):
continue
shard_index = int(meta.get("shard", 0))
start = int(meta.get("start", 0))
length = int(meta.get("length", 0))
if length <= 0:
continue
remaining = length
offset = 0
window_index = 0
while remaining > 0:
window_length = min(self.max_frame_length, remaining)
if window_length >= self.min_window_length:
win_start = start + offset
unique_key = (
f"{motion_key}__start_{win_start}_len_{window_length}"
)
windows.append(
MotionWindow(
motion_key=unique_key,
shard_index=shard_index,
start=win_start,
length=window_length,
raw_motion_key=motion_key,
window_index=window_index,
)
)
window_index += 1
offset += window_length
remaining = max(0, length - offset)
return windows
def __len__(self) -> int:
return len(self.windows)
@staticmethod
def _cast_motion_np(np_array: np.ndarray, name: str) -> Tensor:
if np_array.dtype == np.float32:
pass
elif np_array.dtype.kind == "O":
raise ValueError(f"{name} has object dtype")
elif np.issubdtype(np_array.dtype, np.integer):
logger.warning(
"Casting {} from {} to float32.", name, np_array.dtype
)
np_array = np_array.astype(np.float32, copy=False)
else:
raise ValueError(
f"{name} has dtype {np_array.dtype}, expected float32 or integer."
)
return torch.from_numpy(np_array).to(torch.float32)
@staticmethod
def _make_scalar_metadata_tensor(value: float, length: int) -> Tensor:
return torch.full((int(length), 1), float(value), dtype=torch.float32)
def _sample_online_filter_cutoff_hz(self) -> float:
if not self._online_filter_enabled:
return 0.0
cutoff_pool = self._online_filter_cutoff_hz_pool
if len(cutoff_pool) == 0:
raise ValueError(
"Online filter is enabled but butter_cutoff_hz_pool is empty"
)
if len(cutoff_pool) == 1:
return cutoff_pool[0]
sample_idx = int(torch.randint(len(cutoff_pool), size=(1,)).item())
return cutoff_pool[sample_idx]
def _add_online_filtered_reference_tensors(
self,
arrays: Dict[str, Tensor],
fps: float,
cutoff_hz: float,
) -> None:
filtered_inputs_np = butterworth_filter_root_dof_arrays(
arrays={
"ref_root_pos": arrays["ref_root_pos"].cpu().numpy(),
"ref_root_rot": arrays["ref_root_rot"].cpu().numpy(),
"ref_dof_pos": arrays["ref_dof_pos"].cpu().numpy(),
},
fps=float(fps),
cutoff_hz=float(cutoff_hz),
order=self._online_filter_butter_order,
)
for tensor_name, np_array in filtered_inputs_np.items():
arrays[tensor_name] = torch.from_numpy(np_array).to(torch.float32)
self._fk_transform(
arrays,
fps,
prefix="ft_ref_",
vel_smoothing_sigma=self._ft_ref_vel_smoothing_sigma,
)
@staticmethod
def _derive_root_state_tensors(
arrays: Dict[str, Tensor],
prefix: str = "ref_",
) -> None:
rg_pos_key = f"{prefix}rg_pos"
rb_rot_key = f"{prefix}rb_rot"
body_vel_key = f"{prefix}body_vel"
body_ang_vel_key = f"{prefix}body_ang_vel"
if (
rg_pos_key not in arrays
or rb_rot_key not in arrays
or body_vel_key not in arrays
or body_ang_vel_key not in arrays
):
return
# Keep root-level tensors consistent with the FK-derived body tensors.
arrays[f"{prefix}root_pos"] = arrays[rg_pos_key][:, 0, :]
arrays[f"{prefix}root_rot"] = arrays[rb_rot_key][:, 0, :]
arrays[f"{prefix}root_vel"] = arrays[body_vel_key][:, 0, :]
arrays[f"{prefix}root_ang_vel"] = arrays[body_ang_vel_key][:, 0, :]
def __getitem__(self, index: int) -> MotionClipSample:
window = self.windows[index]
shard_handle = self._get_shard_handle(window.shard_index)
start, end = window.start, window.start + window.length
arrays: Dict[str, Tensor] = {}
for dataset_name in ("ref_root_pos", "ref_root_rot", "ref_dof_pos"):
if dataset_name not in shard_handle:
raise KeyError(
f"Missing mandatory dataset '{dataset_name}' in shard index "
f"{window.shard_index}"
)
np_array = np.asarray(shard_handle[dataset_name][start:end, ...])
arrays[dataset_name] = self._cast_motion_np(np_array, dataset_name)
if "frame_flag" in shard_handle:
frame_flag_np = shard_handle["frame_flag"][start:end]
if frame_flag_np.dtype.kind == "O":
raise ValueError("frame_flag has object dtype")
frame_flag = torch.from_numpy(frame_flag_np).to(torch.long)
else:
frame_flag = torch.ones(window.length, dtype=torch.long)
if window.length > 1:
frame_flag[0] = 0
frame_flag[-1] = 2
elif window.length == 1:
frame_flag[0] = 2
arrays["frame_flag"] = frame_flag
clip_meta = self.clips.get(window.raw_motion_key, {})
metadata = clip_meta.get("metadata", {})
motion_fps_val = metadata.get(
"motion_fps", clip_meta.get("motion_fps")
)
if motion_fps_val is None:
raise ValueError(
f"motion_fps missing for clip {window.raw_motion_key}"
)
motion_fps = float(motion_fps_val)
if motion_fps <= 0.0:
raise ValueError(
f"Invalid motion_fps {motion_fps} for clip {window.raw_motion_key}"
)
arrays["motion_fps"] = self._make_scalar_metadata_tensor(
motion_fps, window.length
)
cutoff_hz = self._sample_online_filter_cutoff_hz()
arrays["filter_cutoff_hz"] = self._make_scalar_metadata_tensor(
cutoff_hz, window.length
)
self._fk_transform(
arrays,
motion_fps,
vel_smoothing_sigma=self._ref_vel_smoothing_sigma,
)
if self._online_filter_enabled and "ft_ref_" in self._allowed_prefixes:
self._add_online_filtered_reference_tensors(
arrays,
motion_fps,
cutoff_hz,
)
if self._world_frame_transform is not None:
self._world_frame_transform(arrays)
self._derive_root_state_tensors(arrays, prefix="ref_")
self._derive_root_state_tensors(arrays, prefix="ft_ref_")
if self._progress_counter is not None:
with self._progress_counter.get_lock():
self._progress_counter.value += 1
return MotionClipSample(
motion_key=window.motion_key,
raw_motion_key=window.raw_motion_key,
window_index=int(index),
tensors=arrays,
length=window.length,
)
def _get_shard_handle(self, shard_index: int) -> h5py.File:
# periodically clean up the file handles
self._h5_access_counter += 1
if self._h5_access_counter >= self._h5_cleanup_interval:
self.close()
self._h5_access_counter = 0
if shard_index in self._file_handles:
handle = self._file_handles.pop(shard_index)
if handle.id:
self._file_handles[shard_index] = handle
return handle
if shard_index < 0 or shard_index >= len(self._shard_paths):
raise IndexError(
f"Shard index {shard_index} out of range for "
f"{len(self._shard_paths)} available shards"
)
shard_path = self._shard_paths[shard_index]
rdcc_nbytes_env = os.getenv("HOLOMOTION_HDF5_RDCC_NBYTES")
if rdcc_nbytes_env is None:
rdcc_nbytes = 4 * 1024 * 1024
else:
rdcc_nbytes = int(rdcc_nbytes_env)
handle = h5py.File(
shard_path,
"r",
libver="latest",
swmr=True,
rdcc_nbytes=rdcc_nbytes,
rdcc_w0=0.75,
)
if (
self._h5_max_open_files is not None
and len(self._file_handles) >= self._h5_max_open_files
):
old_index, old_handle = self._file_handles.popitem(last=False)
old_handle.close()
self._file_handles[shard_index] = handle
return handle
def close(self) -> None:
logger.info("Clearing HDF5 file handles ...")
for handle in self._file_handles.values():
if handle.id:
handle.close()
self._file_handles.clear()
def __del__(self) -> None:
self.close()
def _normalize_root_list(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, (str, os.PathLike)):
return [str(value)]
return [str(v) for v in value]
def build_motion_datasets_from_cfg(
motion_cfg: Mapping[str, Any],
*,
max_frame_length: int,
min_window_length: int,
world_frame_normalization: bool = True,
handpicked_motion_names: Optional[List[str]] = None,
excluded_motion_names: Optional[List[str]] = None,
allowed_prefixes: Optional[Sequence[str]] = None,
) -> Tuple[
Dataset[MotionClipSample],
Optional[Dataset[MotionClipSample]],
Dict[str, Any],
]:
preview_sampling_from_cfg(motion_cfg=motion_cfg)
backend = str(motion_cfg.get("backend", "hdf5")).lower()
if backend in ("hdf5", "hdf5_simple"):
train_roots = _normalize_root_list(
motion_cfg.get("train_hdf5_roots", None)
)
if len(train_roots) == 0:
hdf5_root = motion_cfg.get("hdf5_root", None)
if not hdf5_root:
raise ValueError(
"HDF5 backend requires train_hdf5_roots or hdf5_root"
)
train_roots = [str(hdf5_root)]
manifest_paths = [
os.path.join(str(root), "manifest.json") for root in train_roots
]
train_dataset = Hdf5MotionDataset(
manifest_path=manifest_paths
if len(manifest_paths) > 1
else manifest_paths[0],
max_frame_length=max_frame_length,
min_window_length=min_window_length,
handpicked_motion_names=handpicked_motion_names,
excluded_motion_names=excluded_motion_names,
world_frame_normalization=world_frame_normalization,
allowed_prefixes=allowed_prefixes,
)
val_roots = _normalize_root_list(
motion_cfg.get("val_hdf5_roots", motion_cfg.get("val_hdf5_root"))
)
val_dataset = None
if len(val_roots) > 0:
val_manifest_paths = [
os.path.join(str(root), "manifest.json") for root in val_roots
]
val_dataset = Hdf5MotionDataset(
manifest_path=val_manifest_paths
if len(val_manifest_paths) > 1
else val_manifest_paths[0],
max_frame_length=max_frame_length,
min_window_length=min_window_length,
handpicked_motion_names=handpicked_motion_names,
excluded_motion_names=excluded_motion_names,
world_frame_normalization=world_frame_normalization,
allowed_prefixes=allowed_prefixes,
)
return train_dataset, val_dataset, {}
if backend == "hdf5_v2":
fk_robot_file_path = motion_cfg.get("fk_robot_file_path")
fk_vel_smoothing_sigma = float(
motion_cfg.get("fk_vel_smoothing_sigma", 2.0)
)
fk_world_frame_normalization = bool(
motion_cfg.get("online_fk_world_frame_normalization", True)
)
cache_cfg = motion_cfg.get("cache", {})
allowed_prefixes = cache_cfg.get(
"allowed_prefixes",
["ref_", "ft_ref_"],
)
online_filter_cfg = motion_cfg.get("online_filter", {})
train_roots = _normalize_root_list(
motion_cfg.get("train_hdf5_roots", None)
)
if len(train_roots) == 0:
hdf5_root = motion_cfg.get("hdf5_root", None)
if not hdf5_root:
raise ValueError(
"HDF5 v2 backend requires train_hdf5_roots or hdf5_root"
)
train_roots = [str(hdf5_root)]
train_manifest_paths = [
os.path.join(str(root), "manifest.json") for root in train_roots
]
train_dataset = Hdf5RootDofDataset(
manifest_path=train_manifest_paths
if len(train_manifest_paths) > 1
else train_manifest_paths[0],
max_frame_length=max_frame_length,
min_window_length=min_window_length,
handpicked_motion_names=handpicked_motion_names,
excluded_motion_names=excluded_motion_names,
fk_robot_file_path=fk_robot_file_path,
fk_vel_smoothing_sigma=fk_vel_smoothing_sigma,
fk_world_frame_normalization=fk_world_frame_normalization,
online_filter_cfg=online_filter_cfg,
allowed_prefixes=allowed_prefixes,
)
val_roots = _normalize_root_list(
motion_cfg.get("val_hdf5_roots", motion_cfg.get("val_hdf5_root"))
)
val_dataset = None
if len(val_roots) > 0:
val_manifest_paths = [
os.path.join(str(root), "manifest.json") for root in val_roots
]
val_dataset = Hdf5RootDofDataset(
manifest_path=val_manifest_paths
if len(val_manifest_paths) > 1
else val_manifest_paths[0],
max_frame_length=max_frame_length,
min_window_length=min_window_length,
handpicked_motion_names=handpicked_motion_names,
excluded_motion_names=excluded_motion_names,
fk_robot_file_path=fk_robot_file_path,
fk_vel_smoothing_sigma=fk_vel_smoothing_sigma,
fk_world_frame_normalization=fk_world_frame_normalization,
online_filter_cfg=online_filter_cfg,
allowed_prefixes=allowed_prefixes,
)
cache_kwargs = {
"stage_on_swap_only": bool(
motion_cfg.get("stage_on_swap_only", True)
)
}
return train_dataset, val_dataset, cache_kwargs
raise ValueError(f"Unsupported motion backend: {backend}")
def _cache_collate_fn(
samples: List[MotionClipSample],
mode: str,
batch_size: int,
) -> ClipBatch:
"""Collate function for motion cache DataLoader (supports validation padding)."""
if mode == "val" and batch_size > len(samples) and len(samples) > 0:
extra = batch_size - len(samples)
gen = torch.Generator()
idx = torch.randint(0, len(samples), size=(extra,), generator=gen)
padded = list(samples)
for i in idx.tolist():
padded.append(samples[i])
return ClipBatch.collate_fn(padded)
return ClipBatch.collate_fn(samples)
class InfiniteDistributedSampler(DistributedSampler):
"""Distributed sampler that yields an infinite stream by cycling epochs."""
def __iter__(self):
# Infinite stream by cycling epochs
while True:
self.set_epoch(getattr(self, "_epoch", 0))
for idx in super().__iter__():
yield idx
self._epoch = getattr(self, "_epoch", 0) + 1
class InfiniteRandomSampler(Sampler[int]):
"""Random sampler that yields infinite reshuffled passes over the dataset."""
def __init__(self, data_source: Dataset, seed: int = 0) -> None:
self.data_source = data_source
self.seed = int(seed)
self.epoch = 0
def __iter__(self):
# Yield infinite permutations of indices
while True:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
perm = torch.randperm(len(self.data_source), generator=g)
for idx in perm.tolist():
yield int(idx)
self.epoch += 1
def __len__(self) -> int:
# Large sentinel to satisfy components that query length
return 2**31 - 1
class WeightedBinInfiniteSampler(Sampler[int]):
"""Infinite sampler that respects regex-based weighted bins over indices."""
def __init__(
self,
dataset_len: int,
bin_indices: List[List[int]],
ratios: List[float],
batch_size: int,
seed: int,
) -> None:
self._ds_len = int(max(0, dataset_len))
self._bins = [torch.tensor(b, dtype=torch.long) for b in bin_indices]
self._ratios = list(ratios)
self._batch_size = int(max(1, batch_size))
self._seed = int(seed)
self._epoch = 0
raw_counts = [r * float(self._batch_size) for r in self._ratios]
self._counts = _allocate_batch_counts(
raw_counts=raw_counts,
target_total=self._batch_size,
)
def __iter__(self):
while True:
g = torch.Generator()
g.manual_seed(self._seed + self._epoch)
batch: List[int] = []
for bin_idx, count in zip(self._bins, self._counts):
if count <= 0 or bin_idx.numel() == 0:
continue
choice = torch.randint(
0,
int(bin_idx.numel()),
size=(count,),
generator=g,
)
selected = bin_idx[choice].tolist()
batch.extend(int(x) for x in selected)
if not batch:
# Fallback: uniform over dataset indices
if self._ds_len == 0:
raise ValueError(
"WeightedBinInfiniteSampler cannot sample from an empty dataset"
)
all_idx = torch.randint(
0,
self._ds_len,
size=(self._batch_size,),
generator=g,
)
batch = [int(x) for x in all_idx.tolist()]
if len(batch) > self._batch_size:
batch = batch[: self._batch_size]
elif len(batch) < self._batch_size:
pad = self._batch_size - len(batch)
if pad > 0:
batch.extend(batch[:pad])
perm = torch.randperm(len(batch), generator=g)
for idx in perm.tolist():
yield int(batch[idx])
self._epoch += 1
def __len__(self) -> int:
return 2**31 - 1
class PrioritizedInfiniteSampler(Sampler[int]):
"""Infinite sampler with persistent prioritized and fresh uniform pools."""
def __init__(
self,
dataset_len: int,
batch_size: int,
seed: int,
*,
p_a_ratio: float = 0.2,
ema_alpha_signal: float = 0.2,
ema_alpha_rel_improve: float = 0.2,
relative_eps: float = 1.0e-6,
) -> None:
self._ds_len = int(max(0, dataset_len))
self._batch_size = int(max(1, batch_size))
self._seed = int(seed)
self._epoch = 0
self._p_a_ratio = float(min(1.0, max(0.0, p_a_ratio)))
self._ema_alpha_signal = float(min(1.0, max(0.0, ema_alpha_signal)))
self._ema_alpha_rel_improve = float(
min(1.0, max(0.0, ema_alpha_rel_improve))
)
self._relative_eps = float(max(1.0e-12, relative_eps))
if self._ds_len <= 0:
self._ema_completion_rate = torch.zeros(0, dtype=torch.float32)
self._ema_completion_rate_sq = torch.zeros(0, dtype=torch.float32)
self._ema_completion_rel_improve = torch.zeros(
0, dtype=torch.float32
)
self._selection_counts = torch.zeros(0, dtype=torch.long)
self._seen_mask = torch.zeros(0, dtype=torch.bool)
self._prioritized_pool_indices = torch.zeros(0, dtype=torch.long)
self._prioritized_pool_mask = torch.zeros(0, dtype=torch.bool)
else:
self._ema_completion_rate = torch.zeros(
self._ds_len, dtype=torch.float32
)
self._ema_completion_rate_sq = torch.zeros(
self._ds_len, dtype=torch.float32
)
self._ema_completion_rel_improve = torch.zeros(
self._ds_len, dtype=torch.float32
)
self._selection_counts = torch.zeros(
self._ds_len, dtype=torch.long
)
self._seen_mask = torch.zeros(self._ds_len, dtype=torch.bool)
self._prioritized_pool_indices = torch.zeros(0, dtype=torch.long)
self._prioritized_pool_mask = torch.zeros(
self._ds_len, dtype=torch.bool
)
self._state_version = 0
self._last_updated_swap = -1
self._last_prioritized_pool_mean_score = 0.0
self._last_uniform_pool_mean_score = 0.0
self._last_entered_prioritized_pool_count = 0
self._last_exited_prioritized_pool_count = 0
self._uniform_cycle_start = 0
self._uniform_cycle_step = 1
self._uniform_cycle_offset = self._ds_len
self._uniform_cycle_epoch = 0
@property
def state_version(self) -> int:
return int(self._state_version)
def get_pool_statistics(self) -> Optional[Dict[str, float]]:
if self._ds_len <= 0:
return None
return self._pool_metric_stats()
@staticmethod
def _aggregate_by_index(
window_indices: Tensor,
values: Tensor,
counts: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
if window_indices.numel() == 0:
return (
torch.zeros(0, dtype=torch.long),
torch.zeros(0, dtype=torch.float32),
torch.zeros(0, dtype=torch.float32),
)
unique_indices, inverse = torch.unique(
window_indices.to(dtype=torch.long),
sorted=False,
return_inverse=True,
)
out_weighted_sum = torch.zeros(
unique_indices.numel(), dtype=torch.float32
)
out_count = torch.zeros(unique_indices.numel(), dtype=torch.float32)
out_weighted_sum.scatter_add_(0, inverse, values * counts)
out_count.scatter_add_(0, inverse, counts)
return unique_indices, out_weighted_sum, out_count
def _pool_batch_sizes(self) -> Tuple[int, int]:
if self._ds_len <= 0:
return 0, 0
uniform_count = int(round(self._p_a_ratio * float(self._batch_size)))
uniform_count = max(0, min(self._batch_size, uniform_count))
prioritized_count = max(0, self._batch_size - uniform_count)
return uniform_count, prioritized_count
def _priority_scores_for_indices(self, indices: Tensor) -> Tensor:
if indices.numel() == 0 or self._ds_len <= 0:
return torch.zeros(0, dtype=torch.float32)
idx = indices.to(dtype=torch.long)
progress = torch.clamp(
self._ema_completion_rel_improve.index_select(0, idx),
min=0.0,
max=1.0,
)
remaining_difficulty = torch.clamp(
1.0 - self._ema_completion_rate.index_select(0, idx),
min=0.0,
max=1.0,
)
seen = self._seen_mask.index_select(0, idx).to(dtype=torch.float32)
return progress * remaining_difficulty * seen
def _pool_metric_stats(self) -> Dict[str, float]:
prioritized_pool_size = int(self._prioritized_pool_indices.numel())
return {
"prioritized_pool_size": float(prioritized_pool_size),
"prioritized_pool_mean_score": float(
self._last_prioritized_pool_mean_score
),
"uniform_pool_mean_score": float(
self._last_uniform_pool_mean_score
),
"entered_prioritized_pool_count": float(
self._last_entered_prioritized_pool_count
),
"exited_prioritized_pool_count": float(
self._last_exited_prioritized_pool_count
),
}
def get_window_state_for_indices(
self, window_indices: Tensor
) -> Dict[str, Tensor]:
if self._ds_len <= 0:
empty_bool = torch.zeros(0, dtype=torch.bool)
empty_float = torch.zeros(0, dtype=torch.float32)
return {
"ema_completion_rate": empty_float,
"completion_rate_rel_improve": empty_float,
"selection_count": torch.zeros(0, dtype=torch.long),
"seen": empty_bool,
"in_prioritized_pool": empty_bool,
}
idx = window_indices.detach().to(dtype=torch.long).reshape(-1).cpu()
if idx.numel() == 0:
empty_bool = torch.zeros(0, dtype=torch.bool)
empty_float = torch.zeros(0, dtype=torch.float32)
return {
"ema_completion_rate": empty_float,
"completion_rate_rel_improve": empty_float,
"selection_count": torch.zeros(0, dtype=torch.long),
"seen": empty_bool,
"in_prioritized_pool": empty_bool,
}
return {
"ema_completion_rate": self._ema_completion_rate.index_select(
0, idx
).to(dtype=torch.float32),
"completion_rate_rel_improve": (
self._ema_completion_rel_improve.index_select(0, idx).to(
dtype=torch.float32
)
),
"selection_count": self._selection_counts.index_select(0, idx),
"seen": self._seen_mask.index_select(0, idx),
"in_prioritized_pool": self._prioritized_pool_mask.index_select(
0, idx
),
}
def _rebuild_prioritized_pool(self, candidate_indices: Tensor) -> None:
if self._ds_len <= 0:
return
_, prioritized_count = self._pool_batch_sizes()
previous_indices = self._prioritized_pool_indices
selected = torch.zeros(0, dtype=torch.long)
if prioritized_count > 0:
candidates = torch.cat(
[
previous_indices.to(dtype=torch.long),
candidate_indices.to(dtype=torch.long).reshape(-1),
]
)
candidates = torch.unique(candidates, sorted=False)
scores = self._priority_scores_for_indices(candidates)
positive = scores > 0.0
if bool(positive.any().item()):
candidates = candidates[positive]
scores = scores[positive]
order = torch.argsort(scores, descending=True)
selected = candidates.index_select(
0, order[: min(prioritized_count, candidates.numel())]
)
scores = scores.index_select(
0, order[: min(prioritized_count, scores.numel())]
)
self._last_prioritized_pool_mean_score = float(
scores.mean().item()
)
else:
self._last_prioritized_pool_mean_score = 0.0
if candidates.numel() > selected.numel():
selected_mask = torch.zeros(
candidates.numel(), dtype=torch.bool
)
if selected.numel() > 0:
matches = candidates[:, None] == selected[None, :]
selected_mask = matches.any(dim=1)
nonselected_scores = self._priority_scores_for_indices(
candidates[~selected_mask]
)
self._last_uniform_pool_mean_score = (
float(nonselected_scores.mean().item())
if nonselected_scores.numel() > 0
else 0.0
)
else:
self._last_uniform_pool_mean_score = 0.0
else:
self._last_prioritized_pool_mean_score = 0.0
self._last_uniform_pool_mean_score = 0.0
if previous_indices.numel() > 0:
self._prioritized_pool_mask[previous_indices] = False
if selected.numel() > 0:
self._prioritized_pool_mask[selected] = True
previous_set = set(previous_indices.tolist())
selected_set = set(selected.tolist())
self._last_entered_prioritized_pool_count = len(
selected_set - previous_set
)
self._last_exited_prioritized_pool_count = len(
previous_set - selected_set
)
self._prioritized_pool_indices = selected
def maybe_update_from_observations(
self,
*,
window_indices: Tensor,
mpkpe_signal_means: Tensor,
completion_rate_means: Tensor,
counts: Tensor,
swap_index: int,
) -> bool:
if self._ds_len <= 0:
return False
swap_idx = int(swap_index)
if swap_idx <= 0:
return False
if self._last_updated_swap == swap_idx:
return False
indices = (
window_indices.detach().to(dtype=torch.long).reshape(-1).cpu()
)
# Keep validating the MPKPE tensor shape so the command-side
# curriculum aggregation stays aligned with completion-rate updates.
mpkpe_signal_numel = int(mpkpe_signal_means.numel())
completion_rate = (
completion_rate_means.detach()
.to(dtype=torch.float32)
.reshape(-1)
.cpu()
)
cnt = counts.detach().to(dtype=torch.float32).reshape(-1).cpu()
if not (
indices.numel() == mpkpe_signal_numel
and mpkpe_signal_numel == completion_rate.numel()
and completion_rate.numel() == cnt.numel()
):
raise ValueError(
"Prioritized sampler update tensors must have matching shape."
)
valid_dataset_idx = (indices >= 0) & (indices < self._ds_len)
valid = (
valid_dataset_idx & torch.isfinite(completion_rate) & (cnt > 0.0)
)
current_batch_indices = torch.unique(
indices[valid_dataset_idx], sorted=False
)
if not bool(valid.any().item()):
self._last_entered_prioritized_pool_count = 0
self._last_exited_prioritized_pool_count = 0
self._last_updated_swap = swap_idx
return False
idx_valid = indices[valid]
completion_rate_valid = completion_rate[valid]
cnt_valid = cnt[valid]
touched_idx, completion_rate_sum, completion_rate_count_sum = (
self._aggregate_by_index(
idx_valid,
completion_rate_valid,
cnt_valid,
)
)
if touched_idx.numel() == 0:
self._last_entered_prioritized_pool_count = 0
self._last_exited_prioritized_pool_count = 0
self._last_updated_swap = swap_idx
return False
completion_rate_obs = (
completion_rate_sum / completion_rate_count_sum.clamp_min(1.0e-12)
)
completion_rate_obs = torch.clamp(
completion_rate_obs, min=0.0, max=1.0
)
prev_seen = self._seen_mask[touched_idx]
prev_completion_rate = self._ema_completion_rate[touched_idx]
prev_completion_rate_sq = self._ema_completion_rate_sq[touched_idx]
prev_completion_rate_var = torch.clamp(
prev_completion_rate_sq
- prev_completion_rate * prev_completion_rate,
min=1.0e-6,
)
prev_completion_rate_std = torch.sqrt(prev_completion_rate_var)
next_completion_rate = torch.where(
prev_seen,
(1.0 - self._ema_alpha_signal) * prev_completion_rate
+ self._ema_alpha_signal * completion_rate_obs,
completion_rate_obs,
)
next_completion_rate_sq = torch.where(
prev_seen,
(1.0 - self._ema_alpha_signal) * prev_completion_rate_sq
+ self._ema_alpha_signal
* (completion_rate_obs * completion_rate_obs),
completion_rate_obs * completion_rate_obs,
)
completion_rel_improve_obs = torch.zeros_like(next_completion_rate)
completion_rel_improve_obs[prev_seen] = torch.tanh(
(completion_rate_obs[prev_seen] - prev_completion_rate[prev_seen])
/ (prev_completion_rate_std[prev_seen] + self._relative_eps)
)
prev_completion_rel = self._ema_completion_rel_improve[touched_idx]
next_completion_rel = torch.where(
prev_seen,
(1.0 - self._ema_alpha_rel_improve) * prev_completion_rel
+ self._ema_alpha_rel_improve * completion_rel_improve_obs,
completion_rel_improve_obs,
)
self._ema_completion_rate[touched_idx] = next_completion_rate
self._ema_completion_rate_sq[touched_idx] = next_completion_rate_sq
self._ema_completion_rel_improve[touched_idx] = next_completion_rel
self._seen_mask[touched_idx] = True
self._rebuild_prioritized_pool(touched_idx)
self._state_version += 1
self._last_updated_swap = swap_idx
return True
def _reset_uniform_cycle(self) -> None:
if self._ds_len <= 0:
self._uniform_cycle_start = 0
self._uniform_cycle_step = 1
self._uniform_cycle_offset = 0
return
generator = torch.Generator()
generator.manual_seed(self._seed + self._uniform_cycle_epoch * 1000003)
self._uniform_cycle_epoch += 1
self._uniform_cycle_start = int(
torch.randint(
low=0,
high=self._ds_len,
size=(1,),
generator=generator,
).item()
)
if self._ds_len <= 1:
self._uniform_cycle_step = 1
else:
step = int(
torch.randint(
low=1,
high=self._ds_len,
size=(1,),
generator=generator,
).item()
)
while math.gcd(step, self._ds_len) != 1:
step += 1
if step >= self._ds_len:
step = 1
self._uniform_cycle_step = step
self._uniform_cycle_offset = 0
def _next_uniform_index(self) -> int:
if self._uniform_cycle_offset >= self._ds_len:
self._reset_uniform_cycle()
next_index = (
self._uniform_cycle_start
+ self._uniform_cycle_offset * self._uniform_cycle_step
) % self._ds_len
self._uniform_cycle_offset += 1
return int(next_index)
def _sample_uniform_indices(
self,
generator: torch.Generator,
count: int,
*,
exclude: Optional[Tensor] = None,
) -> Tensor:
del generator
if count <= 0 or self._ds_len <= 0:
return torch.zeros(0, dtype=torch.long)
blocked = set()
if exclude is not None and exclude.numel() > 0:
blocked.update(
exclude.detach().to(dtype=torch.long).reshape(-1).tolist()
)
take = min(int(count), max(0, self._ds_len - len(blocked)))
if take <= 0:
return torch.zeros(0, dtype=torch.long)
selected: List[int] = []
stagnant_steps = 0
while len(selected) < take and stagnant_steps < self._ds_len:
next_index = self._next_uniform_index()
if next_index in blocked:
stagnant_steps += 1
continue
selected.append(next_index)
blocked.add(next_index)
stagnant_steps = 0
return torch.tensor(selected, dtype=torch.long)
def _sample_prioritized_indices(
self, generator: torch.Generator, count: int
) -> Tensor:
if count <= 0 or self._prioritized_pool_indices.numel() == 0:
return torch.zeros(0, dtype=torch.long)
perm = torch.randperm(
self._prioritized_pool_indices.numel(), generator=generator
)
take = min(count, int(self._prioritized_pool_indices.numel()))
return self._prioritized_pool_indices.index_select(0, perm[:take])
def _sample_batch_indices(self, generator: torch.Generator) -> Tensor:
uniform_count, prioritized_count = self._pool_batch_sizes()
prioritized_indices = self._sample_prioritized_indices(
generator, prioritized_count
)
uniform_indices = self._sample_uniform_indices(
generator,
uniform_count,
exclude=prioritized_indices,
)
sampled_indices = torch.cat(
[uniform_indices, prioritized_indices], dim=0
)
if sampled_indices.numel() < self._batch_size:
extra_indices = self._sample_uniform_indices(
generator,
self._batch_size - int(sampled_indices.numel()),
exclude=sampled_indices,
)
sampled_indices = torch.cat(
[sampled_indices, extra_indices], dim=0
)
if sampled_indices.numel() != self._batch_size:
raise ValueError(
"Prioritized sampler failed to assemble a full cache batch."
)
if sampled_indices.numel() > 0:
self._selection_counts[sampled_indices] += 1
return sampled_indices
def get_scores_for_indices(self, window_indices: Tensor) -> Tensor:
if self._ds_len <= 0:
return torch.zeros_like(window_indices, dtype=torch.float32)
idx = window_indices.detach().to(dtype=torch.long).reshape(-1).cpu()
if idx.numel() == 0:
return torch.zeros(0, dtype=torch.float32)
scores = self._priority_scores_for_indices(idx)
return scores.to(dtype=torch.float32)
def __iter__(self):
while True:
if self._ds_len <= 0:
raise ValueError(
"PrioritizedInfiniteSampler cannot sample from "
"an empty dataset."
)
g = torch.Generator()
g.manual_seed(self._seed + self._epoch)
sampled_indices = self._sample_batch_indices(generator=g)
perm = torch.randperm(sampled_indices.numel(), generator=g)
yielded_indices = sampled_indices.index_select(0, perm)
for idx in yielded_indices.tolist():
yield int(idx)
self._epoch += 1
def __len__(self) -> int:
return 2**31 - 1
class Hdf5MotionDataset(Dataset[MotionClipSample]):
"""Dataset that materializes fixed-length motion windows from HDF5 shards."""
def __init__(
self,
manifest_path: str | Sequence[str],
max_frame_length: int,
min_window_length: int = 1,
handpicked_motion_names: Optional[List[str]] = None,
excluded_motion_names: Optional[List[str]] = None,
world_frame_normalization: bool = True,
allowed_prefixes: Optional[Sequence[str]] = None,
) -> None:
super().__init__()
if max_frame_length <= 0:
raise ValueError("max_frame_length must be positive")
self.max_frame_length = int(max_frame_length)
self.min_window_length = int(min_window_length)
self.handpicked_motion_names = (
set(handpicked_motion_names)
if handpicked_motion_names is not None
else None
)
self.excluded_motion_names = (
set(excluded_motion_names)
if excluded_motion_names is not None
else None
)
self._world_frame_transform = (
_WorldFrameNormalizeTransform()
if bool(world_frame_normalization)
else None
)
self._allowed_prefixes: Tuple[str, ...] = ("ref_", "ft_ref_")
self._progress_counter: Optional[mp.Value] = None
# Normalize manifest path(s) to a list for aggregation.
if isinstance(manifest_path, (str, os.PathLike)):
manifest_paths: List[str] = [str(manifest_path)]
else:
manifest_paths = [str(p) for p in manifest_path]
if len(manifest_paths) == 0:
raise ValueError("At least one manifest_path must be provided")
# Aggregate shards and clips across one or many manifests into a single
# logical dataset. Clip keys must be globally unique.
self.hdf5_root = os.path.dirname(manifest_paths[0])
self._manifest_paths: List[str] = manifest_paths
self._shard_paths: List[str] = []
self.shards: List[Dict[str, Any]] = []
self.clips: Dict[str, Dict[str, Any]] = {}
for mp in manifest_paths:
if not os.path.exists(mp):
raise FileNotFoundError(
f"HDF5 manifest not found at {mp}. "
"Please set robot.motion.hdf5_root/train_hdf5_roots "
"to the correct path."
)
with open(mp, "r", encoding="utf-8") as handle:
manifest = json.load(handle)
root = os.path.dirname(mp)
shards_local = list(manifest.get("hdf5_shards", []))
clips_local = manifest.get("clips", {})
shard_offset = len(self.shards)
for shard_meta in shards_local:
self.shards.append(shard_meta)
rel = shard_meta.get("file", None)
if not isinstance(rel, str) or not rel:
raise ValueError(
f"Shard entry in manifest {mp} is missing a valid 'file' field"
)
self._shard_paths.append(os.path.join(root, rel))
for key, meta in clips_local.items():
if key in self.clips:
raise ValueError(
f"Duplicate motion clip key '{key}' found in multiple "
"manifests; clip keys must be globally unique."
)
meta_global = dict(meta)
meta_global["shard"] = (
int(meta_global.get("shard", 0)) + shard_offset
)
self.clips[key] = meta_global
if len(self.shards) == 0:
raise ValueError(
f"No HDF5 shards listed in manifests: {', '.join(manifest_paths)}"
)
self.windows: List[MotionWindow] = self._enumerate_windows()
if len(self.windows) == 0:
raise ValueError(
"No motion windows satisfy the requested frame length constraints"
)
# LRU cache of open HDF5 shard handles; size is bounded to avoid
# unbounded host-memory usage from per-file raw chunk caches.
self._file_handles: "OrderedDict[int, h5py.File]" = OrderedDict()
max_open_env = os.getenv("HOLOMOTION_HDF5_MAX_OPEN_SHARDS")
if max_open_env is None:
self._max_open_files = 64
else:
self._max_open_files = max(1, int(max_open_env))
def set_progress_counter(self, counter: Optional[mp.Value]) -> None:
self._progress_counter = counter
def _enumerate_windows(self) -> List[MotionWindow]:
windows: List[MotionWindow] = []
for motion_key, meta in self.clips.items():
if (
self.handpicked_motion_names is not None
and motion_key not in self.handpicked_motion_names
):
continue
if (
self.excluded_motion_names is not None
and motion_key in self.excluded_motion_names
):
continue
shard_index = int(meta.get("shard", 0))
start = int(meta.get("start", 0))
length = int(meta.get("length", 0))
if length <= 0:
continue
remaining = length
offset = 0
window_index = 0
while remaining > 0:
window_length = min(self.max_frame_length, remaining)
if window_length >= self.min_window_length:
win_start = start + offset
unique_key = (
f"{motion_key}__start_{win_start}_len_{window_length}"
)
windows.append(
MotionWindow(
motion_key=unique_key,
shard_index=shard_index,
start=win_start,
length=window_length,
raw_motion_key=motion_key,
window_index=window_index,
)
)
window_index += 1
offset += window_length
remaining = max(0, length - offset)
return windows
def __len__(self) -> int:
return len(self.windows)
def __getitem__(self, index: int) -> MotionClipSample:
window = self.windows[index]
shard_handle = self._get_shard_handle(window.shard_index)
start, end = window.start, window.start + window.length
arrays: Dict[str, Tensor] = {}
# Mandatory reference source: ref_*
for logical_name, dataset_name in MANDATORY_DATASETS.items():
dname = f"ref_{dataset_name}"
if dname not in shard_handle:
raise KeyError(
f"Missing mandatory dataset '{dname}' in shard index {window.shard_index}"
)
np_array = shard_handle[dname][start:end]
arrays[f"ref_{logical_name}"] = torch.from_numpy(np_array).to(
torch.float32
)
# Optional filtered reference source: ft_ref_*
for logical_name, dataset_name in MANDATORY_DATASETS.items():
dname = f"ft_ref_{dataset_name}"
if dname in shard_handle:
np_array = shard_handle[dname][start:end]
arrays[f"ft_ref_{logical_name}"] = torch.from_numpy(
np_array
).to(torch.float32)
if "frame_flag" in shard_handle:
frame_flag_np = shard_handle["frame_flag"][start:end]
frame_flag = torch.from_numpy(frame_flag_np).to(torch.long)
else:
frame_flag = torch.ones(window.length, dtype=torch.long)
if window.length > 1:
frame_flag[0] = 0
frame_flag[-1] = 2
elif window.length == 1:
# Single-frame window: mark as both start and end (use 2 for end)
frame_flag[0] = 2
arrays["frame_flag"] = frame_flag
if self._world_frame_transform is not None:
self._world_frame_transform(arrays)
# Derived root_* for ref_* (after normalization)
arrays["ref_root_pos"] = arrays["ref_rg_pos"][:, 0, :]
arrays["ref_root_rot"] = arrays["ref_rb_rot"][:, 0, :]
arrays["ref_root_vel"] = arrays["ref_body_vel"][:, 0, :]
arrays["ref_root_ang_vel"] = arrays["ref_body_ang_vel"][:, 0, :]
# Derived root_* for optional ft_ref_* (after normalization)
if (
"ft_ref_rg_pos" in arrays
and "ft_ref_rb_rot" in arrays
and "ft_ref_body_vel" in arrays
and "ft_ref_body_ang_vel" in arrays
):
arrays["ft_ref_root_pos"] = arrays["ft_ref_rg_pos"][:, 0, :]
arrays["ft_ref_root_rot"] = arrays["ft_ref_rb_rot"][:, 0, :]
arrays["ft_ref_root_vel"] = arrays["ft_ref_body_vel"][:, 0, :]
arrays["ft_ref_root_ang_vel"] = arrays["ft_ref_body_ang_vel"][
:, 0, :
]
if self._progress_counter is not None:
with self._progress_counter.get_lock():
self._progress_counter.value += 1
return MotionClipSample(
motion_key=window.motion_key,
raw_motion_key=window.raw_motion_key,
window_index=int(index),
tensors=arrays,
length=window.length,
)
def _get_shard_handle(self, shard_index: int) -> h5py.File:
if shard_index in self._file_handles:
handle = self._file_handles.pop(shard_index)
if handle.id:
# Mark as most recently used.
self._file_handles[shard_index] = handle
return handle
if shard_index < 0 or shard_index >= len(self._shard_paths):
raise IndexError(
f"Shard index {shard_index} out of range for "
f"{len(self._shard_paths)} available shards"
)
shard_path = self._shard_paths[shard_index]
# Open with SWMR and a configurable raw chunk cache to speed up repeated reads.
# The default cache size (in bytes) can be overridden via the
# HOLOMOTION_HDF5_RDCC_NBYTES environment variable.
rdcc_nbytes_env = os.getenv("HOLOMOTION_HDF5_RDCC_NBYTES")
if rdcc_nbytes_env is None:
rdcc_nbytes = 256 * 1024 * 1024 # 256MB default
else:
rdcc_nbytes = int(rdcc_nbytes_env)
handle = h5py.File(
shard_path,
"r",
libver="latest",
swmr=True,
rdcc_nbytes=rdcc_nbytes,
rdcc_w0=0.75,
)
# Enforce LRU limit on the number of simultaneously open shard files.
if (
self._max_open_files is not None
and len(self._file_handles) >= self._max_open_files
):
old_index, old_handle = self._file_handles.popitem(last=False)
old_handle.close()
self._file_handles[shard_index] = handle
return handle
def close(self) -> None:
"""Close all open HDF5 shard handles for this dataset."""
for handle in self._file_handles.values():
if handle.id:
handle.close()
self._file_handles.clear()
class MotionClipBatchCache:
"""Double-buffered motion cache for RL training and evaluation."""
@staticmethod
def _infer_cuda_device_index() -> int:
device_count = int(torch.cuda.device_count())
local_rank_env = os.environ.get("LOCAL_RANK")
if local_rank_env is not None:
local_rank = int(local_rank_env)
if 0 <= local_rank < device_count:
return local_rank
return int(torch.cuda.current_device())
@classmethod
def _normalize_stage_device(
cls, stage_device: Optional[object]
) -> Optional[torch.device]:
if stage_device is None:
return None
if isinstance(stage_device, torch.device):
if stage_device.type == "cpu":
return None
if stage_device.type != "cuda":
raise ValueError(
f"Unsupported stage_device type: {stage_device.type}"
)
if not torch.cuda.is_available():
raise RuntimeError(
"stage_device requested CUDA but CUDA is not available"
)
if stage_device.index is not None:
return stage_device
return torch.device("cuda", cls._infer_cuda_device_index())
if isinstance(stage_device, str):
stage_device_str = stage_device.strip().lower()
if stage_device_str in ("none", "cpu"):
return None
if stage_device_str == "cuda":
if not torch.cuda.is_available():
raise RuntimeError(
"stage_device requested CUDA but CUDA is not available"
)
return torch.device("cuda", cls._infer_cuda_device_index())
if stage_device_str.startswith("cuda:"):
if not torch.cuda.is_available():
raise RuntimeError(
"stage_device requested CUDA but CUDA is not available"
)
return torch.device(stage_device_str)
raise ValueError(
f"Unsupported stage_device string: {stage_device}"
)
raise TypeError(
f"Unsupported stage_device value type: {type(stage_device)}"
)
def __init__(
self,
train_dataset: Dataset[MotionClipSample],
*,
val_dataset: Optional[Dataset[MotionClipSample]] = None,
batch_size: int,
stage_device: Optional[torch.device] = None,
num_workers: int = 4,
prefetch_factor: int = 2,
pin_memory: bool = True,
persistent_workers: bool = True,
sampler_rank: int = 0,
sampler_world_size: int = 1,
allowed_prefixes: Optional[Sequence[str]] = None,
swap_interval_steps: Optional[int] = None,
force_timeout_on_swap: bool = True,
stage_on_swap_only: bool = False,
batch_progress_bar: bool = False,
seed: Optional[int] = None,
loader_timeout: float = 0.0,
) -> None:
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if float(loader_timeout) < 0.0:
raise ValueError("loader_timeout must be >= 0")
self._datasets = {
"train": train_dataset,
"val": val_dataset if val_dataset is not None else train_dataset,
}
self._mode = "train"
self._seed = (
int(seed) if seed is not None else int(time.time_ns() & 0x7FFFFFFF)
)
self._stage_device = self._normalize_stage_device(stage_device)
self._sampler_rank = int(sampler_rank)
self._sampler_world_size = int(max(1, sampler_world_size))
self._batch_size = int(batch_size)
self._allowed_prefixes: Optional[Tuple[str, ...]] = (
tuple(allowed_prefixes) if allowed_prefixes is not None else None
)
# If enabled, keep the prefetched batch on CPU (FK on CPU) and stage to GPU
# only during cache swapping (advance).
self._stage_on_swap_only = bool(stage_on_swap_only)
self._batch_progress_bar = bool(batch_progress_bar)
self._loader_timeout = float(loader_timeout)
self.force_timeout_on_swap = bool(force_timeout_on_swap)
self._batch_progress_counter: Optional[mp.Value] = None
if self._should_use_batch_progress():
ctx = mp.get_context("spawn")
self._batch_progress_counter = ctx.Value("i", 0)
self.swap_interval_steps = (
swap_interval_steps
if swap_interval_steps is not None
else train_dataset.max_frame_length
)
self._num_workers = int(max(0, num_workers))
self._prefetch_factor = (
prefetch_factor if prefetch_factor is not None else None
)
self._pin_memory = bool(pin_memory)
self._persistent_workers = bool(persistent_workers and num_workers > 0)
self._dataloader: Optional[DataLoader] = None
self._sampler: Optional[Sampler[int]] = None
self._iterator: Optional[Iterator[ClipBatch]] = None
self._current_batch: Optional[ClipBatch] = None
self._next_batch: Optional[ClipBatch] = None
self._swap_index = 0
self._effective_batch_size: Optional[int] = None
self._num_batches: Optional[int] = None
# Weighted-bin sampling state
self._weighted_bin_enabled: bool = False
self._weighted_bin_bins: Optional[List[List[int]]] = None
self._weighted_bin_ratios: Optional[List[float]] = None
self._weighted_bin_specs: Optional[List[Dict[str, Any]]] = None
self._cache_curriculum_enabled: bool = False
self._cache_curriculum_cfg: Dict[str, Any] = {}
self._cache_curriculum_sampler: Optional[
PrioritizedInfiniteSampler
] = None
self._cache_curriculum_dump_enabled: bool = False
self._cache_curriculum_dump_every_swaps: int = 10
self._cache_curriculum_dump_chunk_size: int = 4096
self._cache_curriculum_dump_dir: Path = Path(
"cache_curriculum_window_scores"
)
self._cache_curriculum_last_dump_swap: int = -1
# Async GPU staging helpers
self._copy_stream = None
self._pending_ready_event = None
self._current_ready_event = None
self._next_ready_event = None
self._build_dataloader()
if (
self._stage_device is not None
and self._stage_device.type == "cuda"
):
self._copy_stream = torch.cuda.Stream(device=self._stage_device)
self._prime_buffers()
@property
def current_batch(self) -> ClipBatch:
assert self._current_batch is not None
return self._current_batch
@property
def max_frame_length(self) -> int:
return self.current_batch.max_frame_length
@property
def clip_count(self) -> int:
return self.current_batch.lengths.shape[0]
@property
def mode(self) -> str:
return self._mode
@property
def swap_index(self) -> int:
return self._swap_index
@property
def num_batches(self) -> int:
if self._num_batches is None:
raise RuntimeError("DataLoader is not initialised")
return int(self._num_batches)
def set_mode(self, mode: str) -> None:
if mode == self._mode:
return
if mode not in self._datasets:
raise ValueError(f"Unknown cache mode: {mode}")
self._mode = mode
self._build_dataloader()
self._prime_buffers()
def set_seed(self, seed: int, *, reinitialize: bool = True) -> None:
self._seed = int(seed)
if reinitialize:
self._build_dataloader()
self._prime_buffers()
def advance(self) -> None:
if self._stage_on_swap_only:
if self._next_batch is None:
self._next_batch = self._fetch_next_batch()
# Stage the prefetched CPU batch to GPU only at swap time.
staged = self._stage_batch_blocking(self._next_batch)
self._current_batch = staged
self._next_batch = self._fetch_next_batch()
self._swap_index += 1
return
if self._next_batch is None:
self._next_batch = self._fetch_next_batch()
# Ensure asynchronous staging finished before swapping in next batch
if (
self._next_ready_event is not None
and self._stage_device is not None
and self._stage_device.type == "cuda"
):
torch.cuda.current_stream(self._stage_device).wait_event(
self._next_ready_event
)
self._current_batch = self._next_batch
self._next_batch = self._fetch_next_batch()
self._swap_index += 1
# -------------------------
# Weighted-bin configuration
# -------------------------
def enable_weighted_bin_sampling(
self, cfg: Optional[Dict[str, Any]] = None
) -> None:
"""Enable regex-based weighted-bin sampling over manifest motion keys.
The configuration must provide a list under ``bin_regex_patterns`` (or the
legacy name ``bin_regrex_patterns``), where each element is a mapping with:
- ``regex`` (or ``regrex``): Python regular expression applied to the
manifest clip key (e.g., ``AMASS_.*``, ``VR_pico_.*``).
- ``ratio``: Target sampling ratio in [0, 1].
The sum of explicit bin ratios must be <= 1.0. Any remaining mass is
assigned to an implicit ``others`` bin that collects all clips not
matched by any regex.
"""
cfg_local: Dict[str, Any] = dict(cfg or {})
if self._cache_curriculum_enabled:
raise ValueError(
"weighted-bin and cache curriculum sampling cannot be enabled together."
)
dataset = self._datasets.get("train")
if dataset is None:
raise ValueError(
"Weighted-bin sampling requires a training dataset"
)
# Collect manifest-level motion keys for all windows in order
window_keys: List[str] = []
for window in dataset.windows:
motion_key = getattr(window, "raw_motion_key", None)
if motion_key is None:
full_key = getattr(window, "motion_key", "")
if "__start_" in full_key:
motion_key = full_key.split("__start_", 1)[0]
else:
motion_key = full_key
window_keys.append(motion_key)
bin_indices, all_ratios, specs = _configure_weighted_bins(
keys=window_keys,
cfg=cfg_local,
batch_size_for_log=int(self._batch_size),
)
# Log summary in terms of windows
table_rows = []
for item in specs:
table_rows.append(
[
item["name"],
item["regex"],
f"{item['ratio']:.4f}",
int(item["count"]),
f"{item['dataset_fraction']:.4f}",
f"{item['batch_fraction']:.4f}",
]
)
headers = [
"bin",
"regex",
"final_ratio",
"num_windows",
"dataset_fraction",
"batch_fraction",
]
logger.info(
"Motion cache weighted-bin sampling configured:\n"
+ tabulate(table_rows, headers=headers, tablefmt="simple_outline")
)
# Activate weighted-bin sampling and rebuild dataloader/cache
self._weighted_bin_enabled = True
self._weighted_bin_bins = bin_indices
self._weighted_bin_ratios = all_ratios
self._weighted_bin_specs = specs
self._build_dataloader()
self._prime_buffers()
def enable_cache_curriculum_sampling(
self, cfg: Optional[Dict[str, Any]] = None
) -> None:
if self._weighted_bin_enabled:
raise ValueError(
"cache curriculum and weighted-bin sampling cannot be enabled together."
)
self._cache_curriculum_enabled = True
self._cache_curriculum_cfg = dict(cfg or {})
self._cache_curriculum_dump_enabled = bool(
self._cache_curriculum_cfg.get(
"dump_whole_window_scores_json", True
)
)
self._cache_curriculum_dump_every_swaps = max(
1,
int(
self._cache_curriculum_cfg.get(
"dump_whole_window_scores_every_swaps", 10
)
),
)
self._cache_curriculum_dump_chunk_size = max(
1,
int(
self._cache_curriculum_cfg.get(
"dump_whole_window_scores_chunk_size", 4096
)
),
)
self._cache_curriculum_dump_dir = Path(
str(
self._cache_curriculum_cfg.get(
"dump_whole_window_scores_dir",
"cache_curriculum_window_scores",
)
)
)
self._cache_curriculum_last_dump_swap = -1
self._prepare_cache_curriculum_dump_dir(
self._cache_curriculum_dump_dir,
reason="enabled",
)
self._build_dataloader()
self._prime_buffers()
def _prepare_cache_curriculum_dump_dir(
self, dump_dir: Path, *, reason: str
) -> None:
self._cache_curriculum_dump_dir = Path(str(dump_dir))
if not self._cache_curriculum_dump_enabled:
return
self._cache_curriculum_dump_dir.mkdir(parents=True, exist_ok=True)
logger.info(
"Cache curriculum whole-window score dump "
f"{reason}: dir={self._cache_curriculum_dump_dir}, "
f"every_swaps={self._cache_curriculum_dump_every_swaps}, "
f"rank={self._sampler_rank}"
)
def set_cache_curriculum_dump_dir(self, dump_dir: str) -> None:
self._prepare_cache_curriculum_dump_dir(
Path(str(dump_dir)),
reason="directory set",
)
def update_cache_curriculum(
self,
*,
window_indices: Tensor,
mpkpe_signal_means: Tensor,
completion_rate_means: Tensor,
counts: Tensor,
swap_index: int,
) -> bool:
if self._cache_curriculum_sampler is None:
return False
updated = (
self._cache_curriculum_sampler.maybe_update_from_observations(
window_indices=window_indices,
mpkpe_signal_means=mpkpe_signal_means,
completion_rate_means=completion_rate_means,
counts=counts,
swap_index=swap_index,
)
)
if updated:
self._refresh_prefetched_batch()
self._maybe_dump_cache_curriculum_scores_json(swap_index=swap_index)
return updated
def _refresh_prefetched_batch(self) -> None:
if self._next_batch is None:
return
self._next_batch = self._fetch_next_batch()
def _maybe_dump_cache_curriculum_scores_json(
self, *, swap_index: int
) -> None:
if not self._cache_curriculum_dump_enabled:
return
if self._cache_curriculum_sampler is None:
return
swap_idx = int(swap_index)
if swap_idx <= 0:
return
if swap_idx % self._cache_curriculum_dump_every_swaps != 0:
return
if self._cache_curriculum_last_dump_swap == swap_idx:
return
dataset = self._datasets["train"]
ds_len = int(len(dataset))
if ds_len <= 0:
return
self._cache_curriculum_dump_dir.mkdir(parents=True, exist_ok=True)
output_path = self._cache_curriculum_dump_dir / (
"whole_window_scores_"
f"rank_{self._sampler_rank:04d}_swap_{swap_idx:06d}.json"
)
sampler_version = int(self._cache_curriculum_sampler.state_version)
windows = dataset.windows
score_values: List[float] = []
completion_values: List[float] = []
rel_improve_values: List[float] = []
selection_count_values: List[int] = []
seen_values: List[bool] = []
in_pool_values: List[bool] = []
chunk_size = max(1, int(self._cache_curriculum_dump_chunk_size))
for chunk_start in range(0, ds_len, chunk_size):
chunk_end = min(ds_len, chunk_start + chunk_size)
chunk_indices = torch.arange(
chunk_start, chunk_end, dtype=torch.long
)
chunk_scores = (
self._cache_curriculum_sampler.get_scores_for_indices(
chunk_indices
)
)
chunk_state = (
self._cache_curriculum_sampler.get_window_state_for_indices(
chunk_indices
)
)
if chunk_scores.numel() != chunk_indices.numel():
raise ValueError(
"Whole-window score dump shape mismatch for "
"cache curriculum sampler."
)
score_values.extend(chunk_scores.tolist())
completion_values.extend(
chunk_state["ema_completion_rate"].tolist()
)
rel_improve_values.extend(
chunk_state["completion_rate_rel_improve"].tolist()
)
selection_count_values.extend(
chunk_state["selection_count"].tolist()
)
seen_values.extend(chunk_state["seen"].tolist())
in_pool_values.extend(chunk_state["in_prioritized_pool"].tolist())
rows: List[Dict[str, Any]] = []
for window_index in range(ds_len):
window = windows[window_index]
rows.append(
{
"swap_index": int(swap_idx),
"rank": int(self._sampler_rank),
"sampler_state_version": sampler_version,
"window_index": int(window_index),
"raw_motion_key": str(window.raw_motion_key),
"motion_key": str(window.motion_key),
"start": int(window.start),
"length": int(window.length),
"score": float(score_values[window_index]),
"selection_count": int(
selection_count_values[window_index]
),
"ema_completion_rate": float(
completion_values[window_index]
),
"completion_rate_rel_improve": float(
rel_improve_values[window_index]
),
"seen": bool(seen_values[window_index]),
"in_prioritized_pool": bool(in_pool_values[window_index]),
}
)
payload: Dict[str, Any] = {
"swap_index": int(swap_idx),
"rank": int(self._sampler_rank),
"sampler_state_version": sampler_version,
"num_windows": int(ds_len),
"pool_metrics": self._cache_curriculum_sampler.get_pool_statistics()
or {},
"rows": rows,
}
with output_path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2)
handle.write("\n")
self._cache_curriculum_last_dump_swap = swap_idx
def cache_curriculum_scores_for_window_indices(
self, window_indices: Tensor
) -> Optional[Tuple[Tensor, Dict[str, Tensor], int]]:
if self._cache_curriculum_sampler is None:
return None
scores = self._cache_curriculum_sampler.get_scores_for_indices(
window_indices
)
state = self._cache_curriculum_sampler.get_window_state_for_indices(
window_indices
)
version = self._cache_curriculum_sampler.state_version
return scores, state, version
def cache_curriculum_pool_statistics(
self,
) -> Optional[Dict[str, float]]:
if self._cache_curriculum_sampler is None:
return None
return self._cache_curriculum_sampler.get_pool_statistics()
def sample_env_assignments(
self,
num_envs: int,
n_future_frames: int,
device: torch.device,
*,
deterministic_start: bool = False,
) -> Tuple[Tensor, Tensor]:
batch = self.current_batch
lengths = batch.lengths.to(device)
if num_envs <= 0:
raise ValueError("num_envs must be positive")
total = int(lengths.shape[0])
if total == 0:
raise ValueError(
"Cannot sample from an empty batch. Ensure the cache contains "
"at least one motion clip before calling sample_env_assignments."
)
clip_indices = torch.randint(
low=0, high=total, size=(num_envs,), device=device
)
max_start = torch.clamp(
lengths[clip_indices] - 1 - n_future_frames, min=0
)
if deterministic_start:
frame_starts = torch.zeros_like(max_start)
else:
rand = torch.rand_like(max_start, dtype=torch.float32)
frame_starts = torch.floor(rand * (max_start + 1).float()).to(
torch.long
)
return clip_indices, frame_starts
def _prepare_gather_indices(
self,
*,
clip_indices: Tensor,
frame_indices: Tensor,
n_future_frames: int,
) -> Tuple[Tensor, Tensor]:
batch = self.current_batch
staged_device = batch.lengths.device
selected_clips = clip_indices.to(
staged_device, dtype=torch.long
).clone()
frame_indices = frame_indices.to(
staged_device, dtype=torch.long
).clone()
temporal_span = 1 + int(n_future_frames)
time_offsets = torch.arange(
temporal_span, device=staged_device, dtype=torch.long
)
gather_timesteps = frame_indices[:, None] + time_offsets[None, :]
lengths = batch.lengths
max_valid = torch.clamp(
lengths.index_select(0, selected_clips) - 1, min=0
)
gather_timesteps = torch.minimum(
gather_timesteps, max_valid[:, None]
).clone()
return selected_clips, gather_timesteps
def gather_tensor(
self,
tensor_name: str,
*,
clip_indices: Tensor,
frame_indices: Tensor,
n_future_frames: int,
) -> Tensor:
batch = self.current_batch
if tensor_name not in batch.tensors:
raise KeyError(
f"Tensor '{tensor_name}' is not present in current_batch"
)
selected_clips, gather_timesteps = self._prepare_gather_indices(
clip_indices=clip_indices,
frame_indices=frame_indices,
n_future_frames=n_future_frames,
)
tensor = batch.tensors[tensor_name]
return tensor[selected_clips[:, None], gather_timesteps, ...]
def lengths_for_indices(self, clip_indices: Tensor) -> Tensor:
lengths = self.current_batch.lengths.to(clip_indices.device)
return lengths.index_select(0, clip_indices.long())
def motion_keys_for_indices(self, clip_indices: Tensor) -> List[str]:
result = []
base_keys = self.current_batch.motion_keys
for idx in clip_indices.tolist():
result.append(base_keys[int(idx)])
return result
def window_indices_for_indices(self, clip_indices: Tensor) -> Tensor:
base_indices = self.current_batch.window_indices.to(
clip_indices.device
)
return base_indices.index_select(0, clip_indices.long())
def _prime_buffers(self) -> None:
if self._stage_on_swap_only:
# Prefetch on CPU; stage to GPU only for current batch.
cpu_current = self._fetch_next_batch()
self._current_batch = self._stage_batch_blocking(cpu_current)
self._next_batch = self._fetch_next_batch()
self._pending_ready_event = None
self._current_ready_event = None
self._next_ready_event = None
return
self._current_batch = self._fetch_next_batch()
# Ensure first staged batch is ready before consumption
if (
self._current_ready_event is not None
and self._stage_device is not None
and self._stage_device.type == "cuda"
):
t0 = time.time()
torch.cuda.current_stream(self._stage_device).wait_event(
self._current_ready_event
)
t1 = time.time()
logger.info(
f"Perf/Cache/cuda_wait_event_ms={((t1 - t0) * 1e3):.2f} (first)"
)
self._next_batch = self._fetch_next_batch()
def _fetch_next_batch(self) -> ClipBatch:
batch = self._load_next_batch()
if self._stage_on_swap_only:
# Prefetch raw batch on CPU.
return batch
staged = self._stage_batch(batch, record_event=True)
# Move pending event into current/next slot
if self._current_batch is None:
self._current_ready_event = self._pending_ready_event
else:
self._next_ready_event = self._pending_ready_event
self._pending_ready_event = None
return staged
def _load_next_batch(self) -> ClipBatch:
if self._should_use_batch_progress():
return self._load_next_batch_with_progress()
return self._load_next_batch_raw()
def _load_next_batch_raw(self) -> ClipBatch:
if self._iterator is None:
self._iterator = self._build_iterator()
try:
batch = next(self._iterator)
except StopIteration:
self._iterator = self._build_iterator(reset_epoch=True)
batch = next(self._iterator)
return batch
def _load_next_batch_with_progress(self) -> ClipBatch:
if self._iterator is None:
self._iterator = self._build_iterator()
expected = int(self._effective_batch_size or self._batch_size)
counter = self._batch_progress_counter
if counter is None:
return self._load_next_batch_raw()
with counter.get_lock():
counter.value = 0
pbar = tqdm(
total=expected,
desc="Collecting motion batch",
leave=False,
dynamic_ncols=True,
)
last = 0
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(self._load_next_batch_raw)
while not future.done():
with counter.get_lock():
value = counter.value
if value > last:
step = min(value, expected) - last
if step > 0:
pbar.update(step)
last += step
time.sleep(0.05)
batch = future.result(timeout=self._result_timeout())
with counter.get_lock():
value = counter.value
if value > last:
step = min(value, expected) - last
if step > 0:
pbar.update(step)
pbar.close()
return batch
def _stage_batch_blocking(self, batch: ClipBatch) -> ClipBatch:
"""Stage a CPU batch to the configured device on the current stream.
This path is used when `stage_on_swap_only=True` so that only the current
cache batch resides on GPU.
"""
if self._stage_device is None:
return batch
non_blocking = bool(
self._pin_memory and self._stage_device.type == "cuda"
)
tensors = {
name: tensor.to(self._stage_device, non_blocking=non_blocking)
for name, tensor in batch.tensors.items()
}
lengths = batch.lengths.to(
self._stage_device, non_blocking=non_blocking
)
window_indices = batch.window_indices.to(
self._stage_device, non_blocking=non_blocking
)
staged = ClipBatch(
tensors=tensors,
lengths=lengths,
motion_keys=batch.motion_keys,
raw_motion_keys=getattr(
batch, "raw_motion_keys", batch.motion_keys
),
window_indices=window_indices,
max_frame_length=batch.max_frame_length,
)
return staged
def _stage_batch(
self,
batch: ClipBatch,
record_event: bool = False,
) -> ClipBatch:
if self._stage_device is None:
return batch
# If CUDA, copy on a dedicated stream and record readiness
if self._copy_stream is None and (
self._stage_device is not None
and self._stage_device.type == "cuda"
):
self._copy_stream = torch.cuda.Stream(device=self._stage_device)
logger.info(
f"Perf/Cache: created CUDA copy stream lazily on {self._stage_device}"
)
if self._copy_stream is not None:
# estimate payload size for logging
try:
total_bytes = 0
for tensor in batch.tensors.values():
total_bytes += int(tensor.element_size() * tensor.numel())
total_bytes += int(
batch.lengths.element_size() * batch.lengths.numel()
)
total_bytes += int(
batch.window_indices.element_size()
* batch.window_indices.numel()
)
except Exception:
total_bytes = -1
with torch.cuda.stream(self._copy_stream):
tensors = {
name: tensor.to(self._stage_device, non_blocking=True)
for name, tensor in batch.tensors.items()
}
lengths = batch.lengths.to(
self._stage_device, non_blocking=True
)
window_indices = batch.window_indices.to(
self._stage_device, non_blocking=True
)
if record_event:
ev = torch.cuda.Event()
ev.record(self._copy_stream)
self._pending_ready_event = ev
else:
tensors = {
name: tensor.to(self._stage_device, non_blocking=True)
for name, tensor in batch.tensors.items()
}
lengths = batch.lengths.to(self._stage_device, non_blocking=True)
window_indices = batch.window_indices.to(
self._stage_device, non_blocking=True
)
return ClipBatch(
tensors=tensors,
lengths=lengths,
motion_keys=batch.motion_keys,
raw_motion_keys=getattr(
batch, "raw_motion_keys", batch.motion_keys
),
window_indices=window_indices,
max_frame_length=batch.max_frame_length,
)
def _build_iterator(
self, *, reset_epoch: bool = False
) -> Iterator[ClipBatch]:
if self._dataloader is None:
raise RuntimeError("DataLoader is not initialised")
if isinstance(self._sampler, DistributedSampler) and reset_epoch:
self._sampler.set_epoch(self._swap_index + 1)
return iter(self._dataloader)
def _build_dataloader(self) -> None:
dataset = self._datasets[self._mode]
dataset.set_progress_counter(self._batch_progress_counter)
# Clamp batch size to dataset length to avoid empty iterator when drop_last is disabled
effective_batch_size = self._batch_size
ds_len = len(dataset)
if isinstance(ds_len, int) and ds_len > 0:
effective_batch_size = max(1, min(self._batch_size, ds_len))
# Sampler selection: validation uses standard distributed/sequential samplers;
# training can optionally use weighted-bin sampling.
if self._mode == "val":
if self._sampler_world_size > 1:
self._sampler = DistributedSampler(
dataset,
num_replicas=self._sampler_world_size,
rank=self._sampler_rank,
shuffle=False,
drop_last=False,
)
else:
self._sampler = None
self._cache_curriculum_sampler = None
else:
if self._cache_curriculum_enabled:
seed = self._seed + self._sampler_rank * 100003
cfg = dict(self._cache_curriculum_cfg)
self._cache_curriculum_sampler = PrioritizedInfiniteSampler(
dataset_len=ds_len,
batch_size=effective_batch_size,
seed=seed,
p_a_ratio=float(cfg.get("p_a_ratio", 0.2)),
ema_alpha_signal=float(cfg.get("ema_alpha_signal", 0.2)),
ema_alpha_rel_improve=float(
cfg.get("ema_alpha_rel_improve", 0.2)
),
relative_eps=float(cfg.get("relative_eps", 1.0e-6)),
)
self._cache_curriculum_last_dump_swap = -1
self._sampler = self._cache_curriculum_sampler
elif (
self._weighted_bin_enabled
and self._weighted_bin_bins is not None
and self._weighted_bin_ratios is not None
):
seed = self._seed + self._sampler_rank * 100003
self._sampler = WeightedBinInfiniteSampler(
dataset_len=ds_len,
bin_indices=self._weighted_bin_bins,
ratios=self._weighted_bin_ratios,
batch_size=effective_batch_size,
seed=seed,
)
self._cache_curriculum_sampler = None
else:
if self._sampler_world_size > 1:
# Infinite sampler for training: no epoch boundaries
self._sampler = InfiniteDistributedSampler(
dataset,
num_replicas=self._sampler_world_size,
rank=self._sampler_rank,
shuffle=True,
drop_last=False,
)
else:
# Infinite sampler for single-process training
self._sampler = InfiniteRandomSampler(dataset)
self._cache_curriculum_sampler = None
# Only pass prefetch_factor when using workers
pf = (
self._prefetch_factor
if (self._num_workers and self._num_workers > 0)
else None
)
pw = (
self._persistent_workers
if (self._num_workers and self._num_workers > 0)
else False
)
# Collate wrapper: in validation, pad the batch up to cache size by
# uniformly repeating samples when dataset is smaller than batch size.
collate = partial(
_cache_collate_fn,
mode=self._mode,
batch_size=self._batch_size,
)
mp_ctx = None
if self._num_workers and self._num_workers > 0:
mp_ctx = mp.get_context("spawn")
worker_init_fn = None
if (
self._num_workers > 0
and self._stage_device is not None
and self._stage_device.type == "cuda"
):
worker_init_fn = _cpu_only_dataloader_worker_init_fn
self._dataloader = DataLoader(
dataset,
batch_size=effective_batch_size,
sampler=self._sampler,
shuffle=(self._sampler is None and self._mode != "val"),
num_workers=self._num_workers,
prefetch_factor=pf,
pin_memory=self._pin_memory,
timeout=self._loader_timeout_seconds(),
persistent_workers=pw,
collate_fn=collate,
drop_last=False,
multiprocessing_context=mp_ctx,
worker_init_fn=worker_init_fn,
)
self._iterator = None
self._current_batch = None
self._next_batch = None
self._swap_index = 0
# Compute number of batches only for validation; training is infinite
local_len = ds_len
if self._mode == "val":
if self._sampler is not None:
local_len = (
ds_len + self._sampler_world_size - 1
) // self._sampler_world_size
self._effective_batch_size = int(effective_batch_size)
self._num_batches = (
local_len + self._effective_batch_size - 1
) // self._effective_batch_size
else:
self._effective_batch_size = int(effective_batch_size)
self._num_batches = 2**31 # effectively infinite for logging
def close(self) -> None:
"""Release DataLoader workers and close underlying HDF5 datasets."""
datasets = self.__dict__.get("_datasets")
if datasets is None:
return
self._iterator = None
self._current_batch = None
self._next_batch = None
self._dataloader = None
self._copy_stream = None
self._pending_ready_event = None
self._current_ready_event = None
self._next_ready_event = None
for ds in datasets.values():
if ds is not None:
ds.close()
def __del__(self) -> None:
self.close()
def _loader_timeout_seconds(self) -> float:
if not self.force_timeout_on_swap:
return 0.0
return self._loader_timeout
def _result_timeout(self) -> Optional[float]:
timeout_s = self._loader_timeout_seconds()
if timeout_s <= 0.0:
return None
return timeout_s + 1.0
def _should_use_batch_progress(self) -> bool:
if not self._batch_progress_bar:
return False
if self._sampler_world_size > 1:
return False
if self._loader_timeout_seconds() > 0.0:
return False
return True
================================================
FILE: holomotion/src/training/reference_filter_export.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import tempfile
from pathlib import Path
from typing import Mapping, Sequence
import matplotlib
import numpy as np
from omegaconf import DictConfig, ListConfig
from scipy.spatial.transform import Rotation as Rotation3D
from holomotion.src.training.h5_dataloader import (
MotionClipSample,
build_motion_datasets_from_cfg,
)
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def _to_numpy(array_like) -> np.ndarray:
if hasattr(array_like, "detach"):
array_like = array_like.detach().cpu().numpy()
return np.asarray(array_like, dtype=np.float32)
def _require_tensor(
tensors: Mapping[str, object], tensor_name: str, error_message: str
) -> np.ndarray:
if tensor_name not in tensors:
raise ValueError(error_message)
return _to_numpy(tensors[tensor_name])
def _quat_xyzw_to_rpy(quat_xyzw: np.ndarray) -> np.ndarray:
quat_xyzw = np.asarray(quat_xyzw, dtype=np.float32)
flat = quat_xyzw.reshape(-1, 4)
euler = Rotation3D.from_quat(flat).as_euler("xyz", degrees=False)
return euler.reshape(*quat_xyzw.shape[:-1], 3).astype(
np.float32, copy=False
)
def _write_npz(output_path: Path, payload: Mapping[str, np.ndarray]) -> None:
np.savez(str(output_path), **payload)
def _plot_series_groups(
output_path: Path,
title: str,
groups: Sequence[tuple[str, np.ndarray, np.ndarray]],
axis_labels: Sequence[str] = ("x", "y", "z"),
) -> None:
nrows = len(groups)
ncols = len(axis_labels)
fig, axes = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(4.0 * ncols, 2.8 * max(1, nrows)),
squeeze=False,
)
plot_steps = np.arange(groups[0][1].shape[0], dtype=np.int32)
for row_idx, (group_name, ref_values, ft_values) in enumerate(groups):
for col_idx, axis_name in enumerate(axis_labels):
ax = axes[row_idx, col_idx]
ax.plot(
plot_steps,
ref_values[:, col_idx],
label="raw",
linewidth=1.4,
)
ax.plot(
plot_steps,
ft_values[:, col_idx],
label="filtered",
linewidth=1.2,
)
ax.set_title(f"{group_name} {axis_name}")
ax.grid(True, alpha=0.3)
if row_idx == 0 and col_idx == 0:
ax.legend(loc="best")
fig.suptitle(title)
fig.tight_layout()
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
def _plot_dof_matrix(
output_path: Path,
title: str,
dof_names: Sequence[str],
ref_values: np.ndarray,
ft_values: np.ndarray,
) -> None:
num_dofs = len(dof_names)
fig, axes = plt.subplots(
nrows=num_dofs,
ncols=1,
figsize=(14.0, max(2.8 * num_dofs, 3.5)),
squeeze=False,
)
plot_steps = np.arange(ref_values.shape[0], dtype=np.int32)
for idx, dof_name in enumerate(dof_names):
ax = axes[idx, 0]
ax.plot(plot_steps, ref_values[:, idx], label="raw", linewidth=1.4)
ax.plot(plot_steps, ft_values[:, idx], label="filtered", linewidth=1.2)
ax.set_title(dof_name)
ax.grid(True, alpha=0.3)
if idx == 0:
ax.legend(loc="best")
fig.suptitle(title)
fig.tight_layout()
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
def export_reference_filter_debug_artifacts(
*,
sample: MotionClipSample,
output_dir: str | Path,
body_names: Sequence[str],
dof_names: Sequence[str],
selected_body_links: Sequence[str],
) -> Path:
tensors = sample.tensors
if "ft_ref_rg_pos" not in tensors or "ft_ref_dof_pos" not in tensors:
raise ValueError(
"Filtered reference tensors are unavailable. Ensure online filtering "
"is enabled and ft_ref_* tensors are materialized."
)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
ref_root_pos = _require_tensor(
tensors,
"ref_root_pos",
"Missing ref_root_pos tensor in sampled clip.",
)
ft_root_pos = _require_tensor(
tensors,
"ft_ref_root_pos",
"Missing ft_ref_root_pos tensor in sampled clip.",
)
ref_root_rot = _require_tensor(
tensors,
"ref_root_rot",
"Missing ref_root_rot tensor in sampled clip.",
)
ft_root_rot = _require_tensor(
tensors,
"ft_ref_root_rot",
"Missing ft_ref_root_rot tensor in sampled clip.",
)
ref_root_vel = _require_tensor(
tensors,
"ref_root_vel",
"Missing ref_root_vel tensor in sampled clip.",
)
ft_root_vel = _require_tensor(
tensors,
"ft_ref_root_vel",
"Missing ft_ref_root_vel tensor in sampled clip.",
)
ref_root_ang_vel = _require_tensor(
tensors,
"ref_root_ang_vel",
"Missing ref_root_ang_vel tensor in sampled clip.",
)
ft_root_ang_vel = _require_tensor(
tensors,
"ft_ref_root_ang_vel",
"Missing ft_ref_root_ang_vel tensor in sampled clip.",
)
ref_rg_pos = _require_tensor(
tensors,
"ref_rg_pos",
"Missing ref_rg_pos tensor in sampled clip.",
)
ft_ref_rg_pos = _require_tensor(
tensors,
"ft_ref_rg_pos",
"Missing ft_ref_rg_pos tensor in sampled clip.",
)
ref_body_vel = _require_tensor(
tensors,
"ref_body_vel",
"Missing ref_body_vel tensor in sampled clip.",
)
ft_ref_body_vel = _require_tensor(
tensors,
"ft_ref_body_vel",
"Missing ft_ref_body_vel tensor in sampled clip.",
)
ref_body_ang_vel = _require_tensor(
tensors,
"ref_body_ang_vel",
"Missing ref_body_ang_vel tensor in sampled clip.",
)
ft_ref_body_ang_vel = _require_tensor(
tensors,
"ft_ref_body_ang_vel",
"Missing ft_ref_body_ang_vel tensor in sampled clip.",
)
ref_dof_pos = _require_tensor(
tensors,
"ref_dof_pos",
"Missing ref_dof_pos tensor in sampled clip.",
)
ft_ref_dof_pos = _require_tensor(
tensors,
"ft_ref_dof_pos",
"Missing ft_ref_dof_pos tensor in sampled clip.",
)
ref_dof_vel = _require_tensor(
tensors,
"ref_dof_vel",
"Missing ref_dof_vel tensor in sampled clip.",
)
ft_ref_dof_vel = _require_tensor(
tensors,
"ft_ref_dof_vel",
"Missing ft_ref_dof_vel tensor in sampled clip.",
)
body_name_to_idx = {name: idx for idx, name in enumerate(body_names)}
missing_links = [
link_name
for link_name in selected_body_links
if link_name not in body_name_to_idx
]
if missing_links:
raise ValueError(
f"Requested body links are missing from robot.body_names: {missing_links}"
)
ref_root_rpy = _quat_xyzw_to_rpy(ref_root_rot)
ft_root_rpy = _quat_xyzw_to_rpy(ft_root_rot)
root_payload = {
"ref_global_pos": ref_root_pos,
"ft_ref_global_pos": ft_root_pos,
"ref_rpy": ref_root_rpy,
"ft_ref_rpy": ft_root_rpy,
"ref_lin_vel": ref_root_vel,
"ft_ref_lin_vel": ft_root_vel,
"ref_ang_vel": ref_root_ang_vel,
"ft_ref_ang_vel": ft_root_ang_vel,
}
_write_npz(output_dir / "root_signals.npz", root_payload)
body_payload: dict[str, np.ndarray] = {}
for link_name in selected_body_links:
body_idx = body_name_to_idx[link_name]
body_payload[f"{link_name}__ref_global_pos"] = ref_rg_pos[
:, body_idx, :
]
body_payload[f"{link_name}__ft_ref_global_pos"] = ft_ref_rg_pos[
:, body_idx, :
]
body_payload[f"{link_name}__ref_lin_vel"] = ref_body_vel[
:, body_idx, :
]
body_payload[f"{link_name}__ft_ref_lin_vel"] = ft_ref_body_vel[
:, body_idx, :
]
body_payload[f"{link_name}__ref_ang_vel"] = ref_body_ang_vel[
:, body_idx, :
]
body_payload[f"{link_name}__ft_ref_ang_vel"] = ft_ref_body_ang_vel[
:, body_idx, :
]
_write_npz(output_dir / "bodylink_signals.npz", body_payload)
dof_payload = {
"ref_dof_pos": ref_dof_pos,
"ft_ref_dof_pos": ft_ref_dof_pos,
"ref_dof_vel": ref_dof_vel,
"ft_ref_dof_vel": ft_ref_dof_vel,
}
_write_npz(output_dir / "dof_signals.npz", dof_payload)
filter_cutoff_tensor = tensors.get("filter_cutoff_hz")
filter_cutoff_hz = None
if filter_cutoff_tensor is not None:
cutoff_values = _to_numpy(filter_cutoff_tensor).reshape(-1)
if cutoff_values.size > 0:
filter_cutoff_hz = float(cutoff_values[0])
metadata = {
"motion_key": sample.motion_key,
"raw_motion_key": sample.raw_motion_key,
"window_index": int(sample.window_index),
"length": int(sample.length),
"filter_cutoff_hz": filter_cutoff_hz,
"selected_body_links": list(selected_body_links),
"body_names": list(body_names),
"dof_names": list(dof_names),
}
(output_dir / "metadata.json").write_text(
json.dumps(metadata, indent=2, sort_keys=True),
encoding="utf-8",
)
_plot_series_groups(
output_dir / "root_comparison.png",
title="Root Raw vs Filtered Reference Signals",
groups=[
("global_pos", ref_root_pos, ft_root_pos),
("rpy", ref_root_rpy, ft_root_rpy),
("lin_vel", ref_root_vel, ft_root_vel),
("ang_vel", ref_root_ang_vel, ft_root_ang_vel),
],
)
for link_name in selected_body_links:
_plot_series_groups(
output_dir / f"{link_name}_comparison.png",
title=f"{link_name} Raw vs Filtered Reference Signals",
groups=[
(
"global_pos",
body_payload[f"{link_name}__ref_global_pos"],
body_payload[f"{link_name}__ft_ref_global_pos"],
),
(
"lin_vel",
body_payload[f"{link_name}__ref_lin_vel"],
body_payload[f"{link_name}__ft_ref_lin_vel"],
),
(
"ang_vel",
body_payload[f"{link_name}__ref_ang_vel"],
body_payload[f"{link_name}__ft_ref_ang_vel"],
),
],
)
_plot_dof_matrix(
output_dir / "dof_pos_comparison.png",
title="DOF Position Raw vs Filtered",
dof_names=dof_names,
ref_values=ref_dof_pos,
ft_values=ft_ref_dof_pos,
)
_plot_dof_matrix(
output_dir / "dof_vel_comparison.png",
title="DOF Velocity Raw vs Filtered",
dof_names=dof_names,
ref_values=ref_dof_vel,
ft_values=ft_ref_dof_vel,
)
return output_dir
def _to_plain_sequence(values) -> list[str]:
if values is None:
return []
if isinstance(values, (ListConfig, tuple, list)):
return [str(v) for v in values]
return [str(values)]
def export_reference_filter_artifacts_from_config(config) -> Path:
debug_cfg = getattr(config, "debug_reference_filter_export", None)
if debug_cfg is None or not bool(debug_cfg.get("enabled", False)):
raise ValueError("debug_reference_filter_export.enabled must be true.")
motion_cfg = config.robot.motion
online_filter_cfg = motion_cfg.get("online_filter", {})
if not bool(online_filter_cfg.get("enabled", False)):
raise ValueError(
"Reference filter debug export requires robot.motion.online_filter.enabled=true."
)
output_dir = debug_cfg.get("output_dir", None)
if output_dir in (None, ""):
output_dir = tempfile.mkdtemp(prefix="motrack-ref-filter-")
train_dataset, _, _ = build_motion_datasets_from_cfg(
motion_cfg=motion_cfg,
max_frame_length=int(motion_cfg.max_frame_length),
min_window_length=int(motion_cfg.min_frame_length),
world_frame_normalization=bool(
motion_cfg.get("world_frame_normalization", True)
),
)
sample = train_dataset[0]
return export_reference_filter_debug_artifacts(
sample=sample,
output_dir=Path(str(output_dir)),
body_names=_to_plain_sequence(config.robot.body_names),
dof_names=_to_plain_sequence(config.robot.dof_names),
selected_body_links=_to_plain_sequence(
debug_cfg.get("selected_body_links", [])
),
)
================================================
FILE: holomotion/src/training/train.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
from pathlib import Path
import sys
import hydra
from hydra.utils import get_class
from omegaconf import ListConfig, OmegaConf
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from loguru import logger
from holomotion.src.training.reference_filter_export import (
export_reference_filter_artifacts_from_config,
)
from holomotion.src.utils.config import compile_config
def _resolve_mujoco_eval_onnx_names(
exported_dir: Path, ckpt_onnx_names
) -> list[str]:
if not exported_dir.is_dir():
raise FileNotFoundError(
f"Exported ONNX directory not found: {exported_dir}"
)
existing = sorted([p.name for p in exported_dir.glob("*.onnx")])
if len(existing) == 0:
raise FileNotFoundError(f"No .onnx files found under {exported_dir}")
existing_set = set(existing)
if ckpt_onnx_names is None:
return existing
if isinstance(ckpt_onnx_names, ListConfig):
requested = list(ckpt_onnx_names)
elif isinstance(ckpt_onnx_names, (list, tuple)):
requested = list(ckpt_onnx_names)
else:
raise TypeError(
"mujoco_eval.ckpt_onnx_names must be a list/tuple, "
f"got {type(ckpt_onnx_names)}"
)
requested_norm = []
for name in requested:
name_str = str(name).strip()
if name_str == "":
continue
requested_norm.append(Path(name_str).name)
if len(requested_norm) == 0:
return existing
selected = [name for name in requested_norm if name in existing_set]
if len(selected) == 0:
raise ValueError(
"No requested ONNX checkpoints exist under exported directory. "
f"exported_dir={exported_dir}, requested={requested_norm}, "
f"existing={existing}"
)
return selected
def _exec_mujoco_eval(eval_override_dict: dict) -> None:
cli_args = []
for key in sorted(eval_override_dict.keys()):
value = eval_override_dict[key]
if value is None:
continue
if isinstance(value, bool):
cli_args.append(f"{key}={'true' if value else 'false'}")
elif isinstance(value, (int, float)):
cli_args.append(f"{key}={value}")
elif isinstance(value, str):
cli_args.append(f"{key}={value}")
elif isinstance(value, (list, tuple)):
inner = ",".join([str(v) for v in value])
cli_args.append(f"{key}=[{inner}]")
else:
cli_args.append(f"{key}={value}")
argv = [
sys.executable,
"-m",
"holomotion.src.evaluation.eval_mujoco_sim2sim",
] + cli_args
os.execv(sys.executable, argv)
def _maybe_export_reference_filter_artifacts(config: OmegaConf) -> None:
debug_cfg = getattr(config, "debug_reference_filter_export", None)
if debug_cfg is None or not bool(debug_cfg.get("enabled", False)):
return
if not bool(getattr(config, "main_process", True)):
return
export_dir = export_reference_filter_artifacts_from_config(config)
logger.info(f"Exported reference filter debug artifacts to: {export_dir}")
@hydra.main(
config_path="../../config",
config_name="training/train_base",
version_base=None,
)
def main(config: OmegaConf):
"""Train the motion tracking model.
Args:
config: OmegaConf object containing the configuration.
"""
config = compile_config(config, accelerator=None)
dist = None
# In distributed runs, Hydra resolves ${now:...} per process so experiment_save_dir
# can differ by rank (e.g. staggered startup). Use Accelerator to init the process
# group, then broadcast rank 0's path so all ranks write to the same directory.
if getattr(config, "num_processes", 1) > 1:
project_config = ProjectConfiguration(
project_dir=config.experiment_save_dir,
logging_dir=config.experiment_save_dir,
)
_accelerator = Accelerator(project_config=project_config)
import torch.distributed as dist
path_list = (
[config.experiment_save_dir]
if _accelerator.is_main_process
else [None]
)
dist.broadcast_object_list(path_list, src=0)
config.experiment_save_dir = path_list[0]
_maybe_export_reference_filter_artifacts(config)
if dist is not None:
dist.barrier()
log_dir = config.experiment_save_dir
headless = config.headless
algo_class = get_class(config.algo._target_)
algo = algo_class(
env_config=config.env,
config=config.algo.config,
log_dir=log_dir,
headless=headless,
)
algo.load(config.checkpoint)
algo.learn()
if not bool(config.mujoco_eval.get("enabled", False)):
return
if not bool(config.algo.config.get("export_policy", False)):
msg = (
"mujoco_eval.enabled=true requires "
"algo.config.export_policy=true to export ONNX "
"before post-training evaluation."
)
raise ValueError(msg)
if not bool(algo.is_main_process):
os._exit(0)
exported_dir = Path(log_dir) / "exported"
selected_onnx_names = _resolve_mujoco_eval_onnx_names(
exported_dir, config.mujoco_eval.get("ckpt_onnx_names", None)
)
eval_override_dict = OmegaConf.to_container(
config.mujoco_eval, resolve=True
)
eval_override_dict.pop("enabled", None)
eval_override_dict["ckpt_onnx_root_dir"] = str(exported_dir)
eval_override_dict["ckpt_onnx_names"] = selected_onnx_names
_exec_mujoco_eval(eval_override_dict)
if __name__ == "__main__":
main()
================================================
FILE: holomotion/src/utils/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/utils/config.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import copy
import math
import os
from pathlib import Path
import torch
from accelerate import Accelerator
from loguru import logger
from omegaconf import OmegaConf
def setup_hydra_resolvers():
"""Set up custom resolvers for OmegaConf.
This function registers a set of custom resolvers with OmegaConf to allow
for more dynamic and flexible configurations within Hydra. These resolvers
enable performing calculations, conditional logic, and other operations
directly in the YAML configuration files. For example,
you can use `${sqrt:4}` to get `2.0`.
The registered resolvers include:
- `eval`: Evaluates a Python expression.
- `if`: Conditional logic (if-else).
- `eq`: Case-insensitive string comparison.
- `sqrt`: Calculates the square root.
- `sum`: Sums a list of numbers.
- `ceil`: Computes the ceiling of a number.
- `int`: Casts a value to an integer.
- `len`: Returns the length of a list or string.
- `sum_list`: Sums a list of numbers.
"""
try:
OmegaConf.register_new_resolver("eval", eval)
OmegaConf.register_new_resolver(
"if", lambda pred, a, b: a if pred else b
)
OmegaConf.register_new_resolver(
"eq", lambda x, y: x.lower() == y.lower()
)
OmegaConf.register_new_resolver("sqrt", lambda x: math.sqrt(float(x)))
OmegaConf.register_new_resolver("sum", lambda x: sum(x))
OmegaConf.register_new_resolver("ceil", lambda x: math.ceil(x))
OmegaConf.register_new_resolver("int", lambda x: int(x))
OmegaConf.register_new_resolver("len", lambda x: len(x))
OmegaConf.register_new_resolver("sum_list", lambda lst: sum(lst))
except Exception as e:
logger.warning(f"Warning: Some resolvers already registered: {e}")
def compile_config(
config: OmegaConf,
accelerator: Accelerator = None,
eval: bool = False,
) -> None:
"""Compile the configuration.
Args:
config: Unresolved configuration.
accelerator: Accelerator instance.
Returns:
Compiled configuration.
"""
setup_hydra_resolvers()
config = copy.deepcopy(config)
config = compile_config_hf_accelerate(config, accelerator)
config = compile_config_directories(config, eval)
config = compile_config_devices(config)
return config
def compile_config_hf_accelerate(
config,
accelerator: Accelerator = None,
) -> None:
"""Compile the configuration for HF Accelerate.
Args:
config: Configuration.
accelerator: Accelerator instance.
Returns:
Compiled configuration.
"""
if accelerator is not None:
device = accelerator.device
is_main_process = accelerator.is_main_process
process_idx = accelerator.process_index
total_processes = accelerator.num_processes
else:
# Best-effort distributed metadata when running under torchrun / Accelerate launch,
# even if an Accelerator instance is not provided yet.
process_idx = int(
os.environ.get(
"RANK", os.environ.get("ACCELERATE_PROCESS_INDEX", "0")
)
)
total_processes = int(
os.environ.get(
"WORLD_SIZE", os.environ.get("ACCELERATE_NUM_PROCESSES", "1")
)
)
local_rank = int(
os.environ.get(
"LOCAL_RANK",
os.environ.get("ACCELERATE_LOCAL_PROCESS_INDEX", "0"),
)
)
is_main_process = process_idx == 0
if torch.cuda.is_available():
device = torch.device("cuda", local_rank)
else:
device = torch.device("cpu")
config.process_id = process_idx
config.num_processes = total_processes
config.main_process = is_main_process
if hasattr(config, "device"):
config.device = str(device)
logger.info(f"Using device: {device}")
if is_main_process:
logger.info(f"Process {process_idx} on device: {device}")
return config
def compile_config_devices(config):
"""Propagate device and process metadata into the environment configuration."""
config = copy.deepcopy(config)
if hasattr(config, "device"):
device_str = str(config.device)
else:
device_str = str(
torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
world_size = getattr(config, "num_processes", 1)
process_rank = getattr(config, "process_id", 0)
is_main_process = getattr(config, "main_process", True)
if hasattr(config, "env") and hasattr(config.env, "config"):
env_cfg = config.env.config
env_cfg_struct = OmegaConf.is_struct(env_cfg)
OmegaConf.set_struct(env_cfg, False)
env_cfg.num_processes = world_size
env_cfg.process_id = process_rank
env_cfg.main_process = is_main_process
env_cfg.simulation_device = device_str
for key in [
"sim_device",
"rl_device",
"compute_device",
"physx_device",
]:
setattr(env_cfg, key, device_str)
if hasattr(env_cfg, "simulation"):
for sim_key in ["device", "compute_device", "rl_device"]:
sim_cfg = env_cfg.simulation
sim_struct = OmegaConf.is_struct(sim_cfg)
OmegaConf.set_struct(sim_cfg, False)
setattr(sim_cfg, sim_key, device_str)
OmegaConf.set_struct(sim_cfg, sim_struct)
OmegaConf.set_struct(env_cfg, env_cfg_struct)
return config
def compile_config_directories(config, eval: bool = False) -> None:
"""Compile the configuration for folders.
Args:
config: Configuration.
Returns:
Compiled configuration.
"""
if eval:
return config
config = copy.deepcopy(config)
experiment_save_dir = Path(config.experiment_dir)
experiment_save_dir.mkdir(exist_ok=True, parents=True)
config.experiment_save_dir = str(experiment_save_dir)
if hasattr(config, "env"):
config.env.config.save_rendering_dir = str(
Path(config.experiment_dir) / "renderings_training"
)
unresolved_conf = OmegaConf.to_container(config, resolve=False)
if config.main_process:
logger.info(f"Saving config file to {experiment_save_dir}")
with open(experiment_save_dir / "config.yaml", "w") as file:
OmegaConf.save(unresolved_conf, file)
return config
================================================
FILE: holomotion/src/utils/frame_utils.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import isaaclab.utils.math as isaaclab_math
import torch
def positions_world_to_env_frame(
positions_w: torch.Tensor,
env_origins: torch.Tensor,
) -> torch.Tensor:
"""Convert simulator-world positions to IsaacLab environment frame.
IsaacLab's MDP root position helpers return positions in the environment
frame, i.e. simulation-world coordinates with per-environment
`env_origins` subtracted. This helper applies the same
translation removal to arbitrary position tensors so position arithmetic
stays frame-consistent.
"""
if positions_w.ndim < 2 or positions_w.shape[-1] != 3:
raise ValueError(
"positions_w must have shape [B, ..., 3], "
f"got {tuple(positions_w.shape)}."
)
if env_origins.ndim != 2 or env_origins.shape[-1] != 3:
raise ValueError(
"env_origins must have shape [B, 3], "
f"got {tuple(env_origins.shape)}."
)
if positions_w.shape[0] != env_origins.shape[0]:
raise ValueError(
"Batch size mismatch between positions_w and env_origins: "
f"{positions_w.shape[0]} vs {env_origins.shape[0]}."
)
origin_view = env_origins.view(
env_origins.shape[0],
*([1] * (positions_w.ndim - 2)),
3,
)
return positions_w - origin_view
def root_relative_positions_from_env_frame(
body_pos_env: torch.Tensor,
root_pos_env: torch.Tensor,
root_quat_w: torch.Tensor,
) -> torch.Tensor:
"""Convert environment-frame body positions into the root frame.
The input positions must already be in IsaacLab's environment frame rather
than raw simulator-world coordinates. Orientation is unaffected by
`env_origins`, so the articulation root quaternion is reused directly.
"""
if body_pos_env.ndim < 3 or body_pos_env.shape[-1] != 3:
raise ValueError(
"body_pos_env must have shape [B, ..., 3], "
f"got {tuple(body_pos_env.shape)}."
)
if root_pos_env.ndim != 2 or root_pos_env.shape[-1] != 3:
raise ValueError(
"root_pos_env must have shape [B, 3], "
f"got {tuple(root_pos_env.shape)}."
)
if root_quat_w.ndim != 2 or root_quat_w.shape[-1] != 4:
raise ValueError(
"root_quat_w must have shape [B, 4], "
f"got {tuple(root_quat_w.shape)}."
)
if body_pos_env.shape[0] != root_pos_env.shape[0]:
raise ValueError(
"Batch size mismatch between body_pos_env and root_pos_env: "
f"{body_pos_env.shape[0]} vs {root_pos_env.shape[0]}."
)
if body_pos_env.shape[0] != root_quat_w.shape[0]:
raise ValueError(
"Batch size mismatch between body_pos_env and root_quat_w: "
f"{body_pos_env.shape[0]} vs {root_quat_w.shape[0]}."
)
root_pos_view = root_pos_env.view(
root_pos_env.shape[0],
*([1] * (body_pos_env.ndim - 2)),
3,
)
root_quat_view = root_quat_w.view(
root_quat_w.shape[0],
*([1] * (body_pos_env.ndim - 2)),
4,
).expand(*body_pos_env.shape[:-1], 4)
rel_pos_env = body_pos_env - root_pos_view
return isaaclab_math.quat_apply_inverse(root_quat_view, rel_pos_env)
def root_relative_positions_from_mixed_position_frames(
body_pos_w: torch.Tensor,
root_pos_env: torch.Tensor,
root_quat_w: torch.Tensor,
env_origins: torch.Tensor,
) -> torch.Tensor:
"""Build root-relative positions from world-frame bodies.
This is the safe adapter for common IsaacLab code paths where body poses
are read from `robot.data.body_pos_w` in simulator world coordinates while
`isaaclab_mdp.root_pos_w(env)` is already expressed in the environment
frame.
"""
body_pos_env = positions_world_to_env_frame(body_pos_w, env_origins)
return root_relative_positions_from_env_frame(
body_pos_env=body_pos_env,
root_pos_env=root_pos_env,
root_quat_w=root_quat_w,
)
================================================
FILE: holomotion/src/utils/isaac_utils/__init__.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
================================================
FILE: holomotion/src/utils/isaac_utils/maths.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# This file was originally copied from the [ASAP] repository:
# https://github.com/LeCAR-Lab/ASAP
# Modifications have been made to fit the needs of this project.
import os
import random
import numpy as np
import torch
@torch.jit.script
def normalize(x, eps: float = 1e-9):
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
@torch.jit.script
def torch_rand_float(lower, upper, shape, device):
# type: (float, float, Tuple[int, int], str) -> Tensor
return (upper - lower) * torch.rand(*shape, device=device) + lower
@torch.jit.script
def copysign(a, b):
# type: (float, Tensor) -> Tensor
a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
return torch.abs(a) * torch.sign(b)
def set_seed(seed, torch_deterministic=False):
"""Set seed across modules"""
if seed == -1 and torch_deterministic:
seed = 42
elif seed == -1:
seed = np.random.randint(0, 10000)
print(f"Setting seed: {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch_deterministic:
# refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
return seed
================================================
FILE: holomotion/src/utils/isaac_utils/rotations.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# This file was originally copied from the [ASAP] repository:
# https://github.com/LeCAR-Lab/ASAP
# Modifications have been made to fit the needs of this project.
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from holomotion.src.utils.isaac_utils.maths import (
copysign,
normalize,
)
@torch.jit.script
def quat_unit(a):
return normalize(a)
@torch.jit.script
def quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor:
shape = b.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 3)
if w_last:
xyz = a[:, :3]
w = a[:, 3:]
else:
xyz = a[:, 1:]
w = a[:, :1]
t = xyz.cross(b, dim=-1) * 2
return (b + w * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_apply_yaw(quat: Tensor, vec: Tensor, w_last: bool) -> Tensor:
quat_yaw = quat.clone().view(-1, 4)
quat_yaw[:, :2] = 0.0
quat_yaw = normalize(quat_yaw)
return quat_apply(quat_yaw, vec, w_last)
@torch.jit.script
def wrap_to_pi(angles):
angles %= 2 * np.pi
angles -= 2 * np.pi * (angles > np.pi)
return angles
@torch.jit.script
def quat_conjugate(a: Tensor, w_last: bool) -> Tensor:
shape = a.shape
a = a.reshape(-1, 4)
if w_last:
return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)
else:
return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape)
@torch.jit.script
def quat_apply(a: Tensor, b: Tensor, w_last: bool) -> Tensor:
shape = b.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 3)
if w_last:
xyz = a[:, :3]
w = a[:, 3:]
else:
xyz = a[:, 1:]
w = a[:, :1]
t = xyz.cross(b, dim=-1) * 2
return (b + w * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_rotate(q: Tensor, v: Tensor, w_last: bool) -> Tensor:
shape = q.shape
if w_last:
q_w = q[:, -1]
q_vec = q[:, :3]
else:
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a + b + c
@torch.jit.script
def quat_rotate_inverse(q: Tensor, v: Tensor, w_last: bool) -> Tensor:
shape = q.shape
if w_last:
q_w = q[:, -1]
q_vec = q[:, :3]
else:
q_w = q[:, 0]
q_vec = q[:, 1:]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a - b + c
@torch.jit.script
def quat_angle_axis(x: Tensor, w_last: bool) -> Tuple[Tensor, Tensor]:
"""The (angle, axis) representation of the rotation. The axis is normalized to unit length.
The angle is guaranteed to be between [0, pi].
"""
if w_last:
w = x[..., -1]
axis = x[..., :3]
else:
w = x[..., 0]
axis = x[..., 1:]
s = 2 * (w**2) - 1
angle = s.clamp(-1, 1).arccos() # just to be safe
axis /= axis.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-9)
return angle, axis
@torch.jit.script
def quat_from_angle_axis(angle: Tensor, axis: Tensor, w_last: bool) -> Tensor:
theta = (angle / 2).unsqueeze(-1)
xyz = normalize(axis) * theta.sin()
w = theta.cos()
if w_last:
return quat_unit(torch.cat([xyz, w], dim=-1))
else:
return quat_unit(torch.cat([w, xyz], dim=-1))
@torch.jit.script
def vec_to_heading(h_vec):
h_theta = torch.atan2(h_vec[..., 1], h_vec[..., 0])
return h_theta
@torch.jit.script
def heading_to_quat(h_theta, w_last: bool):
axis = torch.zeros(
h_theta.shape
+ [
3,
],
device=h_theta.device,
)
axis[..., 2] = 1
heading_q = quat_from_angle_axis(h_theta, axis, w_last=w_last)
return heading_q
@torch.jit.script
def quat_axis(q: Tensor, axis: int, w_last: bool) -> Tensor:
basis_vec = torch.zeros(q.shape[0], 3, device=q.device)
basis_vec[:, axis] = 1
return quat_rotate(q, basis_vec, w_last)
@torch.jit.script
def normalize_angle(x):
return torch.atan2(torch.sin(x), torch.cos(x))
@torch.jit.script
def get_basis_vector(q: Tensor, v: Tensor, w_last: bool) -> Tensor:
return quat_rotate(q, v, w_last)
@torch.jit.script
def quat_to_angle_axis(q):
# type: (Tensor) -> Tuple[Tensor, Tensor]
# computes axis-angle representation from quaternion q
# q must be normalized
# ZL: could have issues.
min_theta = 1e-5
qx, qy, qz, qw = 0, 1, 2, 3
sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])
angle = 2 * torch.acos(q[..., qw])
angle = normalize_angle(angle)
sin_theta_expand = sin_theta.unsqueeze(-1)
axis = q[..., qx:qw] / sin_theta_expand
mask = torch.abs(sin_theta) > min_theta
default_axis = torch.zeros_like(axis)
default_axis[..., -1] = 1
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
@torch.jit.script
def slerp(q0, q1, t):
# type: (Tensor, Tensor, Tensor) -> Tensor
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = q1.clone()
q1[neg_mask] = -q1[neg_mask]
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta)
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)
ratioA = torch.sin((1 - t) * half_theta) / sin_half_theta
ratioB = torch.sin(t * half_theta) / sin_half_theta
new_q = ratioA * q0 + ratioB * q1
new_q = torch.where(
torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q
)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
@torch.jit.script
def angle_axis_to_exp_map(angle, axis):
# type: (Tensor, Tensor) -> Tensor
# compute exponential map from axis-angle
angle_expand = angle.unsqueeze(-1)
exp_map = angle_expand * axis
return exp_map
@torch.jit.script
def my_quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a + b + c
@torch.jit.script
def calc_heading(q):
# type: (Tensor) -> Tensor
# calculate heading direction from quaternion
# the heading is the direction on the xy plane
# q must be normalized
# this is the x axis heading
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., 0] = 1
rot_dir = my_quat_rotate(q, ref_dir)
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])
return heading
@torch.jit.script
def quat_to_exp_map(q):
# type: (Tensor) -> Tensor
# compute exponential map from quaternion
# q must be normalized
angle, axis = quat_to_angle_axis(q)
exp_map = angle_axis_to_exp_map(angle, axis)
return exp_map
@torch.jit.script
def calc_heading_quat(q, w_last):
# type: (Tensor, bool) -> Tensor
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(heading, axis, w_last=w_last)
return heading_q
@torch.jit.script
def calc_heading_quat_inv(q, w_last):
# type: (Tensor, bool) -> Tensor
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(-heading, axis, w_last=w_last)
return heading_q
@torch.jit.script
def quat_inverse(x, w_last):
# type: (Tensor, bool) -> Tensor
"""The inverse of the rotation"""
return quat_conjugate(x, w_last=w_last)
@torch.jit.script
def get_euler_xyz(q: Tensor, w_last: bool) -> Tuple[Tensor, Tensor, Tensor]:
if w_last:
qx, qy, qz, qw = 0, 1, 2, 3
else:
qw, qx, qy, qz = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = (
q[:, qw] * q[:, qw]
- q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
+ q[:, qz] * q[:, qz]
)
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(
torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)
)
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = (
q[:, qw] * q[:, qw]
+ q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
- q[:, qz] * q[:, qz]
)
yaw = torch.atan2(siny_cosp, cosy_cosp)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)
# @torch.jit.script
def get_euler_xyz_in_tensor(q):
qx, qy, qz, qw = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = (
q[:, qw] * q[:, qw]
- q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
+ q[:, qz] * q[:, qz]
)
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(
torch.abs(sinp) >= 1, copysign(np.pi / 2.0, sinp), torch.asin(sinp)
)
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = (
q[:, qw] * q[:, qw]
+ q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
- q[:, qz] * q[:, qz]
)
yaw = torch.atan2(siny_cosp, cosy_cosp)
return torch.stack((roll, pitch, yaw), dim=-1)
@torch.jit.script
def quat_pos(x):
"""Make all the real part of the quaternion positive"""
q = x
z = (q[..., 3:] < 0).float()
q = (1 - 2 * z) * q
return q
@torch.jit.script
def is_valid_quat(q):
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
return (w * w + x * x + y * y + z * z).allclose(torch.ones_like(w))
@torch.jit.script
def quat_normalize(q):
"""Construct 3D rotation from quaternion (the quaternion needs not to be normalized)."""
q = quat_unit(quat_pos(q)) # normalized to positive and unit quaternion
return q
@torch.jit.script
def quat_mul(a, b, w_last: bool):
assert a.shape == b.shape
shape = a.shape
a = a.reshape(-1, 4)
b = b.reshape(-1, 4)
if w_last:
x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
else:
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
if w_last:
quat = torch.stack([x, y, z, w], dim=-1).view(shape)
else:
quat = torch.stack([w, x, y, z], dim=-1).view(shape)
return quat
@torch.jit.script
def quat_mul_norm(x, y, w_last):
# type: (Tensor, Tensor, bool) -> Tensor
r"""Combine two set of 3D rotations together using \**\* operator. The shape needs to be
broadcastable
"""
return quat_normalize(quat_mul(x, y, w_last))
@torch.jit.script
def quat_mul_norm(x, y, w_last):
# type: (Tensor, Tensor, bool) -> Tensor
r"""Combine two set of 3D rotations together using \**\* operator. The shape needs to be
broadcastable
"""
return quat_unit(quat_mul(x, y, w_last))
@torch.jit.script
def quat_identity(shape: List[int]):
"""Construct 3D identity rotation given shape"""
w = torch.ones(shape + [1])
xyz = torch.zeros(shape + [3])
q = torch.cat([xyz, w], dim=-1)
return quat_normalize(q)
@torch.jit.script
def quat_identity_like(x):
"""Construct identity 3D rotation with the same shape"""
return quat_identity(x.shape[:-1])
@torch.jit.script
def transform_from_rotation_translation(
r: Optional[torch.Tensor] = None, t: Optional[torch.Tensor] = None
):
"""Construct a transform from a quaternion and 3D translation. Only one of them can be None."""
assert r is not None or t is not None, (
"rotation and translation can't be all None"
)
if r is None:
assert t is not None
r = quat_identity(list(t.shape))
if t is None:
t = torch.zeros(list(r.shape) + [3])
return torch.cat([r, t], dim=-1)
@torch.jit.script
def transform_rotation(x):
"""Get rotation from transform"""
return x[..., :4]
@torch.jit.script
def transform_translation(x):
"""Get translation from transform"""
return x[..., 4:]
@torch.jit.script
def transform_mul(x, y):
"""Combine two transformation together"""
z = transform_from_rotation_translation(
r=quat_mul_norm(
transform_rotation(x), transform_rotation(y), w_last=True
),
t=quat_rotate(
transform_rotation(x), transform_translation(y), w_last=True
)
+ transform_translation(x),
)
return z
@torch.compile
def quaternion_to_matrix(
quaternions: torch.Tensor,
w_last: bool = True,
) -> torch.Tensor:
"""Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions as tensor of shape (..., 4).
If w_last=True (default): real part last (x, y, z, w)
If w_last=False: real part first (w, x, y, z)
w_last: If True, quaternion format is (x, y, z, w).
If False, quaternion format is (w, x, y, z). Default: True.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if w_last:
i, j, k, r = torch.unbind(quaternions, -1)
else:
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
@torch.jit.script
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
"""Convert rotations given as axis/angle to quaternions.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles],
dim=-1,
)
return quaternions
# @torch.jit.script
def wxyz_to_xyzw(quat):
return quat[..., [1, 2, 3, 0]]
# @torch.jit.script
def xyzw_to_wxyz(quat):
return quat[..., [3, 0, 1, 2]]
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""W x y z
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
torch.stack(
[q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1
),
torch.stack(
[m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1
),
torch.stack(
[m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1
),
torch.stack(
[m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1
),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,
:, # pyre-ignore[16]
].reshape(batch_dim + (4,))
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def quat_w_first(rot):
rot = torch.cat([rot[..., [-1]], rot[..., :-1]], -1)
return rot
@torch.jit.script
def quat_from_euler_xyz(roll, pitch, yaw):
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp
return torch.stack([qx, qy, qz, qw], dim=-1)
@torch.compile
def remove_yaw_component(
quat_raw: Tensor,
quat_init: Tensor,
w_last: bool = True,
) -> Tensor:
"""Remove yaw component from quaternion while keeping roll and pitch.
This function extracts the yaw component from the initial quaternion and uses
it to normalize the raw quaternion, effectively removing the initial heading
offset while preserving roll and pitch components.
Args:
quat_raw: Current quaternion from IMU, shape (..., 4)
quat_init: Initial quaternion (contains the yaw to be removed), shape (..., 4)
w_last: If True, quaternion format is (x, y, z, w).
If False, quaternion format is (w, x, y, z). Default: True.
Returns:
Quaternion with initial yaw component removed, same shape as input.
The resulting quaternion represents roll and pitch relative to the
heading-aligned coordinate frame.
Example:
>>> # Initial robot orientation (roll=0°, pitch=0°, yaw=45°)
>>> quat_init = quat_from_euler_xyz(
... torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.7854)
... )
>>> # Current IMU reading (roll=10°, pitch=20°, yaw=60°)
>>> quat_raw = quat_from_euler_xyz(
... torch.tensor(0.1745),
... torch.tensor(0.3491),
... torch.tensor(1.0472),
... )
>>> quat_norm = remove_yaw_component(quat_raw, quat_init)
>>> # quat_norm contains roll=10°, pitch=20°, with initial yaw offset removed
"""
# Extract quaternion components based on format
if w_last:
q_w = quat_init[..., -1]
q_vec = quat_init[..., :3]
else:
q_w = quat_init[..., 0]
q_vec = quat_init[..., 1:]
# Calculate heading by rotating x-axis with quaternion
# ref_dir = [1, 0, 0] (x-axis)
ref_dir = torch.zeros_like(q_vec)
ref_dir[..., 0] = 1.0
# Quaternion rotation: v' = v + 2 * w * (q_vec × v) + 2 * q_vec × (q_vec × v)
cross1 = torch.cross(q_vec, ref_dir, dim=-1)
cross2 = torch.cross(q_vec, cross1, dim=-1)
rot_dir = ref_dir + 2.0 * q_w.unsqueeze(-1) * cross1 + 2.0 * cross2
# Extract heading angle from rotated x-axis
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])
# Create inverse heading quaternion (rotation about negative z-axis)
half_heading = (-heading) * 0.5
heading_q_inv = torch.zeros_like(quat_init)
if w_last:
heading_q_inv[..., 0] = 0.0 # x
heading_q_inv[..., 1] = 0.0 # y
heading_q_inv[..., 2] = torch.sin(half_heading) # z
heading_q_inv[..., 3] = torch.cos(half_heading) # w
else:
heading_q_inv[..., 0] = torch.cos(half_heading) # w
heading_q_inv[..., 1] = 0.0 # x
heading_q_inv[..., 2] = 0.0 # y
heading_q_inv[..., 3] = torch.sin(half_heading) # z
# Quaternion multiplication: heading_q_inv * quat_raw
shape = quat_raw.shape
a = heading_q_inv.reshape(-1, 4)
b = quat_raw.reshape(-1, 4)
if w_last:
x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
else:
w1, x1, y1, z1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
w2, x2, y2, z2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
# Quaternion multiplication formula
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
if w_last:
quat_result = torch.stack([x, y, z, w], dim=-1).view(shape)
else:
quat_result = torch.stack([w, x, y, z], dim=-1).view(shape)
# Normalize the result quaternion
norm = torch.norm(quat_result, p=2, dim=-1, keepdim=True)
quat_norm = quat_result / norm.clamp(min=1e-8)
return quat_norm
================================================
FILE: holomotion/src/utils/isaac_utils/setup.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# This file was originally copied from the [ASAP] repository:
# https://github.com/LeCAR-Lab/ASAP
# Modifications have been made to fit the needs of this project.
from setuptools import setup
setup(
name="isaac_utils",
packages=["isaac_utils"],
version="0.0.1",
description="Unified torch env_utils for IsaacGym and IsaacSim",
author="",
classifiers=[],
)
================================================
FILE: holomotion/src/utils/onnx_export.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import inspect
import re
from pathlib import Path
from loguru import logger
def _list_to_csv_str(arr, *, decimals: int = 3, delimiter: str = ",") -> str:
fmt = f"{{:.{decimals}f}}"
return delimiter.join(
fmt.format(x) if isinstance(x, (int, float)) else str(x) for x in arr
)
def attach_onnx_metadata_holomotion(env, onnx_path: str) -> None:
import onnx
metadata = {
"joint_names": env.scene["robot"].data.joint_names,
"joint_stiffness": env.scene["robot"]
.data.default_joint_stiffness[0]
.cpu()
.tolist(),
"joint_damping": env.scene["robot"]
.data.default_joint_damping[0]
.cpu()
.tolist(),
"default_joint_pos": env.scene["robot"]
.data.default_joint_pos[0]
.cpu()
.tolist(),
"action_scale": env.action_manager.get_term("dof_pos")
._scale[0]
.cpu()
.tolist(),
}
model = onnx.load(onnx_path)
for key, value in metadata.items():
entry = onnx.StringStringEntryProto()
entry.key = key
entry.value = (
_list_to_csv_str(value) if isinstance(value, list) else str(value)
)
model.metadata_props.append(entry)
onnx.save(model, onnx_path)
def export_policy_to_onnx(
algo,
checkpoint_path: str,
*,
onnx_name_suffix: str | None = None,
use_kv_cache: bool = True,
) -> str:
checkpoint = Path(checkpoint_path)
export_dir = checkpoint.parent / "exported"
export_dir.mkdir(exist_ok=True)
onnx_name = checkpoint.name.replace(".pt", ".onnx")
if onnx_name_suffix is not None:
suffix = re.sub(r"[\s+]", "_", str(onnx_name_suffix))
onnx_name = onnx_name.replace(".onnx", f"_{suffix}.onnx")
onnx_path = export_dir / onnx_name
logger.info("Starting ONNX minimal policy export (actions-only)...")
actor_was_training = getattr(algo.actor, "training", None)
critic_was_training = getattr(algo.critic, "training", None)
algo.actor.eval()
algo.critic.eval()
try:
actor_for_export = algo.accelerator.unwrap_model(algo.actor)
orig_mod = getattr(actor_for_export, "_orig_mod", None)
if orig_mod is not None:
actor_for_export = orig_mod
export_signature = inspect.signature(actor_for_export.export_onnx)
export_kwargs = {"onnx_path": onnx_path, "opset_version": 17}
if "use_kv_cache" in export_signature.parameters:
export_kwargs["use_kv_cache"] = bool(use_kv_cache)
onnx_path_str = actor_for_export.export_onnx(**export_kwargs)
attach_onnx_metadata_holomotion(algo.env._env, onnx_path=onnx_path_str)
logger.info(
f"Successfully exported minimal policy to: {onnx_path_str}"
)
return onnx_path_str
finally:
if actor_was_training is not None:
algo.actor.train(actor_was_training)
if critic_was_training is not None:
algo.critic.train(critic_was_training)
================================================
FILE: holomotion/src/utils/reference_prefix.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Mapping
def resolve_reference_tensor_key(
batch_tensors: Mapping[str, object],
base_key: str,
prefix: str = "ref_",
) -> str:
tensor_key = base_key
if prefix:
prefixed_key = f"{prefix}{base_key}"
if prefixed_key in batch_tensors:
tensor_key = prefixed_key
elif prefix == "ft_ref_":
raise KeyError(
f"Filtered tensor '{prefixed_key}' is not present in the "
"current motion cache batch. Ensure online filtering is "
"enabled and 'ft_ref_' is materialized in allowed_prefixes."
)
elif base_key not in batch_tensors:
raise KeyError(
f"Neither '{prefixed_key}' nor '{base_key}' is present in "
"the current motion cache batch."
)
elif base_key not in batch_tensors:
raise KeyError(
f"Tensor '{base_key}' is not present in the current motion cache batch."
)
return tensor_key
================================================
FILE: holomotion/src/utils/torch_utils.py
================================================
"""Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
"""
import numpy as np
import torch
import torch.nn.functional as F
def to_torch(x, dtype=torch.float, device="cpu", requires_grad=False):
return torch.tensor(
x, dtype=dtype, device=device, requires_grad=requires_grad
)
@torch.jit.script
def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""Multiply two quaternions together.
Args:
q1: The first quaternion in (w, x, y, z). Shape is (..., 4).
q2: The second quaternion in (w, x, y, z). Shape is (..., 4).
Returns:
The product of the two quaternions in (w, x, y, z). Shape is (..., 4).
Raises:
ValueError: Input shapes of ``q1`` and ``q2`` are not matching.
"""
# check input is correct
if q1.shape != q2.shape:
msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}."
raise ValueError(msg)
# reshape to (N, 4) for multiplication
shape = q1.shape
q1 = q1.reshape(-1, 4)
q2 = q2.reshape(-1, 4)
# extract components from quaternions
w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
# perform multiplication
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
xx = ww + yy + zz
qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
w = qq - ww + (z1 - y1) * (y2 - z2)
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
return torch.stack([w, x, y, z], dim=-1).view(shape)
@torch.jit.script
def normalize(x, eps: float = 1e-9):
return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)
@torch.jit.script
def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Apply a quaternion rotation to a vector.
Args:
quat: The quaternion in (w, x, y, z). Shape is (..., 4).
vec: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
# store shape
shape = vec.shape
# reshape to (N, 3) for multiplication
quat = quat.reshape(-1, 4)
vec = vec.reshape(-1, 3)
# extract components from quaternions
xyz = quat[:, 1:]
t = xyz.cross(vec, dim=-1) * 2
return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_apply_inverse(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
"""Apply an inverse quaternion rotation to a vector.
Args:
quat: The quaternion in (w, x, y, z). Shape is (..., 4).
vec: The vector in (x, y, z). Shape is (..., 3).
Returns:
The rotated vector in (x, y, z). Shape is (..., 3).
"""
# store shape
shape = vec.shape
# reshape to (N, 3) for multiplication
quat = quat.reshape(-1, 4)
vec = vec.reshape(-1, 3)
# extract components from quaternions
xyz = quat[:, 1:]
t = xyz.cross(vec, dim=-1) * 2
return (vec - quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape)
@torch.jit.script
def quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a + b + c
# @torch.jit.script
def quat_rotate_inverse(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a - b + c
@torch.jit.script
def quat_conjugate(a):
shape = a.shape
a = a.reshape(-1, 4)
return torch.cat((a[:, 0:1], -a[:, 1:]), dim=-1).view(shape)
# return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)
@torch.jit.script
def quat_unit(a):
return normalize(a)
@torch.jit.script
def quat_from_angle_axis(angle, axis):
theta = (angle / 2).unsqueeze(-1)
xyz = normalize(axis) * theta.sin()
w = theta.cos()
return quat_unit(torch.cat([xyz, w], dim=-1))
@torch.jit.script
def normalize_angle(x):
return torch.atan2(torch.sin(x), torch.cos(x))
@torch.jit.script
def tf_inverse(q, t):
q_inv = quat_conjugate(q)
return q_inv, -quat_apply(q_inv, t)
@torch.jit.script
def tf_apply(q, t, v):
return quat_apply(q, v) + t
@torch.jit.script
def tf_vector(q, v):
return quat_apply(q, v)
@torch.jit.script
def tf_combine(q1, t1, q2, t2):
return quat_mul(q1, q2), quat_apply(q1, t2) + t1
@torch.jit.script
def get_basis_vector(q, v):
return quat_rotate(q, v)
def get_axis_params(value, axis_idx, x_value=0.0, dtype=np.float64, n_dims=3):
"""Construct arguments to `Vec` according to axis index."""
zs = np.zeros((n_dims,))
assert axis_idx < n_dims, (
"the axis dim should be within the vector dimensions"
)
zs[axis_idx] = 1.0
params = np.where(zs == 1.0, value, zs)
params[0] = x_value
return list(params.astype(dtype))
@torch.jit.script
def copysign(a, b):
a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
return torch.abs(a) * torch.sign(b)
@torch.jit.script
def get_euler_xyz(q: torch.Tensor) -> tuple:
qx, qy, qz, qw = 0, 1, 2, 3
# roll (x-axis rotation)
sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
cosr_cosp = (
q[:, qw] * q[:, qw]
- q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
+ q[:, qz] * q[:, qz]
)
roll = torch.atan2(sinr_cosp, cosr_cosp)
# pitch (y-axis rotation)
sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
pitch = torch.where(
torch.abs(sinp) >= 1,
copysign(torch.tensor(np.pi / 2.0, device=sinp.device), sinp),
torch.asin(sinp),
)
# yaw (z-axis rotation)
siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
cosy_cosp = (
q[:, qw] * q[:, qw]
+ q[:, qx] * q[:, qx]
- q[:, qy] * q[:, qy]
- q[:, qz] * q[:, qz]
)
yaw = torch.atan2(siny_cosp, cosy_cosp)
return roll % (2 * np.pi), pitch % (2 * np.pi), yaw % (2 * np.pi)
@torch.jit.script
def quat_from_euler_xyz(roll, pitch, yaw):
cy = torch.cos(yaw * 0.5)
sy = torch.sin(yaw * 0.5)
cr = torch.cos(roll * 0.5)
sr = torch.sin(roll * 0.5)
cp = torch.cos(pitch * 0.5)
sp = torch.sin(pitch * 0.5)
qw = cy * cr * cp + sy * sr * sp
qx = cy * sr * cp - sy * cr * sp
qy = cy * cr * sp + sy * sr * cp
qz = sy * cr * cp - cy * sr * sp
return torch.stack([qx, qy, qz, qw], dim=-1)
def torch_rand_float(lower, upper, shape, device):
return (upper - lower) * torch.rand(*shape, device=device) + lower
# @torch.jit.script
@torch.compile
def torch_random_dir_2(shape, device):
angle = torch_rand_float(-np.pi, np.pi, shape, device).squeeze(-1)
return torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
@torch.jit.script
def tensor_clamp(t, min_t, max_t):
return torch.max(torch.min(t, max_t), min_t)
@torch.jit.script
def scale(x, lower, upper):
return 0.5 * (x + 1.0) * (upper - lower) + lower
@torch.jit.script
def unscale(x, lower, upper):
return (2.0 * x - upper - lower) / (upper - lower)
def unscale_np(x, lower, upper):
return (2.0 * x - upper - lower) / (upper - lower)
@torch.jit.script
def quat_to_angle_axis(q):
# computes axis-angle representation from quaternion q
# q must be normalized
min_theta = 1e-5
qx, _, _, qw = 0, 1, 2, 3
sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])
angle = 2 * torch.acos(q[..., qw])
angle = normalize_angle(angle)
sin_theta_expand = sin_theta.unsqueeze(-1)
axis = q[..., qx:qw] / sin_theta_expand
mask = torch.abs(sin_theta) > min_theta
default_axis = torch.zeros_like(axis)
default_axis[..., -1] = 1
angle = torch.where(mask, angle, torch.zeros_like(angle))
mask_expand = mask.unsqueeze(-1)
axis = torch.where(mask_expand, axis, default_axis)
return angle, axis
@torch.jit.script
def angle_axis_to_exp_map(angle, axis):
# compute exponential map from axis-angle
angle_expand = angle.unsqueeze(-1)
exp_map = angle_expand * axis
return exp_map
@torch.jit.script
def quat_to_exp_map(q):
# compute exponential map from quaternion
# q must be normalized
angle, axis = quat_to_angle_axis(q)
exp_map = angle_axis_to_exp_map(angle, axis)
return exp_map
@torch.jit.script
def slerp(q0, q1, t):
cos_half_theta = torch.sum(q0 * q1, dim=-1)
neg_mask = cos_half_theta < 0
q1 = q1.clone()
q1[neg_mask] = -q1[neg_mask]
cos_half_theta = torch.abs(cos_half_theta)
cos_half_theta = torch.unsqueeze(cos_half_theta, dim=-1)
half_theta = torch.acos(cos_half_theta)
sin_half_theta = torch.sqrt(1.0 - cos_half_theta * cos_half_theta)
ratio_a = torch.sin((1 - t) * half_theta) / sin_half_theta
ratio_b = torch.sin(t * half_theta) / sin_half_theta
new_q = ratio_a * q0 + ratio_b * q1
new_q = torch.where(
torch.abs(sin_half_theta) < 0.001, 0.5 * q0 + 0.5 * q1, new_q
)
new_q = torch.where(torch.abs(cos_half_theta) >= 1, q0, new_q)
return new_q
@torch.jit.script
def my_quat_rotate(q, v):
shape = q.shape
q_w = q[:, -1]
q_vec = q[:, :3]
a = v * (2.0 * q_w**2 - 1.0).unsqueeze(-1)
b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0
c = (
q_vec
* torch.bmm(
q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)
).squeeze(-1)
* 2.0
)
return a + b + c
@torch.jit.script
def calc_heading(q):
# calculate heading direction from quaternion
# the heading is the direction on the xy plane
# q must be normalized
# this is the x axis heading
ref_dir = torch.zeros_like(q[..., 0:3])
ref_dir[..., 0] = 1
rot_dir = my_quat_rotate(q, ref_dir)
heading = torch.atan2(rot_dir[..., 1], rot_dir[..., 0])
return heading
@torch.jit.script
def calc_heading_quat(q):
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(heading, axis)
return heading_q
@torch.jit.script
def calc_heading_quat_inv(q):
# calculate heading rotation from quaternion
# the heading is the direction on the xy plane
# q must be normalized
heading = calc_heading(q)
axis = torch.zeros_like(q[..., 0:3])
axis[..., 2] = 1
heading_q = quat_from_angle_axis(-heading, axis)
return heading_q
@torch.compiler.disable
def axis_angle_from_quat(
quat: torch.Tensor,
w_last: bool = True,
) -> torch.Tensor:
"""Compute axis-angle (log map) vector from a quaternion.
Args:
quat (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z].
w_last (bool): Whether the scalar part w is the last element.
Returns:
torch.Tensor: (..., 3) axis-angle vector (axis * angle), with angle in radians in [0, pi].
Notes:
- The quaternion is sign-adjusted to ensure w >= 0 and normalized to unit length for numerical stability.
- Uses a stable small-angle handling to avoid NaNs and gradient issues.
"""
# Handle different quaternion formats
if w_last:
# Quaternion is [q_x, q_y, q_z, q_w]
quat_w_orig = quat[..., -1:]
else:
# Quaternion is [q_w, q_x, q_y, q_z]
quat_w_orig = quat[..., 0:1]
# Normalize quaternion to have w > 0
quat = quat * (1.0 - 2.0 * (quat_w_orig < 0.0))
# Ensure unit quaternion for stability
quat = quat / torch.linalg.norm(quat, dim=-1, keepdim=True).clamp_min(
1.0e-9
)
# Recompute quat_xyz and quat_w after potential sign flip
if w_last:
quat_w = quat[..., -1:]
quat_xyz = quat[..., :3]
else:
quat_w = quat[..., 0:1]
quat_xyz = quat[..., 1:4]
mag = torch.linalg.norm(quat_xyz, dim=-1)
half_angle = torch.atan2(mag, quat_w.squeeze(-1))
angle = 2.0 * half_angle
# check whether to apply Taylor approximation
use_taylor = angle.abs() <= 1.0e-6
# To prevent NaN gradients with torch.where, we compute both branches and blend
# based on the condition.
# See: https://pytorch.org/docs/1.9.0/generated/torch.where.html#torch-where
# "However, if you need the gradients to flow through the branches, please use torch.lerp"
# Although we are not using lerp, the principle of avoiding sharp branches is the same.
sin_half_angles_over_angles_approx = 0.5 - angle * angle / 48
# Clamp angle to avoid division by zero in the non-taylor branch when angle is exactly 0.
angle_safe = torch.where(use_taylor, torch.ones_like(angle), angle)
sin_half_angles_over_angles_exact = torch.sin(half_angle) / angle_safe
sin_half_angles_over_angles = torch.where(
use_taylor,
sin_half_angles_over_angles_approx,
sin_half_angles_over_angles_exact,
)
return quat_xyz / sin_half_angles_over_angles[..., None]
@torch.compile
def quat_box_minus(
q1: torch.Tensor,
q2: torch.Tensor,
w_last: bool = True,
) -> torch.Tensor:
"""Right-invariant quaternion difference mapped to so(3) via log map.
Computes log(q1 * q2^{-1}) using the shortest rotation convention.
Args:
q1 (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z].
q2 (torch.Tensor): (..., 4) quaternion with the same format as `q1`.
w_last (bool): Whether the scalar part w is the last element.
Returns:
torch.Tensor: (..., 3) axis-angle error vector.
"""
if w_last:
q1_xyzw = q1
q2_xyzw = q2
else:
# Convert from (w, x, y, z) to (x, y, z, w)
q1_xyzw = torch.cat([q1[..., 1:4], q1[..., 0:1]], dim=-1)
q2_xyzw = torch.cat([q2[..., 1:4], q2[..., 0:1]], dim=-1)
quat_diff = quat_mul(
q1_xyzw,
quat_conjugate(q2_xyzw),
w_last=True,
) # q1 * q2^-1
return axis_angle_from_quat(quat_diff, w_last=True) # log(qd)
@torch.compile
def quat_error_magnitude(
q1: torch.Tensor,
q2: torch.Tensor,
w_last: bool = True,
) -> torch.Tensor:
"""Geodesic angle between two orientations given as quaternions.
Args:
q1 (torch.Tensor): (..., 4) quaternion. If `w_last` is True, format is [x, y, z, w]; otherwise [w, x, y, z].
q2 (torch.Tensor): (..., 4) quaternion with the same format as `q1`.
w_last (bool): Whether the scalar part w is the last element.
Returns:
torch.Tensor: (...,) rotation angle in radians in [0, pi].
"""
axis_angle_error = quat_box_minus(q1, q2, w_last=w_last)
return torch.norm(axis_angle_error, dim=-1)
@torch.jit.script
def quat_inv(q: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
"""Computes the inverse of a quaternion.
Args:
q: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
eps: A small value to avoid division by zero. Defaults to 1e-9.
Returns:
The inverse quaternion in (w, x, y, z). Shape is (N, 4).
"""
return quat_conjugate(q) / q.pow(2).sum(dim=-1, keepdim=True).clamp(
min=eps
)
# --------------------- WXYZ helpers (torch) ---------------------
def xyzw_to_wxyz(q: torch.Tensor) -> torch.Tensor:
"""
Convert quaternion from XYZW to WXYZ.
Args:
q (torch.Tensor): (..., 4) quaternion in XYZW.
Returns:
torch.Tensor: (..., 4) quaternion in WXYZ.
"""
return torch.cat([q[..., 3:4], q[..., 0:3]], dim=-1)
def wxyz_to_xyzw(q: torch.Tensor) -> torch.Tensor:
"""
Convert quaternion from WXYZ to XYZW.
Args:
q (torch.Tensor): (..., 4) quaternion in WXYZ.
Returns:
torch.Tensor: (..., 4) quaternion in XYZW.
"""
return torch.cat([q[..., 1:4], q[..., 0:1]], dim=-1)
@torch.compile
def quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
"""
Hamilton product in WXYZ layout using fused implementation.
Args:
q1 (torch.Tensor): (..., 4) WXYZ.
q2 (torch.Tensor): (..., 4) WXYZ.
Returns:
torch.Tensor: (..., 4) WXYZ.
"""
return quat_mul(q1, q2, w_last=False)
def rotate_vec_wxyz(q_wxyz: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Rotate vector v by quaternion q (WXYZ).
Args:
q_wxyz (torch.Tensor): (..., 4) WXYZ.
v (torch.Tensor): (..., 3).
Returns:
torch.Tensor: (..., 3) rotated vector.
"""
# Support single-vector inputs by promoting to batch
single = q_wxyz.ndim == 1
if single:
q_in = q_wxyz[None, :]
v_in = v[None, :]
else:
q_in = q_wxyz
v_in = v
q_xyzw = wxyz_to_xyzw(q_in)
out = quat_apply(q_xyzw, v_in)
if single:
return out[0]
return out
def rotate_vec_inv_wxyz(q_wxyz: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Rotate vector v by inverse of quaternion q (WXYZ).
Args:
q_wxyz (torch.Tensor): (..., 4) WXYZ.
v (torch.Tensor): (..., 3).
Returns:
torch.Tensor: (..., 3) rotated vector in inverse rotation.
"""
single = q_wxyz.ndim == 1
if single:
q_in = q_wxyz[None, :]
v_in = v[None, :]
else:
q_in = q_wxyz
v_in = v
q_xyzw = wxyz_to_xyzw(q_in)
q_inv_xyzw = quat_conjugate(q_xyzw)
out = quat_apply(q_inv_xyzw, v_in)
if single:
return out[0]
return out
def subtract_frame_transforms(
t01: torch.Tensor,
q01: torch.Tensor,
t02: torch.Tensor | None = None,
q02: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""Subtract transformations between two reference frames into a stationary frame.
It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \times T_{02}`,
where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B.
Args:
t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3).
q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3).
Defaults to None, in which case the position is assumed to be zero.
q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4).
Defaults to None, in which case the orientation is assumed to be identity.
Returns:
A tuple containing the position and orientation of frame 2 w.r.t. frame 1.
Shape of the tensors are (N, 3) and (N, 4) respectively.
"""
# compute orientation
q10 = quat_inv(q01)
if q02 is not None:
q12 = quat_mul(q10, q02)
else:
q12 = q10
# compute translation
if t02 is not None:
t12 = quat_apply(q10, t02 - t01)
else:
t12 = quat_apply(q10, -t01)
return t12, q12
@torch.compile
def quat_normalize_wxyz(q_wxyz: torch.Tensor) -> torch.Tensor:
"""
Normalize quaternion in WXYZ layout.
Args:
q_wxyz (torch.Tensor): (..., 4) WXYZ.
Returns:
torch.Tensor: (..., 4) normalized WXYZ.
"""
return q_wxyz / torch.linalg.norm(q_wxyz, dim=-1, keepdim=True).clamp_min(
1.0e-9
)
# @torch.compile
@torch.jit.script
def matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4).
Returns:
Rotation matrices. The shape is (..., 3, 3).
Reference:
https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70
"""
r, i, j, k = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
@torch.jit.script
def rot6d_from_quat(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to 6D rotation representation.
Uses the continuous 6D rotation representation from Zhou et al. (CVPR 2019).
Args:
quaternions: (..., 4) quaternion in (w, x, y, z).
Returns:
(..., 6) 6D rotation representation (first two columns of rotation matrix, flattened).
"""
mat = matrix_from_quat(quaternions) # (..., 3, 3)
batch_shape = mat.shape[:-2]
return mat[..., :, :2].reshape(batch_shape + (6,))
@torch.jit.script
def matrix_from_rot6d(rot6d: torch.Tensor) -> torch.Tensor:
"""
Convert 6D rotation representation to rotation matrix.
Uses Gram-Schmidt orthogonalization to reconstruct the rotation matrix
from the first two columns.
Args:
rot6d: (..., 6) 6D rotation representation (first two columns of rotation matrix, flattened).
Returns:
(..., 3, 3) rotation matrix.
"""
# Extract first two columns
a1 = rot6d[..., :3] # first column
a2 = rot6d[..., 3:] # second column
# Gram-Schmidt orthogonalization
b1 = torch.nn.functional.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = torch.nn.functional.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
# Stack columns to form rotation matrix
mat = torch.stack((b1, b2, b3), dim=-1) # (..., 3, 3)
return mat
@torch.jit.script
def quat_from_matrix(mat: torch.Tensor) -> torch.Tensor:
"""
Convert rotation matrix to quaternion.
Args:
mat: (..., 3, 3) rotation matrix.
Returns:
(..., 4) quaternion in (w, x, y, z).
"""
batch_dim = mat.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
mat.reshape(batch_dim + (9,)), dim=-1
)
# Compute q_abs = sqrt(max(0, trace_terms))
q_abs = torch.sqrt(
torch.clamp(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
),
min=0.0,
)
)
# Compute quaternion candidates for each branch
quat_by_rijk = torch.stack(
[
torch.stack(
[q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1
),
torch.stack(
[m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1
),
torch.stack(
[m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1
),
torch.stack(
[m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1
),
],
dim=-2,
)
# Normalize candidates (floor at 0.1 for numerical stability)
flr = torch.tensor(0.1, dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].clamp(min=flr))
# Pick the best-conditioned candidate (largest denominator)
return quat_candidates[
torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5,
:,
].reshape(batch_dim + (4,))
@torch.jit.script
def quat_from_rot6d(rot6d: torch.Tensor) -> torch.Tensor:
"""
Convert 6D rotation representation to quaternions.
Args:
rot6d: (..., 6) 6D rotation representation (first two columns of rotation matrix, flattened).
Returns:
(..., 4) quaternion in (w, x, y, z).
"""
mat = matrix_from_rot6d(rot6d)
return quat_from_matrix(mat)
@torch.jit.script
def euler_xyz_from_quat(
quat: torch.Tensor, wrap_to_2pi: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert rotations given as quaternions to Euler angles in radians.
Note:
The euler angles are assumed in XYZ extrinsic convention.
Args:
quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4).
wrap_to_2pi (bool): Whether to wrap output Euler angles into [0, 2π). If
False, angles are returned in the default range (−π, π]. Defaults to
False.
Returns:
A tuple containing roll-pitch-yaw. Each element is a tensor of shape (N,).
Reference:
https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles
"""
q_w, q_x, q_y, q_z = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3]
# roll (x-axis rotation)
sin_roll = 2.0 * (q_w * q_x + q_y * q_z)
cos_roll = 1 - 2 * (q_x * q_x + q_y * q_y)
roll = torch.atan2(sin_roll, cos_roll)
# pitch (y-axis rotation)
sin_pitch = 2.0 * (q_w * q_y - q_z * q_x)
pitch = torch.where(
torch.abs(sin_pitch) >= 1,
torch.copysign(
torch.tensor(torch.pi / 2.0, device=quat.device, dtype=quat.dtype),
sin_pitch,
),
torch.asin(sin_pitch),
)
# yaw (z-axis rotation)
sin_yaw = 2.0 * (q_w * q_z + q_x * q_y)
cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z)
yaw = torch.atan2(sin_yaw, cos_yaw)
if wrap_to_2pi:
return (
roll % (2 * torch.pi),
pitch % (2 * torch.pi),
yaw % (2 * torch.pi),
)
return roll, pitch, yaw
@torch.jit.script
def yaw_quat(quat: torch.Tensor) -> torch.Tensor:
"""Extract the yaw component of a quaternion.
Args:
quat: The orientation in (w, x, y, z). Shape is (..., 4)
Returns:
A quaternion with only yaw component.
"""
shape = quat.shape
quat_yaw = quat.view(-1, 4)
qw = quat_yaw[:, 0]
qx = quat_yaw[:, 1]
qy = quat_yaw[:, 2]
qz = quat_yaw[:, 3]
yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz))
quat_yaw = torch.zeros_like(quat_yaw)
quat_yaw[:, 3] = torch.sin(yaw / 2)
quat_yaw[:, 0] = torch.cos(yaw / 2)
quat_yaw = normalize(quat_yaw)
return quat_yaw.view(shape)
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
@torch.compiler.disable
def gaussian_kernel1d(
sigma: float, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
if sigma <= 0.0:
raise ValueError(f"Invalid sigma: {sigma}")
radius = int(4.0 * sigma + 0.5)
x = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
kernel = torch.exp(-0.5 * (x / sigma).square())
return kernel / kernel.sum()
@torch.compiler.disable
def gaussian_filter1d(x: torch.Tensor, sigma: float, dim: int) -> torch.Tensor:
if x.shape[dim] < 2:
return x
kernel = gaussian_kernel1d(sigma, device=x.device, dtype=x.dtype).reshape(
1, 1, -1
)
x_perm = x.movedim(dim, -1)
x_flat = x_perm.reshape(-1, 1, x_perm.shape[-1])
pad = kernel.shape[-1] // 2
x_flat = F.pad(x_flat, (pad, pad), mode="replicate")
y = F.conv1d(x_flat, kernel)
y = y.reshape(x_perm.shape)
return y.movedim(-1, dim)
def smooth_time_series(
x: torch.Tensor, sigma: float, dim: int
) -> torch.Tensor:
"""Gaussian smooth along a time dimension.
This is a thin wrapper around :func:`gaussian_filter1d` that treats
non-positive sigma as "no-op" for easy ablations.
"""
if sigma <= 0.0:
return x
return gaussian_filter1d(x, sigma=float(sigma), dim=int(dim))
@torch.compiler.disable
def grad_t(x: torch.Tensor, dt: float) -> torch.Tensor:
if dt <= 0.0:
raise ValueError(f"Invalid dt: {dt}")
if x.shape[1] < 2:
return torch.zeros_like(x)
grad = torch.empty_like(x)
inv_dt = 1.0 / dt
grad[:, 0] = (x[:, 1] - x[:, 0]) * inv_dt
grad[:, -1] = (x[:, -1] - x[:, -2]) * inv_dt
if x.shape[1] > 2:
grad[:, 1:-1] = (x[:, 2:] - x[:, :-2]) * (0.5 * inv_dt)
return grad
def axis_angle_to_matrix(
angles: torch.Tensor, axes: torch.Tensor
) -> torch.Tensor:
if axes.shape[-1] != 3:
raise ValueError("Axes must have shape (N, 3)")
axis_norm = torch.linalg.norm(axes, dim=-1)
if torch.any(axis_norm <= 0):
raise ValueError("Axis vector has zero norm")
axis = axes / axis_norm[:, None]
aat = torch.einsum("ni,nj->nij", axis, axis)
skew = torch.zeros(
(axis.shape[0], 3, 3), device=axes.device, dtype=axes.dtype
)
ax, ay, az = axis[:, 0], axis[:, 1], axis[:, 2]
skew[:, 0, 1] = -az
skew[:, 0, 2] = ay
skew[:, 1, 0] = az
skew[:, 1, 2] = -ax
skew[:, 2, 0] = -ay
skew[:, 2, 1] = ax
sin_t = torch.sin(angles)
cos_t = torch.cos(angles)
eye = torch.eye(3, device=axes.device, dtype=axes.dtype)[None, None, None]
return (
cos_t[..., None, None] * eye
+ (1.0 - cos_t)[..., None, None] * aat[None, None]
+ sin_t[..., None, None] * skew[None, None]
)
================================================
FILE: holomotion/src/utils/unitree_g1_actuator_calculator.py
================================================
# Project HoloMotion
#
# Copyright (c) 2024-2026 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class MotorFamily:
name: str
armature: float
x1: float
x2: float
y1: float
y2: float
fs: float
fd: float
va: float = 0.01
@dataclass(frozen=True)
class JointSpec:
joint_expr: str
motor: MotorFamily
effort_limit: float
velocity_limit: float
servo_scale: float = 1.0
envelope_scale: float = 1.0
friction_scale: float = 1.0
# -----------------------------------------------------------------------------
# Base actuator families
# -----------------------------------------------------------------------------
N5020_16 = MotorFamily(
name="N5020_16",
armature=0.003609725,
x1=30.86,
x2=40.13,
y1=24.8,
y2=31.9,
fs=0.6,
fd=0.06,
)
N7520_14P3 = MotorFamily(
name="N7520_14P3",
armature=0.010177520,
x1=22.63,
x2=35.52,
y1=71.0,
y2=83.3,
fs=1.6,
fd=0.16,
)
N7520_22P5 = MotorFamily(
name="N7520_22P5",
armature=0.025101925,
x1=14.5,
x2=22.7,
y1=111.0,
y2=131.0,
fs=2.4,
fd=0.24,
)
W4010_25 = MotorFamily(
name="W4010_25",
armature=0.00425,
x1=15.3,
x2=24.76,
y1=4.8,
y2=8.6,
fs=0.6,
fd=0.06,
)
# -----------------------------------------------------------------------------
# Design constants
# -----------------------------------------------------------------------------
NATURAL_FREQ_HZ = 10.0
DAMPING_RATIO = 2.0
# Set this to your actual physics dt before running the generator.
PHYSICS_DT = 1.0 / 200.0
# Desired action delay budget: at most 2 * (1 / 50) = 0.04 s.
MIN_DELAY_SECONDS = 0.0
MAX_DELAY_SECONDS = 2.0 / 50.0
def seconds_to_delay_steps(delay_seconds: float, physics_dt: float) -> int:
return int(math.floor(delay_seconds / physics_dt + 1e-12))
MIN_DELAY = seconds_to_delay_steps(MIN_DELAY_SECONDS, PHYSICS_DT)
MAX_DELAY = seconds_to_delay_steps(MAX_DELAY_SECONDS, PHYSICS_DT)
# -----------------------------------------------------------------------------
# Single-group mapping
#
# ankle / waist:
# - servo-side armature/gains are doubled
# - torque envelope is NOT doubled
# - friction is NOT doubled
# -----------------------------------------------------------------------------
ALL_JOINT_SPECS: list[JointSpec] = [
# legs
JointSpec(
".*_hip_yaw_joint", N7520_14P3, effort_limit=88.0, velocity_limit=32.0
),
JointSpec(
".*_hip_roll_joint",
N7520_22P5,
effort_limit=139.0,
velocity_limit=20.0,
),
JointSpec(
".*_hip_pitch_joint",
N7520_14P3,
effort_limit=88.0,
velocity_limit=32.0,
),
JointSpec(
".*_knee_joint", N7520_22P5, effort_limit=139.0, velocity_limit=20.0
),
# feet
JointSpec(
".*_ankle_pitch_joint",
N5020_16,
effort_limit=50.0,
velocity_limit=37.0,
servo_scale=2.0,
),
JointSpec(
".*_ankle_roll_joint",
N5020_16,
effort_limit=50.0,
velocity_limit=37.0,
servo_scale=2.0,
),
# waist
JointSpec(
"waist_roll_joint",
N5020_16,
effort_limit=50.0,
velocity_limit=37.0,
servo_scale=2.0,
),
JointSpec(
"waist_pitch_joint",
N5020_16,
effort_limit=50.0,
velocity_limit=37.0,
servo_scale=2.0,
),
JointSpec(
"waist_yaw_joint", N7520_14P3, effort_limit=88.0, velocity_limit=32.0
),
# arms
JointSpec(
".*_shoulder_pitch_joint",
N5020_16,
effort_limit=25.0,
velocity_limit=37.0,
),
JointSpec(
".*_shoulder_roll_joint",
N5020_16,
effort_limit=25.0,
velocity_limit=37.0,
),
JointSpec(
".*_shoulder_yaw_joint",
N5020_16,
effort_limit=25.0,
velocity_limit=37.0,
),
JointSpec(
".*_elbow_joint", N5020_16, effort_limit=25.0, velocity_limit=37.0
),
JointSpec(
".*_wrist_roll_joint", N5020_16, effort_limit=25.0, velocity_limit=37.0
),
JointSpec(
".*_wrist_pitch_joint", W4010_25, effort_limit=5.0, velocity_limit=22.0
),
JointSpec(
".*_wrist_yaw_joint", W4010_25, effort_limit=5.0, velocity_limit=22.0
),
]
def compute_pd_gains(
armature: float, natural_freq_hz: float, damping_ratio: float
) -> tuple[float, float]:
wn = natural_freq_hz * 2.0 * math.pi
kp = armature * wn * wn
kd = 2.0 * damping_ratio * armature * wn
return kp, kd
def fmt_float(x: float) -> str:
return format(float(x), ".12g")
def fmt_value(value: Any, indent: int = 0) -> str:
sp = " " * indent
if isinstance(value, dict):
if not value:
return "{}"
lines = ["{"]
for k, v in value.items():
lines.append(f"{sp} {k!r}: {fmt_value(v, indent + 4)},")
lines.append(f"{sp}}}")
return "\n".join(lines)
if isinstance(value, list):
if not value:
return "[]"
lines = ["["]
for item in value:
lines.append(f"{sp} {fmt_value(item, indent + 4)},")
lines.append(f"{sp}]")
return "\n".join(lines)
if isinstance(value, float):
return fmt_float(value)
return repr(value)
def build_single_group_cfg(
specs: list[JointSpec],
natural_freq_hz: float = NATURAL_FREQ_HZ,
damping_ratio: float = DAMPING_RATIO,
min_delay: int = MIN_DELAY,
max_delay: int = MAX_DELAY,
) -> dict[str, Any]:
joint_names_expr = [spec.joint_expr for spec in specs]
effort_limit: dict[str, float] = {}
velocity_limit: dict[str, float] = {}
stiffness: dict[str, float] = {}
damping: dict[str, float] = {}
armature: dict[str, float] = {}
x1: dict[str, float] = {}
x2: dict[str, float] = {}
y1: dict[str, float] = {}
y2: dict[str, float] = {}
fs: dict[str, float] = {}
fd: dict[str, float] = {}
va: dict[str, float] = {}
action_scale: dict[str, float] = {}
for spec in specs:
name = spec.joint_expr
servo_armature = spec.motor.armature * spec.servo_scale
kp, kd = compute_pd_gains(
servo_armature, natural_freq_hz, damping_ratio
)
effort_limit[name] = spec.effort_limit
velocity_limit[name] = spec.velocity_limit
stiffness[name] = kp
damping[name] = kd
armature[name] = servo_armature
x1[name] = spec.motor.x1
x2[name] = spec.motor.x2
y1[name] = spec.motor.y1 * spec.envelope_scale
y2[name] = spec.motor.y2 * spec.envelope_scale
fs[name] = spec.motor.fs * spec.friction_scale
fd[name] = spec.motor.fd * spec.friction_scale
va[name] = spec.motor.va
action_scale[name] = 0.25 * spec.effort_limit / kp
return {
"joint_names_expr": joint_names_expr,
"min_delay": min_delay,
"max_delay": max_delay,
"effort_limit": effort_limit,
"velocity_limit": velocity_limit,
"stiffness": stiffness,
"damping": damping,
"armature": armature,
"friction": 0.0,
"dynamic_friction": 0.0,
"viscous_friction": 0.0,
"X1": x1,
"X2": x2,
"Y1": y1,
"Y2": y2,
"Fs": fs,
"Fd": fd,
"Va": va,
"action_scale": action_scale,
}
def render_single_group_cfg(
cfg: dict[str, Any], group_name: str = "all_joints"
) -> str:
ordered_keys = [
"joint_names_expr",
"min_delay",
"max_delay",
"effort_limit",
"velocity_limit",
"stiffness",
"damping",
"armature",
"friction",
"dynamic_friction",
"viscous_friction",
"X1",
"X2",
"Y1",
"Y2",
"Fs",
"Fd",
"Va",
]
lines = [
"from unitree_actuators import UnitreeActuatorCfg",
"",
"G1_HIFI_ACTUATORS = {",
f" {group_name!r}: UnitreeActuatorCfg(",
]
for key in ordered_keys:
rendered = fmt_value(cfg[key], indent=8)
lines.append(f" {key}={rendered},")
lines.append(" )")
lines.append("}")
lines.append("")
lines.append("G1_HIFI_ACTION_SCALE = {")
for joint_expr in cfg["joint_names_expr"]:
lines.append(
f" {joint_expr!r}: {fmt_float(cfg['action_scale'][joint_expr])},"
)
lines.append("}")
return "\n".join(lines)
def print_summary(cfg: dict[str, Any]) -> None:
print("# === SUMMARY ===")
print(f"# physics_dt = {fmt_float(PHYSICS_DT)}")
print(f"# min_delay = {cfg['min_delay']}")
print(f"# max_delay = {cfg['max_delay']}")
print(
"# joint_expr | effort_limit | velocity_limit | armature | kp | kd | "
"X1 | X2 | Y1 | Y2 | Fs | Fd | action_scale"
)
for joint_expr in cfg["joint_names_expr"]:
print(
f"# {joint_expr} | "
f"{fmt_float(cfg['effort_limit'][joint_expr])} | "
f"{fmt_float(cfg['velocity_limit'][joint_expr])} | "
f"{fmt_float(cfg['armature'][joint_expr])} | "
f"{fmt_float(cfg['stiffness'][joint_expr])} | "
f"{fmt_float(cfg['damping'][joint_expr])} | "
f"{fmt_float(cfg['X1'][joint_expr])} | "
f"{fmt_float(cfg['X2'][joint_expr])} | "
f"{fmt_float(cfg['Y1'][joint_expr])} | "
f"{fmt_float(cfg['Y2'][joint_expr])} | "
f"{fmt_float(cfg['Fs'][joint_expr])} | "
f"{fmt_float(cfg['Fd'][joint_expr])} | "
f"{fmt_float(cfg['action_scale'][joint_expr])}"
)
print()
def main() -> None:
cfg = build_single_group_cfg(ALL_JOINT_SPECS)
print_summary(cfg)
print(render_single_group_cfg(cfg, group_name="all_joints"))
if __name__ == "__main__":
main()
================================================
FILE: holomotion/tests/__init__.py
================================================
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=64.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "holomotion"
version = "1.2.0"
description = "HoloMotion"
authors = [
{name = "Horizon Robotics"},
]
readme = "README.md"
requires-python = ">=3.10"
dependencies = []
[project.urls]
Homepage = "https://horizonrobotics.github.io/robot_lab/holomotion/ "
Repository = "https://github.com/"
[tool.setuptools.packages.find]
where = ["."]
include = ["holomotion*"]
[tool.ruff]
exclude = [
# common
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
# project
"3rdparty/*",
"dummy/*",
"*.pyi",
"*_pb2.py",
]
# Same as Black.
line-length = 79
indent-width = 4
# required python 3.11
target-version = "py311"
[tool.ruff.lint]
select = [
"E", # flake8-errors
"F", # pyflake
"I", # isort
"B", # flake8-bugber
"TID", # flake8-tidy-imports
"D", # pydocstyle
"Q", # flake8-quotes
"W", # flake8-warnings
"N", # pep8-naming
]
ignore = [
"D104",
"D107",
"D202",
"D105",
"D100",
"D102",
"D103",
"D101",
"D301",
"F403",
"B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling
"B028", # No explicit `stacklevel` keyword argument found
"D417", # requires documentation for every function parameter.
]
[tool.ruff.lint.isort]
known-third-party = []
no-lines-before = ["future", "standard-library"]
combine-as-imports = true
force-wrap-aliases = true
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.ruff.lint.flake8-tidy-imports]
# Disallow all relative imports.
ban-relative-imports = "all"
[tool.ruff.lint.flake8-quotes]
avoid-escape = false
[tool.ruff.lint.mccabe]
max-complexity = 18
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["TID252", "F401"]
[tool.ruff.lint.pep8-naming]
classmethod-decorators = [
# Allow Pydantic's `@validator` decorator to trigger class method treatment.
"pydantic.validator",
# Allow SQLAlchemy's dynamic decorators, like `@field.expression`, to trigger class method treatment.
"declared_attr",
"expression",
"comparator",
]
ignore-names = [
# ruff default (https://docs.astral.sh/ruff/settings/#lintpep8-naming)
"setUp",
"tearDown",
"setUpClass",
"tearDownClass",
"setUpModule",
"tearDownModule",
"asyncSetUp",
"asyncTearDown",
"setUpTestData",
"failureException",
"longMessage",
"maxDiff",
# project
"PROJECT_ROOT", # project test environment fixture
"ROBO_ORCHARD_TEST_WORKSPACE", # project test fixture
"F", # import torch.nn.functional as F
]
[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"
# Like Black, indent with spaces, rather than tabs.
indent-style = "space"
# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false
# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"
docstring-code-format = true
================================================
FILE: tests/benchmark_legacy_onnx_attention.py
================================================
import sys
import tempfile
import time
from pathlib import Path
import numpy as np
import onnx
import onnxruntime
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.modules.network_modules import (
export_safe_scaled_dot_product_attention,
)
class _RawAttentionModule(nn.Module):
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
return F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
class _SafeAttentionModule(nn.Module):
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
return export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
)
def _export_model(
module: nn.Module,
export_path: Path,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor,
) -> None:
torch.onnx.export(
module.eval(),
(q, k, v, mask),
str(export_path),
opset_version=17,
input_names=["q", "k", "v", "mask"],
output_names=["out"],
dynamo=False,
verbose=False,
)
def _benchmark_session(
model_path: Path,
provider,
feed: dict[str, np.ndarray],
*,
warmup_iters: int = 50,
measure_iters: int = 300,
) -> float:
providers = (
["CPUExecutionProvider"]
if provider == "CPUExecutionProvider"
else [provider, "CPUExecutionProvider"]
)
session = onnxruntime.InferenceSession(
str(model_path),
providers=providers,
)
for _ in range(warmup_iters):
session.run(["out"], feed)
start = time.perf_counter()
for _ in range(measure_iters):
session.run(["out"], feed)
elapsed_s = time.perf_counter() - start
return (elapsed_s * 1000.0) / measure_iters
def main() -> None:
torch.manual_seed(0)
q = torch.randn(4, 8, 1, 64)
k = torch.randn(4, 8, 32, 64)
v = torch.randn(4, 8, 32, 64)
valid_lengths = torch.tensor([32, 24, 16, 8], dtype=torch.int64)
mask = (
torch.arange(32, dtype=torch.int64)[None, :] < valid_lengths[:, None]
)
mask = mask[:, None, None, :]
feed = {
"q": q.numpy(),
"k": k.numpy(),
"v": v.numpy(),
"mask": mask.numpy(),
}
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
raw_path = tmp_path / "raw_attention.onnx"
safe_path = tmp_path / "safe_attention.onnx"
_export_model(_RawAttentionModule(), raw_path, q, k, v, mask)
_export_model(_SafeAttentionModule(), safe_path, q, k, v, mask)
raw_model = onnx.load(str(raw_path))
safe_model = onnx.load(str(safe_path))
raw_ops = [node.op_type for node in raw_model.graph.node]
safe_ops = [node.op_type for node in safe_model.graph.node]
print(
"Graph ops: "
f"raw_has_isnan={'IsNaN' in raw_ops}, "
f"safe_has_isnan={'IsNaN' in safe_ops}"
)
cpu_raw = _benchmark_session(raw_path, "CPUExecutionProvider", feed)
cpu_safe = _benchmark_session(safe_path, "CPUExecutionProvider", feed)
print(
f"CPUExecutionProvider: raw={cpu_raw:.4f} ms, "
f"safe={cpu_safe:.4f} ms, "
f"delta={(cpu_safe - cpu_raw) / cpu_raw * 100.0:.2f}%"
)
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
cuda_raw = _benchmark_session(
raw_path, "CUDAExecutionProvider", feed
)
cuda_safe = _benchmark_session(
safe_path, "CUDAExecutionProvider", feed
)
print(
f"CUDAExecutionProvider: raw={cuda_raw:.4f} ms, "
f"safe={cuda_safe:.4f} ms, "
f"delta={(cuda_safe - cuda_raw) / cuda_raw * 100.0:.2f}%"
)
if __name__ == "__main__":
main()
================================================
FILE: tests/benchmark_moe_router_export.py
================================================
import re
import time
from pathlib import Path
import torch
def _extract_int_setting(config_path: Path, key: str) -> int:
pattern = re.compile(rf"^\s*{re.escape(key)}:\s*([0-9]+)\s*$")
for line in config_path.read_text().splitlines():
match = pattern.match(line)
if match:
return int(match.group(1))
raise ValueError(
f"Could not find integer setting {key!r} in {config_path}"
)
def _load_b0310_shape_config() -> dict[str, int]:
repo_root = Path(__file__).resolve().parents[1]
module_cfg = (
repo_root
/ "holomotion"
/ "config"
/ "modules"
/ "motion_tracking"
/ "tf_motrack_v3.yaml"
)
return {
"num_fine_experts": _extract_int_setting(
module_cfg, "num_fine_experts"
),
"top_k": _extract_int_setting(module_cfg, "top_k"),
"max_ctx_len": _extract_int_setting(module_cfg, "max_ctx_len"),
}
def _router_scores_training(
logits_fp32: torch.Tensor,
*,
top_k: int,
bias_fp32: torch.Tensor | None = None,
) -> torch.Tensor:
choice_logits = (
logits_fp32 if bias_fp32 is None else logits_fp32 + bias_fp32
)
_, topk_idx = torch.topk(choice_logits, top_k, dim=-1)
selected_logits = logits_fp32.gather(-1, topk_idx)
log_z = torch.logsumexp(logits_fp32, dim=-1, keepdim=True)
selected_probs = torch.exp(selected_logits - log_z)
return selected_probs / selected_probs.sum(dim=-1, keepdim=True).clamp_min(
1.0e-20
)
def _router_scores_export_safe(
logits_fp32: torch.Tensor,
*,
top_k: int,
bias_fp32: torch.Tensor | None = None,
) -> torch.Tensor:
choice_logits = (
logits_fp32 if bias_fp32 is None else logits_fp32 + bias_fp32
)
_, topk_idx = torch.topk(choice_logits, top_k, dim=-1)
selected_probs = torch.softmax(logits_fp32, dim=-1).gather(-1, topk_idx)
return selected_probs / selected_probs.sum(dim=-1, keepdim=True).clamp_min(
1.0e-20
)
def _benchmark(
fn,
logits_fp32: torch.Tensor,
*,
top_k: int,
bias_fp32: torch.Tensor | None = None,
warmup_iters: int = 200,
measure_iters: int = 2000,
) -> float:
is_cuda = logits_fp32.is_cuda
with torch.inference_mode():
for _ in range(warmup_iters):
fn(logits_fp32, top_k=top_k, bias_fp32=bias_fp32)
if is_cuda:
torch.cuda.synchronize(logits_fp32.device)
start = time.perf_counter()
for _ in range(measure_iters):
fn(logits_fp32, top_k=top_k, bias_fp32=bias_fp32)
if is_cuda:
torch.cuda.synchronize(logits_fp32.device)
elapsed_s = time.perf_counter() - start
return (elapsed_s * 1000.0) / measure_iters
def _run_case(
device: torch.device,
*,
case_name: str,
batch_size: int,
seq_len: int,
num_fine_experts: int,
top_k: int,
) -> None:
seed = 0
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
logits_fp32 = torch.randn(
batch_size,
seq_len,
num_fine_experts,
generator=generator,
dtype=torch.float32,
).to(device)
eager_scores = _router_scores_training(logits_fp32, top_k=top_k)
export_scores = _router_scores_export_safe(logits_fp32, top_k=top_k)
max_abs_diff = (eager_scores - export_scores).abs().max().item()
eager_ms = _benchmark(
_router_scores_training,
logits_fp32,
top_k=top_k,
)
export_ms = _benchmark(
_router_scores_export_safe,
logits_fp32,
top_k=top_k,
)
delta_pct = ((export_ms - eager_ms) / eager_ms) * 100.0
print(
f"{device.type}:{case_name}: "
f"shape=({batch_size}, {seq_len}, {num_fine_experts}), "
f"top_k={top_k}, "
f"training={eager_ms:.6f} ms, "
f"export_safe={export_ms:.6f} ms, "
f"delta={delta_pct:.2f}%, "
f"max_abs_diff={max_abs_diff:.3e}"
)
def main() -> None:
shape_cfg = _load_b0310_shape_config()
num_fine_experts = shape_cfg["num_fine_experts"]
top_k = shape_cfg["top_k"]
max_ctx_len = shape_cfg["max_ctx_len"]
cases = [
("single_step_export", 1, 1),
("training_like_sequence", 16, max_ctx_len),
]
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda"))
print(
"Benchmarking MoE router formulas with "
f"num_fine_experts={num_fine_experts}, top_k={top_k}, "
f"max_ctx_len={max_ctx_len}"
)
for device in devices:
for case_name, batch_size, seq_len in cases:
_run_case(
device,
case_name=case_name,
batch_size=batch_size,
seq_len=seq_len,
num_fine_experts=num_fine_experts,
top_k=top_k,
)
if __name__ == "__main__":
main()
================================================
FILE: tests/test_actor_export_config.py
================================================
import importlib
import sys
import unittest
from pathlib import Path
from types import SimpleNamespace
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.modules.agent_modules import (
PPOTFActor,
PPOTFRefRouterActor,
PPOTFRefRouterSeqActor,
PPOTFRefRouterV3Actor,
_clone_module_for_cpu_export,
)
from holomotion.src.modules.network_modules import (
GroupedMoEBlock,
GroupedMoETransformerPolicy,
ReferenceRoutedGroupedMoETransformerPolicy,
ReferenceRoutedGroupedMoETransformerPolicyV2,
ReferenceRoutedGroupedMoETransformerPolicyV3,
export_safe_scaled_dot_product_attention,
)
from holomotion.src.utils.onnx_export import export_policy_to_onnx
from tensordict import TensorDict
try:
onnx = importlib.import_module("onnx")
torch = importlib.import_module("torch")
nn = importlib.import_module("torch.nn")
except ModuleNotFoundError as exc:
raise unittest.SkipTest(
f"Optional ONNX test dependency missing: {exc.name}"
) from exc
class _DummyTFModule(nn.Module):
def __init__(self):
super().__init__()
self.n_layers = 1
self.max_ctx_len = 4
self.n_kv_heads = 1
self.head_dim = 2
def forward(
self,
obs: torch.Tensor,
past_key_values: torch.Tensor,
current_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return obs[:, :2], past_key_values
class _DummyAttentionTFModule(nn.Module):
def __init__(self):
super().__init__()
self.n_layers = 1
self.max_ctx_len = 4
self.n_kv_heads = 1
self.head_dim = 2
def forward(
self,
obs: torch.Tensor,
past_key_values: torch.Tensor,
current_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_size = obs.shape[0]
max_len = past_key_values.shape[3]
valid_len = (current_pos + 1).clamp(max=max_len)
pos_idx = torch.arange(max_len, device=obs.device, dtype=torch.int64)
mask = (pos_idx[None, :] < valid_len[:, None])[:, None, None, :]
q = obs[:, :2].reshape(batch_size, 1, 1, 2)
k = torch.zeros(batch_size, 1, max_len, 2, device=obs.device)
v = torch.ones(batch_size, 1, max_len, 2, device=obs.device)
attn_out = export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
)
actions = attn_out.reshape(batch_size, 2)
return actions, past_key_values
class _RecordingDeviceModule(nn.Module):
def __init__(self):
super().__init__()
self.current_device = "cuda:0"
self.to_calls = []
def to(self, device):
self.to_calls.append(str(device))
self.current_device = str(device)
return self
def _make_minimal_real_transformer_actor(
*,
n_layers: int = 1,
routing_score_fn: str = "softmax",
num_fine_experts: int = 1,
top_k: int = 1,
use_dynamic_bias: bool = False,
dense_layer_at_last: bool = False,
selected_expert_margin_to_unselected_enabled: bool = False,
selected_expert_margin_to_unselected_target: float = 0.0,
) -> PPOTFActor:
actor = PPOTFActor.__new__(PPOTFActor)
nn.Module.__init__(actor)
actor.actor_module = GroupedMoETransformerPolicy(
input_dim=6,
output_dim=2,
module_config_dict={
"type": "GroupedMoETransformerPolicy",
"num_fine_experts": num_fine_experts,
"num_shared_experts": 0,
"top_k": top_k,
"obs_embed_mlp_hidden": 8,
"d_model": 8,
"n_layers": n_layers,
"n_heads": 2,
"n_kv_heads": 1,
"ff_mult": 1.0,
"ff_mult_dense": 1,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"dense_layer_at_last": dense_layer_at_last,
"use_gated_attn": False,
"use_qk_norm": True,
"routing_score_fn": routing_score_fn,
"use_dynamic_bias": use_dynamic_bias,
"selected_expert_margin_to_unselected": {
"enabled": selected_expert_margin_to_unselected_enabled,
"target": selected_expert_margin_to_unselected_target,
},
},
)
actor.obs_norm_enabled = False
actor.obs_normalizer = nn.Identity()
actor.obs_norm_clip = 0.0
actor.assembler = SimpleNamespace(output_dim=6)
return actor
def _capture_moe_router_outputs(
monkeypatch,
*,
export_mode: bool,
top_k: int,
use_dynamic_bias: bool,
x: torch.Tensor,
router_weight: torch.Tensor,
router_x: torch.Tensor | None = None,
expert_bias: torch.Tensor | None = None,
):
block = GroupedMoEBlock(
d_model=x.shape[-1],
n_heads=2,
n_kv_heads=1,
num_fine_experts=router_weight.shape[0],
num_shared_experts=1,
top_k=top_k,
ff_mult=1.0,
use_qk_norm=True,
use_gated_attn=False,
attn_dropout=0.0,
mlp_dropout=0.0,
use_dynamic_bias=use_dynamic_bias,
routing_score_fn="softmax",
)
block.eval()
with torch.no_grad():
block.router.weight.copy_(router_weight)
if expert_bias is not None:
block.expert_bias.copy_(expert_bias)
captured = {}
def _fake_sparse_experts(
x_input: torch.Tensor,
topk_idx: torch.Tensor,
topk_scores: torch.Tensor,
) -> torch.Tensor:
captured["topk_idx"] = topk_idx.detach().clone()
captured["topk_scores"] = topk_scores.detach().clone()
return torch.zeros_like(x_input)
monkeypatch.setattr(torch.onnx, "is_in_onnx_export", lambda: export_mode)
monkeypatch.setattr(block, "_compute_sparse_experts", _fake_sparse_experts)
captured["output"] = block.compute_moe_ffn(x, router_x=router_x)
return captured
def _make_minimal_ref_router_actor() -> PPOTFRefRouterActor:
actor = PPOTFRefRouterActor.__new__(PPOTFRefRouterActor)
nn.Module.__init__(actor)
actor.actor_module = ReferenceRoutedGroupedMoETransformerPolicy(
input_dim=8,
output_dim=2,
module_config_dict={
"type": "ReferenceRoutedGroupedMoETransformerPolicy",
"num_fine_experts": 4,
"num_shared_experts": 0,
"top_k": 2,
"obs_embed_mlp_hidden": 8,
"router_embed_mlp_hidden": 8,
"router_input_dim": 4,
"router_feature_indices": [0, 1, 4, 5],
"d_model": 8,
"n_layers": 2,
"n_heads": 2,
"n_kv_heads": 1,
"ff_mult": 1.0,
"ff_mult_dense": 1,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"use_gated_attn": False,
"use_qk_norm": True,
"routing_score_fn": "softmax",
"use_dynamic_bias": False,
},
)
actor.obs_norm_enabled = False
actor.obs_normalizer = nn.Identity()
actor.obs_norm_clip = 0.0
actor.assembler = SimpleNamespace(output_dim=8)
return actor
def _make_ref_router_v2_obs_schema() -> dict:
return {
"flattened_obs": {
"seq_len": 1,
"terms": [
"unified/actor_ref_gravity_projection_cur",
"unified/actor_ref_base_linvel_cur",
"unified/actor_ref_base_angvel_cur",
"unified/actor_ref_dof_pos_cur",
"unified/actor_projected_gravity",
"unified/actor_rel_robot_root_ang_vel",
"unified/actor_dof_vel",
"unified/actor_dof_pos",
"unified/actor_ref_root_height_cur",
"unified/actor_last_action",
],
},
"flattened_obs_fut": {
"seq_len": 5,
"terms": [
"unified/actor_ref_gravity_projection_fut",
"unified/actor_ref_base_linvel_fut",
"unified/actor_ref_base_angvel_fut",
"unified/actor_ref_dof_pos_fut",
"unified/actor_ref_root_height_fut",
],
},
}
def _make_ref_router_v2_obs(batch_size: list[int]) -> TensorDict:
shape = list(batch_size)
actor_fut_shape = shape + [5]
unified = TensorDict(
{
"actor_ref_gravity_projection_cur": torch.randn(*shape, 3),
"actor_ref_base_linvel_cur": torch.randn(*shape, 3),
"actor_ref_base_angvel_cur": torch.randn(*shape, 3),
"actor_ref_dof_pos_cur": torch.randn(*shape, 2),
"actor_projected_gravity": torch.randn(*shape, 3),
"actor_rel_robot_root_ang_vel": torch.randn(*shape, 3),
"actor_dof_vel": torch.randn(*shape, 3),
"actor_dof_pos": torch.randn(*shape, 3),
"actor_ref_root_height_cur": torch.randn(*shape, 1),
"actor_last_action": torch.randn(*shape, 2),
"actor_ref_gravity_projection_fut": torch.randn(
*actor_fut_shape, 3
),
"actor_ref_base_linvel_fut": torch.randn(*actor_fut_shape, 3),
"actor_ref_base_angvel_fut": torch.randn(*actor_fut_shape, 3),
"actor_ref_dof_pos_fut": torch.randn(*actor_fut_shape, 2),
"actor_ref_root_height_fut": torch.randn(*actor_fut_shape, 1),
},
batch_size=shape,
)
return TensorDict({"unified": unified}, batch_size=shape)
def _make_minimal_ref_router_v2_actor() -> PPOTFRefRouterSeqActor:
obs_schema = _make_ref_router_v2_obs_schema()
obs_example = _make_ref_router_v2_obs([2])
return PPOTFRefRouterSeqActor(
obs_schema=obs_schema,
module_config_dict={
"type": "ReferenceRoutedGroupedMoETransformerPolicyV2",
"num_fine_experts": 4,
"num_shared_experts": 0,
"top_k": 2,
"obs_embed_mlp_hidden": 8,
"d_model": 8,
"n_layers": 2,
"n_heads": 2,
"n_kv_heads": 1,
"ff_mult": 1.0,
"ff_mult_dense": 1,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"use_gated_attn": False,
"use_qk_norm": True,
"routing_score_fn": "softmax",
"use_dynamic_bias": False,
"ref_hist_n_layers": 1,
"ref_future_conv_channels": 8,
"ref_future_conv_layers": 2,
"ref_future_conv_kernel_size": 3,
"ref_future_conv_stride": 2,
"obs_norm": {"enabled": False},
"output_dim": 2,
},
num_actions=2,
init_noise_std=0.2,
obs_example=obs_example,
)
def _make_minimal_ref_router_v3_actor() -> PPOTFRefRouterV3Actor:
obs_schema = _make_ref_router_v2_obs_schema()
obs_example = _make_ref_router_v2_obs([2])
return PPOTFRefRouterV3Actor(
obs_schema=obs_schema,
module_config_dict={
"type": "ReferenceRoutedGroupedMoETransformerPolicyV3",
"num_fine_experts": 4,
"num_shared_experts": 0,
"top_k": 2,
"obs_embed_mlp_hidden": 8,
"d_model": 8,
"n_layers": 2,
"n_heads": 2,
"n_kv_heads": 1,
"ff_mult": 1.0,
"ff_mult_dense": 1,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"use_gated_attn": False,
"use_qk_norm": True,
"routing_score_fn": "softmax",
"use_dynamic_bias": False,
"ref_hist_n_layers": 1,
"router_future_hidden_dim": 12,
"router_layer_proj_hidden_dim": 10,
"obs_norm": {"enabled": False},
"output_dim": 2,
},
num_actions=2,
init_noise_std=0.2,
obs_example=obs_example,
)
def test_export_policy_to_onnx_uses_opset_17(monkeypatch, tmp_path):
captured = {}
class _FakeActor:
def eval(self):
return self
def export_onnx(
self,
*,
onnx_path,
opset_version,
use_kv_cache=True,
):
captured["onnx_path"] = onnx_path
captured["opset_version"] = opset_version
captured["use_kv_cache"] = use_kv_cache
return str(onnx_path)
actor = _FakeActor()
algo = SimpleNamespace(
actor=actor,
critic=SimpleNamespace(eval=lambda: None),
accelerator=SimpleNamespace(unwrap_model=lambda model: model),
env=SimpleNamespace(_env=object()),
)
monkeypatch.setattr(
"holomotion.src.utils.onnx_export.attach_onnx_metadata_holomotion",
lambda env, onnx_path: None,
)
checkpoint_path = tmp_path / "model.pt"
checkpoint_path.write_bytes(b"")
export_policy_to_onnx(algo, str(checkpoint_path), use_kv_cache=False)
assert captured["opset_version"] == 17
assert captured["use_kv_cache"] is False
def test_export_policy_to_onnx_restores_training_mode(monkeypatch, tmp_path):
class _FakeActor:
def __init__(self):
self.training = True
def eval(self):
self.training = False
return self
def train(self, mode: bool = True):
self.training = bool(mode)
return self
def export_onnx(
self,
*,
onnx_path,
opset_version,
use_kv_cache=True,
):
return str(onnx_path)
class _FakeCritic:
def __init__(self):
self.training = True
def eval(self):
self.training = False
return self
def train(self, mode: bool = True):
self.training = bool(mode)
return self
actor = _FakeActor()
critic = _FakeCritic()
algo = SimpleNamespace(
actor=actor,
critic=critic,
accelerator=SimpleNamespace(unwrap_model=lambda model: model),
env=SimpleNamespace(_env=object()),
)
monkeypatch.setattr(
"holomotion.src.utils.onnx_export.attach_onnx_metadata_holomotion",
lambda env, onnx_path: None,
)
checkpoint_path = tmp_path / "model.pt"
checkpoint_path.write_bytes(b"")
export_policy_to_onnx(algo, str(checkpoint_path), use_kv_cache=False)
assert actor.training is True
assert critic.training is True
def test_clone_module_for_cpu_export_does_not_move_live_module(monkeypatch):
module = _RecordingDeviceModule()
monkeypatch.setattr(
"holomotion.src.modules.agent_modules._module_device",
lambda _: torch.device("cuda:0"),
)
cloned = _clone_module_for_cpu_export(module)
assert module.to_calls == []
assert module.current_device == "cuda:0"
assert isinstance(cloned, _RecordingDeviceModule)
assert cloned is not module
def test_ppotf_actor_export_uses_legacy_torchscript(monkeypatch, tmp_path):
export_calls = []
def _fake_export(*args, **kwargs):
export_calls.append(kwargs)
monkeypatch.setattr(torch.onnx, "export", _fake_export)
actor = PPOTFActor.__new__(PPOTFActor)
nn.Module.__init__(actor)
actor.actor_module = _DummyTFModule()
actor.obs_norm_enabled = False
actor.obs_normalizer = nn.Identity()
actor.obs_norm_clip = 0.0
actor.assembler = SimpleNamespace(output_dim=6)
out_path = tmp_path / "policy.onnx"
PPOTFActor.export_onnx(
actor,
out_path,
opset_version=17,
use_kv_cache=True,
)
assert len(export_calls) == 1
assert export_calls[0]["opset_version"] == 17
assert export_calls[0]["dynamo"] is False
def test_ppotf_actor_export_onnx_avoids_isnan(tmp_path):
actor = PPOTFActor.__new__(PPOTFActor)
nn.Module.__init__(actor)
actor.actor_module = _DummyAttentionTFModule()
actor.obs_norm_enabled = False
actor.obs_normalizer = nn.Identity()
actor.obs_norm_clip = 0.0
actor.assembler = SimpleNamespace(output_dim=2)
out_path = tmp_path / "policy.onnx"
PPOTFActor.export_onnx(
actor,
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
op_types = [node.op_type for node in model.graph.node]
assert "IsNaN" not in op_types
def test_ppotf_real_transformer_export_onnx_avoids_isnan(tmp_path):
actor = _make_minimal_real_transformer_actor()
out_path = tmp_path / "policy_real_tf.onnx"
PPOTFActor.export_onnx(
actor,
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
op_types = [node.op_type for node in model.graph.node]
assert "IsNaN" not in op_types
def test_ppotf_real_moe_transformer_export_reaches_router_ops(tmp_path):
actor = _make_minimal_real_transformer_actor(
n_layers=2,
num_fine_experts=4,
top_k=2,
routing_score_fn="softmax",
)
out_path = tmp_path / "policy_real_moe_tf.onnx"
PPOTFActor.export_onnx(
actor,
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
op_types = [node.op_type for node in model.graph.node]
assert "TopK" in op_types
def test_ppotf_real_moe_transformer_export_exposes_routing_outputs(tmp_path):
actor = _make_minimal_real_transformer_actor(
n_layers=3,
num_fine_experts=4,
top_k=2,
routing_score_fn="softmax",
)
out_path = tmp_path / "policy_real_moe_tf_outputs.onnx"
PPOTFActor.export_onnx(
actor,
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
output_names = [value.name for value in model.graph.output]
assert output_names == [
"actions",
"present_key_values",
"moe_layer_1_expert_indices",
"moe_layer_1_expert_logits",
"moe_layer_2_expert_indices",
"moe_layer_2_expert_logits",
]
def test_ppotf_real_moe_transformer_export_dense_last_uses_actual_moe_indices(
tmp_path,
):
actor = _make_minimal_real_transformer_actor(
n_layers=4,
num_fine_experts=4,
top_k=2,
routing_score_fn="softmax",
dense_layer_at_last=True,
)
out_path = tmp_path / "policy_real_moe_tf_dense_last_outputs.onnx"
PPOTFActor.export_onnx(
actor,
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
output_names = [value.name for value in model.graph.output]
assert output_names == [
"actions",
"present_key_values",
"moe_layer_1_expert_indices",
"moe_layer_1_expert_logits",
"moe_layer_2_expert_indices",
"moe_layer_2_expert_logits",
]
def test_ppotf_real_moe_transformer_export_avoids_reduce_log_sum_exp(
tmp_path,
):
actor = _make_minimal_real_transformer_actor(
n_layers=2,
num_fine_experts=4,
top_k=2,
routing_score_fn="softmax",
)
out_path = tmp_path / "policy_real_moe_tf_no_rlse.onnx"
PPOTFActor.export_onnx(
actor,
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
op_types = [node.op_type for node in model.graph.node]
assert "ReduceLogSumExp" not in op_types
def test_export_safe_moe_router_matches_training_scores_for_topk1(monkeypatch):
x = torch.tensor([[[1.0, -0.5, 0.25, 2.0]]], dtype=torch.float32)
router_weight = torch.tensor(
[
[0.1, 0.3, -0.2, 0.5],
[0.2, -0.4, 0.1, 0.7],
[-0.3, 0.6, 0.2, -0.1],
[0.4, 0.1, -0.5, 0.2],
],
dtype=torch.float32,
)
eager = _capture_moe_router_outputs(
monkeypatch,
export_mode=False,
top_k=1,
use_dynamic_bias=False,
x=x,
router_weight=router_weight,
)
export = _capture_moe_router_outputs(
monkeypatch,
export_mode=True,
top_k=1,
use_dynamic_bias=False,
x=x,
router_weight=router_weight,
)
assert torch.equal(export["topk_idx"], eager["topk_idx"])
torch.testing.assert_close(
export["topk_scores"],
eager["topk_scores"],
atol=1.0e-6,
rtol=1.0e-5,
)
def test_export_safe_moe_router_matches_training_scores_with_dynamic_bias(
monkeypatch,
):
x = torch.tensor(
[
[[0.2, -1.0, 0.5, 1.1], [0.4, 0.3, -0.7, 0.9]],
[[-0.6, 0.8, 1.0, -0.2], [0.1, -0.4, 0.6, 0.7]],
],
dtype=torch.float32,
)
router_weight = torch.tensor(
[
[0.2, -0.1, 0.5, 0.3],
[-0.4, 0.7, 0.2, 0.1],
[0.6, 0.2, -0.3, 0.4],
[0.1, 0.5, 0.4, -0.6],
],
dtype=torch.float32,
)
expert_bias = torch.tensor([0.0, 0.4, -0.3, 0.2], dtype=torch.float32)
eager = _capture_moe_router_outputs(
monkeypatch,
export_mode=False,
top_k=2,
use_dynamic_bias=True,
x=x,
router_weight=router_weight,
expert_bias=expert_bias,
)
export = _capture_moe_router_outputs(
monkeypatch,
export_mode=True,
top_k=2,
use_dynamic_bias=True,
x=x,
router_weight=router_weight,
expert_bias=expert_bias,
)
assert torch.equal(export["topk_idx"], eager["topk_idx"])
torch.testing.assert_close(
export["topk_scores"],
eager["topk_scores"],
atol=1.0e-6,
rtol=1.0e-5,
)
def test_grouped_moe_router_x_keeps_topk_when_main_input_changes(monkeypatch):
router_weight = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
],
dtype=torch.float32,
)
router_x = torch.tensor([[[4.0, 1.0, 0.0, 0.0]]], dtype=torch.float32)
x_a = torch.tensor([[[0.0, 0.5, 1.0, 1.5]]], dtype=torch.float32)
x_b = torch.tensor([[[3.0, -2.0, -1.0, 6.0]]], dtype=torch.float32)
out_a = _capture_moe_router_outputs(
monkeypatch,
export_mode=False,
top_k=1,
use_dynamic_bias=False,
x=x_a,
router_x=router_x,
router_weight=router_weight,
)
out_b = _capture_moe_router_outputs(
monkeypatch,
export_mode=False,
top_k=1,
use_dynamic_bias=False,
x=x_b,
router_x=router_x,
router_weight=router_weight,
)
assert torch.equal(out_a["topk_idx"], out_b["topk_idx"])
assert not torch.allclose(out_a["output"], out_b["output"])
def test_grouped_moe_router_x_changes_topk_when_router_input_changes(
monkeypatch,
):
router_weight = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
],
dtype=torch.float32,
)
x = torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32)
router_x_a = torch.tensor([[[3.0, 0.0, 0.0, 0.0]]], dtype=torch.float32)
router_x_b = torch.tensor([[[0.0, 5.0, 0.0, 0.0]]], dtype=torch.float32)
out_a = _capture_moe_router_outputs(
monkeypatch,
export_mode=False,
top_k=1,
use_dynamic_bias=False,
x=x,
router_x=router_x_a,
router_weight=router_weight,
)
out_b = _capture_moe_router_outputs(
monkeypatch,
export_mode=False,
top_k=1,
use_dynamic_bias=False,
x=x,
router_x=router_x_b,
router_weight=router_weight,
)
assert not torch.equal(out_a["topk_idx"], out_b["topk_idx"])
def test_ref_router_actor_export_keeps_single_obs_input_and_reaches_moe(
tmp_path,
):
actor = _make_minimal_ref_router_actor()
out_path = tmp_path / "policy_ref_router.onnx"
PPOTFActor.export_onnx(
actor,
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
input_names = [value.name for value in model.graph.input]
op_types = [node.op_type for node in model.graph.node]
assert input_names == ["obs", "past_key_values", "step_idx"]
assert "TopK" in op_types
def test_ref_router_v2_actor_export_keeps_single_obs_input_and_reaches_moe(
tmp_path,
):
actor = _make_minimal_ref_router_v2_actor()
assert actor.onnx_past_key_values_shape(batch_size=1) == (
3,
2,
1,
4,
1,
4,
)
out_path = tmp_path / "policy_ref_router_v2.onnx"
actor.export_onnx(
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
input_names = [value.name for value in model.graph.input]
op_types = [node.op_type for node in model.graph.node]
assert input_names == ["obs", "past_key_values", "step_idx"]
assert "TopK" in op_types
def test_ref_router_v3_actor_export_keeps_single_obs_input_and_reaches_moe(
tmp_path,
):
actor = _make_minimal_ref_router_v3_actor()
assert actor.onnx_past_key_values_shape(batch_size=1) == (
3,
2,
1,
4,
1,
4,
)
out_path = tmp_path / "policy_ref_router_v3.onnx"
actor.export_onnx(
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
input_names = [value.name for value in model.graph.input]
op_types = [node.op_type for node in model.graph.node]
assert input_names == ["obs", "past_key_values", "step_idx"]
assert "TopK" in op_types
def test_real_transformer_actor_export_supports_selected_expert_margin(
tmp_path,
):
actor = _make_minimal_real_transformer_actor(
n_layers=2,
num_fine_experts=4,
top_k=2,
selected_expert_margin_to_unselected_enabled=True,
selected_expert_margin_to_unselected_target=0.4,
)
out_path = tmp_path / "policy_selected_expert_margin.onnx"
actor.export_onnx(
out_path,
opset_version=17,
use_kv_cache=True,
)
model = onnx.load(str(out_path))
input_names = [value.name for value in model.graph.input]
op_types = [node.op_type for node in model.graph.node]
assert input_names == ["obs", "past_key_values", "step_idx"]
assert "TopK" in op_types
================================================
FILE: tests/test_algo_base_iteration_logging.py
================================================
from pathlib import Path
import sys
from unittest import mock
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.algo.algo_base import BaseOnpolicyRL
def test_log_iteration_uses_checkpoint_start_for_total_iterations():
algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL)
algo.log_dir = "/tmp/holomotion-test"
algo.gpu_world_size = 1
algo.num_steps_per_env = 8
algo.num_envs = 16
algo.num_learning_iterations = 10
algo.current_learning_iteration = 123
algo.rewbuffer = []
algo.lenbuffer = []
algo._aggregate_episode_log_metrics = lambda: {}
algo._get_additional_log_metrics = lambda: {}
algo.algo_logger = mock.Mock()
BaseOnpolicyRL._log_iteration(
algo,
it=123,
loss_dict={"policy": 1.5},
collection_time=2.0,
learn_time=2.0,
)
algo.algo_logger.log_iteration.assert_called_once()
_, kwargs = algo.algo_logger.log_iteration.call_args
assert kwargs["step"] == 123
assert kwargs["total_learning_iterations"] == 133
assert kwargs["metrics"]["0-Train/iteration"] == 123
assert kwargs["metrics"]["0-Train/iterations_total"] == 133
================================================
FILE: tests/test_build_quantization_dataset.py
================================================
import importlib.util
from pathlib import Path
def _load_module():
module_path = (
Path(__file__).resolve().parents[1]
/ "not_for_commit"
/ "build_quantization_dataset.py"
)
spec = importlib.util.spec_from_file_location(
"build_quantization_dataset", module_path
)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_allocate_sample_counts_normalizes_weights_and_matches_total():
module = _load_module()
counts = module.allocate_sample_counts(
{
"AMASS": 0.2,
"lafan1": 0.2,
"MotionMillion-ft": 0.4,
"pico_train": 0.05,
},
17,
)
assert counts == {
"AMASS": 4,
"lafan1": 4,
"MotionMillion-ft": 8,
"pico_train": 1,
}
assert sum(counts.values()) == 17
def test_build_quantization_dataset_creates_symlinks(tmp_path):
module = _load_module()
npz_root = tmp_path / "retargeted"
npz_root.mkdir()
for dataset_name, clip_count in {"AMASS": 3, "lafan1": 2}.items():
dataset_dir = npz_root / dataset_name
dataset_dir.mkdir()
for clip_idx in range(clip_count):
(dataset_dir / f"clip_{clip_idx}.npz").write_text(
f"{dataset_name}-{clip_idx}", encoding="utf-8"
)
output_dir = module.build_quantization_dataset(
npz_root=npz_root,
dataset_ratios={"AMASS": 2.0, "lafan1": 1.0},
num_clips=3,
seed=0,
current_date="20260324",
)
assert output_dir == npz_root / "20260324_quant_dataset"
created_links = sorted(output_dir.iterdir())
assert len(created_links) == 3
assert all(link.is_symlink() for link in created_links)
assert {link.name.split("__", 1)[0] for link in created_links} == {
"AMASS",
"lafan1",
}
for link in created_links:
assert link.resolve().is_file()
assert link.suffix == ".npz"
def test_build_quantization_dataset_caps_each_dataset_at_available_clips(
tmp_path,
):
module = _load_module()
npz_root = tmp_path / "retargeted"
npz_root.mkdir()
for dataset_name, clip_count in {"AMASS": 1, "lafan1": 5}.items():
dataset_dir = npz_root / dataset_name
dataset_dir.mkdir()
for clip_idx in range(clip_count):
(dataset_dir / f"clip_{clip_idx}.npz").write_text(
f"{dataset_name}-{clip_idx}", encoding="utf-8"
)
output_dir = module.build_quantization_dataset(
npz_root=npz_root,
dataset_ratios={"AMASS": 2.0, "lafan1": 1.0},
num_clips=6,
seed=0,
current_date="20260324",
)
created_links = sorted(output_dir.iterdir())
assert len(created_links) == 3
assert sum(link.name.startswith("AMASS__") for link in created_links) == 1
assert sum(link.name.startswith("lafan1__") for link in created_links) == 2
================================================
FILE: tests/test_cache_curriculum_sampler.py
================================================
import json
from types import SimpleNamespace
import torch
import holomotion.src.training.h5_dataloader as h5_dataloader
from holomotion.src.algo.ppo import PPO
from holomotion.src.training.h5_dataloader import (
MotionClipBatchCache,
ClipBatch,
PrioritizedInfiniteSampler,
)
def _update_sampler(
sampler: PrioritizedInfiniteSampler,
completion_rates: list[float],
*,
swap_index: int,
) -> bool:
num_windows = len(completion_rates)
window_indices = torch.arange(num_windows, dtype=torch.long)
completion_rate_means = torch.tensor(completion_rates, dtype=torch.float32)
mpkpe_signal_means = torch.zeros(num_windows, dtype=torch.float32)
counts = torch.ones(num_windows, dtype=torch.float32)
return sampler.maybe_update_from_observations(
window_indices=window_indices,
mpkpe_signal_means=mpkpe_signal_means,
completion_rate_means=completion_rate_means,
counts=counts,
swap_index=swap_index,
)
def _update_sampler_subset(
sampler: PrioritizedInfiniteSampler,
*,
window_indices: list[int],
completion_rates: list[float],
swap_index: int,
counts: list[float] | None = None,
) -> bool:
if counts is None:
counts = [1.0] * len(window_indices)
window_index_tensor = torch.tensor(window_indices, dtype=torch.long)
completion_rate_tensor = torch.tensor(
completion_rates, dtype=torch.float32
)
count_tensor = torch.tensor(counts, dtype=torch.float32)
mpkpe_signal_means = torch.zeros(len(window_indices), dtype=torch.float32)
return sampler.maybe_update_from_observations(
window_indices=window_index_tensor,
mpkpe_signal_means=mpkpe_signal_means,
completion_rate_means=completion_rate_tensor,
counts=count_tensor,
swap_index=swap_index,
)
class _ChunkLimitedSampler:
def __init__(self, *, max_query_size: int) -> None:
self.max_query_size = int(max_query_size)
self.state_version = 7
self.query_sizes: list[int] = []
def _checked_indices(self, window_indices: torch.Tensor) -> torch.Tensor:
indices = window_indices.to(dtype=torch.long).reshape(-1)
size = int(indices.numel())
self.query_sizes.append(size)
if size > self.max_query_size:
raise AssertionError(
f"expected chunked access <= {self.max_query_size}, got {size}"
)
return indices
def get_scores_for_indices(
self, window_indices: torch.Tensor
) -> torch.Tensor:
indices = self._checked_indices(window_indices)
return indices.to(dtype=torch.float32)
def get_window_state_for_indices(
self, window_indices: torch.Tensor
) -> dict[str, torch.Tensor]:
indices = self._checked_indices(window_indices)
count = int(indices.numel())
return {
"ema_completion_rate": torch.zeros(count, dtype=torch.float32),
"completion_rate_rel_improve": torch.zeros(
count, dtype=torch.float32
),
"selection_count": indices + 10,
"seen": torch.zeros(count, dtype=torch.bool),
"in_prioritized_pool": torch.zeros(count, dtype=torch.bool),
}
def get_pool_statistics(self) -> dict[str, float]:
return {
"prioritized_pool_size": 0.0,
"prioritized_pool_mean_score": 0.0,
"uniform_pool_mean_score": 0.0,
"entered_prioritized_pool_count": 0.0,
"exited_prioritized_pool_count": 0.0,
}
class _FakeTrainDataset:
def __init__(self, windows: list[SimpleNamespace]) -> None:
self.windows = windows
def __len__(self) -> int:
return len(self.windows)
def close(self) -> None:
return
def test_sampler_uses_configured_uniform_ratio_immediately():
sampler = PrioritizedInfiniteSampler(
dataset_len=8,
batch_size=10,
seed=0,
p_a_ratio=0.3,
)
assert sampler._pool_batch_sizes() == (3, 7)
def test_completion_rate_relative_improvement_alone_drives_scores():
sampler = PrioritizedInfiniteSampler(
dataset_len=3,
batch_size=2,
seed=0,
p_a_ratio=0.5,
ema_alpha_signal=0.5,
ema_alpha_rel_improve=1.0,
)
assert _update_sampler(sampler, [0.2, 0.2, 0.2], swap_index=1)
assert _update_sampler(sampler, [0.8, 0.2, 0.2], swap_index=2)
scores = sampler.get_scores_for_indices(torch.arange(3, dtype=torch.long))
assert scores[0].item() > 0.0
assert scores[1].item() == 0.0
assert scores[2].item() == 0.0
def test_sampler_weights_progress_by_remaining_difficulty():
sampler = PrioritizedInfiniteSampler(
dataset_len=2,
batch_size=2,
seed=0,
p_a_ratio=0.5,
ema_alpha_signal=1.0,
ema_alpha_rel_improve=1.0,
)
assert _update_sampler(sampler, [0.1, 0.8], swap_index=1)
assert _update_sampler(sampler, [0.2, 0.9], swap_index=2)
scores = sampler.get_scores_for_indices(torch.arange(2, dtype=torch.long))
assert scores[0].item() > scores[1].item()
def test_sampler_tracks_cumulative_selection_counts():
sampler = PrioritizedInfiniteSampler(
dataset_len=4,
batch_size=2,
seed=0,
p_a_ratio=0.5,
)
iterator = iter(sampler)
selected_indices = [next(iterator) for _ in range(4)]
state = sampler.get_window_state_for_indices(torch.arange(4))
selection_count = state["selection_count"]
expected_count = torch.bincount(
torch.tensor(selected_indices, dtype=torch.long), minlength=4
)
assert int(selection_count.sum().item()) == 4
assert torch.equal(selection_count, expected_count)
def test_low_completion_plateau_drops_from_prioritized_replay():
sampler = PrioritizedInfiniteSampler(
dataset_len=3,
batch_size=3,
seed=0,
p_a_ratio=1.0 / 3.0,
ema_alpha_signal=1.0,
ema_alpha_rel_improve=1.0,
)
assert _update_sampler(sampler, [0.1, 0.2, 0.2], swap_index=1)
assert _update_sampler(sampler, [0.4, 0.2, 0.2], swap_index=2)
assert _update_sampler(sampler, [0.4, 0.8, 0.8], swap_index=3)
state = sampler.get_window_state_for_indices(
torch.arange(3, dtype=torch.long)
)
assert not bool(state["in_prioritized_pool"][0].item())
generator = torch.Generator().manual_seed(0)
uniform_pick = sampler._sample_uniform_indices(generator, 3)
assert 0 in uniform_pick.tolist()
def test_prioritized_windows_persist_beyond_immediate_batch():
sampler = PrioritizedInfiniteSampler(
dataset_len=6,
batch_size=4,
seed=0,
p_a_ratio=0.5,
ema_alpha_signal=1.0,
ema_alpha_rel_improve=1.0,
)
assert _update_sampler_subset(
sampler,
window_indices=[0, 1],
completion_rates=[0.2, 0.2],
swap_index=1,
)
assert _update_sampler_subset(
sampler,
window_indices=[0, 1],
completion_rates=[0.8, 0.7],
swap_index=2,
)
assert _update_sampler_subset(
sampler,
window_indices=[2, 3],
completion_rates=[0.3, 0.3],
swap_index=3,
)
state = sampler.get_window_state_for_indices(torch.tensor([0, 1]))
assert torch.equal(
state["in_prioritized_pool"],
torch.tensor([True, True], dtype=torch.bool),
)
def test_sampler_reports_pool_means_and_membership_churn():
sampler = PrioritizedInfiniteSampler(
dataset_len=4,
batch_size=4,
seed=0,
p_a_ratio=0.5,
ema_alpha_signal=0.5,
ema_alpha_rel_improve=1.0,
)
assert _update_sampler(sampler, [0.2, 0.2, 0.2, 0.2], swap_index=1)
assert _update_sampler(sampler, [0.9, 0.8, 0.2, 0.2], swap_index=2)
assert _update_sampler(sampler, [0.1, 0.2, 0.9, 0.8], swap_index=3)
next(iter(sampler))
stats = sampler.get_pool_statistics()
assert stats is not None
assert set(stats) == {
"prioritized_pool_size",
"prioritized_pool_mean_score",
"uniform_pool_mean_score",
"entered_prioritized_pool_count",
"exited_prioritized_pool_count",
}
assert stats["prioritized_pool_size"] == 2.0
assert stats["entered_prioritized_pool_count"] == 2.0
assert stats["exited_prioritized_pool_count"] == 2.0
assert (
stats["prioritized_pool_mean_score"] > stats["uniform_pool_mean_score"]
)
def test_sampler_hot_path_avoids_full_dataset_temporaries(monkeypatch):
sampler = PrioritizedInfiniteSampler(
dataset_len=1_000_000,
batch_size=8,
seed=0,
p_a_ratio=0.5,
ema_alpha_signal=1.0,
ema_alpha_rel_improve=1.0,
)
orig_zeros = h5_dataloader.torch.zeros
orig_arange = h5_dataloader.torch.arange
orig_randperm = h5_dataloader.torch.randperm
def _guard_size(arg) -> int | None:
if isinstance(arg, int):
return arg
if (
isinstance(arg, tuple)
and len(arg) == 1
and isinstance(arg[0], int)
):
return arg[0]
return None
def guarded_zeros(*args, **kwargs):
size = _guard_size(args[0]) if args else None
if size == sampler._ds_len:
raise AssertionError("full-dataset zeros in hot path")
return orig_zeros(*args, **kwargs)
def guarded_arange(*args, **kwargs):
if args and args[0] == sampler._ds_len:
raise AssertionError("full-dataset arange in hot path")
return orig_arange(*args, **kwargs)
def guarded_randperm(*args, **kwargs):
if args and args[0] == sampler._ds_len:
raise AssertionError("full-dataset randperm in hot path")
return orig_randperm(*args, **kwargs)
monkeypatch.setattr(h5_dataloader.torch, "zeros", guarded_zeros)
monkeypatch.setattr(h5_dataloader.torch, "arange", guarded_arange)
monkeypatch.setattr(h5_dataloader.torch, "randperm", guarded_randperm)
assert _update_sampler_subset(
sampler,
window_indices=[5, 25, 125, 625],
completion_rates=[0.1, 0.2, 0.3, 0.4],
swap_index=1,
)
batch_indices = sampler._sample_batch_indices(
torch.Generator().manual_seed(0)
)
assert int(batch_indices.numel()) == 8
def test_ppo_logs_only_core_curriculum_metrics():
algo = PPO.__new__(PPO)
algo.actor_learning_rate = 1.0e-4
algo.critic_learning_rate = 2.0e-4
algo._last_update_metrics = {}
algo.command_name = "ref_motion"
algo._get_mean_policy_std = lambda: torch.tensor(0.0)
cache = SimpleNamespace(
swap_index=12,
cache_curriculum_pool_statistics=lambda: {
"prioritized_pool_size": 2.0,
"prioritized_pool_mean_score": 0.8,
"uniform_pool_mean_score": 0.1,
"entered_prioritized_pool_count": 1.0,
"exited_prioritized_pool_count": 1.0,
},
)
motion_cmd = SimpleNamespace(_motion_cache=cache)
algo.env = SimpleNamespace(
_env=SimpleNamespace(
command_manager=SimpleNamespace(
get_term=lambda name: motion_cmd,
)
)
)
metrics = algo._get_additional_log_metrics()
assert metrics["1-Perf/Cache/swap_index"] == 12.0
assert metrics["1-Perf/Cache/prioritized_pool_size"] == 2.0
assert metrics["1-Perf/Cache/prioritized_pool_mean_score"] == 0.8
assert metrics["1-Perf/Cache/uniform_pool_mean_score"] == 0.1
assert metrics["1-Perf/Cache/entered_prioritized_pool_count"] == 1.0
assert metrics["1-Perf/Cache/exited_prioritized_pool_count"] == 1.0
assert (
"1-Perf/Cache/curriculum_probability_coefficient_of_variation"
not in metrics
)
assert (
"1-Perf/Cache/curriculum_max_probability_over_uniform" not in metrics
)
assert "1-Perf/Cache/uniform_floor_ratio" not in metrics
def test_cache_curriculum_dumps_on_scheduled_swap_even_without_state_update():
cache = MotionClipBatchCache.__new__(MotionClipBatchCache)
cache._datasets = {}
cache._cache_curriculum_sampler = SimpleNamespace(
maybe_update_from_observations=lambda **kwargs: False,
)
dumped_swaps = []
cache._maybe_dump_cache_curriculum_scores_json = (
lambda *, swap_index: dumped_swaps.append(swap_index)
)
updated = cache.update_cache_curriculum(
window_indices=torch.tensor([0], dtype=torch.long),
mpkpe_signal_means=torch.tensor([0.0], dtype=torch.float32),
completion_rate_means=torch.tensor([0.0], dtype=torch.float32),
counts=torch.tensor([1.0], dtype=torch.float32),
swap_index=5,
)
assert updated is False
assert dumped_swaps == [5]
def test_update_cache_curriculum_refreshes_prefetched_batch_when_state_changes():
cache = MotionClipBatchCache.__new__(MotionClipBatchCache)
cache._datasets = {}
cache._cache_curriculum_sampler = SimpleNamespace(
maybe_update_from_observations=lambda **kwargs: True,
)
cache._cache_curriculum_dump_enabled = False
cache._next_batch = ClipBatch(
tensors={},
lengths=torch.tensor([1], dtype=torch.long),
motion_keys=["stale"],
raw_motion_keys=["stale"],
window_indices=torch.tensor([0], dtype=torch.long),
max_frame_length=1,
)
refreshed_batch = ClipBatch(
tensors={},
lengths=torch.tensor([1], dtype=torch.long),
motion_keys=["fresh"],
raw_motion_keys=["fresh"],
window_indices=torch.tensor([1], dtype=torch.long),
max_frame_length=1,
)
cache._fetch_next_batch = lambda: refreshed_batch
updated = cache.update_cache_curriculum(
window_indices=torch.tensor([0], dtype=torch.long),
mpkpe_signal_means=torch.tensor([0.0], dtype=torch.float32),
completion_rate_means=torch.tensor([0.0], dtype=torch.float32),
counts=torch.tensor([1.0], dtype=torch.float32),
swap_index=5,
)
assert updated is True
assert cache._next_batch.motion_keys == ["fresh"]
def test_cache_curriculum_whole_window_dump_streams_rows_in_chunks(
tmp_path,
):
cache = MotionClipBatchCache.__new__(MotionClipBatchCache)
cache._datasets = {
"train": _FakeTrainDataset(
[
SimpleNamespace(
raw_motion_key=f"raw_{idx}",
motion_key=f"motion_{idx}",
start=idx,
length=idx + 1,
)
for idx in range(5)
]
)
}
sampler = _ChunkLimitedSampler(max_query_size=2)
cache._cache_curriculum_sampler = sampler
cache._cache_curriculum_dump_enabled = True
cache._cache_curriculum_dump_every_swaps = 1
cache._cache_curriculum_dump_chunk_size = 2
cache._cache_curriculum_last_dump_swap = -1
cache._cache_curriculum_dump_dir = tmp_path
cache._sampler_rank = 3
cache._maybe_dump_cache_curriculum_scores_json(swap_index=1)
output_path = tmp_path / "whole_window_scores_rank_0003_swap_000001.json"
payload = json.loads(output_path.read_text())
assert output_path.exists()
assert payload["swap_index"] == 1
assert payload["rank"] == 3
assert payload["sampler_state_version"] == 7
assert payload["num_windows"] == 5
assert len(payload["rows"]) == 5
assert payload["rows"][0]["window_index"] == 0
assert payload["rows"][0]["selection_count"] == 10
assert payload["rows"][-1]["window_index"] == 4
assert payload["rows"][-1]["selection_count"] == 14
assert "probability" not in payload["rows"][0]
assert max(sampler.query_sizes) == 2
================================================
FILE: tests/test_domain_rand_config_builder.py
================================================
import importlib.util
import sys
from pathlib import Path
from types import ModuleType
MODULE_PATH = (
Path(__file__).resolve().parents[1]
/ "holomotion"
/ "src"
/ "env"
/ "isaaclab_components"
/ "isaaclab_domain_rand.py"
)
class _DummyEventTermCfg:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class _DummySceneEntityCfg:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def resolve(self, _scene):
return None
def _load_domain_rand_module(monkeypatch):
isaaclab = ModuleType("isaaclab")
isaaclab_utils = ModuleType("isaaclab.utils")
isaaclab_utils.configclass = lambda cls: cls
isaaclab_utils_math = ModuleType("isaaclab.utils.math")
isaaclab_assets = ModuleType("isaaclab.assets")
isaaclab_assets.Articulation = object
isaaclab_envs = ModuleType("isaaclab.envs")
isaaclab_envs.ManagerBasedEnv = object
isaaclab_envs_mdp = ModuleType("isaaclab.envs.mdp")
isaaclab_envs_mdp.events = ModuleType("isaaclab.envs.mdp.events")
isaaclab_envs_mdp.events._randomize_prop_by_op = (
lambda *args, **kwargs: None
)
isaaclab_managers = ModuleType("isaaclab.managers")
isaaclab_managers.SceneEntityCfg = _DummySceneEntityCfg
isaaclab_managers.EventTermCfg = _DummyEventTermCfg
isaaclab.envs = isaaclab_envs
isaaclab.assets = isaaclab_assets
isaaclab.utils = isaaclab_utils
isaaclab_envs.mdp = isaaclab_envs_mdp
isaaclab_utils.math = isaaclab_utils_math
for name, module in {
"isaaclab": isaaclab,
"isaaclab.utils": isaaclab_utils,
"isaaclab.utils.math": isaaclab_utils_math,
"isaaclab.assets": isaaclab_assets,
"isaaclab.envs": isaaclab_envs,
"isaaclab.envs.mdp": isaaclab_envs_mdp,
"isaaclab.envs.mdp.events": isaaclab_envs_mdp.events,
"isaaclab.managers": isaaclab_managers,
}.items():
monkeypatch.setitem(sys.modules, name, module)
module_name = "_test_domain_rand_builder"
spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def test_build_domain_rand_config_skips_non_event_metadata(monkeypatch):
module = _load_domain_rand_module(monkeypatch)
events_cfg = module.build_domain_rand_config(
{
"erfi": {
"enabled": True,
"rfi_probability": 0.5,
"rfi_lim": 0.1,
"randomize_rfi_lim": True,
"rfi_lim_range": [0.5, 1.5],
"rao_lim": 0.1,
},
"action_delay": {
"enabled": True,
"min_delay": 1,
"max_delay": 3,
},
"motion_init_perturb": {
"root_pose_perturb_range": {"x": [-0.1, 0.1]}
},
"obs_noise": {"actor_dof_pos": {"n_min": -0.01, "n_max": 0.01}},
"default_dof_pos_bias": {
"mode": "startup",
"params": {
"joint_names": [".*"],
"pos_distribution_params": [-0.01, 0.01],
"operation": "add",
"distribution": "uniform",
},
},
}
)
assert hasattr(events_cfg, "default_dof_pos_bias")
assert events_cfg.default_dof_pos_bias.mode == "startup"
assert not hasattr(events_cfg, "erfi")
assert not hasattr(events_cfg, "action_delay")
assert not hasattr(events_cfg, "motion_init_perturb")
assert not hasattr(events_cfg, "obs_noise")
================================================
FILE: tests/test_eval_mujoco_action_delay.py
================================================
from collections import deque
import numpy as np
from omegaconf import OmegaConf
import holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim
def test_action_delay_cfg_defaults_to_disabled_episode():
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.config = OmegaConf.create({})
max_delay_step, delay_type = evaluator._get_action_delay_cfg()
assert max_delay_step == 0
assert delay_type == "episode"
def test_action_delay_cfg_rejects_invalid_delay_type():
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.config = OmegaConf.create(
{
"policy_action_delay_step": 2,
"action_delay_type": "frame",
}
)
try:
evaluator._get_action_delay_cfg()
except ValueError as exc:
assert "action_delay_type" in str(exc)
else:
raise AssertionError("Expected ValueError for invalid delay type.")
def test_apply_action_delay_passthrough_when_disabled():
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.policy_action_delay_step = 0
evaluator.action_delay_type = "episode"
evaluator._policy_action_delay_buffer = deque(maxlen=1)
evaluator._current_policy_action_delay_step = 0
delayed = evaluator._apply_action_delay(
np.array([1.0, -1.0], dtype=np.float32)
)
np.testing.assert_allclose(
delayed, np.array([1.0, -1.0], dtype=np.float32)
)
def test_apply_action_delay_episode_reuses_single_sample(monkeypatch):
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.policy_action_delay_step = 1
evaluator.action_delay_type = "episode"
calls = []
def fake_randint(low, high):
calls.append((low, high))
return 1
monkeypatch.setattr(eval_mujoco_sim2sim.np.random, "randint", fake_randint)
evaluator._reset_action_delay_randomization()
first = evaluator._apply_action_delay(np.array([1.0], dtype=np.float32))
second = evaluator._apply_action_delay(np.array([2.0], dtype=np.float32))
third = evaluator._apply_action_delay(np.array([3.0], dtype=np.float32))
assert calls == [(0, 2)]
assert evaluator._current_policy_action_delay_step == 1
np.testing.assert_allclose(first, np.array([1.0], dtype=np.float32))
np.testing.assert_allclose(second, np.array([1.0], dtype=np.float32))
np.testing.assert_allclose(third, np.array([2.0], dtype=np.float32))
def test_apply_action_delay_step_resamples_each_policy_step(monkeypatch):
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.policy_action_delay_step = 2
evaluator.action_delay_type = "step"
evaluator._policy_action_delay_buffer = deque(maxlen=3)
evaluator._current_policy_action_delay_step = 0
sampled_delays = iter([2, 0, 1])
calls = []
def fake_randint(low, high):
calls.append((low, high))
return next(sampled_delays)
monkeypatch.setattr(eval_mujoco_sim2sim.np.random, "randint", fake_randint)
first = evaluator._apply_action_delay(np.array([1.0], dtype=np.float32))
second = evaluator._apply_action_delay(np.array([2.0], dtype=np.float32))
third = evaluator._apply_action_delay(np.array([3.0], dtype=np.float32))
assert calls == [(0, 3), (0, 3), (0, 3)]
assert evaluator._current_policy_action_delay_step == 1
np.testing.assert_allclose(first, np.array([1.0], dtype=np.float32))
np.testing.assert_allclose(second, np.array([2.0], dtype=np.float32))
np.testing.assert_allclose(third, np.array([2.0], dtype=np.float32))
================================================
FILE: tests/test_eval_mujoco_action_ema.py
================================================
import numpy as np
from omegaconf import OmegaConf
import holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim
def test_action_ema_filter_cfg_reads_erfi_settings():
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.config = OmegaConf.create(
{
"robot": {
"actuators": {
"actuator_type": "unitree_erfi",
"ema_filter_enabled": True,
"ema_filter_alpha": 0.37,
}
},
}
)
enabled, alpha = evaluator._get_action_ema_filter_cfg()
assert enabled is True
assert alpha == 0.37
def test_action_ema_filter_defaults_to_disabled_for_non_erfi():
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.config = OmegaConf.create(
{
"robot": {
"actuators": {
"actuator_type": "unitree",
"ema_filter_enabled": True,
"ema_filter_alpha": 0.37,
}
},
}
)
enabled, alpha = evaluator._get_action_ema_filter_cfg()
assert enabled is False
assert alpha == 1.0
def test_apply_action_ema_filter_uses_previous_filtered_action():
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.action_ema_filter_enabled = True
evaluator.action_ema_filter_alpha = 0.25
evaluator._filtered_actions_onnx = None
first = evaluator._apply_action_ema_filter(
np.array([1.0, -1.0], dtype=np.float32)
)
second = evaluator._apply_action_ema_filter(
np.array([3.0, 1.0], dtype=np.float32)
)
np.testing.assert_allclose(first, np.array([1.0, -1.0], dtype=np.float32))
np.testing.assert_allclose(second, np.array([1.5, -0.5], dtype=np.float32))
def test_reset_action_ema_filter_clears_state():
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator._filtered_actions_onnx = np.array([1.0], dtype=np.float32)
evaluator._reset_action_ema_filter()
assert evaluator._filtered_actions_onnx is None
================================================
FILE: tests/test_eval_mujoco_contact_export.py
================================================
import json
from pathlib import Path
import numpy as np
from omegaconf import OmegaConf
import holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim
def _build_export_evaluator(tmp_path: Path):
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.simulation_dt = 0.005
evaluator._get_stacked_moe_routing_tensors = lambda: (None, None)
evaluator._robot_dof_pos_seq = [
np.array([0.0, 1.0], dtype=np.float32),
np.array([0.5, 1.5], dtype=np.float32),
]
evaluator._robot_dof_vel_seq = [
np.array([0.1, 0.2], dtype=np.float32),
np.array([0.3, 0.4], dtype=np.float32),
]
evaluator._robot_dof_acc_seq = [
np.array([1.0, 2.0], dtype=np.float32),
np.array([3.0, 4.0], dtype=np.float32),
]
evaluator._robot_dof_torque_seq = [
np.array([5.0, 6.0], dtype=np.float32),
np.array([7.0, 8.0], dtype=np.float32),
]
evaluator._robot_low_level_dof_torque_seq = [
np.array([1.0, 2.0], dtype=np.float32),
np.array([3.0, 4.0], dtype=np.float32),
np.array([5.0, 6.0], dtype=np.float32),
np.array([7.0, 8.0], dtype=np.float32),
]
evaluator._robot_low_level_foot_contact_seq = [
np.array([1.0, 0.0], dtype=np.float32),
np.array([1.0, 1.0], dtype=np.float32),
np.array([0.0, 1.0], dtype=np.float32),
np.array([0.0, 0.0], dtype=np.float32),
]
evaluator._robot_low_level_foot_normal_force_seq = [
np.array([50.0, 0.0], dtype=np.float32),
np.array([60.0, 55.0], dtype=np.float32),
np.array([0.0, 45.0], dtype=np.float32),
np.array([0.0, 0.0], dtype=np.float32),
]
evaluator._robot_low_level_foot_tangent_speed_seq = [
np.array([0.02, 0.0], dtype=np.float32),
np.array([0.03, 0.04], dtype=np.float32),
np.array([0.0, 0.05], dtype=np.float32),
np.array([0.0, 0.0], dtype=np.float32),
]
evaluator._robot_action_rate_seq = [
np.float32(0.0),
np.float32(1.0),
]
evaluator._robot_actions_seq = [
np.array([0.11, 0.22], dtype=np.float32),
np.array([0.33, 0.44], dtype=np.float32),
]
evaluator._robot_global_translation_seq = [
np.zeros((2, 3), dtype=np.float32),
np.ones((2, 3), dtype=np.float32),
]
evaluator._robot_global_rotation_quat_seq = [
np.tile(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 1)),
np.tile(np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 1)),
]
evaluator._robot_global_velocity_seq = [
np.zeros((2, 3), dtype=np.float32),
np.ones((2, 3), dtype=np.float32),
]
evaluator._robot_global_angular_velocity_seq = [
np.zeros((2, 3), dtype=np.float32),
np.ones((2, 3), dtype=np.float32),
]
evaluator.ref_dof_pos = np.zeros((2, 2), dtype=np.float32)
evaluator.ref_dof_vel = np.zeros((2, 2), dtype=np.float32)
evaluator.ref_global_translation = np.zeros((2, 2, 3), dtype=np.float32)
evaluator.ref_global_rotation_quat_xyzw = np.tile(
np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32), (2, 2, 1)
)
evaluator.ref_global_velocity = np.zeros((2, 2, 3), dtype=np.float32)
evaluator.ref_global_angular_velocity = np.zeros(
(2, 2, 3), dtype=np.float32
)
motion_npz_path = tmp_path / "motion.npz"
np.savez_compressed(
motion_npz_path,
metadata=np.array(json.dumps({"clip_length": 2}), dtype=np.str_),
)
evaluator.config = OmegaConf.create(
{
"motion_npz_path": str(motion_npz_path),
"ckpt_onnx_path": str(tmp_path / "model.onnx"),
}
)
return evaluator
def test_save_batch_result_exports_low_level_contact_traces(tmp_path: Path):
evaluator = _build_export_evaluator(tmp_path)
output_path = tmp_path / "batch_result.npz"
evaluator.save_batch_result(str(output_path), {"clip_length": 2})
with np.load(output_path, allow_pickle=True) as data:
assert "robot_actions" in data.files
assert "robot_low_level_foot_contact" in data.files
assert "robot_low_level_foot_normal_force" in data.files
assert "robot_low_level_foot_tangent_speed" in data.files
assert "robot_low_level_contact_dt" in data.files
np.testing.assert_allclose(
data["robot_actions"],
np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32),
)
assert data["robot_low_level_foot_contact"].shape == (4, 2)
np.testing.assert_allclose(
data["robot_low_level_contact_dt"], np.array(0.005, np.float32)
)
def test_dump_robot_augmented_npz_exports_low_level_contact_traces(
tmp_path: Path,
):
evaluator = _build_export_evaluator(tmp_path)
evaluator._dump_robot_augmented_npz()
output_path = (
tmp_path
/ "mujoco_output_model"
/ f"{Path(evaluator.config.motion_npz_path).stem}_robot.npz"
)
with np.load(output_path, allow_pickle=True) as data:
assert "robot_actions" in data.files
assert "robot_low_level_foot_contact" in data.files
assert "robot_low_level_foot_normal_force" in data.files
assert "robot_low_level_foot_tangent_speed" in data.files
assert "robot_low_level_contact_dt" in data.files
np.testing.assert_allclose(
data["robot_actions"],
np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32),
)
assert data["robot_low_level_foot_normal_force"].shape == (4, 2)
def test_init_low_level_foot_contact_logging_falls_back_to_ankle_roll_bodies(
monkeypatch,
):
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.config = OmegaConf.create({"robot": {}})
evaluator.m = type(
"FakeModel",
(),
{
"geom_bodyid": np.array([5, 6, 6, 9, 10], dtype=np.int32),
"geom_contype": np.array([0, 1, 1, 0, 1], dtype=np.int32),
"geom_conaffinity": np.array([0, 1, 1, 0, 1], dtype=np.int32),
},
)()
def fake_name2id(model, obj_type, name):
if obj_type == eval_mujoco_sim2sim.mujoco.mjtObj.mjOBJ_GEOM:
return -1
if obj_type == eval_mujoco_sim2sim.mujoco.mjtObj.mjOBJ_BODY:
return {
"left_ankle_roll_link": 6,
"right_ankle_roll_link": 10,
}.get(name, -1)
return -1
monkeypatch.setattr(eval_mujoco_sim2sim.mujoco, "mj_name2id", fake_name2id)
evaluator._init_low_level_foot_contact_logging()
assert evaluator._foot_contact_logging_enabled is True
assert evaluator._foot_geom_id_groups == [[1, 2], [4]]
assert evaluator._foot_geom_id_to_side == {1: 0, 2: 0, 4: 1}
================================================
FILE: tests/test_eval_mujoco_s100_horizon_ptq.py
================================================
import sys
from pathlib import Path
from types import SimpleNamespace
import numpy as np
import pytest
from omegaconf import OmegaConf
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import holomotion.src.evaluation.eval_mujoco_sim2sim_s100 as eval_mujoco_sim2sim_s100
class _FakeIoNode:
def __init__(self, name, shape):
self.name = name
self.shape = shape
def _make_value_info(name, shape):
dims = [SimpleNamespace(dim_value=dim) for dim in shape]
tensor_shape = SimpleNamespace(dim=dims)
tensor_type = SimpleNamespace(shape=tensor_shape)
return SimpleNamespace(
name=name, type=SimpleNamespace(tensor_type=tensor_type)
)
def _make_fake_onnx_model():
return SimpleNamespace(
graph=SimpleNamespace(
input=[
_make_value_info("obs", [1, 16]),
_make_value_info("past_key_values", [1, 2, 3, 4]),
_make_value_info("step_idx", [1]),
],
output=[
_make_value_info("action", [1, 12]),
_make_value_info("present_key_values", [1, 2, 3, 4]),
],
)
)
def _make_evaluator(model_path: Path, bc_path: Path | None = None):
config_dict = {
"ckpt_onnx_path": str(model_path),
"use_gpu": False,
"gpu_id": 0,
}
if bc_path is not None:
config_dict["bc_path"] = str(bc_path)
evaluator = eval_mujoco_sim2sim_s100.MujocoEvaluator.__new__(
eval_mujoco_sim2sim_s100.MujocoEvaluator
)
evaluator.config = OmegaConf.create(config_dict)
evaluator.max_context_len = 0
evaluator._discover_policy_moe_outputs = lambda: None
return evaluator
def test_load_policy_falls_back_to_horizon_quantized_bc_for_ptq_onnx(
monkeypatch, tmp_path
):
model_path = tmp_path / "demo_ptq_model.onnx"
model_path.write_bytes(b"onnx")
quantized_path = tmp_path / "demo_quantized_model.bc"
quantized_path.write_bytes(b"bc")
captured = {}
class _FakeHBRuntime:
def __init__(self, model_path):
captured["hb_model_path"] = model_path
self.input_names = ["obs", "past_key_values", "step_idx"]
self.output_names = ["action", "present_key_values"]
def run(self, output_names, input_feed):
raise AssertionError("run should not be called in this test")
def _raise_hz_calibration(*args, **kwargs):
raise RuntimeError("Failed to load custom op HzCalibration")
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"get_available_providers",
lambda: ["CPUExecutionProvider"],
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"InferenceSession",
_raise_hz_calibration,
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100,
"HBRuntime",
_FakeHBRuntime,
raising=False,
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnx,
"load",
lambda _: _make_fake_onnx_model(),
)
evaluator = _make_evaluator(model_path)
evaluator.load_policy()
assert captured["hb_model_path"] == str(quantized_path)
assert evaluator.policy_input_name == "obs"
assert evaluator.policy_kv_input_name == "past_key_values"
assert evaluator.policy_step_input_name == "step_idx"
assert evaluator.policy_output_name == "action"
assert evaluator.policy_kv_output_name == "present_key_values"
assert evaluator.policy_model_context_len == 4
@pytest.mark.parametrize(
"runtime_name",
[
"demo_model_16000_ptq_model.bc",
"demo_model_16000_ptq_model.hbm",
"demo_model_16000_quantized_model.hbm",
],
)
def test_load_policy_resolves_common_horizon_runtime_artifact_names(
monkeypatch, tmp_path, runtime_name
):
model_path = tmp_path / "demo_model_16000_ptq_model.onnx"
model_path.write_bytes(b"onnx")
runtime_path = tmp_path / runtime_name
runtime_path.write_bytes(b"runtime")
captured = {}
class _FakeHBRuntime:
def __init__(self, model_path):
captured["hb_model_path"] = model_path
self.input_names = ["obs", "past_key_values", "step_idx"]
self.output_names = ["action", "present_key_values"]
def run(self, output_names, input_feed):
raise AssertionError("run should not be called in this test")
def _raise_hz_calibration(*args, **kwargs):
raise RuntimeError("Failed to load custom op HzCalibration")
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"get_available_providers",
lambda: ["CPUExecutionProvider"],
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"InferenceSession",
_raise_hz_calibration,
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100,
"HBRuntime",
_FakeHBRuntime,
raising=False,
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnx,
"load",
lambda _: _make_fake_onnx_model(),
)
evaluator = _make_evaluator(model_path)
evaluator.load_policy()
assert captured["hb_model_path"] == str(runtime_path)
assert evaluator.policy_model_context_len == 4
def test_load_policy_raises_original_error_when_ptq_fallback_bc_missing(
monkeypatch, tmp_path
):
model_path = tmp_path / "demo_ptq_model.onnx"
model_path.write_bytes(b"onnx")
def _raise_hz_calibration(*args, **kwargs):
raise RuntimeError("Failed to load custom op HzCalibration")
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"get_available_providers",
lambda: ["CPUExecutionProvider"],
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"InferenceSession",
_raise_hz_calibration,
)
evaluator = _make_evaluator(model_path)
with pytest.raises(RuntimeError, match="HzCalibration"):
evaluator.load_policy()
def test_load_policy_keeps_standard_onnxruntime_path_for_regular_onnx(
monkeypatch, tmp_path
):
model_path = tmp_path / "demo_model.onnx"
model_path.write_bytes(b"onnx")
captured = {}
class _FakeInferenceSession:
def __init__(self, model_path, sess_options, providers):
captured["model_path"] = model_path
captured["providers"] = providers
def get_providers(self):
return ["CPUExecutionProvider"]
def get_inputs(self):
return [
_FakeIoNode("obs", [1, 16]),
_FakeIoNode("past_key_values", [1, 2, 3, 4]),
_FakeIoNode("step_idx", [1]),
]
def get_outputs(self):
return [
_FakeIoNode("action", [1, 12]),
_FakeIoNode("present_key_values", [1, 2, 3, 4]),
]
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"get_available_providers",
lambda: ["CPUExecutionProvider"],
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"InferenceSession",
_FakeInferenceSession,
)
evaluator = _make_evaluator(model_path)
evaluator.load_policy()
assert captured["model_path"] == str(model_path)
assert captured["providers"] == ["CPUExecutionProvider"]
assert evaluator.policy_model_context_len == 4
def test_load_policy_prefers_explicit_bc_path_for_inference_and_onnx_for_metadata(
monkeypatch, tmp_path
):
model_path = tmp_path / "demo_model.onnx"
model_path.write_bytes(b"onnx")
runtime_path = tmp_path / "demo_quantized_model.bc"
runtime_path.write_bytes(b"bc")
captured = {}
class _FakeHBRuntime:
def __init__(self, model_path):
captured["hb_model_path"] = model_path
self.input_names = ["obs", "past_key_values", "step_idx"]
self.output_names = ["action", "present_key_values"]
def run(self, output_names, input_feed):
raise AssertionError("run should not be called in this test")
def _unexpected_ort_session(*args, **kwargs):
raise AssertionError(
"InferenceSession should not be created when bc_path is set"
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"get_available_providers",
lambda: ["CPUExecutionProvider"],
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnxruntime,
"InferenceSession",
_unexpected_ort_session,
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100,
"HBRuntime",
_FakeHBRuntime,
raising=False,
)
monkeypatch.setattr(
eval_mujoco_sim2sim_s100.onnx,
"load",
lambda _: _make_fake_onnx_model(),
)
evaluator = _make_evaluator(model_path, bc_path=runtime_path)
evaluator.load_policy()
assert captured["hb_model_path"] == str(runtime_path)
assert evaluator.policy_input_name == "obs"
assert evaluator.policy_kv_input_name == "past_key_values"
assert evaluator.policy_step_input_name == "step_idx"
assert evaluator.policy_output_name == "action"
assert evaluator.policy_kv_output_name == "present_key_values"
assert evaluator.policy_model_context_len == 4
def test_bc_runtime_run_normalizes_inputs_for_hbruntime(monkeypatch, tmp_path):
runtime_path = tmp_path / "demo_quantized_model.bc"
runtime_path.write_bytes(b"bc")
captured = {}
class _FakeHBRuntime:
def __init__(self, model_path):
captured["model_path"] = model_path
self.input_names = ["obs", "past_key_values", "step_idx"]
self.output_names = ["action", "present_key_values"]
def run(self, output_names, input_feed):
captured["output_names"] = list(output_names)
captured["input_feed"] = input_feed
return ["ok"]
monkeypatch.setattr(
eval_mujoco_sim2sim_s100,
"HBRuntime",
_FakeHBRuntime,
raising=False,
)
wrapper = eval_mujoco_sim2sim_s100._HbSessionWrapper(runtime_path)
obs = np.arange(6, dtype=np.float64).reshape(2, 3).T
past_key_values = np.arange(24, dtype=np.float64).reshape(2, 3, 4)
step_idx = np.array([7], dtype=np.int32)
outputs = wrapper.run(
["action"],
{
"obs": obs,
"past_key_values": past_key_values,
"step_idx": step_idx,
},
)
assert outputs == ["ok"]
assert captured["model_path"] == str(runtime_path)
assert captured["output_names"] == ["action"]
assert captured["input_feed"]["obs"].dtype == np.float32
assert captured["input_feed"]["obs"].flags["C_CONTIGUOUS"]
assert captured["input_feed"]["past_key_values"].dtype == np.float32
assert captured["input_feed"]["past_key_values"].flags["C_CONTIGUOUS"]
assert captured["input_feed"]["step_idx"].dtype == np.int64
assert captured["input_feed"]["step_idx"].flags["C_CONTIGUOUS"]
def test_update_policy_raises_clear_error_before_runtime_on_obs_dim_mismatch():
evaluator = eval_mujoco_sim2sim_s100.MujocoEvaluator.__new__(
eval_mujoco_sim2sim_s100.MujocoEvaluator
)
evaluator._record_robot_states = lambda: None
evaluator.obs_builder = SimpleNamespace(
build_policy_obs=lambda: np.zeros(425, dtype=np.float32)
)
evaluator.policy_input_name = "obs"
evaluator.policy_output_name = "action"
evaluator.policy_obs_expected_dim = 786
evaluator.use_kv_cache = False
evaluator.policy_step_input_name = None
evaluator.policy_kv_output_name = None
evaluator.policy_moe_layer_output_names = []
evaluator.dump_onnx_io_npy = False
evaluator.counter = 0
evaluator.command_mode = "velocity_tracking"
evaluator.config = OmegaConf.create(
{"motion_npz_dir": "", "motion_npz_path": ""}
)
evaluator.policy_session = SimpleNamespace(
run=lambda *args, **kwargs: (_ for _ in ()).throw(
AssertionError("runtime should not be called on shape mismatch")
)
)
with pytest.raises(
ValueError, match="expects 786 features but evaluator built 425"
):
evaluator._update_policy()
================================================
FILE: tests/test_eval_mujoco_use_gpu.py
================================================
import sys
from pathlib import Path
from omegaconf import OmegaConf
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import holomotion.src.evaluation.eval_mujoco_sim2sim as eval_mujoco_sim2sim
class _FakeIoNode:
def __init__(self, name, shape):
self.name = name
self.shape = shape
def test_load_policy_treats_false_string_use_gpu_as_cpu(monkeypatch):
captured = {}
class _FakeInferenceSession:
def __init__(self, model_path, sess_options, providers):
captured["model_path"] = model_path
captured["providers"] = providers
def get_providers(self):
return ["CPUExecutionProvider"]
def get_inputs(self):
return [_FakeIoNode("obs", [1, 16])]
def get_outputs(self):
return [_FakeIoNode("action", [1, 12])]
monkeypatch.setattr(
eval_mujoco_sim2sim.onnxruntime,
"get_available_providers",
lambda: ["CUDAExecutionProvider", "CPUExecutionProvider"],
)
monkeypatch.setattr(
eval_mujoco_sim2sim.onnxruntime,
"InferenceSession",
_FakeInferenceSession,
)
evaluator = eval_mujoco_sim2sim.MujocoEvaluator.__new__(
eval_mujoco_sim2sim.MujocoEvaluator
)
evaluator.config = OmegaConf.create(
{
"ckpt_onnx_path": "model.onnx",
"use_gpu": "false",
"gpu_id": 3,
}
)
evaluator.max_context_len = 0
evaluator.load_policy()
assert captured["model_path"] == "model.onnx"
assert captured["providers"] == ["CPUExecutionProvider"]
def test_create_ray_evaluator_preserves_use_gpu_false(monkeypatch):
captured = {}
class _FakeEvaluator:
def __init__(self, config):
captured["use_gpu"] = config.use_gpu
captured["gpu_id"] = config.gpu_id
monkeypatch.setattr(eval_mujoco_sim2sim, "MujocoEvaluator", _FakeEvaluator)
eval_mujoco_sim2sim._create_ray_evaluator(
{"use_gpu": False, "gpu_id": 5}, "holomotion"
)
assert captured["use_gpu"] is False
assert captured["gpu_id"] == 5
def test_run_mujoco_sim2sim_eval_preserves_use_gpu_false(
monkeypatch, tmp_path
):
captured = {}
class _FakeEvaluator:
def __init__(self, config):
captured["use_gpu"] = config.use_gpu
def setup(self):
captured["setup"] = True
def run_simulation(self):
captured["run_simulation"] = True
monkeypatch.setattr(
eval_mujoco_sim2sim.hydra.utils,
"get_original_cwd",
lambda: str(tmp_path),
)
monkeypatch.setattr(
eval_mujoco_sim2sim,
"process_config",
lambda _: OmegaConf.create(
{
"use_gpu": False,
"model_type": "holomotion",
}
),
)
monkeypatch.setattr(eval_mujoco_sim2sim, "MujocoEvaluator", _FakeEvaluator)
eval_mujoco_sim2sim.run_mujoco_sim2sim_eval(OmegaConf.create({}))
assert captured["use_gpu"] is False
assert captured["setup"] is True
assert captured["run_simulation"] is True
================================================
FILE: tests/test_eval_onnx_io_dump.py
================================================
import json
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import numpy as np
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.evaluation.eval_mujoco_sim2sim import (
MujocoEvaluator,
write_onnx_io_dump_readme,
)
from holomotion.src.evaluation.ray_evaluator_actor import RayEvaluatorActor
class _Config(SimpleNamespace):
def get(self, key, default=None):
return getattr(self, key, default)
def test_save_onnx_io_dump_stacks_per_frame_inputs_and_outputs(tmp_path):
evaluator = MujocoEvaluator.__new__(MujocoEvaluator)
evaluator._reset_onnx_io_dump_buffers()
evaluator._record_onnx_io_frame(
input_feed={
"obs": np.array([[1.0, 2.0]], dtype=np.float32),
"step": np.array([0], dtype=np.int64),
},
output_names=["action", "kv_cache"],
onnx_output=[
np.array([[0.1, 0.2]], dtype=np.float32),
np.array([[[3.0, 4.0]]], dtype=np.float32),
],
)
evaluator._record_onnx_io_frame(
input_feed={
"obs": np.array([[5.0, 6.0]], dtype=np.float32),
"step": np.array([1], dtype=np.int64),
},
output_names=["action", "kv_cache"],
onnx_output=[
np.array([[0.3, 0.4]], dtype=np.float32),
np.array([[[7.0, 8.0]]], dtype=np.float32),
],
)
output_path = tmp_path / "clip_onnx_io.npy"
evaluator.save_onnx_io_dump(
output_path,
{
"source_npz": "clip.npz",
"onnx_model": "model.onnx",
},
)
payload = np.load(output_path, allow_pickle=True).item()
assert payload["input_names"] == ["obs", "step"]
assert payload["output_names"] == ["action", "kv_cache"]
np.testing.assert_allclose(
payload["inputs"]["obs"],
np.array([[[1.0, 2.0]], [[5.0, 6.0]]], dtype=np.float32),
)
np.testing.assert_array_equal(
payload["inputs"]["step"],
np.array([[0], [1]], dtype=np.int64),
)
np.testing.assert_allclose(
payload["outputs"]["action"],
np.array([[[0.1, 0.2]], [[0.3, 0.4]]], dtype=np.float32),
)
np.testing.assert_allclose(
payload["outputs"]["kv_cache"],
np.array([[[[3.0, 4.0]]], [[[7.0, 8.0]]]], dtype=np.float32),
)
assert payload["source_npz"] == "clip.npz"
assert payload["onnx_model"] == "model.onnx"
def test_write_onnx_io_dump_readme_creates_chinese_loading_instructions(
tmp_path,
):
readme_path = write_onnx_io_dump_readme(tmp_path)
assert readme_path == tmp_path / "README.md"
content = readme_path.read_text(encoding="utf-8")
assert "每个动作片段会生成一个 `.npy` 文件" in content
assert "allow_pickle=True" in content
assert "np.load(npy_path, allow_pickle=True).item()" in content
def test_save_batch_result_persists_low_level_torque_dump_and_dt(tmp_path):
evaluator = MujocoEvaluator.__new__(MujocoEvaluator)
evaluator.policy_dt = 0.02
evaluator.simulation_dt = 0.005
evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)]
evaluator._robot_low_level_dof_torque_seq = [
np.array([1.0, -1.0], dtype=np.float32),
np.array([-1.0, 1.0], dtype=np.float32),
]
evaluator._robot_action_rate_seq = [np.float32(0.0)]
evaluator._robot_global_translation_seq = [
np.zeros((1, 3), dtype=np.float32)
]
evaluator._robot_global_rotation_quat_seq = [
np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32)
]
evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)]
evaluator._robot_global_angular_velocity_seq = [
np.zeros((1, 3), dtype=np.float32)
]
evaluator.ref_dof_pos = np.zeros((1, 2), dtype=np.float32)
evaluator.ref_dof_vel = np.zeros((1, 2), dtype=np.float32)
evaluator.ref_global_translation = np.zeros((1, 1, 3), dtype=np.float32)
evaluator.ref_global_rotation_quat_xyzw = np.array(
[[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32
)
evaluator.ref_global_velocity = np.zeros((1, 1, 3), dtype=np.float32)
evaluator.ref_global_angular_velocity = np.zeros(
(1, 1, 3), dtype=np.float32
)
output_path = tmp_path / "demo_eval.npz"
evaluator.save_batch_result(
output_path, {"source_file": "clip.npz", "clip_length": 1}
)
with np.load(output_path, allow_pickle=True) as payload:
metadata = json.loads(payload["metadata"].item())
np.testing.assert_allclose(
payload["robot_low_level_dof_torque"],
np.array([[1.0, -1.0], [-1.0, 1.0]], dtype=np.float32),
)
assert metadata["source_file"] == "clip.npz"
assert metadata["clip_length"] == 1
assert metadata["robot_low_level_torque_dt"] == 0.005
def test_save_batch_result_persists_moe_routing_tensors(tmp_path):
evaluator = MujocoEvaluator.__new__(MujocoEvaluator)
evaluator.policy_dt = 0.02
evaluator.simulation_dt = 0.005
evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)]
evaluator._robot_low_level_dof_torque_seq = [
np.array([0.5, -0.5], dtype=np.float32)
]
evaluator._robot_action_rate_seq = [np.float32(0.0)]
evaluator._robot_global_translation_seq = [
np.zeros((1, 3), dtype=np.float32)
]
evaluator._robot_global_rotation_quat_seq = [
np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32)
]
evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)]
evaluator._robot_global_angular_velocity_seq = [
np.zeros((1, 3), dtype=np.float32)
]
evaluator._robot_moe_expert_indices_seq = [
np.array([[1, 3], [0, 2]], dtype=np.int64)
]
evaluator._robot_moe_expert_logits_seq = [
np.array(
[[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]],
dtype=np.float32,
)
]
evaluator.ref_dof_pos = np.zeros((1, 2), dtype=np.float32)
evaluator.ref_dof_vel = np.zeros((1, 2), dtype=np.float32)
evaluator.ref_global_translation = np.zeros((1, 1, 3), dtype=np.float32)
evaluator.ref_global_rotation_quat_xyzw = np.array(
[[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32
)
evaluator.ref_global_velocity = np.zeros((1, 1, 3), dtype=np.float32)
evaluator.ref_global_angular_velocity = np.zeros(
(1, 1, 3), dtype=np.float32
)
output_path = tmp_path / "demo_eval_moe.npz"
evaluator.save_batch_result(output_path, {"source_file": "clip.npz"})
with np.load(output_path, allow_pickle=True) as payload:
np.testing.assert_array_equal(
payload["robot_moe_expert_indices"],
np.array([[[1, 3], [0, 2]]], dtype=np.int64),
)
np.testing.assert_allclose(
payload["robot_moe_expert_logits"],
np.array(
[[[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]]],
dtype=np.float32,
),
)
def test_dump_robot_augmented_npz_persists_moe_routing_tensors(tmp_path):
source_npz = tmp_path / "clip.npz"
np.savez(source_npz, ref=np.array([1], dtype=np.int32))
onnx_path = tmp_path / "model.onnx"
onnx_path.write_bytes(b"")
evaluator = MujocoEvaluator.__new__(MujocoEvaluator)
evaluator.simulation_dt = 0.005
evaluator.config = _Config(
motion_npz_path=str(source_npz),
ckpt_onnx_path=str(onnx_path),
)
evaluator._robot_dof_pos_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_vel_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_acc_seq = [np.zeros(2, dtype=np.float32)]
evaluator._robot_dof_torque_seq = [np.ones(2, dtype=np.float32)]
evaluator._robot_low_level_dof_torque_seq = [
np.array([0.5, -0.5], dtype=np.float32)
]
evaluator._robot_action_rate_seq = [np.float32(0.0)]
evaluator._robot_global_translation_seq = [
np.zeros((1, 3), dtype=np.float32)
]
evaluator._robot_global_rotation_quat_seq = [
np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32)
]
evaluator._robot_global_velocity_seq = [np.zeros((1, 3), dtype=np.float32)]
evaluator._robot_global_angular_velocity_seq = [
np.zeros((1, 3), dtype=np.float32)
]
evaluator._robot_moe_expert_indices_seq = [
np.array([[1, 3], [0, 2]], dtype=np.int64)
]
evaluator._robot_moe_expert_logits_seq = [
np.array(
[[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]],
dtype=np.float32,
)
]
evaluator._dump_robot_augmented_npz()
out_path = tmp_path / "mujoco_output_model" / "clip_robot.npz"
with np.load(out_path, allow_pickle=True) as payload:
np.testing.assert_array_equal(
payload["robot_moe_expert_indices"],
np.array([[[1, 3], [0, 2]]], dtype=np.int64),
)
np.testing.assert_allclose(
payload["robot_moe_expert_logits"],
np.array(
[[[0.1, 0.2, 0.3, 0.4], [1.0, 1.1, 1.2, 1.3]]],
dtype=np.float32,
),
)
def test_ray_actor_run_clip_overwrites_existing_outputs_and_sidecar(tmp_path):
class _FakeEvaluator:
def __init__(self):
self.n_motion_frames = 2
self.calls = []
self.counter = 0
def load_specific_motion(self, file_path):
self.calls.append(("load", file_path))
def reset_state_teleport(self):
self.calls.append(("reset",))
def _update_policy(self):
self.calls.append(("update",))
def _apply_control(self, sleep=False):
self.calls.append(("apply", sleep))
def save_batch_result(self, output_path, meta_info):
self.calls.append(("save_batch", output_path, meta_info))
Path(output_path).write_text("fresh-npz", encoding="utf-8")
def save_onnx_io_dump(self, output_path, meta_info):
self.calls.append(("save_onnx", output_path, meta_info))
np.save(
output_path,
{"source_npz": meta_info["source_file"]},
allow_pickle=True,
)
actor = RayEvaluatorActor.__new__(RayEvaluatorActor)
actor.output_dir = str(tmp_path)
actor.config_dict = {
"ckpt_onnx_path": "model.onnx",
"dump_onnx_io_npy": True,
}
actor.evaluator = _FakeEvaluator()
clip_path = tmp_path / "demo_clip.npz"
np.savez(clip_path, dummy=np.array([1], dtype=np.int32))
existing_npz = tmp_path / "demo_clip_eval.npz"
existing_npz.write_text("stale", encoding="utf-8")
onnx_dir = tmp_path / "onnx_io_npy"
onnx_dir.mkdir()
status = actor.run_clip(str(clip_path))
assert status == "success"
assert existing_npz.read_text(encoding="utf-8") == "fresh-npz"
onnx_dump_path = onnx_dir / "demo_clip_onnx_io.npy"
assert onnx_dump_path.is_file()
payload = np.load(onnx_dump_path, allow_pickle=True).item()
assert payload["source_npz"] == "demo_clip.npz"
assert ("load", str(clip_path)) in actor.evaluator.calls
assert ("reset",) in actor.evaluator.calls
assert actor.evaluator.calls.count(("update",)) == 2
def test_ray_actor_skips_sidecar_for_non_default_model_type(tmp_path):
class _FakeEvaluator:
def __init__(self):
self.n_motion_frames = 1
self.calls = []
self.counter = 0
def load_specific_motion(self, file_path):
self.calls.append(("load", file_path))
def reset_state_teleport(self):
self.calls.append(("reset",))
def _update_policy(self):
self.calls.append(("update",))
def _apply_control(self, sleep=False):
self.calls.append(("apply", sleep))
def save_batch_result(self, output_path, meta_info):
self.calls.append(("save_batch", output_path, meta_info))
Path(output_path).write_text("fresh-npz", encoding="utf-8")
def save_onnx_io_dump(self, output_path, meta_info):
self.calls.append(("save_onnx", output_path, meta_info))
actor = RayEvaluatorActor.__new__(RayEvaluatorActor)
actor.output_dir = str(tmp_path)
actor.config_dict = {
"ckpt_onnx_path": "model.onnx",
"dump_onnx_io_npy": True,
"model_type": "gmt",
}
actor.evaluator = _FakeEvaluator()
clip_path = tmp_path / "demo_clip.npz"
np.savez(clip_path, dummy=np.array([1], dtype=np.int32))
status = actor.run_clip(str(clip_path))
assert status == "success"
assert not any(call[0] == "save_onnx" for call in actor.evaluator.calls)
def test_ray_actor_treats_empty_model_type_as_default_holomotion(tmp_path):
class _FakeEvaluator:
def __init__(self):
self.n_motion_frames = 1
self.calls = []
self.counter = 0
def load_specific_motion(self, file_path):
self.calls.append(("load", file_path))
def reset_state_teleport(self):
self.calls.append(("reset",))
def _update_policy(self):
self.calls.append(("update",))
def _apply_control(self, sleep=False):
self.calls.append(("apply", sleep))
def save_batch_result(self, output_path, meta_info):
self.calls.append(("save_batch", output_path, meta_info))
Path(output_path).write_text("fresh-npz", encoding="utf-8")
def save_onnx_io_dump(self, output_path, meta_info):
self.calls.append(("save_onnx", output_path, meta_info))
np.save(output_path, {"source_npz": meta_info["source_file"]})
actor = RayEvaluatorActor.__new__(RayEvaluatorActor)
actor.output_dir = str(tmp_path)
actor.config_dict = {
"ckpt_onnx_path": "model.onnx",
"dump_onnx_io_npy": True,
"model_type": "",
}
actor.evaluator = _FakeEvaluator()
clip_path = tmp_path / "demo_clip.npz"
np.savez(clip_path, dummy=np.array([1], dtype=np.int32))
status = actor.run_clip(str(clip_path))
assert status == "success"
assert any(call[0] == "save_onnx" for call in actor.evaluator.calls)
def test_ray_actor_init_uses_configured_evaluator_module(
monkeypatch, tmp_path
):
class _FakeEvaluator:
def __init__(self):
self.setup_called = False
def setup(self):
self.setup_called = True
captured = {}
fake_evaluator = _FakeEvaluator()
def _unexpected_default_factory(*args, **kwargs):
raise AssertionError("default evaluator factory should not be used")
def _fake_override_factory(config_dict, model_type):
captured["config_dict"] = config_dict
captured["model_type"] = model_type
return fake_evaluator
monkeypatch.setattr(
"holomotion.src.evaluation.eval_mujoco_sim2sim._create_ray_evaluator",
_unexpected_default_factory,
)
sys.modules["holomotion.src.evaluation.fake_eval_module"] = (
types.SimpleNamespace(_create_ray_evaluator=_fake_override_factory)
)
actor = RayEvaluatorActor(
{
"ckpt_onnx_path": "model.onnx",
"model_type": "holomotion",
"ray_evaluator_module": "holomotion.src.evaluation.fake_eval_module",
},
str(tmp_path),
)
assert actor.evaluator is fake_evaluator
assert fake_evaluator.setup_called is True
assert captured["model_type"] == "holomotion"
assert (
captured["config_dict"]["ray_evaluator_module"]
== "holomotion.src.evaluation.fake_eval_module"
)
================================================
FILE: tests/test_evaluation_metrics.py
================================================
import csv
import json
from pathlib import Path
import numpy as np
from holomotion.src.evaluation.metrics import (
_compute_clip_stability_summary,
_per_frame_metrics_from_npz,
offline_evaluate_dumped_npzs,
)
def _make_eval_data(
robot_dof_torque: np.ndarray,
*,
robot_dof_vel: np.ndarray | None = None,
robot_dof_acc: np.ndarray | None = None,
robot_action_rate: np.ndarray | None = None,
robot_low_level_dof_torque: np.ndarray | None = None,
robot_global_angular_velocity: np.ndarray | None = None,
robot_low_level_foot_contact: np.ndarray | None = None,
robot_low_level_foot_normal_force: np.ndarray | None = None,
robot_low_level_foot_tangent_speed: np.ndarray | None = None,
robot_moe_expert_logits: np.ndarray | None = None,
):
num_frames = int(robot_dof_torque.shape[0])
num_dofs = int(robot_dof_torque.shape[1])
root = np.zeros((num_frames, 1, 3), dtype=np.float32)
child = np.tile(
np.array([[[0.0, 0.0, 1.0]]], dtype=np.float32), (num_frames, 1, 1)
)
global_translation = np.concatenate([root, child], axis=1)
global_rotation = np.tile(
np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32),
(num_frames, 2, 1),
)
zeros_dof = np.zeros((num_frames, num_dofs), dtype=np.float32)
payload = {
"ref_dof_pos": zeros_dof.copy(),
"robot_dof_pos": zeros_dof.copy(),
"ref_dof_vel": zeros_dof.copy(),
"ref_global_translation": global_translation.copy(),
"robot_global_translation": global_translation.copy(),
"ref_global_rotation_quat": global_rotation.copy(),
"robot_global_rotation_quat": global_rotation.copy(),
"ref_global_velocity": np.zeros((num_frames, 2, 3), dtype=np.float32),
"ref_global_angular_velocity": np.zeros(
(num_frames, 2, 3), dtype=np.float32
),
"robot_global_velocity": np.zeros(
(num_frames, 2, 3), dtype=np.float32
),
"robot_global_angular_velocity": (
np.zeros((num_frames, 2, 3), dtype=np.float32)
if robot_global_angular_velocity is None
else robot_global_angular_velocity.astype(np.float32)
),
"robot_dof_vel": (
zeros_dof.copy()
if robot_dof_vel is None
else robot_dof_vel.astype(np.float32)
),
"robot_dof_acc": (
zeros_dof.copy()
if robot_dof_acc is None
else robot_dof_acc.astype(np.float32)
),
"robot_dof_torque": robot_dof_torque.astype(np.float32),
"robot_action_rate": (
np.zeros((num_frames,), dtype=np.float32)
if robot_action_rate is None
else robot_action_rate.astype(np.float32)
),
}
if robot_low_level_dof_torque is not None:
payload["robot_low_level_dof_torque"] = (
robot_low_level_dof_torque.astype(np.float32)
)
if robot_low_level_foot_contact is not None:
payload["robot_low_level_foot_contact"] = (
robot_low_level_foot_contact.astype(np.float32)
)
if robot_low_level_foot_normal_force is not None:
payload["robot_low_level_foot_normal_force"] = (
robot_low_level_foot_normal_force.astype(np.float32)
)
if robot_low_level_foot_tangent_speed is not None:
payload["robot_low_level_foot_tangent_speed"] = (
robot_low_level_foot_tangent_speed.astype(np.float32)
)
if robot_moe_expert_logits is not None:
payload["robot_moe_expert_logits"] = robot_moe_expert_logits.astype(
np.float32
)
return payload
def test_per_frame_metrics_include_torque_jump_diagnostics():
constant_torque = np.ones((4, 2), dtype=np.float32)
constant_df = _per_frame_metrics_from_npz(
motion_key="constant",
data=_make_eval_data(constant_torque),
robot_control_dt=0.5,
)
assert "mean_torque_jump_norm" in constant_df.columns
assert "mean_torque_jump_ratio" in constant_df.columns
assert np.isnan(constant_df["mean_torque_jump_norm"].iloc[0])
assert np.isnan(constant_df["mean_torque_jump_ratio"].iloc[0])
np.testing.assert_allclose(
np.nan_to_num(constant_df["mean_torque_jump_norm"].to_numpy()),
np.zeros(4, dtype=np.float64),
)
np.testing.assert_allclose(
np.nan_to_num(constant_df["mean_torque_jump_ratio"].to_numpy()),
np.zeros(4, dtype=np.float64),
)
jump_torque = np.array(
[
[1.0, 0.0],
[1.0, 0.0],
[-1.0, 0.0],
[-1.0, 0.0],
],
dtype=np.float32,
)
jump_df = _per_frame_metrics_from_npz(
motion_key="jump",
data=_make_eval_data(jump_torque),
robot_control_dt=0.5,
)
assert jump_df["mean_torque_jump_norm"].iloc[2] > 3.9
assert jump_df["mean_torque_jump_ratio"].iloc[2] > 1.9
def test_offline_evaluate_dumped_npzs_exports_torque_jump_summary_metrics(
tmp_path: Path,
):
eval_dir = tmp_path / "eval"
eval_dir.mkdir()
jump_torque_50hz = np.tile(
np.array([[1.0, 0.0]], dtype=np.float32), (4, 1)
)
jump_torque_low_level = np.array(
[
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0],
[-1.0, 0.0],
[1.0, 0.0],
[-1.0, 0.0],
[1.0, 0.0],
[-1.0, 0.0],
[1.0, 0.0],
[-1.0, 0.0],
[1.0, 0.0],
],
dtype=np.float32,
)
payload = _make_eval_data(
jump_torque_50hz,
robot_low_level_dof_torque=jump_torque_low_level,
)
payload["metadata"] = np.array(
json.dumps({"clip_length": 4, "robot_low_level_torque_dt": 0.005}),
dtype=np.str_,
)
np.savez_compressed(eval_dir / "demo_clip.npz", **payload)
output_json_path = eval_dir / "summary.json"
result = offline_evaluate_dumped_npzs(
npz_dir=str(eval_dir),
output_json_path=str(output_json_path),
)
per_clip = result["per_clip"][0]
for key in (
"mean_torque_jump_norm",
"p95_torque_jump_norm",
"mean_torque_jump_ratio",
"p95_torque_jump_ratio",
):
assert key in per_clip
assert key in result["dataset"]["mean"]
assert per_clip["mean_dof_torque"] == 1.0
assert per_clip["p95_torque_jump_norm"] > 300.0
assert per_clip["p95_torque_jump_ratio"] > 1.0
with output_json_path.open("r", encoding="utf-8") as handle:
written = json.load(handle)
assert "p95_torque_jump_ratio" in written["dataset"]["mean"]
csv_path = eval_dir / "per_clip_metrics.csv"
with csv_path.open("r", encoding="utf-8", newline="") as handle:
reader = csv.DictReader(handle)
row = next(reader)
assert "p95_torque_jump_ratio" in row
assert "mean_torque_jump_norm" in row
def test_compute_clip_stability_summary_detects_chatter_and_support_events():
num_frames = 50
num_low_level = 200
policy_dt = 0.02
low_level_dt = 0.005
t_policy = np.arange(num_frames, dtype=np.float32) * policy_dt
t_low = np.arange(num_low_level, dtype=np.float32) * low_level_dt
smooth_ang_vel = np.zeros((num_frames, 2, 3), dtype=np.float32)
smooth_ang_vel[:, 0, 0] = 0.2 * np.sin(2.0 * np.pi * 1.0 * t_policy)
unstable_ang_vel = smooth_ang_vel.copy()
unstable_ang_vel[:, 0, 0] += 0.7 * np.sin(
2.0 * np.pi * 8.0 * t_policy
).astype(np.float32)
unstable_ang_vel[:, 0, 1] += 0.4 * np.sin(
2.0 * np.pi * 6.0 * t_policy
).astype(np.float32)
smooth_low_level_torque = np.zeros((num_low_level, 2), dtype=np.float32)
smooth_low_level_torque[:, 0] = np.sin(2.0 * np.pi * 1.0 * t_low)
unstable_low_level_torque = smooth_low_level_torque.copy()
unstable_low_level_torque[:, 0] += 0.8 * np.sin(
2.0 * np.pi * 15.0 * t_low
).astype(np.float32)
unstable_low_level_torque[80:85, 0] += 2.5
unstable_low_level_torque[120:123, 0] -= 2.5
stable_contact = np.zeros((num_low_level, 2), dtype=np.float32)
stable_contact[:100, 0] = 1.0
stable_contact[100:, 1] = 1.0
stable_normal_force = stable_contact * np.array(
[[80.0, 75.0]], dtype=np.float32
)
stable_tangent_speed = stable_contact * 0.01
unstable_contact = np.zeros((num_low_level, 2), dtype=np.float32)
for start in range(0, num_low_level, 10):
unstable_contact[start : start + 5, 0] = 1.0
unstable_contact[start + 5 : start + 10, 1] = 1.0
unstable_normal_force = unstable_contact * 60.0
touchdown_mask = unstable_contact.copy()
touchdown_mask[1:] = np.clip(
unstable_contact[1:] - unstable_contact[:-1], a_min=0.0, a_max=None
)
unstable_normal_force += touchdown_mask * 120.0
unstable_tangent_speed = unstable_contact * 0.25
smooth_metrics = _compute_clip_stability_summary(
data=_make_eval_data(
np.zeros((num_frames, 2), dtype=np.float32),
robot_low_level_dof_torque=smooth_low_level_torque,
robot_global_angular_velocity=smooth_ang_vel,
robot_low_level_foot_contact=stable_contact,
robot_low_level_foot_normal_force=stable_normal_force,
robot_low_level_foot_tangent_speed=stable_tangent_speed,
),
robot_control_dt=policy_dt,
low_level_contact_dt=low_level_dt,
)
unstable_metrics = _compute_clip_stability_summary(
data=_make_eval_data(
np.zeros((num_frames, 2), dtype=np.float32),
robot_low_level_dof_torque=unstable_low_level_torque,
robot_global_angular_velocity=unstable_ang_vel,
robot_low_level_foot_contact=unstable_contact,
robot_low_level_foot_normal_force=unstable_normal_force,
robot_low_level_foot_tangent_speed=unstable_tangent_speed,
),
robot_control_dt=policy_dt,
low_level_contact_dt=low_level_dt,
)
assert (
unstable_metrics["torque_chatter_hf_ratio"]
> smooth_metrics["torque_chatter_hf_ratio"]
)
assert (
unstable_metrics["torque_jump_burst_max"]
> smooth_metrics["torque_jump_burst_max"]
)
assert (
unstable_metrics["torso_rp_hf_ratio"]
> smooth_metrics["torso_rp_hf_ratio"]
)
assert (
unstable_metrics["torso_rp_angacc_p95"]
> smooth_metrics["torso_rp_angacc_p95"]
)
assert (
unstable_metrics["foot_contact_toggle_rate"]
> smooth_metrics["foot_contact_toggle_rate"]
)
assert (
unstable_metrics["foot_impact_force_p95"]
> smooth_metrics["foot_impact_force_p95"]
)
assert (
unstable_metrics["stance_slip_speed_p95"]
> smooth_metrics["stance_slip_speed_p95"]
)
def test_compute_clip_stability_summary_reports_expert_switching_js_div():
num_frames = 8
stable_logits = np.tile(
np.array(
[
[8.0, -4.0, -4.0],
[-4.0, 8.0, -4.0],
],
dtype=np.float32,
)[None, :, :],
(num_frames, 1, 1),
)
switching_logits = stable_logits.copy()
switching_logits[1::2, 0, :] = np.array(
[-4.0, 8.0, -4.0], dtype=np.float32
)
switching_logits[1::2, 1, :] = np.array(
[-4.0, -4.0, 8.0], dtype=np.float32
)
stable_metrics = _compute_clip_stability_summary(
data=_make_eval_data(
np.zeros((num_frames, 2), dtype=np.float32),
robot_moe_expert_logits=stable_logits,
),
robot_control_dt=0.02,
low_level_contact_dt=0.02,
)
switching_metrics = _compute_clip_stability_summary(
data=_make_eval_data(
np.zeros((num_frames, 2), dtype=np.float32),
robot_moe_expert_logits=switching_logits,
),
robot_control_dt=0.02,
low_level_contact_dt=0.02,
)
assert stable_metrics["expert_switching_js_div"] < 1e-6
assert (
switching_metrics["expert_switching_js_div"]
> stable_metrics["expert_switching_js_div"]
)
def test_offline_evaluate_dumped_npzs_reports_nan_contact_metrics_for_legacy_npz(
tmp_path: Path,
):
eval_dir = tmp_path / "legacy_eval"
eval_dir.mkdir()
payload = _make_eval_data(np.ones((8, 2), dtype=np.float32))
payload["metadata"] = np.array(
json.dumps({"clip_length": 8, "robot_low_level_torque_dt": 0.005}),
dtype=np.str_,
)
np.savez_compressed(eval_dir / "legacy_clip.npz", **payload)
output_json_path = eval_dir / "summary.json"
result = offline_evaluate_dumped_npzs(
npz_dir=str(eval_dir),
output_json_path=str(output_json_path),
)
per_clip = result["per_clip"][0]
for key in (
"torque_chatter_hf_ratio",
"torque_jump_burst_max",
"torso_rp_hf_ratio",
"torso_rp_angacc_p95",
"foot_contact_toggle_rate",
"foot_impact_force_p95",
"stance_slip_speed_p95",
"expert_switching_js_div",
):
assert key in per_clip
assert key in result["dataset"]["mean"]
assert np.isnan(per_clip["foot_contact_toggle_rate"])
assert np.isnan(per_clip["foot_impact_force_p95"])
assert np.isnan(per_clip["stance_slip_speed_p95"])
assert np.isnan(per_clip["expert_switching_js_div"])
================================================
FILE: tests/test_isaaclab_termination.py
================================================
import importlib.util
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import pytest
import torch
MODULE_PATH = (
Path(__file__).resolve().parents[1]
/ "holomotion"
/ "src"
/ "env"
/ "isaaclab_components"
/ "isaaclab_termination.py"
)
MOTION_COMMAND_MODULE_NAME = (
"holomotion.src.env.isaaclab_components.isaaclab_motion_tracking_command"
)
ISAACLAB_UTILS_MODULE_NAME = (
"holomotion.src.env.isaaclab_components.isaaclab_utils"
)
class _Scene(SimpleNamespace):
def __getitem__(self, key):
return getattr(self, key)
def _load_isaaclab_termination_module(module_name: str):
isaaclab_module = types.ModuleType("isaaclab")
isaaclab_envs = types.ModuleType("isaaclab.envs")
isaaclab_envs.ManagerBasedRLEnv = object
isaaclab_terminations = types.SimpleNamespace(
time_out=lambda env: torch.zeros(1, dtype=torch.bool),
bad_orientation=lambda env, limit_angle: torch.zeros(
1, dtype=torch.bool
),
root_height_below_minimum=lambda env, minimum_height: torch.zeros(
1, dtype=torch.bool
),
native_only_term=lambda env, margin: torch.zeros(1, dtype=torch.bool),
)
isaaclab_envs_mdp = types.ModuleType("isaaclab.envs.mdp")
isaaclab_envs_mdp.terminations = isaaclab_terminations
isaaclab_managers = types.ModuleType("isaaclab.managers")
class _TerminationTermCfg:
def __init__(self, func, params=None, time_out=False):
self.func = func
self.params = {} if params is None else params
self.time_out = time_out
isaaclab_managers.TerminationTermCfg = _TerminationTermCfg
isaaclab_managers.SceneEntityCfg = object
isaaclab_utils = types.ModuleType("isaaclab.utils")
isaaclab_utils.configclass = lambda cls: cls
isaaclab_utils_math = types.ModuleType("isaaclab.utils.math")
isaaclab_utils_math.quat_apply_inverse = (
lambda quat, vec: torch.zeros_like(vec)
)
isaaclab_utils.math = isaaclab_utils_math
isaaclab_assets = types.ModuleType("isaaclab.assets")
isaaclab_assets.Articulation = object
isaaclab_components_package = types.ModuleType(
"holomotion.src.env.isaaclab_components"
)
motion_command_module = types.ModuleType(MOTION_COMMAND_MODULE_NAME)
motion_command_module.RefMotionCommand = object
isaaclab_utils_module = types.ModuleType(ISAACLAB_UTILS_MODULE_NAME)
isaaclab_utils_module._get_body_indices = lambda robot, keybody_names: None
isaaclab_utils_module.resolve_holo_config = lambda cfg: cfg
isaaclab_components_package.isaaclab_motion_tracking_command = (
motion_command_module
)
isaaclab_components_package.isaaclab_utils = isaaclab_utils_module
fake_modules = {
"isaaclab": isaaclab_module,
"isaaclab.envs": isaaclab_envs,
"isaaclab.envs.mdp": isaaclab_envs_mdp,
"isaaclab.managers": isaaclab_managers,
"isaaclab.utils": isaaclab_utils,
"isaaclab.utils.math": isaaclab_utils_math,
"isaaclab.assets": isaaclab_assets,
"holomotion.src.env.isaaclab_components": isaaclab_components_package,
MOTION_COMMAND_MODULE_NAME: motion_command_module,
ISAACLAB_UTILS_MODULE_NAME: isaaclab_utils_module,
}
original_modules = {name: sys.modules.get(name) for name in fake_modules}
sys.modules.update(fake_modules)
try:
spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
finally:
for name, original in original_modules.items():
if original is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = original
def test_wholebody_mpjpe_far_flags_envs_above_mean_error_threshold():
termination_module = _load_isaaclab_termination_module(
"isaaclab_termination_under_test"
)
current_dof_pos = torch.tensor(
[
[0.0, 0.2, 0.6],
[0.0, 0.1, 0.2],
]
)
ref_dof_pos = torch.zeros_like(current_dof_pos)
command = SimpleNamespace(
robot=SimpleNamespace(data=SimpleNamespace(joint_pos=current_dof_pos)),
get_ref_motion_dof_pos_cur=lambda prefix="ref_": ref_dof_pos,
get_ref_motion_dof_pos_immediate_next=lambda prefix="ref_": ref_dof_pos,
)
env = SimpleNamespace(
command_manager=SimpleNamespace(get_term=lambda name: command)
)
result = termination_module.wholebody_mpjpe_far(env, threshold=0.2)
assert result.dtype == torch.bool
assert torch.equal(result, torch.tensor([True, False]))
def test_wholebody_mpjpe_far_uses_immediate_next_reference():
termination_module = _load_isaaclab_termination_module(
"isaaclab_termination_under_test_next_dof"
)
current_dof_pos = torch.tensor([[0.0, 0.1, 0.2]])
command = SimpleNamespace(
robot=SimpleNamespace(data=SimpleNamespace(joint_pos=current_dof_pos)),
get_ref_motion_dof_pos_cur=lambda prefix="ref_": (_ for _ in ()).throw(
AssertionError("current reference should not be used")
),
get_ref_motion_dof_pos_immediate_next=lambda prefix="ref_": current_dof_pos,
)
env = SimpleNamespace(
command_manager=SimpleNamespace(get_term=lambda name: command)
)
result = termination_module.wholebody_mpjpe_far(env, threshold=0.05)
assert torch.equal(result, torch.tensor([False]))
def test_keybody_ref_pos_far_uses_immediate_next_reference():
termination_module = _load_isaaclab_termination_module(
"isaaclab_termination_under_test_next_keybody"
)
body_pos = torch.tensor(
[[[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]]], dtype=torch.float32
)
robot = SimpleNamespace(
body_names=["anchor", "target"],
data=SimpleNamespace(body_pos_w=body_pos),
)
command = SimpleNamespace(
robot=robot,
get_ref_motion_bodylink_global_pos_cur=(
lambda prefix="ref_": (_ for _ in ()).throw(
AssertionError("current reference should not be used")
)
),
get_ref_motion_bodylink_global_pos_immediate_next=(
lambda prefix="ref_": body_pos
),
)
env = SimpleNamespace(
command_manager=SimpleNamespace(get_term=lambda name: command)
)
result = termination_module.keybody_ref_pos_far(
env,
threshold=0.1,
keybody_names=["target"],
)
assert torch.equal(result, torch.tensor([False]))
def test_ref_gravity_projection_far_uses_immediate_next_reference():
termination_module = _load_isaaclab_termination_module(
"isaaclab_termination_under_test_next_gravity"
)
gravity = torch.tensor([[0.0, 0.0, -1.0]], dtype=torch.float32)
anchor_quat = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
robot = SimpleNamespace(
data=SimpleNamespace(
GRAVITY_VEC_W=gravity,
body_quat_w=anchor_quat[:, None, :],
)
)
command = SimpleNamespace(
robot=robot,
anchor_bodylink_idx=0,
get_ref_motion_anchor_bodylink_global_rot_wxyz_cur=(
lambda prefix="ref_": (_ for _ in ()).throw(
AssertionError("current reference should not be used")
)
),
get_ref_motion_anchor_bodylink_global_rot_wxyz_immediate_next=(
lambda prefix="ref_": anchor_quat
),
)
env = SimpleNamespace(
scene=_Scene(robot=robot),
command_manager=SimpleNamespace(get_term=lambda name: command),
)
result = termination_module.ref_gravity_projection_far(
env,
threshold=0.1,
)
assert torch.equal(result, torch.tensor([False]))
def test_build_terminations_config_registers_wholebody_mpjpe_far():
termination_module = _load_isaaclab_termination_module(
"isaaclab_termination_under_test_for_cfg"
)
config = termination_module.build_terminations_config(
{
"wholebody_mpjpe_far": {
"params": {"threshold": 0.3},
}
}
)
assert (
config.wholebody_mpjpe_far.func
is termination_module.wholebody_mpjpe_far
)
assert config.wholebody_mpjpe_far.params == {"threshold": 0.3}
assert config.wholebody_mpjpe_far.time_out is False
def test_build_terminations_config_resolves_native_isaaclab_termination():
termination_module = _load_isaaclab_termination_module(
"isaaclab_termination_under_test_for_native_cfg"
)
config = termination_module.build_terminations_config(
{
"native_only_term": {
"params": {"margin": 0.3},
}
}
)
assert (
config.native_only_term.func
is termination_module.isaaclab_mdp.terminations.native_only_term
)
assert config.native_only_term.params == {"margin": 0.3}
assert config.native_only_term.time_out is False
def test_build_terminations_config_raises_on_unknown_termination():
termination_module = _load_isaaclab_termination_module(
"isaaclab_termination_under_test_for_unknown_cfg"
)
with pytest.raises(ValueError, match="Unknown termination function"):
termination_module.build_terminations_config({"missing_term": {}})
================================================
FILE: tests/test_mean_process_5metrics.py
================================================
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.scripts.evaluation import mean_process_5metrics
LEGACY_METRICS = [
"mpjpe_g",
"mpjpe_l",
"whole_body_joints_dist",
"root_vel_error",
"root_r_error",
"root_p_error",
"root_y_error",
"root_height_error",
"mean_dof_vel",
"mean_dof_acc",
"mean_dof_torque",
"mean_action_rate",
"success",
]
TORQUE_JUMP_METRICS = [
"mean_torque_jump_norm",
"p95_torque_jump_norm",
"mean_torque_jump_ratio",
"p95_torque_jump_ratio",
]
def test_macro_report_appends_torque_jump_columns_to_legacy_tables(
tmp_path, monkeypatch
):
json_path = tmp_path / "model_a.json"
payload = {
"per_clip": [
{
"motion_key": "clips_AMASS_demo",
"mpjpe_g": 1.0,
"mpjpe_l": 2.0,
"whole_body_joints_dist": 3.0,
"root_vel_error": 4.0,
"root_r_error": 5.0,
"root_p_error": 6.0,
"root_y_error": 7.0,
"root_height_error": 8.0,
"mean_dof_vel": 9.0,
"mean_dof_acc": 10.0,
"mean_dof_torque": 11.0,
"mean_action_rate": 12.0,
"success": 1.0,
"mean_torque_jump_norm": 13.0,
"p95_torque_jump_norm": 14.0,
"mean_torque_jump_ratio": 15.0,
"p95_torque_jump_ratio": 16.0,
}
]
}
json_path.write_text(json.dumps(payload), encoding="utf-8")
mean_df, _ = mean_process_5metrics.process_data(str(tmp_path))
assert mean_df.columns.tolist() == [
"Method",
"Dataset",
*LEGACY_METRICS,
*TORQUE_JUMP_METRICS,
]
captured_headers = {}
def _fake_tabulate(_rows, headers, **_kwargs):
captured_headers["headers"] = headers
return "fake-table"
monkeypatch.setattr(mean_process_5metrics, "tabulate", _fake_tabulate)
report_path = (
mean_process_5metrics.generate_macro_mean_report_from_json_dir(
str(tmp_path)
)
)
tsv_path = tmp_path / "sub_dataset_macro_mean_metrics.tsv"
header = tsv_path.read_text(encoding="utf-8").splitlines()[0].split("\t")
assert header == [
"Dataset",
"Global Bodylink Pos Err",
"Local Bodylink Pos Err",
"Dof Position Err",
"Root Vel Err",
"Root Roll Err",
"Root Pitch Err",
"Root Yaw Err",
"Root Height Err",
"Mean Dof Vel",
"Mean Dof Acc",
"Mean Dof Torque",
"Mean Action Rate",
"Success Rate",
"Mean Torque Jump Norm",
"P95 Torque Jump Norm",
"Mean Torque Jump Ratio",
"P95 Torque Jump Ratio",
]
assert captured_headers["headers"] == header
report_text = Path(report_path).read_text(encoding="utf-8")
legacy_index = report_text.index("Success Rate")
torque_index = report_text.index("Mean Torque Jump Norm")
assert torque_index > legacy_index
================================================
FILE: tests/test_motion_cache_gather_state.py
================================================
import sys
import unittest
from unittest import mock
from pathlib import Path
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.training.h5_dataloader import (
ClipBatch,
Hdf5RootDofDataset,
MotionClipBatchCache,
MotionWindow,
_CpuFKTransform,
_normalize_online_filter_cfg,
build_motion_datasets_from_cfg,
)
def _expected_field(
tensor: torch.Tensor,
clip_indices: torch.Tensor,
frame_indices: torch.Tensor,
n_future_frames: int,
lengths: torch.Tensor,
) -> torch.Tensor:
temporal_span = 1 + int(n_future_frames)
time_offsets = torch.arange(temporal_span, dtype=torch.long)
gather_timesteps = frame_indices[:, None] + time_offsets[None, :]
max_valid = torch.clamp(lengths.index_select(0, clip_indices) - 1, min=0)
gather_timesteps = torch.minimum(gather_timesteps, max_valid[:, None])
return tensor[clip_indices[:, None], gather_timesteps]
class MotionCacheGatherStateTests(unittest.TestCase):
def test_normalize_online_filter_cfg_includes_velocity_smoothing_sigmas(
self,
):
default_cfg = _normalize_online_filter_cfg({})
self.assertEqual(default_cfg["ref_vel_smoothing_sigma"], 2.0)
self.assertEqual(default_cfg["ft_ref_vel_smoothing_sigma"], 2.0)
explicit_cfg = _normalize_online_filter_cfg(
{
"enabled": True,
"butter_cutoff_hz_pool": [3.0],
"ref_vel_smoothing_sigma": 0.0,
"ft_ref_vel_smoothing_sigma": 2.0,
},
default_vel_smoothing_sigma=0.5,
)
self.assertEqual(explicit_cfg["ref_vel_smoothing_sigma"], 0.0)
self.assertEqual(explicit_cfg["ft_ref_vel_smoothing_sigma"], 2.0)
def test_normalize_online_filter_cfg_uses_fk_sigma_fallback_defaults(self):
cfg = _normalize_online_filter_cfg(
{},
default_vel_smoothing_sigma=0.5,
)
self.assertEqual(cfg["ref_vel_smoothing_sigma"], 0.5)
self.assertEqual(cfg["ft_ref_vel_smoothing_sigma"], 0.5)
def test_build_motion_datasets_from_cfg_passes_fk_sigma_fallback(self):
with (
mock.patch(
"holomotion.src.training.h5_dataloader.preview_sampling_from_cfg"
),
mock.patch(
"holomotion.src.training.h5_dataloader.Hdf5RootDofDataset"
) as dataset_cls,
):
build_motion_datasets_from_cfg(
{
"backend": "hdf5_v2",
"hdf5_root": "/tmp/train",
"fk_robot_file_path": "robot.xml",
"fk_vel_smoothing_sigma": 0.5,
"cache": {"allowed_prefixes": ["ref_", "ft_ref_"]},
"online_filter": {"enabled": False},
},
max_frame_length=16,
min_window_length=4,
)
self.assertEqual(dataset_cls.call_count, 1)
self.assertEqual(
dataset_cls.call_args.kwargs["fk_vel_smoothing_sigma"],
0.5,
)
def test_build_motion_datasets_from_cfg_defaults_fk_sigma_fallback(self):
with (
mock.patch(
"holomotion.src.training.h5_dataloader.preview_sampling_from_cfg"
),
mock.patch(
"holomotion.src.training.h5_dataloader.Hdf5RootDofDataset"
) as dataset_cls,
):
build_motion_datasets_from_cfg(
{
"backend": "hdf5_v2",
"hdf5_root": "/tmp/train",
"fk_robot_file_path": "robot.xml",
"cache": {"allowed_prefixes": ["ref_", "ft_ref_"]},
"online_filter": {"enabled": False},
},
max_frame_length=16,
min_window_length=4,
)
self.assertEqual(dataset_cls.call_count, 1)
self.assertEqual(
dataset_cls.call_args.kwargs["fk_vel_smoothing_sigma"],
2.0,
)
def test_gather_tensor_returns_expected_values(self):
cache = MotionClipBatchCache.__new__(MotionClipBatchCache)
ref_dof_pos = torch.arange(2 * 6 * 3, dtype=torch.float32).reshape(
2, 6, 3
)
ref_rg_pos = torch.arange(2 * 6 * 2 * 3, dtype=torch.float32).reshape(
2, 6, 2, 3
)
lengths = torch.tensor([6, 4], dtype=torch.long)
window_indices = torch.tensor([10, 11], dtype=torch.long)
cache._current_batch = ClipBatch(
tensors={
"ref_dof_pos": ref_dof_pos,
"ref_rg_pos": ref_rg_pos,
},
lengths=lengths,
motion_keys=["clip-a", "clip-b"],
raw_motion_keys=["clip-a", "clip-b"],
window_indices=window_indices,
max_frame_length=6,
)
clip_indices = torch.tensor([1, 0, 1, 1], dtype=torch.long)
frame_indices = torch.tensor([0, 2, 3, 1], dtype=torch.long)
gathered_dof_pos = cache.gather_tensor(
"ref_dof_pos",
clip_indices=clip_indices,
frame_indices=frame_indices,
n_future_frames=2,
)
gathered_rg_pos = cache.gather_tensor(
"ref_rg_pos",
clip_indices=clip_indices,
frame_indices=frame_indices,
n_future_frames=2,
)
expected_dof_pos = _expected_field(
ref_dof_pos,
clip_indices,
frame_indices,
n_future_frames=2,
lengths=lengths,
)
expected_rg_pos = _expected_field(
ref_rg_pos,
clip_indices,
frame_indices,
n_future_frames=2,
lengths=lengths,
)
torch.testing.assert_close(gathered_dof_pos, expected_dof_pos)
torch.testing.assert_close(gathered_rg_pos, expected_rg_pos)
self.assertEqual(tuple(gathered_dof_pos.shape), (4, 3, 3))
self.assertEqual(tuple(gathered_rg_pos.shape), (4, 3, 2, 3))
def test_gather_tensor_reflects_updated_indices_without_cached_state(self):
cache = MotionClipBatchCache.__new__(MotionClipBatchCache)
ref_dof_pos = torch.arange(3 * 6 * 3, dtype=torch.float32).reshape(
3, 6, 3
)
lengths = torch.tensor([6, 5, 4], dtype=torch.long)
window_indices = torch.tensor([10, 11, 12], dtype=torch.long)
cache._current_batch = ClipBatch(
tensors={"ref_dof_pos": ref_dof_pos},
lengths=lengths,
motion_keys=["clip-a", "clip-b", "clip-c"],
raw_motion_keys=["clip-a", "clip-b", "clip-c"],
window_indices=window_indices,
max_frame_length=6,
)
initial_clip_indices = torch.tensor([0, 1, 2, 1], dtype=torch.long)
initial_frame_indices = torch.tensor([0, 1, 2, 0], dtype=torch.long)
updated_clip_indices = torch.tensor([0, 2, 1, 0], dtype=torch.long)
updated_frame_indices = torch.tensor([1, 0, 3, 2], dtype=torch.long)
initial_gathered = cache.gather_tensor(
"ref_dof_pos",
clip_indices=initial_clip_indices,
frame_indices=initial_frame_indices,
n_future_frames=2,
)
updated_gathered = cache.gather_tensor(
"ref_dof_pos",
clip_indices=updated_clip_indices,
frame_indices=updated_frame_indices,
n_future_frames=2,
)
expected_initial = _expected_field(
ref_dof_pos,
initial_clip_indices,
initial_frame_indices,
n_future_frames=2,
lengths=lengths,
)
expected_updated = _expected_field(
ref_dof_pos,
updated_clip_indices,
updated_frame_indices,
n_future_frames=2,
lengths=lengths,
)
torch.testing.assert_close(initial_gathered, expected_initial)
torch.testing.assert_close(updated_gathered, expected_updated)
def test_cpu_fk_transform_forwards_explicit_vel_smoothing_sigma(self):
transform = _CpuFKTransform.__new__(_CpuFKTransform)
transform._fk = mock.Mock(
return_value={
"global_translation": torch.zeros(1, 4, 2, 3),
"global_rotation_quat": torch.zeros(1, 4, 2, 4),
"global_velocity": torch.zeros(1, 4, 2, 3),
"global_angular_velocity": torch.zeros(1, 4, 2, 3),
"dof_vel": torch.zeros(1, 4, 2),
}
)
arrays = {
"ref_root_pos": torch.zeros(4, 3),
"ref_root_rot": torch.zeros(4, 4),
"ref_dof_pos": torch.zeros(4, 2),
}
transform(
arrays,
fps=60.0,
prefix="ref_",
vel_smoothing_sigma=0.0,
)
self.assertEqual(
transform._fk.call_args.kwargs["vel_smoothing_sigma"],
0.0,
)
def test_cpu_fk_transform_defaults_vel_smoothing_sigma_to_two(self):
transform = _CpuFKTransform.__new__(_CpuFKTransform)
transform._fk = mock.Mock(
return_value={
"global_translation": torch.zeros(1, 4, 2, 3),
"global_rotation_quat": torch.zeros(1, 4, 2, 4),
"global_velocity": torch.zeros(1, 4, 2, 3),
"global_angular_velocity": torch.zeros(1, 4, 2, 3),
"dof_vel": torch.zeros(1, 4, 2),
}
)
arrays = {
"ref_root_pos": torch.zeros(4, 3),
"ref_root_rot": torch.zeros(4, 4),
"ref_dof_pos": torch.zeros(4, 2),
}
transform(arrays, fps=60.0)
self.assertEqual(
transform._fk.call_args.kwargs["vel_smoothing_sigma"],
2.0,
)
def test_hdf5_v2_sample_exposes_zero_cutoff_metadata_when_disabled(self):
dataset = self._make_stub_root_dof_dataset()
sample = dataset[0]
self.assertIn("filter_cutoff_hz", sample.tensors)
torch.testing.assert_close(
sample.tensors["filter_cutoff_hz"],
torch.zeros(4, 1, dtype=torch.float32),
)
def test_hdf5_v2_sample_exposes_sampled_cutoff_metadata(self):
dataset = self._make_stub_root_dof_dataset(
cutoff_pool=(3.0,),
online_filter_enabled=True,
)
sample = dataset[0]
self.assertIn("filter_cutoff_hz", sample.tensors)
torch.testing.assert_close(
sample.tensors["filter_cutoff_hz"],
torch.full((4, 1), 3.0, dtype=torch.float32),
)
def test_hdf5_v2_sample_generates_filtered_reference_family(self):
dataset = self._make_stub_root_dof_dataset(
cutoff_pool=(3.0,),
online_filter_enabled=True,
)
sample = dataset[0]
for tensor_name in (
"ft_ref_root_pos",
"ft_ref_root_rot",
"ft_ref_dof_pos",
"ft_ref_rg_pos",
"ft_ref_rb_rot",
"ft_ref_body_vel",
"ft_ref_body_ang_vel",
"ft_ref_dof_vel",
"ft_ref_root_vel",
"ft_ref_root_ang_vel",
):
self.assertIn(tensor_name, sample.tensors)
def test_hdf5_v2_sample_uses_split_fk_smoothing_sigmas(self):
dataset = self._make_stub_root_dof_dataset(
cutoff_pool=(3.0,),
online_filter_enabled=True,
ref_vel_smoothing_sigma=0.0,
ft_ref_vel_smoothing_sigma=2.0,
)
sample = dataset[0]
self.assertIn("ref_root_vel", sample.tensors)
self.assertIn("ft_ref_root_vel", sample.tensors)
self.assertEqual(
dataset._fk_calls,
[("ref_", 0.0), ("ft_ref_", 2.0)],
)
def test_hdf5_v2_sample_skips_filtered_reference_family_when_disabled(
self,
):
dataset = self._make_stub_root_dof_dataset(
cutoff_pool=(3.0,),
online_filter_enabled=True,
allowed_prefixes=("ref_",),
)
sample = dataset[0]
self.assertNotIn("ft_ref_root_pos", sample.tensors)
self.assertNotIn("ft_ref_rg_pos", sample.tensors)
@staticmethod
def _make_stub_root_dof_dataset(
*,
cutoff_pool=(0.0,),
online_filter_enabled=False,
allowed_prefixes=("ref_", "ft_ref_"),
ref_vel_smoothing_sigma=2.0,
ft_ref_vel_smoothing_sigma=2.0,
):
dataset = Hdf5RootDofDataset.__new__(Hdf5RootDofDataset)
dataset.windows = [
MotionWindow(
motion_key="clip-a__start_0_len_4",
shard_index=0,
start=0,
length=4,
raw_motion_key="clip-a",
window_index=0,
)
]
dataset.clips = {
"clip-a": {
"metadata": {
"motion_fps": 60.0,
}
}
}
dataset._progress_counter = None
dataset._world_frame_transform = None
dataset._file_handles = {}
dataset._h5_access_counter = 0
dataset._h5_cleanup_interval = int(1e6)
dataset._online_filter_enabled = bool(online_filter_enabled)
dataset._online_filter_cutoff_hz_pool = tuple(cutoff_pool)
dataset._allowed_prefixes = tuple(allowed_prefixes)
dataset._ref_vel_smoothing_sigma = float(ref_vel_smoothing_sigma)
dataset._ft_ref_vel_smoothing_sigma = float(ft_ref_vel_smoothing_sigma)
dataset._fk_calls = []
shard_handle = {
"ref_root_pos": torch.arange(12, dtype=torch.float32)
.reshape(4, 3)
.numpy(),
"ref_root_rot": torch.tensor(
[[0.0, 0.0, 0.0, 1.0]] * 4, dtype=torch.float32
).numpy(),
"ref_dof_pos": torch.arange(8, dtype=torch.float32)
.reshape(4, 2)
.numpy(),
}
dataset._online_filter_butter_order = 4
def fake_fk_transform(
arrays,
fps,
prefix="ref_",
vel_smoothing_sigma=2.0,
):
del fps
dataset._fk_calls.append((prefix, float(vel_smoothing_sigma)))
root_pos = arrays[f"{prefix}root_pos"]
root_rot = arrays[f"{prefix}root_rot"]
arrays[f"{prefix}rg_pos"] = torch.stack(
[root_pos, root_pos], dim=1
)
arrays[f"{prefix}rb_rot"] = torch.stack(
[root_rot, root_rot], dim=1
)
arrays[f"{prefix}body_vel"] = torch.zeros(
4, 2, 3, dtype=torch.float32
)
arrays[f"{prefix}body_ang_vel"] = torch.zeros(
4, 2, 3, dtype=torch.float32
)
arrays[f"{prefix}dof_vel"] = torch.zeros(4, 2, dtype=torch.float32)
dataset._fk_transform = fake_fk_transform
dataset._get_shard_handle = lambda shard_index: shard_handle
return dataset
================================================
FILE: tests/test_motion_cache_startup.py
================================================
from pathlib import Path
import sys
from types import SimpleNamespace
import unittest
from unittest import mock
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.algo.algo_base import BaseOnpolicyRL
import holomotion.src.training.h5_dataloader as h5_dataloader_module
from holomotion.src.training.h5_dataloader import MotionClipBatchCache
class _FakeDataset:
def __init__(self, length: int = 8) -> None:
self._length = int(length)
self.max_frame_length = 16
self.progress_counter = None
def __len__(self) -> int:
return self._length
def __getitem__(self, index: int):
raise AssertionError("__getitem__ should not be called in these tests")
def set_progress_counter(self, counter) -> None:
self.progress_counter = counter
def close(self) -> None:
return
class MotionCacheStartupTests(unittest.TestCase):
def test_motion_cache_uses_explicit_constructor_seed(self):
with (
mock.patch.object(
MotionClipBatchCache, "_build_dataloader", lambda self: None
),
mock.patch.object(
MotionClipBatchCache, "_prime_buffers", lambda self: None
),
):
cache = MotionClipBatchCache(
train_dataset=_FakeDataset(),
batch_size=2,
num_workers=0,
pin_memory=False,
persistent_workers=False,
seed=1234,
)
self.assertEqual(cache._seed, 1234)
def test_setup_seeding_does_not_reinitialize_motion_cache(self):
algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL)
algo.config = {"seed": 100}
algo.process_rank = 2
algo.command_name = "ref_motion"
env_seed_calls = []
motion_cache_seed_calls = []
algo.env = SimpleNamespace(
seed=lambda seed: env_seed_calls.append(seed)
)
algo.command_term = SimpleNamespace(
cfg=SimpleNamespace(seed=102),
set_motion_cache_seed=lambda seed,
reinitialize: motion_cache_seed_calls.append((seed, reinitialize)),
)
BaseOnpolicyRL._setup_seeding(algo)
self.assertEqual(algo.base_seed, 100)
self.assertEqual(algo.seed, 102)
self.assertEqual(env_seed_calls, [102])
self.assertEqual(motion_cache_seed_calls, [(102, False)])
def test_motion_cache_passes_loader_timeout_to_dataloader(self):
captured_kwargs = {}
class _FakeLoader:
def __init__(self, *args, **kwargs) -> None:
del args
captured_kwargs.update(kwargs)
with (
mock.patch.object(h5_dataloader_module, "DataLoader", _FakeLoader),
mock.patch.object(
MotionClipBatchCache, "_prime_buffers", lambda self: None
),
):
MotionClipBatchCache(
train_dataset=_FakeDataset(),
batch_size=2,
num_workers=0,
pin_memory=False,
persistent_workers=False,
loader_timeout=17,
)
self.assertEqual(captured_kwargs["timeout"], 17)
def test_motion_cache_disables_progress_bar_in_distributed_runs(self):
with (
mock.patch.object(
MotionClipBatchCache, "_build_dataloader", lambda self: None
),
mock.patch.object(
MotionClipBatchCache, "_prime_buffers", lambda self: None
),
):
cache = MotionClipBatchCache(
train_dataset=_FakeDataset(),
batch_size=2,
num_workers=0,
pin_memory=False,
persistent_workers=False,
sampler_world_size=8,
batch_progress_bar=True,
)
self.assertIs(cache._should_use_batch_progress(), False)
self.assertIsNone(cache._batch_progress_counter)
def test_motion_cache_keeps_progress_bar_for_local_runs(self):
with (
mock.patch.object(
MotionClipBatchCache, "_build_dataloader", lambda self: None
),
mock.patch.object(
MotionClipBatchCache, "_prime_buffers", lambda self: None
),
):
cache = MotionClipBatchCache(
train_dataset=_FakeDataset(),
batch_size=2,
num_workers=0,
pin_memory=False,
persistent_workers=False,
sampler_world_size=1,
batch_progress_bar=True,
)
self.assertIs(cache._should_use_batch_progress(), True)
self.assertIsNotNone(cache._batch_progress_counter)
def test_motion_cache_requires_positive_loader_timeout(self):
with (
mock.patch.object(
MotionClipBatchCache, "_build_dataloader", lambda self: None
),
mock.patch.object(
MotionClipBatchCache, "_prime_buffers", lambda self: None
),
):
with self.assertRaisesRegex(
ValueError, "loader_timeout must be >= 0"
):
MotionClipBatchCache(
train_dataset=_FakeDataset(),
batch_size=2,
num_workers=0,
pin_memory=False,
persistent_workers=False,
loader_timeout=-1,
)
================================================
FILE: tests/test_motion_tracking_command_reference_prefix.py
================================================
import sys
import unittest
from pathlib import Path
from types import SimpleNamespace
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.utils.reference_prefix import resolve_reference_tensor_key
class MotionTrackingCommandReferencePrefixTests(unittest.TestCase):
def test_ft_ref_prefix_uses_filtered_tensor_when_present(self):
resolved = resolve_reference_tensor_key(
batch_tensors={"ft_ref_root_pos": SimpleNamespace()},
base_key="root_pos",
prefix="ft_ref_",
)
self.assertEqual(resolved, "ft_ref_root_pos")
def test_ft_ref_prefix_requires_filtered_tensor(self):
with self.assertRaises(KeyError):
resolve_reference_tensor_key(
batch_tensors={"root_pos": SimpleNamespace()},
base_key="root_pos",
prefix="ft_ref_",
)
def test_ref_prefix_falls_back_to_unprefixed_tensor(self):
resolved = resolve_reference_tensor_key(
batch_tensors={"root_pos": SimpleNamespace()},
base_key="root_pos",
prefix="ref_",
)
self.assertEqual(resolved, "root_pos")
def test_ref_prefix_prefers_prefixed_tensor_when_present(self):
resolved = resolve_reference_tensor_key(
batch_tensors={
"root_pos": SimpleNamespace(),
"ref_root_pos": SimpleNamespace(),
},
base_key="root_pos",
prefix="ref_",
)
self.assertEqual(resolved, "ref_root_pos")
================================================
FILE: tests/test_motion_tracking_timing.py
================================================
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
import torch
MODULE_PATH = (
Path(__file__).resolve().parents[1]
/ "holomotion"
/ "src"
/ "env"
/ "isaaclab_components"
/ "isaaclab_motion_tracking_command.py"
)
class _DummyConfig:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def _install_fake_motion_command_deps(monkeypatch):
isaaclab_mdp = ModuleType("isaaclab.envs.mdp")
isaaclab_sim = ModuleType("isaaclab.sim")
isaaclab_sim.PreviewSurfaceCfg = _DummyConfig
isaaclab_sim.PhysxCfg = _DummyConfig
isaaclab_sim.SimulationCfg = _DummyConfig
isaaclab_math = ModuleType("isaaclab.utils.math")
isaaclab_math.quat_apply_inverse = lambda quat, vec: vec
isaaclab_math.quat_apply = lambda quat, vec: vec
isaaclab_math.yaw_quat = lambda quat: quat
isaaclab_math.quat_inv = lambda quat: quat
isaaclab_math.quat_mul = lambda lhs, rhs: lhs
isaaclab_math.sample_uniform = (
lambda low, high, shape, device=None: torch.zeros(
*shape, device=device
)
)
isaaclab_actuators = ModuleType("isaaclab.actuators")
isaaclab_actuators.ImplicitActuatorCfg = _DummyConfig
isaaclab_assets = ModuleType("isaaclab.assets")
isaaclab_assets.Articulation = object
isaaclab_assets.ArticulationCfg = _DummyConfig
isaaclab_assets.AssetBaseCfg = _DummyConfig
isaaclab_envs = ModuleType("isaaclab.envs")
isaaclab_envs.ManagerBasedRLEnv = object
isaaclab_envs.ManagerBasedRLEnvCfg = _DummyConfig
isaaclab_envs.ViewerCfg = _DummyConfig
isaaclab_envs_mdp_actions = ModuleType("isaaclab.envs.mdp.actions")
isaaclab_envs_mdp_actions.JointEffortActionCfg = _DummyConfig
isaaclab_managers = ModuleType("isaaclab.managers")
isaaclab_managers.ActionTermCfg = _DummyConfig
isaaclab_managers.CommandTerm = object
isaaclab_managers.CommandTermCfg = _DummyConfig
isaaclab_managers.EventTermCfg = _DummyConfig
isaaclab_managers.ObservationGroupCfg = _DummyConfig
isaaclab_managers.ObservationTermCfg = _DummyConfig
isaaclab_managers.RewardTermCfg = _DummyConfig
isaaclab_managers.TerminationTermCfg = _DummyConfig
isaaclab_markers = ModuleType("isaaclab.markers")
isaaclab_markers.VisualizationMarkers = _DummyConfig
isaaclab_markers.VisualizationMarkersCfg = _DummyConfig
isaaclab_markers_config = ModuleType("isaaclab.markers.config")
isaaclab_markers_config.SPHERE_MARKER_CFG = SimpleNamespace(
replace=lambda **kwargs: SimpleNamespace(
markers={"sphere": SimpleNamespace(radius=None)},
**kwargs,
)
)
isaaclab_scene = ModuleType("isaaclab.scene")
isaaclab_scene.InteractiveSceneCfg = _DummyConfig
isaaclab_sensors = ModuleType("isaaclab.sensors")
isaaclab_sensors.ContactSensorCfg = _DummyConfig
isaaclab_sensors.RayCasterCfg = _DummyConfig
isaaclab_sensors.patterns = _DummyConfig
isaaclab_terrains = ModuleType("isaaclab.terrains")
isaaclab_terrains.TerrainImporterCfg = _DummyConfig
isaaclab_utils = ModuleType("isaaclab.utils")
isaaclab_utils.configclass = lambda cls: cls
isaaclab_noise = ModuleType("isaaclab.utils.noise")
isaaclab_noise.AdditiveUniformNoiseCfg = _DummyConfig
h5_dataloader = ModuleType("holomotion.src.training.h5_dataloader")
h5_dataloader.Hdf5MotionDataset = object
h5_dataloader.Hdf5RootDofDataset = object
h5_dataloader.MotionClipBatchCache = object
h5_dataloader.build_motion_datasets_from_cfg = lambda *args, **kwargs: None
rotations = ModuleType("holomotion.src.utils.isaac_utils.rotations")
rotations.calc_heading_quat_inv = lambda *args, **kwargs: None
rotations.get_euler_xyz = lambda *args, **kwargs: None
rotations.my_quat_rotate = lambda *args, **kwargs: None
rotations.quat_inverse = lambda *args, **kwargs: None
rotations.quat_mul = lambda *args, **kwargs: None
rotations.quat_rotate = lambda *args, **kwargs: None
rotations.quat_rotate_inverse = lambda *args, **kwargs: None
rotations.quaternion_to_matrix = lambda *args, **kwargs: None
rotations.wrap_to_pi = lambda *args, **kwargs: None
rotations.wxyz_to_xyzw = lambda x: x
rotations.xyzw_to_wxyz = lambda x: x
reference_prefix = ModuleType("holomotion.src.utils.reference_prefix")
reference_prefix.resolve_reference_tensor_key = (
lambda batch_tensors, base_key, prefix="ref_": f"{prefix}{base_key}"
)
omegaconf = ModuleType("omegaconf")
omegaconf.OmegaConf = SimpleNamespace(
to_container=lambda value, resolve=True: value
)
loguru = ModuleType("loguru")
loguru.logger = SimpleNamespace(info=lambda *args, **kwargs: None)
tqdm = ModuleType("tqdm")
tqdm.tqdm = lambda iterable, *args, **kwargs: iterable
scipy = ModuleType("scipy")
scipy_spatial = ModuleType("scipy.spatial")
scipy_transform = ModuleType("scipy.spatial.transform")
scipy_transform.Rotation = object
for name, module in {
"isaaclab.envs.mdp": isaaclab_mdp,
"isaaclab.sim": isaaclab_sim,
"isaaclab.utils.math": isaaclab_math,
"isaaclab.actuators": isaaclab_actuators,
"isaaclab.assets": isaaclab_assets,
"isaaclab.envs": isaaclab_envs,
"isaaclab.envs.mdp.actions": isaaclab_envs_mdp_actions,
"isaaclab.managers": isaaclab_managers,
"isaaclab.markers": isaaclab_markers,
"isaaclab.markers.config": isaaclab_markers_config,
"isaaclab.scene": isaaclab_scene,
"isaaclab.sensors": isaaclab_sensors,
"isaaclab.terrains": isaaclab_terrains,
"isaaclab.utils": isaaclab_utils,
"isaaclab.utils.noise": isaaclab_noise,
"holomotion.src.training.h5_dataloader": h5_dataloader,
"holomotion.src.utils.isaac_utils.rotations": rotations,
"holomotion.src.utils.reference_prefix": reference_prefix,
"omegaconf": omegaconf,
"loguru": loguru,
"tqdm": tqdm,
"scipy": scipy,
"scipy.spatial": scipy_spatial,
"scipy.spatial.transform": scipy_transform,
}.items():
monkeypatch.setitem(sys.modules, name, module)
def _load_motion_command_module(monkeypatch):
_install_fake_motion_command_deps(monkeypatch)
module_name = "_test_motion_tracking_timing"
spec = importlib.util.spec_from_file_location(module_name, MODULE_PATH)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def test_immediate_next_reference_getters_use_slot_one(monkeypatch):
module = _load_motion_command_module(monkeypatch)
command = module.RefMotionCommand.__new__(module.RefMotionCommand)
command.urdf2sim_dof_idx = torch.tensor([1, 0], dtype=torch.long)
command.urdf2sim_body_idx = torch.tensor([1, 0], dtype=torch.long)
command._env_origins = torch.tensor(
[[10.0, 20.0, 30.0]], dtype=torch.float32
)
base_tensors = {
"ref_dof_pos": torch.tensor(
[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]], dtype=torch.float32
),
"ref_root_pos": torch.tensor(
[[[0.0, 1.0, 2.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
dtype=torch.float32,
),
"ref_body_vel": torch.tensor(
[
[
[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]],
[[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]],
[[4.0, 4.0, 4.0], [5.0, 5.0, 5.0]],
]
],
dtype=torch.float32,
),
}
command._get_ref_state_array = (
lambda base_key, prefix="ref_": base_tensors[f"{prefix}{base_key}"]
)
dof_pos = command.get_ref_motion_dof_pos_immediate_next()
root_pos = command.get_ref_motion_root_global_pos_immediate_next()
body_lin_vel = (
command.get_ref_motion_bodylink_global_lin_vel_immediate_next()
)
assert torch.allclose(dof_pos, torch.tensor([[4.0, 3.0]]))
assert torch.allclose(root_pos, torch.tensor([[17.0, 28.0, 39.0]]))
assert torch.allclose(
body_lin_vel,
torch.tensor([[[3.0, 3.0, 3.0], [2.0, 2.0, 2.0]]]),
)
def test_update_command_skips_just_reset_envs(monkeypatch):
module = _load_motion_command_module(monkeypatch)
command = module.RefMotionCommand.__new__(module.RefMotionCommand)
command.device = torch.device("cpu")
command.num_envs = 3
command._frame_indices = torch.tensor([10, 20, 30], dtype=torch.long)
command._swap_step_counter = 0
command._swap_pending = False
command._motion_cache = SimpleNamespace(swap_interval_steps=100)
command._env = SimpleNamespace(
episode_length_buf=torch.tensor([5, 0, 2], dtype=torch.long)
)
command._filter_env_ids_for_motion_task = lambda env_ids: env_ids
command._resample_when_motion_end_cache = lambda: None
command._update_ref_motion_state_from_cache = lambda env_ids=None: None
command._update_command()
assert torch.equal(command._frame_indices, torch.tensor([11, 20, 31]))
assert command._swap_step_counter == 1
def test_update_command_resumes_advancing_after_reset_step(monkeypatch):
module = _load_motion_command_module(monkeypatch)
command = module.RefMotionCommand.__new__(module.RefMotionCommand)
command.device = torch.device("cpu")
command.num_envs = 1
command._frame_indices = torch.tensor([20], dtype=torch.long)
command._swap_step_counter = 0
command._swap_pending = False
command._motion_cache = SimpleNamespace(swap_interval_steps=100)
command._env = SimpleNamespace(episode_length_buf=torch.tensor([0]))
command._filter_env_ids_for_motion_task = lambda env_ids: env_ids
command._resample_when_motion_end_cache = lambda: None
command._update_ref_motion_state_from_cache = lambda env_ids=None: None
command._update_command()
assert torch.equal(command._frame_indices, torch.tensor([20]))
command._env.episode_length_buf = torch.tensor([1])
command._update_command()
assert torch.equal(command._frame_indices, torch.tensor([21]))
def test_mpjpe_metrics_use_immediate_next_reference(monkeypatch):
module = _load_motion_command_module(monkeypatch)
command = module.RefMotionCommand.__new__(module.RefMotionCommand)
command.device = torch.device("cpu")
command.num_envs = 1
command.metrics = {}
command.arm_dof_indices = [0]
command.torso_dof_indices = [1]
command.leg_dof_indices = [2]
command.robot = SimpleNamespace(
data=SimpleNamespace(
joint_pos=torch.tensor([[0.1, 0.2, 0.3]], dtype=torch.float32)
)
)
command.get_ref_motion_dof_pos_cur = lambda prefix="ref_": (
_ for _ in ()
).throw(AssertionError("current reference should not be used"))
command.get_ref_motion_dof_pos_immediate_next = (
lambda prefix="ref_": torch.tensor(
[[0.1, 0.2, 0.3]], dtype=torch.float32
)
)
command._update_mpjpe_metrics()
assert torch.allclose(
command.metrics["Task/MPJPE_WholeBody"], torch.zeros(1)
)
def test_mpkpe_metrics_use_immediate_next_reference(monkeypatch):
module = _load_motion_command_module(monkeypatch)
command = module.RefMotionCommand.__new__(module.RefMotionCommand)
command.device = torch.device("cpu")
command.num_envs = 1
command.metrics = {}
command.arm_body_indices = [0]
command.torso_body_indices = [1]
command.leg_body_indices = [2]
command.robot = SimpleNamespace(
data=SimpleNamespace(
body_pos_w=torch.tensor(
[[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]],
dtype=torch.float32,
)
)
)
command.get_ref_motion_bodylink_global_pos_cur = lambda prefix="ref_": (
_ for _ in ()
).throw(AssertionError("current reference should not be used"))
command.get_ref_motion_bodylink_global_pos_immediate_next = (
lambda prefix="ref_": torch.tensor(
[[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]],
dtype=torch.float32,
)
)
command._update_mpkpe_metrics()
assert torch.allclose(
command.metrics["Task/MPKPE_WholeBody"], torch.zeros(1)
)
================================================
FILE: tests/test_mujoco_filtered_ref_compat.py
================================================
import tempfile
from pathlib import Path
import numpy as np
from omegaconf import OmegaConf
from holomotion.src.evaluation.eval_mujoco_sim2sim import MujocoEvaluator
from holomotion.src.evaluation.obs.obs_builder import PolicyObsBuilder
PROJECT_ROOT = Path(__file__).resolve().parents[1]
OBS_CONFIG_PATH = (
PROJECT_ROOT
/ "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3_with_freq.yaml"
)
MODULE_CONFIG_PATH = (
PROJECT_ROOT
/ "holomotion/config/modules/motion_tracking/tf_motrack_v3_with_ft.yaml"
)
OBS_CONFIG_PATH_V2 = (
PROJECT_ROOT
/ "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3_sonic_router_v2.yaml"
)
MODULE_CONFIG_PATH_V2 = (
PROJECT_ROOT
/ "holomotion/config/modules/motion_tracking/tf_motrack_v3_wo_eepos_ref_route_v2.yaml"
)
DOMAIN_RAND_CONFIG_PATH = (
PROJECT_ROOT
/ "holomotion/config/env/domain_randomization/domain_rand_strong.yaml"
)
def _make_minimal_motion_npz(path: Path, *, include_cutoff: bool) -> None:
payload = {
"ref_global_translation": np.zeros((2, 1, 3), dtype=np.float32),
"ref_global_rotation_quat": np.tile(
np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32),
(2, 1, 1),
),
"ref_global_velocity": np.zeros((2, 1, 3), dtype=np.float32),
"ref_global_angular_velocity": np.zeros((2, 1, 3), dtype=np.float32),
"ref_dof_pos": np.zeros((2, 2), dtype=np.float32),
"ref_dof_vel": np.zeros((2, 2), dtype=np.float32),
}
if include_cutoff:
payload["filter_cutoff_hz"] = np.array(
[[2.0], [3.0]], dtype=np.float32
)
np.savez(path, **payload)
def test_policy_obs_list_accepts_shared_cutoff_term():
config = OmegaConf.merge(
OmegaConf.load(OBS_CONFIG_PATH),
OmegaConf.load(MODULE_CONFIG_PATH),
)
evaluator = MujocoEvaluator.__new__(MujocoEvaluator)
evaluator.config = config
atomic_obs_list = evaluator._get_policy_atomic_obs_list()
term_names = [str(list(item.keys())[0]) for item in atomic_obs_list]
assert term_names[0] == "ref_motion_filter_cutoff_hz"
assert "actor_ref_gravity_projection_cur" in term_names
def test_cutoff_obs_getters_use_current_frame_and_default_zero():
evaluator = MujocoEvaluator.__new__(MujocoEvaluator)
evaluator.motion_frame_idx = 1
evaluator.filter_cutoff_hz = np.array([[2.0], [3.0]], dtype=np.float32)
assert evaluator._get_obs_ref_motion_filter_cutoff_hz() == np.float32(3.0)
assert (
evaluator._get_obs_actor_ref_motion_filter_cutoff_hz()
== np.float32(3.0)
)
missing = MujocoEvaluator.__new__(MujocoEvaluator)
missing.motion_frame_idx = 0
assert missing._get_obs_ref_motion_filter_cutoff_hz() == 0.0
def test_policy_obs_list_v2_uses_only_actor_schema_terms():
config = OmegaConf.merge(
OmegaConf.load(OBS_CONFIG_PATH_V2),
OmegaConf.load(MODULE_CONFIG_PATH_V2),
OmegaConf.load(DOMAIN_RAND_CONFIG_PATH),
)
evaluator = MujocoEvaluator.__new__(MujocoEvaluator)
evaluator.config = config
atomic_obs_list = evaluator._get_policy_atomic_obs_list()
term_names = [str(list(item.keys())[0]) for item in atomic_obs_list]
assert not any(name.startswith("actor_moe_router_") for name in term_names)
assert "actor_ref_gravity_projection_cur" in term_names
assert "actor_ref_base_linvel_fut" in term_names
def test_load_specific_motion_loads_cutoff_metadata_with_zero_fallback():
with tempfile.TemporaryDirectory() as tmp_dir:
with_cutoff = Path(tmp_dir) / "with_cutoff.npz"
without_cutoff = Path(tmp_dir) / "without_cutoff.npz"
_make_minimal_motion_npz(with_cutoff, include_cutoff=True)
_make_minimal_motion_npz(without_cutoff, include_cutoff=False)
evaluator = MujocoEvaluator.__new__(MujocoEvaluator)
evaluator.load_specific_motion(with_cutoff)
np.testing.assert_allclose(
evaluator.filter_cutoff_hz,
np.array([[2.0], [3.0]], dtype=np.float32),
)
evaluator.load_specific_motion(without_cutoff)
np.testing.assert_allclose(
evaluator.filter_cutoff_hz,
np.zeros((2, 1), dtype=np.float32),
)
================================================
FILE: tests/test_obs_norm_compile.py
================================================
import torch
import torch.nn as nn
from holomotion.src.modules.agent_modules import PPOTFActor
from holomotion.src.modules.network_modules import EmpiricalNormalization
def _make_actor_with_obs_norm(obs_dim: int = 16) -> PPOTFActor:
actor = PPOTFActor.__new__(PPOTFActor)
nn.Module.__init__(actor)
actor.obs_norm_enabled = True
actor.obs_norm_clip = 10.0
actor.obs_normalizer = EmpiricalNormalization(shape=(obs_dim,))
return actor
def test_obs_norm_update_is_not_captured_by_dynamo():
actor = _make_actor_with_obs_norm()
obs = torch.randn(8, 16)
def normalize_with_update(x: torch.Tensor) -> torch.Tensor:
return actor._normalize_actor_obs(x, True)
explanation = torch._dynamo.explain(normalize_with_update)(obs)
graph_code = "\n".join(graph.code for graph in explanation.graphs)
assert "torch.var" not in graph_code
assert "torch.mean" not in graph_code
count_before_compile = actor.obs_normalizer.count.item()
compiled = torch.compile(normalize_with_update, backend="eager")
normalized = compiled(obs)
assert normalized.shape == obs.shape
assert (
actor.obs_normalizer.count.item() - count_before_compile
== obs.shape[0]
)
================================================
FILE: tests/test_observation_frames.py
================================================
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
import torch
OBSERVATION_PATH = (
Path(__file__).resolve().parents[1]
/ "holomotion"
/ "src"
/ "env"
/ "isaaclab_components"
/ "isaaclab_observation.py"
)
class _DummyConfig:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class _Scene(SimpleNamespace):
def __getitem__(self, key):
return getattr(self, key)
def _identity_quat(*shape: int) -> torch.Tensor:
quat = torch.zeros(*shape, 4, dtype=torch.float32)
quat[..., 0] = 1.0
return quat
def _load_observation_module(monkeypatch):
isaaclab = ModuleType("isaaclab")
isaaclab_mdp = ModuleType("isaaclab.envs.mdp")
isaaclab_math = ModuleType("isaaclab.utils.math")
isaaclab_math.quat_apply = lambda quat, vec: vec
isaaclab_math.quat_apply_inverse = lambda quat, vec: vec
isaaclab_math.quat_inv = lambda quat: quat
isaaclab_math.matrix_from_quat = lambda quat: torch.zeros(
*quat.shape[:-1], 3, 3, dtype=quat.dtype, device=quat.device
)
isaaclab_math.subtract_frame_transforms = lambda t01, q01, t02, q02: (
t02 - t01,
q02,
)
isaaclab_math.__getattr__ = lambda name: (lambda *args, **kwargs: None)
isaaclab_noise = ModuleType("isaaclab.utils.noise")
isaaclab_noise.__getattr__ = lambda name: _DummyConfig
isaaclab_envs = ModuleType("isaaclab.envs")
isaaclab_envs.ManagerBasedRLEnv = object
isaaclab_envs.ManagerBasedRLEnvCfg = _DummyConfig
isaaclab_envs.ViewerCfg = _DummyConfig
isaaclab_sim = ModuleType("isaaclab.sim")
isaaclab_sim.__getattr__ = lambda name: _DummyConfig
isaaclab_actuators = ModuleType("isaaclab.actuators")
isaaclab_actuators.ImplicitActuatorCfg = _DummyConfig
isaaclab_assets = ModuleType("isaaclab.assets")
isaaclab_assets.Articulation = object
isaaclab_assets.ArticulationCfg = _DummyConfig
isaaclab_assets.AssetBaseCfg = _DummyConfig
isaaclab_managers = ModuleType("isaaclab.managers")
isaaclab_managers.__getattr__ = lambda name: _DummyConfig
isaaclab_markers = ModuleType("isaaclab.markers")
isaaclab_markers.VisualizationMarkers = _DummyConfig
isaaclab_markers.VisualizationMarkersCfg = _DummyConfig
isaaclab_markers_config = ModuleType("isaaclab.markers.config")
isaaclab_markers_config.FRAME_MARKER_CFG = _DummyConfig
isaaclab_scene = ModuleType("isaaclab.scene")
isaaclab_scene.InteractiveSceneCfg = _DummyConfig
isaaclab_sensors = ModuleType("isaaclab.sensors")
isaaclab_sensors.ContactSensorCfg = _DummyConfig
isaaclab_sensors.RayCasterCfg = _DummyConfig
isaaclab_sensors.patterns = _DummyConfig
isaaclab_terrains = ModuleType("isaaclab.terrains")
isaaclab_terrains.TerrainImporterCfg = _DummyConfig
isaaclab_utils = ModuleType("isaaclab.utils")
isaaclab_utils.configclass = lambda cls: cls
omegaconf = ModuleType("omegaconf")
omegaconf.DictConfig = dict
omegaconf.ListConfig = list
omegaconf.OmegaConf = SimpleNamespace(
to_container=lambda value, resolve=True: value
)
fake_utils_module = ModuleType(
"holomotion.src.env.isaaclab_components.isaaclab_utils"
)
fake_utils_module.resolve_holo_config = lambda value: value
isaaclab.envs = isaaclab_envs
isaaclab.sim = isaaclab_sim
isaaclab.actuators = isaaclab_actuators
isaaclab.assets = isaaclab_assets
isaaclab.managers = isaaclab_managers
isaaclab.markers = isaaclab_markers
isaaclab.scene = isaaclab_scene
isaaclab.sensors = isaaclab_sensors
isaaclab.terrains = isaaclab_terrains
isaaclab.utils = isaaclab_utils
isaaclab_envs.mdp = isaaclab_mdp
isaaclab_utils.math = isaaclab_math
isaaclab_utils.noise = isaaclab_noise
for name, module in {
"isaaclab": isaaclab,
"isaaclab.envs.mdp": isaaclab_mdp,
"isaaclab.utils.math": isaaclab_math,
"isaaclab.utils.noise": isaaclab_noise,
"isaaclab.envs": isaaclab_envs,
"isaaclab.sim": isaaclab_sim,
"isaaclab.actuators": isaaclab_actuators,
"isaaclab.assets": isaaclab_assets,
"isaaclab.managers": isaaclab_managers,
"isaaclab.markers": isaaclab_markers,
"isaaclab.markers.config": isaaclab_markers_config,
"isaaclab.scene": isaaclab_scene,
"isaaclab.sensors": isaaclab_sensors,
"isaaclab.terrains": isaaclab_terrains,
"isaaclab.utils": isaaclab_utils,
"omegaconf": omegaconf,
(
"holomotion.src.env.isaaclab_components.isaaclab_utils"
): fake_utils_module,
}.items():
monkeypatch.setitem(sys.modules, name, module)
module_name = "_test_observation_frames"
spec = importlib.util.spec_from_file_location(
module_name, OBSERVATION_PATH
)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def test_ref_future_observations_can_limit_num_frames(monkeypatch):
observation = _load_observation_module(monkeypatch)
class _Command:
def get_ref_motion_dof_pos_fut(self, prefix="ref_"):
return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)
def get_ref_motion_dof_vel_fut(self, prefix="ref_"):
return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)
def get_ref_motion_gravity_projection_fut(self, prefix="ref_"):
return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)
def get_ref_motion_base_linvel_fut(self, prefix="ref_"):
return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)
def get_ref_motion_base_angvel_fut(self, prefix="ref_"):
return torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)
def get_ref_motion_root_global_pos_fut(self, prefix="ref_"):
pos = torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)
pos[..., 2] = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [0.5, 1.5, 2.5, 3.5]],
dtype=torch.float32,
)
return pos
class _CommandManager:
def get_term(self, name):
return _Command()
env = SimpleNamespace(
command_manager=_CommandManager(),
scene=SimpleNamespace(env_origins=torch.zeros(2, 3)),
)
dof_pos = observation.ObservationFunctions._get_obs_ref_dof_pos_fut(
env, num_frames=2
)
dof_vel = observation.ObservationFunctions._get_obs_ref_dof_vel_fut(
env, num_frames=2
)
gravity = (
observation.ObservationFunctions._get_obs_ref_gravity_projection_fut(
env, num_frames=2
)
)
base_linvel = (
observation.ObservationFunctions._get_obs_ref_base_linvel_fut(
env, num_frames=2
)
)
base_angvel = (
observation.ObservationFunctions._get_obs_ref_base_angvel_fut(
env, num_frames=2
)
)
root_height = (
observation.ObservationFunctions._get_obs_ref_root_height_fut(
env, num_frames=2
)
)
assert dof_pos.shape == (2, 2, 3)
assert dof_vel.shape == (2, 6)
assert gravity.shape == (2, 2, 3)
assert base_linvel.shape == (2, 2, 3)
assert base_angvel.shape == (2, 2, 3)
assert root_height.shape == (2, 2, 1)
torch.testing.assert_close(
dof_pos,
torch.tensor(
[
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]],
],
dtype=torch.float32,
),
)
torch.testing.assert_close(
dof_vel,
torch.tensor(
[
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
[12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
],
dtype=torch.float32,
),
)
torch.testing.assert_close(
root_height[..., 0], torch.tensor([[1.0, 2.0], [0.5, 1.5]])
)
def _make_env():
env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32)
robot_data = SimpleNamespace(
body_pos_w=torch.tensor(
[[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32
),
body_quat_w=_identity_quat(1, 2),
)
robot = SimpleNamespace(body_names=["anchor", "target"], data=robot_data)
command = SimpleNamespace(
anchor_bodylink_name="anchor",
get_ref_motion_anchor_bodylink_global_pos_cur=(
lambda prefix="ref_": torch.tensor([[10.0, 0.0, 0.0]])
),
get_ref_motion_anchor_bodylink_global_rot_wxyz_cur=(
lambda prefix="ref_": _identity_quat(1)
),
)
return SimpleNamespace(
num_envs=1,
scene=_Scene(env_origins=env_origins, robot=robot),
command_manager=SimpleNamespace(get_term=lambda name: command),
)
def test_global_robot_bodylink_pos_is_in_environment_frame(monkeypatch):
observation = _load_observation_module(monkeypatch)
env = _make_env()
pos = observation.ObservationFunctions._get_obs_global_robot_bodylink_pos(
env,
keybody_names=["target"],
)
assert torch.allclose(pos, torch.tensor([[[1.0, 0.0, 0.0]]]))
def test_root_rel_robot_bodylink_pos_uses_consistent_env_frame(monkeypatch):
observation = _load_observation_module(monkeypatch)
env = _make_env()
observation.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros(
1, 3, dtype=torch.float32
)
observation.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1)
pos = (
observation.ObservationFunctions._get_obs_root_rel_robot_bodylink_pos(
env,
keybody_names=["target"],
)
)
assert torch.allclose(pos, torch.tensor([[[1.0, 0.0, 0.0]]]))
def test_global_anchor_pos_diff_uses_environment_frame_consistently(
monkeypatch,
):
observation = _load_observation_module(monkeypatch)
env = _make_env()
pos_diff = (
observation.ObservationFunctions._get_obs_global_anchor_pos_diff(env)
)
assert torch.allclose(pos_diff, torch.zeros(1, 3))
def test_build_additive_uniform_noise_cfg_supports_optional_z_override(
monkeypatch,
):
observation = _load_observation_module(monkeypatch)
noise = observation._build_noise_cfg(
{
"type": "AdditiveUniformNoiseCfg",
"params": {
"n_min": -0.1,
"n_max": 0.1,
"n_min_z": -0.02,
"n_max_z": 0.03,
},
}
)
assert torch.equal(
noise.kwargs["n_min"], torch.tensor([-0.1, -0.1, -0.02])
)
assert torch.equal(noise.kwargs["n_max"], torch.tensor([0.1, 0.1, 0.03]))
def test_build_additive_uniform_noise_cfg_keeps_scalar_bounds_without_z_override(
monkeypatch,
):
observation = _load_observation_module(monkeypatch)
noise = observation._build_noise_cfg(
{
"type": "AdditiveUniformNoiseCfg",
"params": {
"n_min": -0.1,
"n_max": 0.1,
},
}
)
assert noise.kwargs["n_min"] == pytest.approx(-0.1)
assert noise.kwargs["n_max"] == pytest.approx(0.1)
================================================
FILE: tests/test_onnx_attention_export.py
================================================
import sys
import tempfile
from pathlib import Path
import numpy as np
import onnx
import onnxruntime
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.modules.network_modules import (
export_safe_scaled_dot_product_attention,
)
class _ExportAttentionModule(nn.Module):
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
return export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
)
class _ExportCausalAttentionModule(nn.Module):
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
)
def _export_model(
export_path: Path,
module: nn.Module,
inputs: tuple[torch.Tensor, ...],
input_names: list[str],
) -> None:
torch.onnx.export(
module.eval(),
inputs,
str(export_path),
opset_version=17,
input_names=input_names,
output_names=["out"],
dynamo=False,
verbose=False,
)
def _export_op_types(
module: nn.Module,
*inputs: torch.Tensor,
input_names: list[str],
) -> list[str]:
with tempfile.TemporaryDirectory() as tmp_dir:
export_path = Path(tmp_dir) / "attention.onnx"
_export_model(export_path, module, inputs, input_names)
model = onnx.load(str(export_path))
return [node.op_type for node in model.graph.node]
def _run_onnx(
module: nn.Module,
*inputs: torch.Tensor,
input_names: list[str],
) -> np.ndarray:
with tempfile.TemporaryDirectory() as tmp_dir:
export_path = Path(tmp_dir) / "attention.onnx"
_export_model(export_path, module, inputs, input_names)
session = onnxruntime.InferenceSession(
str(export_path),
providers=["CPUExecutionProvider"],
)
feed = {
name: tensor.detach().cpu().numpy()
for name, tensor in zip(input_names, inputs, strict=True)
}
outputs = session.run(["out"], feed)
return outputs[0]
def test_export_safe_attention_uses_native_bool_mask_outside_export(
monkeypatch,
):
captured = {}
original_sdpa = F.scaled_dot_product_attention
def _spy_sdpa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
enable_gqa: bool = False,
) -> torch.Tensor:
captured["mask_dtype"] = None if attn_mask is None else attn_mask.dtype
return original_sdpa(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
enable_gqa=enable_gqa,
)
monkeypatch.setattr(torch.onnx, "is_in_onnx_export", lambda: False)
monkeypatch.setattr(F, "scaled_dot_product_attention", _spy_sdpa)
q = torch.randn(1, 2, 3, 4)
k = torch.randn(1, 2, 5, 4)
v = torch.randn(1, 2, 5, 4)
mask = torch.ones(1, 1, 3, 5, dtype=torch.bool)
export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
)
assert captured["mask_dtype"] == torch.bool
def test_export_safe_attention_matches_sdpa_for_valid_masks():
torch.manual_seed(0)
q = torch.randn(2, 4, 3, 8)
k = torch.randn(2, 4, 5, 8)
v = torch.randn(2, 4, 5, 8)
mask = torch.tensor(
[
[[[True, True, False, False, False]]],
[[[True, False, False, False, False]]],
],
dtype=torch.bool,
).expand(2, 1, 3, 5)
expected = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
)
actual = export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
)
torch.testing.assert_close(actual, expected, atol=1.0e-6, rtol=1.0e-5)
def test_legacy_attention_export_avoids_isnan():
torch.manual_seed(2)
q = torch.randn(1, 4, 1, 8)
k = torch.randn(1, 4, 16, 8)
v = torch.randn(1, 4, 16, 8)
mask = torch.ones(1, 1, 1, 16, dtype=torch.bool)
op_types = _export_op_types(
_ExportAttentionModule(),
q,
k,
v,
mask,
input_names=["q", "k", "v", "mask"],
)
assert "IsNaN" not in op_types
def test_legacy_attention_export_ort_matches_pytorch_for_future_mask():
torch.manual_seed(3)
q = torch.randn(2, 4, 3, 8)
k = torch.randn(2, 4, 5, 8)
v = torch.randn(2, 4, 5, 8)
mask = torch.tensor(
[
[[[True, True, True, False, False]]],
[[[True, False, False, False, False]]],
],
dtype=torch.bool,
).expand(2, 1, 3, 5)
expected = export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
)
actual = _run_onnx(
_ExportAttentionModule(),
q,
k,
v,
mask,
input_names=["q", "k", "v", "mask"],
)
np.testing.assert_allclose(
actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5
)
def test_legacy_attention_export_ort_matches_pytorch_for_causal_path():
torch.manual_seed(4)
q = torch.randn(2, 4, 6, 8)
k = torch.randn(2, 4, 6, 8)
v = torch.randn(2, 4, 6, 8)
expected = export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
)
actual = _run_onnx(
_ExportCausalAttentionModule(),
q,
k,
v,
input_names=["q", "k", "v"],
)
np.testing.assert_allclose(
actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5
)
def test_legacy_attention_export_ort_matches_pytorch_for_kv_mask():
torch.manual_seed(5)
q = torch.randn(2, 4, 1, 8)
k = torch.randn(2, 4, 16, 8)
v = torch.randn(2, 4, 16, 8)
valid_lengths = torch.tensor([16, 5], dtype=torch.int64)
mask = (
torch.arange(16, dtype=torch.int64)[None, :] < valid_lengths[:, None]
)
mask = mask[:, None, None, :]
expected = export_safe_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=0.0,
is_causal=False,
)
actual = _run_onnx(
_ExportAttentionModule(),
q,
k,
v,
mask,
input_names=["q", "k", "v", "mask"],
)
np.testing.assert_allclose(
actual, expected.detach().cpu().numpy(), atol=1.0e-6, rtol=1.0e-5
)
================================================
FILE: tests/test_onnx_export.py
================================================
import sys
from types import SimpleNamespace
from holomotion.src.utils.onnx_export import attach_onnx_metadata_holomotion
class _FakeEntry:
def __init__(self):
self.key = ""
self.value = ""
class _FakeTensor:
def __init__(self, values):
self._values = values
def __getitem__(self, index):
return _FakeTensor(self._values[index])
def cpu(self):
return self
def tolist(self):
return self._values
def test_attach_onnx_metadata_uses_default_joint_gains(monkeypatch):
model = SimpleNamespace(metadata_props=[])
fake_onnx = SimpleNamespace(
load=lambda path: model,
save=lambda loaded_model, path: None,
StringStringEntryProto=_FakeEntry,
)
monkeypatch.setitem(sys.modules, "onnx", fake_onnx)
robot_data = SimpleNamespace(
joint_names=["joint_a", "joint_b"],
joint_stiffness=_FakeTensor([[0.0, 0.0]]),
joint_damping=_FakeTensor([[0.0, 0.0]]),
default_joint_stiffness=_FakeTensor([[10.0, 20.0]]),
default_joint_damping=_FakeTensor([[1.0, 2.0]]),
default_joint_pos=_FakeTensor([[0.1, -0.2]]),
)
action_term = SimpleNamespace(_scale=_FakeTensor([[0.5, 0.25]]))
env = SimpleNamespace(
scene={"robot": SimpleNamespace(data=robot_data)},
action_manager=SimpleNamespace(
get_term=lambda name: action_term,
),
)
attach_onnx_metadata_holomotion(env, "dummy.onnx")
metadata = {entry.key: entry.value for entry in model.metadata_props}
assert metadata["joint_stiffness"] == "10.000,20.000"
assert metadata["joint_damping"] == "1.000,2.000"
================================================
FILE: tests/test_plot_moe_expert_heatmap.py
================================================
import importlib.util
from pathlib import Path
from unittest.mock import MagicMock
import numpy as np
SCRIPT_PATH = (
Path(__file__).resolve().parents[1]
/ "not_for_commit"
/ "plot_moe_expert_heatmap.py"
)
def _load_plot_moe_expert_heatmap_module():
spec = importlib.util.spec_from_file_location(
"plot_moe_expert_heatmap", SCRIPT_PATH
)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _write_eval_npz(path: Path) -> None:
np.savez(
path,
robot_moe_expert_logits=np.array(
[
[[0.0, 1.0, 2.0, 3.0], [1.5, 0.5, -0.5, -1.5]],
[[0.1, 1.1, 2.1, 3.1], [1.0, 0.0, -1.0, -2.0]],
[[0.2, 1.2, 2.2, 3.2], [0.5, -0.5, -1.5, -2.5]],
[[0.3, 1.3, 2.3, 3.3], [0.0, -1.0, -2.0, -3.0]],
[[0.4, 1.4, 2.4, 3.4], [-0.5, -1.5, -2.5, -3.5]],
],
dtype=np.float32,
),
robot_moe_expert_indices=np.array(
[
[[3, 2], [0, 1]],
[[3, 2], [0, 1]],
[[3, 2], [0, 1]],
[[3, 2], [0, 1]],
[[3, 2], [0, 1]],
],
dtype=np.int64,
),
robot_dof_torque=np.linspace(-1.0, 1.0, 15, dtype=np.float32).reshape(
5, 3
),
robot_actions=np.linspace(-0.5, 0.5, 15, dtype=np.float32).reshape(
5, 3
),
robot_low_level_dof_torque=np.zeros((20, 3), dtype=np.float32),
robot_low_level_torque_dt=np.array(0.01, dtype=np.float32),
)
def test_plot_dump_exports_moe_heatmap_pdf(tmp_path):
module = _load_plot_moe_expert_heatmap_module()
npz_path = tmp_path / "demo_eval.npz"
_write_eval_npz(npz_path)
output_path = module.plot_dump(npz_path)
assert output_path == (
tmp_path / "demo_eval_moe_expert_probability_heatmap.pdf"
)
assert output_path.is_file()
assert (tmp_path / "demo_eval_robot_dof_torque_line_plot.pdf").is_file()
assert (tmp_path / "demo_eval_robot_actions_line_plot.pdf").is_file()
def test_selected_expert_weights_are_renormalized_within_selected_ids():
module = _load_plot_moe_expert_heatmap_module()
probabilities = np.array(
[
[[0.1, 0.2, 0.3, 0.4], [0.7, 0.1, 0.1, 0.1]],
[[0.25, 0.25, 0.25, 0.25], [0.05, 0.15, 0.3, 0.5]],
],
dtype=np.float32,
)
expert_indices = np.array(
[
[[1, 3], [0, 2]],
[[0, 2], [1, 3]],
],
dtype=np.int64,
)
selected_weights = module.compute_selected_expert_weights(
probabilities, expert_indices
)
np.testing.assert_allclose(
selected_weights,
np.array(
[
[[1.0 / 3.0, 2.0 / 3.0], [0.875, 0.125]],
[[0.5, 0.5], [0.23076923, 0.7692308]],
],
dtype=np.float32,
),
)
def test_selected_expert_heatmap_only_colors_activated_experts():
module = _load_plot_moe_expert_heatmap_module()
probabilities = np.array(
[
[[0.1, 0.2, 0.3, 0.4], [0.7, 0.1, 0.1, 0.1]],
[[0.25, 0.25, 0.25, 0.25], [0.05, 0.15, 0.3, 0.5]],
],
dtype=np.float32,
)
expert_indices = np.array(
[
[[1, 3], [0, 2]],
[[0, 2], [1, 3]],
],
dtype=np.int64,
)
selected_heatmap = module.build_selected_expert_heatmap(
probabilities, expert_indices
)
np.testing.assert_allclose(
selected_heatmap,
np.array(
[
[
[0.0, 1.0 / 3.0, 0.0, 2.0 / 3.0],
[0.875, 0.0, 0.125, 0.0],
],
[
[0.5, 0.0, 0.5, 0.0],
[0.0, 0.23076923, 0.0, 0.7692308],
],
],
dtype=np.float32,
),
)
def test_collect_npz_paths_recursively_sorts_directory_entries(tmp_path):
module = _load_plot_moe_expert_heatmap_module()
input_dir = tmp_path / "evals"
first_npz = input_dir / "z_branch" / "clip_z.npz"
second_npz = input_dir / "a_branch" / "nested" / "clip_a.npz"
second_npz.parent.mkdir(parents=True)
first_npz.parent.mkdir(parents=True)
_write_eval_npz(first_npz)
_write_eval_npz(second_npz)
(input_dir / "ignore.txt").write_text("ignore", encoding="utf-8")
assert module.collect_npz_paths(input_dir) == [second_npz, first_npz]
def test_plot_input_path_directory_generates_all_heatmaps_with_tqdm(
tmp_path,
):
module = _load_plot_moe_expert_heatmap_module()
input_dir = tmp_path / "evals"
npz_paths = [
input_dir / "z_branch" / "clip_z.npz",
input_dir / "a_branch" / "nested" / "clip_a.npz",
]
for npz_path in npz_paths:
npz_path.parent.mkdir(parents=True, exist_ok=True)
_write_eval_npz(npz_path)
expected_output_paths = [
input_dir
/ "a_branch"
/ "nested"
/ "clip_a_moe_expert_probability_heatmap.pdf",
input_dir / "z_branch" / "clip_z_moe_expert_probability_heatmap.pdf",
]
fake_tqdm = MagicMock(side_effect=lambda iterable, **_: iterable)
original_tqdm = module.tqdm
module.tqdm = fake_tqdm
try:
output_paths = module.plot_input_path(input_dir)
finally:
module.tqdm = original_tqdm
assert output_paths == expected_output_paths
assert all(path.is_file() for path in expected_output_paths)
expected_torque_paths = [
input_dir
/ "a_branch"
/ "nested"
/ "clip_a_robot_dof_torque_line_plot.pdf",
input_dir / "z_branch" / "clip_z_robot_dof_torque_line_plot.pdf",
]
assert all(path.is_file() for path in expected_torque_paths)
expected_action_paths = [
input_dir
/ "a_branch"
/ "nested"
/ "clip_a_robot_actions_line_plot.pdf",
input_dir / "z_branch" / "clip_z_robot_actions_line_plot.pdf",
]
assert all(path.is_file() for path in expected_action_paths)
assert list(fake_tqdm.call_args.args[0]) == sorted(npz_paths)
assert fake_tqdm.call_args.kwargs == {
"desc": "Generating plot PDFs",
"unit": "file",
"dynamic_ncols": True,
}
def test_plot_dump_requires_2d_robot_dof_torque(tmp_path):
module = _load_plot_moe_expert_heatmap_module()
npz_path = tmp_path / "bad_eval.npz"
np.savez(
npz_path,
robot_moe_expert_logits=np.zeros((2, 1, 3), dtype=np.float32),
robot_dof_torque=np.zeros((2,), dtype=np.float32),
)
try:
module.plot_dump(npz_path)
except ValueError as exc:
assert "robot_dof_torque must have shape [frames, dofs]" in str(exc)
else:
raise AssertionError(
"Expected plot_dump to reject 1-D robot_dof_torque"
)
================================================
FILE: tests/test_plot_state_series.py
================================================
import importlib.util
from pathlib import Path
import numpy as np
SCRIPT_PATH = (
Path(__file__).resolve().parents[1]
/ "not_for_commit"
/ "plot_state_series.py"
)
def _load_plot_state_series_module():
spec = importlib.util.spec_from_file_location(
"plot_state_series", SCRIPT_PATH
)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_plot_dump_exports_time_matched_scalar_series(tmp_path):
module = _load_plot_state_series_module()
robot_config_path = tmp_path / "robot.yaml"
robot_config_path.write_text(
"robot:\n dof_names:\n - joint_a\n - joint_b\n",
encoding="utf-8",
)
npz_path = tmp_path / "demo_eval.npz"
np.savez(
npz_path,
robot_dof_torque=np.arange(10, dtype=np.float32).reshape(5, 2),
robot_dof_acc=np.arange(10, 20, dtype=np.float32).reshape(5, 2),
robot_action_rate=np.linspace(0.0, 1.0, 5, dtype=np.float32),
reward=np.linspace(1.0, 2.0, 5, dtype=np.float32),
bad_scalar=np.array([1.0, 2.0], dtype=np.float32),
metadata=np.array("demo", dtype=" events.index("actor_optimizer")
assert events.index("override_sigma") > events.index("critic_optimizer")
algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict)
def test_ppo_load_skips_optimizer_restore_during_offline_eval():
algo = PPO.__new__(PPO)
algo.is_main_process = False
algo.is_offline_eval = True
algo.device = torch.device("cpu")
algo.actor = nn.Linear(1, 1)
algo.critic = nn.Linear(1, 1)
algo.accelerator = SimpleNamespace(unwrap_model=lambda model: model)
algo.config = {}
algo._load_extra_checkpoint_state = mock.Mock()
algo._resolve_model_file_path = (
lambda ckpt_path, model_name: f"{ckpt_path}:{model_name}"
)
algo._load_accelerate_model = mock.Mock()
algo._maybe_override_loaded_actor_sigma = mock.Mock()
algo.actor_optimizer = mock.Mock()
algo.critic_optimizer = mock.Mock()
loaded_dict = {
"actor_optimizer_state_dict": {"state": {"stale": {}}},
"critic_optimizer_state_dict": {"state": {"stale": {}}},
"iter": 321,
"infos": {"source": "offline-eval"},
}
with mock.patch(
"holomotion.src.algo.ppo.torch.load", return_value=loaded_dict
):
infos = algo.load("checkpoint.pt")
assert infos == {"source": "offline-eval"}
assert algo.current_learning_iteration == 321
algo.actor_optimizer.load_state_dict.assert_not_called()
algo.critic_optimizer.load_state_dict.assert_not_called()
algo._maybe_override_loaded_actor_sigma.assert_called_once_with()
algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict)
def test_ppo_load_skips_incompatible_optimizer_state_restore():
algo = PPO.__new__(PPO)
algo.is_main_process = False
algo.is_offline_eval = False
algo.device = torch.device("cpu")
algo.actor = nn.Linear(1, 1)
algo.critic = nn.Linear(1, 1)
algo.accelerator = SimpleNamespace(unwrap_model=lambda model: model)
algo.config = {}
algo._load_extra_checkpoint_state = mock.Mock()
algo._resolve_model_file_path = (
lambda ckpt_path, model_name: f"{ckpt_path}:{model_name}"
)
algo._load_accelerate_model = mock.Mock()
algo._maybe_override_loaded_actor_sigma = mock.Mock()
algo.actor_optimizer = mock.Mock()
algo.actor_optimizer.state_dict.return_value = {
"state": {},
"param_groups": [{"params": [0]}],
}
algo.actor_optimizer.load_state_dict.side_effect = AssertionError(
"incompatible actor optimizer state should be skipped"
)
algo.critic_optimizer = mock.Mock()
algo.critic_optimizer.state_dict.return_value = {
"state": {},
"param_groups": [{"params": [0]}],
}
loaded_dict = {
"actor_optimizer_state_dict": {
"state": {0: {"step": torch.tensor(1)}},
"param_groups": [{"params": [0, 1]}],
},
"critic_optimizer_state_dict": {
"state": {0: {"step": torch.tensor(2)}},
"param_groups": [{"params": [0]}],
},
"iter": 77,
"infos": {"source": "resume-training"},
}
with mock.patch(
"holomotion.src.algo.ppo.torch.load", return_value=loaded_dict
):
infos = algo.load("checkpoint.pt")
assert infos == {"source": "resume-training"}
assert algo.current_learning_iteration == 77
algo.actor_optimizer.load_state_dict.assert_not_called()
algo.critic_optimizer.load_state_dict.assert_called_once_with(
loaded_dict["critic_optimizer_state_dict"]
)
algo._maybe_override_loaded_actor_sigma.assert_called_once_with()
algo._load_extra_checkpoint_state.assert_called_once_with(loaded_dict)
def test_checkpoint_state_to_cpu_moves_nested_tensors():
source = {
"state": {
0: {
"exp_avg": torch.tensor([1.0, 2.0], requires_grad=True),
"exp_avg_sq": torch.tensor([3.0, 4.0]),
}
},
"param_groups": [{"lr": 1.0e-3}],
"step_tensor": torch.tensor([5]),
}
converted = _checkpoint_state_to_cpu(source)
assert converted is not source
assert converted["state"] is not source["state"]
assert converted["state"][0]["exp_avg"].device.type == "cpu"
assert converted["state"][0]["exp_avg_sq"].device.type == "cpu"
assert converted["step_tensor"].device.type == "cpu"
assert converted["state"][0]["exp_avg"].requires_grad is False
torch.testing.assert_close(
converted["state"][0]["exp_avg"],
source["state"][0]["exp_avg"].detach(),
)
================================================
FILE: tests/test_ppo_entropy_annealing.py
================================================
from pathlib import Path
import sys
from types import SimpleNamespace
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.algo.algo_base import BaseOnpolicyRL
from holomotion.src.algo.ppo import PPO
def _build_entropy_algo(
*,
initial_entropy_coef: float,
anneal_entropy: bool,
zero_entropy_point: float,
current_learning_iteration: int,
total_learning_iterations: int,
num_learning_iterations: int = 0,
):
algo = PPO.__new__(PPO)
algo.initial_entropy_coef = float(initial_entropy_coef)
algo.anneal_entropy = bool(anneal_entropy)
algo.zero_entropy_point = float(zero_entropy_point)
algo.current_learning_iteration = int(current_learning_iteration)
algo.total_learning_iterations = int(total_learning_iterations)
algo.num_learning_iterations = int(num_learning_iterations)
return algo
def test_entropy_coef_is_constant_when_annealing_disabled():
algo = _build_entropy_algo(
initial_entropy_coef=5.0e-3,
anneal_entropy=False,
zero_entropy_point=1.0,
current_learning_iteration=50,
total_learning_iterations=100,
)
assert algo._get_effective_entropy_coef() == pytest.approx(5.0e-3)
def test_entropy_coef_decays_and_respects_resumed_total_iterations():
algo = _build_entropy_algo(
initial_entropy_coef=5.0e-3,
anneal_entropy=True,
zero_entropy_point=1.0,
current_learning_iteration=123,
total_learning_iterations=133,
num_learning_iterations=10,
)
expected = 5.0e-3 * max(0.0, 1.0 - 123.0 / 133.0)
assert algo._get_effective_entropy_coef() == pytest.approx(expected)
def test_entropy_coef_clamps_to_zero_at_and_after_zero_point():
algo = _build_entropy_algo(
initial_entropy_coef=5.0e-3,
anneal_entropy=True,
zero_entropy_point=0.75,
current_learning_iteration=75,
total_learning_iterations=100,
)
assert algo._get_effective_entropy_coef() == pytest.approx(0.0)
algo.current_learning_iteration = 90
assert algo._get_effective_entropy_coef() == pytest.approx(0.0)
@pytest.mark.parametrize(
("initial_entropy_coef", "anneal_entropy", "zero_entropy_point"),
[
(-1.0, False, 1.0),
(1.0, True, 0.0),
(1.0, True, -0.1),
(1.0, True, 1.1),
],
)
def test_validate_entropy_schedule_config_rejects_invalid_values(
initial_entropy_coef: float,
anneal_entropy: bool,
zero_entropy_point: float,
):
with pytest.raises(ValueError):
PPO._validate_entropy_schedule_config(
initial_entropy_coef=initial_entropy_coef,
anneal_entropy=anneal_entropy,
zero_entropy_point=zero_entropy_point,
)
def test_learn_sets_current_iteration_before_each_update():
algo = BaseOnpolicyRL.__new__(BaseOnpolicyRL)
algo.env = SimpleNamespace(reset_all=lambda: ({},))
algo._wrap_obs_dict = lambda obs_dict: obs_dict
algo._ensure_storage = lambda obs_td: None
algo.train_mode = lambda: None
algo.rollout_policy = lambda obs_td: obs_td
algo.log_dir = "/tmp/holomotion-test"
algo.num_learning_iterations = 3
algo.current_learning_iteration = 5
algo.total_learning_iterations = 0
algo.log_interval = 100
algo.save_interval = 100
algo.is_main_process = False
algo.ep_infos = []
algo._post_iteration_hook = lambda it: None
algo._post_training_hook = lambda: None
algo._release_cuda_cache = lambda: None
algo.save = lambda *args, **kwargs: None
algo.accelerator = SimpleNamespace(
wait_for_everyone=lambda: None,
end_training=lambda: None,
)
observed_iterations = []
observed_totals = []
def _update():
observed_iterations.append(algo.current_learning_iteration)
observed_totals.append(algo.total_learning_iterations)
return {}
algo.update = _update
BaseOnpolicyRL.learn(algo)
assert observed_iterations == [5, 6, 7]
assert observed_totals == [8, 8, 8]
================================================
FILE: tests/test_ppo_symmetry_loss.py
================================================
from contextlib import nullcontext
from types import ModuleType, SimpleNamespace
import sys
import pytest
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from tensordict import TensorDict
from holomotion.src.algo.ppo import PPO
class _DummyAccelerator:
def autocast(self):
return nullcontext()
def backward(self, loss):
loss.backward()
def clip_grad_norm_(self, parameters, max_norm):
torch.nn.utils.clip_grad_norm_(list(parameters), max_norm)
def reduce(self, tensor, reduction="mean"):
return tensor
class _DummyActor(nn.Module):
def __init__(self, num_actions: int, mirror_offset: float):
super().__init__()
self.mu_param = nn.Parameter(torch.full((num_actions,), 0.25))
self.log_std = nn.Parameter(torch.zeros(num_actions))
self.mirror_offset = float(mirror_offset)
def forward(
self,
obs_td: TensorDict,
actions: torch.Tensor | None = None,
mode: str = "sampling",
*,
update_obs_norm: bool = True,
) -> TensorDict:
del obs_td, update_obs_norm
batch_size = int(actions.shape[0]) if actions is not None else 2
mu = self.mu_param.unsqueeze(0).expand(batch_size, -1)
sigma = torch.exp(self.log_std).unsqueeze(0).expand(batch_size, -1)
out = TensorDict({}, batch_size=[batch_size])
out.set("mu", mu)
out.set("sigma", sigma)
if mode == "inference":
out.set("actions", mu + self.mirror_offset)
return out
if actions is None:
actions = mu
out.set("actions", actions)
zero_with_grad = mu.sum(dim=-1) * 0.0
out.set("actions_log_prob", zero_with_grad)
out.set("entropy", zero_with_grad)
return out
class _DummyCritic(nn.Module):
def __init__(self):
super().__init__()
self.value = nn.Parameter(torch.tensor([0.1], dtype=torch.float32))
def forward(self, obs_td: TensorDict, *, update_obs_norm: bool = True):
del obs_td, update_obs_norm
batch_size = 2
out = TensorDict({}, batch_size=[batch_size])
out.set("values", self.value.view(1, 1).expand(batch_size, 1))
return out
class _SingleBatchStorage:
def __init__(self, batch):
self._batch = batch
self.data = {
"returns": torch.zeros(2, 1, 1, dtype=torch.float32),
"values": torch.zeros(2, 1, 1, dtype=torch.float32),
}
self.num_envs = 1
self.num_transitions_per_env = 2
self.cleared = False
def iter_minibatches(self, num_mini_batches: int, num_epochs: int):
del num_mini_batches, num_epochs
yield self._batch
def clear(self):
self.cleared = True
def _install_mirror_stub():
module = ModuleType(
"holomotion.src.env.isaaclab_components.isaaclab_observation"
)
class MirrorFunctions:
@staticmethod
def mirror_dof(
x: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor
):
perm = perm.to(device=x.device, dtype=torch.long)
sign = sign.to(device=x.device, dtype=x.dtype)
mirrored = torch.index_select(x, dim=x.ndim - 1, index=perm)
view_shape = [1] * (mirrored.ndim - 1) + [int(sign.numel())]
return mirrored * sign.view(*view_shape)
@staticmethod
def mirror_action(
actions: torch.Tensor, *, perm: torch.Tensor, sign: torch.Tensor
):
return MirrorFunctions.mirror_dof(actions, perm=perm, sign=sign)
@staticmethod
def mirror_vec3(x: torch.Tensor):
sign = torch.tensor(
[1.0, -1.0, 1.0], device=x.device, dtype=x.dtype
)
view_shape = [1] * (x.ndim - 1) + [3]
return x * sign.view(*view_shape)
@staticmethod
def mirror_axial_vec3(x: torch.Tensor):
sign = torch.tensor(
[-1.0, 1.0, -1.0], device=x.device, dtype=x.dtype
)
view_shape = [1] * (x.ndim - 1) + [3]
return x * sign.view(*view_shape)
@staticmethod
def mirror_velocity_command(x: torch.Tensor):
if x.shape[-1] == 3:
sign = torch.tensor(
[1.0, -1.0, -1.0], device=x.device, dtype=x.dtype
)
else:
sign = torch.tensor(
[1.0, 1.0, -1.0, -1.0], device=x.device, dtype=x.dtype
)
view_shape = [1] * (x.ndim - 1) + [int(sign.numel())]
return x * sign.view(*view_shape)
module.MirrorFunctions = MirrorFunctions
sys.modules[module.__name__] = module
def test_setup_symmetry_builds_expected_dof_permutation_and_signs():
_install_mirror_stub()
algo = PPO.__new__(PPO)
algo.device = torch.device("cpu")
algo.num_actions = 5
algo.command_name = "base_velocity"
algo.symmetry_loss_enabled = True
algo.is_main_process = False
algo.config = OmegaConf.create(
{
"module_dict": {
"actor": {
"obs_schema": {
"flattened_obs": {
"seq_len": 2,
"terms": ["unified/actor_dof_pos"],
}
}
}
},
"symmetry_loss": {
"enabled": True,
"coef": 0.1,
"dof_sign_by_name": {
"left_hip_pitch_joint": 1.0,
"right_hip_pitch_joint": 1.0,
"waist_yaw_joint": -1.0,
"left_knee_joint": 1.0,
"right_knee_joint": 1.0,
},
},
}
)
algo.env = SimpleNamespace(
_env=SimpleNamespace(
scene={
"robot": SimpleNamespace(
joint_names=[
"left_hip_pitch_joint",
"right_hip_pitch_joint",
"waist_yaw_joint",
"left_knee_joint",
"right_knee_joint",
]
)
}
)
)
algo.env_config = OmegaConf.create(
{
"config": {
"robot": {
"dof_sign_by_name": {
"left_hip_pitch_joint": 1.0,
"right_hip_pitch_joint": 1.0,
"waist_yaw_joint": -1.0,
"left_knee_joint": 1.0,
"right_knee_joint": 1.0,
}
},
"obs": {
"obs_groups": {
"unified": {
"atomic_obs_list": [
{
"actor_dof_pos": {
"mirror_func": "mirror_dof",
}
}
]
}
}
},
}
}
)
algo._setup_symmetry()
assert algo._sym_dof_perm.tolist() == [1, 0, 2, 4, 3]
assert algo._sym_dof_sign.tolist() == [1.0, 1.0, -1.0, 1.0, 1.0]
def test_mirror_actor_obs_uses_slash_qualified_actor_terms_only():
_install_mirror_stub()
algo = PPO.__new__(PPO)
algo.command_name = "base_velocity"
algo.symmetry_loss_enabled = True
algo.symmetry_loss_coef = 0.1
algo._obs_mirror_map = {
"unified/actor_velocity_command": lambda x: x * 2.0,
"unified/actor_dof_pos": lambda x: x + 1.0,
}
obs_td = TensorDict.from_dict(
{
"unified": {
"actor_velocity_command": torch.tensor(
[[[1.0, 2.0, 3.0]]], dtype=torch.float32
),
"actor_dof_pos": torch.tensor(
[[[0.1, 0.2]]], dtype=torch.float32
),
"critic_dof_pos": torch.tensor(
[[9.0, 8.0]], dtype=torch.float32
),
}
},
batch_size=[1],
device="cpu",
)
mirrored = algo._mirror_actor_obs(obs_td)
torch.testing.assert_close(
mirrored["unified", "actor_velocity_command"],
torch.tensor([[[2.0, 4.0, 6.0]]], dtype=torch.float32),
)
torch.testing.assert_close(
mirrored["unified", "actor_dof_pos"],
torch.tensor([[[1.1, 1.2]]], dtype=torch.float32),
)
torch.testing.assert_close(
mirrored["unified", "critic_dof_pos"],
obs_td["unified", "critic_dof_pos"],
)
def test_update_reports_symmetry_loss_only_for_velocity_tracking():
algo = PPO.__new__(PPO)
algo.device = torch.device("cpu")
algo.accelerator = _DummyAccelerator()
algo.actor = _DummyActor(num_actions=2, mirror_offset=1.0)
algo.critic = _DummyCritic()
algo.actor_optimizer = torch.optim.SGD(algo.actor.parameters(), lr=0.01)
algo.critic_optimizer = torch.optim.SGD(algo.critic.parameters(), lr=0.01)
algo.storage = _SingleBatchStorage(
SimpleNamespace(
obs=TensorDict.from_dict(
{
"unified": {
"actor_dof_pos": torch.zeros(2, 1, 2),
"critic_dof_pos": torch.zeros(2, 2),
}
},
batch_size=[2],
device="cpu",
),
actions=torch.zeros(2, 2),
values=torch.zeros(2, 1),
advantages=torch.zeros(2, 1),
returns=torch.zeros(2, 1),
actions_log_prob=torch.zeros(2, 1),
mu=torch.zeros(2, 2),
sigma=torch.ones(2, 2),
)
)
algo.value_loss_coef = 1.0
algo.clip_param = 0.2
algo.max_grad_norm = 1.0
algo.schedule = "fixed"
algo.desired_kl = None
algo.distributed_update_mode = "legacy"
algo.num_mini_batches = 1
algo.num_learning_epochs = 1
algo.configured_num_mini_batches = 1
algo.requested_num_mini_batches = 1
algo.distributed_lr_scale_factor = 1.0
algo.entropy_coef = 0.0
algo.initial_entropy_coef = 0.0
algo.anneal_entropy = False
algo.use_clipped_value_loss = False
algo.actor_learning_rate = 1.0e-3
algo.critic_learning_rate = 1.0e-3
algo.global_advantage_norm = True
algo.is_distributed = False
algo.symmetry_loss_enabled = True
algo.symmetry_loss_coef = 0.5
algo._mirror_actor_obs = lambda obs_td: obs_td
algo._mirror_env_action = lambda actions: actions
algo._post_update_hook = lambda loss_dict: None
algo.command_name = "base_velocity"
velocity_loss = algo.update()
assert velocity_loss["symmetry_loss"] == pytest.approx(1.0)
algo.storage = _SingleBatchStorage(algo.storage._batch)
algo.command_name = "ref_motion"
non_velocity_loss = algo.update()
assert "symmetry_loss" not in non_velocity_loss
================================================
FILE: tests/test_ppo_tf_aux_keybody.py
================================================
import copy
import sys
from types import ModuleType, SimpleNamespace
from unittest import mock
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from holomotion.src.algo.algo_utils import PpoAuxTransition, RolloutStorage
from holomotion.src.algo.ppo import PPO
from holomotion.src.algo.ppo_tf import PPOTF
from holomotion.src.modules.agent_modules import (
PPOTFActor,
TensorDictAssembler,
)
from holomotion.src.modules.network_modules import GroupedMoETransformerPolicy
from holomotion.src.modules.network_modules import GroupedMoEBlock
from holomotion.src.modules.network_modules import ModernTransformerBlock
from tensordict import TensorDict
def _make_aux_policy(
*,
denoise_root_lin_vel_weight: float = 1.0e-2,
denoise_root_ang_vel_weight: float = 1.0e-2,
denoise_dof_pos_weight: float = 1.0e-2,
) -> GroupedMoETransformerPolicy:
module_config = {
"num_fine_experts": 2,
"num_shared_experts": 1,
"top_k": 1,
"routing_score_fn": "softmax",
"routing_scale": 1.0,
"use_dynamic_bias": False,
"bias_update_rate": 0.001,
"expert_bias_clip": 0.0,
"obs_embed_mlp_hidden": 16,
"d_model": 8,
"n_heads": 2,
"n_kv_heads": 1,
"use_gated_attn": False,
"n_layers": 1,
"ff_mult": 1.0,
"ff_mult_dense": 2,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"aux_state_pred": {
"enabled": True,
"w_denoise_ref_root_lin_vel": denoise_root_lin_vel_weight,
"w_denoise_ref_root_ang_vel": denoise_root_ang_vel_weight,
"w_denoise_ref_dof_pos": denoise_dof_pos_weight,
"keybody_contact_names": [
"left_knee_link",
"right_knee_link",
],
"keybody_rel_pos_names": [
"left_knee_link",
"right_knee_link",
],
},
}
return GroupedMoETransformerPolicy(
input_dim=6,
output_dim=4,
module_config_dict=module_config,
)
def _make_aux_actor() -> PPOTFActor:
actor = PPOTFActor.__new__(PPOTFActor)
nn.Module.__init__(actor)
actor.actor_module = _make_aux_policy()
actor.aux_state_pred_enabled = True
actor.aux_router_command_recon_enabled = False
actor.aux_router_switch_penalty_enabled = False
actor.obs_norm_enabled = False
actor.obs_normalizer = nn.Identity()
actor.obs_norm_clip = 0.0
actor.actor_obs_transforms = []
actor.assembler = TensorDictAssembler(
{"flat_obs": {"seq_len": 1, "terms": ["flat_obs"]}},
output_mode="flat",
)
actor.min_sigma = 0.01
actor.max_sigma = 1.0
actor.log_std = nn.Parameter(torch.zeros(4, dtype=torch.float32))
return actor
def _make_aux_command_policy(
*,
n_layers: int = 3,
dense_layer_at_last: bool = False,
enable_aux_router_command_recon: bool = True,
freeze_router: bool = False,
) -> GroupedMoETransformerPolicy:
module_config = {
"num_fine_experts": 2,
"num_shared_experts": 1,
"top_k": 1,
"routing_score_fn": "softmax",
"routing_scale": 1.0,
"use_dynamic_bias": False,
"bias_update_rate": 0.001,
"expert_bias_clip": 0.0,
"obs_embed_mlp_hidden": 16,
"d_model": 8,
"n_heads": 2,
"n_kv_heads": 1,
"use_gated_attn": False,
"n_layers": n_layers,
"ff_mult": 1.0,
"ff_mult_dense": 2,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"dense_layer_at_last": dense_layer_at_last,
"freeze_router": freeze_router,
"aux_router_command_recon": {
"enabled": enable_aux_router_command_recon,
"output_dim": 5,
"hidden_dim": 7,
},
}
return GroupedMoETransformerPolicy(
input_dim=6,
output_dim=4,
module_config_dict=module_config,
)
def _make_temporal_aux_actor() -> PPOTFActor:
actor = PPOTFActor.__new__(PPOTFActor)
nn.Module.__init__(actor)
actor.actor_module = _make_aux_command_policy()
actor.aux_state_pred_enabled = False
actor.aux_router_command_recon_enabled = True
actor.aux_router_switch_penalty_enabled = True
actor.obs_norm_enabled = False
actor.obs_normalizer = nn.Identity()
actor.obs_norm_clip = 0.0
actor.actor_obs_transforms = []
actor.assembler = TensorDictAssembler(
{"flat_obs": {"seq_len": 1, "terms": ["flat_obs"]}},
output_mode="flat",
)
actor.min_sigma = 0.01
actor.max_sigma = 1.0
actor.log_std = nn.Parameter(torch.zeros(4, dtype=torch.float32))
return actor
def _make_temporal_only_aux_actor() -> PPOTFActor:
actor = _make_temporal_aux_actor()
actor.aux_router_command_recon_enabled = False
return actor
def test_rollout_storage_allocates_ref_and_robot_keybody_targets():
original_tokens = dict(PpoAuxTransition.SHAPE_TOKENS)
PpoAuxTransition.SHAPE_TOKENS["C"] = 2
PpoAuxTransition.SHAPE_TOKENS["K"] = 8
try:
obs_template = TensorDict(
{"flat_obs": torch.zeros(2, 5)},
batch_size=[2],
)
storage = RolloutStorage(
num_envs=2,
num_transitions_per_env=3,
obs_template=obs_template,
actions_shape=(4,),
transition_cls=PpoAuxTransition,
)
finally:
PpoAuxTransition.SHAPE_TOKENS = original_tokens
assert storage.data["gt_ref_keybody_rel_pos"].shape == (3, 2, 8, 3)
assert storage.data["gt_robot_keybody_rel_pos"].shape == (3, 2, 8, 3)
assert storage.data["gt_denoise_ref_root_lin_vel"].shape == (3, 2, 3)
assert storage.data["gt_denoise_ref_root_ang_vel"].shape == (3, 2, 3)
assert storage.data["gt_denoise_ref_dof_pos"].shape == (3, 2, 4)
def test_grouped_moe_policy_returns_keybody_position_predictions():
policy = _make_aux_policy()
pre_moe_hidden = torch.randn(2, 3, policy.d_model)
outputs = policy.predict_aux_from_pre_moe(pre_moe_hidden)
assert outputs["base_lin_vel_loc"].shape == (2, 3, 3)
assert outputs["base_lin_vel_log_std"].shape == (2, 3, 3)
assert outputs["root_height_loc"].shape == (2, 3, 1)
assert outputs["root_height_log_std"].shape == (2, 3, 1)
assert outputs["keybody_contact_logits"].shape == (2, 3, 2)
assert outputs["ref_keybody_rel_pos"].shape == (2, 3, 2, 3)
assert outputs["robot_keybody_rel_pos"].shape == (2, 3, 2, 3)
assert outputs["denoise_ref_root_lin_vel_residual"].shape == (2, 3, 3)
assert outputs["denoise_ref_root_ang_vel_residual"].shape == (2, 3, 3)
assert outputs["denoise_ref_dof_pos_residual"].shape == (2, 3, 4)
def test_grouped_moe_policy_omits_denoise_predictions_when_weights_zero():
policy = _make_aux_policy(
denoise_root_lin_vel_weight=0.0,
denoise_root_ang_vel_weight=0.0,
denoise_dof_pos_weight=0.0,
)
pre_moe_hidden = torch.randn(2, 3, policy.d_model)
outputs = policy.predict_aux_from_pre_moe(pre_moe_hidden)
assert "denoise_ref_root_lin_vel_residual" not in outputs
assert "denoise_ref_root_ang_vel_residual" not in outputs
assert "denoise_ref_dof_pos_residual" not in outputs
def test_ppotf_actor_sequence_logp_emits_actor_facing_dof_denoise_keys():
actor = _make_aux_actor()
obs_td = TensorDict(
{"flat_obs": torch.randn(2, 3, 6, dtype=torch.float32)},
batch_size=[2, 3],
)
actions = torch.randn(2, 3, 4, dtype=torch.float32)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
outputs = actor(
obs_td,
actions=actions,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=False,
)
assert outputs["aux_denoise_ref_dof_pos_residual"].shape == (2, 3, 4)
assert "aux_denoise_ref_keybody_rel_pos_loc" not in outputs.keys()
assert "aux_denoise_ref_keybody_rel_pos_log_std" not in outputs.keys()
def test_grouped_moe_policy_default_layout_keeps_dense_first_and_moe_tail():
policy = _make_aux_command_policy(
enable_aux_router_command_recon=False,
)
assert len(policy.layers) == 3
assert isinstance(policy.layers[0], ModernTransformerBlock)
assert all(
isinstance(layer, GroupedMoEBlock) for layer in policy.layers[1:]
)
assert policy._num_moe_layers == 2
assert policy._last_moe_layer_idx == 2
def test_grouped_moe_policy_dense_layer_at_last_keeps_only_middle_layers_moe():
policy = _make_aux_command_policy(
n_layers=4,
dense_layer_at_last=True,
)
obs_seq = torch.randn(2, 3, 6, dtype=torch.float32)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
_, router_features = policy.sequence_mu(
obs_seq,
attn_mask=attn_mask,
return_router_features=True,
)
assert isinstance(policy.layers[0], ModernTransformerBlock)
assert isinstance(policy.layers[1], GroupedMoEBlock)
assert isinstance(policy.layers[2], GroupedMoEBlock)
assert isinstance(policy.layers[3], ModernTransformerBlock)
assert policy._num_moe_layers == 2
assert policy._last_moe_layer_idx == 2
assert router_features.shape == (2, 3, 4)
def test_grouped_moe_policy_dense_layer_at_last_allows_shallow_fully_dense():
policy = _make_aux_command_policy(
n_layers=2,
dense_layer_at_last=True,
enable_aux_router_command_recon=False,
)
assert len(policy.layers) == 2
assert all(
isinstance(layer, ModernTransformerBlock) for layer in policy.layers
)
assert policy._num_moe_layers == 0
assert policy._last_moe_layer_idx is None
def test_grouped_moe_policy_command_recon_uses_live_router_features():
policy = _make_aux_command_policy()
obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
_, router_features = policy.sequence_mu(
obs_seq,
attn_mask=attn_mask,
return_router_features=True,
)
pred = policy.predict_aux_router_command_from_router_features(
router_features
)
assert policy._num_moe_layers == 2
assert router_features.shape == (2, 3, 4)
assert pred.shape == (2, 3, 5)
assert router_features.requires_grad
pred.sum().backward()
first_moe = next(
layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock)
)
assert first_moe.last_router_distribution is not None
assert first_moe.last_router_distribution.requires_grad
assert first_moe.router.weight.grad is not None
def test_grouped_moe_policy_freeze_router_detaches_router_features_and_params():
policy = _make_aux_command_policy(freeze_router=True)
obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
_, router_features, router_temporal_features = policy.sequence_mu(
obs_seq,
attn_mask=attn_mask,
return_router_features=True,
return_router_temporal_features=True,
)
pred = policy.predict_aux_router_command_from_router_features(
router_features
)
first_moe = next(
layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock)
)
assert first_moe.freeze_router is True
assert first_moe.router.weight.requires_grad is False
assert router_features.requires_grad is False
assert router_temporal_features.requires_grad is False
pred.sum().backward()
assert first_moe.last_router_distribution is not None
assert first_moe.last_router_distribution.requires_grad is False
assert first_moe.last_router_logits is not None
assert first_moe.last_router_logits.requires_grad is False
assert first_moe.router.weight.grad is None
def test_grouped_moe_policy_loads_legacy_aux_command_recon_head_keys():
policy = _make_aux_command_policy(enable_aux_router_command_recon=True)
state_dict = copy.deepcopy(policy.state_dict())
expected_tensors = {}
for key in list(state_dict.keys()):
if "aux_router_command_recon_head." not in key:
continue
legacy_key = key.replace(
"aux_router_command_recon_head.",
"aux_command_recon_head.",
)
legacy_value = torch.randn_like(state_dict[key])
expected_tensors[key] = legacy_value
state_dict[legacy_key] = legacy_value
del state_dict[key]
result = policy.load_state_dict(state_dict, strict=True)
assert result.missing_keys == []
assert result.unexpected_keys == []
for key, expected in expected_tensors.items():
actual = policy.state_dict()[key]
assert torch.allclose(actual, expected)
def test_grouped_moe_policy_ignores_legacy_aux_command_recon_head_keys_when_disabled():
policy = _make_aux_command_policy(enable_aux_router_command_recon=False)
legacy_policy = _make_aux_command_policy(
enable_aux_router_command_recon=True
)
state_dict = copy.deepcopy(policy.state_dict())
for key, value in legacy_policy.state_dict().items():
if "aux_router_command_recon_head." not in key:
continue
legacy_key = key.replace(
"aux_router_command_recon_head.",
"aux_command_recon_head.",
)
state_dict[legacy_key] = value.clone()
result = policy.load_state_dict(state_dict, strict=True)
assert result.missing_keys == []
assert result.unexpected_keys == []
assert policy.aux_router_command_recon_head is None
def test_grouped_moe_policy_clears_router_cache_before_deepcopy():
policy = _make_aux_command_policy()
obs_seq = torch.randn(2, 3, 6, dtype=torch.float32, requires_grad=True)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
_, router_features = policy.sequence_mu(
obs_seq,
attn_mask=attn_mask,
return_router_features=True,
)
pred = policy.predict_aux_router_command_from_router_features(
router_features
)
pred.sum().backward()
first_moe = next(
layer for layer in policy.layers if isinstance(layer, GroupedMoEBlock)
)
assert first_moe.last_router_distribution is not None
policy.clear_router_distribution_cache()
copied = copy.deepcopy(policy)
copied_first_moe = next(
layer for layer in copied.layers if isinstance(layer, GroupedMoEBlock)
)
assert copied_first_moe.last_router_distribution is None
def test_grouped_moe_block_tracks_least_utilized_expert_stats():
block = GroupedMoEBlock(
d_model=8,
n_heads=2,
n_kv_heads=1,
num_fine_experts=4,
num_shared_experts=1,
top_k=1,
ff_mult=1.0,
use_qk_norm=True,
use_gated_attn=False,
attn_dropout=0.0,
mlp_dropout=0.0,
use_dynamic_bias=False,
routing_score_fn="softmax",
)
block._apply_bias_update_from_counts(torch.tensor([5, 3, 0, 2]))
assert block.last_active_expert_ratio.item() == pytest.approx(0.75)
assert block.last_max_expert_frac.item() == pytest.approx(0.5)
assert block.last_min_expert_frac.item() == pytest.approx(0.0)
assert block.last_dead_expert_ratio.item() == pytest.approx(0.25)
block._apply_bias_update_from_counts(torch.tensor([5, 3, 1, 1]))
assert block.last_min_expert_frac.item() == pytest.approx(0.1)
assert block.last_dead_expert_ratio.item() == pytest.approx(0.0)
def test_grouped_moe_block_tracks_dead_expert_margin_to_topk_loss():
block = GroupedMoEBlock(
d_model=4,
n_heads=2,
n_kv_heads=1,
num_fine_experts=3,
num_shared_experts=1,
top_k=1,
ff_mult=1.0,
use_qk_norm=True,
use_gated_attn=False,
attn_dropout=0.0,
mlp_dropout=0.0,
use_dynamic_bias=False,
routing_score_fn="softmax",
dead_expert_margin_to_topk_enabled=True,
)
topk_idx = torch.tensor([[[0], [0]]], dtype=torch.long)
dense_distribution = torch.tensor(
[[[0.8, 0.15, 0.05], [0.7, 0.2, 0.1]]], dtype=torch.float32
)
choice_scores = torch.log(dense_distribution)
loss = block._update_routed_expert_stats_and_floor_loss(
topk_idx=topk_idx,
dense_distribution=dense_distribution,
choice_scores=choice_scores,
)
expected = torch.relu(
choice_scores.gather(-1, topk_idx)[..., -1:] - choice_scores
)
expected = expected[..., 1:].sum() / 4.0
torch.testing.assert_close(loss, expected)
torch.testing.assert_close(
block.last_dead_expert_margin_to_topk_loss, expected
)
torch.testing.assert_close(
block.last_dead_expert_margin_to_topk_loss_value,
expected.detach(),
)
torch.testing.assert_close(
block.last_dead_expert_margin_to_topk_target,
choice_scores.gather(-1, topk_idx)[..., -1:].mean(),
)
torch.testing.assert_close(
block.last_dense_expert_usage,
dense_distribution.mean(dim=(0, 1)),
)
def test_grouped_moe_block_tracks_selected_expert_margin_to_unselected():
block = GroupedMoEBlock(
d_model=4,
n_heads=2,
n_kv_heads=1,
num_fine_experts=4,
num_shared_experts=1,
top_k=2,
ff_mult=1.0,
use_qk_norm=True,
use_gated_attn=False,
attn_dropout=0.0,
mlp_dropout=0.0,
use_dynamic_bias=False,
routing_score_fn="softmax",
selected_expert_margin_to_unselected_enabled=True,
selected_expert_margin_to_unselected_target=0.4,
)
topk_idx = torch.tensor([[[0, 2], [1, 0]]], dtype=torch.long)
dense_distribution = torch.tensor(
[
[
[0.42, 0.21, 0.28, 0.09],
[0.27, 0.36, 0.22, 0.15],
]
],
dtype=torch.float32,
)
choice_scores = torch.tensor(
[[[1.0, 0.3, 0.8, 0.1], [0.9, 1.2, 0.7, 0.4]]],
dtype=torch.float32,
)
block._update_routed_expert_stats_and_floor_loss(
topk_idx=topk_idx,
dense_distribution=dense_distribution,
choice_scores=choice_scores,
)
expected_margin = torch.tensor((0.5 + 0.2) / 2.0)
expected_loss = torch.tensor((0.0 + 0.2) / 2.0)
torch.testing.assert_close(
block.last_selected_expert_margin_to_unselected,
expected_margin,
)
torch.testing.assert_close(
block.last_selected_expert_margin_to_unselected_loss,
expected_loss,
)
torch.testing.assert_close(
block.last_selected_expert_margin_to_unselected_loss_value,
expected_loss,
)
def test_ppotf_summarize_moe_layer_stats_includes_least_utilized_metrics():
moe_layers = [
SimpleNamespace(
last_active_expert_ratio=torch.tensor(0.75),
last_max_expert_frac=torch.tensor(0.50),
last_min_expert_frac=torch.tensor(0.00),
last_dead_expert_ratio=torch.tensor(0.25),
last_expert_count_cv=torch.tensor(1.20),
last_selected_expert_margin_to_unselected=torch.tensor(0.30),
),
SimpleNamespace(
last_active_expert_ratio=torch.tensor(0.50),
last_max_expert_frac=torch.tensor(0.30),
last_min_expert_frac=torch.tensor(0.05),
last_dead_expert_ratio=torch.tensor(0.50),
last_expert_count_cv=torch.tensor(0.80),
last_selected_expert_margin_to_unselected=torch.tensor(0.10),
),
]
metrics = PPOTF._summarize_moe_layer_stats(moe_layers)
assert metrics["moe_active_expert_ratio"] == pytest.approx(0.625)
assert metrics["moe_max_expert_frac"] == pytest.approx(0.40)
assert metrics["moe_least_expert_frac"] == pytest.approx(0.025)
assert metrics["moe_dead_expert_ratio"] == pytest.approx(0.375)
assert metrics["moe_expert_count_cv"] == pytest.approx(1.0)
assert metrics[
"moe_selected_expert_margin_to_unselected"
] == pytest.approx(0.20)
def test_compute_routed_expert_orthogonal_loss_uses_active_experts_only():
algo = PPOTF.__new__(PPOTF)
algo.router_expert_orthogonal_min_active_usage = 0.1
algo.router_expert_orthogonal_eps = 1.0e-8
moe_layer = SimpleNamespace(
last_routed_expert_usage=torch.tensor(
[0.2, 0.12, 0.05], dtype=torch.float32
),
down_proj=torch.tensor(
[
[[1.0, 0.0]],
[[1.0, 1.0]],
[[0.0, 1.0]],
],
dtype=torch.float32,
),
)
loss, active_count, mean_offdiag = (
algo._compute_routed_expert_orthogonal_loss(
moe_layer,
dtype=torch.float32,
device=torch.device("cpu"),
)
)
active_vecs = F.normalize(
torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32),
p=2.0,
dim=-1,
eps=1.0e-8,
)
gram = active_vecs @ active_vecs.transpose(0, 1)
offdiag = gram.masked_select(~torch.eye(2, dtype=torch.bool))
torch.testing.assert_close(active_count, torch.tensor(2.0))
torch.testing.assert_close(loss, offdiag.square().sum())
torch.testing.assert_close(mean_offdiag, offdiag.abs().mean())
def test_compute_routed_expert_orthogonal_loss_returns_zero_below_two_active():
algo = PPOTF.__new__(PPOTF)
algo.router_expert_orthogonal_min_active_usage = 0.1
algo.router_expert_orthogonal_eps = 1.0e-8
moe_layer = SimpleNamespace(
last_routed_expert_usage=torch.tensor(
[0.2, 0.05, 0.01], dtype=torch.float32
),
down_proj=torch.randn(3, 1, 2, dtype=torch.float32),
)
loss, active_count, mean_offdiag = (
algo._compute_routed_expert_orthogonal_loss(
moe_layer,
dtype=torch.float32,
device=torch.device("cpu"),
)
)
torch.testing.assert_close(loss, torch.tensor(0.0))
torch.testing.assert_close(active_count, torch.tensor(1.0))
torch.testing.assert_close(mean_offdiag, torch.tensor(0.0))
def test_ppotf_actor_sequence_logp_emits_router_features_for_aux_router_losses():
actor = _make_temporal_aux_actor()
obs_td = TensorDict(
{"flat_obs": torch.randn(2, 3, 6, dtype=torch.float32)},
batch_size=[2, 3],
)
actions = torch.randn(2, 3, 4, dtype=torch.float32)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
outputs = actor(
obs_td,
actions=actions,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=False,
)
assert outputs["router_features"].shape == (2, 3, 4)
assert outputs["router_temporal_features"].shape == (2, 3, 4)
assert outputs["aux_router_command_recon"].shape == (2, 3, 5)
def test_ppotf_actor_sequence_logp_emits_only_router_features_for_temporal_only_aux():
actor = _make_temporal_only_aux_actor()
obs_td = TensorDict(
{"flat_obs": torch.randn(2, 3, 6, dtype=torch.float32)},
batch_size=[2, 3],
)
actions = torch.randn(2, 3, 4, dtype=torch.float32)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
outputs = actor(
obs_td,
actions=actions,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=False,
)
assert outputs["router_features"].shape == (2, 3, 4)
assert outputs["router_temporal_features"].shape == (2, 3, 4)
assert "aux_router_command_recon" not in outputs.keys()
def test_masked_adjacent_router_js_averages_only_valid_adjacent_tokens():
router_features = torch.tensor(
[
[
[0.8, 0.2, 0.6, 0.4],
[0.6, 0.4, 0.5, 0.5],
[0.1, 0.9, 0.4, 0.6],
]
],
dtype=torch.float32,
)
valid_tok = torch.tensor([[1.0, 1.0, 0.0]], dtype=torch.float32)
loss = PPOTF._masked_adjacent_router_js(
router_features=router_features,
valid_tok=valid_tok,
num_moe_layers=2,
num_fine_experts=2,
)
layer0_prev = torch.tensor([0.8, 0.2], dtype=torch.float32)
layer0_curr = torch.tensor([0.6, 0.4], dtype=torch.float32)
mix0 = 0.5 * (layer0_prev + layer0_curr)
js0 = 0.5 * (
(layer0_prev * (torch.log(layer0_prev) - torch.log(mix0))).sum()
+ (layer0_curr * (torch.log(layer0_curr) - torch.log(mix0))).sum()
)
layer1_prev = torch.tensor([0.6, 0.4], dtype=torch.float32)
layer1_curr = torch.tensor([0.5, 0.5], dtype=torch.float32)
mix1 = 0.5 * (layer1_prev + layer1_curr)
js1 = 0.5 * (
(layer1_prev * (torch.log(layer1_prev) - torch.log(mix1))).sum()
+ (layer1_curr * (torch.log(layer1_curr) - torch.log(mix1))).sum()
)
expected = 0.5 * (js0 + js1)
assert torch.isclose(loss, expected)
def test_masked_adjacent_router_normed_smooth_l1_averages_only_valid_adjacent_tokens():
router_temporal_features = torch.tensor(
[
[
[3.0, 1.0, 0.0],
[2.0, 0.0, 2.0],
[1.0, 1.0, 1.0],
]
],
dtype=torch.float32,
)
valid_tok = torch.tensor([[1.0, 1.0, 0.0]], dtype=torch.float32)
loss = PPOTF._masked_adjacent_router_normed_smooth_l1(
router_temporal_features=router_temporal_features,
valid_tok=valid_tok,
num_moe_layers=1,
num_fine_experts=3,
)
prev_logits = router_temporal_features[:, :1].reshape(1, 1, 1, 3)
curr_logits = router_temporal_features[:, 1:2].reshape(1, 1, 1, 3)
prev_norm = F.normalize(
prev_logits - prev_logits.mean(dim=-1, keepdim=True),
p=2.0,
dim=-1,
eps=1.0e-5,
)
curr_norm = F.normalize(
curr_logits - curr_logits.mean(dim=-1, keepdim=True),
p=2.0,
dim=-1,
eps=1.0e-5,
)
expected = F.smooth_l1_loss(
curr_norm,
prev_norm,
reduction="none",
beta=1.0,
).mean()
assert torch.isclose(loss, expected)
def test_masked_aux_keybody_mse_averages_only_valid_tokens():
pred = torch.tensor([[[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]])
target = torch.zeros_like(pred)
valid_tok = torch.tensor([[1.0, 0.0]])
loss = PPOTF._masked_aux_keybody_mse(pred, target, valid_tok)
expected = torch.tensor((1.0 + 4.0 + 9.0) / 3.0)
assert torch.isclose(loss, expected)
def test_masked_aux_huber_averages_only_valid_tokens():
pred = torch.zeros(1, 2, 1, 3)
target = torch.tensor([[[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]])
valid_tok = torch.tensor([[1.0, 0.0]])
loss = PPOTF._masked_aux_huber(
pred=pred,
target=target,
valid_tok=valid_tok,
beta=1.0,
)
expected = torch.tensor((0.5 + 1.5 + 2.5) / 3.0)
assert torch.isclose(loss, expected)
def test_setup_configs_rejects_router_aux_terms_outside_motion_tracking():
algo = PPOTF.__new__(PPOTF)
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {
"enabled": False,
},
"aux_router_switch_penalty": {"enabled": True, "weight": 1.0},
}
algo.command_name = "velocity"
with mock.patch.object(PPO, "_setup_configs", return_value=None):
with pytest.raises(ValueError, match="aux_router_switch_penalty"):
algo._setup_configs()
def test_setup_configs_rejects_unknown_router_switch_penalty_metric():
algo = PPOTF.__new__(PPOTF)
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {
"enabled": False,
},
"aux_router_switch_penalty": {
"enabled": True,
"weight": 1.0,
"metric": "not_a_metric",
},
}
algo.command_name = "ref_motion"
with mock.patch.object(PPO, "_setup_configs", return_value=None):
with pytest.raises(
ValueError, match="aux_router_switch_penalty.metric"
):
algo._setup_configs()
def test_setup_configs_reads_dead_expert_margin_to_topk_only():
algo = PPOTF.__new__(PPOTF)
algo.command_name = "ref_motion"
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
"dead_expert_margin_to_topk": {"enabled": True, "weight": 0.7},
}
with mock.patch.object(PPO, "_setup_configs", return_value=None):
algo._setup_configs()
assert algo.use_dead_expert_margin_to_topk is True
assert algo.dead_expert_margin_to_topk_weight == pytest.approx(0.7)
algo = PPOTF.__new__(PPOTF)
algo.command_name = "ref_motion"
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
}
with mock.patch.object(PPO, "_setup_configs", return_value=None):
algo._setup_configs()
assert algo.use_dead_expert_margin_to_topk is False
assert algo.dead_expert_margin_to_topk_weight == pytest.approx(0.0)
def test_setup_configs_reads_selected_expert_margin_to_unselected():
algo = PPOTF.__new__(PPOTF)
algo.command_name = "ref_motion"
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
"selected_expert_margin_to_unselected": {
"enabled": True,
"weight": 0.9,
"target": 0.3,
},
}
with mock.patch.object(PPO, "_setup_configs", return_value=None):
algo._setup_configs()
assert algo.use_selected_expert_margin_to_unselected is True
assert algo.selected_expert_margin_to_unselected_weight == pytest.approx(
0.9
)
assert algo.selected_expert_margin_to_unselected_target == pytest.approx(
0.3
)
algo = PPOTF.__new__(PPOTF)
algo.command_name = "ref_motion"
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
}
with mock.patch.object(PPO, "_setup_configs", return_value=None):
algo._setup_configs()
assert algo.use_selected_expert_margin_to_unselected is False
assert algo.selected_expert_margin_to_unselected_weight == pytest.approx(
0.0
)
assert algo.selected_expert_margin_to_unselected_target == pytest.approx(
0.0
)
def test_setup_configs_reads_aux_router_future_recon():
algo = PPOTF.__new__(PPOTF)
algo.command_name = "ref_motion"
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
"aux_router_future_recon": {
"enabled": True,
"weight": 0.7,
"hidden_dim": 13,
"huber_beta": 0.3,
},
"module_dict": {
"actor": {
"type": "ReferenceRoutedGroupedMoETransformerPolicyV3",
}
},
}
with mock.patch.object(PPO, "_setup_configs", return_value=None):
algo._setup_configs()
assert algo.use_aux_router_future_recon is True
assert algo.aux_router_future_recon_weight == pytest.approx(0.7)
assert algo.aux_router_future_recon_hidden_dim == 13
assert algo.aux_router_future_recon_huber_beta == pytest.approx(0.3)
def test_setup_configs_reads_router_expert_orthogonal():
algo = PPOTF.__new__(PPOTF)
algo.command_name = "ref_motion"
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
"dead_expert_margin_to_topk": {"enabled": True, "weight": 0.7},
"router_expert_orthogonal": {
"enabled": True,
"weight": 0.9,
"min_active_usage": 0.2,
"eps": 1.0e-6,
},
}
with mock.patch.object(PPO, "_setup_configs", return_value=None):
algo._setup_configs()
assert algo.use_router_expert_orthogonal is True
assert algo.router_expert_orthogonal_weight == pytest.approx(0.9)
assert algo.router_expert_orthogonal_min_active_usage == pytest.approx(0.2)
assert algo.router_expert_orthogonal_eps == pytest.approx(1.0e-6)
def test_setup_configs_rejects_router_expert_orthogonal_without_dead_margin():
algo = PPOTF.__new__(PPOTF)
algo.command_name = "ref_motion"
algo.config = {
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
"router_expert_orthogonal": {
"enabled": True,
"weight": 0.9,
},
}
with mock.patch.object(PPO, "_setup_configs", return_value=None):
with pytest.raises(ValueError, match="requires.*dead_expert"):
algo._setup_configs()
def test_build_transition_uses_filtered_residual_targets_for_denoise_outputs():
algo = PPOTF.__new__(PPOTF)
algo.use_aux_state_pred = True
algo.use_aux_root_height = False
algo.use_aux_denoise_ref_root_lin_vel = True
algo.use_aux_denoise_ref_root_ang_vel = True
algo.use_aux_denoise_ref_dof_pos = True
algo.aux_state_pred_num_contact_bodies = 0
algo.aux_state_pred_num_keybody_bodies = 0
algo.command_name = "ref_motion"
algo.num_envs = 2
algo.device = torch.device("cpu")
algo.transition_cls = PpoAuxTransition
world_lin_vel = torch.tensor(
[[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], dtype=torch.float32
)
world_ang_vel = torch.tensor(
[[-1.0, -2.0, -3.0], [-4.0, -5.0, -6.0]], dtype=torch.float32
)
base_lin_vel = torch.tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32
)
base_ang_vel = torch.tensor(
[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float32
)
dof_pos = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], dtype=torch.float32
)
command = SimpleNamespace(
get_ref_motion_root_global_lin_vel_cur=lambda prefix="ref_": (
world_lin_vel if prefix == "ft_ref_" else world_lin_vel + 100.0
),
get_ref_motion_root_global_ang_vel_cur=lambda prefix="ref_": (
world_ang_vel if prefix == "ft_ref_" else world_ang_vel + 100.0
),
get_ref_motion_base_linvel_cur=lambda prefix="ref_": (
base_lin_vel if prefix == "ft_ref_" else base_lin_vel + 100.0
),
get_ref_motion_base_angvel_cur=lambda prefix="ref_": (
base_ang_vel if prefix == "ft_ref_" else base_ang_vel + 100.0
),
get_ref_motion_dof_pos_cur=lambda prefix="ref_": (
dof_pos if prefix == "ft_ref_" else dof_pos + 100.0
),
)
algo.env = SimpleNamespace(
_env=SimpleNamespace(
command_manager=SimpleNamespace(get_term=lambda name: command)
)
)
obs_td = TensorDict(
{"flat_obs": torch.zeros(2, 5, dtype=torch.float32)},
batch_size=[2],
)
actor_out = TensorDict(
{
"actions": torch.zeros(2, 4, dtype=torch.float32),
"actions_log_prob": torch.zeros(2, dtype=torch.float32),
"mu": torch.zeros(2, 4, dtype=torch.float32),
"sigma": torch.ones(2, 4, dtype=torch.float32),
},
batch_size=[2],
)
critic_out = TensorDict(
{"values": torch.zeros(2, 1, dtype=torch.float32)},
batch_size=[2],
)
isaaclab_pkg = ModuleType("isaaclab")
isaaclab_envs = ModuleType("isaaclab.envs")
isaaclab_mdp = ModuleType("isaaclab.envs.mdp")
isaaclab_mdp.base_lin_vel = lambda env: torch.zeros(
2, 3, dtype=torch.float32
)
isaaclab_envs.mdp = isaaclab_mdp
isaaclab_pkg.envs = isaaclab_envs
with mock.patch.dict(
sys.modules,
{
"isaaclab": isaaclab_pkg,
"isaaclab.envs": isaaclab_envs,
"isaaclab.envs.mdp": isaaclab_mdp,
},
):
transition = algo._build_transition(obs_td, actor_out, critic_out)
torch.testing.assert_close(
transition.gt_denoise_ref_root_lin_vel,
torch.full_like(base_lin_vel, -100.0),
)
torch.testing.assert_close(
transition.gt_denoise_ref_root_ang_vel,
torch.full_like(base_ang_vel, -100.0),
)
torch.testing.assert_close(
transition.gt_denoise_ref_dof_pos, torch.full_like(dof_pos, -100.0)
)
def test_build_transition_rejects_mismatched_denoise_dof_target_shape():
algo = PPOTF.__new__(PPOTF)
algo.use_aux_state_pred = True
algo.use_aux_root_height = False
algo.use_aux_denoise_ref_root_lin_vel = False
algo.use_aux_denoise_ref_root_ang_vel = False
algo.use_aux_denoise_ref_dof_pos = True
algo.aux_state_pred_num_contact_bodies = 0
algo.aux_state_pred_num_keybody_bodies = 0
algo.command_name = "ref_motion"
algo.num_envs = 2
algo.device = torch.device("cpu")
algo.transition_cls = PpoAuxTransition
command = SimpleNamespace(
get_ref_motion_dof_pos_cur=lambda prefix="ref_": torch.zeros(
2, 5, dtype=torch.float32
)
)
algo.env = SimpleNamespace(
_env=SimpleNamespace(
command_manager=SimpleNamespace(get_term=lambda name: command)
)
)
obs_td = TensorDict(
{"flat_obs": torch.zeros(2, 5, dtype=torch.float32)},
batch_size=[2],
)
actor_out = TensorDict(
{
"actions": torch.zeros(2, 4, dtype=torch.float32),
"actions_log_prob": torch.zeros(2, dtype=torch.float32),
"mu": torch.zeros(2, 4, dtype=torch.float32),
"sigma": torch.ones(2, 4, dtype=torch.float32),
},
batch_size=[2],
)
critic_out = TensorDict(
{"values": torch.zeros(2, 1, dtype=torch.float32)},
batch_size=[2],
)
isaaclab_pkg = ModuleType("isaaclab")
isaaclab_envs = ModuleType("isaaclab.envs")
isaaclab_mdp = ModuleType("isaaclab.envs.mdp")
isaaclab_mdp.base_lin_vel = lambda env: torch.zeros(
2, 3, dtype=torch.float32
)
isaaclab_envs.mdp = isaaclab_mdp
isaaclab_pkg.envs = isaaclab_envs
with mock.patch.dict(
sys.modules,
{
"isaaclab": isaaclab_pkg,
"isaaclab.envs": isaaclab_envs,
"isaaclab.envs.mdp": isaaclab_mdp,
},
):
with pytest.raises(ValueError, match="gt_denoise_ref_dof_pos"):
algo._build_transition(obs_td, actor_out, critic_out)
def test_compute_aux_router_future_recon_loss_uses_normalized_future_targets():
algo = PPOTF.__new__(PPOTF)
algo.aux_router_future_recon_huber_beta = 0.5
obs_schema = {
"flattened_obs_fut": {
"seq_len": 2,
"terms": [
"unified/actor_ref_base_linvel_fut",
"unified/actor_ref_dof_pos_fut",
],
}
}
obs_b = TensorDict(
{
"unified": TensorDict(
{
"actor_ref_base_linvel_fut": torch.tensor(
[
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]],
]
],
dtype=torch.float32,
),
"actor_ref_dof_pos_fut": torch.tensor(
[
[
[[0.1, 0.2], [0.3, 0.4]],
[[0.5, 0.6], [0.7, 0.8]],
[[0.9, 1.0], [1.1, 1.2]],
]
],
dtype=torch.float32,
),
},
batch_size=[1, 3],
)
},
batch_size=[1, 3],
)
assembler = TensorDictAssembler(obs_schema, output_mode="flat")
class _DummyPolicy(nn.Module):
def normalize_aux_router_future_recon_target(
self, future_target: torch.Tensor
) -> torch.Tensor:
return future_target * 0.25
actor_wrapper = SimpleNamespace(
aux_router_future_recon_assembler=assembler,
actor_module=_DummyPolicy(),
)
raw_target = assembler(obs_b.flatten(0, 1)).reshape(1, 3, -1)
normalized_target = raw_target * 0.25
pred = normalized_target + torch.tensor(
[
[
[0.0] * raw_target.shape[-1],
[0.5] * raw_target.shape[-1],
[1.0] * raw_target.shape[-1],
]
],
dtype=torch.float32,
)
actor_out = TensorDict(
{"aux_router_future_recon": pred},
batch_size=[1, 3],
)
valid_tok = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
loss = algo._compute_aux_router_future_recon_loss(
actor_wrapper=actor_wrapper,
actor_out=actor_out,
obs_b=obs_b,
valid_tok=valid_tok,
)
expected = PPOTF._masked_aux_huber(
pred=pred,
target=normalized_target,
valid_tok=valid_tok,
beta=0.5,
)
assert torch.isclose(loss, expected)
def test_root_relative_body_pos_uses_consistent_environment_frame():
body_pos_w = torch.tensor(
[[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32
)
root_pos_env = torch.zeros(1, 3, dtype=torch.float32)
root_quat_w = torch.tensor([[1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)
env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32)
rel = PPOTF._root_relative_body_pos_from_mixed_position_frames(
body_pos_w=body_pos_w,
root_pos_env=root_pos_env,
root_quat_w=root_quat_w,
env_origins=env_origins,
)
expected = torch.tensor(
[[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], dtype=torch.float32
)
assert torch.allclose(rel, expected)
================================================
FILE: tests/test_ref_router_actor.py
================================================
from __future__ import annotations
import pytest
import torch
from holomotion.src.algo.ppo_tf import PPOTF
from holomotion.src.modules.agent_modules import PPOTFRefRouterActor
from tensordict import TensorDict
def _make_ref_router_obs_schema() -> dict:
return {
"flattened_obs": {
"seq_len": 1,
"terms": [
"unified/actor_ref_dof_pos_cur",
"unified/actor_dof_pos",
"unified/actor_ref_root_height_cur",
"unified/actor_last_action",
],
},
"flattened_obs_fut": {
"seq_len": 2,
"terms": [
"unified/actor_ref_dof_pos_fut",
"unified/actor_ref_root_height_fut",
],
},
}
def _make_ref_router_obs(batch_size: list[int]) -> TensorDict:
shape = list(batch_size)
fut_shape = shape + [2]
unified = TensorDict(
{
"actor_ref_dof_pos_cur": torch.randn(*shape, 2),
"actor_dof_pos": torch.randn(*shape, 3),
"actor_ref_root_height_cur": torch.randn(*shape, 1),
"actor_last_action": torch.randn(*shape, 2),
"actor_ref_dof_pos_fut": torch.randn(*fut_shape, 2),
"actor_ref_root_height_fut": torch.randn(*fut_shape, 1),
},
batch_size=shape,
)
return TensorDict({"unified": unified}, batch_size=shape)
def _make_ref_router_actor(
*,
num_actions: int = 4,
freeze_router: bool = False,
aux_router_future_recon: dict | None = None,
) -> PPOTFRefRouterActor:
obs_schema = _make_ref_router_obs_schema()
obs_example = _make_ref_router_obs([2])
module_config = {
"type": "ReferenceRoutedGroupedMoETransformerPolicy",
"num_fine_experts": 3,
"num_shared_experts": 1,
"top_k": 1,
"routing_score_fn": "softmax",
"routing_scale": 1.0,
"use_dynamic_bias": False,
"bias_update_rate": 0.001,
"expert_bias_clip": 0.0,
"obs_embed_mlp_hidden": 16,
"d_model": 8,
"n_layers": 2,
"n_heads": 2,
"n_kv_heads": 1,
"use_gated_attn": False,
"use_qk_norm": True,
"ff_mult": 1.0,
"ff_mult_dense": 2,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"freeze_router": freeze_router,
"obs_norm": {"enabled": False},
"output_dim": num_actions,
"aux_router_future_recon": aux_router_future_recon
or {"enabled": False},
}
return PPOTFRefRouterActor(
obs_schema=obs_schema,
module_config_dict=module_config,
num_actions=num_actions,
init_noise_std=0.2,
obs_example=obs_example,
)
def test_ref_router_actor_infers_only_actor_ref_feature_indices():
obs_schema = _make_ref_router_obs_schema()
obs_example = _make_ref_router_obs([2])
indices = PPOTFRefRouterActor.infer_router_feature_indices(
obs_schema, obs_example
)
assert indices == [0, 1, 5, 8, 9, 10, 11, 12, 13]
def test_ref_router_actor_single_step_and_sequence_logp_match_contract():
actor = _make_ref_router_actor()
obs_td = _make_ref_router_obs([2])
inference_out = actor(
obs_td,
mode="inference",
update_obs_norm=False,
)
assert inference_out["actions"].shape == (2, 4)
assert inference_out["mu"].shape == (2, 4)
assert inference_out["sigma"].shape == (2, 4)
cache_shape = actor.onnx_past_key_values_shape(batch_size=2)
past_key_values = torch.zeros(*cache_shape, dtype=torch.float32)
step_idx = torch.zeros(2, dtype=torch.long)
with torch.no_grad():
actions, present = actor(
obs_td,
past_key_values=past_key_values,
current_pos=step_idx,
)
assert actions.shape == (2, 4)
assert present.shape == cache_shape
obs_seq = _make_ref_router_obs([2, 3])
actions_seq = torch.randn(2, 3, 4)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
seq_out = actor(
obs_seq,
actions=actions_seq,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=False,
)
assert seq_out["mu"].shape == (2, 3, 4)
assert seq_out["sigma"].shape == (2, 3, 4)
assert seq_out["actions_log_prob"].shape == (2, 3, 1)
assert seq_out["entropy"].shape == (2, 3, 1)
def test_ref_router_actor_rejects_aux_router_future_recon():
with pytest.raises(
ValueError,
match="does not support aux_router_future_recon",
):
_make_ref_router_actor(
aux_router_future_recon={"enabled": True, "weight": 1.0}
)
def test_ppotf_select_actor_wrapper_rejects_ref_router_cross_attn():
with pytest.raises(
ValueError,
match="ReferenceRoutedGroupedMoETransformerPolicy",
):
PPOTF._select_actor_wrapper_cls(
{
"type": "ReferenceRoutedGroupedMoETransformerPolicy",
"use_future_cross_attn": True,
}
)
def test_ref_router_actor_freeze_router_freezes_router_obs_embed():
actor = _make_ref_router_actor(freeze_router=True)
module = actor.actor_module
assert module.freeze_router is True
assert module.router_obs_embed[0].weight.requires_grad is False
assert module.router_obs_embed[0].bias.requires_grad is False
assert module.router_obs_embed[2].weight.requires_grad is False
assert module.router_obs_embed[2].bias.requires_grad is False
def test_ref_router_actor_freeze_router_reapplies_after_load_state_dict():
actor = _make_ref_router_actor(freeze_router=True)
module = actor.actor_module
state_dict = module.state_dict()
module.router_obs_embed.requires_grad_(True)
for layer in module.layers:
if hasattr(layer, "router"):
layer.router.requires_grad_(True)
result = module.load_state_dict(state_dict, strict=True)
assert result.missing_keys == []
assert result.unexpected_keys == []
assert module.router_obs_embed[0].weight.requires_grad is False
assert module.router_obs_embed[0].bias.requires_grad is False
assert module.router_obs_embed[2].weight.requires_grad is False
assert module.router_obs_embed[2].bias.requires_grad is False
for layer in module.layers:
if hasattr(layer, "router"):
assert layer.router.weight.requires_grad is False
================================================
FILE: tests/test_ref_router_seq_actor.py
================================================
from __future__ import annotations
import pytest
import torch
from holomotion.src.algo.ppo_tf import PPOTF
from holomotion.src.modules.agent_modules import (
PPOTFRefRouterSeqActor,
PPOTFRefRouterV3Actor,
)
from tensordict import TensorDict
REF_CUR_TERM_DIMS = {
"actor_ref_gravity_projection_cur": 3,
"actor_ref_base_linvel_cur": 3,
"actor_ref_base_angvel_cur": 3,
"actor_ref_dof_pos_cur": 2,
"actor_ref_root_height_cur": 1,
}
REF_FUT_TERM_DIMS = {
"actor_ref_gravity_projection_fut": 3,
"actor_ref_base_linvel_fut": 3,
"actor_ref_base_angvel_fut": 3,
"actor_ref_dof_pos_fut": 2,
"actor_ref_root_height_fut": 1,
}
def _make_ref_router_v2_obs_schema(
*,
include_ref_cur: bool = True,
include_ref_fut: bool = True,
) -> dict:
flat_terms = []
if include_ref_cur:
flat_terms.extend(
[
"unified/actor_ref_gravity_projection_cur",
"unified/actor_ref_base_linvel_cur",
"unified/actor_ref_base_angvel_cur",
"unified/actor_ref_dof_pos_cur",
"unified/actor_ref_root_height_cur",
]
)
flat_terms.extend(
[
"unified/actor_projected_gravity",
"unified/actor_rel_robot_root_ang_vel",
"unified/actor_dof_pos",
"unified/actor_dof_vel",
"unified/actor_last_action",
]
)
schema = {
"flattened_obs": {"seq_len": 1, "terms": flat_terms},
}
if include_ref_fut:
schema["flattened_obs_fut"] = {
"seq_len": 5,
"terms": [
"unified/actor_ref_gravity_projection_fut",
"unified/actor_ref_base_linvel_fut",
"unified/actor_ref_base_angvel_fut",
"unified/actor_ref_dof_pos_fut",
"unified/actor_ref_root_height_fut",
],
}
return schema
def _make_ref_router_v2_obs(batch_size: list[int]) -> TensorDict:
shape = list(batch_size)
fut_shape = shape + [5]
unified = TensorDict(
{
"actor_ref_gravity_projection_cur": torch.randn(*shape, 3),
"actor_ref_base_linvel_cur": torch.randn(*shape, 3),
"actor_ref_base_angvel_cur": torch.randn(*shape, 3),
"actor_ref_dof_pos_cur": torch.randn(*shape, 2),
"actor_ref_root_height_cur": torch.randn(*shape, 1),
"actor_projected_gravity": torch.randn(*shape, 3),
"actor_rel_robot_root_ang_vel": torch.randn(*shape, 3),
"actor_dof_pos": torch.randn(*shape, 4),
"actor_dof_vel": torch.randn(*shape, 4),
"actor_last_action": torch.randn(*shape, 2),
"actor_ref_gravity_projection_fut": torch.randn(*fut_shape, 3),
"actor_ref_base_linvel_fut": torch.randn(*fut_shape, 3),
"actor_ref_base_angvel_fut": torch.randn(*fut_shape, 3),
"actor_ref_dof_pos_fut": torch.randn(*fut_shape, 2),
"actor_ref_root_height_fut": torch.randn(*fut_shape, 1),
},
batch_size=shape,
)
return TensorDict({"unified": unified}, batch_size=shape)
def _make_ref_router_v2_actor(
*,
obs_schema: dict | None = None,
num_actions: int = 4,
aux_state_pred: dict | None = None,
aux_router_command_recon: dict | None = None,
freeze_router: bool = False,
) -> PPOTFRefRouterSeqActor:
obs_schema = (
_make_ref_router_v2_obs_schema() if obs_schema is None else obs_schema
)
obs_example = _make_ref_router_v2_obs([2])
module_config = {
"type": "ReferenceRoutedGroupedMoETransformerPolicyV2",
"num_fine_experts": 3,
"num_shared_experts": 1,
"top_k": 1,
"routing_score_fn": "softmax",
"routing_scale": 1.0,
"use_dynamic_bias": False,
"bias_update_rate": 0.001,
"expert_bias_clip": 0.0,
"obs_embed_mlp_hidden": 16,
"d_model": 8,
"n_layers": 2,
"n_heads": 2,
"n_kv_heads": 1,
"use_gated_attn": False,
"use_qk_norm": True,
"ff_mult": 1.0,
"ff_mult_dense": 2,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"freeze_router": freeze_router,
"ref_hist_n_layers": 1,
"ref_future_conv_channels": 8,
"ref_future_conv_layers": 2,
"ref_future_conv_kernel_size": 3,
"ref_future_conv_stride": 2,
"obs_norm": {"enabled": False},
"output_dim": num_actions,
"aux_state_pred": aux_state_pred or {"enabled": False},
"aux_router_command_recon": aux_router_command_recon
or {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
}
return PPOTFRefRouterSeqActor(
obs_schema=obs_schema,
module_config_dict=module_config,
num_actions=num_actions,
init_noise_std=0.2,
obs_example=obs_example,
)
def _make_ref_router_v3_actor(
*,
obs_schema: dict | None = None,
num_actions: int = 4,
freeze_router: bool = False,
aux_router_future_recon: dict | None = None,
) -> PPOTFRefRouterV3Actor:
obs_schema = (
_make_ref_router_v2_obs_schema() if obs_schema is None else obs_schema
)
obs_example = _make_ref_router_v2_obs([2])
module_config = {
"type": "ReferenceRoutedGroupedMoETransformerPolicyV3",
"num_fine_experts": 3,
"num_shared_experts": 1,
"top_k": 1,
"routing_score_fn": "softmax",
"routing_scale": 1.0,
"use_dynamic_bias": False,
"bias_update_rate": 0.001,
"expert_bias_clip": 0.0,
"obs_embed_mlp_hidden": 16,
"d_model": 8,
"n_layers": 2,
"n_heads": 2,
"n_kv_heads": 1,
"use_gated_attn": False,
"use_qk_norm": True,
"ff_mult": 1.0,
"ff_mult_dense": 2,
"attn_dropout": 0.0,
"mlp_dropout": 0.0,
"max_ctx_len": 4,
"freeze_router": freeze_router,
"ref_hist_n_layers": 1,
"router_future_hidden_dim": 12,
"router_layer_proj_hidden_dim": 10,
"obs_norm": {"enabled": False},
"output_dim": num_actions,
"aux_state_pred": {"enabled": False},
"aux_router_command_recon": {"enabled": False},
"aux_router_future_recon": aux_router_future_recon
or {"enabled": False},
"aux_router_switch_penalty": {"enabled": False},
}
return PPOTFRefRouterV3Actor(
obs_schema=obs_schema,
module_config_dict=module_config,
num_actions=num_actions,
init_noise_std=0.2,
obs_example=obs_example,
)
def test_ppotf_select_actor_wrapper_uses_ref_router_seq_actor():
actor_cls = PPOTF._select_actor_wrapper_cls(
{"type": "ReferenceRoutedGroupedMoETransformerPolicyV2"}
)
assert actor_cls is PPOTFRefRouterSeqActor
def test_ppotf_select_actor_wrapper_uses_ref_router_v3_actor():
actor_cls = PPOTF._select_actor_wrapper_cls(
{"type": "ReferenceRoutedGroupedMoETransformerPolicyV3"}
)
assert actor_cls is PPOTFRefRouterV3Actor
def test_ref_router_seq_actor_infers_shared_ref_partitions_without_router_schemas():
actor = _make_ref_router_v2_actor()
assert actor.state_obs_input_dim > 0
assert actor.ref_cur_token_dim == sum(REF_CUR_TERM_DIMS.values())
assert actor.ref_fut_token_dim == sum(REF_FUT_TERM_DIMS.values())
assert actor.ref_fut_seq_len == 5
cache_shape = actor.onnx_past_key_values_shape(batch_size=2)
assert len(cache_shape) == 6
assert cache_shape[0] == actor.actor_module.onnx_kv_layers
assert cache_shape[1] == 2
assert cache_shape[2] == 2
assert cache_shape[-1] == 4
def test_ref_router_v3_actor_keeps_full_obs_backbone_and_layer_router_adapters():
actor = _make_ref_router_v3_actor()
module = actor.actor_module
assert actor.full_obs_input_dim > actor.state_obs_input_dim
assert module.full_obs_input_dim == actor.full_obs_input_dim
assert module.obs_embed[0].in_features == actor.full_obs_input_dim
assert len(module.router_layer_projections) == sum(
isinstance(layer, type(module.layers[1])) for layer in module.layers
)
def test_ref_router_v3_history_backbone_consumes_flat_ref_motion():
actor = _make_ref_router_v3_actor()
module = actor.actor_module
x = torch.randn(2, 3, module.full_obs_input_dim)
_, ref_cur_x, ref_fut_x = module._split_actor_ref_inputs(x)
assert hasattr(module, "_build_router_ref_motion")
ref_motion_x = module._build_router_ref_motion(ref_cur_x, ref_fut_x)
assert ref_motion_x.shape == (
2,
3,
actor.ref_cur_token_dim
+ actor.ref_fut_seq_len * actor.ref_fut_token_dim,
)
assert module.ref_frame_embed[0].in_features == ref_motion_x.shape[-1]
assert not hasattr(module, "router_future_obs_embed")
assert not hasattr(module, "router_future_pool")
assert not hasattr(module, "router_summary_fusion")
assert not hasattr(module, "router_summary_norm")
def test_ref_router_v3_actor_sequence_logp_emits_aux_router_future_recon():
actor = _make_ref_router_v3_actor(
aux_router_future_recon={
"enabled": True,
"hidden_dim": 9,
"weight": 1.0,
}
)
obs_td = _make_ref_router_v2_obs([2, 3])
actions_seq = torch.randn(2, 3, 4)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
seq_out = actor(
obs_td,
actions=actions_seq,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=False,
)
assert seq_out["aux_router_future_recon"].shape == (
2,
3,
actor.ref_fut_seq_len * actor.ref_fut_token_dim,
)
def test_ref_router_v3_actor_updates_future_recon_empirical_normalizer():
actor = _make_ref_router_v3_actor(
aux_router_future_recon={
"enabled": True,
"hidden_dim": 9,
"weight": 1.0,
}
)
obs_td = _make_ref_router_v2_obs([2, 3])
actions_seq = torch.randn(2, 3, 4)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
assert actor.aux_router_future_recon_assembler is not None
assert (
int(actor.actor_module.aux_router_future_recon_normalizer.count) == 0
)
actor(
obs_td,
actions=actions_seq,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=True,
)
assert (
int(actor.actor_module.aux_router_future_recon_normalizer.count) == 6
)
def test_ref_router_v2_freeze_router_freezes_reference_router_path():
actor = _make_ref_router_v2_actor(freeze_router=True)
module = actor.actor_module
x = torch.randn(2, 3, module.full_obs_input_dim, requires_grad=True)
mu, router_h, router_temporal_features = module.sequence_mu(
x,
return_ref_aux_hidden=True,
return_router_temporal_features=True,
)
mu.sum().backward()
first_moe = next(
layer for layer in module.layers if hasattr(layer, "router")
)
assert module.freeze_router is True
assert module.ref_frame_embed[0].weight.requires_grad is False
assert module.ref_hist_attn.q_proj.weight.requires_grad is False
assert module.ref_future_conv[0].weight.requires_grad is False
assert module.router_ref_pool.q_proj.weight.requires_grad is False
assert module.router_query.requires_grad is False
assert first_moe.router.weight.requires_grad is False
assert router_h.requires_grad is False
assert router_temporal_features.requires_grad is False
assert module.ref_frame_embed[0].weight.grad is None
assert module.ref_hist_attn.q_proj.weight.grad is None
assert module.ref_future_conv[0].weight.grad is None
assert module.router_ref_pool.q_proj.weight.grad is None
assert module.router_query.grad is None
assert first_moe.router.weight.grad is None
def test_ref_router_v3_freeze_router_reapplies_after_load_state_dict():
actor = _make_ref_router_v3_actor(freeze_router=True)
module = actor.actor_module
state_dict = module.state_dict()
module.ref_frame_embed.requires_grad_(True)
module.ref_hist_attn.requires_grad_(True)
module.router_layer_projections.requires_grad_(True)
for layer in module.layers:
if hasattr(layer, "router"):
layer.router.requires_grad_(True)
result = module.load_state_dict(state_dict, strict=True)
assert result.missing_keys == []
assert result.unexpected_keys == []
assert module.ref_frame_embed[0].weight.requires_grad is False
assert module.ref_hist_attn.q_proj.weight.requires_grad is False
assert module.router_layer_projections[0][1].weight.requires_grad is False
assert not hasattr(module, "router_future_obs_embed")
assert not hasattr(module, "router_future_pool")
assert not hasattr(module, "router_summary_fusion")
assert not hasattr(module, "router_summary_norm")
for layer in module.layers:
if hasattr(layer, "router"):
assert layer.router.weight.requires_grad is False
def test_ref_router_v2_film_head_starts_near_identity():
actor = _make_ref_router_v2_actor()
module = actor.actor_module
assert hasattr(module, "_actor_film_gain")
gains = module._actor_film_gain().detach()
assert gains.shape == (module.d_model,)
assert torch.allclose(gains, torch.full_like(gains, 0.05), atol=1.0e-5)
last_linear = module.actor_ref_film[-1]
assert torch.count_nonzero(last_linear.weight.detach()) == 0
assert torch.count_nonzero(last_linear.bias.detach()) == 0
hidden = torch.randn(2, 1, module.d_model)
actor_ref_ctx = torch.randn(2, 1, module.d_model)
conditioned = module._apply_actor_ref_film(hidden, actor_ref_ctx)
assert torch.allclose(conditioned, hidden)
def test_ref_router_v2_pre_moe_hidden_precedes_film_modulation():
actor = _make_ref_router_v2_actor()
module = actor.actor_module
with torch.no_grad():
module.actor_film_gain_raw.fill_(100.0)
module.actor_ref_film[-1].weight.zero_()
module.actor_ref_film[-1].bias.fill_(2.0)
x = torch.randn(2, module.full_obs_input_dim)
x_seq = x[:, None, :]
state_x, ref_cur_x, ref_fut_x = module._split_actor_ref_inputs(x_seq)
state_h = module.obs_embed(state_x)
ref_cur_h = module.ref_frame_embed(ref_cur_x)
ref_hist_attn = module.ref_hist_attn(
module.ref_hist_norm(ref_cur_h),
*module.get_cos_sin(ref_cur_h, torch.zeros(2, 1, dtype=torch.long)),
mask=None,
)
ref_hist_h = module.ref_hist_out_norm(ref_cur_h + ref_hist_attn)
ref_fut_tokens = module._encode_future_tokens(ref_fut_x)
shared_ref_tokens = torch.cat(
[ref_hist_h.unsqueeze(2), ref_fut_tokens], dim=2
)
router_h = module._pool_router_context(shared_ref_tokens)
cos, sin = module.get_cos_sin(state_h, torch.zeros(2, 1, dtype=torch.long))
block0_hidden = module._forward_layers_range(
state_h,
cos=cos,
sin=sin,
mask=None,
router_h=router_h,
start_layer=0,
end_layer=1,
)
_, pre_moe_hidden = module.sequence_mu(
x_seq,
return_pre_moe_hidden=True,
)
assert torch.allclose(pre_moe_hidden, block0_hidden)
def test_ref_router_v2_film_gain_is_bounded_per_channel():
actor = _make_ref_router_v2_actor()
module = actor.actor_module
assert hasattr(module, "actor_film_gain_raw")
with torch.no_grad():
module.actor_film_gain_raw.copy_(
torch.linspace(-100.0, 100.0, module.d_model)
)
gains = module._actor_film_gain()
assert gains.shape == (module.d_model,)
assert torch.all(gains >= 0.0)
assert torch.all(gains <= module.actor_film_gain_max + 1.0e-6)
assert torch.unique(gains).numel() > 1
def test_ref_router_v2_film_perturbation_rms_stays_bounded():
actor = _make_ref_router_v2_actor()
module = actor.actor_module
assert hasattr(module, "actor_film_gain_raw")
with torch.no_grad():
module.actor_ref_film[-1].weight.zero_()
module.actor_ref_film[-1].bias.fill_(100.0)
module.actor_film_gain_raw.fill_(100.0)
hidden = torch.randn(4, 3, module.d_model)
actor_ref_ctx = torch.randn(4, 3, module.d_model)
conditioned = module._apply_actor_ref_film(hidden, actor_ref_ctx)
delta = conditioned - hidden
delta_rms = delta.pow(2).mean(dim=-1).sqrt()
assert torch.all(delta_rms <= module.actor_film_gain_max + 1.0e-5)
def test_ref_router_v2_aux_prediction_stays_bound_to_returned_pre_moe_hidden():
actor = _make_ref_router_v2_actor(
aux_state_pred={
"enabled": True,
"w_base_lin_vel": 1.0,
"w_keybody_contact": 1.0,
"w_ref_keybody_rel_pos": 1.0,
"w_robot_keybody_rel_pos": 1.0,
"keybody_contact_names": ["knee"],
"keybody_rel_pos_names": ["knee"],
}
)
module = actor.actor_module
module.eval()
x_a = torch.randn(1, 1, module.full_obs_input_dim)
x_b = x_a + 0.5
with torch.no_grad():
_, pre_a, ref_aux_a = module.sequence_mu(
x_a,
return_pre_moe_hidden=True,
return_ref_aux_hidden=True,
)
aux_a = module.predict_aux_from_pre_moe(
pre_a, ref_aux_hidden=ref_aux_a
)
_, pre_b, ref_aux_b = module.sequence_mu(
x_b,
return_pre_moe_hidden=True,
return_ref_aux_hidden=True,
)
aux_a_late = module.predict_aux_from_pre_moe(
pre_a, ref_aux_hidden=ref_aux_a
)
aux_b = module.predict_aux_from_pre_moe(
pre_b, ref_aux_hidden=ref_aux_b
)
assert torch.allclose(
aux_a_late["ref_keybody_rel_pos"], aux_a["ref_keybody_rel_pos"]
)
assert torch.allclose(
aux_a_late["base_lin_vel_loc"], aux_a["base_lin_vel_loc"]
)
assert not torch.allclose(
aux_a["ref_keybody_rel_pos"], aux_b["ref_keybody_rel_pos"]
)
assert not hasattr(pre_a, "_ref_aux_hidden")
def test_ref_router_v2_sequence_single_step_and_cached_onnx_agree():
actor = _make_ref_router_v2_actor()
module = actor.actor_module
module.eval()
x_seq = torch.randn(1, 2, module.full_obs_input_dim)
attn_mask = torch.tril(torch.ones(2, 2, dtype=torch.bool)).unsqueeze(0)
with torch.no_grad():
mu_seq = module.sequence_mu(x_seq, attn_mask=attn_mask)
module.reset_kv_cache(num_envs=1, device=x_seq.device)
mu_step_0 = module.single_step_mu(x_seq[:, 0, :])
mu_step_1 = module.single_step_mu(x_seq[:, 1, :])
mu_single_step = torch.stack([mu_step_0, mu_step_1], dim=1)
cache_shape = actor.onnx_past_key_values_shape(batch_size=1)
past_key_values = torch.zeros(*cache_shape, dtype=x_seq.dtype)
step_0 = torch.zeros(1, dtype=torch.long)
step_1 = torch.ones(1, dtype=torch.long)
mu_onnx_0, present_0 = module.forward(
x_seq[:, 0, :],
past_key_values=past_key_values,
current_pos=step_0,
)
mu_onnx_1, present_1 = module.forward(
x_seq[:, 1, :],
past_key_values=present_0,
current_pos=step_1,
)
mu_onnx = torch.stack([mu_onnx_0, mu_onnx_1], dim=1)
assert torch.allclose(mu_single_step, mu_seq, atol=1.0e-5, rtol=1.0e-4)
assert torch.allclose(mu_onnx, mu_seq, atol=1.0e-5, rtol=1.0e-4)
assert present_0.shape == cache_shape
assert present_1.shape == cache_shape
def test_ref_router_v3_sequence_single_step_and_cached_onnx_agree():
actor = _make_ref_router_v3_actor()
module = actor.actor_module
module.eval()
x_seq = torch.randn(1, 2, module.full_obs_input_dim)
attn_mask = torch.tril(torch.ones(2, 2, dtype=torch.bool)).unsqueeze(0)
with torch.no_grad():
mu_seq = module.sequence_mu(x_seq, attn_mask=attn_mask)
module.reset_kv_cache(num_envs=1, device=x_seq.device)
mu_step_0 = module.single_step_mu(x_seq[:, 0, :])
mu_step_1 = module.single_step_mu(x_seq[:, 1, :])
mu_single_step = torch.stack([mu_step_0, mu_step_1], dim=1)
cache_shape = actor.onnx_past_key_values_shape(batch_size=1)
past_key_values = torch.zeros(*cache_shape, dtype=x_seq.dtype)
step_0 = torch.zeros(1, dtype=torch.long)
step_1 = torch.ones(1, dtype=torch.long)
mu_onnx_0, present_0 = module.forward(
x_seq[:, 0, :],
past_key_values=past_key_values,
current_pos=step_0,
)
mu_onnx_1, present_1 = module.forward(
x_seq[:, 1, :],
past_key_values=present_0,
current_pos=step_1,
)
mu_onnx = torch.stack([mu_onnx_0, mu_onnx_1], dim=1)
assert torch.allclose(mu_single_step, mu_seq, atol=1.0e-5, rtol=1.0e-4)
assert torch.allclose(mu_onnx, mu_seq, atol=1.0e-5, rtol=1.0e-4)
assert present_0.shape == cache_shape
assert present_1.shape == cache_shape
def test_ref_router_seq_actor_single_step_and_sequence_logp_match_contract():
actor = _make_ref_router_v2_actor()
obs_td = _make_ref_router_v2_obs([2])
inference_out = actor(
obs_td,
mode="inference",
update_obs_norm=False,
)
assert inference_out["actions"].shape == (2, 4)
assert inference_out["mu"].shape == (2, 4)
assert inference_out["sigma"].shape == (2, 4)
cache_shape = actor.onnx_past_key_values_shape(batch_size=2)
past_key_values = torch.zeros(*cache_shape, dtype=torch.float32)
step_idx = torch.zeros(2, dtype=torch.long)
with torch.no_grad():
actions, present = actor(
obs_td,
past_key_values=past_key_values,
current_pos=step_idx,
)
assert actions.shape == (2, 4)
assert present.shape == cache_shape
obs_seq = _make_ref_router_v2_obs([2, 3])
actions_seq = torch.randn(2, 3, 4)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
seq_out = actor(
obs_seq,
actions=actions_seq,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=False,
)
assert seq_out["mu"].shape == (2, 3, 4)
assert seq_out["sigma"].shape == (2, 3, 4)
assert seq_out["actions_log_prob"].shape == (2, 3, 1)
assert seq_out["entropy"].shape == (2, 3, 1)
def test_ref_router_v3_actor_single_step_and_sequence_logp_match_contract():
actor = _make_ref_router_v3_actor()
obs_td = _make_ref_router_v2_obs([2])
inference_out = actor(
obs_td,
mode="inference",
update_obs_norm=False,
)
assert inference_out["actions"].shape == (2, 4)
assert inference_out["mu"].shape == (2, 4)
assert inference_out["sigma"].shape == (2, 4)
cache_shape = actor.onnx_past_key_values_shape(batch_size=2)
past_key_values = torch.zeros(*cache_shape, dtype=torch.float32)
step_idx = torch.zeros(2, dtype=torch.long)
with torch.no_grad():
actions, present = actor(
obs_td,
past_key_values=past_key_values,
current_pos=step_idx,
)
assert actions.shape == (2, 4)
assert present.shape == cache_shape
obs_seq = _make_ref_router_v2_obs([2, 3])
actions_seq = torch.randn(2, 3, 4)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
seq_out = actor(
obs_seq,
actions=actions_seq,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=False,
)
assert seq_out["mu"].shape == (2, 3, 4)
assert seq_out["sigma"].shape == (2, 3, 4)
assert seq_out["actions_log_prob"].shape == (2, 3, 1)
assert seq_out["entropy"].shape == (2, 3, 1)
def test_ref_router_seq_actor_sequence_logp_emits_aux_preds_without_metadata():
actor = _make_ref_router_v2_actor(
aux_state_pred={
"enabled": True,
"w_base_lin_vel": 1.0,
"w_keybody_contact": 1.0,
"w_ref_keybody_rel_pos": 1.0,
"w_robot_keybody_rel_pos": 1.0,
"keybody_contact_names": ["knee"],
"keybody_rel_pos_names": ["knee"],
}
)
obs_seq = _make_ref_router_v2_obs([2, 3])
actions_seq = torch.randn(2, 3, 4)
attn_mask = torch.tril(torch.ones(3, 3, dtype=torch.bool)).expand(
2, -1, -1
)
seq_out = actor(
obs_seq,
actions=actions_seq,
mode="sequence_logp",
attn_mask=attn_mask,
update_obs_norm=False,
)
assert "aux_ref_keybody_rel_pos" in seq_out.keys()
assert "aux_robot_keybody_rel_pos" in seq_out.keys()
assert "aux_base_lin_vel_loc" in seq_out.keys()
assert seq_out["aux_ref_keybody_rel_pos"].shape == (2, 3, 1, 3)
assert seq_out["aux_robot_keybody_rel_pos"].shape == (2, 3, 1, 3)
def test_ref_router_seq_actor_requires_all_shared_ref_terms():
obs_schema = _make_ref_router_v2_obs_schema(include_ref_cur=False)
with pytest.raises(ValueError, match="missing required current ref term"):
_make_ref_router_v2_actor(obs_schema=obs_schema)
def test_ref_router_seq_actor_rejects_aux_router_command_recon():
with pytest.raises(ValueError, match="aux_router_command_recon"):
_make_ref_router_v2_actor(
aux_router_command_recon={"enabled": True, "hidden_dim": 8}
)
def test_ref_router_seq_actor_rejects_unsupported_aux_state_pred_weights():
with pytest.raises(ValueError, match="root_height"):
_make_ref_router_v2_actor(
aux_state_pred={
"enabled": True,
"w_base_lin_vel": 0.0,
"w_keybody_contact": 0.0,
"w_ref_keybody_rel_pos": 0.0,
"w_robot_keybody_rel_pos": 0.0,
"w_root_height": 1.0,
"keybody_contact_names": [],
"keybody_rel_pos_names": [],
}
)
================================================
FILE: tests/test_reference_filter_export.py
================================================
import json
import sys
import tempfile
import unittest
from unittest import mock
from pathlib import Path
import numpy as np
import torch
from omegaconf import OmegaConf
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.training.h5_dataloader import MotionClipSample
from holomotion.src.training.reference_filter_export import (
export_reference_filter_artifacts_from_config,
export_reference_filter_debug_artifacts,
)
def _quat_xyzw_from_rpy(roll: float, pitch: float, yaw: float) -> torch.Tensor:
cr = np.cos(roll * 0.5)
sr = np.sin(roll * 0.5)
cp = np.cos(pitch * 0.5)
sp = np.sin(pitch * 0.5)
cy = np.cos(yaw * 0.5)
sy = np.sin(yaw * 0.5)
return torch.tensor(
[
sr * cp * cy - cr * sp * sy,
cr * sp * cy + sr * cp * sy,
cr * cp * sy - sr * sp * cy,
cr * cp * cy + sr * sp * sy,
],
dtype=torch.float32,
)
def _make_sample(*, include_filtered: bool = True) -> MotionClipSample:
timesteps = 4
num_bodies = 5
num_dofs = 3
ref_rg_pos = torch.arange(
timesteps * num_bodies * 3, dtype=torch.float32
).reshape(timesteps, num_bodies, 3)
ref_body_vel = ref_rg_pos + 100.0
ref_body_ang_vel = ref_rg_pos + 200.0
ref_dof_pos = torch.arange(
timesteps * num_dofs, dtype=torch.float32
).reshape(timesteps, num_dofs)
ref_dof_vel = ref_dof_pos + 50.0
ref_rb_rot = torch.stack(
[
_quat_xyzw_from_rpy(0.0, 0.0, 0.0),
_quat_xyzw_from_rpy(0.1, -0.2, 0.3),
_quat_xyzw_from_rpy(0.2, -0.1, 0.4),
_quat_xyzw_from_rpy(0.3, 0.0, 0.5),
],
dim=0,
)[:, None, :].repeat(1, num_bodies, 1)
tensors = {
"ref_rg_pos": ref_rg_pos,
"ref_rb_rot": ref_rb_rot,
"ref_body_vel": ref_body_vel,
"ref_body_ang_vel": ref_body_ang_vel,
"ref_root_pos": ref_rg_pos[:, 0, :],
"ref_root_rot": ref_rb_rot[:, 0, :],
"ref_root_vel": ref_body_vel[:, 0, :],
"ref_root_ang_vel": ref_body_ang_vel[:, 0, :],
"ref_dof_pos": ref_dof_pos,
"ref_dof_vel": ref_dof_vel,
"filter_cutoff_hz": torch.full(
(timesteps, 1), 2.0, dtype=torch.float32
),
}
if include_filtered:
tensors.update(
{
"ft_ref_rg_pos": ref_rg_pos + 0.5,
"ft_ref_rb_rot": ref_rb_rot.clone(),
"ft_ref_body_vel": ref_body_vel + 0.25,
"ft_ref_body_ang_vel": ref_body_ang_vel + 0.25,
"ft_ref_root_pos": ref_rg_pos[:, 0, :] + 0.5,
"ft_ref_root_rot": ref_rb_rot[:, 0, :].clone(),
"ft_ref_root_vel": ref_body_vel[:, 0, :] + 0.25,
"ft_ref_root_ang_vel": ref_body_ang_vel[:, 0, :] + 0.25,
"ft_ref_dof_pos": ref_dof_pos + 0.75,
"ft_ref_dof_vel": ref_dof_vel + 0.75,
}
)
return MotionClipSample(
motion_key="clip-a__start_0_len_4",
raw_motion_key="clip-a",
window_index=0,
tensors=tensors,
length=timesteps,
)
class ReferenceFilterExportTests(unittest.TestCase):
def test_export_reference_filter_artifacts_from_config_builds_dataset(
self,
):
sample = _make_sample()
body_names = [
"root_link",
"torso_link",
"left_wrist_yaw_link",
"right_wrist_yaw_link",
"left_ankle_roll_link",
]
dof_names = [
"waist_yaw_joint",
"left_wrist_yaw_joint",
"left_ankle_roll_joint",
]
with tempfile.TemporaryDirectory() as tmp_dir:
config = OmegaConf.create(
{
"robot": {
"body_names": body_names,
"dof_names": dof_names,
"motion": {
"online_filter": {"enabled": True},
"max_frame_length": 4,
"min_frame_length": 1,
"world_frame_normalization": True,
},
},
"debug_reference_filter_export": {
"enabled": True,
"output_dir": tmp_dir,
"selected_body_links": [
"left_wrist_yaw_link",
"left_ankle_roll_link",
],
},
}
)
with mock.patch(
"holomotion.src.training.reference_filter_export."
"build_motion_datasets_from_cfg",
return_value=([sample], None, {}),
) as build_mock:
output_dir = export_reference_filter_artifacts_from_config(
config
)
self.assertEqual(output_dir, Path(tmp_dir))
build_mock.assert_called_once()
self.assertTrue((Path(tmp_dir) / "metadata.json").is_file())
def test_export_reference_filter_debug_artifacts_writes_outputs(self):
sample = _make_sample()
body_names = [
"root_link",
"torso_link",
"left_wrist_yaw_link",
"right_wrist_yaw_link",
"left_ankle_roll_link",
]
dof_names = [
"waist_yaw_joint",
"left_wrist_yaw_joint",
"left_ankle_roll_joint",
]
selected_body_links = [
"left_wrist_yaw_link",
"right_wrist_yaw_link",
"left_ankle_roll_link",
]
with tempfile.TemporaryDirectory() as tmp_dir:
output_dir = Path(tmp_dir)
export_reference_filter_debug_artifacts(
sample=sample,
output_dir=output_dir,
body_names=body_names,
dof_names=dof_names,
selected_body_links=selected_body_links,
)
self.assertTrue((output_dir / "metadata.json").is_file())
self.assertTrue((output_dir / "root_signals.npz").is_file())
self.assertTrue((output_dir / "bodylink_signals.npz").is_file())
self.assertTrue((output_dir / "dof_signals.npz").is_file())
self.assertTrue((output_dir / "root_comparison.png").is_file())
self.assertTrue(
(output_dir / "left_wrist_yaw_link_comparison.png").is_file()
)
self.assertTrue((output_dir / "dof_pos_comparison.png").is_file())
self.assertTrue((output_dir / "dof_vel_comparison.png").is_file())
metadata = json.loads(
(output_dir / "metadata.json").read_text(encoding="utf-8")
)
self.assertEqual(metadata["filter_cutoff_hz"], 2.0)
self.assertEqual(
metadata["selected_body_links"], selected_body_links
)
self.assertEqual(metadata["dof_names"], dof_names)
root_payload = np.load(output_dir / "root_signals.npz")
self.assertIn("ref_global_pos", root_payload.files)
self.assertIn("ft_ref_rpy", root_payload.files)
self.assertEqual(root_payload["ref_global_pos"].shape, (4, 3))
self.assertEqual(root_payload["ft_ref_rpy"].shape, (4, 3))
dof_payload = np.load(output_dir / "dof_signals.npz")
self.assertEqual(dof_payload["ref_dof_pos"].shape, (4, 3))
self.assertEqual(dof_payload["ft_ref_dof_vel"].shape, (4, 3))
def test_export_reference_filter_debug_artifacts_requires_filtered_tensors(
self,
):
sample = _make_sample(include_filtered=False)
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaisesRegex(
ValueError, "Filtered reference tensors are unavailable"
):
export_reference_filter_debug_artifacts(
sample=sample,
output_dir=Path(tmp_dir),
body_names=["root_link", "left_wrist_yaw_link"],
dof_names=["waist_yaw_joint"],
selected_body_links=["left_wrist_yaw_link"],
)
================================================
FILE: tests/test_reference_motion_config_wiring.py
================================================
import sys
import unittest
from pathlib import Path
from omegaconf import OmegaConf
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
PROJECT_ROOT = Path(__file__).resolve().parents[1]
OBS_CONFIG_PATHS = [
PROJECT_ROOT
/ "holomotion/config/env/observations/motion_tracking/obs_motion_tracking.yaml",
PROJECT_ROOT
/ "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_ref_v3.yaml",
PROJECT_ROOT
/ "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_more_info.yaml",
PROJECT_ROOT
/ "holomotion/config/env/observations/motion_tracking/obs_motrack_mlp_20260210.yaml",
PROJECT_ROOT
/ "holomotion/config/env/observations/motion_tracking/obs_motrack_tf_20260210.yaml",
PROJECT_ROOT
/ "holomotion/config/env/observations/motion_tracking/obs_motrack_teacher.yaml",
]
TERMINATION_CONFIG_PATHS = [
PROJECT_ROOT
/ "holomotion/config/env/terminations/termination_motion_tracking.yaml",
PROJECT_ROOT
/ "holomotion/config/env/terminations/termination_motion_tracking_simple.yaml",
PROJECT_ROOT
/ "holomotion/config/env/terminations/termination_motrack_with_kpe.yaml",
PROJECT_ROOT
/ "holomotion/config/env/terminations/termination_motrack_with_kpe_jpe.yaml",
]
ROBOT_TRAINING_CONFIG_PATH = (
PROJECT_ROOT
/ "holomotion/config/robot/unitree/G1/29dof/29dof_training_isaaclab.yaml"
)
REWARD_CONFIG_PATH = (
PROJECT_ROOT
/ "holomotion/config/env/rewards/motion_tracking/rew_motrack_robust.yaml"
)
class ReferenceMotionConfigWiringTests(unittest.TestCase):
def test_motion_tracking_observation_configs_expose_cutoff_term(self):
for config_path in OBS_CONFIG_PATHS:
with self.subTest(config_path=str(config_path)):
config = OmegaConf.load(config_path)
self.assertTrue(
self._config_has_obs_term(
config,
term_name="ref_motion_filter_cutoff_hz",
)
)
def test_motion_tracking_termination_configs_forward_ref_prefix(self):
for config_path in TERMINATION_CONFIG_PATHS:
with self.subTest(config_path=str(config_path)):
config = OmegaConf.load(config_path)
for term_name, term_cfg in config.terminations.items():
if term_name == "time_out":
continue
self.assertIn("params", term_cfg)
self.assertIn("ref_prefix", term_cfg.params)
def test_robot_training_config_uses_hdf5_v2_backend(self):
config = OmegaConf.load(ROBOT_TRAINING_CONFIG_PATH)
self.assertEqual(config.robot.motion.backend, "hdf5_v2")
def test_termination_ref_prefix_resolves_with_reward_config(self):
reward_config = OmegaConf.load(REWARD_CONFIG_PATH)
for config_path in TERMINATION_CONFIG_PATHS:
with self.subTest(config_path=str(config_path)):
termination_config = OmegaConf.load(config_path)
merged = OmegaConf.merge(reward_config, termination_config)
for term_name, term_cfg in merged.terminations.items():
if term_name == "time_out":
continue
resolved_ref_prefix = OmegaConf.select(
merged,
f"terminations.{term_name}.params.ref_prefix",
)
self.assertIsInstance(resolved_ref_prefix, str)
self.assertTrue(resolved_ref_prefix.endswith("ref_"))
@staticmethod
def _config_has_obs_term(config, term_name: str) -> bool:
for group_cfg in config.obs.obs_groups.values():
for term_cfg in group_cfg.atomic_obs_list:
for obs_name, obs_value in term_cfg.items():
if obs_name == term_name:
return True
if obs_value.get("func") == term_name:
return True
return False
================================================
FILE: tests/test_root_rel_rewards.py
================================================
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import torch
REWARDS_PATH = (
Path(__file__).resolve().parents[1]
/ "holomotion"
/ "src"
/ "env"
/ "isaaclab_components"
/ "isaaclab_rewards.py"
)
MOTION_TRACKING_PATH = (
Path(__file__).resolve().parents[1]
/ "holomotion"
/ "src"
/ "env"
/ "motion_tracking.py"
)
class _DummyConfig:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
if args:
self.name = args[0]
for key, value in kwargs.items():
setattr(self, key, value)
if not hasattr(self, "params"):
self.params = {}
class _DummyManagerTermBase:
def __init__(self, cfg, env):
self.cfg = cfg
self._env = env
@property
def num_envs(self):
return self._env.num_envs
@property
def device(self):
return self._env.device
def reset(self, env_ids=None):
pass
def _identity_quat(*shape: int) -> torch.Tensor:
quat = torch.zeros(*shape, 4, dtype=torch.float32)
quat[..., 0] = 1.0
return quat
def _load_rewards_module(monkeypatch):
isaaclab_root = ModuleType("isaaclab")
isaaclab_assets = ModuleType("isaaclab.assets")
isaaclab_assets.Articulation = object
isaaclab_envs = ModuleType("isaaclab.envs")
isaaclab_envs.ManagerBasedRLEnv = object
isaaclab_mdp = ModuleType("isaaclab.envs.mdp")
isaaclab_mdp.__getattr__ = lambda name: (lambda *args, **kwargs: None)
isaaclab_managers = ModuleType("isaaclab.managers")
isaaclab_managers.ManagerTermBase = _DummyManagerTermBase
isaaclab_managers.RewardTermCfg = _DummyConfig
isaaclab_managers.SceneEntityCfg = _DummyConfig
isaaclab_sensors = ModuleType("isaaclab.sensors")
isaaclab_sensors.ContactSensor = object
isaaclab_utils = ModuleType("isaaclab.utils")
isaaclab_utils.configclass = lambda cls: cls
isaaclab_math = ModuleType("isaaclab.utils.math")
isaaclab_math.quat_apply = lambda quat, vec: vec
isaaclab_math.quat_apply_inverse = lambda quat, vec: vec
isaaclab_math.quat_inv = lambda quat: quat
isaaclab_math.quat_mul = lambda lhs, rhs: lhs
isaaclab_math.yaw_quat = lambda quat: quat
isaaclab_math.quat_error_magnitude = lambda lhs, rhs: torch.linalg.norm(
lhs - rhs, dim=-1
)
isaaclab_math.__getattr__ = lambda name: (lambda *args, **kwargs: None)
hydra_utils = ModuleType("hydra.utils")
hydra_utils.instantiate = lambda value, *args, **kwargs: value
omegaconf = ModuleType("omegaconf")
omegaconf.DictConfig = dict
omegaconf.ListConfig = list
omegaconf.OmegaConf = SimpleNamespace(
to_container=lambda value, resolve=True: value
)
loguru = ModuleType("loguru")
loguru.logger = SimpleNamespace(
info=lambda *args, **kwargs: None,
warning=lambda *args, **kwargs: None,
)
fake_command_module = ModuleType(
"holomotion.src.env.isaaclab_components."
"isaaclab_motion_tracking_command"
)
fake_command_module.RefMotionCommand = object
fake_utils_module = ModuleType(
"holomotion.src.env.isaaclab_components.isaaclab_utils"
)
fake_utils_module._get_body_indices = lambda robot, keybody_names: [
robot.body_names.index(name) for name in keybody_names
]
fake_utils_module._get_dof_indices = lambda robot, key_dofs: []
fake_utils_module.resolve_holo_config = lambda value: value
for name, module in {
"isaaclab": isaaclab_root,
"isaaclab.assets": isaaclab_assets,
"isaaclab.envs": isaaclab_envs,
"isaaclab.envs.mdp": isaaclab_mdp,
"isaaclab.managers": isaaclab_managers,
"isaaclab.sensors": isaaclab_sensors,
"isaaclab.utils": isaaclab_utils,
"isaaclab.utils.math": isaaclab_math,
"hydra.utils": hydra_utils,
"omegaconf": omegaconf,
"loguru": loguru,
(
"holomotion.src.env.isaaclab_components."
"isaaclab_motion_tracking_command"
): fake_command_module,
(
"holomotion.src.env.isaaclab_components.isaaclab_utils"
): fake_utils_module,
}.items():
monkeypatch.setitem(sys.modules, name, module)
isaaclab_root.assets = isaaclab_assets
isaaclab_root.envs = isaaclab_envs
isaaclab_root.managers = isaaclab_managers
isaaclab_root.sensors = isaaclab_sensors
isaaclab_root.utils = isaaclab_utils
isaaclab_envs.mdp = isaaclab_mdp
isaaclab_utils.math = isaaclab_math
module_name = "_test_root_rel_rewards"
spec = importlib.util.spec_from_file_location(module_name, REWARDS_PATH)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _load_motion_tracking_module(monkeypatch):
class _DummyConfigClass:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
isaaclab_root = ModuleType("isaaclab")
isaaclab_actuators = ModuleType("isaaclab.actuators")
isaaclab_actuators.ImplicitActuatorCfg = _DummyConfigClass
isaaclab_assets = ModuleType("isaaclab.assets")
isaaclab_assets.Articulation = object
isaaclab_envs = ModuleType("isaaclab.envs")
isaaclab_envs.ManagerBasedEnv = object
isaaclab_envs.ManagerBasedRLEnv = object
isaaclab_envs.ManagerBasedRLEnvCfg = object
isaaclab_envs.ViewerCfg = _DummyConfigClass
isaaclab_envs_mdp = ModuleType("isaaclab.envs.mdp")
isaaclab_envs_mdp.__getattr__ = lambda name: (lambda *args, **kwargs: None)
isaaclab_envs_mdp_events = ModuleType("isaaclab.envs.mdp.events")
isaaclab_envs_mdp_events._randomize_prop_by_op = (
lambda *args, **kwargs: None
)
isaaclab_managers = ModuleType("isaaclab.managers")
isaaclab_managers.EventTermCfg = _DummyConfigClass
isaaclab_managers.SceneEntityCfg = _DummyConfig
isaaclab_sim = ModuleType("isaaclab.sim")
isaaclab_sim.PhysxCfg = _DummyConfigClass
isaaclab_sim.SimulationCfg = _DummyConfigClass
isaaclab_utils = ModuleType("isaaclab.utils")
isaaclab_utils.configclass = lambda cls: cls
isaaclab_utils_io = ModuleType("isaaclab.utils.io")
isaaclab_utils_io.dump_yaml = lambda *args, **kwargs: None
isaaclab_utils_math = ModuleType("isaaclab.utils.math")
isaaclab_utils_math.__getattr__ = lambda name: (
lambda *args, **kwargs: None
)
easydict = ModuleType("easydict")
easydict.EasyDict = lambda value=None: value if value is not None else {}
omegaconf = ModuleType("omegaconf")
omegaconf.OmegaConf = SimpleNamespace(
to_container=lambda value, resolve=True: value
)
loguru = ModuleType("loguru")
loguru.logger = SimpleNamespace(
info=lambda *args, **kwargs: None,
warning=lambda *args, **kwargs: None,
)
isaaclab_components = ModuleType("holomotion.src.env.isaaclab_components")
for name in [
"ActionsCfg",
"VelTrack_CommandsCfg",
"MoTrack_CommandsCfg",
"EventsCfg",
"MotionTrackingSceneCfg",
"ObservationsCfg",
"RewardsCfg",
"TerminationsCfg",
"CurriculumCfg",
]:
setattr(isaaclab_components, name, _DummyConfigClass)
for name in [
"build_actions_config",
"build_motion_tracking_commands_config",
"build_velocity_commands_config",
"build_domain_rand_config",
"build_curriculum_config",
"build_observations_config",
"build_rewards_config",
"build_scene_config",
"build_terminations_config",
]:
setattr(isaaclab_components, name, lambda *args, **kwargs: None)
fake_observation_module = ModuleType(
"holomotion.src.env.isaaclab_components.isaaclab_observation"
)
fake_observation_module.ObservationFunctions = object
fake_utils_module = ModuleType(
"holomotion.src.env.isaaclab_components.isaaclab_utils"
)
fake_utils_module.resolve_holo_config = lambda value: value
for name, module in {
"isaaclab": isaaclab_root,
"isaaclab.actuators": isaaclab_actuators,
"isaaclab.assets": isaaclab_assets,
"isaaclab.envs": isaaclab_envs,
"isaaclab.envs.mdp": isaaclab_envs_mdp,
"isaaclab.envs.mdp.events": isaaclab_envs_mdp_events,
"isaaclab.managers": isaaclab_managers,
"isaaclab.sim": isaaclab_sim,
"isaaclab.utils": isaaclab_utils,
"isaaclab.utils.io": isaaclab_utils_io,
"isaaclab.utils.math": isaaclab_utils_math,
"easydict": easydict,
"omegaconf": omegaconf,
"loguru": loguru,
"holomotion.src.env.isaaclab_components": isaaclab_components,
(
"holomotion.src.env.isaaclab_components.isaaclab_observation"
): fake_observation_module,
(
"holomotion.src.env.isaaclab_components.isaaclab_utils"
): fake_utils_module,
}.items():
monkeypatch.setitem(sys.modules, name, module)
isaaclab_root.actuators = isaaclab_actuators
isaaclab_root.assets = isaaclab_assets
isaaclab_root.envs = isaaclab_envs
isaaclab_root.managers = isaaclab_managers
isaaclab_root.sim = isaaclab_sim
isaaclab_root.utils = isaaclab_utils
isaaclab_envs.mdp = isaaclab_envs_mdp
isaaclab_utils.io = isaaclab_utils_io
isaaclab_utils.math = isaaclab_utils_math
module_name = "_test_motion_tracking"
spec = importlib.util.spec_from_file_location(
module_name, MOTION_TRACKING_PATH
)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _make_env():
env_origins = torch.tensor([[10.0, 0.0, 0.0]], dtype=torch.float32)
robot_data = SimpleNamespace(
body_pos_w=torch.tensor(
[[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32
),
body_quat_w=_identity_quat(1, 2),
body_lin_vel_w=torch.tensor(
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32
),
body_ang_vel_w=torch.tensor(
[[[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]], dtype=torch.float32
),
)
robot = SimpleNamespace(body_names=["anchor", "target"], data=robot_data)
command = SimpleNamespace(
robot=robot,
anchor_bodylink_idx=0,
get_ref_motion_root_global_pos_cur=lambda prefix="ref_": torch.tensor(
[[10.0, 0.0, 0.0]], dtype=torch.float32
),
get_ref_motion_root_global_pos_immediate_next=(
lambda prefix="ref_": torch.tensor(
[[10.0, 0.0, 0.0]], dtype=torch.float32
)
),
get_ref_motion_root_global_rot_quat_wxyz_cur=(
lambda prefix="ref_": _identity_quat(1)
),
get_ref_motion_root_global_rot_quat_wxyz_immediate_next=(
lambda prefix="ref_": _identity_quat(1)
),
get_ref_motion_root_global_lin_vel_cur=(
lambda prefix="ref_": torch.zeros(1, 3, dtype=torch.float32)
),
get_ref_motion_root_global_lin_vel_immediate_next=(
lambda prefix="ref_": torch.zeros(1, 3, dtype=torch.float32)
),
get_ref_motion_root_global_ang_vel_cur=(
lambda prefix="ref_": torch.tensor([[0.0, 0.0, 1.0]])
),
get_ref_motion_root_global_ang_vel_immediate_next=(
lambda prefix="ref_": torch.tensor([[0.0, 0.0, 1.0]])
),
get_ref_motion_bodylink_global_pos_cur=(
lambda prefix="ref_": torch.tensor(
[[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32
)
),
get_ref_motion_bodylink_global_pos_immediate_next=(
lambda prefix="ref_": torch.tensor(
[[[10.0, 0.0, 0.0], [11.0, 0.0, 0.0]]], dtype=torch.float32
)
),
get_ref_motion_bodylink_global_lin_vel_cur=(
lambda prefix="ref_": torch.tensor(
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32
)
),
get_ref_motion_bodylink_global_lin_vel_immediate_next=(
lambda prefix="ref_": torch.tensor(
[[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=torch.float32
)
),
get_ref_motion_bodylink_global_rot_wxyz_immediate_next=(
lambda prefix="ref_": _identity_quat(1, 2)
),
get_ref_motion_bodylink_global_ang_vel_immediate_next=(
lambda prefix="ref_": torch.tensor(
[[[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]], dtype=torch.float32
)
),
)
return SimpleNamespace(
command_manager=SimpleNamespace(get_term=lambda name: command),
scene=SimpleNamespace(env_origins=env_origins),
)
def _make_torque_rate_env(
applied_torque: torch.Tensor,
actuators: dict,
joint_vel: torch.Tensor | None = None,
joint_vel_limits: torch.Tensor | None = None,
):
class _Scene(dict):
pass
if joint_vel is None:
joint_vel = torch.zeros_like(applied_torque)
if joint_vel_limits is None:
joint_vel_limits = torch.ones_like(applied_torque)
asset = SimpleNamespace(
data=SimpleNamespace(
applied_torque=applied_torque.clone(),
joint_vel=joint_vel.clone(),
joint_vel_limits=joint_vel_limits.clone(),
),
actuators=actuators,
)
scene = _Scene(robot=asset)
return SimpleNamespace(
scene=scene,
num_envs=applied_torque.shape[0],
device=applied_torque.device,
episode_length_buf=torch.zeros(
applied_torque.shape[0],
dtype=torch.long,
device=applied_torque.device,
),
)
def _make_action_acc_env(action: torch.Tensor):
return SimpleNamespace(
action_manager=SimpleNamespace(action=action.clone()),
num_envs=action.shape[0],
device=action.device,
episode_length_buf=torch.zeros(
action.shape[0], dtype=torch.long, device=action.device
),
)
def test_root_rel_keybody_pos_reward_uses_true_root_frame(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
env = _make_env()
rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros(
1, 3, dtype=torch.float32
)
rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1)
reward = rewards.root_rel_keybodylink_pos_tracking_l2_exp(
env,
std=1.0,
keybody_names=["target"],
)
assert torch.allclose(reward, torch.ones(1))
def test_root_rel_keybody_pos_bydmmc_reward_uses_true_root_frame(
monkeypatch,
):
rewards = _load_rewards_module(monkeypatch)
env = _make_env()
rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros(
1, 3, dtype=torch.float32
)
rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1)
reward = rewards.root_rel_keybodylink_pos_tracking_l2_exp_bydmmc_style(
env,
std=1.0,
keybody_names=["target"],
)
assert torch.allclose(reward, torch.ones(1))
def test_root_rel_keybody_lin_vel_reward_uses_true_root_frame(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
env = _make_env()
rewards.isaaclab_mdp.root_pos_w = lambda _env: torch.zeros(
1, 3, dtype=torch.float32
)
rewards.isaaclab_mdp.root_quat_w = lambda _env: _identity_quat(1)
rewards.isaaclab_mdp.root_lin_vel_w = lambda _env: torch.zeros(
1, 3, dtype=torch.float32
)
rewards.isaaclab_mdp.root_ang_vel_w = lambda _env: torch.tensor(
[[0.0, 0.0, 1.0]], dtype=torch.float32
)
reward = rewards.root_rel_keybodylink_lin_vel_tracking_l2_exp(
env,
std=1.0,
keybody_names=["target"],
)
assert torch.allclose(reward, torch.ones(1))
def test_root_pos_xy_tracking_uses_immediate_next_reference(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
robot_data = SimpleNamespace(
root_pos_w=torch.tensor([[1.0, 2.0, 0.0]], dtype=torch.float32)
)
robot = SimpleNamespace(data=robot_data)
command = SimpleNamespace(
robot=robot,
get_ref_motion_root_global_pos_cur=(
lambda prefix="ref_": (_ for _ in ()).throw(
AssertionError("current reference should not be used")
)
),
get_ref_motion_root_global_pos_immediate_next=(
lambda prefix="ref_": torch.tensor(
[[1.0, 2.0, 3.0]], dtype=torch.float32
)
),
)
env = SimpleNamespace(
command_manager=SimpleNamespace(get_term=lambda name: command)
)
reward = rewards.root_pos_xy_tracking_exp(env, std=1.0)
assert torch.allclose(reward, torch.ones(1))
def test_global_keybody_lin_vel_tracking_uses_immediate_next_reference(
monkeypatch,
):
rewards = _load_rewards_module(monkeypatch)
robot_data = SimpleNamespace(
body_lin_vel_w=torch.tensor(
[[[0.0, 0.0, 0.0], [3.0, 4.0, 0.0]]], dtype=torch.float32
)
)
robot = SimpleNamespace(
body_names=["anchor", "target"],
data=robot_data,
)
command = SimpleNamespace(
robot=robot,
get_ref_motion_bodylink_global_lin_vel_cur=(
lambda prefix="ref_": (_ for _ in ()).throw(
AssertionError("current reference should not be used")
)
),
get_ref_motion_bodylink_global_lin_vel_immediate_next=(
lambda prefix="ref_": torch.tensor(
[[[0.0, 0.0, 0.0], [3.0, 4.0, 0.0]]], dtype=torch.float32
)
),
)
env = SimpleNamespace(
command_manager=SimpleNamespace(get_term=lambda name: command)
)
reward = rewards.global_keybodylink_lin_vel_tracking_l2_exp(
env,
std=1.0,
keybody_names=["target"],
)
assert torch.allclose(reward, torch.ones(1))
def test_normed_torque_rate_matches_selected_joint_math(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
env = _make_torque_rate_env(
applied_torque=torch.zeros(2, 3, dtype=torch.float32),
actuators={
"all_joints": SimpleNamespace(
joint_indices=slice(None),
effort_limit=torch.tensor(
[[10.0, 20.0, 40.0], [10.0, 20.0, 40.0]],
dtype=torch.float32,
),
)
},
)
term = rewards.normed_torque_rate(_DummyConfig(params={}), env)
asset_cfg = SimpleNamespace(
name="robot", joint_ids=torch.tensor([0, 2], dtype=torch.long)
)
first = term(env, asset_cfg=asset_cfg)
assert torch.allclose(first, torch.zeros(2))
env.episode_length_buf[:] = 1
env.scene["robot"].data.applied_torque = torch.tensor(
[[1.0, 9.0, 4.0], [2.0, 7.0, 8.0]],
dtype=torch.float32,
)
reward = term(env, asset_cfg=asset_cfg)
expected = torch.tensor(
[
(1.0 / 10.0) ** 2 + (4.0 / 40.0) ** 2,
(2.0 / 10.0) ** 2 + (8.0 / 40.0) ** 2,
],
dtype=torch.float32,
)
assert torch.allclose(reward, expected)
def test_normed_torque_rate_assembles_limits_across_actuator_groups(
monkeypatch,
):
rewards = _load_rewards_module(monkeypatch)
env = _make_torque_rate_env(
applied_torque=torch.zeros(1, 3, dtype=torch.float32),
actuators={
"implicit_group": SimpleNamespace(
joint_indices=torch.tensor([0, 2], dtype=torch.long),
effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32),
),
"unitree_group": SimpleNamespace(
joint_indices=torch.tensor([1], dtype=torch.long),
effort_limit=torch.tensor([[5.0]], dtype=torch.float32),
),
},
)
term = rewards.normed_torque_rate(_DummyConfig(params={}), env)
asset_cfg = SimpleNamespace(
name="robot", joint_ids=torch.tensor([0, 1, 2], dtype=torch.long)
)
_ = term(env, asset_cfg=asset_cfg)
env.episode_length_buf[:] = 1
env.scene["robot"].data.applied_torque = torch.tensor(
[[1.0, 1.0, 2.0]], dtype=torch.float32
)
reward = term(env, asset_cfg=asset_cfg)
expected = torch.tensor(
[(1.0 / 10.0) ** 2 + (1.0 / 5.0) ** 2 + (2.0 / 20.0) ** 2],
dtype=torch.float32,
)
assert torch.allclose(reward, expected)
def test_normed_torque_rate_resets_first_step_history(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
env = _make_torque_rate_env(
applied_torque=torch.tensor(
[[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32
),
actuators={
"all_joints": SimpleNamespace(
joint_indices=slice(None),
effort_limit=torch.tensor(
[[10.0, 10.0], [10.0, 10.0]], dtype=torch.float32
),
)
},
)
term = rewards.normed_torque_rate(_DummyConfig(params={}), env)
asset_cfg = SimpleNamespace(
name="robot", joint_ids=torch.tensor([0, 1], dtype=torch.long)
)
first = term(env, asset_cfg=asset_cfg)
assert torch.allclose(first, torch.zeros(2))
env.episode_length_buf[:] = 1
env.scene["robot"].data.applied_torque = torch.tensor(
[[2.0, 4.0], [5.0, 8.0]], dtype=torch.float32
)
second = term(env, asset_cfg=asset_cfg)
assert torch.all(second > 0.0)
term.reset(env_ids=[0])
env.scene["robot"].data.applied_torque = torch.tensor(
[[7.0, 9.0], [6.0, 10.0]], dtype=torch.float32
)
after_reset = term(env, asset_cfg=asset_cfg)
assert torch.isclose(after_reset[0], torch.tensor(0.0))
assert after_reset[1] > 0.0
def test_normed_torque_rate_reuses_cached_normalization(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
actuator = SimpleNamespace(
joint_indices=slice(None),
effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32),
)
env = _make_torque_rate_env(
applied_torque=torch.zeros(1, 2, dtype=torch.float32),
actuators={"all_joints": actuator},
)
term = rewards.normed_torque_rate(_DummyConfig(params={}), env)
asset_cfg = SimpleNamespace(
name="robot", joint_ids=torch.tensor([0, 1], dtype=torch.long)
)
_ = term(env, asset_cfg=asset_cfg)
actuator.effort_limit = torch.tensor(
[[1000.0, 1000.0]], dtype=torch.float32
)
env.episode_length_buf[:] = 1
env.scene["robot"].data.applied_torque = torch.tensor(
[[2.0, 4.0]], dtype=torch.float32
)
reward = term(env, asset_cfg=asset_cfg)
expected = torch.tensor(
[(2.0 / 10.0) ** 2 + (4.0 / 20.0) ** 2], dtype=torch.float32
)
assert torch.allclose(reward, expected)
def test_normed_positive_work_matches_selected_joint_math(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
env = _make_torque_rate_env(
applied_torque=torch.tensor(
[[2.0, 5.0, 8.0], [3.0, -4.0, 6.0]], dtype=torch.float32
),
joint_vel=torch.tensor(
[[1.0, -5.0, 2.0], [2.0, 3.0, -2.0]], dtype=torch.float32
),
joint_vel_limits=torch.tensor(
[[4.0, 10.0, 8.0], [4.0, 10.0, 8.0]], dtype=torch.float32
),
actuators={
"all_joints": SimpleNamespace(
joint_indices=slice(None),
effort_limit=torch.tensor(
[[4.0, 10.0, 16.0], [4.0, 10.0, 16.0]],
dtype=torch.float32,
),
)
},
)
term = rewards.normed_positive_work(_DummyConfig(params={}), env)
asset_cfg = SimpleNamespace(
name="robot", joint_ids=torch.tensor([0, 2], dtype=torch.long)
)
reward = term(env, asset_cfg=asset_cfg)
expected = torch.tensor(
[
(2.0 / 4.0) * (1.0 / 4.0) + (8.0 / 16.0) * (2.0 / 8.0),
(3.0 / 4.0) * (2.0 / 4.0),
],
dtype=torch.float32,
)
assert torch.allclose(reward, expected)
def test_normed_positive_work_assembles_effort_limits_across_actuators(
monkeypatch,
):
rewards = _load_rewards_module(monkeypatch)
env = _make_torque_rate_env(
applied_torque=torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float32),
joint_vel=torch.tensor([[5.0, 2.0, -1.0]], dtype=torch.float32),
joint_vel_limits=torch.tensor([[10.0, 4.0, 8.0]], dtype=torch.float32),
actuators={
"implicit_group": SimpleNamespace(
joint_indices=torch.tensor([0, 2], dtype=torch.long),
effort_limit=torch.tensor([[4.0, 20.0]], dtype=torch.float32),
),
"unitree_group": SimpleNamespace(
joint_indices=torch.tensor([1], dtype=torch.long),
effort_limit=torch.tensor([[6.0]], dtype=torch.float32),
),
},
)
term = rewards.normed_positive_work(_DummyConfig(params={}), env)
asset_cfg = SimpleNamespace(
name="robot", joint_ids=torch.tensor([0, 1, 2], dtype=torch.long)
)
reward = term(env, asset_cfg=asset_cfg)
expected = torch.tensor(
[
(2.0 / 4.0) * (5.0 / 10.0) + (3.0 / 6.0) * (2.0 / 4.0),
],
dtype=torch.float32,
)
assert torch.allclose(reward, expected)
def test_normed_positive_work_reuses_cached_effort_limits(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
actuator = SimpleNamespace(
joint_indices=slice(None),
effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32),
)
env = _make_torque_rate_env(
applied_torque=torch.tensor([[2.0, 4.0]], dtype=torch.float32),
joint_vel=torch.tensor([[5.0, 10.0]], dtype=torch.float32),
joint_vel_limits=torch.tensor([[10.0, 20.0]], dtype=torch.float32),
actuators={"all_joints": actuator},
)
term = rewards.normed_positive_work(_DummyConfig(params={}), env)
asset_cfg = SimpleNamespace(
name="robot", joint_ids=torch.tensor([0, 1], dtype=torch.long)
)
first = term(env, asset_cfg=asset_cfg)
assert torch.allclose(
first,
torch.tensor(
[(2.0 / 10.0) * (5.0 / 10.0) + (4.0 / 20.0) * (10.0 / 20.0)],
dtype=torch.float32,
),
)
actuator.effort_limit = torch.tensor(
[[1000.0, 1000.0]], dtype=torch.float32
)
reward = term(env, asset_cfg=asset_cfg)
expected = torch.tensor(
[(2.0 / 10.0) * (5.0 / 10.0) + (4.0 / 20.0) * (10.0 / 20.0)],
dtype=torch.float32,
)
assert torch.allclose(reward, expected)
def test_normed_positive_work_requires_positive_finite_velocity_limits(
monkeypatch,
):
rewards = _load_rewards_module(monkeypatch)
env = _make_torque_rate_env(
applied_torque=torch.tensor([[2.0, 4.0]], dtype=torch.float32),
joint_vel=torch.tensor([[1.0, 1.0]], dtype=torch.float32),
joint_vel_limits=torch.tensor([[0.0, torch.inf]], dtype=torch.float32),
actuators={
"all_joints": SimpleNamespace(
joint_indices=slice(None),
effort_limit=torch.tensor([[10.0, 20.0]], dtype=torch.float32),
)
},
)
term = rewards.normed_positive_work(_DummyConfig(params={}), env)
try:
term(
env,
asset_cfg=SimpleNamespace(
name="robot", joint_ids=torch.tensor([0, 1], dtype=torch.long)
),
)
except ValueError as exc:
assert (
"normed_positive_work requires finite, strictly positive"
in str(exc)
)
else:
raise AssertionError(
"expected normed_positive_work to reject invalid limits"
)
def test_action_acc_matches_second_order_action_change(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
env = _make_action_acc_env(torch.zeros(2, 2, dtype=torch.float32))
term = rewards.action_acc(_DummyConfig(params={}), env)
first = term(env)
assert torch.allclose(first, torch.zeros(2))
env.episode_length_buf[:] = 1
env.action_manager.action = torch.tensor(
[[1.0, 2.0], [2.0, 1.0]], dtype=torch.float32
)
second = term(env)
assert torch.allclose(second, torch.zeros(2))
env.episode_length_buf[:] = 2
env.action_manager.action = torch.tensor(
[[3.0, 1.0], [5.0, 1.0]], dtype=torch.float32
)
third = term(env)
expected = torch.tensor([10.0, 2.0], dtype=torch.float32)
assert torch.allclose(third, expected)
def test_action_acc_reset_clears_history(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
env = _make_action_acc_env(torch.zeros(1, 2, dtype=torch.float32))
term = rewards.action_acc(_DummyConfig(params={}), env)
assert torch.allclose(term(env), torch.zeros(1))
env.episode_length_buf[:] = 1
env.action_manager.action = torch.tensor([[1.0, 1.0]], dtype=torch.float32)
assert torch.allclose(term(env), torch.zeros(1))
env.episode_length_buf[:] = 2
env.action_manager.action = torch.tensor([[2.0, 0.0]], dtype=torch.float32)
assert torch.allclose(term(env), torch.tensor([4.0]))
term.reset(env_ids=[0])
env.action_manager.action = torch.tensor([[7.0, 7.0]], dtype=torch.float32)
assert torch.allclose(term(env), torch.zeros(1))
def test_build_rewards_config_exposes_action_acc_term(monkeypatch):
rewards = _load_rewards_module(monkeypatch)
rewards_cfg = rewards.build_rewards_config(
{
"action_acc": {
"weight": -2.5,
"params": {},
}
}
)
assert rewards_cfg.action_acc.func is rewards.action_acc
assert rewards_cfg.action_acc.weight == -2.5
assert rewards_cfg.action_acc.params == {}
def test_motion_tracking_logs_normed_torque_rate_metric(monkeypatch):
motion_tracking = _load_motion_tracking_module(monkeypatch)
env = motion_tracking.MotionTrackingEnv.__new__(
motion_tracking.MotionTrackingEnv
)
env.metrics = {}
env._robot_prev_joint_vel = None
env._robot_prev_applied_torque = None
env._robot_torque_rate_inv_effort_limit = None
env._robot_torque_rate_needs_reseed = None
robot = SimpleNamespace(
data=SimpleNamespace(
joint_vel=torch.zeros(2, 2, dtype=torch.float32),
applied_torque=torch.zeros(2, 2, dtype=torch.float32),
),
actuators={
"all_joints": SimpleNamespace(
joint_indices=slice(None),
effort_limit=torch.tensor(
[[10.0, 20.0], [10.0, 20.0]], dtype=torch.float32
),
)
},
)
env._env = SimpleNamespace(
step_dt=0.5,
action_manager=SimpleNamespace(
action=torch.zeros(2, 2, dtype=torch.float32),
prev_action=torch.zeros(2, 2, dtype=torch.float32),
),
scene={"robot": robot},
episode_length_buf=torch.zeros(2, dtype=torch.long),
num_envs=2,
device=torch.device("cpu"),
)
infos = {"log": {}}
env._update_robot_metrics(infos)
assert torch.isclose(
infos["log"]["Metrics/Robot/Normed_Torque_Rate"],
torch.tensor(0.0),
)
env._env.episode_length_buf[:] = 1
robot.data.applied_torque = torch.tensor(
[[1.0, 4.0], [2.0, 8.0]], dtype=torch.float32
)
env._update_robot_metrics(infos)
expected = torch.tensor(
[
(1.0 / 10.0) ** 2 + (4.0 / 20.0) ** 2,
(2.0 / 10.0) ** 2 + (8.0 / 20.0) ** 2,
],
dtype=torch.float32,
).mean()
assert torch.allclose(
infos["log"]["Metrics/Robot/Normed_Torque_Rate"], expected
)
================================================
FILE: tests/test_unitree_actuators.py
================================================
import importlib.util
import json
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
import torch
ACTUATOR_MODULE_PATH = (
Path(__file__).resolve().parents[1]
/ "holomotion"
/ "src"
/ "env"
/ "isaaclab_components"
/ "unitree_actuators.py"
)
SCENE_MODULE_PATH = (
Path(__file__).resolve().parents[1]
/ "holomotion"
/ "src"
/ "env"
/ "isaaclab_components"
/ "isaaclab_scene.py"
)
class _DummyArticulationActions:
def __init__(
self,
joint_positions=None,
joint_velocities=None,
joint_efforts=None,
joint_indices=None,
):
self.joint_positions = joint_positions
self.joint_velocities = joint_velocities
self.joint_efforts = joint_efforts
self.joint_indices = joint_indices
class _DummyDelayedPDActuatorCfg:
min_delay = 0
max_delay = 0
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class _DummyDelayedPDActuator:
def __init__(self, cfg, *args, **kwargs):
self.cfg = cfg
self._num_envs = kwargs.get("num_envs", 4)
self._device = kwargs.get("device", "cpu")
self.num_joints = len(
kwargs.get("joint_names", ["joint_a", "joint_b"])
)
self.computed_effort = torch.zeros(
self._num_envs, self.num_joints, device=self._device
)
self.applied_effort = torch.zeros_like(self.computed_effort)
effort_limit = kwargs.get("effort_limit", 100.0)
if isinstance(effort_limit, torch.Tensor):
self.effort_limit = effort_limit.clone().to(device=self._device)
else:
self.effort_limit = torch.full_like(
self.computed_effort, float(effort_limit)
)
self.super_compute_inputs = []
self.super_compute_joint_positions = []
self.reset_calls = []
def _parse_joint_parameter(self, value, default):
if value is None:
value = default
if isinstance(value, torch.Tensor):
return value.clone().to(device=self._device)
if isinstance(value, dict):
values = list(value.values())
tensor = torch.tensor(
values, dtype=torch.float32, device=self._device
)
return tensor.unsqueeze(0).repeat(self._num_envs, 1)
if isinstance(value, (float, int)):
return torch.full_like(self.computed_effort, float(value))
raise TypeError(f"Unsupported parameter type: {type(value)}")
def reset(self, env_ids):
self.reset_calls.append(env_ids)
def compute(self, control_action, joint_pos, joint_vel):
if control_action.joint_efforts is None:
self.super_compute_inputs.append(None)
else:
self.super_compute_inputs.append(
control_action.joint_efforts.clone()
)
if control_action.joint_positions is None:
self.super_compute_joint_positions.append(None)
else:
self.super_compute_joint_positions.append(
control_action.joint_positions.clone()
)
self.computed_effort = control_action.joint_efforts.clone()
self.applied_effort = control_action.joint_efforts.clone()
return control_action
def _configclass(cls):
annotations = getattr(cls, "__annotations__", {})
defaults = {
name: getattr(cls, name) for name in annotations if hasattr(cls, name)
}
def __init__(self, **kwargs):
for name, value in defaults.items():
setattr(self, name, value)
for key, value in kwargs.items():
setattr(self, key, value)
cls.__init__ = __init__
return cls
def _load_unitree_actuator_module(monkeypatch):
isaaclab_root = ModuleType("isaaclab")
isaaclab_actuators = ModuleType("isaaclab.actuators")
isaaclab_actuators.DelayedPDActuator = _DummyDelayedPDActuator
isaaclab_actuators.DelayedPDActuatorCfg = _DummyDelayedPDActuatorCfg
isaaclab_utils = ModuleType("isaaclab.utils")
isaaclab_utils.configclass = _configclass
isaaclab_utils_types = ModuleType("isaaclab.utils.types")
isaaclab_utils_types.ArticulationActions = _DummyArticulationActions
for name, module in {
"isaaclab": isaaclab_root,
"isaaclab.actuators": isaaclab_actuators,
"isaaclab.utils": isaaclab_utils,
"isaaclab.utils.types": isaaclab_utils_types,
}.items():
monkeypatch.setitem(sys.modules, name, module)
isaaclab_root.actuators = isaaclab_actuators
isaaclab_root.utils = isaaclab_utils
isaaclab_utils.types = isaaclab_utils_types
module_name = "_test_unitree_actuators"
spec = importlib.util.spec_from_file_location(
module_name, ACTUATOR_MODULE_PATH
)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def _make_erfi_actuator(module, *, cfg_kwargs=None, num_envs=4, num_joints=3):
if cfg_kwargs is None:
cfg_kwargs = {}
cfg_defaults = {
"Y1": 100.0,
"Y2": 120.0,
"erfi_enabled": True,
"ema_filter_enabled": False,
"ema_filter_alpha": 1.0,
"ema_filter_debug_dump_path": None,
"ema_filter_debug_stop_after_dump": False,
"rfi_probability": 0.5,
"rfi_lim": 0.1,
"randomize_rfi_lim": True,
"rfi_lim_range": (0.5, 1.5),
"rao_lim": 0.1,
}
cfg_defaults.update(cfg_kwargs)
cfg = module.UnitreeErfiActuatorCfg(**cfg_defaults)
actuator = module.UnitreeErfiActuator(
cfg,
joint_names=[f"joint_{idx}" for idx in range(num_joints)],
joint_ids=torch.arange(num_joints),
num_envs=num_envs,
device="cpu",
stiffness=0.0,
damping=0.0,
armature=0.0,
friction=0.0,
dynamic_friction=0.0,
viscous_friction=0.0,
effort_limit=100.0,
velocity_limit=100.0,
)
return actuator
def _make_action(actuator):
return _DummyArticulationActions(
joint_positions=torch.zeros_like(actuator.computed_effort),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
def test_unitree_erfi_reset_samples_all_rfi(monkeypatch):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={"rfi_probability": 1.0},
)
actuator.reset(torch.tensor([0, 1, 2, 3], dtype=torch.long))
assert torch.all(actuator._mode_is_rfi)
assert torch.allclose(
actuator._rao_scale, torch.zeros_like(actuator._rao_scale)
)
def test_unitree_erfi_reset_samples_all_rao(monkeypatch):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={"rfi_probability": 0.0},
)
actuator.reset(torch.tensor([0, 1, 2, 3], dtype=torch.long))
assert not torch.any(actuator._mode_is_rfi)
assert torch.any(actuator._rao_scale != 0.0)
def test_unitree_erfi_rfi_without_randomized_limit_uses_effort_limit_ratio(
monkeypatch,
):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={
"rfi_probability": 1.0,
"randomize_rfi_lim": False,
"rfi_lim": 0.1,
},
num_envs=2,
num_joints=2,
)
actuator.reset(torch.tensor([0, 1], dtype=torch.long))
torch.manual_seed(0)
actuator.compute(
_make_action(actuator),
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
injected = actuator.super_compute_inputs[-1]
assert torch.all(injected.abs() <= 10.0 + 1.0e-6)
def test_unitree_erfi_reset_randomizes_rfi_scale_within_range(monkeypatch):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={
"rfi_probability": 1.0,
"rfi_lim_range": (0.5, 1.5),
},
num_envs=2,
num_joints=2,
)
actuator.reset(torch.tensor([0, 1], dtype=torch.long))
assert torch.all(actuator._rfi_lim_scale >= 0.5)
assert torch.all(actuator._rfi_lim_scale <= 1.5)
def test_unitree_erfi_rao_bias_stays_constant_between_resets(monkeypatch):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={"rfi_probability": 0.0, "rao_lim": 0.1},
num_envs=2,
num_joints=2,
)
actuator.reset(torch.tensor([0, 1], dtype=torch.long))
action = _make_action(actuator)
actuator.compute(
action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
first = actuator.super_compute_inputs[-1].clone()
actuator.compute(
action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
second = actuator.super_compute_inputs[-1].clone()
assert torch.allclose(first, second)
def test_unitree_erfi_rfi_changes_each_compute(monkeypatch):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={"rfi_probability": 1.0, "randomize_rfi_lim": False},
num_envs=2,
num_joints=2,
)
actuator.reset(torch.tensor([0, 1], dtype=torch.long))
action = _make_action(actuator)
torch.manual_seed(0)
actuator.compute(
action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
first = actuator.super_compute_inputs[-1].clone()
torch.manual_seed(1)
actuator.compute(
action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
second = actuator.super_compute_inputs[-1].clone()
assert not torch.allclose(first, second)
def test_unitree_erfi_disabled_matches_plain_unitree(monkeypatch):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={"erfi_enabled": False},
num_envs=2,
num_joints=2,
)
action = _make_action(actuator)
actuator.reset(torch.tensor([0, 1], dtype=torch.long))
actuator.compute(
action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
assert torch.allclose(
actuator.super_compute_inputs[-1],
torch.zeros_like(actuator.super_compute_inputs[-1]),
)
def test_unitree_erfi_ema_filters_joint_positions(monkeypatch):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={
"erfi_enabled": False,
"ema_filter_enabled": True,
"ema_filter_alpha": 0.25,
},
num_envs=2,
num_joints=2,
)
first_action = _DummyArticulationActions(
joint_positions=torch.tensor(
[[1.0, -1.0], [0.5, -0.5]], dtype=torch.float32
),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
second_action = _DummyArticulationActions(
joint_positions=torch.tensor(
[[3.0, 1.0], [1.5, 0.5]], dtype=torch.float32
),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
actuator.compute(
first_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
actuator.compute(
second_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
assert torch.allclose(
actuator.super_compute_joint_positions[0],
first_action.joint_positions,
)
expected_second = (
0.25 * second_action.joint_positions
+ 0.75 * first_action.joint_positions
)
assert torch.allclose(
actuator.super_compute_joint_positions[1], expected_second
)
def test_unitree_erfi_ema_reset_clears_only_selected_envs(monkeypatch):
module = _load_unitree_actuator_module(monkeypatch)
actuator = _make_erfi_actuator(
module,
cfg_kwargs={
"erfi_enabled": False,
"ema_filter_enabled": True,
"ema_filter_alpha": 0.5,
},
num_envs=2,
num_joints=1,
)
one_action = _DummyArticulationActions(
joint_positions=torch.tensor([[1.0], [1.0]], dtype=torch.float32),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
two_action = _DummyArticulationActions(
joint_positions=torch.tensor([[2.0], [2.0]], dtype=torch.float32),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
zero_action = _DummyArticulationActions(
joint_positions=torch.zeros_like(actuator.computed_effort),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
actuator.compute(
one_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
actuator.compute(
two_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
actuator.reset(torch.tensor([1], dtype=torch.long))
actuator.compute(
zero_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
assert torch.allclose(
actuator.super_compute_joint_positions[1],
torch.tensor([[1.5], [1.5]], dtype=torch.float32),
)
assert torch.allclose(
actuator.super_compute_joint_positions[2],
torch.tensor([[0.75], [0.0]], dtype=torch.float32),
)
def test_unitree_erfi_ema_debug_dump_records_formula(monkeypatch, tmp_path):
module = _load_unitree_actuator_module(monkeypatch)
dump_path = tmp_path / "ema_verify.json"
actuator = _make_erfi_actuator(
module,
cfg_kwargs={
"erfi_enabled": False,
"ema_filter_enabled": True,
"ema_filter_alpha": 0.25,
"ema_filter_debug_dump_path": str(dump_path),
},
num_envs=2,
num_joints=2,
)
first_action = _DummyArticulationActions(
joint_positions=torch.tensor(
[[1.0, -1.0], [0.5, -0.5]], dtype=torch.float32
),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
second_action = _DummyArticulationActions(
joint_positions=torch.tensor(
[[3.0, 1.0], [1.5, 0.5]], dtype=torch.float32
),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
actuator.compute(
first_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
actuator.compute(
second_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
assert dump_path.is_file()
payload = json.loads(dump_path.read_text())
expected_second = (
0.25 * second_action.joint_positions[0]
+ 0.75 * first_action.joint_positions[0]
)
assert payload["alpha"] == 0.25
assert payload["matched"] is True
assert payload["env_index"] == 0
assert payload["raw_joint_positions"] == [3.0, 1.0]
assert payload["previous_filtered_joint_positions"] == [1.0, -1.0]
assert payload["expected_filtered_joint_positions"] == pytest.approx(
expected_second.tolist()
)
assert payload["actual_filtered_joint_positions"] == pytest.approx(
expected_second.tolist()
)
def test_unitree_erfi_ema_debug_stop_after_dump(monkeypatch, tmp_path):
module = _load_unitree_actuator_module(monkeypatch)
dump_path = tmp_path / "ema_verify.json"
actuator = _make_erfi_actuator(
module,
cfg_kwargs={
"erfi_enabled": False,
"ema_filter_enabled": True,
"ema_filter_alpha": 0.5,
"ema_filter_debug_dump_path": str(dump_path),
"ema_filter_debug_stop_after_dump": True,
},
num_envs=1,
num_joints=1,
)
first_action = _DummyArticulationActions(
joint_positions=torch.tensor([[1.0]], dtype=torch.float32),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
second_action = _DummyArticulationActions(
joint_positions=torch.tensor([[3.0]], dtype=torch.float32),
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
actuator.compute(
first_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
with pytest.raises(RuntimeError, match="EMA verification dump written"):
actuator.compute(
second_action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
assert dump_path.is_file()
def test_unitree_erfi_ema_debug_dump_records_skip_reason(
monkeypatch, tmp_path
):
module = _load_unitree_actuator_module(monkeypatch)
dump_path = tmp_path / "ema_verify_skip.json"
actuator = _make_erfi_actuator(
module,
cfg_kwargs={
"erfi_enabled": False,
"ema_filter_enabled": True,
"ema_filter_debug_dump_path": str(dump_path),
"ema_filter_debug_stop_after_dump": True,
},
num_envs=1,
num_joints=1,
)
action = _DummyArticulationActions(
joint_positions=None,
joint_velocities=torch.zeros_like(actuator.computed_effort),
joint_efforts=torch.zeros_like(actuator.computed_effort),
)
with pytest.raises(RuntimeError, match="EMA verification dump written"):
actuator.compute(
action,
joint_pos=torch.zeros_like(actuator.computed_effort),
joint_vel=torch.zeros_like(actuator.computed_effort),
)
payload = json.loads(dump_path.read_text())
assert payload["applied"] is False
assert payload["reason"] == "joint_positions_none"
def _load_scene_module(monkeypatch):
actuator_module = _load_unitree_actuator_module(monkeypatch)
isaaclab_root = ModuleType("isaaclab")
isaaclab_sim = ModuleType("isaaclab.sim")
isaaclab_sim.UrdfFileCfg = lambda **kwargs: SimpleNamespace(**kwargs)
isaaclab_sim.RigidBodyPropertiesCfg = lambda **kwargs: SimpleNamespace(
**kwargs
)
isaaclab_sim.ArticulationRootPropertiesCfg = (
lambda **kwargs: SimpleNamespace(**kwargs)
)
isaaclab_sim.UrdfConverterCfg = SimpleNamespace(
JointDriveCfg=SimpleNamespace(
PDGainsCfg=lambda **kwargs: SimpleNamespace(**kwargs)
)
)
isaaclab_actuators = ModuleType("isaaclab.actuators")
isaaclab_actuators.ImplicitActuatorCfg = lambda **kwargs: SimpleNamespace(
**kwargs
)
isaaclab_assets = ModuleType("isaaclab.assets")
isaaclab_assets.ArticulationCfg = SimpleNamespace(
InitialStateCfg=lambda **kwargs: SimpleNamespace(**kwargs)
)
isaaclab_assets.ArticulationCfg = lambda **kwargs: SimpleNamespace(
**kwargs
)
isaaclab_assets.AssetBaseCfg = lambda **kwargs: SimpleNamespace(**kwargs)
isaaclab_scene = ModuleType("isaaclab.scene")
isaaclab_scene.InteractiveSceneCfg = object
isaaclab_sensors = ModuleType("isaaclab.sensors")
isaaclab_sensors.ContactSensorCfg = lambda **kwargs: SimpleNamespace(
**kwargs
)
isaaclab_sensors.RayCasterCfg = SimpleNamespace(
OffsetCfg=lambda **kwargs: SimpleNamespace(**kwargs)
)
isaaclab_sensors.patterns = SimpleNamespace(
GridPatternCfg=lambda **kwargs: SimpleNamespace(**kwargs)
)
isaaclab_terrains = ModuleType("isaaclab.terrains")
isaaclab_terrains.TerrainImporterCfg = object
isaaclab_utils = ModuleType("isaaclab.utils")
isaaclab_utils.configclass = _configclass
loguru = ModuleType("loguru")
loguru.logger = SimpleNamespace(info=lambda *args, **kwargs: None)
fake_terrain = ModuleType(
"holomotion.src.env.isaaclab_components.isaaclab_terrain"
)
fake_terrain.build_terrain_config = lambda *args, **kwargs: None
fake_unitree = ModuleType(
"holomotion.src.env.isaaclab_components.unitree_actuators"
)
fake_unitree.UnitreeActuator = actuator_module.UnitreeActuator
fake_unitree.UnitreeActuatorCfg = actuator_module.UnitreeActuatorCfg
fake_unitree.UnitreeErfiActuator = actuator_module.UnitreeErfiActuator
fake_unitree.UnitreeErfiActuatorCfg = (
actuator_module.UnitreeErfiActuatorCfg
)
for name, module in {
"isaaclab": isaaclab_root,
"isaaclab.sim": isaaclab_sim,
"isaaclab.actuators": isaaclab_actuators,
"isaaclab.assets": isaaclab_assets,
"isaaclab.scene": isaaclab_scene,
"isaaclab.sensors": isaaclab_sensors,
"isaaclab.terrains": isaaclab_terrains,
"isaaclab.utils": isaaclab_utils,
"loguru": loguru,
(
"holomotion.src.env.isaaclab_components.isaaclab_terrain"
): fake_terrain,
(
"holomotion.src.env.isaaclab_components.unitree_actuators"
): fake_unitree,
}.items():
monkeypatch.setitem(sys.modules, name, module)
module_name = "_test_isaaclab_scene"
spec = importlib.util.spec_from_file_location(
module_name, SCENE_MODULE_PATH
)
module = importlib.util.module_from_spec(spec)
assert spec is not None
assert spec.loader is not None
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def test_scene_builder_selects_unitree_erfi_cfg(monkeypatch):
module = _load_scene_module(monkeypatch)
actuators = module._build_unitree_actuator_cfg(
{"actuator_type": "unitree_erfi"},
{"erfi": {"enabled": True, "rfi_lim": 0.2}},
)
assert isinstance(actuators["all_joints"], module.UnitreeErfiActuatorCfg)
assert actuators["all_joints"].erfi_enabled is True
assert actuators["all_joints"].rfi_lim == 0.2
def test_scene_builder_keeps_plain_unitree_cfg(monkeypatch):
module = _load_scene_module(monkeypatch)
actuators = module._build_unitree_actuator_cfg(
{"actuator_type": "unitree"}, {}
)
assert isinstance(actuators["all_joints"], module.UnitreeActuatorCfg)
assert not hasattr(actuators["all_joints"], "rfi_lim")
def test_scene_builder_disables_erfi_when_domain_rand_missing(monkeypatch):
module = _load_scene_module(monkeypatch)
actuators = module._build_unitree_actuator_cfg(
{"actuator_type": "unitree_erfi"}, {}
)
assert isinstance(actuators["all_joints"], module.UnitreeErfiActuatorCfg)
assert actuators["all_joints"].erfi_enabled is False
def test_scene_builder_applies_domain_rand_action_delay_to_unitree(
monkeypatch,
):
module = _load_scene_module(monkeypatch)
actuators = module._build_unitree_actuator_cfg(
{"actuator_type": "unitree"},
{"action_delay": {"enabled": True, "min_delay": 1, "max_delay": 3}},
)
assert isinstance(actuators["all_joints"], module.UnitreeActuatorCfg)
assert actuators["all_joints"].min_delay == 1
assert actuators["all_joints"].max_delay == 3
def test_scene_builder_applies_domain_rand_action_delay_to_unitree_erfi(
monkeypatch,
):
module = _load_scene_module(monkeypatch)
actuators = module._build_unitree_actuator_cfg(
{"actuator_type": "unitree_erfi"},
{
"erfi": {"enabled": True},
"action_delay": {
"enabled": True,
"min_delay": 2,
"max_delay": 4,
},
},
)
assert isinstance(actuators["all_joints"], module.UnitreeErfiActuatorCfg)
assert actuators["all_joints"].min_delay == 2
assert actuators["all_joints"].max_delay == 4
def test_scene_builder_applies_erfi_ema_filter_config(monkeypatch):
module = _load_scene_module(monkeypatch)
actuators = module._build_unitree_actuator_cfg(
{
"actuator_type": "unitree_erfi",
"ema_filter_enabled": True,
"ema_filter_alpha": 0.37,
},
{"erfi": {"enabled": True}},
)
assert isinstance(actuators["all_joints"], module.UnitreeErfiActuatorCfg)
assert actuators["all_joints"].class_type.__name__ == "UnitreeErfiActuator"
assert actuators["all_joints"].ema_filter_enabled is True
assert actuators["all_joints"].ema_filter_alpha == 0.37
def test_scene_builder_disables_action_delay_when_domain_rand_missing(
monkeypatch,
):
module = _load_scene_module(monkeypatch)
actuators = module._build_unitree_actuator_cfg(
{"actuator_type": "unitree"}, {}
)
assert isinstance(actuators["all_joints"], module.UnitreeActuatorCfg)
assert actuators["all_joints"].min_delay == 0
assert actuators["all_joints"].max_delay == 0
================================================
FILE: tests/test_visualize_with_mujoco.py
================================================
import sys
from pathlib import Path
import numpy as np
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from holomotion.src.motion_retargeting.utils.visualize_with_mujoco import (
_resolve_visualization_arrays,
)
def test_resolve_visualization_arrays_uses_robot_for_pose_and_ref_for_overlay():
arrays = {
"robot_dof_pos": np.array([[1.0, 2.0]], dtype=np.float32),
"robot_global_translation": np.array(
[[[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]], dtype=np.float32
),
"robot_global_rotation_quat": np.array(
[[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]],
dtype=np.float32,
),
"ref_global_translation": np.array(
[[[20.0, 21.0, 22.0], [23.0, 24.0, 25.0]]], dtype=np.float32
),
}
resolved = _resolve_visualization_arrays(
arrays=arrays,
key_prefix_order=["robot_"],
draw_ref_body_spheres=True,
ref_key_prefix_order=["ref_"],
)
np.testing.assert_allclose(resolved["dof_pos"], arrays["robot_dof_pos"])
np.testing.assert_allclose(
resolved["global_translation"], arrays["robot_global_translation"]
)
np.testing.assert_allclose(
resolved["global_rotation_quat"],
arrays["robot_global_rotation_quat"],
)
np.testing.assert_allclose(
resolved["ref_body_positions"],
arrays["ref_global_translation"],
)
================================================
FILE: train.env
================================================
# This is the environment file for running HoloMotion scripts.
export CONDA_BASE=$(conda info --base)
export Train_CONDA_PREFIX="$CONDA_BASE/envs/holomotion_train"
# export CUDA_HOME=$Train_CONDA_PREFIX
export CUDA_HOME=/usr/local/cuda
# export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$Train_CONDA_PREFIX/lib/:$Train_CONDA_PREFIX/lib/stubs"
# export LIBRARY_PATH="$Train_CONDA_PREFIX/lib/stubs:$Train_CONDA_PREFIX/lib:$LIBRARY_PATH"
export HYDRA_FULL_ERROR=1
export OMNI_KIT_ACCEPT_EULA="YES"
export ACCEPT_EULA="YES"
# export CUDA_LAUNCH_BLOCKING=1
export USE_NVRTC=1
export HDF5_USE_FILE_LOCKING=FALSE
export HOLOMOTION_ISAAC_STAGGER_SEC=1
export HOLOMOTION_HDF5_RDCC_NBYTES=$((4 * 1024 * 1024)) # 4MB
export HOLOMOTION_HDF5_MAX_OPEN_SHARDS=16 # 16 shards
export TORCH_DISTRIBUTED_DEBUG=INFO
# export TORCHDYNAMO_VERBOSE=0
echo "--------------------------------"
echo "Train_CONDA_PREFIX: $Train_CONDA_PREFIX"
echo "CUDA_HOME: $CUDA_HOME"
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
echo "LIBRARY_PATH: $LIBRARY_PATH"
echo "HYDRA_FULL_ERROR: $HYDRA_FULL_ERROR"
echo "OMNI_KIT_ACCEPT_EULA: $OMNI_KIT_ACCEPT_EULA"
echo "HDF5_USE_FILE_LOCKING: $HDF5_USE_FILE_LOCKING"
echo "HOLOMOTION_ISAAC_STAGGER_SEC: $HOLOMOTION_ISAAC_STAGGER_SEC"
echo "HOLOMOTION_HDF5_RDCC_NBYTES: $HOLOMOTION_HDF5_RDCC_NBYTES"
echo "HOLOMOTION_HDF5_MAX_OPEN_SHARDS: $HOLOMOTION_HDF5_MAX_OPEN_SHARDS"
echo "--------------------------------"
# Graceful shutdown function for training scripts
# Note: Scripts must set TRAIN_PID variable and call: trap cleanup SIGINT SIGTERM
cleanup() {
echo ""
echo "🛑 Cleanup triggered - shutting down training process ${TRAIN_PID}..."
exec 2>/dev/null # Suppress error messages during cleanup
[[ -n "${TRAIN_PID}" ]] && kill -TERM "${TRAIN_PID}" 2>/dev/null && echo " ✓ Sent TERM signal to process ${TRAIN_PID}"
sleep 2
[[ -n "${TRAIN_PID}" ]] && pkill -P "${TRAIN_PID}" 2>/dev/null && echo " ✓ Killed child processes"
exec 2>&1
echo " ✓ Cleanup complete"
}